Coverage for python/lsst/daf/butler/dimensions/_record_set.py: 23%
106 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:50 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:50 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("DimensionRecordSet", "DimensionRecordFactory")
32from collections.abc import Collection, Iterable, Iterator
33from typing import TYPE_CHECKING, Any, Protocol, final
35from ._coordinate import DataCoordinate, DataIdValue
36from ._records import DimensionRecord
38if TYPE_CHECKING:
39 from ._elements import DimensionElement
40 from ._universe import DimensionUniverse
43class DimensionRecordFactory(Protocol):
44 """Protocol for a callback that can be used to create a dimension record
45 to add to a `DimensionRecordSet` when a search for an existing one fails.
46 """
48 def __call__(
49 self, record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...]
50 ) -> DimensionRecord:
51 """Make a new `DimensionRecord` instance.
53 Parameters
54 ----------
55 record_class : `type` [ `DimensionRecord` ]
56 A concrete `DimensionRecord` subclass.
57 required_values : `tuple`
58 Tuple of data ID values, corresponding to
59 ``record_class.definition.required``.
60 """
61 ... # pragma: no cover
64def fail_record_lookup(
65 record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...]
66) -> DimensionRecord:
67 """Raise `LookupError` to indicate that a `DimensionRecord` could not be
68 found or created.
70 This is intended for use as the default value for arguments that take a
71 `DimensionRecordFactory` callback.
73 Parameters
74 ----------
75 record_class : `type` [ `DimensionRecord` ]
76 Type of record to create.
77 required_values : `tuple`
78 Tuple of data ID required values that are sufficient to identify a
79 record that exists in the data repository.
81 Returns
82 -------
83 record : `DimensionRecord`
84 Never returned; this function always raises `LookupError`.
85 """
86 raise LookupError(
87 f"No {record_class.definition.name!r} record with data ID "
88 f"{DataCoordinate.from_required_values(record_class.definition.minimal_group, required_values)}."
89 )
92@final
93class DimensionRecordSet(Collection[DimensionRecord]): # numpydoc ignore=PR01
94 """A mutable set-like container specialized for `DimensionRecord` objects.
96 Parameters
97 ----------
98 element : `DimensionElement` or `str`, optional
99 The dimension element that defines the records held by this set. If
100 not a `DimensionElement` instance, ``universe`` must be provided.
101 records : `~collections.abc.Iterable` [ `DimensionRecord` ], optional
102 Dimension records to add to the set.
103 universe : `DimensionUniverse`, optional
104 Object that defines all dimensions. Ignored if ``element`` is a
105 `DimensionElement` instance.
107 Notes
108 -----
109 `DimensionRecordSet` maintains its insertion order (like `dict`, and unlike
110 `set`).
112 `DimensionRecordSet` implements `collections.abc.Collection` but not
113 `collections.abc.Set` because the latter would require interoperability
114 with all other `~collections.abc.Set` implementations rather than just
115 `DimensionRecordSet`, and that adds a lot of complexity without much clear
116 value. To help make this clear to type checkers it implements only the
117 named-method versions of these operations (e.g. `issubset`) rather than the
118 operator special methods (e.g. ``__le__``).
120 `DimensionRecord` equality is defined in terms of a record's data ID fields
121 only, and `DimensionRecordSet` does not generally specify which record
122 "wins" when two records with the same data ID interact (e.g. in
123 `intersection`). The `add` and `update` methods are notable exceptions:
124 they always replace the existing record with the new one.
126 Dimension records can also be held by `DimensionRecordTable`, which
127 provides column-oriented access and Arrow interoperability.
128 """
130 def __init__(
131 self,
132 element: DimensionElement | str,
133 records: Iterable[DimensionRecord] = (),
134 universe: DimensionUniverse | None = None,
135 *,
136 _by_required_values: dict[tuple[DataIdValue, ...], DimensionRecord] | None = None,
137 ):
138 if isinstance(element, str):
139 if universe is None:
140 raise TypeError("'universe' must be provided if 'element' is not a DimensionElement.")
141 element = universe[element]
142 else:
143 universe = element.universe
144 if _by_required_values is None:
145 _by_required_values = {}
146 self._record_type = element.RecordClass
147 self._by_required_values = _by_required_values
148 self._dimensions = element.minimal_group
149 self.update(records)
151 @property
152 def element(self) -> DimensionElement:
153 """Name of the dimension element these records correspond to."""
154 return self._record_type.definition
156 def __contains__(self, key: object) -> bool:
157 match key:
158 case DimensionRecord() if key.definition == self.element:
159 required_values = key.dataId.required_values
160 case DataCoordinate() if key.dimensions == self.element.minimal_group:
161 required_values = key.required_values
162 case _:
163 return False
164 return required_values in self._by_required_values
166 def __len__(self) -> int:
167 return len(self._by_required_values)
169 def __iter__(self) -> Iterator[DimensionRecord]:
170 return iter(self._by_required_values.values())
172 def __eq__(self, other: object) -> bool:
173 if not isinstance(other, DimensionRecordSet):
174 return False
175 return (
176 self._record_type is other._record_type
177 and self._by_required_values.keys() == other._by_required_values.keys()
178 )
180 def issubset(self, other: DimensionRecordSet) -> bool:
181 """Test whether all elements in ``self`` are in ``other``.
183 Parameters
184 ----------
185 other : `DimensionRecordSet`
186 Another record set with the same record type.
188 Returns
189 -------
190 issubset ; `bool`
191 Whether all elements in ``self`` are in ``other``.
192 """
193 if self._record_type is not other._record_type:
194 raise ValueError(
195 "Invalid comparison between dimension record sets for elements "
196 f"{self.element.name!r} and {other.element.name!r}."
197 )
198 return self._by_required_values.keys() <= other._by_required_values.keys()
200 def issuperset(self, other: DimensionRecordSet) -> bool:
201 """Test whether all elements in ``other`` are in ``self``.
203 Parameters
204 ----------
205 other : `DimensionRecordSet`
206 Another record set with the same record type.
208 Returns
209 -------
210 issuperset ; `bool`
211 Whether all elements in ``other`` are in ``self``.
212 """
213 if self._record_type is not other._record_type:
214 raise ValueError(
215 "Invalid comparison between dimension record sets for elements "
216 f"{self.element.name!r} and {other.element.name!r}."
217 )
218 return self._by_required_values.keys() >= other._by_required_values.keys()
220 def isdisjoint(self, other: DimensionRecordSet) -> bool:
221 """Test whether the intersection of ``self`` and ``other`` is empty.
223 Parameters
224 ----------
225 other : `DimensionRecordSet`
226 Another record set with the same record type.
228 Returns
229 -------
230 isdisjoint ; `bool`
231 Whether the intersection of ``self`` and ``other`` is empty.
232 """
233 if self._record_type is not other._record_type:
234 raise ValueError(
235 "Invalid comparison between dimension record sets for elements "
236 f"{self.element.name!r} and {other.element.name!r}."
237 )
238 return self._by_required_values.keys().isdisjoint(other._by_required_values.keys())
240 def intersection(self, other: DimensionRecordSet) -> DimensionRecordSet:
241 """Return a new set with only records that are in both ``self`` and
242 ``other``.
244 Parameters
245 ----------
246 other : `DimensionRecordSet`
247 Another record set with the same record type.
249 Returns
250 -------
251 intersection : `DimensionRecordSet`
252 A new record set with all elements in both sets.
253 """
254 if self._record_type is not other._record_type:
255 raise ValueError(
256 "Invalid intersection between dimension record sets for elements "
257 f"{self.element.name!r} and {other.element.name!r}."
258 )
259 return DimensionRecordSet(
260 self.element,
261 _by_required_values={
262 k: v for k, v in self._by_required_values.items() if k in other._by_required_values
263 },
264 )
266 def difference(self, other: DimensionRecordSet) -> DimensionRecordSet:
267 """Return a new set with only records that are in ``self`` and not in
268 ``other``.
270 Parameters
271 ----------
272 other : `DimensionRecordSet`
273 Another record set with the same record type.
275 Returns
276 -------
277 difference : `DimensionRecordSet`
278 A new record set with all elements ``self`` that are not in
279 ``other``.
280 """
281 if self._record_type is not other._record_type:
282 raise ValueError(
283 "Invalid difference between dimension record sets for elements "
284 f"{self.element.name!r} and {other.element.name!r}."
285 )
286 return DimensionRecordSet(
287 self.element,
288 _by_required_values={
289 k: v for k, v in self._by_required_values.items() if k not in other._by_required_values
290 },
291 )
293 def union(self, other: DimensionRecordSet) -> DimensionRecordSet:
294 """Return a new set with all records that are either in ``self`` or
295 ``other``.
297 Parameters
298 ----------
299 other : `DimensionRecordSet`
300 Another record set with the same record type.
302 Returns
303 -------
304 intersection : `DimensionRecordSet`
305 A new record set with all elements in either set.
306 """
307 if self._record_type is not other._record_type:
308 raise ValueError(
309 "Invalid union between dimension record sets for elements "
310 f"{self.element.name!r} and {other.element.name!r}."
311 )
312 return DimensionRecordSet(
313 self.element,
314 _by_required_values=self._by_required_values | other._by_required_values,
315 )
317 def find(
318 self,
319 data_id: DataCoordinate,
320 or_add: DimensionRecordFactory = fail_record_lookup,
321 ) -> DimensionRecord:
322 """Return the record with the given data ID.
324 Parameters
325 ----------
326 data_id : `DataCoordinate`
327 Data ID to match.
328 or_add : `DimensionRecordFactory`
329 Callback that is invoked if no existing record is found, to create
330 a new record that is added to the set and returned. The return
331 value of this callback is *not* checked to see if it is a valid
332 dimension record with the right element and data ID.
334 Returns
335 -------
336 record : `DimensionRecord`
337 Matching record.
339 Raises
340 ------
341 KeyError
342 Raised if no record with this data ID was found.
343 ValueError
344 Raised if the data ID did not have the right dimensions.
345 """
346 if data_id.dimensions != self._dimensions:
347 raise ValueError(
348 f"data ID {data_id} has incorrect dimensions for dimension records for {self.element!r}."
349 )
350 return self.find_with_required_values(data_id.required_values, or_add)
352 def find_with_required_values(
353 self, required_values: tuple[DataIdValue, ...], or_add: DimensionRecordFactory = fail_record_lookup
354 ) -> DimensionRecord:
355 """Return the record whose data ID has the given required values.
357 Parameters
358 ----------
359 required_values : `tuple` [ `int` or `str` ]
360 Data ID values to match.
361 or_add : `DimensionRecordFactory`
362 Callback that is invoked if no existing record is found, to create
363 a new record that is added to the set and returned. The return
364 value of this callback is *not* checked to see if it is a valid
365 dimension record with the right element and data ID.
367 Returns
368 -------
369 record : `DimensionRecord`
370 Matching record.
372 Raises
373 ------
374 ValueError
375 Raised if the data ID did not have the right dimensions.
376 """
377 if (result := self._by_required_values.get(required_values)) is None:
378 result = or_add(self._record_type, required_values)
379 self._by_required_values[required_values] = result
380 return result
382 def add(self, value: DimensionRecord, replace: bool = True) -> None:
383 """Add a new record to the set.
385 Parameters
386 ----------
387 value : `DimensionRecord`
388 Record to add.
389 replace : `bool`, optional
390 If `True` (default) replace any existing record with the same data
391 ID. If `False` the existing record will be kept.
393 Raises
394 ------
395 ValueError
396 Raised if ``value.element != self.element``.
397 """
398 if value.definition.name != self.element:
399 raise ValueError(
400 f"Cannot add record {value} for {value.definition.name!r} to set for {self.element!r}."
401 )
402 if replace:
403 self._by_required_values[value.dataId.required_values] = value
404 else:
405 self._by_required_values.setdefault(value.dataId.required_values, value)
407 def update(self, values: Iterable[DimensionRecord], replace: bool = True) -> None:
408 """Add new records to the set.
410 Parameters
411 ----------
412 values : `~collections.abc.Iterable` [ `DimensionRecord` ]
413 Records to add.
414 replace : `bool`, optional
415 If `True` (default) replace any existing records with the same data
416 IDs. If `False` the existing records will be kept.
418 Raises
419 ------
420 ValueError
421 Raised if ``value.element != self.element``.
422 """
423 for value in values:
424 self.add(value, replace=replace)
426 def update_from_data_coordinates(self, data_coordinates: Iterable[DataCoordinate]) -> None:
427 """Add records to the set by extracting and deduplicating them from
428 data coordinates.
430 Parameters
431 ----------
432 data_coordinates : `~collections.abc.Iterable` [ `DataCoordinate` ]
433 Data coordinates to extract from. `DataCoordinate.hasRecords` must
434 be `True`.
435 """
436 for data_coordinate in data_coordinates:
437 if record := data_coordinate._record(self.element.name):
438 self._by_required_values[record.dataId.required_values] = record
440 def discard(self, value: DimensionRecord | DataCoordinate) -> None:
441 """Remove a record if it exists.
443 Parameters
444 ----------
445 value : `DimensionRecord` or `DataCoordinate`
446 Record to remove, or its data ID.
447 """
448 if isinstance(value, DimensionRecord):
449 value = value.dataId
450 if value.dimensions != self._dimensions:
451 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.")
452 self._by_required_values.pop(value.required_values, None)
454 def remove(self, value: DimensionRecord | DataCoordinate) -> None:
455 """Remove a record.
457 Parameters
458 ----------
459 value : `DimensionRecord` or `DataCoordinate`
460 Record to remove, or its data ID.
462 Raises
463 ------
464 KeyError
465 Raised if there is no matching record.
466 """
467 if isinstance(value, DimensionRecord):
468 value = value.dataId
469 if value.dimensions != self._dimensions:
470 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.")
471 del self._by_required_values[value.required_values]
473 def pop(self) -> DimensionRecord:
474 """Remove and return an arbitrary record."""
475 return self._by_required_values.popitem()[1]
477 def __deepcopy__(self, memo: dict[str, Any]) -> DimensionRecordSet:
478 return DimensionRecordSet(self.element, _by_required_values=self._by_required_values.copy())