Coverage for tests/test_dimensions.py: 11%
382 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-19 12:13 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-19 12:13 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import copy
23import itertools
24import os
25import pickle
26import unittest
27from dataclasses import dataclass
28from random import Random
29from typing import Iterator, Optional
31from lsst.daf.butler import (
32 Config,
33 DataCoordinate,
34 DataCoordinateSequence,
35 DataCoordinateSet,
36 Dimension,
37 DimensionConfig,
38 DimensionGraph,
39 DimensionUniverse,
40 NamedKeyDict,
41 NamedValueSet,
42 Registry,
43 SpatialRegionDatabaseRepresentation,
44 TimespanDatabaseRepresentation,
45 YamlRepoImportBackend,
46)
47from lsst.daf.butler.registry import RegistryConfig
49DIMENSION_DATA_FILE = os.path.normpath(
50 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
51)
54def loadDimensionData() -> DataCoordinateSequence:
55 """Load dimension data from an export file included in the code repository.
57 Returns
58 -------
59 dataIds : `DataCoordinateSet`
60 A set containing all data IDs in the export file.
61 """
62 # Create an in-memory SQLite database and Registry just to import the YAML
63 # data and retreive it as a set of DataCoordinate objects.
64 config = RegistryConfig()
65 config["db"] = "sqlite://"
66 registry = Registry.createFromConfig(config)
67 with open(DIMENSION_DATA_FILE, "r") as stream:
68 backend = YamlRepoImportBackend(stream, registry)
69 backend.register()
70 backend.load(datastore=None)
71 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"])
72 return registry.queryDataIds(dimensions).expanded().toSequence()
75class DimensionTestCase(unittest.TestCase):
76 """Tests for dimensions.
78 All tests here rely on the content of ``config/dimensions.yaml``, either
79 to test that the definitions there are read in properly or just as generic
80 data for testing various operations.
81 """
83 def setUp(self):
84 self.universe = DimensionUniverse()
86 def checkGraphInvariants(self, graph):
87 elements = list(graph.elements)
88 for n, element in enumerate(elements):
89 # Ordered comparisons on graphs behave like sets.
90 self.assertLessEqual(element.graph, graph)
91 # Ordered comparisons on elements correspond to the ordering within
92 # a DimensionUniverse (topological, with deterministic
93 # tiebreakers).
94 for other in elements[:n]:
95 self.assertLess(other, element)
96 self.assertLessEqual(other, element)
97 for other in elements[n + 1 :]:
98 self.assertGreater(other, element)
99 self.assertGreaterEqual(other, element)
100 if isinstance(element, Dimension):
101 self.assertEqual(element.graph.required, element.required)
102 self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
103 self.assertCountEqual(
104 graph.required,
105 [
106 dimension
107 for dimension in graph.dimensions
108 if not any(dimension in other.graph.implied for other in graph.elements)
109 ],
110 )
111 self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
112 self.assertCountEqual(
113 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)]
114 )
115 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied))
116 # Check primary key traversal order: each element should follow any it
117 # requires, and element that is implied by any other in the graph
118 # follow at least one of those.
119 seen = NamedValueSet()
120 for element in graph.primaryKeyTraversalOrder:
121 with self.subTest(required=graph.required, implied=graph.implied, element=element):
122 seen.add(element)
123 self.assertLessEqual(element.graph.required, seen)
124 if element in graph.implied:
125 self.assertTrue(any(element in s.implied for s in seen))
126 self.assertCountEqual(seen, graph.elements)
128 def testConfigPresent(self):
129 config = self.universe.dimensionConfig
130 self.assertIsInstance(config, DimensionConfig)
132 def testCompatibility(self):
133 # Simple check that should always be true.
134 self.assertTrue(self.universe.isCompatibleWith(self.universe))
136 # Create a universe like the default universe but with a different
137 # version number.
138 clone = self.universe.dimensionConfig.copy()
139 clone["version"] = clone["version"] + 1_000_000 # High version number
140 universe_clone = DimensionUniverse(config=clone)
141 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm:
142 self.assertTrue(self.universe.isCompatibleWith(universe_clone))
143 self.assertIn("differing versions", "\n".join(cm.output))
145 # Create completely incompatible universe.
146 config = Config(
147 {
148 "version": 1,
149 "namespace": "compat_test",
150 "skypix": {
151 "common": "htm7",
152 "htm": {
153 "class": "lsst.sphgeom.HtmPixelization",
154 "max_level": 24,
155 },
156 },
157 "elements": {
158 "A": {
159 "keys": [
160 {
161 "name": "id",
162 "type": "int",
163 }
164 ],
165 "storage": {
166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
167 },
168 },
169 "B": {
170 "keys": [
171 {
172 "name": "id",
173 "type": "int",
174 }
175 ],
176 "storage": {
177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
178 },
179 },
180 },
181 "packers": {},
182 }
183 )
184 universe2 = DimensionUniverse(config=config)
185 self.assertFalse(universe2.isCompatibleWith(self.universe))
187 def testVersion(self):
188 self.assertEqual(self.universe.namespace, "daf_butler")
189 # Test was added starting at version 2.
190 self.assertGreaterEqual(self.universe.version, 2)
192 def testConfigRead(self):
193 self.assertEqual(
194 set(self.universe.getStaticDimensions().names),
195 {
196 "instrument",
197 "visit",
198 "visit_system",
199 "exposure",
200 "detector",
201 "physical_filter",
202 "band",
203 "subfilter",
204 "skymap",
205 "tract",
206 "patch",
207 }
208 | {f"htm{level}" for level in range(25)}
209 | {f"healpix{level}" for level in range(18)},
210 )
212 def testGraphs(self):
213 self.checkGraphInvariants(self.universe.empty)
214 for element in self.universe.getStaticElements():
215 self.checkGraphInvariants(element.graph)
217 def testInstrumentDimensions(self):
218 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
219 self.assertCountEqual(
220 graph.dimensions.names,
221 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"),
222 )
223 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit"))
224 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
225 self.assertCountEqual(
226 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
227 )
228 self.assertCountEqual(graph.governors.names, {"instrument"})
230 def testCalibrationDimensions(self):
231 graph = DimensionGraph(self.universe, names=("physical_filter", "detector"))
232 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band"))
233 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter"))
234 self.assertCountEqual(graph.implied.names, ("band",))
235 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
236 self.assertCountEqual(graph.governors.names, {"instrument"})
238 def testObservationDimensions(self):
239 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
240 self.assertCountEqual(
241 graph.dimensions.names,
242 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"),
243 )
244 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit"))
245 self.assertCountEqual(graph.implied.names, ("physical_filter", "band"))
246 self.assertCountEqual(
247 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
248 )
249 self.assertCountEqual(graph.spatial.names, ("observation_regions",))
250 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
251 self.assertCountEqual(graph.governors.names, {"instrument"})
252 self.assertEqual(graph.spatial.names, {"observation_regions"})
253 self.assertEqual(graph.temporal.names, {"observation_timespans"})
254 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"])
255 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"])
257 def testSkyMapDimensions(self):
258 graph = DimensionGraph(self.universe, names=("patch",))
259 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch"))
260 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch"))
261 self.assertCountEqual(graph.implied.names, ())
262 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
263 self.assertCountEqual(graph.spatial.names, ("skymap_regions",))
264 self.assertCountEqual(graph.governors.names, {"skymap"})
265 self.assertEqual(graph.spatial.names, {"skymap_regions"})
266 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"])
268 def testSubsetCalculation(self):
269 """Test that independent spatial and temporal options are computed
270 correctly.
271 """
272 graph = DimensionGraph(
273 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure")
274 )
275 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm"))
276 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
278 def testSchemaGeneration(self):
279 tableSpecs = NamedKeyDict({})
280 for element in self.universe.getStaticElements():
281 if element.hasTable and element.viewOf is None:
282 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
283 RegionReprClass=SpatialRegionDatabaseRepresentation,
284 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
285 )
286 for element, tableSpec in tableSpecs.items():
287 for dep in element.required:
288 with self.subTest(element=element.name, dep=dep.name):
289 if dep != element:
290 self.assertIn(dep.name, tableSpec.fields)
291 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
292 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
293 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
294 self.assertFalse(tableSpec.fields[dep.name].nullable)
295 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
296 else:
297 self.assertIn(element.primaryKey.name, tableSpec.fields)
298 self.assertEqual(
299 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
300 )
301 self.assertEqual(
302 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
303 )
304 self.assertEqual(
305 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
306 )
307 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
308 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
309 for dep in element.implied:
310 with self.subTest(element=element.name, dep=dep.name):
311 self.assertIn(dep.name, tableSpec.fields)
312 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
313 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
314 for foreignKey in tableSpec.foreignKeys:
315 self.assertIn(foreignKey.table, tableSpecs)
316 self.assertIn(foreignKey.table, element.graph.dimensions.names)
317 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
318 for source, target in zip(foreignKey.source, foreignKey.target):
319 self.assertIn(source, tableSpec.fields.names)
320 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
321 self.assertEqual(
322 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
323 )
324 self.assertEqual(
325 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
326 )
327 self.assertEqual(
328 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
329 )
331 def testPickling(self):
332 # Pickling and copying should always yield the exact same object within
333 # a single process (cross-process is impossible to test here).
334 universe1 = DimensionUniverse()
335 universe2 = pickle.loads(pickle.dumps(universe1))
336 universe3 = copy.copy(universe1)
337 universe4 = copy.deepcopy(universe1)
338 self.assertIs(universe1, universe2)
339 self.assertIs(universe1, universe3)
340 self.assertIs(universe1, universe4)
341 for element1 in universe1.getStaticElements():
342 element2 = pickle.loads(pickle.dumps(element1))
343 self.assertIs(element1, element2)
344 graph1 = element1.graph
345 graph2 = pickle.loads(pickle.dumps(graph1))
346 self.assertIs(graph1, graph2)
349@dataclass
350class SplitByStateFlags:
351 """A struct that separates data IDs with different states but the same
352 values.
353 """
355 minimal: Optional[DataCoordinateSequence] = None
356 """Data IDs that only contain values for required dimensions.
358 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
359 if ``minimal.graph.implied`` has no elements.
360 `DataCoordinate.hasRecords()` will always return `False`.
361 """
363 complete: Optional[DataCoordinateSequence] = None
364 """Data IDs that contain values for all dimensions.
366 `DataCoordinateSequence.hasFull()` will always `True` and
367 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
368 """
370 expanded: Optional[DataCoordinateSequence] = None
371 """Data IDs that contain values for all dimensions as well as records.
373 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
374 always return `True` for this attribute.
375 """
377 def chain(self, n: Optional[int] = None) -> Iterator:
378 """Iterate over the data IDs of different types.
380 Parameters
381 ----------
382 n : `int`, optional
383 If provided (`None` is default), iterate over only the ``nth``
384 data ID in each attribute.
386 Yields
387 ------
388 dataId : `DataCoordinate`
389 A data ID from one of the attributes in this struct.
390 """
391 if n is None:
392 s = slice(None, None)
393 else:
394 s = slice(n, n + 1)
395 if self.minimal is not None:
396 yield from self.minimal[s]
397 if self.complete is not None:
398 yield from self.complete[s]
399 if self.expanded is not None:
400 yield from self.expanded[s]
403class DataCoordinateTestCase(unittest.TestCase):
405 RANDOM_SEED = 10
407 @classmethod
408 def setUpClass(cls):
409 cls.allDataIds = loadDimensionData()
411 def setUp(self):
412 self.rng = Random(self.RANDOM_SEED)
414 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None):
415 """Select random data IDs from those loaded from test data.
417 Parameters
418 ----------
419 n : `int`
420 Number of data IDs to select.
421 dataIds : `DataCoordinateSequence`, optional
422 Data IDs to select from. Defaults to ``self.allDataIds``.
424 Returns
425 -------
426 selected : `DataCoordinateSequence`
427 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
428 """
429 if dataIds is None:
430 dataIds = self.allDataIds
431 return DataCoordinateSequence(
432 self.rng.sample(dataIds, n),
433 graph=dataIds.graph,
434 hasFull=dataIds.hasFull(),
435 hasRecords=dataIds.hasRecords(),
436 check=False,
437 )
439 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph:
440 """Generate a random `DimensionGraph` that has a subset of the
441 dimensions in a given one.
443 Parameters
444 ----------
445 n : `int`
446 Number of dimensions to select, before automatic expansion by
447 `DimensionGraph`.
448 dataIds : `DimensionGraph`, optional
449 Dimensions to select from. Defaults to ``self.allDataIds.graph``.
451 Returns
452 -------
453 selected : `DimensionGraph`
454 ``n`` or more dimensions randomly selected from ``graph`` with
455 replacement.
456 """
457 if graph is None:
458 graph = self.allDataIds.graph
459 return DimensionGraph(
460 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions)))
461 )
463 def splitByStateFlags(
464 self,
465 dataIds: Optional[DataCoordinateSequence] = None,
466 *,
467 expanded: bool = True,
468 complete: bool = True,
469 minimal: bool = True,
470 ) -> SplitByStateFlags:
471 """Given a sequence of data IDs, generate new equivalent sequences
472 containing less information.
474 Parameters
475 ----------
476 dataIds : `DataCoordinateSequence`, optional.
477 Data IDs to start from. Defaults to ``self.allDataIds``.
478 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
479 `True`.
480 expanded : `bool`, optional
481 If `True` (default) include the original data IDs that contain all
482 information in the result.
483 complete : `bool`, optional
484 If `True` (default) include data IDs for which ``hasFull()``
485 returns `True` but ``hasRecords()`` does not.
486 minimal : `bool`, optional
487 If `True` (default) include data IDS that only contain values for
488 required dimensions, for which ``hasFull()`` may not return `True`.
490 Returns
491 -------
492 split : `SplitByStateFlags`
493 A dataclass holding the indicated data IDs in attributes that
494 correspond to the boolean keyword arguments.
495 """
496 if dataIds is None:
497 dataIds = self.allDataIds
498 assert dataIds.hasFull() and dataIds.hasRecords()
499 result = SplitByStateFlags(expanded=dataIds)
500 if complete:
501 result.complete = DataCoordinateSequence(
502 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded],
503 graph=dataIds.graph,
504 )
505 self.assertTrue(result.complete.hasFull())
506 self.assertFalse(result.complete.hasRecords())
507 if minimal:
508 result.minimal = DataCoordinateSequence(
509 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded],
510 graph=dataIds.graph,
511 )
512 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied)
513 self.assertFalse(result.minimal.hasRecords())
514 if not expanded:
515 result.expanded = None
516 return result
518 def testMappingInterface(self):
519 """Test that the mapping interface in `DataCoordinate` and (when
520 applicable) its ``full`` property are self-consistent and consistent
521 with the ``graph`` property.
522 """
523 for n in range(5):
524 dimensions = self.randomDimensionSubset()
525 dataIds = self.randomDataIds(n=1).subset(dimensions)
526 split = self.splitByStateFlags(dataIds)
527 for dataId in split.chain():
528 with self.subTest(dataId=dataId):
529 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()])
530 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()])
531 self.assertEqual(dataId.keys(), dataId.graph.required)
532 for dataId in itertools.chain(split.complete, split.expanded):
533 with self.subTest(dataId=dataId):
534 self.assertTrue(dataId.hasFull())
535 self.assertEqual(dataId.graph.dimensions, dataId.full.keys())
536 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions])
538 def testEquality(self):
539 """Test that different `DataCoordinate` instances with different state
540 flags can be compared with each other and other mappings.
541 """
542 dataIds = self.randomDataIds(n=2)
543 split = self.splitByStateFlags(dataIds)
544 # Iterate over all combinations of different states of DataCoordinate,
545 # with the same underlying data ID values.
546 for a0, b0 in itertools.combinations(split.chain(0), 2):
547 self.assertEqual(a0, b0)
548 self.assertEqual(a0, b0.byName())
549 self.assertEqual(a0.byName(), b0)
550 # Same thing, for a different data ID value.
551 for a1, b1 in itertools.combinations(split.chain(1), 2):
552 self.assertEqual(a1, b1)
553 self.assertEqual(a1, b1.byName())
554 self.assertEqual(a1.byName(), b1)
555 # Iterate over all combinations of different states of DataCoordinate,
556 # with different underlying data ID values.
557 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
558 self.assertNotEqual(a0, b1)
559 self.assertNotEqual(a1, b0)
560 self.assertNotEqual(a0, b1.byName())
561 self.assertNotEqual(a0.byName(), b1)
562 self.assertNotEqual(a1, b0.byName())
563 self.assertNotEqual(a1.byName(), b0)
565 def testStandardize(self):
566 """Test constructing a DataCoordinate from many different kinds of
567 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
568 """
569 for n in range(5):
570 dimensions = self.randomDimensionSubset()
571 dataIds = self.randomDataIds(n=1).subset(dimensions)
572 split = self.splitByStateFlags(dataIds)
573 for m, dataId in enumerate(split.chain()):
574 # Passing in any kind of DataCoordinate alone just returns
575 # that object.
576 self.assertIs(dataId, DataCoordinate.standardize(dataId))
577 # Same if we also explicitly pass the dimensions we want.
578 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph))
579 # Same if we pass the dimensions and some irrelevant
580 # kwargs.
581 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12))
582 # Test constructing a new data ID from this one with a
583 # subset of the dimensions.
584 # This is not possible for some combinations of
585 # dimensions if hasFull is False (see
586 # `DataCoordinate.subset` docs).
587 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph)
588 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required):
589 newDataIds = [
590 dataId.subset(newDimensions),
591 DataCoordinate.standardize(dataId, graph=newDimensions),
592 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12),
593 ]
594 for newDataId in newDataIds:
595 with self.subTest(newDataId=newDataId, type=type(dataId)):
596 commonKeys = dataId.keys() & newDataId.keys()
597 self.assertTrue(commonKeys)
598 self.assertEqual(
599 [newDataId[k] for k in commonKeys],
600 [dataId[k] for k in commonKeys],
601 )
602 # This should never "downgrade" from
603 # Complete to Minimal or Expanded to Complete.
604 if dataId.hasRecords():
605 self.assertTrue(newDataId.hasRecords())
606 if dataId.hasFull():
607 self.assertTrue(newDataId.hasFull())
608 # Start from a complete data ID, and pass its values in via several
609 # different ways that should be equivalent.
610 for dataId in split.complete:
611 # Split the keys (dimension names) into two random subsets, so
612 # we can pass some as kwargs below.
613 keys1 = set(
614 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2)
615 )
616 keys2 = dataId.graph.dimensions.names - keys1
617 newCompleteDataIds = [
618 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe),
619 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph),
620 DataCoordinate.standardize(
621 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName()
622 ),
623 DataCoordinate.standardize(
624 DataCoordinate.makeEmpty(dataId.graph.universe),
625 graph=dataId.graph,
626 **dataId.full.byName(),
627 ),
628 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe),
629 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()),
630 DataCoordinate.standardize(
631 {k: dataId[k] for k in keys1},
632 universe=dataId.universe,
633 **{k: dataId[k] for k in keys2},
634 ),
635 DataCoordinate.standardize(
636 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2}
637 ),
638 ]
639 for newDataId in newCompleteDataIds:
640 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
641 self.assertEqual(dataId, newDataId)
642 self.assertTrue(newDataId.hasFull())
644 def testUnion(self):
645 """Test `DataCoordinate.union`."""
646 # Make test graphs to combine; mostly random, but with a few explicit
647 # cases to make sure certain edge cases are covered.
648 graphs = [self.randomDimensionSubset(n=2) for i in range(2)]
649 graphs.append(self.allDataIds.universe["visit"].graph)
650 graphs.append(self.allDataIds.universe["detector"].graph)
651 graphs.append(self.allDataIds.universe["physical_filter"].graph)
652 graphs.append(self.allDataIds.universe["band"].graph)
653 # Iterate over all combinations, including the same graph with itself.
654 for graph1, graph2 in itertools.product(graphs, repeat=2):
655 parentDataIds = self.randomDataIds(n=1)
656 split1 = self.splitByStateFlags(parentDataIds.subset(graph1))
657 split2 = self.splitByStateFlags(parentDataIds.subset(graph2))
658 (parentDataId,) = parentDataIds
659 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
660 unioned = lhs.union(rhs)
661 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
662 self.assertEqual(unioned.graph, graph1.union(graph2))
663 self.assertEqual(unioned, parentDataId.subset(unioned.graph))
664 if unioned.hasFull():
665 self.assertEqual(unioned.subset(lhs.graph), lhs)
666 self.assertEqual(unioned.subset(rhs.graph), rhs)
667 if lhs.hasFull() and rhs.hasFull():
668 self.assertTrue(unioned.hasFull())
669 if lhs.graph >= unioned.graph and lhs.hasFull():
670 self.assertTrue(unioned.hasFull())
671 if lhs.hasRecords():
672 self.assertTrue(unioned.hasRecords())
673 if rhs.graph >= unioned.graph and rhs.hasFull():
674 self.assertTrue(unioned.hasFull())
675 if rhs.hasRecords():
676 self.assertTrue(unioned.hasRecords())
677 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions:
678 self.assertTrue(unioned.hasFull())
679 if lhs.hasRecords() and rhs.hasRecords():
680 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements:
681 self.assertTrue(unioned.hasRecords())
683 def testRegions(self):
684 """Test that data IDs for a few known dimensions have the expected
685 regions.
686 """
687 for dataId in self.randomDataIds(n=4).subset(
688 DimensionGraph(self.allDataIds.universe, names=["visit"])
689 ):
690 self.assertIsNotNone(dataId.region)
691 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
692 self.assertEqual(dataId.region, dataId.records["visit"].region)
693 for dataId in self.randomDataIds(n=4).subset(
694 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"])
695 ):
696 self.assertIsNotNone(dataId.region)
697 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
698 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
699 for dataId in self.randomDataIds(n=4).subset(
700 DimensionGraph(self.allDataIds.universe, names=["tract"])
701 ):
702 self.assertIsNotNone(dataId.region)
703 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
704 self.assertEqual(dataId.region, dataId.records["tract"].region)
705 for dataId in self.randomDataIds(n=4).subset(
706 DimensionGraph(self.allDataIds.universe, names=["patch"])
707 ):
708 self.assertIsNotNone(dataId.region)
709 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
710 self.assertEqual(dataId.region, dataId.records["patch"].region)
712 def testTimespans(self):
713 """Test that data IDs for a few known dimensions have the expected
714 timespans.
715 """
716 for dataId in self.randomDataIds(n=4).subset(
717 DimensionGraph(self.allDataIds.universe, names=["visit"])
718 ):
719 self.assertIsNotNone(dataId.timespan)
720 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"})
721 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
722 # Also test the case for non-temporal DataIds.
723 for dataId in self.randomDataIds(n=4).subset(
724 DimensionGraph(self.allDataIds.universe, names=["patch"])
725 ):
726 self.assertIsNone(dataId.timespan)
728 def testIterableStatusFlags(self):
729 """Test that DataCoordinateSet and DataCoordinateSequence compute
730 their hasFull and hasRecords flags correctly from their elements.
731 """
732 dataIds = self.randomDataIds(n=10)
733 split = self.splitByStateFlags(dataIds)
734 for cls in (DataCoordinateSet, DataCoordinateSequence):
735 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull())
736 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull())
737 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords())
738 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords())
739 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull())
740 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull())
741 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords())
742 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords())
743 with self.assertRaises(ValueError):
744 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True)
745 self.assertEqual(
746 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied
747 )
748 self.assertEqual(
749 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied
750 )
751 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords())
752 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords())
753 with self.assertRaises(ValueError):
754 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True)
755 if dataIds.graph.implied:
756 with self.assertRaises(ValueError):
757 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True)
759 def testSetOperations(self):
760 """Test for self-consistency across DataCoordinateSet's operations."""
761 c = self.randomDataIds(n=10).toSet()
762 a = self.randomDataIds(n=20).toSet() | c
763 b = self.randomDataIds(n=20).toSet() | c
764 # Make sure we don't have a particularly unlucky random seed, since
765 # that would make a lot of this test uninteresting.
766 self.assertNotEqual(a, b)
767 self.assertGreater(len(a), 0)
768 self.assertGreater(len(b), 0)
769 # The rest of the tests should not depend on the random seed.
770 self.assertEqual(a, a)
771 self.assertNotEqual(a, a.toSequence())
772 self.assertEqual(a, a.toSequence().toSet())
773 self.assertEqual(a, a.toSequence().toSet())
774 self.assertEqual(b, b)
775 self.assertNotEqual(b, b.toSequence())
776 self.assertEqual(b, b.toSequence().toSet())
777 self.assertEqual(a & b, a.intersection(b))
778 self.assertLessEqual(a & b, a)
779 self.assertLessEqual(a & b, b)
780 self.assertEqual(a | b, a.union(b))
781 self.assertGreaterEqual(a | b, a)
782 self.assertGreaterEqual(a | b, b)
783 self.assertEqual(a - b, a.difference(b))
784 self.assertLessEqual(a - b, a)
785 self.assertLessEqual(b - a, b)
786 self.assertEqual(a ^ b, a.symmetric_difference(b))
787 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
790if __name__ == "__main__": 790 ↛ 791line 790 didn't jump to line 791, because the condition on line 790 was never true
791 unittest.main()