Coverage for tests/test_dimensions.py: 11%
462 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 10:57 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 10:57 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import itertools
30import math
31import os
32import pickle
33import unittest
34from collections.abc import Iterator
35from dataclasses import dataclass
36from random import Random
38import lsst.sphgeom
39import pydantic
40from lsst.daf.butler import (
41 Config,
42 DataCoordinate,
43 DataCoordinateSequence,
44 DataCoordinateSet,
45 Dimension,
46 DimensionConfig,
47 DimensionElement,
48 DimensionGroup,
49 DimensionPacker,
50 DimensionUniverse,
51 NamedKeyDict,
52 NamedValueSet,
53 YamlRepoImportBackend,
54 ddl,
55)
56from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory
57from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation
59DIMENSION_DATA_FILE = os.path.normpath(
60 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
61)
64def loadDimensionData() -> DataCoordinateSequence:
65 """Load dimension data from an export file included in the code repository.
67 Returns
68 -------
69 dataIds : `DataCoordinateSet`
70 A set containing all data IDs in the export file.
71 """
72 # Create an in-memory SQLite database and Registry just to import the YAML
73 # data and retreive it as a set of DataCoordinate objects.
74 config = RegistryConfig()
75 config["db"] = "sqlite://"
76 registry = _RegistryFactory(config).create_from_config()
77 with open(DIMENSION_DATA_FILE) as stream:
78 backend = YamlRepoImportBackend(stream, registry)
79 backend.register()
80 backend.load(datastore=None)
81 dimensions = registry.dimensions.conform(["visit", "detector", "tract", "patch"])
82 return registry.queryDataIds(dimensions).expanded().toSequence()
85class ConcreteTestDimensionPacker(DimensionPacker):
86 """A concrete `DimensionPacker` for testing its base class implementations.
88 This class just returns the detector ID as-is.
89 """
91 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup):
92 super().__init__(fixed, dimensions)
93 self._n_detectors = fixed.records["instrument"].detector_max
94 self._max_bits = (self._n_detectors - 1).bit_length()
96 @property
97 def maxBits(self) -> int:
98 # Docstring inherited from DimensionPacker.maxBits
99 return self._max_bits
101 def _pack(self, dataId: DataCoordinate) -> int:
102 # Docstring inherited from DimensionPacker._pack
103 return dataId["detector"]
105 def unpack(self, packedId: int) -> DataCoordinate:
106 # Docstring inherited from DimensionPacker.unpack
107 return DataCoordinate.standardize(
108 {
109 "instrument": self.fixed["instrument"],
110 "detector": packedId,
111 },
112 dimensions=self._dimensions,
113 )
116class DimensionTestCase(unittest.TestCase):
117 """Tests for dimensions.
119 All tests here rely on the content of ``config/dimensions.yaml``, either
120 to test that the definitions there are read in properly or just as generic
121 data for testing various operations.
122 """
124 def setUp(self):
125 self.universe = DimensionUniverse()
127 def checkGroupInvariants(self, group: DimensionGroup):
128 elements = list(group.elements)
129 for n, element_name in enumerate(elements):
130 element = self.universe[element_name]
131 # Ordered comparisons on graphs behave like sets.
132 self.assertLessEqual(element.minimal_group, group)
133 # Ordered comparisons on elements correspond to the ordering within
134 # a DimensionUniverse (topological, with deterministic
135 # tiebreakers).
136 for other_name in elements[:n]:
137 other = self.universe[other_name]
138 self.assertLess(other, element)
139 self.assertLessEqual(other, element)
140 for other_name in elements[n + 1 :]:
141 other = self.universe[other_name]
142 self.assertGreater(other, element)
143 self.assertGreaterEqual(other, element)
144 if isinstance(element, Dimension):
145 self.assertEqual(element.minimal_group.required, element.required)
146 self.assertEqual(self.universe.conform(group.required), group)
147 self.assertCountEqual(
148 group.required,
149 [
150 dimension_name
151 for dimension_name in group.names
152 if not any(
153 dimension_name in self.universe[other_name].minimal_group.implied
154 for other_name in group.elements
155 )
156 ],
157 )
158 self.assertCountEqual(group.implied, group.names - group.required)
159 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied))
160 # Check primary key traversal order: each element should follow any it
161 # requires, and element that is implied by any other in the graph
162 # follow at least one of those.
163 seen: set[str] = set()
164 for element_name in group.lookup_order:
165 element = self.universe[element_name]
166 with self.subTest(required=group.required, implied=group.implied, element=element):
167 seen.add(element_name)
168 self.assertLessEqual(element.minimal_group.required, seen)
169 if element_name in group.implied:
170 self.assertTrue(any(element_name in self.universe[s].implied for s in seen))
171 self.assertCountEqual(seen, group.elements)
173 def testConfigPresent(self):
174 config = self.universe.dimensionConfig
175 self.assertIsInstance(config, DimensionConfig)
177 def testCompatibility(self):
178 # Simple check that should always be true.
179 self.assertTrue(self.universe.isCompatibleWith(self.universe))
181 # Create a universe like the default universe but with a different
182 # version number.
183 clone = self.universe.dimensionConfig.copy()
184 clone["version"] = clone["version"] + 1_000_000 # High version number
185 universe_clone = DimensionUniverse(config=clone)
186 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm:
187 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
188 self.assertIn("differing versions", "\n".join(cm.output))
190 # Create completely incompatible universe.
191 config = Config(
192 {
193 "version": 1,
194 "namespace": "compat_test",
195 "skypix": {
196 "common": "htm7",
197 "htm": {
198 "class": "lsst.sphgeom.HtmPixelization",
199 "max_level": 24,
200 },
201 },
202 "elements": {
203 "A": {
204 "keys": [
205 {
206 "name": "id",
207 "type": "int",
208 }
209 ],
210 "storage": {
211 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
212 },
213 },
214 "B": {
215 "keys": [
216 {
217 "name": "id",
218 "type": "int",
219 }
220 ],
221 "storage": {
222 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
223 },
224 },
225 },
226 "packers": {},
227 }
228 )
229 universe2 = DimensionUniverse(config=config)
230 self.assertFalse(universe2.isCompatibleWith(self.universe))
232 def testVersion(self):
233 self.assertEqual(self.universe.namespace, "daf_butler")
234 # Test was added starting at version 2.
235 self.assertGreaterEqual(self.universe.version, 2)
237 def testConfigRead(self):
238 self.assertEqual(
239 set(self.universe.getStaticDimensions().names),
240 {
241 "instrument",
242 "visit",
243 "visit_system",
244 "exposure",
245 "detector",
246 "physical_filter",
247 "band",
248 "subfilter",
249 "skymap",
250 "tract",
251 "patch",
252 }
253 | {f"htm{level}" for level in range(1, 25)}
254 | {f"healpix{level}" for level in range(1, 18)},
255 )
257 def testGraphs(self):
258 self.checkGroupInvariants(self.universe.empty.as_group())
259 for element in self.universe.getStaticElements():
260 self.checkGroupInvariants(element.minimal_group)
262 def testInstrumentDimensions(self):
263 group = self.universe.conform(["exposure", "detector", "visit"])
264 self.assertCountEqual(
265 group.names,
266 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"),
267 )
268 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit"))
269 self.assertCountEqual(group.implied, ("physical_filter", "band"))
270 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
271 self.assertCountEqual(group.governors, {"instrument"})
272 for element in group.elements:
273 self.assertEqual(self.universe[element].has_own_table, element != "band", element)
274 self.assertEqual(
275 self.universe[element].implied_union_target,
276 "physical_filter" if element == "band" else None,
277 element,
278 )
279 self.assertEqual(
280 self.universe[element].defines_relationships,
281 element
282 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"),
283 element,
284 )
286 def testCalibrationDimensions(self):
287 group = self.universe.conform(["physical_filter", "detector"])
288 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band"))
289 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter"))
290 self.assertCountEqual(group.implied, ("band",))
291 self.assertCountEqual(group.elements, group.names)
292 self.assertCountEqual(group.governors, {"instrument"})
294 def testObservationDimensions(self):
295 group = self.universe.conform(["exposure", "detector", "visit"])
296 self.assertCountEqual(
297 group.names,
298 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"),
299 )
300 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit"))
301 self.assertCountEqual(group.implied, ("physical_filter", "band"))
302 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition"))
303 self.assertCountEqual(group.spatial.names, ("observation_regions",))
304 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
305 self.assertCountEqual(group.governors, {"instrument"})
306 self.assertEqual(group.spatial.names, {"observation_regions"})
307 self.assertEqual(group.temporal.names, {"observation_timespans"})
308 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"])
309 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"])
310 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"])
311 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"])
312 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"])
313 self.assertEqual(
314 self.universe.get_elements_populated_by(self.universe["visit"]),
315 NamedValueSet(
316 {
317 self.universe["visit"],
318 self.universe["visit_definition"],
319 self.universe["visit_system_membership"],
320 self.universe["visit_detector_region"],
321 }
322 ),
323 )
325 def testSkyMapDimensions(self):
326 group = self.universe.conform(["patch"])
327 self.assertEqual(group.names, {"skymap", "tract", "patch"})
328 self.assertEqual(group.required, {"skymap", "tract", "patch"})
329 self.assertEqual(group.implied, set())
330 self.assertEqual(group.elements, group.names)
331 self.assertEqual(group.governors, {"skymap"})
332 self.assertEqual(group.spatial.names, {"skymap_regions"})
333 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"])
335 def testSubsetCalculation(self):
336 """Test that independent spatial and temporal options are computed
337 correctly.
338 """
339 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"])
340 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm"))
341 self.assertCountEqual(group.temporal.names, ("observation_timespans",))
343 def testSchemaGeneration(self):
344 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({})
345 for element in self.universe.getStaticElements():
346 if element.has_own_table:
347 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
348 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
349 )
350 for element, tableSpec in tableSpecs.items():
351 for dep in element.required:
352 with self.subTest(element=element.name, dep=dep.name):
353 if dep != element:
354 self.assertIn(dep.name, tableSpec.fields)
355 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
356 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
357 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
358 self.assertFalse(tableSpec.fields[dep.name].nullable)
359 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
360 else:
361 self.assertIn(element.primaryKey.name, tableSpec.fields)
362 self.assertEqual(
363 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
364 )
365 self.assertEqual(
366 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
367 )
368 self.assertEqual(
369 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
370 )
371 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
372 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
373 for dep in element.implied:
374 with self.subTest(element=element.name, dep=dep.name):
375 self.assertIn(dep.name, tableSpec.fields)
376 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
377 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
378 for foreignKey in tableSpec.foreignKeys:
379 self.assertIn(foreignKey.table, tableSpecs)
380 self.assertIn(foreignKey.table, element.dimensions)
381 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
382 for source, target in zip(foreignKey.source, foreignKey.target, strict=True):
383 self.assertIn(source, tableSpec.fields.names)
384 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
385 self.assertEqual(
386 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
387 )
388 self.assertEqual(
389 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
390 )
391 self.assertEqual(
392 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
393 )
395 def testPickling(self):
396 # Pickling and copying should always yield the exact same object within
397 # a single process (cross-process is impossible to test here).
398 universe1 = DimensionUniverse()
399 universe2 = pickle.loads(pickle.dumps(universe1))
400 universe3 = copy.copy(universe1)
401 universe4 = copy.deepcopy(universe1)
402 self.assertIs(universe1, universe2)
403 self.assertIs(universe1, universe3)
404 self.assertIs(universe1, universe4)
405 for element1 in universe1.getStaticElements():
406 element2 = pickle.loads(pickle.dumps(element1))
407 self.assertIs(element1, element2)
408 group1 = element1.minimal_group
409 group2 = pickle.loads(pickle.dumps(group1))
410 self.assertIs(group1, group2)
413@dataclass
414class SplitByStateFlags:
415 """A struct that separates data IDs with different states but the same
416 values.
417 """
419 minimal: DataCoordinateSequence | None = None
420 """Data IDs that only contain values for required dimensions.
422 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
423 if ``minimal.dimensions.implied`` has no elements.
424 `DataCoordinate.hasRecords()` will always return `False`.
425 """
427 complete: DataCoordinateSequence | None = None
428 """Data IDs that contain values for all dimensions.
430 `DataCoordinateSequence.hasFull()` will always `True` and
431 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
432 """
434 expanded: DataCoordinateSequence | None = None
435 """Data IDs that contain values for all dimensions as well as records.
437 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
438 always return `True` for this attribute.
439 """
441 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]:
442 """Iterate over the data IDs of different types.
444 Parameters
445 ----------
446 n : `int`, optional
447 If provided (`None` is default), iterate over only the ``nth``
448 data ID in each attribute.
450 Yields
451 ------
452 dataId : `DataCoordinate`
453 A data ID from one of the attributes in this struct.
454 """
455 if n is None:
456 s = slice(None, None)
457 else:
458 s = slice(n, n + 1)
459 if self.minimal is not None:
460 yield from self.minimal[s]
461 if self.complete is not None:
462 yield from self.complete[s]
463 if self.expanded is not None:
464 yield from self.expanded[s]
467class DataCoordinateTestCase(unittest.TestCase):
468 """Test `DataCoordinate`."""
470 RANDOM_SEED = 10
472 @classmethod
473 def setUpClass(cls):
474 cls.allDataIds = loadDimensionData()
476 def setUp(self):
477 self.rng = Random(self.RANDOM_SEED)
479 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
480 """Select random data IDs from those loaded from test data.
482 Parameters
483 ----------
484 n : `int`
485 Number of data IDs to select.
486 dataIds : `DataCoordinateSequence`, optional
487 Data IDs to select from. Defaults to ``self.allDataIds``.
489 Returns
490 -------
491 selected : `DataCoordinateSequence`
492 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
493 """
494 if dataIds is None:
495 dataIds = self.allDataIds
496 return DataCoordinateSequence(
497 self.rng.sample(dataIds, n),
498 dimensions=dataIds.dimensions,
499 hasFull=dataIds.hasFull(),
500 hasRecords=dataIds.hasRecords(),
501 check=False,
502 )
504 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup:
505 """Generate a random `DimensionGroup` that has a subset of the
506 dimensions in a given one.
508 Parameters
509 ----------
510 n : `int`
511 Number of dimensions to select, before automatic expansion by
512 `DimensionGroup`.
513 group : `DimensionGroup`, optional
514 Dimensions to select from. Defaults to
515 ``self.allDataIds.dimensions``.
517 Returns
518 -------
519 selected : `DimensionGroup`
520 ``n`` or more dimensions randomly selected from ``group`` with
521 replacement.
522 """
523 if group is None:
524 group = self.allDataIds.dimensions
525 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group))))
527 def splitByStateFlags(
528 self,
529 dataIds: DataCoordinateSequence | None = None,
530 *,
531 expanded: bool = True,
532 complete: bool = True,
533 minimal: bool = True,
534 ) -> SplitByStateFlags:
535 """Given a sequence of data IDs, generate new equivalent sequences
536 containing less information.
538 Parameters
539 ----------
540 dataIds : `DataCoordinateSequence`, optional.
541 Data IDs to start from. Defaults to ``self.allDataIds``.
542 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
543 `True`.
544 expanded : `bool`, optional
545 If `True` (default) include the original data IDs that contain all
546 information in the result.
547 complete : `bool`, optional
548 If `True` (default) include data IDs for which ``hasFull()``
549 returns `True` but ``hasRecords()`` does not.
550 minimal : `bool`, optional
551 If `True` (default) include data IDS that only contain values for
552 required dimensions, for which ``hasFull()`` may not return `True`.
554 Returns
555 -------
556 split : `SplitByStateFlags`
557 A dataclass holding the indicated data IDs in attributes that
558 correspond to the boolean keyword arguments.
559 """
560 if dataIds is None:
561 dataIds = self.allDataIds
562 assert dataIds.hasFull() and dataIds.hasRecords()
563 result = SplitByStateFlags(expanded=dataIds)
564 if complete:
565 result.complete = DataCoordinateSequence(
566 [
567 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions)
568 for e in result.expanded
569 ],
570 dimensions=dataIds.dimensions,
571 )
572 self.assertTrue(result.complete.hasFull())
573 self.assertFalse(result.complete.hasRecords())
574 if minimal:
575 result.minimal = DataCoordinateSequence(
576 [
577 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions)
578 for e in result.expanded
579 ],
580 dimensions=dataIds.dimensions,
581 )
582 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied)
583 self.assertFalse(result.minimal.hasRecords())
584 if not expanded:
585 result.expanded = None
586 return result
588 def testMappingViews(self):
589 """Test that the ``mapping`` and ``required`` attributes in
590 `DataCoordinate` are self-consistent and consistent with the
591 ``dimensions`` property.
592 """
593 for _ in range(5):
594 dimensions = self.randomDimensionSubset()
595 dataIds = self.randomDataIds(n=1).subset(dimensions)
596 split = self.splitByStateFlags(dataIds)
597 for dataId in split.chain():
598 with self.subTest(dataId=dataId):
599 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
600 self.assertEqual(
601 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required]
602 )
603 self.assertEqual(
604 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required]
605 )
606 self.assertEqual(dataId.required.keys(), dataId.dimensions.required)
607 for dataId in itertools.chain(split.complete, split.expanded):
608 with self.subTest(dataId=dataId):
609 self.assertTrue(dataId.hasFull())
610 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys())
611 self.assertEqual(
612 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()]
613 )
615 def test_pickle(self):
616 for _ in range(5):
617 dimensions = self.randomDimensionSubset()
618 dataIds = self.randomDataIds(n=1).subset(dimensions)
619 split = self.splitByStateFlags(dataIds)
620 for data_id in split.chain():
621 s = pickle.dumps(data_id)
622 read_data_id: DataCoordinate = pickle.loads(s)
623 self.assertEqual(data_id, read_data_id)
624 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
625 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
626 if data_id.hasFull():
627 self.assertEqual(data_id.mapping, read_data_id.mapping)
628 if data_id.hasRecords():
629 for element_name in data_id.dimensions.elements:
630 self.assertEqual(
631 data_id.records[element_name], read_data_id.records[element_name]
632 )
634 def test_record_attributes(self):
635 """Test that dimension records are available as attributes on expanded
636 data coordinates.
637 """
638 for _ in range(5):
639 dimensions = self.randomDimensionSubset()
640 dataIds = self.randomDataIds(n=1).subset(dimensions)
641 split = self.splitByStateFlags(dataIds)
642 for data_id in split.expanded:
643 for element_name in data_id.dimensions.elements:
644 self.assertIs(getattr(data_id, element_name), data_id.records[element_name])
645 self.assertIn(element_name, dir(data_id))
646 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
647 data_id.not_a_dimension_name
648 for data_id in itertools.chain(split.minimal, split.complete):
649 for element_name in data_id.dimensions.elements:
650 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
651 getattr(data_id, element_name)
652 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
653 data_id.not_a_dimension_name
655 def testEquality(self):
656 """Test that different `DataCoordinate` instances with different state
657 flags can be compared with each other and other mappings.
658 """
659 dataIds = self.randomDataIds(n=2)
660 split = self.splitByStateFlags(dataIds)
661 # Iterate over all combinations of different states of DataCoordinate,
662 # with the same underlying data ID values.
663 for a0, b0 in itertools.combinations(split.chain(0), 2):
664 self.assertEqual(a0, b0)
665 # Same thing, for a different data ID value.
666 for a1, b1 in itertools.combinations(split.chain(1), 2):
667 self.assertEqual(a1, b1)
668 # Iterate over all combinations of different states of DataCoordinate,
669 # with different underlying data ID values.
670 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
671 self.assertNotEqual(a0, b1)
672 self.assertNotEqual(a1, b0)
674 def testStandardize(self):
675 """Test constructing a DataCoordinate from many different kinds of
676 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
677 """
678 for _ in range(5):
679 dimensions = self.randomDimensionSubset()
680 dataIds = self.randomDataIds(n=1).subset(dimensions)
681 split = self.splitByStateFlags(dataIds)
682 for dataId in split.chain():
683 # Passing in any kind of DataCoordinate alone just returns
684 # that object.
685 self.assertIs(dataId, DataCoordinate.standardize(dataId))
686 # Same if we also explicitly pass the dimensions we want.
687 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions))
688 # Same if we pass the dimensions and some irrelevant
689 # kwargs.
690 self.assertIs(
691 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12)
692 )
693 # Test constructing a new data ID from this one with a
694 # subset of the dimensions.
695 # This is not possible for some combinations of
696 # dimensions if hasFull is False (see
697 # `DataCoordinate.subset` docs).
698 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions)
699 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required:
700 newDataIds = [
701 dataId.subset(newDimensions),
702 DataCoordinate.standardize(dataId, dimensions=newDimensions),
703 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12),
704 ]
705 for newDataId in newDataIds:
706 with self.subTest(newDataId=newDataId, type=type(dataId)):
707 commonKeys = dataId.dimensions.required & newDataId.dimensions.required
708 self.assertTrue(commonKeys)
709 self.assertEqual(
710 [newDataId[k] for k in commonKeys],
711 [dataId[k] for k in commonKeys],
712 )
713 # This should never "downgrade" from
714 # Complete to Minimal or Expanded to Complete.
715 if dataId.hasRecords():
716 self.assertTrue(newDataId.hasRecords())
717 if dataId.hasFull():
718 self.assertTrue(newDataId.hasFull())
719 # Start from a complete data ID, and pass its values in via several
720 # different ways that should be equivalent.
721 for dataId in split.complete:
722 # Split the keys (dimension names) into two random subsets, so
723 # we can pass some as kwargs below.
724 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2))
725 keys2 = dataId.dimensions.names - keys1
726 newCompleteDataIds = [
727 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe),
728 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions),
729 DataCoordinate.standardize(
730 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping
731 ),
732 DataCoordinate.standardize(
733 DataCoordinate.make_empty(dataId.dimensions.universe),
734 dimensions=dataId.dimensions,
735 **dataId.mapping,
736 ),
737 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe),
738 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping),
739 DataCoordinate.standardize(
740 {k: dataId[k] for k in keys1},
741 universe=dataId.universe,
742 **{k: dataId[k] for k in keys2},
743 ),
744 DataCoordinate.standardize(
745 {k: dataId[k] for k in keys1},
746 dimensions=dataId.dimensions,
747 **{k: dataId[k] for k in keys2},
748 ),
749 ]
750 for newDataId in newCompleteDataIds:
751 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
752 self.assertEqual(dataId, newDataId)
753 self.assertTrue(newDataId.hasFull())
755 def testUnion(self):
756 """Test `DataCoordinate.union`."""
757 # Make test groups to combine; mostly random, but with a few explicit
758 # cases to make sure certain edge cases are covered.
759 groups = [self.randomDimensionSubset(n=2) for i in range(2)]
760 groups.append(self.allDataIds.universe["visit"].minimal_group)
761 groups.append(self.allDataIds.universe["detector"].minimal_group)
762 groups.append(self.allDataIds.universe["physical_filter"].minimal_group)
763 groups.append(self.allDataIds.universe["band"].minimal_group)
764 # Iterate over all combinations, including the same graph with itself.
765 for group1, group2 in itertools.product(groups, repeat=2):
766 parentDataIds = self.randomDataIds(n=1)
767 split1 = self.splitByStateFlags(parentDataIds.subset(group1))
768 split2 = self.splitByStateFlags(parentDataIds.subset(group2))
769 (parentDataId,) = parentDataIds
770 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
771 unioned = lhs.union(rhs)
772 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
773 self.assertEqual(unioned.dimensions, group1.union(group2))
774 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions))
775 if unioned.hasFull():
776 self.assertEqual(unioned.subset(lhs.dimensions), lhs)
777 self.assertEqual(unioned.subset(rhs.dimensions), rhs)
778 if lhs.hasFull() and rhs.hasFull():
779 self.assertTrue(unioned.hasFull())
780 if lhs.dimensions >= unioned.dimensions and lhs.hasFull():
781 self.assertTrue(unioned.hasFull())
782 if lhs.hasRecords():
783 self.assertTrue(unioned.hasRecords())
784 if rhs.dimensions >= unioned.dimensions and rhs.hasFull():
785 self.assertTrue(unioned.hasFull())
786 if rhs.hasRecords():
787 self.assertTrue(unioned.hasRecords())
788 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names:
789 self.assertTrue(unioned.hasFull())
790 if (
791 lhs.hasRecords()
792 and rhs.hasRecords()
793 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements
794 ):
795 self.assertTrue(unioned.hasRecords())
797 def testRegions(self):
798 """Test that data IDs for a few known dimensions have the expected
799 regions.
800 """
801 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
802 self.assertIsNotNone(dataId.region)
803 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
804 self.assertEqual(dataId.region, dataId.records["visit"].region)
805 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])):
806 self.assertIsNotNone(dataId.region)
807 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"})
808 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
809 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])):
810 self.assertIsNotNone(dataId.region)
811 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
812 self.assertEqual(dataId.region, dataId.records["tract"].region)
813 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
814 self.assertIsNotNone(dataId.region)
815 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"})
816 self.assertEqual(dataId.region, dataId.records["patch"].region)
817 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])):
818 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
819 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
821 def testTimespans(self):
822 """Test that data IDs for a few known dimensions have the expected
823 timespans.
824 """
825 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])):
826 self.assertIsNotNone(dataId.timespan)
827 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"})
828 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
829 self.assertEqual(dataId.timespan, dataId.visit.timespan)
830 # Also test the case for non-temporal DataIds.
831 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])):
832 self.assertIsNone(dataId.timespan)
834 def testIterableStatusFlags(self):
835 """Test that DataCoordinateSet and DataCoordinateSequence compute
836 their hasFull and hasRecords flags correctly from their elements.
837 """
838 dataIds = self.randomDataIds(n=10)
839 split = self.splitByStateFlags(dataIds)
840 for cls in (DataCoordinateSet, DataCoordinateSequence):
841 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull())
842 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull())
843 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords())
844 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords())
845 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull())
846 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull())
847 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords())
848 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords())
849 with self.assertRaises(ValueError):
850 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True)
851 self.assertEqual(
852 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(),
853 not dataIds.dimensions.implied,
854 )
855 self.assertEqual(
856 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(),
857 not dataIds.dimensions.implied,
858 )
859 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords())
860 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords())
861 with self.assertRaises(ValueError):
862 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True)
863 if dataIds.dimensions.implied:
864 with self.assertRaises(ValueError):
865 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True)
867 def testSetOperations(self):
868 """Test for self-consistency across DataCoordinateSet's operations."""
869 c = self.randomDataIds(n=10).toSet()
870 a = self.randomDataIds(n=20).toSet() | c
871 b = self.randomDataIds(n=20).toSet() | c
872 # Make sure we don't have a particularly unlucky random seed, since
873 # that would make a lot of this test uninteresting.
874 self.assertNotEqual(a, b)
875 self.assertGreater(len(a), 0)
876 self.assertGreater(len(b), 0)
877 # The rest of the tests should not depend on the random seed.
878 self.assertEqual(a, a)
879 self.assertNotEqual(a, a.toSequence())
880 self.assertEqual(a, a.toSequence().toSet())
881 self.assertEqual(a, a.toSequence().toSet())
882 self.assertEqual(b, b)
883 self.assertNotEqual(b, b.toSequence())
884 self.assertEqual(b, b.toSequence().toSet())
885 self.assertEqual(a & b, a.intersection(b))
886 self.assertLessEqual(a & b, a)
887 self.assertLessEqual(a & b, b)
888 self.assertEqual(a | b, a.union(b))
889 self.assertGreaterEqual(a | b, a)
890 self.assertGreaterEqual(a | b, b)
891 self.assertEqual(a - b, a.difference(b))
892 self.assertLessEqual(a - b, a)
893 self.assertLessEqual(b - a, b)
894 self.assertEqual(a ^ b, a.symmetric_difference(b))
895 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
897 def testPackers(self):
898 (instrument_data_id,) = self.allDataIds.subset(
899 self.allDataIds.universe.conform(["instrument"])
900 ).toSet()
901 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"]))
902 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions)
903 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
904 self.assertEqual(packed_id, detector_data_id["detector"])
905 self.assertEqual(max_bits, packer.maxBits)
906 self.assertEqual(
907 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
908 )
909 self.assertEqual(packer.pack(detector_data_id), packed_id)
910 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
911 self.assertEqual(packer.unpack(packed_id), detector_data_id)
913 def test_dimension_group_pydantic(self):
914 """Test that DimensionGroup round-trips through Pydantic as long as
915 it's given the universe when validated.
916 """
917 dimensions = self.allDataIds.dimensions
918 adapter = pydantic.TypeAdapter(DimensionGroup)
919 json_str = adapter.dump_json(dimensions)
920 python_data = adapter.dump_python(dimensions)
921 self.assertEqual(
922 dimensions, adapter.validate_json(json_str, context=dict(universe=dimensions.universe))
923 )
924 self.assertEqual(
925 dimensions, adapter.validate_python(python_data, context=dict(universe=dimensions.universe))
926 )
927 self.assertEqual(dimensions, adapter.validate_python(dimensions))
929 def test_dimension_element_pydantic(self):
930 """Test that DimensionElement round-trips through Pydantic as long as
931 it's given the universe when validated.
932 """
933 element = self.allDataIds.universe["visit"]
934 adapter = pydantic.TypeAdapter(DimensionElement)
935 json_str = adapter.dump_json(element)
936 python_data = adapter.dump_python(element)
937 self.assertEqual(element, adapter.validate_json(json_str, context=dict(universe=element.universe)))
938 self.assertEqual(
939 element, adapter.validate_python(python_data, context=dict(universe=element.universe))
940 )
941 self.assertEqual(element, adapter.validate_python(element))
944if __name__ == "__main__":
945 unittest.main()