Coverage for tests/test_dimensions.py: 10%
441 statements
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-29 02:58 -0700
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-29 02:58 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import copy
23import itertools
24import math
25import os
26import pickle
27import unittest
28from dataclasses import dataclass
29from random import Random
30from typing import Iterator, Optional
32import lsst.sphgeom
33from lsst.daf.butler import (
34 Config,
35 DataCoordinate,
36 DataCoordinateSequence,
37 DataCoordinateSet,
38 Dimension,
39 DimensionConfig,
40 DimensionGraph,
41 DimensionPacker,
42 DimensionUniverse,
43 NamedKeyDict,
44 NamedValueSet,
45 Registry,
46 TimespanDatabaseRepresentation,
47 YamlRepoImportBackend,
48)
49from lsst.daf.butler.registry import RegistryConfig
51DIMENSION_DATA_FILE = os.path.normpath(
52 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
53)
56def loadDimensionData() -> DataCoordinateSequence:
57 """Load dimension data from an export file included in the code repository.
59 Returns
60 -------
61 dataIds : `DataCoordinateSet`
62 A set containing all data IDs in the export file.
63 """
64 # Create an in-memory SQLite database and Registry just to import the YAML
65 # data and retreive it as a set of DataCoordinate objects.
66 config = RegistryConfig()
67 config["db"] = "sqlite://"
68 registry = Registry.createFromConfig(config)
69 with open(DIMENSION_DATA_FILE, "r") as stream:
70 backend = YamlRepoImportBackend(stream, registry)
71 backend.register()
72 backend.load(datastore=None)
73 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"])
74 return registry.queryDataIds(dimensions).expanded().toSequence()
77class TestDimensionPacker(DimensionPacker):
78 """A concrete `DimensionPacker` for testing its base class implementations.
80 This class just returns the detector ID as-is.
81 """
83 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph):
84 super().__init__(fixed, dimensions)
85 self._n_detectors = fixed.records["instrument"].detector_max
86 self._max_bits = (self._n_detectors - 1).bit_length()
88 @property
89 def maxBits(self) -> int:
90 # Docstring inherited from DimensionPacker.maxBits
91 return self._max_bits
93 def _pack(self, dataId: DataCoordinate) -> int:
94 # Docstring inherited from DimensionPacker._pack
95 return dataId["detector"]
97 def unpack(self, packedId: int) -> DataCoordinate:
98 # Docstring inherited from DimensionPacker.unpack
99 return DataCoordinate.standardize(
100 {
101 "instrument": self.fixed["instrument"],
102 "detector": packedId,
103 },
104 graph=self.dimensions,
105 )
108class DimensionTestCase(unittest.TestCase):
109 """Tests for dimensions.
111 All tests here rely on the content of ``config/dimensions.yaml``, either
112 to test that the definitions there are read in properly or just as generic
113 data for testing various operations.
114 """
116 def setUp(self):
117 self.universe = DimensionUniverse()
119 def checkGraphInvariants(self, graph):
120 elements = list(graph.elements)
121 for n, element in enumerate(elements):
122 # Ordered comparisons on graphs behave like sets.
123 self.assertLessEqual(element.graph, graph)
124 # Ordered comparisons on elements correspond to the ordering within
125 # a DimensionUniverse (topological, with deterministic
126 # tiebreakers).
127 for other in elements[:n]:
128 self.assertLess(other, element)
129 self.assertLessEqual(other, element)
130 for other in elements[n + 1 :]:
131 self.assertGreater(other, element)
132 self.assertGreaterEqual(other, element)
133 if isinstance(element, Dimension):
134 self.assertEqual(element.graph.required, element.required)
135 self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
136 self.assertCountEqual(
137 graph.required,
138 [
139 dimension
140 for dimension in graph.dimensions
141 if not any(dimension in other.graph.implied for other in graph.elements)
142 ],
143 )
144 self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
145 self.assertCountEqual(
146 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)]
147 )
148 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied))
149 # Check primary key traversal order: each element should follow any it
150 # requires, and element that is implied by any other in the graph
151 # follow at least one of those.
152 seen = NamedValueSet()
153 for element in graph.primaryKeyTraversalOrder:
154 with self.subTest(required=graph.required, implied=graph.implied, element=element):
155 seen.add(element)
156 self.assertLessEqual(element.graph.required, seen)
157 if element in graph.implied:
158 self.assertTrue(any(element in s.implied for s in seen))
159 self.assertCountEqual(seen, graph.elements)
161 def testConfigPresent(self):
162 config = self.universe.dimensionConfig
163 self.assertIsInstance(config, DimensionConfig)
165 def testCompatibility(self):
166 # Simple check that should always be true.
167 self.assertTrue(self.universe.isCompatibleWith(self.universe))
169 # Create a universe like the default universe but with a different
170 # version number.
171 clone = self.universe.dimensionConfig.copy()
172 clone["version"] = clone["version"] + 1_000_000 # High version number
173 universe_clone = DimensionUniverse(config=clone)
174 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm:
175 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
176 self.assertIn("differing versions", "\n".join(cm.output))
178 # Create completely incompatible universe.
179 config = Config(
180 {
181 "version": 1,
182 "namespace": "compat_test",
183 "skypix": {
184 "common": "htm7",
185 "htm": {
186 "class": "lsst.sphgeom.HtmPixelization",
187 "max_level": 24,
188 },
189 },
190 "elements": {
191 "A": {
192 "keys": [
193 {
194 "name": "id",
195 "type": "int",
196 }
197 ],
198 "storage": {
199 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
200 },
201 },
202 "B": {
203 "keys": [
204 {
205 "name": "id",
206 "type": "int",
207 }
208 ],
209 "storage": {
210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
211 },
212 },
213 },
214 "packers": {},
215 }
216 )
217 universe2 = DimensionUniverse(config=config)
218 self.assertFalse(universe2.isCompatibleWith(self.universe))
220 def testVersion(self):
221 self.assertEqual(self.universe.namespace, "daf_butler")
222 # Test was added starting at version 2.
223 self.assertGreaterEqual(self.universe.version, 2)
225 def testConfigRead(self):
226 self.assertEqual(
227 set(self.universe.getStaticDimensions().names),
228 {
229 "instrument",
230 "visit",
231 "visit_system",
232 "exposure",
233 "detector",
234 "physical_filter",
235 "band",
236 "subfilter",
237 "skymap",
238 "tract",
239 "patch",
240 }
241 | {f"htm{level}" for level in range(25)}
242 | {f"healpix{level}" for level in range(18)},
243 )
245 def testGraphs(self):
246 self.checkGraphInvariants(self.universe.empty)
247 for element in self.universe.getStaticElements():
248 self.checkGraphInvariants(element.graph)
250 def testInstrumentDimensions(self):
251 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
252 self.assertCountEqual(
253 graph.dimensions.names,
254 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"),
255 )
256 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit"))
257 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
258 self.assertCountEqual(
259 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
260 )
261 self.assertCountEqual(graph.governors.names, {"instrument"})
263 def testCalibrationDimensions(self):
264 graph = DimensionGraph(self.universe, names=("physical_filter", "detector"))
265 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band"))
266 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter"))
267 self.assertCountEqual(graph.implied.names, ("band",))
268 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
269 self.assertCountEqual(graph.governors.names, {"instrument"})
271 def testObservationDimensions(self):
272 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
273 self.assertCountEqual(
274 graph.dimensions.names,
275 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"),
276 )
277 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit"))
278 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
279 self.assertCountEqual(
280 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
281 )
282 self.assertCountEqual(graph.spatial.names, ("observation_regions",))
283 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
284 self.assertCountEqual(graph.governors.names, {"instrument"})
285 self.assertEqual(graph.spatial.names, {"observation_regions"})
286 self.assertEqual(graph.temporal.names, {"observation_timespans"})
287 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"])
288 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"])
290 def testSkyMapDimensions(self):
291 graph = DimensionGraph(self.universe, names=("patch",))
292 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch"))
293 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch"))
294 self.assertCountEqual(graph.implied.names, ())
295 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
296 self.assertCountEqual(graph.spatial.names, ("skymap_regions",))
297 self.assertCountEqual(graph.governors.names, {"skymap"})
298 self.assertEqual(graph.spatial.names, {"skymap_regions"})
299 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"])
301 def testSubsetCalculation(self):
302 """Test that independent spatial and temporal options are computed
303 correctly.
304 """
305 graph = DimensionGraph(
306 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure")
307 )
308 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm"))
309 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
311 def testSchemaGeneration(self):
312 tableSpecs = NamedKeyDict({})
313 for element in self.universe.getStaticElements():
314 if element.hasTable and element.viewOf is None:
315 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
316 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
317 )
318 for element, tableSpec in tableSpecs.items():
319 for dep in element.required:
320 with self.subTest(element=element.name, dep=dep.name):
321 if dep != element:
322 self.assertIn(dep.name, tableSpec.fields)
323 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
324 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
325 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
326 self.assertFalse(tableSpec.fields[dep.name].nullable)
327 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
328 else:
329 self.assertIn(element.primaryKey.name, tableSpec.fields)
330 self.assertEqual(
331 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
332 )
333 self.assertEqual(
334 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
335 )
336 self.assertEqual(
337 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
338 )
339 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
340 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
341 for dep in element.implied:
342 with self.subTest(element=element.name, dep=dep.name):
343 self.assertIn(dep.name, tableSpec.fields)
344 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
345 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
346 for foreignKey in tableSpec.foreignKeys:
347 self.assertIn(foreignKey.table, tableSpecs)
348 self.assertIn(foreignKey.table, element.graph.dimensions.names)
349 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
350 for source, target in zip(foreignKey.source, foreignKey.target):
351 self.assertIn(source, tableSpec.fields.names)
352 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
353 self.assertEqual(
354 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
355 )
356 self.assertEqual(
357 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
358 )
359 self.assertEqual(
360 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
361 )
363 def testPickling(self):
364 # Pickling and copying should always yield the exact same object within
365 # a single process (cross-process is impossible to test here).
366 universe1 = DimensionUniverse()
367 universe2 = pickle.loads(pickle.dumps(universe1))
368 universe3 = copy.copy(universe1)
369 universe4 = copy.deepcopy(universe1)
370 self.assertIs(universe1, universe2)
371 self.assertIs(universe1, universe3)
372 self.assertIs(universe1, universe4)
373 for element1 in universe1.getStaticElements():
374 element2 = pickle.loads(pickle.dumps(element1))
375 self.assertIs(element1, element2)
376 graph1 = element1.graph
377 graph2 = pickle.loads(pickle.dumps(graph1))
378 self.assertIs(graph1, graph2)
381@dataclass
382class SplitByStateFlags:
383 """A struct that separates data IDs with different states but the same
384 values.
385 """
387 minimal: Optional[DataCoordinateSequence] = None
388 """Data IDs that only contain values for required dimensions.
390 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
391 if ``minimal.graph.implied`` has no elements.
392 `DataCoordinate.hasRecords()` will always return `False`.
393 """
395 complete: Optional[DataCoordinateSequence] = None
396 """Data IDs that contain values for all dimensions.
398 `DataCoordinateSequence.hasFull()` will always `True` and
399 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
400 """
402 expanded: Optional[DataCoordinateSequence] = None
403 """Data IDs that contain values for all dimensions as well as records.
405 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
406 always return `True` for this attribute.
407 """
409 def chain(self, n: Optional[int] = None) -> Iterator:
410 """Iterate over the data IDs of different types.
412 Parameters
413 ----------
414 n : `int`, optional
415 If provided (`None` is default), iterate over only the ``nth``
416 data ID in each attribute.
418 Yields
419 ------
420 dataId : `DataCoordinate`
421 A data ID from one of the attributes in this struct.
422 """
423 if n is None:
424 s = slice(None, None)
425 else:
426 s = slice(n, n + 1)
427 if self.minimal is not None:
428 yield from self.minimal[s]
429 if self.complete is not None:
430 yield from self.complete[s]
431 if self.expanded is not None:
432 yield from self.expanded[s]
435class DataCoordinateTestCase(unittest.TestCase):
436 RANDOM_SEED = 10
438 @classmethod
439 def setUpClass(cls):
440 cls.allDataIds = loadDimensionData()
442 def setUp(self):
443 self.rng = Random(self.RANDOM_SEED)
445 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None):
446 """Select random data IDs from those loaded from test data.
448 Parameters
449 ----------
450 n : `int`
451 Number of data IDs to select.
452 dataIds : `DataCoordinateSequence`, optional
453 Data IDs to select from. Defaults to ``self.allDataIds``.
455 Returns
456 -------
457 selected : `DataCoordinateSequence`
458 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
459 """
460 if dataIds is None:
461 dataIds = self.allDataIds
462 return DataCoordinateSequence(
463 self.rng.sample(dataIds, n),
464 graph=dataIds.graph,
465 hasFull=dataIds.hasFull(),
466 hasRecords=dataIds.hasRecords(),
467 check=False,
468 )
470 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph:
471 """Generate a random `DimensionGraph` that has a subset of the
472 dimensions in a given one.
474 Parameters
475 ----------
476 n : `int`
477 Number of dimensions to select, before automatic expansion by
478 `DimensionGraph`.
479 dataIds : `DimensionGraph`, optional
480 Dimensions to select from. Defaults to ``self.allDataIds.graph``.
482 Returns
483 -------
484 selected : `DimensionGraph`
485 ``n`` or more dimensions randomly selected from ``graph`` with
486 replacement.
487 """
488 if graph is None:
489 graph = self.allDataIds.graph
490 return DimensionGraph(
491 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions)))
492 )
494 def splitByStateFlags(
495 self,
496 dataIds: Optional[DataCoordinateSequence] = None,
497 *,
498 expanded: bool = True,
499 complete: bool = True,
500 minimal: bool = True,
501 ) -> SplitByStateFlags:
502 """Given a sequence of data IDs, generate new equivalent sequences
503 containing less information.
505 Parameters
506 ----------
507 dataIds : `DataCoordinateSequence`, optional.
508 Data IDs to start from. Defaults to ``self.allDataIds``.
509 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
510 `True`.
511 expanded : `bool`, optional
512 If `True` (default) include the original data IDs that contain all
513 information in the result.
514 complete : `bool`, optional
515 If `True` (default) include data IDs for which ``hasFull()``
516 returns `True` but ``hasRecords()`` does not.
517 minimal : `bool`, optional
518 If `True` (default) include data IDS that only contain values for
519 required dimensions, for which ``hasFull()`` may not return `True`.
521 Returns
522 -------
523 split : `SplitByStateFlags`
524 A dataclass holding the indicated data IDs in attributes that
525 correspond to the boolean keyword arguments.
526 """
527 if dataIds is None:
528 dataIds = self.allDataIds
529 assert dataIds.hasFull() and dataIds.hasRecords()
530 result = SplitByStateFlags(expanded=dataIds)
531 if complete:
532 result.complete = DataCoordinateSequence(
533 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded],
534 graph=dataIds.graph,
535 )
536 self.assertTrue(result.complete.hasFull())
537 self.assertFalse(result.complete.hasRecords())
538 if minimal:
539 result.minimal = DataCoordinateSequence(
540 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded],
541 graph=dataIds.graph,
542 )
543 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied)
544 self.assertFalse(result.minimal.hasRecords())
545 if not expanded:
546 result.expanded = None
547 return result
549 def testMappingInterface(self):
550 """Test that the mapping interface in `DataCoordinate` and (when
551 applicable) its ``full`` property are self-consistent and consistent
552 with the ``graph`` property.
553 """
554 for n in range(5):
555 dimensions = self.randomDimensionSubset()
556 dataIds = self.randomDataIds(n=1).subset(dimensions)
557 split = self.splitByStateFlags(dataIds)
558 for dataId in split.chain():
559 with self.subTest(dataId=dataId):
560 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()])
561 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()])
562 self.assertEqual(dataId.keys(), dataId.graph.required)
563 for dataId in itertools.chain(split.complete, split.expanded):
564 with self.subTest(dataId=dataId):
565 self.assertTrue(dataId.hasFull())
566 self.assertEqual(dataId.graph.dimensions, dataId.full.keys())
567 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions])
569 def test_pickle(self):
570 for n in range(5):
571 dimensions = self.randomDimensionSubset()
572 dataIds = self.randomDataIds(n=1).subset(dimensions)
573 split = self.splitByStateFlags(dataIds)
574 for data_id in split.chain():
575 s = pickle.dumps(data_id)
576 read_data_id = pickle.loads(s)
577 self.assertEqual(data_id, read_data_id)
578 self.assertEqual(data_id.hasFull(), read_data_id.hasFull())
579 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords())
580 if data_id.hasFull():
581 self.assertEqual(data_id.full, read_data_id.full)
582 if data_id.hasRecords():
583 self.assertEqual(data_id.records, read_data_id.records)
585 def test_record_attributes(self):
586 """Test that dimension records are available as attributes on expanded
587 data coordinates.
588 """
589 for n in range(5):
590 dimensions = self.randomDimensionSubset()
591 dataIds = self.randomDataIds(n=1).subset(dimensions)
592 split = self.splitByStateFlags(dataIds)
593 for data_id in split.expanded:
594 for element in data_id.graph.elements:
595 self.assertIs(getattr(data_id, element.name), data_id.records[element.name])
596 self.assertIn(element.name, dir(data_id))
597 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
598 getattr(data_id, "not_a_dimension_name")
599 for data_id in itertools.chain(split.minimal, split.complete):
600 for element in data_id.graph.elements:
601 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"):
602 getattr(data_id, element.name)
603 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"):
604 getattr(data_id, "not_a_dimension_name")
606 def testEquality(self):
607 """Test that different `DataCoordinate` instances with different state
608 flags can be compared with each other and other mappings.
609 """
610 dataIds = self.randomDataIds(n=2)
611 split = self.splitByStateFlags(dataIds)
612 # Iterate over all combinations of different states of DataCoordinate,
613 # with the same underlying data ID values.
614 for a0, b0 in itertools.combinations(split.chain(0), 2):
615 self.assertEqual(a0, b0)
616 self.assertEqual(a0, b0.byName())
617 self.assertEqual(a0.byName(), b0)
618 # Same thing, for a different data ID value.
619 for a1, b1 in itertools.combinations(split.chain(1), 2):
620 self.assertEqual(a1, b1)
621 self.assertEqual(a1, b1.byName())
622 self.assertEqual(a1.byName(), b1)
623 # Iterate over all combinations of different states of DataCoordinate,
624 # with different underlying data ID values.
625 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
626 self.assertNotEqual(a0, b1)
627 self.assertNotEqual(a1, b0)
628 self.assertNotEqual(a0, b1.byName())
629 self.assertNotEqual(a0.byName(), b1)
630 self.assertNotEqual(a1, b0.byName())
631 self.assertNotEqual(a1.byName(), b0)
633 def testStandardize(self):
634 """Test constructing a DataCoordinate from many different kinds of
635 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
636 """
637 for n in range(5):
638 dimensions = self.randomDimensionSubset()
639 dataIds = self.randomDataIds(n=1).subset(dimensions)
640 split = self.splitByStateFlags(dataIds)
641 for m, dataId in enumerate(split.chain()):
642 # Passing in any kind of DataCoordinate alone just returns
643 # that object.
644 self.assertIs(dataId, DataCoordinate.standardize(dataId))
645 # Same if we also explicitly pass the dimensions we want.
646 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph))
647 # Same if we pass the dimensions and some irrelevant
648 # kwargs.
649 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12))
650 # Test constructing a new data ID from this one with a
651 # subset of the dimensions.
652 # This is not possible for some combinations of
653 # dimensions if hasFull is False (see
654 # `DataCoordinate.subset` docs).
655 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph)
656 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required):
657 newDataIds = [
658 dataId.subset(newDimensions),
659 DataCoordinate.standardize(dataId, graph=newDimensions),
660 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12),
661 ]
662 for newDataId in newDataIds:
663 with self.subTest(newDataId=newDataId, type=type(dataId)):
664 commonKeys = dataId.keys() & newDataId.keys()
665 self.assertTrue(commonKeys)
666 self.assertEqual(
667 [newDataId[k] for k in commonKeys],
668 [dataId[k] for k in commonKeys],
669 )
670 # This should never "downgrade" from
671 # Complete to Minimal or Expanded to Complete.
672 if dataId.hasRecords():
673 self.assertTrue(newDataId.hasRecords())
674 if dataId.hasFull():
675 self.assertTrue(newDataId.hasFull())
676 # Start from a complete data ID, and pass its values in via several
677 # different ways that should be equivalent.
678 for dataId in split.complete:
679 # Split the keys (dimension names) into two random subsets, so
680 # we can pass some as kwargs below.
681 keys1 = set(
682 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2)
683 )
684 keys2 = dataId.graph.dimensions.names - keys1
685 newCompleteDataIds = [
686 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe),
687 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph),
688 DataCoordinate.standardize(
689 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName()
690 ),
691 DataCoordinate.standardize(
692 DataCoordinate.makeEmpty(dataId.graph.universe),
693 graph=dataId.graph,
694 **dataId.full.byName(),
695 ),
696 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe),
697 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()),
698 DataCoordinate.standardize(
699 {k: dataId[k] for k in keys1},
700 universe=dataId.universe,
701 **{k: dataId[k] for k in keys2},
702 ),
703 DataCoordinate.standardize(
704 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2}
705 ),
706 ]
707 for newDataId in newCompleteDataIds:
708 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
709 self.assertEqual(dataId, newDataId)
710 self.assertTrue(newDataId.hasFull())
712 def testUnion(self):
713 """Test `DataCoordinate.union`."""
714 # Make test graphs to combine; mostly random, but with a few explicit
715 # cases to make sure certain edge cases are covered.
716 graphs = [self.randomDimensionSubset(n=2) for i in range(2)]
717 graphs.append(self.allDataIds.universe["visit"].graph)
718 graphs.append(self.allDataIds.universe["detector"].graph)
719 graphs.append(self.allDataIds.universe["physical_filter"].graph)
720 graphs.append(self.allDataIds.universe["band"].graph)
721 # Iterate over all combinations, including the same graph with itself.
722 for graph1, graph2 in itertools.product(graphs, repeat=2):
723 parentDataIds = self.randomDataIds(n=1)
724 split1 = self.splitByStateFlags(parentDataIds.subset(graph1))
725 split2 = self.splitByStateFlags(parentDataIds.subset(graph2))
726 (parentDataId,) = parentDataIds
727 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
728 unioned = lhs.union(rhs)
729 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
730 self.assertEqual(unioned.graph, graph1.union(graph2))
731 self.assertEqual(unioned, parentDataId.subset(unioned.graph))
732 if unioned.hasFull():
733 self.assertEqual(unioned.subset(lhs.graph), lhs)
734 self.assertEqual(unioned.subset(rhs.graph), rhs)
735 if lhs.hasFull() and rhs.hasFull():
736 self.assertTrue(unioned.hasFull())
737 if lhs.graph >= unioned.graph and lhs.hasFull():
738 self.assertTrue(unioned.hasFull())
739 if lhs.hasRecords():
740 self.assertTrue(unioned.hasRecords())
741 if rhs.graph >= unioned.graph and rhs.hasFull():
742 self.assertTrue(unioned.hasFull())
743 if rhs.hasRecords():
744 self.assertTrue(unioned.hasRecords())
745 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions:
746 self.assertTrue(unioned.hasFull())
747 if lhs.hasRecords() and rhs.hasRecords():
748 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements:
749 self.assertTrue(unioned.hasRecords())
751 def testRegions(self):
752 """Test that data IDs for a few known dimensions have the expected
753 regions.
754 """
755 for dataId in self.randomDataIds(n=4).subset(
756 DimensionGraph(self.allDataIds.universe, names=["visit"])
757 ):
758 self.assertIsNotNone(dataId.region)
759 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
760 self.assertEqual(dataId.region, dataId.records["visit"].region)
761 for dataId in self.randomDataIds(n=4).subset(
762 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"])
763 ):
764 self.assertIsNotNone(dataId.region)
765 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
766 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
767 for dataId in self.randomDataIds(n=4).subset(
768 DimensionGraph(self.allDataIds.universe, names=["tract"])
769 ):
770 self.assertIsNotNone(dataId.region)
771 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
772 self.assertEqual(dataId.region, dataId.records["tract"].region)
773 for dataId in self.randomDataIds(n=4).subset(
774 DimensionGraph(self.allDataIds.universe, names=["patch"])
775 ):
776 self.assertIsNotNone(dataId.region)
777 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
778 self.assertEqual(dataId.region, dataId.records["patch"].region)
779 for data_id in self.randomDataIds(n=1).subset(
780 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"])
781 ):
782 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN)
783 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN)
785 def testTimespans(self):
786 """Test that data IDs for a few known dimensions have the expected
787 timespans.
788 """
789 for dataId in self.randomDataIds(n=4).subset(
790 DimensionGraph(self.allDataIds.universe, names=["visit"])
791 ):
792 self.assertIsNotNone(dataId.timespan)
793 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"})
794 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
795 self.assertEqual(dataId.timespan, dataId.visit.timespan)
796 # Also test the case for non-temporal DataIds.
797 for dataId in self.randomDataIds(n=4).subset(
798 DimensionGraph(self.allDataIds.universe, names=["patch"])
799 ):
800 self.assertIsNone(dataId.timespan)
802 def testIterableStatusFlags(self):
803 """Test that DataCoordinateSet and DataCoordinateSequence compute
804 their hasFull and hasRecords flags correctly from their elements.
805 """
806 dataIds = self.randomDataIds(n=10)
807 split = self.splitByStateFlags(dataIds)
808 for cls in (DataCoordinateSet, DataCoordinateSequence):
809 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull())
810 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull())
811 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords())
812 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords())
813 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull())
814 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull())
815 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords())
816 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords())
817 with self.assertRaises(ValueError):
818 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True)
819 self.assertEqual(
820 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied
821 )
822 self.assertEqual(
823 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied
824 )
825 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords())
826 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords())
827 with self.assertRaises(ValueError):
828 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True)
829 if dataIds.graph.implied:
830 with self.assertRaises(ValueError):
831 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True)
833 def testSetOperations(self):
834 """Test for self-consistency across DataCoordinateSet's operations."""
835 c = self.randomDataIds(n=10).toSet()
836 a = self.randomDataIds(n=20).toSet() | c
837 b = self.randomDataIds(n=20).toSet() | c
838 # Make sure we don't have a particularly unlucky random seed, since
839 # that would make a lot of this test uninteresting.
840 self.assertNotEqual(a, b)
841 self.assertGreater(len(a), 0)
842 self.assertGreater(len(b), 0)
843 # The rest of the tests should not depend on the random seed.
844 self.assertEqual(a, a)
845 self.assertNotEqual(a, a.toSequence())
846 self.assertEqual(a, a.toSequence().toSet())
847 self.assertEqual(a, a.toSequence().toSet())
848 self.assertEqual(b, b)
849 self.assertNotEqual(b, b.toSequence())
850 self.assertEqual(b, b.toSequence().toSet())
851 self.assertEqual(a & b, a.intersection(b))
852 self.assertLessEqual(a & b, a)
853 self.assertLessEqual(a & b, b)
854 self.assertEqual(a | b, a.union(b))
855 self.assertGreaterEqual(a | b, a)
856 self.assertGreaterEqual(a | b, b)
857 self.assertEqual(a - b, a.difference(b))
858 self.assertLessEqual(a - b, a)
859 self.assertLessEqual(b - a, b)
860 self.assertEqual(a ^ b, a.symmetric_difference(b))
861 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
863 def testPackers(self):
864 (instrument_data_id,) = self.allDataIds.subset(
865 self.allDataIds.universe.extract(["instrument"])
866 ).toSet()
867 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"]))
868 packer = TestDimensionPacker(instrument_data_id, detector_data_id.graph)
869 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True)
870 self.assertEqual(packed_id, detector_data_id["detector"])
871 self.assertEqual(max_bits, packer.maxBits)
872 self.assertEqual(
873 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max))
874 )
875 self.assertEqual(packer.pack(detector_data_id), packed_id)
876 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"])
877 self.assertEqual(packer.unpack(packed_id), detector_data_id)
880if __name__ == "__main__":
881 unittest.main()