Coverage for tests / test_dimensions.py: 11%
493 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import itertools
30import math
31import os
32import pickle
33import unittest
34from collections.abc import Iterator
35from dataclasses import dataclass
36from random import Random
38import pydantic
40import lsst.sphgeom
41from lsst.daf.butler import (
42 Config,
43 DataCoordinate,
44 DataCoordinateSequence,
45 DataCoordinateSet,
46 DatasetType,
47 Dimension,
48 DimensionConfig,
49 DimensionElement,
50 DimensionGroup,
51 DimensionNameError,
52 DimensionPacker,
53 DimensionUniverse,
54 NamedKeyDict,
55 NamedValueSet,
56 ddl,
57)
58from lsst.daf.butler.tests.utils import create_populated_sqlite_registry
59from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation
61DIMENSION_DATA_FILE = os.path.normpath(
62 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
63)
66def loadDimensionData() -> DataCoordinateSequence:
67 """Load dimension data from an export file included in the code repository.
69 Returns
70 -------
71 dataIds : `DataCoordinateSet`
72 A set containing all data IDs in the export file.
73 """
74 # Create an in-memory SQLite database and Registry just to import the YAML
75 # data and retrieve it as a set of DataCoordinate objects.
76 with create_populated_sqlite_registry(DIMENSION_DATA_FILE) as butler:
77 dimensions = butler.registry.dimensions.conform(["visit", "detector", "tract", "patch"])
78 return butler.registry.queryDataIds(dimensions).expanded().toSequence()
81class ConcreteTestDimensionPacker(DimensionPacker):
82 """A concrete `DimensionPacker` for testing its base class implementations.
84 This class just returns the detector ID as-is.
85 """
87 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup):
88 super().__init__(fixed, dimensions)
89 self._n_detectors = fixed.records["instrument"].detector_max
90 self._max_bits = (self._n_detectors - 1).bit_length()
92 @property
93 def maxBits(self) -> int:
94 # Docstring inherited from DimensionPacker.maxBits
95 return self._max_bits
97 def _pack(self, dataId: DataCoordinate) -> int:
98 # Docstring inherited from DimensionPacker._pack
99 return dataId["detector"]
101 def unpack(self, packedId: int) -> DataCoordinate:
102 # Docstring inherited from DimensionPacker.unpack
103 return DataCoordinate.standardize(
104 {
105 "instrument": self.fixed["instrument"],
106 "detector": packedId,
107 },
108 dimensions=self._dimensions,
109 )
112class DimensionTestCase(unittest.TestCase):
113 """Tests for dimensions.
115 All tests here rely on the content of ``config/dimensions.yaml``, either
116 to test that the definitions there are read in properly or just as generic
117 data for testing various operations.
118 """
120 def setUp(self):
121 self.universe = DimensionUniverse()
123 def checkGroupInvariants(self, group: DimensionGroup):
124 elements = list(group.elements)
125 for n, element_name in enumerate(elements):
126 element = self.universe[element_name]
127 # Ordered comparisons on graphs behave like sets.
128 self.assertLessEqual(element.minimal_group, group)
129 # Ordered comparisons on elements correspond to the ordering within
130 # a DimensionUniverse (topological, with deterministic
131 # tiebreakers).
132 for other_name in elements[:n]:
133 other = self.universe[other_name]
134 self.assertLess(other, element)
135 self.assertLessEqual(other, element)
136 for other_name in elements[n + 1 :]:
137 other = self.universe[other_name]
138 self.assertGreater(other, element)
139 self.assertGreaterEqual(other, element)
140 if isinstance(element, Dimension):
141 self.assertEqual(element.minimal_group.required, element.required)
142 self.assertEqual(self.universe.conform(group.required), group)
143 self.assertCountEqual(
144 group.required,
145 [
146 dimension_name
147 for dimension_name in group.names
148 if not any(
149 dimension_name in self.universe[other_name].minimal_group.implied
150 for other_name in group.elements
151 )
152 ],
153 )
154 self.assertCountEqual(group.implied, group.names - group.required)
155 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied))
156 # Check primary key traversal order: each element should follow any it
157 # requires, and element that is implied by any other in the graph
158 # follow at least one of those.
159 seen: set[str] = set()
160 for element_name in group.lookup_order:
161 element = self.universe[element_name]
162 with self.subTest(
163 required=repr(group.required), implied=repr(group.implied), element=repr(element)
164 ):
165 seen.add(element_name)
166 self.assertLessEqual(element.minimal_group.required, seen)
167 if element_name in group.implied:
168 self.assertTrue(any(element_name in self.universe[s].implied for s in seen))
169 self.assertCountEqual(seen, group.elements)
171 def testConfigPresent(self):
172 config = self.universe.dimensionConfig
173 self.assertIsInstance(config, DimensionConfig)
175 def test_group_union(self) -> None:
176 """Test unions of DimensionGroups."""
177 a = self.universe.conform(["visit"])
178 b = self.universe.conform(["detector"])
179 self.assertIs(a.union(b), self.universe.conform(["visit", "detector"]))
180 self.assertIs(DimensionGroup.union(a, b), self.universe.conform(["visit", "detector"]))
181 self.assertIs(DimensionGroup.union(universe=self.universe), self.universe.empty)
182 with self.assertRaises(TypeError):
183 DimensionGroup.union()
185 def test_group_set_ops(self) -> None:
186 """Test set operations on DimensionGroups."""
187 a = self.universe.conform(["visit"])
188 b = self.universe.conform(["detector"])
189 c = self.universe.conform(["visit", "detector"])
190 d = self.universe.conform(["physical_filter"])
191 e = self.universe.conform(["detector", "physical_filter"])
192 self.assertEqual(a.union(b), c)
193 self.assertEqual(a.union(d), a)
194 self.assertEqual(b.union(d), e)
195 self.assertEqual(a.difference(b), a)
196 self.assertEqual(a.difference(c), self.universe.empty)
197 self.assertEqual(a.difference(d), a)
198 self.assertEqual(c.difference(a), b)
199 self.assertEqual(a.intersection(c), a)
200 self.assertEqual(a.intersection(d), d)
201 self.assertEqual(a.intersection(e), d)
203 def testCompatibility(self):
204 # Simple check that should always be true.
205 self.assertTrue(self.universe.isCompatibleWith(self.universe))
207 # Create a universe like the default universe but with a different
208 # version number.
209 clone = self.universe.dimensionConfig.copy()
210 clone["version"] = clone["version"] + 1_000_000 # High version number
211 universe_clone = DimensionUniverse(config=clone)
212 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm:
213 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
214 self.assertIn("differing versions", "\n".join(cm.output))
216 # Create completely incompatible universe.
217 config = Config(
218 {
219 "version": 1,
220 "namespace": "compat_test",
221 "skypix": {
222 "common": "htm7",
223 "htm": {
224 "class": "lsst.sphgeom.HtmPixelization",
225 "max_level": 24,
226 },
227 },
228 "elements": {
229 "A": {
230 "keys": [
231 {
232 "name": "id",
233 "type": "int",
234 }
235 ],
236 "storage": {
237 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
238 },
239 },
240 "B": {
241 "keys": [
242 {
243 "name": "id",
244 "type": "int",
245 }
246 ],
247 "storage": {
248 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
249 },
250 },
251 },
252 "packers": {},
253 }
254 )
255 universe2 = DimensionUniverse(config=config)
256 self.assertFalse(universe2.isCompatibleWith(self.universe))
258 def testVersion(self):
259 self.assertEqual(self.universe.namespace, "daf_butler")
260 # Test was added starting at version 2.
261 self.assertGreaterEqual(self.universe.version, 2)
263 def testConfigRead(self):
264 self.assertEqual(
265 set(self.universe.getStaticDimensions().names),
266 {
267 "instrument",
268 "visit",
269 "visit_system",
270 "exposure",
271 "detector",
272 "group",
273 "day_obs",
274 "physical_filter",
275 "band",
276 "subfilter",
277 "skymap",
278 "tract",
279 "patch",
280 "ssp_hypothesis_table",
281 "ssp_hypothesis_bundle",
282 "ssp_balanced_index",
283 }
284 | {f"htm{level}" for level in range(1, 25)}
285 | {f"healpix{level}" for level in range(1, 18)},
286 )
288 def testGraphs(self):
289 self.checkGroupInvariants(self.universe.empty)
290 for element in self.universe.getStaticElements():
291 self.checkGroupInvariants(element.minimal_group)
293 def testInstrumentDimensions(self):
294 group = self.universe.conform(["exposure", "detector", "visit"])
295 self.assertCountEqual(
296 group.names,
297 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "group", "day_obs"),
298 )
299 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit"))
300 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
301 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
302 self.assertCountEqual(group.governors, {"instrument"})
303 for element in group.elements:
304 self.assertEqual(self.universe[element].has_own_table, element != "band", element)
305 self.assertEqual(
306 self.universe[element].implied_union_target,
307 "physical_filter" if element == "band" else None,
308 element,
309 )
310 self.assertEqual(
311 self.universe[element].defines_relationships,
312 element
313 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"),
314 element,
315 )
317 def testCalibrationDimensions(self):
318 group = self.universe.conform(["physical_filter", "detector"])
319 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band"))
320 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter"))
321 self.assertCountEqual(group.implied, ("band",))
322 self.assertCountEqual(group.elements, group.names)
323 self.assertCountEqual(group.governors, {"instrument"})
324 self.assertIsNone(group.region_dimension)
325 self.assertIsNone(group.timespan_dimension)
327 def testObservationDimensions(self):
328 group = self.universe.conform(["exposure", "detector", "visit"])
329 self.assertCountEqual(
330 group.names,
331 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "group", "day_obs"),
332 )
333 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit"))
334 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
335 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
336 self.assertCountEqual(group.spatial.names, ("observation_regions",))
337 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
338 self.assertCountEqual(group.governors, {"instrument"})
339 self.assertEqual(group.region_dimension, "visit_detector_region")
340 self.assertEqual(group.timespan_dimension, "exposure")
341 self.assertEqual(group.spatial.names, {"observation_regions"})
342 self.assertEqual(group.temporal.names, {"observation_timespans"})
343 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"])
344 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"])
345 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"])
346 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"])
347 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"])
348 self.assertEqual(
349 self.universe.get_elements_populated_by(self.universe["visit"]),
350 NamedValueSet(
351 {
352 self.universe["visit"],
353 self.universe["visit_definition"],
354 self.universe["visit_system_membership"],
355 self.universe["visit_detector_region"],
356 }
357 ),
358 )
360 def testSkyMapDimensions(self):
361 group = self.universe.conform(["patch"])
362 self.assertEqual(group.names, {"skymap", "tract", "patch"})
363 self.assertEqual(group.required, {"skymap", "tract", "patch"})
364 self.assertEqual(group.implied, set())
365 self.assertEqual(group.elements, group.names)
366 self.assertEqual(group.governors, {"skymap"})
367 self.assertEqual(group.spatial.names, {"skymap_regions"})
368 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"])
370 def testSubsetCalculation(self):
371 """Test that independent spatial and temporal options are computed
372 correctly.
373 """
374 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"])
375 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm"))
376 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
377 self.assertEqual(group.timespan_dimension, "exposure")
378 # Can not choose between visit_detector_region or htm7 or tract/patch.
379 self.assertIsNone(group.region_dimension)
381 def testSchemaGeneration(self):
382 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({})
383 for element in self.universe.getStaticElements():
384 if element.has_own_table:
385 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
386 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
387 )
388 for element, tableSpec in tableSpecs.items():
389 for dep in element.required:
390 with self.subTest(element=element.name, dep=dep.name):
391 if dep != element:
392 self.assertIn(dep.name, tableSpec.fields)
393 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
394 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
395 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
396 self.assertFalse(tableSpec.fields[dep.name].nullable)
397 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
398 else:
399 self.assertIn(element.primaryKey.name, tableSpec.fields)
400 self.assertEqual(
401 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
402 )
403 self.assertEqual(
404 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
405 )
406 self.assertEqual(
407 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
408 )
409 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
410 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
411 for dep in element.implied:
412 with self.subTest(element=element.name, dep=dep.name):
413 self.assertIn(dep.name, tableSpec.fields)
414 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
415 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
416 for foreignKey in tableSpec.foreignKeys:
417 self.assertIn(foreignKey.table, tableSpecs)
418 self.assertIn(foreignKey.table, element.dimensions)
419 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
420 for source, target in zip(foreignKey.source, foreignKey.target, strict=True):
421 self.assertIn(source, tableSpec.fields.names)
422 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
423 self.assertEqual(
424 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
425 )
426 self.assertEqual(
427 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
428 )
429 self.assertEqual(
430 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
431 )
433 def testPickling(self):
434 # Pickling and copying should always yield the exact same object within
435 # a single process (cross-process is impossible to test here).
436 universe1 = DimensionUniverse()
437 universe2 = pickle.loads(pickle.dumps(universe1))
438 universe3 = copy.copy(universe1)
439 universe4 = copy.deepcopy(universe1)
440 self.assertIs(universe1, universe2)
441 self.assertIs(universe1, universe3)
442 self.assertIs(universe1, universe4)
443 for element1 in universe1.getStaticElements():
444 element2 = pickle.loads(pickle.dumps(element1))
445 self.assertIs(element1, element2)
446 group1 = element1.minimal_group
447 group2 = pickle.loads(pickle.dumps(group1))
448 self.assertIs(group1, group2)
450 def testSerialization(self):
451 # Check that dataset types round-trip correctly through serialization.
452 dataset_type = DatasetType(
453 "flat",
454 dimensions=["instrument", "detector", "physical_filter", "band"],
455 isCalibration=True,
456 universe=self.universe,
457 storageClass="int",
458 )
459 roundtripped = DatasetType.from_simple(dataset_type.to_simple(), universe=self.universe)
461 self.assertEqual(dataset_type, roundtripped)
464@dataclass
465class SplitByStateFlags:
466 """A struct that separates data IDs with different states but the same
467 values.
468 """
470 minimal: DataCoordinateSequence | None = None
471 """Data IDs that only contain values for required dimensions.
473 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
474 if ``minimal.dimensions.implied`` has no elements.
475 `DataCoordinate.hasRecords()` will always return `False`.
476 """
478 complete: DataCoordinateSequence | None = None
479 """Data IDs that contain values for all dimensions.
481 `DataCoordinateSequence.hasFull()` will always `True` and
482 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
483 """
485 expanded: DataCoordinateSequence | None = None
486 """Data IDs that contain values for all dimensions as well as records.
488 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
489 always return `True` for this attribute.
490 """
492 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]:
493 """Iterate over the data IDs of different types.
495 Parameters
496 ----------
497 n : `int`, optional
498 If provided (`None` is default), iterate over only the ``nth``
499 data ID in each attribute.
501 Yields
502 ------
503 dataId : `DataCoordinate`
504 A data ID from one of the attributes in this struct.
505 """
506 if n is None:
507 s = slice(None, None)
508 else:
509 s = slice(n, n + 1)
510 if self.minimal is not None:
511 yield from self.minimal[s]
512 if self.complete is not None:
513 yield from self.complete[s]
514 if self.expanded is not None:
515 yield from self.expanded[s]
518class DataCoordinateTestCase(unittest.TestCase):
519 """Test `DataCoordinate`."""
521 RANDOM_SEED = 10
523 @classmethod
524 def setUpClass(cls):
525 cls.allDataIds = loadDimensionData()
527 def setUp(self):
528 self.rng = Random(self.RANDOM_SEED)
530 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
531 """Select random data IDs from those loaded from test data.
533 Parameters
534 ----------
535 n : `int`
536 Number of data IDs to select.
537 dataIds : `DataCoordinateSequence`, optional
538 Data IDs to select from. Defaults to ``self.allDataIds``.
540 Returns
541 -------
542 selected : `DataCoordinateSequence`
543 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
544 """
545 if dataIds is None:
546 dataIds = self.allDataIds
547 return DataCoordinateSequence(
548 self.rng.sample(dataIds, n),
549 dimensions=dataIds.dimensions,
550 hasFull=dataIds.hasFull(),
551 hasRecords=dataIds.hasRecords(),
552 check=False,
553 )
555 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup:
556 """Generate a random `DimensionGroup` that has a subset of the
557 dimensions in a given one.
559 Parameters
560 ----------
561 n : `int`
562 Number of dimensions to select, before automatic expansion by
563 `DimensionGroup`.
564 group : `DimensionGroup`, optional
565 Dimensions to select from. Defaults to
566 ``self.allDataIds.dimensions``.
568 Returns
569 -------
570 selected : `DimensionGroup`
571 ``n`` or more dimensions randomly selected from ``group`` with
572 replacement.
573 """
574 if group is None:
575 group = self.allDataIds.dimensions
576 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group))))
578 def splitByStateFlags(
579 self,
580 dataIds: DataCoordinateSequence | None = None,
581 *,
582 expanded: bool = True,
583 complete: bool = True,
584 minimal: bool = True,
585 ) -> SplitByStateFlags:
586 """Given a sequence of data IDs, generate new equivalent sequences
587 containing less information.
589 Parameters
590 ----------
591 dataIds : `DataCoordinateSequence`, optional.
592 Data IDs to start from. Defaults to ``self.allDataIds``.
593 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
594 `True`.
595 expanded : `bool`, optional
596 If `True` (default) include the original data IDs that contain all
597 information in the result.
598 complete : `bool`, optional
599 If `True` (default) include data IDs for which ``hasFull()``
600 returns `True` but ``hasRecords()`` does not.
601 minimal : `bool`, optional
602 If `True` (default) include data IDS that only contain values for
603 required dimensions, for which ``hasFull()`` may not return `True`.
605 Returns
606 -------
607 split : `SplitByStateFlags`
608 A dataclass holding the indicated data IDs in attributes that
609 correspond to the boolean keyword arguments.
610 """
611 if dataIds is None:
612 dataIds = self.allDataIds
613 assert dataIds.hasFull() and dataIds.hasRecords()
614 result = SplitByStateFlags(expanded=dataIds)
615 if complete:
616 result.complete = DataCoordinateSequence(
617 [
618 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions)
619 for e in result.expanded
620 ],
621 dimensions=dataIds.dimensions,
622 )
623 self.assertTrue(result.complete.hasFull())
624 self.assertFalse(result.complete.hasRecords())
625 if minimal:
626 result.minimal = DataCoordinateSequence(
627 [
628 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions)
629 for e in result.expanded
630 ],
631 dimensions=dataIds.dimensions,
632 )
633 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied)
634 self.assertFalse(result.minimal.hasRecords())
635 if not expanded:
636 result.expanded = None
637 return result
639 def testMappingViews(self):
640 """Test that the ``mapping`` and ``required`` attributes in
641 `DataCoordinate` are self-consistent and consistent with the
642 ``dimensions`` property.
643 """
644 for _ in range(5):
645 dimensions = self.randomDimensionSubset()
646 dataIds = self.randomDataIds(n=1).subset(dimensions)
647 split = self.splitByStateFlags(dataIds)
648 for dataId in split.chain():
649 with self.subTest(dataId=repr(dataId)):
650 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
651 self.assertEqual(
652 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required]
653 )
654 self.assertEqual(
655 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required]
656 )
657 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
658 for dataId in itertools.chain(split.complete, split.expanded):
659 with self.subTest(dataId=repr(dataId)):
660 self.assertTrue(dataId.hasFull())
661 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys())
662 self.assertEqual(
663 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()]
664 )
666 def test_pickle(self):
667 for _ in range(5):
668 dimensions = self.randomDimensionSubset()
669 dataIds = self.randomDataIds(n=1).subset(dimensions)
670 split = self.splitByStateFlags(dataIds)
671 for data_id in split.chain():
672 s = pickle.dumps(data_id)
673 read_data_id: DataCoordinate = pickle.loads(s)
674 self.assertEqual(data_id, read_data_id)
675 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
676 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
677 if data_id.hasFull():
678 self.assertEqual(data_id.mapping, read_data_id.mapping)
679 if data_id.hasRecords():
680 for element_name in data_id.dimensions.elements:
681 self.assertEqual(
682 data_id.records[element_name], read_data_id.records[element_name]
683 )
685 def test_record_attributes(self):
686 """Test that dimension records are available as attributes on expanded
687 data coordinates.
688 """
689 for _ in range(5):
690 dimensions = self.randomDimensionSubset()
691 dataIds = self.randomDataIds(n=1).subset(dimensions)
692 split = self.splitByStateFlags(dataIds)
693 for data_id in split.expanded:
694 for element_name in data_id.dimensions.elements:
695 self.assertIs(getattr(data_id, element_name), data_id.records[element_name])
696 self.assertIn(element_name, dir(data_id))
697 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
698 data_id.not_a_dimension_name
699 for data_id in itertools.chain(split.minimal, split.complete):
700 for element_name in data_id.dimensions.elements:
701 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
702 getattr(data_id, element_name)
703 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
704 data_id.not_a_dimension_name
706 def testEquality(self):
707 """Test that different `DataCoordinate` instances with different state
708 flags can be compared with each other and other mappings.
709 """
710 dataIds = self.randomDataIds(n=2)
711 split = self.splitByStateFlags(dataIds)
712 # Iterate over all combinations of different states of DataCoordinate,
713 # with the same underlying data ID values.
714 for a0, b0 in itertools.combinations(split.chain(0), 2):
715 self.assertEqual(a0, b0)
716 # Same thing, for a different data ID value.
717 for a1, b1 in itertools.combinations(split.chain(1), 2):
718 self.assertEqual(a1, b1)
719 # Iterate over all combinations of different states of DataCoordinate,
720 # with different underlying data ID values.
721 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
722 self.assertNotEqual(a0, b1)
723 self.assertNotEqual(a1, b0)
725 def testStandardize(self):
726 """Test constructing a DataCoordinate from many different kinds of
727 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
728 """
729 for _ in range(5):
730 dimensions = self.randomDimensionSubset()
731 dataIds = self.randomDataIds(n=1).subset(dimensions)
732 split = self.splitByStateFlags(dataIds)
733 for dataId in split.chain():
734 # Passing in any kind of DataCoordinate alone just returns
735 # that object.
736 self.assertIs(dataId, DataCoordinate.standardize(dataId))
737 # Same if we also explicitly pass the dimensions we want.
738 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions))
739 # Same if we pass the dimensions and some irrelevant
740 # kwargs.
741 self.assertIs(
742 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12)
743 )
744 # Test constructing a new data ID from this one with a
745 # subset of the dimensions.
746 # This is not possible for some combinations of
747 # dimensions if hasFull is False (see
748 # `DataCoordinate.subset` docs).
749 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions)
750 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required:
751 newDataIds = [
752 dataId.subset(newDimensions),
753 DataCoordinate.standardize(dataId, dimensions=newDimensions),
754 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12),
755 ]
756 for newDataId in newDataIds:
757 with self.subTest(newDataId=repr(newDataId), type=str(type(dataId))):
758 commonKeys = dataId.dimensions.required & newDataId.dimensions.required
759 self.assertTrue(commonKeys)
760 self.assertEqual(
761 [newDataId[k] for k in commonKeys],
762 [dataId[k] for k in commonKeys],
763 )
764 # This should never "downgrade" from
765 # Complete to Minimal or Expanded to Complete.
766 if dataId.hasRecords():
767 self.assertTrue(newDataId.hasRecords())
768 if dataId.hasFull():
769 self.assertTrue(newDataId.hasFull())
770 # Start from a complete data ID, and pass its values in via several
771 # different ways that should be equivalent.
772 for dataId in split.complete:
773 # Split the keys (dimension names) into two random subsets, so
774 # we can pass some as kwargs below.
775 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2))
776 keys2 = dataId.dimensions.names - keys1
777 newCompleteDataIds = [
778 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe),
779 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions),
780 DataCoordinate.standardize(
781 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping
782 ),
783 DataCoordinate.standardize(
784 DataCoordinate.make_empty(dataId.dimensions.universe),
785 dimensions=dataId.dimensions,
786 **dataId.mapping,
787 ),
788 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe),
789 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping),
790 DataCoordinate.standardize(
791 {k: dataId[k] for k in keys1},
792 universe=dataId.universe,
793 **{k: dataId[k] for k in keys2},
794 ),
795 DataCoordinate.standardize(
796 {k: dataId[k] for k in keys1},
797 dimensions=dataId.dimensions,
798 **{k: dataId[k] for k in keys2},
799 ),
800 ]
801 for newDataId in newCompleteDataIds:
802 with self.subTest(dataId=repr(dataId), newDataId=repr(newDataId), type=str(type(dataId))):
803 self.assertEqual(dataId, newDataId)
804 self.assertTrue(newDataId.hasFull())
806 coord = DataCoordinate.standardize({"instrument": "HSC"}, universe=self.allDataIds.universe)
807 with self.assertRaises(DimensionNameError):
808 coord.subset(["detector"])
810 def testUnion(self):
811 """Test `DataCoordinate.union`."""
812 # Make test groups to combine; mostly random, but with a few explicit
813 # cases to make sure certain edge cases are covered.
814 groups = [self.randomDimensionSubset(n=2) for i in range(2)]
815 groups.append(self.allDataIds.universe["visit"].minimal_group)
816 groups.append(self.allDataIds.universe["detector"].minimal_group)
817 groups.append(self.allDataIds.universe["physical_filter"].minimal_group)
818 groups.append(self.allDataIds.universe["band"].minimal_group)
819 # Iterate over all combinations, including the same graph with itself.
820 for group1, group2 in itertools.product(groups, repeat=2):
821 parentDataIds = self.randomDataIds(n=1)
822 split1 = self.splitByStateFlags(parentDataIds.subset(group1))
823 split2 = self.splitByStateFlags(parentDataIds.subset(group2))
824 (parentDataId,) = parentDataIds
825 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
826 unioned = lhs.union(rhs)
827 with self.subTest(lhs=repr(lhs), rhs=repr(rhs), unioned=repr(unioned)):
828 self.assertEqual(unioned.dimensions, group1.union(group2))
829 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions))
830 if unioned.hasFull():
831 self.assertEqual(unioned.subset(lhs.dimensions), lhs)
832 self.assertEqual(unioned.subset(rhs.dimensions), rhs)
833 if lhs.hasFull() and rhs.hasFull():
834 self.assertTrue(unioned.hasFull())
835 if lhs.dimensions >= unioned.dimensions and lhs.hasFull():
836 self.assertTrue(unioned.hasFull())
837 if lhs.hasRecords():
838 self.assertTrue(unioned.hasRecords())
839 if rhs.dimensions >= unioned.dimensions and rhs.hasFull():
840 self.assertTrue(unioned.hasFull())
841 if rhs.hasRecords():
842 self.assertTrue(unioned.hasRecords())
843 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names:
844 self.assertTrue(unioned.hasFull())
845 if (
846 lhs.hasRecords()
847 and rhs.hasRecords()
848 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements
849 ):
850 self.assertTrue(unioned.hasRecords())
852 def testRegions(self):
853 """Test that data IDs for a few known dimensions have the expected
854 regions.
855 """
856 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
857 self.assertIsNotNone(dataId.region)
858 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
859 self.assertEqual(dataId.region, dataId.records["visit"].region)
860 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])):
861 self.assertIsNotNone(dataId.region)
862 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
863 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
864 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])):
865 self.assertIsNotNone(dataId.region)
866 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
867 self.assertEqual(dataId.region, dataId.records["tract"].region)
868 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
869 self.assertIsNotNone(dataId.region)
870 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
871 self.assertEqual(dataId.region, dataId.records["patch"].region)
872 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])):
873 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
874 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
876 def testTimespans(self):
877 """Test that data IDs for a few known dimensions have the expected
878 timespans.
879 """
880 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
881 self.assertIsNotNone(dataId.timespan)
882 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"})
883 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
884 self.assertEqual(dataId.timespan, dataId.visit.timespan)
885 # Also test the case for non-temporal DataIds.
886 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
887 self.assertIsNone(dataId.timespan)
889 def testIterableStatusFlags(self):
890 """Test that DataCoordinateSet and DataCoordinateSequence compute
891 their hasFull and hasRecords flags correctly from their elements.
892 """
893 dataIds = self.randomDataIds(n=10)
894 split = self.splitByStateFlags(dataIds)
895 for cls in (DataCoordinateSet, DataCoordinateSequence):
896 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull())
897 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull())
898 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords())
899 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords())
900 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull())
901 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull())
902 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords())
903 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords())
904 with self.assertRaises(ValueError):
905 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True)
906 self.assertEqual(
907 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(),
908 not dataIds.dimensions.implied,
909 )
910 self.assertEqual(
911 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(),
912 not dataIds.dimensions.implied,
913 )
914 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords())
915 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords())
916 with self.assertRaises(ValueError):
917 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True)
918 if dataIds.dimensions.implied:
919 with self.assertRaises(ValueError):
920 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True)
922 def testSetOperations(self):
923 """Test for self-consistency across DataCoordinateSet's operations."""
924 c = self.randomDataIds(n=10).toSet()
925 a = self.randomDataIds(n=20).toSet() | c
926 b = self.randomDataIds(n=20).toSet() | c
927 # Make sure we don't have a particularly unlucky random seed, since
928 # that would make a lot of this test uninteresting.
929 self.assertNotEqual(a, b)
930 self.assertGreater(len(a), 0)
931 self.assertGreater(len(b), 0)
932 # The rest of the tests should not depend on the random seed.
933 self.assertEqual(a, a)
934 self.assertNotEqual(a, a.toSequence())
935 self.assertEqual(a, a.toSequence().toSet())
936 self.assertEqual(a, a.toSequence().toSet())
937 self.assertEqual(b, b)
938 self.assertNotEqual(b, b.toSequence())
939 self.assertEqual(b, b.toSequence().toSet())
940 self.assertEqual(a & b, a.intersection(b))
941 self.assertLessEqual(a & b, a)
942 self.assertLessEqual(a & b, b)
943 self.assertEqual(a | b, a.union(b))
944 self.assertGreaterEqual(a | b, a)
945 self.assertGreaterEqual(a | b, b)
946 self.assertEqual(a - b, a.difference(b))
947 self.assertLessEqual(a - b, a)
948 self.assertLessEqual(b - a, b)
949 self.assertEqual(a ^ b, a.symmetric_difference(b))
950 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
952 def testPackers(self):
953 (instrument_data_id,) = self.allDataIds.subset(
954 self.allDataIds.universe.conform(["instrument"])
955 ).toSet()
956 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"]))
957 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions)
958 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
959 self.assertEqual(packed_id, detector_data_id["detector"])
960 self.assertEqual(max_bits, packer.maxBits)
961 self.assertEqual(
962 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
963 )
964 self.assertEqual(packer.pack(detector_data_id), packed_id)
965 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
966 self.assertEqual(packer.unpack(packed_id), detector_data_id)
968 def test_dimension_group_pydantic(self):
969 """Test that DimensionGroup round-trips through Pydantic as long as
970 it's given the universe when validated.
971 """
972 dimensions = self.allDataIds.dimensions
973 adapter = pydantic.TypeAdapter(DimensionGroup)
974 json_str = adapter.dump_json(dimensions)
975 python_data = adapter.dump_python(dimensions)
976 self.assertEqual(
977 dimensions, adapter.validate_json(json_str, context=dict(universe=dimensions.universe))
978 )
979 self.assertEqual(
980 dimensions, adapter.validate_python(python_data, context=dict(universe=dimensions.universe))
981 )
982 self.assertEqual(dimensions, adapter.validate_python(dimensions))
984 def test_dimension_element_pydantic(self):
985 """Test that DimensionElement round-trips through Pydantic as long as
986 it's given the universe when validated.
987 """
988 element = self.allDataIds.universe["visit"]
989 adapter = pydantic.TypeAdapter(DimensionElement)
990 json_str = adapter.dump_json(element)
991 python_data = adapter.dump_python(element)
992 self.assertEqual(element, adapter.validate_json(json_str, context=dict(universe=element.universe)))
993 self.assertEqual(
994 element, adapter.validate_python(python_data, context=dict(universe=element.universe))
995 )
996 self.assertEqual(element, adapter.validate_python(element))
999if __name__ == "__main__":
1000 unittest.main()