Coverage for python/lsst/daf/butler/dimensions/_record_set.py: 23%
112 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:16 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:16 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("DimensionRecordSet", "DimensionRecordFactory")
32from collections.abc import Collection, Iterable, Iterator
33from typing import TYPE_CHECKING, Any, Protocol, final
35from ._coordinate import DataCoordinate, DataIdValue
36from ._records import DimensionRecord
38if TYPE_CHECKING:
39 from ._elements import DimensionElement
40 from ._universe import DimensionUniverse
43class DimensionRecordFactory(Protocol):
44 """Protocol for a callback that can be used to create a dimension record
45 to add to a `DimensionRecordSet` when a search for an existing one fails.
46 """
48 def __call__(
49 self, record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...]
50 ) -> DimensionRecord:
51 """Make a new `DimensionRecord` instance.
53 Parameters
54 ----------
55 record_class : `type` [ `DimensionRecord` ]
56 A concrete `DimensionRecord` subclass.
57 required_values : `tuple`
58 Tuple of data ID values, corresponding to
59 ``record_class.definition.required``.
60 """
61 ... # pragma: no cover
64def fail_record_lookup(
65 record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...]
66) -> DimensionRecord:
67 """Raise `LookupError` to indicate that a `DimensionRecord` could not be
68 found or created.
70 This is intended for use as the default value for arguments that take a
71 `DimensionRecordFactory` callback.
73 Parameters
74 ----------
75 record_class : `type` [ `DimensionRecord` ]
76 Type of record to create.
77 required_values : `tuple`
78 Tuple of data ID required values that are sufficient to identify a
79 record that exists in the data repository.
81 Returns
82 -------
83 record : `DimensionRecord`
84 Never returned; this function always raises `LookupError`.
85 """
86 raise LookupError(
87 f"No {record_class.definition.name!r} record with data ID "
88 f"{DataCoordinate.from_required_values(record_class.definition.minimal_group, required_values)}."
89 )
92@final
93class DimensionRecordSet(Collection[DimensionRecord]): # numpydoc ignore=PR01
94 """A mutable set-like container specialized for `DimensionRecord` objects.
96 Parameters
97 ----------
98 element : `DimensionElement` or `str`, optional
99 The dimension element that defines the records held by this set. If
100 not a `DimensionElement` instance, ``universe`` must be provided.
101 records : `~collections.abc.Iterable` [ `DimensionRecord` ], optional
102 Dimension records to add to the set.
103 universe : `DimensionUniverse`, optional
104 Object that defines all dimensions. Ignored if ``element`` is a
105 `DimensionElement` instance.
107 Notes
108 -----
109 `DimensionRecordSet` maintains its insertion order (like `dict`, and unlike
110 `set`).
112 `DimensionRecordSet` implements `collections.abc.Collection` but not
113 `collections.abc.Set` because the latter would require interoperability
114 with all other `~collections.abc.Set` implementations rather than just
115 `DimensionRecordSet`, and that adds a lot of complexity without much clear
116 value. To help make this clear to type checkers it implements only the
117 named-method versions of these operations (e.g. `issubset`) rather than the
118 operator special methods (e.g. ``__le__``).
120 `DimensionRecord` equality is defined in terms of a record's data ID fields
121 only, and `DimensionRecordSet` does not generally specify which record
122 "wins" when two records with the same data ID interact (e.g. in
123 `intersection`). The `add` and `update` methods are notable exceptions:
124 they always replace the existing record with the new one.
126 Dimension records can also be held by `DimensionRecordTable`, which
127 provides column-oriented access and Arrow interoperability.
128 """
130 def __init__(
131 self,
132 element: DimensionElement | str,
133 records: Iterable[DimensionRecord] = (),
134 universe: DimensionUniverse | None = None,
135 *,
136 _by_required_values: dict[tuple[DataIdValue, ...], DimensionRecord] | None = None,
137 ):
138 if isinstance(element, str):
139 if universe is None:
140 raise TypeError("'universe' must be provided if 'element' is not a DimensionElement.")
141 element = universe[element]
142 else:
143 universe = element.universe
144 if _by_required_values is None:
145 _by_required_values = {}
146 self._record_type = element.RecordClass
147 self._by_required_values = _by_required_values
148 self._dimensions = element.minimal_group
149 self.update(records)
151 @property
152 def element(self) -> DimensionElement:
153 """Name of the dimension element these records correspond to."""
154 return self._record_type.definition
156 def __contains__(self, key: object) -> bool:
157 match key:
158 case DimensionRecord() if key.definition == self.element:
159 required_values = key.dataId.required_values
160 case DataCoordinate() if key.dimensions == self.element.minimal_group:
161 required_values = key.required_values
162 case _:
163 return False
164 return required_values in self._by_required_values
166 def __len__(self) -> int:
167 return len(self._by_required_values)
169 def __iter__(self) -> Iterator[DimensionRecord]:
170 return iter(self._by_required_values.values())
172 def __eq__(self, other: object) -> bool:
173 if not isinstance(other, DimensionRecordSet):
174 return False
175 return (
176 self._record_type is other._record_type
177 and self._by_required_values.keys() == other._by_required_values.keys()
178 )
180 def __repr__(self) -> str:
181 lines = [f"DimensionRecordSet({self.element.name}, {{"]
182 for record in self:
183 lines.append(f" {record!r},")
184 lines.append("})")
185 return "\n".join(lines)
187 def issubset(self, other: DimensionRecordSet) -> bool:
188 """Test whether all elements in ``self`` are in ``other``.
190 Parameters
191 ----------
192 other : `DimensionRecordSet`
193 Another record set with the same record type.
195 Returns
196 -------
197 issubset ; `bool`
198 Whether all elements in ``self`` are in ``other``.
199 """
200 if self._record_type is not other._record_type:
201 raise ValueError(
202 "Invalid comparison between dimension record sets for elements "
203 f"{self.element.name!r} and {other.element.name!r}."
204 )
205 return self._by_required_values.keys() <= other._by_required_values.keys()
207 def issuperset(self, other: DimensionRecordSet) -> bool:
208 """Test whether all elements in ``other`` are in ``self``.
210 Parameters
211 ----------
212 other : `DimensionRecordSet`
213 Another record set with the same record type.
215 Returns
216 -------
217 issuperset ; `bool`
218 Whether all elements in ``other`` are in ``self``.
219 """
220 if self._record_type is not other._record_type:
221 raise ValueError(
222 "Invalid comparison between dimension record sets for elements "
223 f"{self.element.name!r} and {other.element.name!r}."
224 )
225 return self._by_required_values.keys() >= other._by_required_values.keys()
227 def isdisjoint(self, other: DimensionRecordSet) -> bool:
228 """Test whether the intersection of ``self`` and ``other`` is empty.
230 Parameters
231 ----------
232 other : `DimensionRecordSet`
233 Another record set with the same record type.
235 Returns
236 -------
237 isdisjoint ; `bool`
238 Whether the intersection of ``self`` and ``other`` is empty.
239 """
240 if self._record_type is not other._record_type:
241 raise ValueError(
242 "Invalid comparison between dimension record sets for elements "
243 f"{self.element.name!r} and {other.element.name!r}."
244 )
245 return self._by_required_values.keys().isdisjoint(other._by_required_values.keys())
247 def intersection(self, other: DimensionRecordSet) -> DimensionRecordSet:
248 """Return a new set with only records that are in both ``self`` and
249 ``other``.
251 Parameters
252 ----------
253 other : `DimensionRecordSet`
254 Another record set with the same record type.
256 Returns
257 -------
258 intersection : `DimensionRecordSet`
259 A new record set with all elements in both sets.
260 """
261 if self._record_type is not other._record_type:
262 raise ValueError(
263 "Invalid intersection between dimension record sets for elements "
264 f"{self.element.name!r} and {other.element.name!r}."
265 )
266 return DimensionRecordSet(
267 self.element,
268 _by_required_values={
269 k: v for k, v in self._by_required_values.items() if k in other._by_required_values
270 },
271 )
273 def difference(self, other: DimensionRecordSet) -> DimensionRecordSet:
274 """Return a new set with only records that are in ``self`` and not in
275 ``other``.
277 Parameters
278 ----------
279 other : `DimensionRecordSet`
280 Another record set with the same record type.
282 Returns
283 -------
284 difference : `DimensionRecordSet`
285 A new record set with all elements ``self`` that are not in
286 ``other``.
287 """
288 if self._record_type is not other._record_type:
289 raise ValueError(
290 "Invalid difference between dimension record sets for elements "
291 f"{self.element.name!r} and {other.element.name!r}."
292 )
293 return DimensionRecordSet(
294 self.element,
295 _by_required_values={
296 k: v for k, v in self._by_required_values.items() if k not in other._by_required_values
297 },
298 )
300 def union(self, other: DimensionRecordSet) -> DimensionRecordSet:
301 """Return a new set with all records that are either in ``self`` or
302 ``other``.
304 Parameters
305 ----------
306 other : `DimensionRecordSet`
307 Another record set with the same record type.
309 Returns
310 -------
311 intersection : `DimensionRecordSet`
312 A new record set with all elements in either set.
313 """
314 if self._record_type is not other._record_type:
315 raise ValueError(
316 "Invalid union between dimension record sets for elements "
317 f"{self.element.name!r} and {other.element.name!r}."
318 )
319 return DimensionRecordSet(
320 self.element,
321 _by_required_values=self._by_required_values | other._by_required_values,
322 )
324 def find(
325 self,
326 data_id: DataCoordinate,
327 or_add: DimensionRecordFactory = fail_record_lookup,
328 ) -> DimensionRecord:
329 """Return the record with the given data ID.
331 Parameters
332 ----------
333 data_id : `DataCoordinate`
334 Data ID to match.
335 or_add : `DimensionRecordFactory`
336 Callback that is invoked if no existing record is found, to create
337 a new record that is added to the set and returned. The return
338 value of this callback is *not* checked to see if it is a valid
339 dimension record with the right element and data ID.
341 Returns
342 -------
343 record : `DimensionRecord`
344 Matching record.
346 Raises
347 ------
348 KeyError
349 Raised if no record with this data ID was found.
350 ValueError
351 Raised if the data ID did not have the right dimensions.
352 """
353 if data_id.dimensions != self._dimensions:
354 raise ValueError(
355 f"data ID {data_id} has incorrect dimensions for dimension records for {self.element!r}."
356 )
357 return self.find_with_required_values(data_id.required_values, or_add)
359 def find_with_required_values(
360 self, required_values: tuple[DataIdValue, ...], or_add: DimensionRecordFactory = fail_record_lookup
361 ) -> DimensionRecord:
362 """Return the record whose data ID has the given required values.
364 Parameters
365 ----------
366 required_values : `tuple` [ `int` or `str` ]
367 Data ID values to match.
368 or_add : `DimensionRecordFactory`
369 Callback that is invoked if no existing record is found, to create
370 a new record that is added to the set and returned. The return
371 value of this callback is *not* checked to see if it is a valid
372 dimension record with the right element and data ID.
374 Returns
375 -------
376 record : `DimensionRecord`
377 Matching record.
379 Raises
380 ------
381 ValueError
382 Raised if the data ID did not have the right dimensions.
383 """
384 if (result := self._by_required_values.get(required_values)) is None:
385 result = or_add(self._record_type, required_values)
386 self._by_required_values[required_values] = result
387 return result
389 def add(self, value: DimensionRecord, replace: bool = True) -> None:
390 """Add a new record to the set.
392 Parameters
393 ----------
394 value : `DimensionRecord`
395 Record to add.
396 replace : `bool`, optional
397 If `True` (default) replace any existing record with the same data
398 ID. If `False` the existing record will be kept.
400 Raises
401 ------
402 ValueError
403 Raised if ``value.element != self.element``.
404 """
405 if value.definition.name != self.element:
406 raise ValueError(
407 f"Cannot add record {value} for {value.definition.name!r} to set for {self.element!r}."
408 )
409 if replace:
410 self._by_required_values[value.dataId.required_values] = value
411 else:
412 self._by_required_values.setdefault(value.dataId.required_values, value)
414 def update(self, values: Iterable[DimensionRecord], replace: bool = True) -> None:
415 """Add new records to the set.
417 Parameters
418 ----------
419 values : `~collections.abc.Iterable` [ `DimensionRecord` ]
420 Records to add.
421 replace : `bool`, optional
422 If `True` (default) replace any existing records with the same data
423 IDs. If `False` the existing records will be kept.
425 Raises
426 ------
427 ValueError
428 Raised if ``value.element != self.element``.
429 """
430 for value in values:
431 self.add(value, replace=replace)
433 def update_from_data_coordinates(self, data_coordinates: Iterable[DataCoordinate]) -> None:
434 """Add records to the set by extracting and deduplicating them from
435 data coordinates.
437 Parameters
438 ----------
439 data_coordinates : `~collections.abc.Iterable` [ `DataCoordinate` ]
440 Data coordinates to extract from. `DataCoordinate.hasRecords` must
441 be `True`.
442 """
443 for data_coordinate in data_coordinates:
444 if record := data_coordinate._record(self.element.name):
445 self._by_required_values[record.dataId.required_values] = record
447 def discard(self, value: DimensionRecord | DataCoordinate) -> None:
448 """Remove a record if it exists.
450 Parameters
451 ----------
452 value : `DimensionRecord` or `DataCoordinate`
453 Record to remove, or its data ID.
454 """
455 if isinstance(value, DimensionRecord):
456 value = value.dataId
457 if value.dimensions != self._dimensions:
458 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.")
459 self._by_required_values.pop(value.required_values, None)
461 def remove(self, value: DimensionRecord | DataCoordinate) -> None:
462 """Remove a record.
464 Parameters
465 ----------
466 value : `DimensionRecord` or `DataCoordinate`
467 Record to remove, or its data ID.
469 Raises
470 ------
471 KeyError
472 Raised if there is no matching record.
473 """
474 if isinstance(value, DimensionRecord):
475 value = value.dataId
476 if value.dimensions != self._dimensions:
477 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.")
478 del self._by_required_values[value.required_values]
480 def pop(self) -> DimensionRecord:
481 """Remove and return an arbitrary record."""
482 return self._by_required_values.popitem()[1]
484 def __deepcopy__(self, memo: dict[str, Any]) -> DimensionRecordSet:
485 return DimensionRecordSet(self.element, _by_required_values=self._by_required_values.copy())