Coverage for tests/test_dimensions.py: 11%
444 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:44 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:44 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import itertools
30import math
31import os
32import pickle
33import unittest
34from collections.abc import Iterator
35from dataclasses import dataclass
36from random import Random
38import lsst.sphgeom
39from lsst.daf.butler import (
40 Config,
41 DataCoordinate,
42 DataCoordinateSequence,
43 DataCoordinateSet,
44 Dimension,
45 DimensionConfig,
46 DimensionGraph,
47 DimensionPacker,
48 DimensionUniverse,
49 NamedKeyDict,
50 NamedValueSet,
51 TimespanDatabaseRepresentation,
52 YamlRepoImportBackend,
53)
54from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory
56DIMENSION_DATA_FILE = os.path.normpath(
57 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
58)
61def loadDimensionData() -> DataCoordinateSequence:
62 """Load dimension data from an export file included in the code repository.
64 Returns
65 -------
66 dataIds : `DataCoordinateSet`
67 A set containing all data IDs in the export file.
68 """
69 # Create an in-memory SQLite database and Registry just to import the YAML
70 # data and retreive it as a set of DataCoordinate objects.
71 config = RegistryConfig()
72 config["db"] = "sqlite://"
73 registry = _RegistryFactory(config).create_from_config()
74 with open(DIMENSION_DATA_FILE) as stream:
75 backend = YamlRepoImportBackend(stream, registry)
76 backend.register()
77 backend.load(datastore=None)
78 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"])
79 return registry.queryDataIds(dimensions).expanded().toSequence()
82class ConcreteTestDimensionPacker(DimensionPacker):
83 """A concrete `DimensionPacker` for testing its base class implementations.
85 This class just returns the detector ID as-is.
86 """
88 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph):
89 super().__init__(fixed, dimensions)
90 self._n_detectors = fixed.records["instrument"].detector_max
91 self._max_bits = (self._n_detectors - 1).bit_length()
93 @property
94 def maxBits(self) -> int:
95 # Docstring inherited from DimensionPacker.maxBits
96 return self._max_bits
98 def _pack(self, dataId: DataCoordinate) -> int:
99 # Docstring inherited from DimensionPacker._pack
100 return dataId["detector"]
102 def unpack(self, packedId: int) -> DataCoordinate:
103 # Docstring inherited from DimensionPacker.unpack
104 return DataCoordinate.standardize(
105 {
106 "instrument": self.fixed["instrument"],
107 "detector": packedId,
108 },
109 graph=self.dimensions,
110 )
113class DimensionTestCase(unittest.TestCase):
114 """Tests for dimensions.
116 All tests here rely on the content of ``config/dimensions.yaml``, either
117 to test that the definitions there are read in properly or just as generic
118 data for testing various operations.
119 """
121 def setUp(self):
122 self.universe = DimensionUniverse()
124 def checkGraphInvariants(self, graph):
125 elements = list(graph.elements)
126 for n, element in enumerate(elements):
127 # Ordered comparisons on graphs behave like sets.
128 self.assertLessEqual(element.graph, graph)
129 # Ordered comparisons on elements correspond to the ordering within
130 # a DimensionUniverse (topological, with deterministic
131 # tiebreakers).
132 for other in elements[:n]:
133 self.assertLess(other, element)
134 self.assertLessEqual(other, element)
135 for other in elements[n + 1 :]:
136 self.assertGreater(other, element)
137 self.assertGreaterEqual(other, element)
138 if isinstance(element, Dimension):
139 self.assertEqual(element.graph.required, element.required)
140 self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
141 self.assertCountEqual(
142 graph.required,
143 [
144 dimension
145 for dimension in graph.dimensions
146 if not any(dimension in other.graph.implied for other in graph.elements)
147 ],
148 )
149 self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
150 self.assertCountEqual(
151 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)]
152 )
153 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied))
154 # Check primary key traversal order: each element should follow any it
155 # requires, and element that is implied by any other in the graph
156 # follow at least one of those.
157 seen = NamedValueSet()
158 for element in graph.primaryKeyTraversalOrder:
159 with self.subTest(required=graph.required, implied=graph.implied, element=element):
160 seen.add(element)
161 self.assertLessEqual(element.graph.required, seen)
162 if element in graph.implied:
163 self.assertTrue(any(element in s.implied for s in seen))
164 self.assertCountEqual(seen, graph.elements)
166 def testConfigPresent(self):
167 config = self.universe.dimensionConfig
168 self.assertIsInstance(config, DimensionConfig)
170 def testCompatibility(self):
171 # Simple check that should always be true.
172 self.assertTrue(self.universe.isCompatibleWith(self.universe))
174 # Create a universe like the default universe but with a different
175 # version number.
176 clone = self.universe.dimensionConfig.copy()
177 clone["version"] = clone["version"] + 1_000_000 # High version number
178 universe_clone = DimensionUniverse(config=clone)
179 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm:
180 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
181 self.assertIn("differing versions", "\n".join(cm.output))
183 # Create completely incompatible universe.
184 config = Config(
185 {
186 "version": 1,
187 "namespace": "compat_test",
188 "skypix": {
189 "common": "htm7",
190 "htm": {
191 "class": "lsst.sphgeom.HtmPixelization",
192 "max_level": 24,
193 },
194 },
195 "elements": {
196 "A": {
197 "keys": [
198 {
199 "name": "id",
200 "type": "int",
201 }
202 ],
203 "storage": {
204 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
205 },
206 },
207 "B": {
208 "keys": [
209 {
210 "name": "id",
211 "type": "int",
212 }
213 ],
214 "storage": {
215 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
216 },
217 },
218 },
219 "packers": {},
220 }
221 )
222 universe2 = DimensionUniverse(config=config)
223 self.assertFalse(universe2.isCompatibleWith(self.universe))
225 def testVersion(self):
226 self.assertEqual(self.universe.namespace, "daf_butler")
227 # Test was added starting at version 2.
228 self.assertGreaterEqual(self.universe.version, 2)
230 def testConfigRead(self):
231 self.assertEqual(
232 set(self.universe.getStaticDimensions().names),
233 {
234 "instrument",
235 "visit",
236 "visit_system",
237 "exposure",
238 "detector",
239 "physical_filter",
240 "band",
241 "subfilter",
242 "skymap",
243 "tract",
244 "patch",
245 }
246 | {f"htm{level}" for level in range(25)}
247 | {f"healpix{level}" for level in range(18)},
248 )
250 def testGraphs(self):
251 self.checkGraphInvariants(self.universe.empty)
252 for element in self.universe.getStaticElements():
253 self.checkGraphInvariants(element.graph)
255 def testInstrumentDimensions(self):
256 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
257 self.assertCountEqual(
258 graph.dimensions.names,
259 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"),
260 )
261 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit"))
262 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
263 self.assertCountEqual(
264 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
265 )
266 self.assertCountEqual(graph.governors.names, {"instrument"})
268 def testCalibrationDimensions(self):
269 graph = DimensionGraph(self.universe, names=("physical_filter", "detector"))
270 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band"))
271 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter"))
272 self.assertCountEqual(graph.implied.names, ("band",))
273 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
274 self.assertCountEqual(graph.governors.names, {"instrument"})
276 def testObservationDimensions(self):
277 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
278 self.assertCountEqual(
279 graph.dimensions.names,
280 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"),
281 )
282 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit"))
283 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
284 self.assertCountEqual(
285 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
286 )
287 self.assertCountEqual(graph.spatial.names, ("observation_regions",))
288 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
289 self.assertCountEqual(graph.governors.names, {"instrument"})
290 self.assertEqual(graph.spatial.names, {"observation_regions"})
291 self.assertEqual(graph.temporal.names, {"observation_timespans"})
292 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"])
293 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"])
294 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"])
295 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"])
296 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"])
297 self.assertEqual(
298 self.universe.get_elements_populated_by(self.universe["visit"]),
299 NamedValueSet(
300 {
301 self.universe["visit"],
302 self.universe["visit_definition"],
303 self.universe["visit_system_membership"],
304 self.universe["visit_detector_region"],
305 }
306 ),
307 )
309 def testSkyMapDimensions(self):
310 graph = DimensionGraph(self.universe, names=("patch",))
311 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch"))
312 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch"))
313 self.assertCountEqual(graph.implied.names, ())
314 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
315 self.assertCountEqual(graph.spatial.names, ("skymap_regions",))
316 self.assertCountEqual(graph.governors.names, {"skymap"})
317 self.assertEqual(graph.spatial.names, {"skymap_regions"})
318 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"])
320 def testSubsetCalculation(self):
321 """Test that independent spatial and temporal options are computed
322 correctly.
323 """
324 graph = DimensionGraph(
325 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure")
326 )
327 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm"))
328 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
330 def testSchemaGeneration(self):
331 tableSpecs = NamedKeyDict({})
332 for element in self.universe.getStaticElements():
333 if element.hasTable and element.viewOf is None:
334 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
335 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
336 )
337 for element, tableSpec in tableSpecs.items():
338 for dep in element.required:
339 with self.subTest(element=element.name, dep=dep.name):
340 if dep != element:
341 self.assertIn(dep.name, tableSpec.fields)
342 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
343 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
344 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
345 self.assertFalse(tableSpec.fields[dep.name].nullable)
346 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
347 else:
348 self.assertIn(element.primaryKey.name, tableSpec.fields)
349 self.assertEqual(
350 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
351 )
352 self.assertEqual(
353 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
354 )
355 self.assertEqual(
356 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
357 )
358 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
359 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
360 for dep in element.implied:
361 with self.subTest(element=element.name, dep=dep.name):
362 self.assertIn(dep.name, tableSpec.fields)
363 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
364 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
365 for foreignKey in tableSpec.foreignKeys:
366 self.assertIn(foreignKey.table, tableSpecs)
367 self.assertIn(foreignKey.table, element.graph.dimensions.names)
368 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
369 for source, target in zip(foreignKey.source, foreignKey.target, strict=True):
370 self.assertIn(source, tableSpec.fields.names)
371 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
372 self.assertEqual(
373 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
374 )
375 self.assertEqual(
376 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
377 )
378 self.assertEqual(
379 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
380 )
382 def testPickling(self):
383 # Pickling and copying should always yield the exact same object within
384 # a single process (cross-process is impossible to test here).
385 universe1 = DimensionUniverse()
386 universe2 = pickle.loads(pickle.dumps(universe1))
387 universe3 = copy.copy(universe1)
388 universe4 = copy.deepcopy(universe1)
389 self.assertIs(universe1, universe2)
390 self.assertIs(universe1, universe3)
391 self.assertIs(universe1, universe4)
392 for element1 in universe1.getStaticElements():
393 element2 = pickle.loads(pickle.dumps(element1))
394 self.assertIs(element1, element2)
395 graph1 = element1.graph
396 graph2 = pickle.loads(pickle.dumps(graph1))
397 self.assertIs(graph1, graph2)
400@dataclass
401class SplitByStateFlags:
402 """A struct that separates data IDs with different states but the same
403 values.
404 """
406 minimal: DataCoordinateSequence | None = None
407 """Data IDs that only contain values for required dimensions.
409 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
410 if ``minimal.graph.implied`` has no elements.
411 `DataCoordinate.hasRecords()` will always return `False`.
412 """
414 complete: DataCoordinateSequence | None = None
415 """Data IDs that contain values for all dimensions.
417 `DataCoordinateSequence.hasFull()` will always `True` and
418 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
419 """
421 expanded: DataCoordinateSequence | None = None
422 """Data IDs that contain values for all dimensions as well as records.
424 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
425 always return `True` for this attribute.
426 """
428 def chain(self, n: int | None = None) -> Iterator:
429 """Iterate over the data IDs of different types.
431 Parameters
432 ----------
433 n : `int`, optional
434 If provided (`None` is default), iterate over only the ``nth``
435 data ID in each attribute.
437 Yields
438 ------
439 dataId : `DataCoordinate`
440 A data ID from one of the attributes in this struct.
441 """
442 if n is None:
443 s = slice(None, None)
444 else:
445 s = slice(n, n + 1)
446 if self.minimal is not None:
447 yield from self.minimal[s]
448 if self.complete is not None:
449 yield from self.complete[s]
450 if self.expanded is not None:
451 yield from self.expanded[s]
454class DataCoordinateTestCase(unittest.TestCase):
455 """Test `DataCoordinate`."""
457 RANDOM_SEED = 10
459 @classmethod
460 def setUpClass(cls):
461 cls.allDataIds = loadDimensionData()
463 def setUp(self):
464 self.rng = Random(self.RANDOM_SEED)
466 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
467 """Select random data IDs from those loaded from test data.
469 Parameters
470 ----------
471 n : `int`
472 Number of data IDs to select.
473 dataIds : `DataCoordinateSequence`, optional
474 Data IDs to select from. Defaults to ``self.allDataIds``.
476 Returns
477 -------
478 selected : `DataCoordinateSequence`
479 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
480 """
481 if dataIds is None:
482 dataIds = self.allDataIds
483 return DataCoordinateSequence(
484 self.rng.sample(dataIds, n),
485 graph=dataIds.graph,
486 hasFull=dataIds.hasFull(),
487 hasRecords=dataIds.hasRecords(),
488 check=False,
489 )
491 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph:
492 """Generate a random `DimensionGraph` that has a subset of the
493 dimensions in a given one.
495 Parameters
496 ----------
497 n : `int`
498 Number of dimensions to select, before automatic expansion by
499 `DimensionGraph`.
500 dataIds : `DimensionGraph`, optional
501 Dimensions to select from. Defaults to ``self.allDataIds.graph``.
503 Returns
504 -------
505 selected : `DimensionGraph`
506 ``n`` or more dimensions randomly selected from ``graph`` with
507 replacement.
508 """
509 if graph is None:
510 graph = self.allDataIds.graph
511 return DimensionGraph(
512 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions)))
513 )
515 def splitByStateFlags(
516 self,
517 dataIds: DataCoordinateSequence | None = None,
518 *,
519 expanded: bool = True,
520 complete: bool = True,
521 minimal: bool = True,
522 ) -> SplitByStateFlags:
523 """Given a sequence of data IDs, generate new equivalent sequences
524 containing less information.
526 Parameters
527 ----------
528 dataIds : `DataCoordinateSequence`, optional.
529 Data IDs to start from. Defaults to ``self.allDataIds``.
530 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
531 `True`.
532 expanded : `bool`, optional
533 If `True` (default) include the original data IDs that contain all
534 information in the result.
535 complete : `bool`, optional
536 If `True` (default) include data IDs for which ``hasFull()``
537 returns `True` but ``hasRecords()`` does not.
538 minimal : `bool`, optional
539 If `True` (default) include data IDS that only contain values for
540 required dimensions, for which ``hasFull()`` may not return `True`.
542 Returns
543 -------
544 split : `SplitByStateFlags`
545 A dataclass holding the indicated data IDs in attributes that
546 correspond to the boolean keyword arguments.
547 """
548 if dataIds is None:
549 dataIds = self.allDataIds
550 assert dataIds.hasFull() and dataIds.hasRecords()
551 result = SplitByStateFlags(expanded=dataIds)
552 if complete:
553 result.complete = DataCoordinateSequence(
554 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded],
555 graph=dataIds.graph,
556 )
557 self.assertTrue(result.complete.hasFull())
558 self.assertFalse(result.complete.hasRecords())
559 if minimal:
560 result.minimal = DataCoordinateSequence(
561 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded],
562 graph=dataIds.graph,
563 )
564 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied)
565 self.assertFalse(result.minimal.hasRecords())
566 if not expanded:
567 result.expanded = None
568 return result
570 def testMappingInterface(self):
571 """Test that the mapping interface in `DataCoordinate` and (when
572 applicable) its ``full`` property are self-consistent and consistent
573 with the ``graph`` property.
574 """
575 for _ in range(5):
576 dimensions = self.randomDimensionSubset()
577 dataIds = self.randomDataIds(n=1).subset(dimensions)
578 split = self.splitByStateFlags(dataIds)
579 for dataId in split.chain():
580 with self.subTest(dataId=dataId):
581 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId])
582 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId])
583 self.assertEqual(dataId.keys(), dataId.graph.required)
584 for dataId in itertools.chain(split.complete, split.expanded):
585 with self.subTest(dataId=dataId):
586 self.assertTrue(dataId.hasFull())
587 self.assertEqual(dataId.graph.dimensions, dataId.full.keys())
588 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions])
590 def test_pickle(self):
591 for _ in range(5):
592 dimensions = self.randomDimensionSubset()
593 dataIds = self.randomDataIds(n=1).subset(dimensions)
594 split = self.splitByStateFlags(dataIds)
595 for data_id in split.chain():
596 s = pickle.dumps(data_id)
597 read_data_id = pickle.loads(s)
598 self.assertEqual(data_id, read_data_id)
599 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
600 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
601 if data_id.hasFull():
602 self.assertEqual(data_id.full, read_data_id.full)
603 if data_id.hasRecords():
604 self.assertEqual(data_id.records, read_data_id.records)
606 def test_record_attributes(self):
607 """Test that dimension records are available as attributes on expanded
608 data coordinates.
609 """
610 for _ in range(5):
611 dimensions = self.randomDimensionSubset()
612 dataIds = self.randomDataIds(n=1).subset(dimensions)
613 split = self.splitByStateFlags(dataIds)
614 for data_id in split.expanded:
615 for element in data_id.graph.elements:
616 self.assertIs(getattr(data_id, element.name), data_id.records[element.name])
617 self.assertIn(element.name, dir(data_id))
618 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
619 data_id.not_a_dimension_name
620 for data_id in itertools.chain(split.minimal, split.complete):
621 for element in data_id.graph.elements:
622 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
623 getattr(data_id, element.name)
624 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
625 data_id.not_a_dimension_name
627 def testEquality(self):
628 """Test that different `DataCoordinate` instances with different state
629 flags can be compared with each other and other mappings.
630 """
631 dataIds = self.randomDataIds(n=2)
632 split = self.splitByStateFlags(dataIds)
633 # Iterate over all combinations of different states of DataCoordinate,
634 # with the same underlying data ID values.
635 for a0, b0 in itertools.combinations(split.chain(0), 2):
636 self.assertEqual(a0, b0)
637 self.assertEqual(a0, b0.byName())
638 self.assertEqual(a0.byName(), b0)
639 # Same thing, for a different data ID value.
640 for a1, b1 in itertools.combinations(split.chain(1), 2):
641 self.assertEqual(a1, b1)
642 self.assertEqual(a1, b1.byName())
643 self.assertEqual(a1.byName(), b1)
644 # Iterate over all combinations of different states of DataCoordinate,
645 # with different underlying data ID values.
646 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
647 self.assertNotEqual(a0, b1)
648 self.assertNotEqual(a1, b0)
649 self.assertNotEqual(a0, b1.byName())
650 self.assertNotEqual(a0.byName(), b1)
651 self.assertNotEqual(a1, b0.byName())
652 self.assertNotEqual(a1.byName(), b0)
654 def testStandardize(self):
655 """Test constructing a DataCoordinate from many different kinds of
656 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
657 """
658 for _ in range(5):
659 dimensions = self.randomDimensionSubset()
660 dataIds = self.randomDataIds(n=1).subset(dimensions)
661 split = self.splitByStateFlags(dataIds)
662 for dataId in split.chain():
663 # Passing in any kind of DataCoordinate alone just returns
664 # that object.
665 self.assertIs(dataId, DataCoordinate.standardize(dataId))
666 # Same if we also explicitly pass the dimensions we want.
667 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph))
668 # Same if we pass the dimensions and some irrelevant
669 # kwargs.
670 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12))
671 # Test constructing a new data ID from this one with a
672 # subset of the dimensions.
673 # This is not possible for some combinations of
674 # dimensions if hasFull is False (see
675 # `DataCoordinate.subset` docs).
676 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph)
677 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required):
678 newDataIds = [
679 dataId.subset(newDimensions),
680 DataCoordinate.standardize(dataId, graph=newDimensions),
681 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12),
682 ]
683 for newDataId in newDataIds:
684 with self.subTest(newDataId=newDataId, type=type(dataId)):
685 commonKeys = dataId.keys() & newDataId.keys()
686 self.assertTrue(commonKeys)
687 self.assertEqual(
688 [newDataId[k] for k in commonKeys],
689 [dataId[k] for k in commonKeys],
690 )
691 # This should never "downgrade" from
692 # Complete to Minimal or Expanded to Complete.
693 if dataId.hasRecords():
694 self.assertTrue(newDataId.hasRecords())
695 if dataId.hasFull():
696 self.assertTrue(newDataId.hasFull())
697 # Start from a complete data ID, and pass its values in via several
698 # different ways that should be equivalent.
699 for dataId in split.complete:
700 # Split the keys (dimension names) into two random subsets, so
701 # we can pass some as kwargs below.
702 keys1 = set(
703 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2)
704 )
705 keys2 = dataId.graph.dimensions.names - keys1
706 newCompleteDataIds = [
707 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe),
708 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph),
709 DataCoordinate.standardize(
710 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName()
711 ),
712 DataCoordinate.standardize(
713 DataCoordinate.makeEmpty(dataId.graph.universe),
714 graph=dataId.graph,
715 **dataId.full.byName(),
716 ),
717 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe),
718 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()),
719 DataCoordinate.standardize(
720 {k: dataId[k] for k in keys1},
721 universe=dataId.universe,
722 **{k: dataId[k] for k in keys2},
723 ),
724 DataCoordinate.standardize(
725 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2}
726 ),
727 ]
728 for newDataId in newCompleteDataIds:
729 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
730 self.assertEqual(dataId, newDataId)
731 self.assertTrue(newDataId.hasFull())
733 def testUnion(self):
734 """Test `DataCoordinate.union`."""
735 # Make test graphs to combine; mostly random, but with a few explicit
736 # cases to make sure certain edge cases are covered.
737 graphs = [self.randomDimensionSubset(n=2) for i in range(2)]
738 graphs.append(self.allDataIds.universe["visit"].graph)
739 graphs.append(self.allDataIds.universe["detector"].graph)
740 graphs.append(self.allDataIds.universe["physical_filter"].graph)
741 graphs.append(self.allDataIds.universe["band"].graph)
742 # Iterate over all combinations, including the same graph with itself.
743 for graph1, graph2 in itertools.product(graphs, repeat=2):
744 parentDataIds = self.randomDataIds(n=1)
745 split1 = self.splitByStateFlags(parentDataIds.subset(graph1))
746 split2 = self.splitByStateFlags(parentDataIds.subset(graph2))
747 (parentDataId,) = parentDataIds
748 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
749 unioned = lhs.union(rhs)
750 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
751 self.assertEqual(unioned.graph, graph1.union(graph2))
752 self.assertEqual(unioned, parentDataId.subset(unioned.graph))
753 if unioned.hasFull():
754 self.assertEqual(unioned.subset(lhs.graph), lhs)
755 self.assertEqual(unioned.subset(rhs.graph), rhs)
756 if lhs.hasFull() and rhs.hasFull():
757 self.assertTrue(unioned.hasFull())
758 if lhs.graph >= unioned.graph and lhs.hasFull():
759 self.assertTrue(unioned.hasFull())
760 if lhs.hasRecords():
761 self.assertTrue(unioned.hasRecords())
762 if rhs.graph >= unioned.graph and rhs.hasFull():
763 self.assertTrue(unioned.hasFull())
764 if rhs.hasRecords():
765 self.assertTrue(unioned.hasRecords())
766 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions:
767 self.assertTrue(unioned.hasFull())
768 if (
769 lhs.hasRecords()
770 and rhs.hasRecords()
771 and lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements
772 ):
773 self.assertTrue(unioned.hasRecords())
775 def testRegions(self):
776 """Test that data IDs for a few known dimensions have the expected
777 regions.
778 """
779 for dataId in self.randomDataIds(n=4).subset(
780 DimensionGraph(self.allDataIds.universe, names=["visit"])
781 ):
782 self.assertIsNotNone(dataId.region)
783 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
784 self.assertEqual(dataId.region, dataId.records["visit"].region)
785 for dataId in self.randomDataIds(n=4).subset(
786 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"])
787 ):
788 self.assertIsNotNone(dataId.region)
789 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
790 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
791 for dataId in self.randomDataIds(n=4).subset(
792 DimensionGraph(self.allDataIds.universe, names=["tract"])
793 ):
794 self.assertIsNotNone(dataId.region)
795 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
796 self.assertEqual(dataId.region, dataId.records["tract"].region)
797 for dataId in self.randomDataIds(n=4).subset(
798 DimensionGraph(self.allDataIds.universe, names=["patch"])
799 ):
800 self.assertIsNotNone(dataId.region)
801 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
802 self.assertEqual(dataId.region, dataId.records["patch"].region)
803 for data_id in self.randomDataIds(n=1).subset(
804 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"])
805 ):
806 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
807 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
809 def testTimespans(self):
810 """Test that data IDs for a few known dimensions have the expected
811 timespans.
812 """
813 for dataId in self.randomDataIds(n=4).subset(
814 DimensionGraph(self.allDataIds.universe, names=["visit"])
815 ):
816 self.assertIsNotNone(dataId.timespan)
817 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"})
818 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
819 self.assertEqual(dataId.timespan, dataId.visit.timespan)
820 # Also test the case for non-temporal DataIds.
821 for dataId in self.randomDataIds(n=4).subset(
822 DimensionGraph(self.allDataIds.universe, names=["patch"])
823 ):
824 self.assertIsNone(dataId.timespan)
826 def testIterableStatusFlags(self):
827 """Test that DataCoordinateSet and DataCoordinateSequence compute
828 their hasFull and hasRecords flags correctly from their elements.
829 """
830 dataIds = self.randomDataIds(n=10)
831 split = self.splitByStateFlags(dataIds)
832 for cls in (DataCoordinateSet, DataCoordinateSequence):
833 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull())
834 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull())
835 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords())
836 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords())
837 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull())
838 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull())
839 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords())
840 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords())
841 with self.assertRaises(ValueError):
842 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True)
843 self.assertEqual(
844 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied
845 )
846 self.assertEqual(
847 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied
848 )
849 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords())
850 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords())
851 with self.assertRaises(ValueError):
852 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True)
853 if dataIds.graph.implied:
854 with self.assertRaises(ValueError):
855 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True)
857 def testSetOperations(self):
858 """Test for self-consistency across DataCoordinateSet's operations."""
859 c = self.randomDataIds(n=10).toSet()
860 a = self.randomDataIds(n=20).toSet() | c
861 b = self.randomDataIds(n=20).toSet() | c
862 # Make sure we don't have a particularly unlucky random seed, since
863 # that would make a lot of this test uninteresting.
864 self.assertNotEqual(a, b)
865 self.assertGreater(len(a), 0)
866 self.assertGreater(len(b), 0)
867 # The rest of the tests should not depend on the random seed.
868 self.assertEqual(a, a)
869 self.assertNotEqual(a, a.toSequence())
870 self.assertEqual(a, a.toSequence().toSet())
871 self.assertEqual(a, a.toSequence().toSet())
872 self.assertEqual(b, b)
873 self.assertNotEqual(b, b.toSequence())
874 self.assertEqual(b, b.toSequence().toSet())
875 self.assertEqual(a & b, a.intersection(b))
876 self.assertLessEqual(a & b, a)
877 self.assertLessEqual(a & b, b)
878 self.assertEqual(a | b, a.union(b))
879 self.assertGreaterEqual(a | b, a)
880 self.assertGreaterEqual(a | b, b)
881 self.assertEqual(a - b, a.difference(b))
882 self.assertLessEqual(a - b, a)
883 self.assertLessEqual(b - a, b)
884 self.assertEqual(a ^ b, a.symmetric_difference(b))
885 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
887 def testPackers(self):
888 (instrument_data_id,) = self.allDataIds.subset(
889 self.allDataIds.universe.extract(["instrument"])
890 ).toSet()
891 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"]))
892 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph)
893 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
894 self.assertEqual(packed_id, detector_data_id["detector"])
895 self.assertEqual(max_bits, packer.maxBits)
896 self.assertEqual(
897 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
898 )
899 self.assertEqual(packer.pack(detector_data_id), packed_id)
900 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
901 self.assertEqual(packer.unpack(packed_id), detector_data_id)
904if __name__ == "__main__":
905 unittest.main()