Coverage for tests / test_dimensions.py: 11%
493 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import itertools
30import math
31import os
32import pickle
33import unittest
34from collections.abc import Iterator
35from dataclasses import dataclass
36from random import Random
38import pydantic
40import lsst.sphgeom
41from lsst.daf.butler import (
42 Config,
43 DataCoordinate,
44 DataCoordinateSequence,
45 DataCoordinateSet,
46 DatasetType,
47 Dimension,
48 DimensionConfig,
49 DimensionElement,
50 DimensionGroup,
51 DimensionNameError,
52 DimensionPacker,
53 DimensionUniverse,
54 NamedKeyDict,
55 NamedValueSet,
56 ddl,
57)
58from lsst.daf.butler.tests.utils import create_populated_sqlite_registry
59from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation
61DIMENSION_DATA_FILE = os.path.normpath(
62 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
63)
66def loadDimensionData() -> DataCoordinateSequence:
67 """Load dimension data from an export file included in the code repository.
69 Returns
70 -------
71 dataIds : `DataCoordinateSet`
72 A set containing all data IDs in the export file.
73 """
74 # Create an in-memory SQLite database and Registry just to import the YAML
75 # data and retrieve it as a set of DataCoordinate objects.
76 with create_populated_sqlite_registry(DIMENSION_DATA_FILE) as butler:
77 dimensions = butler.registry.dimensions.conform(["visit", "detector", "tract", "patch"])
78 return butler.registry.queryDataIds(dimensions).expanded().toSequence()
81class ConcreteTestDimensionPacker(DimensionPacker):
82 """A concrete `DimensionPacker` for testing its base class implementations.
84 This class just returns the detector ID as-is.
85 """
87 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup):
88 super().__init__(fixed, dimensions)
89 self._n_detectors = fixed.records["instrument"].detector_max
90 self._max_bits = (self._n_detectors - 1).bit_length()
92 @property
93 def maxBits(self) -> int:
94 # Docstring inherited from DimensionPacker.maxBits
95 return self._max_bits
97 def _pack(self, dataId: DataCoordinate) -> int:
98 # Docstring inherited from DimensionPacker._pack
99 return dataId["detector"]
101 def unpack(self, packedId: int) -> DataCoordinate:
102 # Docstring inherited from DimensionPacker.unpack
103 return DataCoordinate.standardize(
104 {
105 "instrument": self.fixed["instrument"],
106 "detector": packedId,
107 },
108 dimensions=self._dimensions,
109 )
112class DimensionTestCase(unittest.TestCase):
113 """Tests for dimensions.
115 All tests here rely on the content of ``config/dimensions.yaml``, either
116 to test that the definitions there are read in properly or just as generic
117 data for testing various operations.
118 """
120 def setUp(self):
121 self.universe = DimensionUniverse()
123 def checkGroupInvariants(self, group: DimensionGroup):
124 elements = list(group.elements)
125 for n, element_name in enumerate(elements):
126 element = self.universe[element_name]
127 # Ordered comparisons on graphs behave like sets.
128 self.assertLessEqual(element.minimal_group, group)
129 # Ordered comparisons on elements correspond to the ordering within
130 # a DimensionUniverse (topological, with deterministic
131 # tiebreakers).
132 for other_name in elements[:n]:
133 other = self.universe[other_name]
134 self.assertLess(other, element)
135 self.assertLessEqual(other, element)
136 for other_name in elements[n + 1 :]:
137 other = self.universe[other_name]
138 self.assertGreater(other, element)
139 self.assertGreaterEqual(other, element)
140 if isinstance(element, Dimension):
141 self.assertEqual(element.minimal_group.required, element.required)
142 self.assertEqual(self.universe.conform(group.required), group)
143 self.assertCountEqual(
144 group.required,
145 [
146 dimension_name
147 for dimension_name in group.names
148 if not any(
149 dimension_name in self.universe[other_name].minimal_group.implied
150 for other_name in group.elements
151 )
152 ],
153 )
154 self.assertCountEqual(group.implied, group.names - group.required)
155 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied))
156 # Check primary key traversal order: each element should follow any it
157 # requires, and element that is implied by any other in the graph
158 # follow at least one of those.
159 seen: set[str] = set()
160 for element_name in group.lookup_order:
161 element = self.universe[element_name]
162 with self.subTest(
163 required=repr(group.required), implied=repr(group.implied), element=repr(element)
164 ):
165 seen.add(element_name)
166 self.assertLessEqual(element.minimal_group.required, seen)
167 if element_name in group.implied:
168 self.assertTrue(any(element_name in self.universe[s].implied for s in seen))
169 self.assertCountEqual(seen, group.elements)
171 def testConfigPresent(self):
172 config = self.universe.dimensionConfig
173 self.assertIsInstance(config, DimensionConfig)
175 def test_group_union(self) -> None:
176 """Test unions of DimensionGroups."""
177 a = self.universe.conform(["visit"])
178 b = self.universe.conform(["detector"])
179 self.assertIs(a.union(b), self.universe.conform(["visit", "detector"]))
180 self.assertIs(DimensionGroup.union(a, b), self.universe.conform(["visit", "detector"]))
181 self.assertIs(DimensionGroup.union(universe=self.universe), self.universe.empty)
182 with self.assertRaises(TypeError):
183 DimensionGroup.union()
185 def test_group_set_ops(self) -> None:
186 """Test set operations on DimensionGroups."""
187 a = self.universe.conform(["visit"])
188 b = self.universe.conform(["detector"])
189 c = self.universe.conform(["visit", "detector"])
190 d = self.universe.conform(["physical_filter"])
191 e = self.universe.conform(["detector", "physical_filter"])
192 self.assertEqual(a.union(b), c)
193 self.assertEqual(a.union(d), a)
194 self.assertEqual(b.union(d), e)
195 self.assertEqual(a.difference(b), a)
196 self.assertEqual(a.difference(c), self.universe.empty)
197 self.assertEqual(a.difference(d), a)
198 self.assertEqual(c.difference(a), b)
199 self.assertEqual(a.intersection(c), a)
200 self.assertEqual(a.intersection(d), d)
201 self.assertEqual(a.intersection(e), d)
203 def testCompatibility(self):
204 # Simple check that should always be true.
205 self.assertTrue(self.universe.isCompatibleWith(self.universe))
207 # Create a universe like the default universe but with a different
208 # version number.
209 clone = self.universe.dimensionConfig.copy()
210 clone["version"] = clone["version"] + 1_000_000 # High version number
211 universe_clone = DimensionUniverse(config=clone)
212 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm:
213 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
214 self.assertIn("differing versions", "\n".join(cm.output))
216 # Create completely incompatible universe.
217 config = Config(
218 {
219 "version": 1,
220 "namespace": "compat_test",
221 "skypix": {
222 "common": "htm7",
223 "htm": {
224 "class": "lsst.sphgeom.HtmPixelization",
225 "max_level": 24,
226 },
227 },
228 "elements": {
229 "A": {
230 "keys": [
231 {
232 "name": "id",
233 "type": "int",
234 }
235 ],
236 "storage": {
237 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
238 },
239 },
240 "B": {
241 "keys": [
242 {
243 "name": "id",
244 "type": "int",
245 }
246 ],
247 "storage": {
248 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
249 },
250 },
251 },
252 "packers": {},
253 }
254 )
255 universe2 = DimensionUniverse(config=config)
256 self.assertFalse(universe2.isCompatibleWith(self.universe))
258 def testVersion(self):
259 self.assertEqual(self.universe.namespace, "daf_butler")
260 # Test was added starting at version 2.
261 self.assertGreaterEqual(self.universe.version, 2)
263 def testConfigRead(self):
264 self.assertEqual(
265 set(self.universe.getStaticDimensions().names),
266 {
267 "instrument",
268 "visit",
269 "visit_system",
270 "exposure",
271 "detector",
272 "group",
273 "day_obs",
274 "physical_filter",
275 "band",
276 "subfilter",
277 "skymap",
278 "tract",
279 "patch",
280 }
281 | {f"htm{level}" for level in range(1, 25)}
282 | {f"healpix{level}" for level in range(1, 18)},
283 )
285 def testGraphs(self):
286 self.checkGroupInvariants(self.universe.empty)
287 for element in self.universe.getStaticElements():
288 self.checkGroupInvariants(element.minimal_group)
290 def testInstrumentDimensions(self):
291 group = self.universe.conform(["exposure", "detector", "visit"])
292 self.assertCountEqual(
293 group.names,
294 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "group", "day_obs"),
295 )
296 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit"))
297 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
298 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
299 self.assertCountEqual(group.governors, {"instrument"})
300 for element in group.elements:
301 self.assertEqual(self.universe[element].has_own_table, element != "band", element)
302 self.assertEqual(
303 self.universe[element].implied_union_target,
304 "physical_filter" if element == "band" else None,
305 element,
306 )
307 self.assertEqual(
308 self.universe[element].defines_relationships,
309 element
310 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"),
311 element,
312 )
314 def testCalibrationDimensions(self):
315 group = self.universe.conform(["physical_filter", "detector"])
316 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band"))
317 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter"))
318 self.assertCountEqual(group.implied, ("band",))
319 self.assertCountEqual(group.elements, group.names)
320 self.assertCountEqual(group.governors, {"instrument"})
321 self.assertIsNone(group.region_dimension)
322 self.assertIsNone(group.timespan_dimension)
324 def testObservationDimensions(self):
325 group = self.universe.conform(["exposure", "detector", "visit"])
326 self.assertCountEqual(
327 group.names,
328 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "group", "day_obs"),
329 )
330 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit"))
331 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
332 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
333 self.assertCountEqual(group.spatial.names, ("observation_regions",))
334 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
335 self.assertCountEqual(group.governors, {"instrument"})
336 self.assertEqual(group.region_dimension, "visit_detector_region")
337 self.assertEqual(group.timespan_dimension, "exposure")
338 self.assertEqual(group.spatial.names, {"observation_regions"})
339 self.assertEqual(group.temporal.names, {"observation_timespans"})
340 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"])
341 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"])
342 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"])
343 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"])
344 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"])
345 self.assertEqual(
346 self.universe.get_elements_populated_by(self.universe["visit"]),
347 NamedValueSet(
348 {
349 self.universe["visit"],
350 self.universe["visit_definition"],
351 self.universe["visit_system_membership"],
352 self.universe["visit_detector_region"],
353 }
354 ),
355 )
357 def testSkyMapDimensions(self):
358 group = self.universe.conform(["patch"])
359 self.assertEqual(group.names, {"skymap", "tract", "patch"})
360 self.assertEqual(group.required, {"skymap", "tract", "patch"})
361 self.assertEqual(group.implied, set())
362 self.assertEqual(group.elements, group.names)
363 self.assertEqual(group.governors, {"skymap"})
364 self.assertEqual(group.spatial.names, {"skymap_regions"})
365 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"])
367 def testSubsetCalculation(self):
368 """Test that independent spatial and temporal options are computed
369 correctly.
370 """
371 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"])
372 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm"))
373 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
374 self.assertEqual(group.timespan_dimension, "exposure")
375 # Can not choose between visit_detector_region or htm7 or tract/patch.
376 self.assertIsNone(group.region_dimension)
378 def testSchemaGeneration(self):
379 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({})
380 for element in self.universe.getStaticElements():
381 if element.has_own_table:
382 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
383 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
384 )
385 for element, tableSpec in tableSpecs.items():
386 for dep in element.required:
387 with self.subTest(element=element.name, dep=dep.name):
388 if dep != element:
389 self.assertIn(dep.name, tableSpec.fields)
390 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
391 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
392 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
393 self.assertFalse(tableSpec.fields[dep.name].nullable)
394 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
395 else:
396 self.assertIn(element.primaryKey.name, tableSpec.fields)
397 self.assertEqual(
398 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
399 )
400 self.assertEqual(
401 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
402 )
403 self.assertEqual(
404 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
405 )
406 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
407 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
408 for dep in element.implied:
409 with self.subTest(element=element.name, dep=dep.name):
410 self.assertIn(dep.name, tableSpec.fields)
411 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
412 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
413 for foreignKey in tableSpec.foreignKeys:
414 self.assertIn(foreignKey.table, tableSpecs)
415 self.assertIn(foreignKey.table, element.dimensions)
416 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
417 for source, target in zip(foreignKey.source, foreignKey.target, strict=True):
418 self.assertIn(source, tableSpec.fields.names)
419 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
420 self.assertEqual(
421 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
422 )
423 self.assertEqual(
424 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
425 )
426 self.assertEqual(
427 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
428 )
430 def testPickling(self):
431 # Pickling and copying should always yield the exact same object within
432 # a single process (cross-process is impossible to test here).
433 universe1 = DimensionUniverse()
434 universe2 = pickle.loads(pickle.dumps(universe1))
435 universe3 = copy.copy(universe1)
436 universe4 = copy.deepcopy(universe1)
437 self.assertIs(universe1, universe2)
438 self.assertIs(universe1, universe3)
439 self.assertIs(universe1, universe4)
440 for element1 in universe1.getStaticElements():
441 element2 = pickle.loads(pickle.dumps(element1))
442 self.assertIs(element1, element2)
443 group1 = element1.minimal_group
444 group2 = pickle.loads(pickle.dumps(group1))
445 self.assertIs(group1, group2)
447 def testSerialization(self):
448 # Check that dataset types round-trip correctly through serialization.
449 dataset_type = DatasetType(
450 "flat",
451 dimensions=["instrument", "detector", "physical_filter", "band"],
452 isCalibration=True,
453 universe=self.universe,
454 storageClass="int",
455 )
456 roundtripped = DatasetType.from_simple(dataset_type.to_simple(), universe=self.universe)
458 self.assertEqual(dataset_type, roundtripped)
461@dataclass
462class SplitByStateFlags:
463 """A struct that separates data IDs with different states but the same
464 values.
465 """
467 minimal: DataCoordinateSequence | None = None
468 """Data IDs that only contain values for required dimensions.
470 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
471 if ``minimal.dimensions.implied`` has no elements.
472 `DataCoordinate.hasRecords()` will always return `False`.
473 """
475 complete: DataCoordinateSequence | None = None
476 """Data IDs that contain values for all dimensions.
478 `DataCoordinateSequence.hasFull()` will always `True` and
479 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
480 """
482 expanded: DataCoordinateSequence | None = None
483 """Data IDs that contain values for all dimensions as well as records.
485 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
486 always return `True` for this attribute.
487 """
489 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]:
490 """Iterate over the data IDs of different types.
492 Parameters
493 ----------
494 n : `int`, optional
495 If provided (`None` is default), iterate over only the ``nth``
496 data ID in each attribute.
498 Yields
499 ------
500 dataId : `DataCoordinate`
501 A data ID from one of the attributes in this struct.
502 """
503 if n is None:
504 s = slice(None, None)
505 else:
506 s = slice(n, n + 1)
507 if self.minimal is not None:
508 yield from self.minimal[s]
509 if self.complete is not None:
510 yield from self.complete[s]
511 if self.expanded is not None:
512 yield from self.expanded[s]
515class DataCoordinateTestCase(unittest.TestCase):
516 """Test `DataCoordinate`."""
518 RANDOM_SEED = 10
520 @classmethod
521 def setUpClass(cls):
522 cls.allDataIds = loadDimensionData()
524 def setUp(self):
525 self.rng = Random(self.RANDOM_SEED)
527 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
528 """Select random data IDs from those loaded from test data.
530 Parameters
531 ----------
532 n : `int`
533 Number of data IDs to select.
534 dataIds : `DataCoordinateSequence`, optional
535 Data IDs to select from. Defaults to ``self.allDataIds``.
537 Returns
538 -------
539 selected : `DataCoordinateSequence`
540 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
541 """
542 if dataIds is None:
543 dataIds = self.allDataIds
544 return DataCoordinateSequence(
545 self.rng.sample(dataIds, n),
546 dimensions=dataIds.dimensions,
547 hasFull=dataIds.hasFull(),
548 hasRecords=dataIds.hasRecords(),
549 check=False,
550 )
552 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup:
553 """Generate a random `DimensionGroup` that has a subset of the
554 dimensions in a given one.
556 Parameters
557 ----------
558 n : `int`
559 Number of dimensions to select, before automatic expansion by
560 `DimensionGroup`.
561 group : `DimensionGroup`, optional
562 Dimensions to select from. Defaults to
563 ``self.allDataIds.dimensions``.
565 Returns
566 -------
567 selected : `DimensionGroup`
568 ``n`` or more dimensions randomly selected from ``group`` with
569 replacement.
570 """
571 if group is None:
572 group = self.allDataIds.dimensions
573 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group))))
575 def splitByStateFlags(
576 self,
577 dataIds: DataCoordinateSequence | None = None,
578 *,
579 expanded: bool = True,
580 complete: bool = True,
581 minimal: bool = True,
582 ) -> SplitByStateFlags:
583 """Given a sequence of data IDs, generate new equivalent sequences
584 containing less information.
586 Parameters
587 ----------
588 dataIds : `DataCoordinateSequence`, optional.
589 Data IDs to start from. Defaults to ``self.allDataIds``.
590 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
591 `True`.
592 expanded : `bool`, optional
593 If `True` (default) include the original data IDs that contain all
594 information in the result.
595 complete : `bool`, optional
596 If `True` (default) include data IDs for which ``hasFull()``
597 returns `True` but ``hasRecords()`` does not.
598 minimal : `bool`, optional
599 If `True` (default) include data IDS that only contain values for
600 required dimensions, for which ``hasFull()`` may not return `True`.
602 Returns
603 -------
604 split : `SplitByStateFlags`
605 A dataclass holding the indicated data IDs in attributes that
606 correspond to the boolean keyword arguments.
607 """
608 if dataIds is None:
609 dataIds = self.allDataIds
610 assert dataIds.hasFull() and dataIds.hasRecords()
611 result = SplitByStateFlags(expanded=dataIds)
612 if complete:
613 result.complete = DataCoordinateSequence(
614 [
615 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions)
616 for e in result.expanded
617 ],
618 dimensions=dataIds.dimensions,
619 )
620 self.assertTrue(result.complete.hasFull())
621 self.assertFalse(result.complete.hasRecords())
622 if minimal:
623 result.minimal = DataCoordinateSequence(
624 [
625 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions)
626 for e in result.expanded
627 ],
628 dimensions=dataIds.dimensions,
629 )
630 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied)
631 self.assertFalse(result.minimal.hasRecords())
632 if not expanded:
633 result.expanded = None
634 return result
636 def testMappingViews(self):
637 """Test that the ``mapping`` and ``required`` attributes in
638 `DataCoordinate` are self-consistent and consistent with the
639 ``dimensions`` property.
640 """
641 for _ in range(5):
642 dimensions = self.randomDimensionSubset()
643 dataIds = self.randomDataIds(n=1).subset(dimensions)
644 split = self.splitByStateFlags(dataIds)
645 for dataId in split.chain():
646 with self.subTest(dataId=repr(dataId)):
647 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
648 self.assertEqual(
649 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required]
650 )
651 self.assertEqual(
652 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required]
653 )
654 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
655 for dataId in itertools.chain(split.complete, split.expanded):
656 with self.subTest(dataId=repr(dataId)):
657 self.assertTrue(dataId.hasFull())
658 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys())
659 self.assertEqual(
660 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()]
661 )
663 def test_pickle(self):
664 for _ in range(5):
665 dimensions = self.randomDimensionSubset()
666 dataIds = self.randomDataIds(n=1).subset(dimensions)
667 split = self.splitByStateFlags(dataIds)
668 for data_id in split.chain():
669 s = pickle.dumps(data_id)
670 read_data_id: DataCoordinate = pickle.loads(s)
671 self.assertEqual(data_id, read_data_id)
672 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
673 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
674 if data_id.hasFull():
675 self.assertEqual(data_id.mapping, read_data_id.mapping)
676 if data_id.hasRecords():
677 for element_name in data_id.dimensions.elements:
678 self.assertEqual(
679 data_id.records[element_name], read_data_id.records[element_name]
680 )
682 def test_record_attributes(self):
683 """Test that dimension records are available as attributes on expanded
684 data coordinates.
685 """
686 for _ in range(5):
687 dimensions = self.randomDimensionSubset()
688 dataIds = self.randomDataIds(n=1).subset(dimensions)
689 split = self.splitByStateFlags(dataIds)
690 for data_id in split.expanded:
691 for element_name in data_id.dimensions.elements:
692 self.assertIs(getattr(data_id, element_name), data_id.records[element_name])
693 self.assertIn(element_name, dir(data_id))
694 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
695 data_id.not_a_dimension_name
696 for data_id in itertools.chain(split.minimal, split.complete):
697 for element_name in data_id.dimensions.elements:
698 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
699 getattr(data_id, element_name)
700 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
701 data_id.not_a_dimension_name
703 def testEquality(self):
704 """Test that different `DataCoordinate` instances with different state
705 flags can be compared with each other and other mappings.
706 """
707 dataIds = self.randomDataIds(n=2)
708 split = self.splitByStateFlags(dataIds)
709 # Iterate over all combinations of different states of DataCoordinate,
710 # with the same underlying data ID values.
711 for a0, b0 in itertools.combinations(split.chain(0), 2):
712 self.assertEqual(a0, b0)
713 # Same thing, for a different data ID value.
714 for a1, b1 in itertools.combinations(split.chain(1), 2):
715 self.assertEqual(a1, b1)
716 # Iterate over all combinations of different states of DataCoordinate,
717 # with different underlying data ID values.
718 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
719 self.assertNotEqual(a0, b1)
720 self.assertNotEqual(a1, b0)
722 def testStandardize(self):
723 """Test constructing a DataCoordinate from many different kinds of
724 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
725 """
726 for _ in range(5):
727 dimensions = self.randomDimensionSubset()
728 dataIds = self.randomDataIds(n=1).subset(dimensions)
729 split = self.splitByStateFlags(dataIds)
730 for dataId in split.chain():
731 # Passing in any kind of DataCoordinate alone just returns
732 # that object.
733 self.assertIs(dataId, DataCoordinate.standardize(dataId))
734 # Same if we also explicitly pass the dimensions we want.
735 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions))
736 # Same if we pass the dimensions and some irrelevant
737 # kwargs.
738 self.assertIs(
739 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12)
740 )
741 # Test constructing a new data ID from this one with a
742 # subset of the dimensions.
743 # This is not possible for some combinations of
744 # dimensions if hasFull is False (see
745 # `DataCoordinate.subset` docs).
746 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions)
747 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required:
748 newDataIds = [
749 dataId.subset(newDimensions),
750 DataCoordinate.standardize(dataId, dimensions=newDimensions),
751 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12),
752 ]
753 for newDataId in newDataIds:
754 with self.subTest(newDataId=repr(newDataId), type=str(type(dataId))):
755 commonKeys = dataId.dimensions.required & newDataId.dimensions.required
756 self.assertTrue(commonKeys)
757 self.assertEqual(
758 [newDataId[k] for k in commonKeys],
759 [dataId[k] for k in commonKeys],
760 )
761 # This should never "downgrade" from
762 # Complete to Minimal or Expanded to Complete.
763 if dataId.hasRecords():
764 self.assertTrue(newDataId.hasRecords())
765 if dataId.hasFull():
766 self.assertTrue(newDataId.hasFull())
767 # Start from a complete data ID, and pass its values in via several
768 # different ways that should be equivalent.
769 for dataId in split.complete:
770 # Split the keys (dimension names) into two random subsets, so
771 # we can pass some as kwargs below.
772 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2))
773 keys2 = dataId.dimensions.names - keys1
774 newCompleteDataIds = [
775 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe),
776 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions),
777 DataCoordinate.standardize(
778 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping
779 ),
780 DataCoordinate.standardize(
781 DataCoordinate.make_empty(dataId.dimensions.universe),
782 dimensions=dataId.dimensions,
783 **dataId.mapping,
784 ),
785 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe),
786 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping),
787 DataCoordinate.standardize(
788 {k: dataId[k] for k in keys1},
789 universe=dataId.universe,
790 **{k: dataId[k] for k in keys2},
791 ),
792 DataCoordinate.standardize(
793 {k: dataId[k] for k in keys1},
794 dimensions=dataId.dimensions,
795 **{k: dataId[k] for k in keys2},
796 ),
797 ]
798 for newDataId in newCompleteDataIds:
799 with self.subTest(dataId=repr(dataId), newDataId=repr(newDataId), type=str(type(dataId))):
800 self.assertEqual(dataId, newDataId)
801 self.assertTrue(newDataId.hasFull())
803 coord = DataCoordinate.standardize({"instrument": "HSC"}, universe=self.allDataIds.universe)
804 with self.assertRaises(DimensionNameError):
805 coord.subset(["detector"])
807 def testUnion(self):
808 """Test `DataCoordinate.union`."""
809 # Make test groups to combine; mostly random, but with a few explicit
810 # cases to make sure certain edge cases are covered.
811 groups = [self.randomDimensionSubset(n=2) for i in range(2)]
812 groups.append(self.allDataIds.universe["visit"].minimal_group)
813 groups.append(self.allDataIds.universe["detector"].minimal_group)
814 groups.append(self.allDataIds.universe["physical_filter"].minimal_group)
815 groups.append(self.allDataIds.universe["band"].minimal_group)
816 # Iterate over all combinations, including the same graph with itself.
817 for group1, group2 in itertools.product(groups, repeat=2):
818 parentDataIds = self.randomDataIds(n=1)
819 split1 = self.splitByStateFlags(parentDataIds.subset(group1))
820 split2 = self.splitByStateFlags(parentDataIds.subset(group2))
821 (parentDataId,) = parentDataIds
822 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
823 unioned = lhs.union(rhs)
824 with self.subTest(lhs=repr(lhs), rhs=repr(rhs), unioned=repr(unioned)):
825 self.assertEqual(unioned.dimensions, group1.union(group2))
826 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions))
827 if unioned.hasFull():
828 self.assertEqual(unioned.subset(lhs.dimensions), lhs)
829 self.assertEqual(unioned.subset(rhs.dimensions), rhs)
830 if lhs.hasFull() and rhs.hasFull():
831 self.assertTrue(unioned.hasFull())
832 if lhs.dimensions >= unioned.dimensions and lhs.hasFull():
833 self.assertTrue(unioned.hasFull())
834 if lhs.hasRecords():
835 self.assertTrue(unioned.hasRecords())
836 if rhs.dimensions >= unioned.dimensions and rhs.hasFull():
837 self.assertTrue(unioned.hasFull())
838 if rhs.hasRecords():
839 self.assertTrue(unioned.hasRecords())
840 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names:
841 self.assertTrue(unioned.hasFull())
842 if (
843 lhs.hasRecords()
844 and rhs.hasRecords()
845 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements
846 ):
847 self.assertTrue(unioned.hasRecords())
849 def testRegions(self):
850 """Test that data IDs for a few known dimensions have the expected
851 regions.
852 """
853 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
854 self.assertIsNotNone(dataId.region)
855 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
856 self.assertEqual(dataId.region, dataId.records["visit"].region)
857 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])):
858 self.assertIsNotNone(dataId.region)
859 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
860 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
861 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])):
862 self.assertIsNotNone(dataId.region)
863 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
864 self.assertEqual(dataId.region, dataId.records["tract"].region)
865 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
866 self.assertIsNotNone(dataId.region)
867 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
868 self.assertEqual(dataId.region, dataId.records["patch"].region)
869 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])):
870 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
871 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
873 def testTimespans(self):
874 """Test that data IDs for a few known dimensions have the expected
875 timespans.
876 """
877 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
878 self.assertIsNotNone(dataId.timespan)
879 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"})
880 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
881 self.assertEqual(dataId.timespan, dataId.visit.timespan)
882 # Also test the case for non-temporal DataIds.
883 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
884 self.assertIsNone(dataId.timespan)
886 def testIterableStatusFlags(self):
887 """Test that DataCoordinateSet and DataCoordinateSequence compute
888 their hasFull and hasRecords flags correctly from their elements.
889 """
890 dataIds = self.randomDataIds(n=10)
891 split = self.splitByStateFlags(dataIds)
892 for cls in (DataCoordinateSet, DataCoordinateSequence):
893 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull())
894 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull())
895 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords())
896 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords())
897 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull())
898 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull())
899 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords())
900 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords())
901 with self.assertRaises(ValueError):
902 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True)
903 self.assertEqual(
904 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(),
905 not dataIds.dimensions.implied,
906 )
907 self.assertEqual(
908 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(),
909 not dataIds.dimensions.implied,
910 )
911 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords())
912 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords())
913 with self.assertRaises(ValueError):
914 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True)
915 if dataIds.dimensions.implied:
916 with self.assertRaises(ValueError):
917 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True)
919 def testSetOperations(self):
920 """Test for self-consistency across DataCoordinateSet's operations."""
921 c = self.randomDataIds(n=10).toSet()
922 a = self.randomDataIds(n=20).toSet() | c
923 b = self.randomDataIds(n=20).toSet() | c
924 # Make sure we don't have a particularly unlucky random seed, since
925 # that would make a lot of this test uninteresting.
926 self.assertNotEqual(a, b)
927 self.assertGreater(len(a), 0)
928 self.assertGreater(len(b), 0)
929 # The rest of the tests should not depend on the random seed.
930 self.assertEqual(a, a)
931 self.assertNotEqual(a, a.toSequence())
932 self.assertEqual(a, a.toSequence().toSet())
933 self.assertEqual(a, a.toSequence().toSet())
934 self.assertEqual(b, b)
935 self.assertNotEqual(b, b.toSequence())
936 self.assertEqual(b, b.toSequence().toSet())
937 self.assertEqual(a & b, a.intersection(b))
938 self.assertLessEqual(a & b, a)
939 self.assertLessEqual(a & b, b)
940 self.assertEqual(a | b, a.union(b))
941 self.assertGreaterEqual(a | b, a)
942 self.assertGreaterEqual(a | b, b)
943 self.assertEqual(a - b, a.difference(b))
944 self.assertLessEqual(a - b, a)
945 self.assertLessEqual(b - a, b)
946 self.assertEqual(a ^ b, a.symmetric_difference(b))
947 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
949 def testPackers(self):
950 (instrument_data_id,) = self.allDataIds.subset(
951 self.allDataIds.universe.conform(["instrument"])
952 ).toSet()
953 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"]))
954 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions)
955 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
956 self.assertEqual(packed_id, detector_data_id["detector"])
957 self.assertEqual(max_bits, packer.maxBits)
958 self.assertEqual(
959 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
960 )
961 self.assertEqual(packer.pack(detector_data_id), packed_id)
962 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
963 self.assertEqual(packer.unpack(packed_id), detector_data_id)
965 def test_dimension_group_pydantic(self):
966 """Test that DimensionGroup round-trips through Pydantic as long as
967 it's given the universe when validated.
968 """
969 dimensions = self.allDataIds.dimensions
970 adapter = pydantic.TypeAdapter(DimensionGroup)
971 json_str = adapter.dump_json(dimensions)
972 python_data = adapter.dump_python(dimensions)
973 self.assertEqual(
974 dimensions, adapter.validate_json(json_str, context=dict(universe=dimensions.universe))
975 )
976 self.assertEqual(
977 dimensions, adapter.validate_python(python_data, context=dict(universe=dimensions.universe))
978 )
979 self.assertEqual(dimensions, adapter.validate_python(dimensions))
981 def test_dimension_element_pydantic(self):
982 """Test that DimensionElement round-trips through Pydantic as long as
983 it's given the universe when validated.
984 """
985 element = self.allDataIds.universe["visit"]
986 adapter = pydantic.TypeAdapter(DimensionElement)
987 json_str = adapter.dump_json(element)
988 python_data = adapter.dump_python(element)
989 self.assertEqual(element, adapter.validate_json(json_str, context=dict(universe=element.universe)))
990 self.assertEqual(
991 element, adapter.validate_python(python_data, context=dict(universe=element.universe))
992 )
993 self.assertEqual(element, adapter.validate_python(element))
996if __name__ == "__main__":
997 unittest.main()