Coverage for tests/test_dimensions.py: 10%
441 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 10:10 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 10:10 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import copy
23import itertools
24import math
25import os
26import pickle
27import unittest
28from collections.abc import Iterator
29from dataclasses import dataclass
30from random import Random
32import lsst.sphgeom
33from lsst.daf.butler import (
34 Config,
35 DataCoordinate,
36 DataCoordinateSequence,
37 DataCoordinateSet,
38 Dimension,
39 DimensionConfig,
40 DimensionGraph,
41 DimensionPacker,
42 DimensionUniverse,
43 NamedKeyDict,
44 NamedValueSet,
45 Registry,
46 TimespanDatabaseRepresentation,
47 YamlRepoImportBackend,
48)
49from lsst.daf.butler.registry import RegistryConfig
51DIMENSION_DATA_FILE = os.path.normpath(
52 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
53)
56def loadDimensionData() -> DataCoordinateSequence:
57 """Load dimension data from an export file included in the code repository.
59 Returns
60 -------
61 dataIds : `DataCoordinateSet`
62 A set containing all data IDs in the export file.
63 """
64 # Create an in-memory SQLite database and Registry just to import the YAML
65 # data and retreive it as a set of DataCoordinate objects.
66 config = RegistryConfig()
67 config["db"] = "sqlite://"
68 registry = Registry.createFromConfig(config)
69 with open(DIMENSION_DATA_FILE) as stream:
70 backend = YamlRepoImportBackend(stream, registry)
71 backend.register()
72 backend.load(datastore=None)
73 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"])
74 return registry.queryDataIds(dimensions).expanded().toSequence()
77class ConcreteTestDimensionPacker(DimensionPacker):
78 """A concrete `DimensionPacker` for testing its base class implementations.
80 This class just returns the detector ID as-is.
81 """
83 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph):
84 super().__init__(fixed, dimensions)
85 self._n_detectors = fixed.records["instrument"].detector_max
86 self._max_bits = (self._n_detectors - 1).bit_length()
88 @property
89 def maxBits(self) -> int:
90 # Docstring inherited from DimensionPacker.maxBits
91 return self._max_bits
93 def _pack(self, dataId: DataCoordinate) -> int:
94 # Docstring inherited from DimensionPacker._pack
95 return dataId["detector"]
97 def unpack(self, packedId: int) -> DataCoordinate:
98 # Docstring inherited from DimensionPacker.unpack
99 return DataCoordinate.standardize(
100 {
101 "instrument": self.fixed["instrument"],
102 "detector": packedId,
103 },
104 graph=self.dimensions,
105 )
108class DimensionTestCase(unittest.TestCase):
109 """Tests for dimensions.
111 All tests here rely on the content of ``config/dimensions.yaml``, either
112 to test that the definitions there are read in properly or just as generic
113 data for testing various operations.
114 """
116 def setUp(self):
117 self.universe = DimensionUniverse()
119 def checkGraphInvariants(self, graph):
120 elements = list(graph.elements)
121 for n, element in enumerate(elements):
122 # Ordered comparisons on graphs behave like sets.
123 self.assertLessEqual(element.graph, graph)
124 # Ordered comparisons on elements correspond to the ordering within
125 # a DimensionUniverse (topological, with deterministic
126 # tiebreakers).
127 for other in elements[:n]:
128 self.assertLess(other, element)
129 self.assertLessEqual(other, element)
130 for other in elements[n + 1 :]:
131 self.assertGreater(other, element)
132 self.assertGreaterEqual(other, element)
133 if isinstance(element, Dimension):
134 self.assertEqual(element.graph.required, element.required)
135 self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
136 self.assertCountEqual(
137 graph.required,
138 [
139 dimension
140 for dimension in graph.dimensions
141 if not any(dimension in other.graph.implied for other in graph.elements)
142 ],
143 )
144 self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
145 self.assertCountEqual(
146 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)]
147 )
148 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied))
149 # Check primary key traversal order: each element should follow any it
150 # requires, and element that is implied by any other in the graph
151 # follow at least one of those.
152 seen = NamedValueSet()
153 for element in graph.primaryKeyTraversalOrder:
154 with self.subTest(required=graph.required, implied=graph.implied, element=element):
155 seen.add(element)
156 self.assertLessEqual(element.graph.required, seen)
157 if element in graph.implied:
158 self.assertTrue(any(element in s.implied for s in seen))
159 self.assertCountEqual(seen, graph.elements)
161 def testConfigPresent(self):
162 config = self.universe.dimensionConfig
163 self.assertIsInstance(config, DimensionConfig)
165 def testCompatibility(self):
166 # Simple check that should always be true.
167 self.assertTrue(self.universe.isCompatibleWith(self.universe))
169 # Create a universe like the default universe but with a different
170 # version number.
171 clone = self.universe.dimensionConfig.copy()
172 clone["version"] = clone["version"] + 1_000_000 # High version number
173 universe_clone = DimensionUniverse(config=clone)
174 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm:
175 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
176 self.assertIn("differing versions", "\n".join(cm.output))
178 # Create completely incompatible universe.
179 config = Config(
180 {
181 "version": 1,
182 "namespace": "compat_test",
183 "skypix": {
184 "common": "htm7",
185 "htm": {
186 "class": "lsst.sphgeom.HtmPixelization",
187 "max_level": 24,
188 },
189 },
190 "elements": {
191 "A": {
192 "keys": [
193 {
194 "name": "id",
195 "type": "int",
196 }
197 ],
198 "storage": {
199 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
200 },
201 },
202 "B": {
203 "keys": [
204 {
205 "name": "id",
206 "type": "int",
207 }
208 ],
209 "storage": {
210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
211 },
212 },
213 },
214 "packers": {},
215 }
216 )
217 universe2 = DimensionUniverse(config=config)
218 self.assertFalse(universe2.isCompatibleWith(self.universe))
220 def testVersion(self):
221 self.assertEqual(self.universe.namespace, "daf_butler")
222 # Test was added starting at version 2.
223 self.assertGreaterEqual(self.universe.version, 2)
225 def testConfigRead(self):
226 self.assertEqual(
227 set(self.universe.getStaticDimensions().names),
228 {
229 "instrument",
230 "visit",
231 "visit_system",
232 "exposure",
233 "detector",
234 "physical_filter",
235 "band",
236 "subfilter",
237 "skymap",
238 "tract",
239 "patch",
240 }
241 | {f"htm{level}" for level in range(25)}
242 | {f"healpix{level}" for level in range(18)},
243 )
245 def testGraphs(self):
246 self.checkGraphInvariants(self.universe.empty)
247 for element in self.universe.getStaticElements():
248 self.checkGraphInvariants(element.graph)
250 def testInstrumentDimensions(self):
251 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
252 self.assertCountEqual(
253 graph.dimensions.names,
254 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"),
255 )
256 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit"))
257 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
258 self.assertCountEqual(
259 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
260 )
261 self.assertCountEqual(graph.governors.names, {"instrument"})
263 def testCalibrationDimensions(self):
264 graph = DimensionGraph(self.universe, names=("physical_filter", "detector"))
265 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band"))
266 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter"))
267 self.assertCountEqual(graph.implied.names, ("band",))
268 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
269 self.assertCountEqual(graph.governors.names, {"instrument"})
271 def testObservationDimensions(self):
272 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
273 self.assertCountEqual(
274 graph.dimensions.names,
275 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"),
276 )
277 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit"))
278 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
279 self.assertCountEqual(
280 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
281 )
282 self.assertCountEqual(graph.spatial.names, ("observation_regions",))
283 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
284 self.assertCountEqual(graph.governors.names, {"instrument"})
285 self.assertEqual(graph.spatial.names, {"observation_regions"})
286 self.assertEqual(graph.temporal.names, {"observation_timespans"})
287 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"])
288 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"])
290 def testSkyMapDimensions(self):
291 graph = DimensionGraph(self.universe, names=("patch",))
292 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch"))
293 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch"))
294 self.assertCountEqual(graph.implied.names, ())
295 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
296 self.assertCountEqual(graph.spatial.names, ("skymap_regions",))
297 self.assertCountEqual(graph.governors.names, {"skymap"})
298 self.assertEqual(graph.spatial.names, {"skymap_regions"})
299 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"])
301 def testSubsetCalculation(self):
302 """Test that independent spatial and temporal options are computed
303 correctly.
304 """
305 graph = DimensionGraph(
306 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure")
307 )
308 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm"))
309 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
311 def testSchemaGeneration(self):
312 tableSpecs = NamedKeyDict({})
313 for element in self.universe.getStaticElements():
314 if element.hasTable and element.viewOf is None:
315 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
316 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
317 )
318 for element, tableSpec in tableSpecs.items():
319 for dep in element.required:
320 with self.subTest(element=element.name, dep=dep.name):
321 if dep != element:
322 self.assertIn(dep.name, tableSpec.fields)
323 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
324 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
325 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
326 self.assertFalse(tableSpec.fields[dep.name].nullable)
327 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
328 else:
329 self.assertIn(element.primaryKey.name, tableSpec.fields)
330 self.assertEqual(
331 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
332 )
333 self.assertEqual(
334 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
335 )
336 self.assertEqual(
337 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
338 )
339 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
340 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
341 for dep in element.implied:
342 with self.subTest(element=element.name, dep=dep.name):
343 self.assertIn(dep.name, tableSpec.fields)
344 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
345 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
346 for foreignKey in tableSpec.foreignKeys:
347 self.assertIn(foreignKey.table, tableSpecs)
348 self.assertIn(foreignKey.table, element.graph.dimensions.names)
349 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
350 for source, target in zip(foreignKey.source, foreignKey.target):
351 self.assertIn(source, tableSpec.fields.names)
352 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
353 self.assertEqual(
354 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
355 )
356 self.assertEqual(
357 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
358 )
359 self.assertEqual(
360 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
361 )
363 def testPickling(self):
364 # Pickling and copying should always yield the exact same object within
365 # a single process (cross-process is impossible to test here).
366 universe1 = DimensionUniverse()
367 universe2 = pickle.loads(pickle.dumps(universe1))
368 universe3 = copy.copy(universe1)
369 universe4 = copy.deepcopy(universe1)
370 self.assertIs(universe1, universe2)
371 self.assertIs(universe1, universe3)
372 self.assertIs(universe1, universe4)
373 for element1 in universe1.getStaticElements():
374 element2 = pickle.loads(pickle.dumps(element1))
375 self.assertIs(element1, element2)
376 graph1 = element1.graph
377 graph2 = pickle.loads(pickle.dumps(graph1))
378 self.assertIs(graph1, graph2)
381@dataclass
382class SplitByStateFlags:
383 """A struct that separates data IDs with different states but the same
384 values.
385 """
387 minimal: DataCoordinateSequence | None = None
388 """Data IDs that only contain values for required dimensions.
390 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
391 if ``minimal.graph.implied`` has no elements.
392 `DataCoordinate.hasRecords()` will always return `False`.
393 """
395 complete: DataCoordinateSequence | None = None
396 """Data IDs that contain values for all dimensions.
398 `DataCoordinateSequence.hasFull()` will always `True` and
399 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
400 """
402 expanded: DataCoordinateSequence | None = None
403 """Data IDs that contain values for all dimensions as well as records.
405 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
406 always return `True` for this attribute.
407 """
409 def chain(self, n: int | None = None) -> Iterator:
410 """Iterate over the data IDs of different types.
412 Parameters
413 ----------
414 n : `int`, optional
415 If provided (`None` is default), iterate over only the ``nth``
416 data ID in each attribute.
418 Yields
419 ------
420 dataId : `DataCoordinate`
421 A data ID from one of the attributes in this struct.
422 """
423 if n is None:
424 s = slice(None, None)
425 else:
426 s = slice(n, n + 1)
427 if self.minimal is not None:
428 yield from self.minimal[s]
429 if self.complete is not None:
430 yield from self.complete[s]
431 if self.expanded is not None:
432 yield from self.expanded[s]
435class DataCoordinateTestCase(unittest.TestCase):
436 """Test `DataCoordinate`."""
438 RANDOM_SEED = 10
440 @classmethod
441 def setUpClass(cls):
442 cls.allDataIds = loadDimensionData()
444 def setUp(self):
445 self.rng = Random(self.RANDOM_SEED)
447 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None):
448 """Select random data IDs from those loaded from test data.
450 Parameters
451 ----------
452 n : `int`
453 Number of data IDs to select.
454 dataIds : `DataCoordinateSequence`, optional
455 Data IDs to select from. Defaults to ``self.allDataIds``.
457 Returns
458 -------
459 selected : `DataCoordinateSequence`
460 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
461 """
462 if dataIds is None:
463 dataIds = self.allDataIds
464 return DataCoordinateSequence(
465 self.rng.sample(dataIds, n),
466 graph=dataIds.graph,
467 hasFull=dataIds.hasFull(),
468 hasRecords=dataIds.hasRecords(),
469 check=False,
470 )
472 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph:
473 """Generate a random `DimensionGraph` that has a subset of the
474 dimensions in a given one.
476 Parameters
477 ----------
478 n : `int`
479 Number of dimensions to select, before automatic expansion by
480 `DimensionGraph`.
481 dataIds : `DimensionGraph`, optional
482 Dimensions to select from. Defaults to ``self.allDataIds.graph``.
484 Returns
485 -------
486 selected : `DimensionGraph`
487 ``n`` or more dimensions randomly selected from ``graph`` with
488 replacement.
489 """
490 if graph is None:
491 graph = self.allDataIds.graph
492 return DimensionGraph(
493 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions)))
494 )
496 def splitByStateFlags(
497 self,
498 dataIds: DataCoordinateSequence | None = None,
499 *,
500 expanded: bool = True,
501 complete: bool = True,
502 minimal: bool = True,
503 ) -> SplitByStateFlags:
504 """Given a sequence of data IDs, generate new equivalent sequences
505 containing less information.
507 Parameters
508 ----------
509 dataIds : `DataCoordinateSequence`, optional.
510 Data IDs to start from. Defaults to ``self.allDataIds``.
511 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
512 `True`.
513 expanded : `bool`, optional
514 If `True` (default) include the original data IDs that contain all
515 information in the result.
516 complete : `bool`, optional
517 If `True` (default) include data IDs for which ``hasFull()``
518 returns `True` but ``hasRecords()`` does not.
519 minimal : `bool`, optional
520 If `True` (default) include data IDS that only contain values for
521 required dimensions, for which ``hasFull()`` may not return `True`.
523 Returns
524 -------
525 split : `SplitByStateFlags`
526 A dataclass holding the indicated data IDs in attributes that
527 correspond to the boolean keyword arguments.
528 """
529 if dataIds is None:
530 dataIds = self.allDataIds
531 assert dataIds.hasFull() and dataIds.hasRecords()
532 result = SplitByStateFlags(expanded=dataIds)
533 if complete:
534 result.complete = DataCoordinateSequence(
535 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded],
536 graph=dataIds.graph,
537 )
538 self.assertTrue(result.complete.hasFull())
539 self.assertFalse(result.complete.hasRecords())
540 if minimal:
541 result.minimal = DataCoordinateSequence(
542 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded],
543 graph=dataIds.graph,
544 )
545 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied)
546 self.assertFalse(result.minimal.hasRecords())
547 if not expanded:
548 result.expanded = None
549 return result
551 def testMappingInterface(self):
552 """Test that the mapping interface in `DataCoordinate` and (when
553 applicable) its ``full`` property are self-consistent and consistent
554 with the ``graph`` property.
555 """
556 for n in range(5):
557 dimensions = self.randomDimensionSubset()
558 dataIds = self.randomDataIds(n=1).subset(dimensions)
559 split = self.splitByStateFlags(dataIds)
560 for dataId in split.chain():
561 with self.subTest(dataId=dataId):
562 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()])
563 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()])
564 self.assertEqual(dataId.keys(), dataId.graph.required)
565 for dataId in itertools.chain(split.complete, split.expanded):
566 with self.subTest(dataId=dataId):
567 self.assertTrue(dataId.hasFull())
568 self.assertEqual(dataId.graph.dimensions, dataId.full.keys())
569 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions])
571 def test_pickle(self):
572 for n in range(5):
573 dimensions = self.randomDimensionSubset()
574 dataIds = self.randomDataIds(n=1).subset(dimensions)
575 split = self.splitByStateFlags(dataIds)
576 for data_id in split.chain():
577 s = pickle.dumps(data_id)
578 read_data_id = pickle.loads(s)
579 self.assertEqual(data_id, read_data_id)
580 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
581 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
582 if data_id.hasFull():
583 self.assertEqual(data_id.full, read_data_id.full)
584 if data_id.hasRecords():
585 self.assertEqual(data_id.records, read_data_id.records)
587 def test_record_attributes(self):
588 """Test that dimension records are available as attributes on expanded
589 data coordinates.
590 """
591 for n in range(5):
592 dimensions = self.randomDimensionSubset()
593 dataIds = self.randomDataIds(n=1).subset(dimensions)
594 split = self.splitByStateFlags(dataIds)
595 for data_id in split.expanded:
596 for element in data_id.graph.elements:
597 self.assertIs(getattr(data_id, element.name), data_id.records[element.name])
598 self.assertIn(element.name, dir(data_id))
599 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
600 getattr(data_id, "not_a_dimension_name")
601 for data_id in itertools.chain(split.minimal, split.complete):
602 for element in data_id.graph.elements:
603 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
604 getattr(data_id, element.name)
605 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
606 getattr(data_id, "not_a_dimension_name")
608 def testEquality(self):
609 """Test that different `DataCoordinate` instances with different state
610 flags can be compared with each other and other mappings.
611 """
612 dataIds = self.randomDataIds(n=2)
613 split = self.splitByStateFlags(dataIds)
614 # Iterate over all combinations of different states of DataCoordinate,
615 # with the same underlying data ID values.
616 for a0, b0 in itertools.combinations(split.chain(0), 2):
617 self.assertEqual(a0, b0)
618 self.assertEqual(a0, b0.byName())
619 self.assertEqual(a0.byName(), b0)
620 # Same thing, for a different data ID value.
621 for a1, b1 in itertools.combinations(split.chain(1), 2):
622 self.assertEqual(a1, b1)
623 self.assertEqual(a1, b1.byName())
624 self.assertEqual(a1.byName(), b1)
625 # Iterate over all combinations of different states of DataCoordinate,
626 # with different underlying data ID values.
627 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
628 self.assertNotEqual(a0, b1)
629 self.assertNotEqual(a1, b0)
630 self.assertNotEqual(a0, b1.byName())
631 self.assertNotEqual(a0.byName(), b1)
632 self.assertNotEqual(a1, b0.byName())
633 self.assertNotEqual(a1.byName(), b0)
635 def testStandardize(self):
636 """Test constructing a DataCoordinate from many different kinds of
637 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
638 """
639 for n in range(5):
640 dimensions = self.randomDimensionSubset()
641 dataIds = self.randomDataIds(n=1).subset(dimensions)
642 split = self.splitByStateFlags(dataIds)
643 for m, dataId in enumerate(split.chain()):
644 # Passing in any kind of DataCoordinate alone just returns
645 # that object.
646 self.assertIs(dataId, DataCoordinate.standardize(dataId))
647 # Same if we also explicitly pass the dimensions we want.
648 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph))
649 # Same if we pass the dimensions and some irrelevant
650 # kwargs.
651 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12))
652 # Test constructing a new data ID from this one with a
653 # subset of the dimensions.
654 # This is not possible for some combinations of
655 # dimensions if hasFull is False (see
656 # `DataCoordinate.subset` docs).
657 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph)
658 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required):
659 newDataIds = [
660 dataId.subset(newDimensions),
661 DataCoordinate.standardize(dataId, graph=newDimensions),
662 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12),
663 ]
664 for newDataId in newDataIds:
665 with self.subTest(newDataId=newDataId, type=type(dataId)):
666 commonKeys = dataId.keys() & newDataId.keys()
667 self.assertTrue(commonKeys)
668 self.assertEqual(
669 [newDataId[k] for k in commonKeys],
670 [dataId[k] for k in commonKeys],
671 )
672 # This should never "downgrade" from
673 # Complete to Minimal or Expanded to Complete.
674 if dataId.hasRecords():
675 self.assertTrue(newDataId.hasRecords())
676 if dataId.hasFull():
677 self.assertTrue(newDataId.hasFull())
678 # Start from a complete data ID, and pass its values in via several
679 # different ways that should be equivalent.
680 for dataId in split.complete:
681 # Split the keys (dimension names) into two random subsets, so
682 # we can pass some as kwargs below.
683 keys1 = set(
684 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2)
685 )
686 keys2 = dataId.graph.dimensions.names - keys1
687 newCompleteDataIds = [
688 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe),
689 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph),
690 DataCoordinate.standardize(
691 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName()
692 ),
693 DataCoordinate.standardize(
694 DataCoordinate.makeEmpty(dataId.graph.universe),
695 graph=dataId.graph,
696 **dataId.full.byName(),
697 ),
698 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe),
699 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()),
700 DataCoordinate.standardize(
701 {k: dataId[k] for k in keys1},
702 universe=dataId.universe,
703 **{k: dataId[k] for k in keys2},
704 ),
705 DataCoordinate.standardize(
706 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2}
707 ),
708 ]
709 for newDataId in newCompleteDataIds:
710 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
711 self.assertEqual(dataId, newDataId)
712 self.assertTrue(newDataId.hasFull())
714 def testUnion(self):
715 """Test `DataCoordinate.union`."""
716 # Make test graphs to combine; mostly random, but with a few explicit
717 # cases to make sure certain edge cases are covered.
718 graphs = [self.randomDimensionSubset(n=2) for i in range(2)]
719 graphs.append(self.allDataIds.universe["visit"].graph)
720 graphs.append(self.allDataIds.universe["detector"].graph)
721 graphs.append(self.allDataIds.universe["physical_filter"].graph)
722 graphs.append(self.allDataIds.universe["band"].graph)
723 # Iterate over all combinations, including the same graph with itself.
724 for graph1, graph2 in itertools.product(graphs, repeat=2):
725 parentDataIds = self.randomDataIds(n=1)
726 split1 = self.splitByStateFlags(parentDataIds.subset(graph1))
727 split2 = self.splitByStateFlags(parentDataIds.subset(graph2))
728 (parentDataId,) = parentDataIds
729 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
730 unioned = lhs.union(rhs)
731 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
732 self.assertEqual(unioned.graph, graph1.union(graph2))
733 self.assertEqual(unioned, parentDataId.subset(unioned.graph))
734 if unioned.hasFull():
735 self.assertEqual(unioned.subset(lhs.graph), lhs)
736 self.assertEqual(unioned.subset(rhs.graph), rhs)
737 if lhs.hasFull() and rhs.hasFull():
738 self.assertTrue(unioned.hasFull())
739 if lhs.graph >= unioned.graph and lhs.hasFull():
740 self.assertTrue(unioned.hasFull())
741 if lhs.hasRecords():
742 self.assertTrue(unioned.hasRecords())
743 if rhs.graph >= unioned.graph and rhs.hasFull():
744 self.assertTrue(unioned.hasFull())
745 if rhs.hasRecords():
746 self.assertTrue(unioned.hasRecords())
747 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions:
748 self.assertTrue(unioned.hasFull())
749 if lhs.hasRecords() and rhs.hasRecords():
750 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements:
751 self.assertTrue(unioned.hasRecords())
753 def testRegions(self):
754 """Test that data IDs for a few known dimensions have the expected
755 regions.
756 """
757 for dataId in self.randomDataIds(n=4).subset(
758 DimensionGraph(self.allDataIds.universe, names=["visit"])
759 ):
760 self.assertIsNotNone(dataId.region)
761 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
762 self.assertEqual(dataId.region, dataId.records["visit"].region)
763 for dataId in self.randomDataIds(n=4).subset(
764 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"])
765 ):
766 self.assertIsNotNone(dataId.region)
767 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
768 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
769 for dataId in self.randomDataIds(n=4).subset(
770 DimensionGraph(self.allDataIds.universe, names=["tract"])
771 ):
772 self.assertIsNotNone(dataId.region)
773 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
774 self.assertEqual(dataId.region, dataId.records["tract"].region)
775 for dataId in self.randomDataIds(n=4).subset(
776 DimensionGraph(self.allDataIds.universe, names=["patch"])
777 ):
778 self.assertIsNotNone(dataId.region)
779 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
780 self.assertEqual(dataId.region, dataId.records["patch"].region)
781 for data_id in self.randomDataIds(n=1).subset(
782 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"])
783 ):
784 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
785 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
787 def testTimespans(self):
788 """Test that data IDs for a few known dimensions have the expected
789 timespans.
790 """
791 for dataId in self.randomDataIds(n=4).subset(
792 DimensionGraph(self.allDataIds.universe, names=["visit"])
793 ):
794 self.assertIsNotNone(dataId.timespan)
795 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"})
796 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
797 self.assertEqual(dataId.timespan, dataId.visit.timespan)
798 # Also test the case for non-temporal DataIds.
799 for dataId in self.randomDataIds(n=4).subset(
800 DimensionGraph(self.allDataIds.universe, names=["patch"])
801 ):
802 self.assertIsNone(dataId.timespan)
804 def testIterableStatusFlags(self):
805 """Test that DataCoordinateSet and DataCoordinateSequence compute
806 their hasFull and hasRecords flags correctly from their elements.
807 """
808 dataIds = self.randomDataIds(n=10)
809 split = self.splitByStateFlags(dataIds)
810 for cls in (DataCoordinateSet, DataCoordinateSequence):
811 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull())
812 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull())
813 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords())
814 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords())
815 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull())
816 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull())
817 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords())
818 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords())
819 with self.assertRaises(ValueError):
820 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True)
821 self.assertEqual(
822 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied
823 )
824 self.assertEqual(
825 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied
826 )
827 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords())
828 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords())
829 with self.assertRaises(ValueError):
830 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True)
831 if dataIds.graph.implied:
832 with self.assertRaises(ValueError):
833 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True)
835 def testSetOperations(self):
836 """Test for self-consistency across DataCoordinateSet's operations."""
837 c = self.randomDataIds(n=10).toSet()
838 a = self.randomDataIds(n=20).toSet() | c
839 b = self.randomDataIds(n=20).toSet() | c
840 # Make sure we don't have a particularly unlucky random seed, since
841 # that would make a lot of this test uninteresting.
842 self.assertNotEqual(a, b)
843 self.assertGreater(len(a), 0)
844 self.assertGreater(len(b), 0)
845 # The rest of the tests should not depend on the random seed.
846 self.assertEqual(a, a)
847 self.assertNotEqual(a, a.toSequence())
848 self.assertEqual(a, a.toSequence().toSet())
849 self.assertEqual(a, a.toSequence().toSet())
850 self.assertEqual(b, b)
851 self.assertNotEqual(b, b.toSequence())
852 self.assertEqual(b, b.toSequence().toSet())
853 self.assertEqual(a & b, a.intersection(b))
854 self.assertLessEqual(a & b, a)
855 self.assertLessEqual(a & b, b)
856 self.assertEqual(a | b, a.union(b))
857 self.assertGreaterEqual(a | b, a)
858 self.assertGreaterEqual(a | b, b)
859 self.assertEqual(a - b, a.difference(b))
860 self.assertLessEqual(a - b, a)
861 self.assertLessEqual(b - a, b)
862 self.assertEqual(a ^ b, a.symmetric_difference(b))
863 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
865 def testPackers(self):
866 (instrument_data_id,) = self.allDataIds.subset(
867 self.allDataIds.universe.extract(["instrument"])
868 ).toSet()
869 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"]))
870 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph)
871 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
872 self.assertEqual(packed_id, detector_data_id["detector"])
873 self.assertEqual(max_bits, packer.maxBits)
874 self.assertEqual(
875 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
876 )
877 self.assertEqual(packer.pack(detector_data_id), packed_id)
878 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
879 self.assertEqual(packer.unpack(packed_id), detector_data_id)
882if __name__ == "__main__":
883 unittest.main()