Coverage for tests/test_dimensions.py: 11%
462 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:07 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:07 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import itertools
30import math
31import os
32import pickle
33import unittest
34from collections.abc import Iterator
35from dataclasses import dataclass
36from random import Random
38import lsst.sphgeom
39import pydantic
40from lsst.daf.butler import (
41 Config,
42 DataCoordinate,
43 DataCoordinateSequence,
44 DataCoordinateSet,
45 Dimension,
46 DimensionConfig,
47 DimensionElement,
48 DimensionGroup,
49 DimensionPacker,
50 DimensionUniverse,
51 NamedKeyDict,
52 NamedValueSet,
53 YamlRepoImportBackend,
54 ddl,
55)
56from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory
57from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation
59DIMENSION_DATA_FILE = os.path.normpath(
60 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
61)
64def loadDimensionData() -> DataCoordinateSequence:
65 """Load dimension data from an export file included in the code repository.
67 Returns
68 -------
69 dataIds : `DataCoordinateSet`
70 A set containing all data IDs in the export file.
71 """
72 # Create an in-memory SQLite database and Registry just to import the YAML
73 # data and retreive it as a set of DataCoordinate objects.
74 config = RegistryConfig()
75 config["db"] = "sqlite://"
76 registry = _RegistryFactory(config).create_from_config()
77 with open(DIMENSION_DATA_FILE) as stream:
78 backend = YamlRepoImportBackend(stream, registry)
79 backend.register()
80 backend.load(datastore=None)
81 dimensions = registry.dimensions.conform(["visit", "detector", "tract", "patch"])
82 return registry.queryDataIds(dimensions).expanded().toSequence()
85class ConcreteTestDimensionPacker(DimensionPacker):
86 """A concrete `DimensionPacker` for testing its base class implementations.
88 This class just returns the detector ID as-is.
89 """
91 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup):
92 super().__init__(fixed, dimensions)
93 self._n_detectors = fixed.records["instrument"].detector_max
94 self._max_bits = (self._n_detectors - 1).bit_length()
96 @property
97 def maxBits(self) -> int:
98 # Docstring inherited from DimensionPacker.maxBits
99 return self._max_bits
101 def _pack(self, dataId: DataCoordinate) -> int:
102 # Docstring inherited from DimensionPacker._pack
103 return dataId["detector"]
105 def unpack(self, packedId: int) -> DataCoordinate:
106 # Docstring inherited from DimensionPacker.unpack
107 return DataCoordinate.standardize(
108 {
109 "instrument": self.fixed["instrument"],
110 "detector": packedId,
111 },
112 dimensions=self._dimensions,
113 )
116class DimensionTestCase(unittest.TestCase):
117 """Tests for dimensions.
119 All tests here rely on the content of ``config/dimensions.yaml``, either
120 to test that the definitions there are read in properly or just as generic
121 data for testing various operations.
122 """
124 def setUp(self):
125 self.universe = DimensionUniverse()
127 def checkGroupInvariants(self, group: DimensionGroup):
128 elements = list(group.elements)
129 for n, element_name in enumerate(elements):
130 element = self.universe[element_name]
131 # Ordered comparisons on graphs behave like sets.
132 self.assertLessEqual(element.minimal_group, group)
133 # Ordered comparisons on elements correspond to the ordering within
134 # a DimensionUniverse (topological, with deterministic
135 # tiebreakers).
136 for other_name in elements[:n]:
137 other = self.universe[other_name]
138 self.assertLess(other, element)
139 self.assertLessEqual(other, element)
140 for other_name in elements[n + 1 :]:
141 other = self.universe[other_name]
142 self.assertGreater(other, element)
143 self.assertGreaterEqual(other, element)
144 if isinstance(element, Dimension):
145 self.assertEqual(element.minimal_group.required, element.required)
146 self.assertEqual(self.universe.conform(group.required), group)
147 self.assertCountEqual(
148 group.required,
149 [
150 dimension_name
151 for dimension_name in group.names
152 if not any(
153 dimension_name in self.universe[other_name].minimal_group.implied
154 for other_name in group.elements
155 )
156 ],
157 )
158 self.assertCountEqual(group.implied, group.names - group.required)
159 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied))
160 # Check primary key traversal order: each element should follow any it
161 # requires, and element that is implied by any other in the graph
162 # follow at least one of those.
163 seen: set[str] = set()
164 for element_name in group.lookup_order:
165 element = self.universe[element_name]
166 with self.subTest(required=group.required, implied=group.implied, element=element):
167 seen.add(element_name)
168 self.assertLessEqual(element.minimal_group.required, seen)
169 if element_name in group.implied:
170 self.assertTrue(any(element_name in self.universe[s].implied for s in seen))
171 self.assertCountEqual(seen, group.elements)
173 def testConfigPresent(self):
174 config = self.universe.dimensionConfig
175 self.assertIsInstance(config, DimensionConfig)
177 def testCompatibility(self):
178 # Simple check that should always be true.
179 self.assertTrue(self.universe.isCompatibleWith(self.universe))
181 # Create a universe like the default universe but with a different
182 # version number.
183 clone = self.universe.dimensionConfig.copy()
184 clone["version"] = clone["version"] + 1_000_000 # High version number
185 universe_clone = DimensionUniverse(config=clone)
186 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm:
187 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
188 self.assertIn("differing versions", "\n".join(cm.output))
190 # Create completely incompatible universe.
191 config = Config(
192 {
193 "version": 1,
194 "namespace": "compat_test",
195 "skypix": {
196 "common": "htm7",
197 "htm": {
198 "class": "lsst.sphgeom.HtmPixelization",
199 "max_level": 24,
200 },
201 },
202 "elements": {
203 "A": {
204 "keys": [
205 {
206 "name": "id",
207 "type": "int",
208 }
209 ],
210 "storage": {
211 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
212 },
213 },
214 "B": {
215 "keys": [
216 {
217 "name": "id",
218 "type": "int",
219 }
220 ],
221 "storage": {
222 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
223 },
224 },
225 },
226 "packers": {},
227 }
228 )
229 universe2 = DimensionUniverse(config=config)
230 self.assertFalse(universe2.isCompatibleWith(self.universe))
232 def testVersion(self):
233 self.assertEqual(self.universe.namespace, "daf_butler")
234 # Test was added starting at version 2.
235 self.assertGreaterEqual(self.universe.version, 2)
237 def testConfigRead(self):
238 self.assertEqual(
239 set(self.universe.getStaticDimensions().names),
240 {
241 "instrument",
242 "visit",
243 "visit_system",
244 "exposure",
245 "detector",
246 "group",
247 "day_obs",
248 "physical_filter",
249 "band",
250 "subfilter",
251 "skymap",
252 "tract",
253 "patch",
254 }
255 | {f"htm{level}" for level in range(1, 25)}
256 | {f"healpix{level}" for level in range(1, 18)},
257 )
259 def testGraphs(self):
260 self.checkGroupInvariants(self.universe.empty.as_group())
261 for element in self.universe.getStaticElements():
262 self.checkGroupInvariants(element.minimal_group)
264 def testInstrumentDimensions(self):
265 group = self.universe.conform(["exposure", "detector", "visit"])
266 self.assertCountEqual(
267 group.names,
268 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "group", "day_obs"),
269 )
270 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit"))
271 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
272 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
273 self.assertCountEqual(group.governors, {"instrument"})
274 for element in group.elements:
275 self.assertEqual(self.universe[element].has_own_table, element != "band", element)
276 self.assertEqual(
277 self.universe[element].implied_union_target,
278 "physical_filter" if element == "band" else None,
279 element,
280 )
281 self.assertEqual(
282 self.universe[element].defines_relationships,
283 element
284 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"),
285 element,
286 )
288 def testCalibrationDimensions(self):
289 group = self.universe.conform(["physical_filter", "detector"])
290 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band"))
291 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter"))
292 self.assertCountEqual(group.implied, ("band",))
293 self.assertCountEqual(group.elements, group.names)
294 self.assertCountEqual(group.governors, {"instrument"})
296 def testObservationDimensions(self):
297 group = self.universe.conform(["exposure", "detector", "visit"])
298 self.assertCountEqual(
299 group.names,
300 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "group", "day_obs"),
301 )
302 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit"))
303 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs"))
304 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
305 self.assertCountEqual(group.spatial.names, ("observation_regions",))
306 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
307 self.assertCountEqual(group.governors, {"instrument"})
308 self.assertEqual(group.spatial.names, {"observation_regions"})
309 self.assertEqual(group.temporal.names, {"observation_timespans"})
310 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"])
311 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"])
312 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"])
313 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"])
314 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"])
315 self.assertEqual(
316 self.universe.get_elements_populated_by(self.universe["visit"]),
317 NamedValueSet(
318 {
319 self.universe["visit"],
320 self.universe["visit_definition"],
321 self.universe["visit_system_membership"],
322 self.universe["visit_detector_region"],
323 }
324 ),
325 )
327 def testSkyMapDimensions(self):
328 group = self.universe.conform(["patch"])
329 self.assertEqual(group.names, {"skymap", "tract", "patch"})
330 self.assertEqual(group.required, {"skymap", "tract", "patch"})
331 self.assertEqual(group.implied, set())
332 self.assertEqual(group.elements, group.names)
333 self.assertEqual(group.governors, {"skymap"})
334 self.assertEqual(group.spatial.names, {"skymap_regions"})
335 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"])
337 def testSubsetCalculation(self):
338 """Test that independent spatial and temporal options are computed
339 correctly.
340 """
341 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"])
342 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm"))
343 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
345 def testSchemaGeneration(self):
346 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({})
347 for element in self.universe.getStaticElements():
348 if element.has_own_table:
349 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
350 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
351 )
352 for element, tableSpec in tableSpecs.items():
353 for dep in element.required:
354 with self.subTest(element=element.name, dep=dep.name):
355 if dep != element:
356 self.assertIn(dep.name, tableSpec.fields)
357 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
358 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
359 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
360 self.assertFalse(tableSpec.fields[dep.name].nullable)
361 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
362 else:
363 self.assertIn(element.primaryKey.name, tableSpec.fields)
364 self.assertEqual(
365 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
366 )
367 self.assertEqual(
368 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
369 )
370 self.assertEqual(
371 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
372 )
373 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
374 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
375 for dep in element.implied:
376 with self.subTest(element=element.name, dep=dep.name):
377 self.assertIn(dep.name, tableSpec.fields)
378 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
379 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
380 for foreignKey in tableSpec.foreignKeys:
381 self.assertIn(foreignKey.table, tableSpecs)
382 self.assertIn(foreignKey.table, element.dimensions)
383 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
384 for source, target in zip(foreignKey.source, foreignKey.target, strict=True):
385 self.assertIn(source, tableSpec.fields.names)
386 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
387 self.assertEqual(
388 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
389 )
390 self.assertEqual(
391 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
392 )
393 self.assertEqual(
394 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
395 )
397 def testPickling(self):
398 # Pickling and copying should always yield the exact same object within
399 # a single process (cross-process is impossible to test here).
400 universe1 = DimensionUniverse()
401 universe2 = pickle.loads(pickle.dumps(universe1))
402 universe3 = copy.copy(universe1)
403 universe4 = copy.deepcopy(universe1)
404 self.assertIs(universe1, universe2)
405 self.assertIs(universe1, universe3)
406 self.assertIs(universe1, universe4)
407 for element1 in universe1.getStaticElements():
408 element2 = pickle.loads(pickle.dumps(element1))
409 self.assertIs(element1, element2)
410 group1 = element1.minimal_group
411 group2 = pickle.loads(pickle.dumps(group1))
412 self.assertIs(group1, group2)
415@dataclass
416class SplitByStateFlags:
417 """A struct that separates data IDs with different states but the same
418 values.
419 """
421 minimal: DataCoordinateSequence | None = None
422 """Data IDs that only contain values for required dimensions.
424 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
425 if ``minimal.dimensions.implied`` has no elements.
426 `DataCoordinate.hasRecords()` will always return `False`.
427 """
429 complete: DataCoordinateSequence | None = None
430 """Data IDs that contain values for all dimensions.
432 `DataCoordinateSequence.hasFull()` will always `True` and
433 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
434 """
436 expanded: DataCoordinateSequence | None = None
437 """Data IDs that contain values for all dimensions as well as records.
439 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
440 always return `True` for this attribute.
441 """
443 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]:
444 """Iterate over the data IDs of different types.
446 Parameters
447 ----------
448 n : `int`, optional
449 If provided (`None` is default), iterate over only the ``nth``
450 data ID in each attribute.
452 Yields
453 ------
454 dataId : `DataCoordinate`
455 A data ID from one of the attributes in this struct.
456 """
457 if n is None:
458 s = slice(None, None)
459 else:
460 s = slice(n, n + 1)
461 if self.minimal is not None:
462 yield from self.minimal[s]
463 if self.complete is not None:
464 yield from self.complete[s]
465 if self.expanded is not None:
466 yield from self.expanded[s]
469class DataCoordinateTestCase(unittest.TestCase):
470 """Test `DataCoordinate`."""
472 RANDOM_SEED = 10
474 @classmethod
475 def setUpClass(cls):
476 cls.allDataIds = loadDimensionData()
478 def setUp(self):
479 self.rng = Random(self.RANDOM_SEED)
481 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
482 """Select random data IDs from those loaded from test data.
484 Parameters
485 ----------
486 n : `int`
487 Number of data IDs to select.
488 dataIds : `DataCoordinateSequence`, optional
489 Data IDs to select from. Defaults to ``self.allDataIds``.
491 Returns
492 -------
493 selected : `DataCoordinateSequence`
494 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
495 """
496 if dataIds is None:
497 dataIds = self.allDataIds
498 return DataCoordinateSequence(
499 self.rng.sample(dataIds, n),
500 dimensions=dataIds.dimensions,
501 hasFull=dataIds.hasFull(),
502 hasRecords=dataIds.hasRecords(),
503 check=False,
504 )
506 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup:
507 """Generate a random `DimensionGroup` that has a subset of the
508 dimensions in a given one.
510 Parameters
511 ----------
512 n : `int`
513 Number of dimensions to select, before automatic expansion by
514 `DimensionGroup`.
515 group : `DimensionGroup`, optional
516 Dimensions to select from. Defaults to
517 ``self.allDataIds.dimensions``.
519 Returns
520 -------
521 selected : `DimensionGroup`
522 ``n`` or more dimensions randomly selected from ``group`` with
523 replacement.
524 """
525 if group is None:
526 group = self.allDataIds.dimensions
527 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group))))
529 def splitByStateFlags(
530 self,
531 dataIds: DataCoordinateSequence | None = None,
532 *,
533 expanded: bool = True,
534 complete: bool = True,
535 minimal: bool = True,
536 ) -> SplitByStateFlags:
537 """Given a sequence of data IDs, generate new equivalent sequences
538 containing less information.
540 Parameters
541 ----------
542 dataIds : `DataCoordinateSequence`, optional.
543 Data IDs to start from. Defaults to ``self.allDataIds``.
544 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
545 `True`.
546 expanded : `bool`, optional
547 If `True` (default) include the original data IDs that contain all
548 information in the result.
549 complete : `bool`, optional
550 If `True` (default) include data IDs for which ``hasFull()``
551 returns `True` but ``hasRecords()`` does not.
552 minimal : `bool`, optional
553 If `True` (default) include data IDS that only contain values for
554 required dimensions, for which ``hasFull()`` may not return `True`.
556 Returns
557 -------
558 split : `SplitByStateFlags`
559 A dataclass holding the indicated data IDs in attributes that
560 correspond to the boolean keyword arguments.
561 """
562 if dataIds is None:
563 dataIds = self.allDataIds
564 assert dataIds.hasFull() and dataIds.hasRecords()
565 result = SplitByStateFlags(expanded=dataIds)
566 if complete:
567 result.complete = DataCoordinateSequence(
568 [
569 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions)
570 for e in result.expanded
571 ],
572 dimensions=dataIds.dimensions,
573 )
574 self.assertTrue(result.complete.hasFull())
575 self.assertFalse(result.complete.hasRecords())
576 if minimal:
577 result.minimal = DataCoordinateSequence(
578 [
579 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions)
580 for e in result.expanded
581 ],
582 dimensions=dataIds.dimensions,
583 )
584 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied)
585 self.assertFalse(result.minimal.hasRecords())
586 if not expanded:
587 result.expanded = None
588 return result
590 def testMappingViews(self):
591 """Test that the ``mapping`` and ``required`` attributes in
592 `DataCoordinate` are self-consistent and consistent with the
593 ``dimensions`` property.
594 """
595 for _ in range(5):
596 dimensions = self.randomDimensionSubset()
597 dataIds = self.randomDataIds(n=1).subset(dimensions)
598 split = self.splitByStateFlags(dataIds)
599 for dataId in split.chain():
600 with self.subTest(dataId=dataId):
601 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
602 self.assertEqual(
603 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required]
604 )
605 self.assertEqual(
606 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required]
607 )
608 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
609 for dataId in itertools.chain(split.complete, split.expanded):
610 with self.subTest(dataId=dataId):
611 self.assertTrue(dataId.hasFull())
612 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys())
613 self.assertEqual(
614 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()]
615 )
617 def test_pickle(self):
618 for _ in range(5):
619 dimensions = self.randomDimensionSubset()
620 dataIds = self.randomDataIds(n=1).subset(dimensions)
621 split = self.splitByStateFlags(dataIds)
622 for data_id in split.chain():
623 s = pickle.dumps(data_id)
624 read_data_id: DataCoordinate = pickle.loads(s)
625 self.assertEqual(data_id, read_data_id)
626 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
627 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
628 if data_id.hasFull():
629 self.assertEqual(data_id.mapping, read_data_id.mapping)
630 if data_id.hasRecords():
631 for element_name in data_id.dimensions.elements:
632 self.assertEqual(
633 data_id.records[element_name], read_data_id.records[element_name]
634 )
636 def test_record_attributes(self):
637 """Test that dimension records are available as attributes on expanded
638 data coordinates.
639 """
640 for _ in range(5):
641 dimensions = self.randomDimensionSubset()
642 dataIds = self.randomDataIds(n=1).subset(dimensions)
643 split = self.splitByStateFlags(dataIds)
644 for data_id in split.expanded:
645 for element_name in data_id.dimensions.elements:
646 self.assertIs(getattr(data_id, element_name), data_id.records[element_name])
647 self.assertIn(element_name, dir(data_id))
648 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
649 data_id.not_a_dimension_name
650 for data_id in itertools.chain(split.minimal, split.complete):
651 for element_name in data_id.dimensions.elements:
652 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
653 getattr(data_id, element_name)
654 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
655 data_id.not_a_dimension_name
657 def testEquality(self):
658 """Test that different `DataCoordinate` instances with different state
659 flags can be compared with each other and other mappings.
660 """
661 dataIds = self.randomDataIds(n=2)
662 split = self.splitByStateFlags(dataIds)
663 # Iterate over all combinations of different states of DataCoordinate,
664 # with the same underlying data ID values.
665 for a0, b0 in itertools.combinations(split.chain(0), 2):
666 self.assertEqual(a0, b0)
667 # Same thing, for a different data ID value.
668 for a1, b1 in itertools.combinations(split.chain(1), 2):
669 self.assertEqual(a1, b1)
670 # Iterate over all combinations of different states of DataCoordinate,
671 # with different underlying data ID values.
672 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
673 self.assertNotEqual(a0, b1)
674 self.assertNotEqual(a1, b0)
676 def testStandardize(self):
677 """Test constructing a DataCoordinate from many different kinds of
678 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
679 """
680 for _ in range(5):
681 dimensions = self.randomDimensionSubset()
682 dataIds = self.randomDataIds(n=1).subset(dimensions)
683 split = self.splitByStateFlags(dataIds)
684 for dataId in split.chain():
685 # Passing in any kind of DataCoordinate alone just returns
686 # that object.
687 self.assertIs(dataId, DataCoordinate.standardize(dataId))
688 # Same if we also explicitly pass the dimensions we want.
689 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions))
690 # Same if we pass the dimensions and some irrelevant
691 # kwargs.
692 self.assertIs(
693 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12)
694 )
695 # Test constructing a new data ID from this one with a
696 # subset of the dimensions.
697 # This is not possible for some combinations of
698 # dimensions if hasFull is False (see
699 # `DataCoordinate.subset` docs).
700 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions)
701 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required:
702 newDataIds = [
703 dataId.subset(newDimensions),
704 DataCoordinate.standardize(dataId, dimensions=newDimensions),
705 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12),
706 ]
707 for newDataId in newDataIds:
708 with self.subTest(newDataId=newDataId, type=type(dataId)):
709 commonKeys = dataId.dimensions.required & newDataId.dimensions.required
710 self.assertTrue(commonKeys)
711 self.assertEqual(
712 [newDataId[k] for k in commonKeys],
713 [dataId[k] for k in commonKeys],
714 )
715 # This should never "downgrade" from
716 # Complete to Minimal or Expanded to Complete.
717 if dataId.hasRecords():
718 self.assertTrue(newDataId.hasRecords())
719 if dataId.hasFull():
720 self.assertTrue(newDataId.hasFull())
721 # Start from a complete data ID, and pass its values in via several
722 # different ways that should be equivalent.
723 for dataId in split.complete:
724 # Split the keys (dimension names) into two random subsets, so
725 # we can pass some as kwargs below.
726 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2))
727 keys2 = dataId.dimensions.names - keys1
728 newCompleteDataIds = [
729 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe),
730 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions),
731 DataCoordinate.standardize(
732 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping
733 ),
734 DataCoordinate.standardize(
735 DataCoordinate.make_empty(dataId.dimensions.universe),
736 dimensions=dataId.dimensions,
737 **dataId.mapping,
738 ),
739 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe),
740 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping),
741 DataCoordinate.standardize(
742 {k: dataId[k] for k in keys1},
743 universe=dataId.universe,
744 **{k: dataId[k] for k in keys2},
745 ),
746 DataCoordinate.standardize(
747 {k: dataId[k] for k in keys1},
748 dimensions=dataId.dimensions,
749 **{k: dataId[k] for k in keys2},
750 ),
751 ]
752 for newDataId in newCompleteDataIds:
753 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
754 self.assertEqual(dataId, newDataId)
755 self.assertTrue(newDataId.hasFull())
757 def testUnion(self):
758 """Test `DataCoordinate.union`."""
759 # Make test groups to combine; mostly random, but with a few explicit
760 # cases to make sure certain edge cases are covered.
761 groups = [self.randomDimensionSubset(n=2) for i in range(2)]
762 groups.append(self.allDataIds.universe["visit"].minimal_group)
763 groups.append(self.allDataIds.universe["detector"].minimal_group)
764 groups.append(self.allDataIds.universe["physical_filter"].minimal_group)
765 groups.append(self.allDataIds.universe["band"].minimal_group)
766 # Iterate over all combinations, including the same graph with itself.
767 for group1, group2 in itertools.product(groups, repeat=2):
768 parentDataIds = self.randomDataIds(n=1)
769 split1 = self.splitByStateFlags(parentDataIds.subset(group1))
770 split2 = self.splitByStateFlags(parentDataIds.subset(group2))
771 (parentDataId,) = parentDataIds
772 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
773 unioned = lhs.union(rhs)
774 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
775 self.assertEqual(unioned.dimensions, group1.union(group2))
776 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions))
777 if unioned.hasFull():
778 self.assertEqual(unioned.subset(lhs.dimensions), lhs)
779 self.assertEqual(unioned.subset(rhs.dimensions), rhs)
780 if lhs.hasFull() and rhs.hasFull():
781 self.assertTrue(unioned.hasFull())
782 if lhs.dimensions >= unioned.dimensions and lhs.hasFull():
783 self.assertTrue(unioned.hasFull())
784 if lhs.hasRecords():
785 self.assertTrue(unioned.hasRecords())
786 if rhs.dimensions >= unioned.dimensions and rhs.hasFull():
787 self.assertTrue(unioned.hasFull())
788 if rhs.hasRecords():
789 self.assertTrue(unioned.hasRecords())
790 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names:
791 self.assertTrue(unioned.hasFull())
792 if (
793 lhs.hasRecords()
794 and rhs.hasRecords()
795 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements
796 ):
797 self.assertTrue(unioned.hasRecords())
799 def testRegions(self):
800 """Test that data IDs for a few known dimensions have the expected
801 regions.
802 """
803 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
804 self.assertIsNotNone(dataId.region)
805 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
806 self.assertEqual(dataId.region, dataId.records["visit"].region)
807 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])):
808 self.assertIsNotNone(dataId.region)
809 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
810 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
811 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])):
812 self.assertIsNotNone(dataId.region)
813 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
814 self.assertEqual(dataId.region, dataId.records["tract"].region)
815 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
816 self.assertIsNotNone(dataId.region)
817 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
818 self.assertEqual(dataId.region, dataId.records["patch"].region)
819 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])):
820 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
821 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
823 def testTimespans(self):
824 """Test that data IDs for a few known dimensions have the expected
825 timespans.
826 """
827 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
828 self.assertIsNotNone(dataId.timespan)
829 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"})
830 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
831 self.assertEqual(dataId.timespan, dataId.visit.timespan)
832 # Also test the case for non-temporal DataIds.
833 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
834 self.assertIsNone(dataId.timespan)
836 def testIterableStatusFlags(self):
837 """Test that DataCoordinateSet and DataCoordinateSequence compute
838 their hasFull and hasRecords flags correctly from their elements.
839 """
840 dataIds = self.randomDataIds(n=10)
841 split = self.splitByStateFlags(dataIds)
842 for cls in (DataCoordinateSet, DataCoordinateSequence):
843 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull())
844 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull())
845 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords())
846 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords())
847 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull())
848 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull())
849 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords())
850 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords())
851 with self.assertRaises(ValueError):
852 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True)
853 self.assertEqual(
854 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(),
855 not dataIds.dimensions.implied,
856 )
857 self.assertEqual(
858 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(),
859 not dataIds.dimensions.implied,
860 )
861 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords())
862 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords())
863 with self.assertRaises(ValueError):
864 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True)
865 if dataIds.dimensions.implied:
866 with self.assertRaises(ValueError):
867 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True)
869 def testSetOperations(self):
870 """Test for self-consistency across DataCoordinateSet's operations."""
871 c = self.randomDataIds(n=10).toSet()
872 a = self.randomDataIds(n=20).toSet() | c
873 b = self.randomDataIds(n=20).toSet() | c
874 # Make sure we don't have a particularly unlucky random seed, since
875 # that would make a lot of this test uninteresting.
876 self.assertNotEqual(a, b)
877 self.assertGreater(len(a), 0)
878 self.assertGreater(len(b), 0)
879 # The rest of the tests should not depend on the random seed.
880 self.assertEqual(a, a)
881 self.assertNotEqual(a, a.toSequence())
882 self.assertEqual(a, a.toSequence().toSet())
883 self.assertEqual(a, a.toSequence().toSet())
884 self.assertEqual(b, b)
885 self.assertNotEqual(b, b.toSequence())
886 self.assertEqual(b, b.toSequence().toSet())
887 self.assertEqual(a & b, a.intersection(b))
888 self.assertLessEqual(a & b, a)
889 self.assertLessEqual(a & b, b)
890 self.assertEqual(a | b, a.union(b))
891 self.assertGreaterEqual(a | b, a)
892 self.assertGreaterEqual(a | b, b)
893 self.assertEqual(a - b, a.difference(b))
894 self.assertLessEqual(a - b, a)
895 self.assertLessEqual(b - a, b)
896 self.assertEqual(a ^ b, a.symmetric_difference(b))
897 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
899 def testPackers(self):
900 (instrument_data_id,) = self.allDataIds.subset(
901 self.allDataIds.universe.conform(["instrument"])
902 ).toSet()
903 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"]))
904 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions)
905 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
906 self.assertEqual(packed_id, detector_data_id["detector"])
907 self.assertEqual(max_bits, packer.maxBits)
908 self.assertEqual(
909 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
910 )
911 self.assertEqual(packer.pack(detector_data_id), packed_id)
912 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
913 self.assertEqual(packer.unpack(packed_id), detector_data_id)
915 def test_dimension_group_pydantic(self):
916 """Test that DimensionGroup round-trips through Pydantic as long as
917 it's given the universe when validated.
918 """
919 dimensions = self.allDataIds.dimensions
920 adapter = pydantic.TypeAdapter(DimensionGroup)
921 json_str = adapter.dump_json(dimensions)
922 python_data = adapter.dump_python(dimensions)
923 self.assertEqual(
924 dimensions, adapter.validate_json(json_str, context=dict(universe=dimensions.universe))
925 )
926 self.assertEqual(
927 dimensions, adapter.validate_python(python_data, context=dict(universe=dimensions.universe))
928 )
929 self.assertEqual(dimensions, adapter.validate_python(dimensions))
931 def test_dimension_element_pydantic(self):
932 """Test that DimensionElement round-trips through Pydantic as long as
933 it's given the universe when validated.
934 """
935 element = self.allDataIds.universe["visit"]
936 adapter = pydantic.TypeAdapter(DimensionElement)
937 json_str = adapter.dump_json(element)
938 python_data = adapter.dump_python(element)
939 self.assertEqual(element, adapter.validate_json(json_str, context=dict(universe=element.universe)))
940 self.assertEqual(
941 element, adapter.validate_python(python_data, context=dict(universe=element.universe))
942 )
943 self.assertEqual(element, adapter.validate_python(element))
946if __name__ == "__main__":
947 unittest.main()