Coverage for tests/test_dimensions.py: 11%
365 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-24 02:27 -0700
« prev ^ index » next coverage.py v6.4, created at 2022-05-24 02:27 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import copy
23import itertools
24import os
25import pickle
26import unittest
27from dataclasses import dataclass
28from random import Random
29from typing import Iterator, Optional
31from lsst.daf.butler import (
32 DataCoordinate,
33 DataCoordinateSequence,
34 DataCoordinateSet,
35 Dimension,
36 DimensionConfig,
37 DimensionGraph,
38 DimensionUniverse,
39 NamedKeyDict,
40 NamedValueSet,
41 Registry,
42 SpatialRegionDatabaseRepresentation,
43 TimespanDatabaseRepresentation,
44 YamlRepoImportBackend,
45)
46from lsst.daf.butler.registry import RegistryConfig
48DIMENSION_DATA_FILE = os.path.normpath(
49 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml")
50)
53def loadDimensionData() -> DataCoordinateSequence:
54 """Load dimension data from an export file included in the code repository.
56 Returns
57 -------
58 dataIds : `DataCoordinateSet`
59 A set containing all data IDs in the export file.
60 """
61 # Create an in-memory SQLite database and Registry just to import the YAML
62 # data and retreive it as a set of DataCoordinate objects.
63 config = RegistryConfig()
64 config["db"] = "sqlite://"
65 registry = Registry.createFromConfig(config)
66 with open(DIMENSION_DATA_FILE, "r") as stream:
67 backend = YamlRepoImportBackend(stream, registry)
68 backend.register()
69 backend.load(datastore=None)
70 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"])
71 return registry.queryDataIds(dimensions).expanded().toSequence()
74class DimensionTestCase(unittest.TestCase):
75 """Tests for dimensions.
77 All tests here rely on the content of ``config/dimensions.yaml``, either
78 to test that the definitions there are read in properly or just as generic
79 data for testing various operations.
80 """
82 def setUp(self):
83 self.universe = DimensionUniverse()
85 def checkGraphInvariants(self, graph):
86 elements = list(graph.elements)
87 for n, element in enumerate(elements):
88 # Ordered comparisons on graphs behave like sets.
89 self.assertLessEqual(element.graph, graph)
90 # Ordered comparisons on elements correspond to the ordering within
91 # a DimensionUniverse (topological, with deterministic
92 # tiebreakers).
93 for other in elements[:n]:
94 self.assertLess(other, element)
95 self.assertLessEqual(other, element)
96 for other in elements[n + 1 :]:
97 self.assertGreater(other, element)
98 self.assertGreaterEqual(other, element)
99 if isinstance(element, Dimension):
100 self.assertEqual(element.graph.required, element.required)
101 self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
102 self.assertCountEqual(
103 graph.required,
104 [
105 dimension
106 for dimension in graph.dimensions
107 if not any(dimension in other.graph.implied for other in graph.elements)
108 ],
109 )
110 self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
111 self.assertCountEqual(
112 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)]
113 )
114 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied))
115 # Check primary key traversal order: each element should follow any it
116 # requires, and element that is implied by any other in the graph
117 # follow at least one of those.
118 seen = NamedValueSet()
119 for element in graph.primaryKeyTraversalOrder:
120 with self.subTest(required=graph.required, implied=graph.implied, element=element):
121 seen.add(element)
122 self.assertLessEqual(element.graph.required, seen)
123 if element in graph.implied:
124 self.assertTrue(any(element in s.implied for s in seen))
125 self.assertCountEqual(seen, graph.elements)
127 def testConfigPresent(self):
128 config = self.universe.dimensionConfig
129 self.assertIsInstance(config, DimensionConfig)
131 def testConfigRead(self):
132 self.assertEqual(
133 set(self.universe.getStaticDimensions().names),
134 {
135 "instrument",
136 "visit",
137 "visit_system",
138 "exposure",
139 "detector",
140 "physical_filter",
141 "band",
142 "subfilter",
143 "skymap",
144 "tract",
145 "patch",
146 }
147 | {f"htm{level}" for level in range(25)}
148 | {f"healpix{level}" for level in range(18)},
149 )
151 def testGraphs(self):
152 self.checkGraphInvariants(self.universe.empty)
153 for element in self.universe.getStaticElements():
154 self.checkGraphInvariants(element.graph)
156 def testInstrumentDimensions(self):
157 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
158 self.assertCountEqual(
159 graph.dimensions.names,
160 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "visit_system"),
161 )
162 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit"))
163 self.assertCountEqual(graph.implied.names, ("physical_filter", "band", "visit_system"))
164 self.assertCountEqual(
165 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
166 )
167 self.assertCountEqual(graph.governors.names, {"instrument"})
169 def testCalibrationDimensions(self):
170 graph = DimensionGraph(self.universe, names=("physical_filter", "detector"))
171 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band"))
172 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter"))
173 self.assertCountEqual(graph.implied.names, ("band",))
174 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
175 self.assertCountEqual(graph.governors.names, {"instrument"})
177 def testObservationDimensions(self):
178 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit"))
179 self.assertCountEqual(
180 graph.dimensions.names,
181 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "visit_system"),
182 )
183 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit"))
184 self.assertCountEqual(graph.implied.names, ("physical_filter", "band", "visit_system"))
185 self.assertCountEqual(
186 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition")
187 )
188 self.assertCountEqual(graph.spatial.names, ("observation_regions",))
189 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
190 self.assertCountEqual(graph.governors.names, {"instrument"})
191 self.assertEqual(graph.spatial.names, {"observation_regions"})
192 self.assertEqual(graph.temporal.names, {"observation_timespans"})
193 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"])
194 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"])
196 def testSkyMapDimensions(self):
197 graph = DimensionGraph(self.universe, names=("patch",))
198 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch"))
199 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch"))
200 self.assertCountEqual(graph.implied.names, ())
201 self.assertCountEqual(graph.elements.names, graph.dimensions.names)
202 self.assertCountEqual(graph.spatial.names, ("skymap_regions",))
203 self.assertCountEqual(graph.governors.names, {"skymap"})
204 self.assertEqual(graph.spatial.names, {"skymap_regions"})
205 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"])
207 def testSubsetCalculation(self):
208 """Test that independent spatial and temporal options are computed
209 correctly.
210 """
211 graph = DimensionGraph(
212 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure")
213 )
214 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm"))
215 self.assertCountEqual(graph.temporal.names, ("observation_timespans",))
217 def testSchemaGeneration(self):
218 tableSpecs = NamedKeyDict({})
219 for element in self.universe.getStaticElements():
220 if element.hasTable and element.viewOf is None:
221 tableSpecs[element] = element.RecordClass.fields.makeTableSpec(
222 RegionReprClass=SpatialRegionDatabaseRepresentation,
223 TimespanReprClass=TimespanDatabaseRepresentation.Compound,
224 )
225 for element, tableSpec in tableSpecs.items():
226 for dep in element.required:
227 with self.subTest(element=element.name, dep=dep.name):
228 if dep != element:
229 self.assertIn(dep.name, tableSpec.fields)
230 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
231 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length)
232 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes)
233 self.assertFalse(tableSpec.fields[dep.name].nullable)
234 self.assertTrue(tableSpec.fields[dep.name].primaryKey)
235 else:
236 self.assertIn(element.primaryKey.name, tableSpec.fields)
237 self.assertEqual(
238 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype
239 )
240 self.assertEqual(
241 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length
242 )
243 self.assertEqual(
244 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes
245 )
246 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable)
247 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey)
248 for dep in element.implied:
249 with self.subTest(element=element.name, dep=dep.name):
250 self.assertIn(dep.name, tableSpec.fields)
251 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype)
252 self.assertFalse(tableSpec.fields[dep.name].primaryKey)
253 for foreignKey in tableSpec.foreignKeys:
254 self.assertIn(foreignKey.table, tableSpecs)
255 self.assertIn(foreignKey.table, element.graph.dimensions.names)
256 self.assertEqual(len(foreignKey.source), len(foreignKey.target))
257 for source, target in zip(foreignKey.source, foreignKey.target):
258 self.assertIn(source, tableSpec.fields.names)
259 self.assertIn(target, tableSpecs[foreignKey.table].fields.names)
260 self.assertEqual(
261 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype
262 )
263 self.assertEqual(
264 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length
265 )
266 self.assertEqual(
267 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes
268 )
270 def testPickling(self):
271 # Pickling and copying should always yield the exact same object within
272 # a single process (cross-process is impossible to test here).
273 universe1 = DimensionUniverse()
274 universe2 = pickle.loads(pickle.dumps(universe1))
275 universe3 = copy.copy(universe1)
276 universe4 = copy.deepcopy(universe1)
277 self.assertIs(universe1, universe2)
278 self.assertIs(universe1, universe3)
279 self.assertIs(universe1, universe4)
280 for element1 in universe1.getStaticElements():
281 element2 = pickle.loads(pickle.dumps(element1))
282 self.assertIs(element1, element2)
283 graph1 = element1.graph
284 graph2 = pickle.loads(pickle.dumps(graph1))
285 self.assertIs(graph1, graph2)
288@dataclass
289class SplitByStateFlags:
290 """A struct that separates data IDs with different states but the same
291 values.
292 """
294 minimal: Optional[DataCoordinateSequence] = None
295 """Data IDs that only contain values for required dimensions.
297 `DataCoordinateSequence.hasFull()` will return `True` for this if and only
298 if ``minimal.graph.implied`` has no elements.
299 `DataCoordinate.hasRecords()` will always return `False`.
300 """
302 complete: Optional[DataCoordinateSequence] = None
303 """Data IDs that contain values for all dimensions.
305 `DataCoordinateSequence.hasFull()` will always `True` and
306 `DataCoordinate.hasRecords()` will always return `True` for this attribute.
307 """
309 expanded: Optional[DataCoordinateSequence] = None
310 """Data IDs that contain values for all dimensions as well as records.
312 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will
313 always return `True` for this attribute.
314 """
316 def chain(self, n: Optional[int] = None) -> Iterator:
317 """Iterate over the data IDs of different types.
319 Parameters
320 ----------
321 n : `int`, optional
322 If provided (`None` is default), iterate over only the ``nth``
323 data ID in each attribute.
325 Yields
326 ------
327 dataId : `DataCoordinate`
328 A data ID from one of the attributes in this struct.
329 """
330 if n is None:
331 s = slice(None, None)
332 else:
333 s = slice(n, n + 1)
334 if self.minimal is not None:
335 yield from self.minimal[s]
336 if self.complete is not None:
337 yield from self.complete[s]
338 if self.expanded is not None:
339 yield from self.expanded[s]
342class DataCoordinateTestCase(unittest.TestCase):
344 RANDOM_SEED = 10
346 @classmethod
347 def setUpClass(cls):
348 cls.allDataIds = loadDimensionData()
350 def setUp(self):
351 self.rng = Random(self.RANDOM_SEED)
353 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None):
354 """Select random data IDs from those loaded from test data.
356 Parameters
357 ----------
358 n : `int`
359 Number of data IDs to select.
360 dataIds : `DataCoordinateSequence`, optional
361 Data IDs to select from. Defaults to ``self.allDataIds``.
363 Returns
364 -------
365 selected : `DataCoordinateSequence`
366 ``n`` Data IDs randomly selected from ``dataIds`` with replacement.
367 """
368 if dataIds is None:
369 dataIds = self.allDataIds
370 return DataCoordinateSequence(
371 self.rng.sample(dataIds, n),
372 graph=dataIds.graph,
373 hasFull=dataIds.hasFull(),
374 hasRecords=dataIds.hasRecords(),
375 check=False,
376 )
378 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph:
379 """Generate a random `DimensionGraph` that has a subset of the
380 dimensions in a given one.
382 Parameters
383 ----------
384 n : `int`
385 Number of dimensions to select, before automatic expansion by
386 `DimensionGraph`.
387 dataIds : `DimensionGraph`, optional
388 Dimensions to select from. Defaults to ``self.allDataIds.graph``.
390 Returns
391 -------
392 selected : `DimensionGraph`
393 ``n`` or more dimensions randomly selected from ``graph`` with
394 replacement.
395 """
396 if graph is None:
397 graph = self.allDataIds.graph
398 return DimensionGraph(
399 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions)))
400 )
402 def splitByStateFlags(
403 self,
404 dataIds: Optional[DataCoordinateSequence] = None,
405 *,
406 expanded: bool = True,
407 complete: bool = True,
408 minimal: bool = True,
409 ) -> SplitByStateFlags:
410 """Given a sequence of data IDs, generate new equivalent sequences
411 containing less information.
413 Parameters
414 ----------
415 dataIds : `DataCoordinateSequence`, optional.
416 Data IDs to start from. Defaults to ``self.allDataIds``.
417 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return
418 `True`.
419 expanded : `bool`, optional
420 If `True` (default) include the original data IDs that contain all
421 information in the result.
422 complete : `bool`, optional
423 If `True` (default) include data IDs for which ``hasFull()``
424 returns `True` but ``hasRecords()`` does not.
425 minimal : `bool`, optional
426 If `True` (default) include data IDS that only contain values for
427 required dimensions, for which ``hasFull()`` may not return `True`.
429 Returns
430 -------
431 split : `SplitByStateFlags`
432 A dataclass holding the indicated data IDs in attributes that
433 correspond to the boolean keyword arguments.
434 """
435 if dataIds is None:
436 dataIds = self.allDataIds
437 assert dataIds.hasFull() and dataIds.hasRecords()
438 result = SplitByStateFlags(expanded=dataIds)
439 if complete:
440 result.complete = DataCoordinateSequence(
441 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded],
442 graph=dataIds.graph,
443 )
444 self.assertTrue(result.complete.hasFull())
445 self.assertFalse(result.complete.hasRecords())
446 if minimal:
447 result.minimal = DataCoordinateSequence(
448 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded],
449 graph=dataIds.graph,
450 )
451 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied)
452 self.assertFalse(result.minimal.hasRecords())
453 if not expanded:
454 result.expanded = None
455 return result
457 def testMappingInterface(self):
458 """Test that the mapping interface in `DataCoordinate` and (when
459 applicable) its ``full`` property are self-consistent and consistent
460 with the ``graph`` property.
461 """
462 for n in range(5):
463 dimensions = self.randomDimensionSubset()
464 dataIds = self.randomDataIds(n=1).subset(dimensions)
465 split = self.splitByStateFlags(dataIds)
466 for dataId in split.chain():
467 with self.subTest(dataId=dataId):
468 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()])
469 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()])
470 self.assertEqual(dataId.keys(), dataId.graph.required)
471 for dataId in itertools.chain(split.complete, split.expanded):
472 with self.subTest(dataId=dataId):
473 self.assertTrue(dataId.hasFull())
474 self.assertEqual(dataId.graph.dimensions, dataId.full.keys())
475 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions])
477 def testEquality(self):
478 """Test that different `DataCoordinate` instances with different state
479 flags can be compared with each other and other mappings.
480 """
481 dataIds = self.randomDataIds(n=2)
482 split = self.splitByStateFlags(dataIds)
483 # Iterate over all combinations of different states of DataCoordinate,
484 # with the same underlying data ID values.
485 for a0, b0 in itertools.combinations(split.chain(0), 2):
486 self.assertEqual(a0, b0)
487 self.assertEqual(a0, b0.byName())
488 self.assertEqual(a0.byName(), b0)
489 # Same thing, for a different data ID value.
490 for a1, b1 in itertools.combinations(split.chain(1), 2):
491 self.assertEqual(a1, b1)
492 self.assertEqual(a1, b1.byName())
493 self.assertEqual(a1.byName(), b1)
494 # Iterate over all combinations of different states of DataCoordinate,
495 # with different underlying data ID values.
496 for a0, b1 in itertools.product(split.chain(0), split.chain(1)):
497 self.assertNotEqual(a0, b1)
498 self.assertNotEqual(a1, b0)
499 self.assertNotEqual(a0, b1.byName())
500 self.assertNotEqual(a0.byName(), b1)
501 self.assertNotEqual(a1, b0.byName())
502 self.assertNotEqual(a1.byName(), b0)
504 def testStandardize(self):
505 """Test constructing a DataCoordinate from many different kinds of
506 input via `DataCoordinate.standardize` and `DataCoordinate.subset`.
507 """
508 for n in range(5):
509 dimensions = self.randomDimensionSubset()
510 dataIds = self.randomDataIds(n=1).subset(dimensions)
511 split = self.splitByStateFlags(dataIds)
512 for m, dataId in enumerate(split.chain()):
513 # Passing in any kind of DataCoordinate alone just returns
514 # that object.
515 self.assertIs(dataId, DataCoordinate.standardize(dataId))
516 # Same if we also explicitly pass the dimensions we want.
517 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph))
518 # Same if we pass the dimensions and some irrelevant
519 # kwargs.
520 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12))
521 # Test constructing a new data ID from this one with a
522 # subset of the dimensions.
523 # This is not possible for some combinations of
524 # dimensions if hasFull is False (see
525 # `DataCoordinate.subset` docs).
526 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph)
527 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required):
528 newDataIds = [
529 dataId.subset(newDimensions),
530 DataCoordinate.standardize(dataId, graph=newDimensions),
531 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12),
532 ]
533 for newDataId in newDataIds:
534 with self.subTest(newDataId=newDataId, type=type(dataId)):
535 commonKeys = dataId.keys() & newDataId.keys()
536 self.assertTrue(commonKeys)
537 self.assertEqual(
538 [newDataId[k] for k in commonKeys],
539 [dataId[k] for k in commonKeys],
540 )
541 # This should never "downgrade" from
542 # Complete to Minimal or Expanded to Complete.
543 if dataId.hasRecords():
544 self.assertTrue(newDataId.hasRecords())
545 if dataId.hasFull():
546 self.assertTrue(newDataId.hasFull())
547 # Start from a complete data ID, and pass its values in via several
548 # different ways that should be equivalent.
549 for dataId in split.complete:
550 # Split the keys (dimension names) into two random subsets, so
551 # we can pass some as kwargs below.
552 keys1 = set(
553 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2)
554 )
555 keys2 = dataId.graph.dimensions.names - keys1
556 newCompleteDataIds = [
557 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe),
558 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph),
559 DataCoordinate.standardize(
560 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName()
561 ),
562 DataCoordinate.standardize(
563 DataCoordinate.makeEmpty(dataId.graph.universe),
564 graph=dataId.graph,
565 **dataId.full.byName(),
566 ),
567 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe),
568 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()),
569 DataCoordinate.standardize(
570 {k: dataId[k] for k in keys1},
571 universe=dataId.universe,
572 **{k: dataId[k] for k in keys2},
573 ),
574 DataCoordinate.standardize(
575 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2}
576 ),
577 ]
578 for newDataId in newCompleteDataIds:
579 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)):
580 self.assertEqual(dataId, newDataId)
581 self.assertTrue(newDataId.hasFull())
583 def testUnion(self):
584 """Test `DataCoordinate.union`."""
585 # Make test graphs to combine; mostly random, but with a few explicit
586 # cases to make sure certain edge cases are covered.
587 graphs = [self.randomDimensionSubset(n=2) for i in range(2)]
588 graphs.append(self.allDataIds.universe["visit"].graph)
589 graphs.append(self.allDataIds.universe["detector"].graph)
590 graphs.append(self.allDataIds.universe["physical_filter"].graph)
591 graphs.append(self.allDataIds.universe["band"].graph)
592 # Iterate over all combinations, including the same graph with itself.
593 for graph1, graph2 in itertools.product(graphs, repeat=2):
594 parentDataIds = self.randomDataIds(n=1)
595 split1 = self.splitByStateFlags(parentDataIds.subset(graph1))
596 split2 = self.splitByStateFlags(parentDataIds.subset(graph2))
597 (parentDataId,) = parentDataIds
598 for lhs, rhs in itertools.product(split1.chain(), split2.chain()):
599 unioned = lhs.union(rhs)
600 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned):
601 self.assertEqual(unioned.graph, graph1.union(graph2))
602 self.assertEqual(unioned, parentDataId.subset(unioned.graph))
603 if unioned.hasFull():
604 self.assertEqual(unioned.subset(lhs.graph), lhs)
605 self.assertEqual(unioned.subset(rhs.graph), rhs)
606 if lhs.hasFull() and rhs.hasFull():
607 self.assertTrue(unioned.hasFull())
608 if lhs.graph >= unioned.graph and lhs.hasFull():
609 self.assertTrue(unioned.hasFull())
610 if lhs.hasRecords():
611 self.assertTrue(unioned.hasRecords())
612 if rhs.graph >= unioned.graph and rhs.hasFull():
613 self.assertTrue(unioned.hasFull())
614 if rhs.hasRecords():
615 self.assertTrue(unioned.hasRecords())
616 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions:
617 self.assertTrue(unioned.hasFull())
618 if lhs.hasRecords() and rhs.hasRecords():
619 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements:
620 self.assertTrue(unioned.hasRecords())
622 def testRegions(self):
623 """Test that data IDs for a few known dimensions have the expected
624 regions.
625 """
626 for dataId in self.randomDataIds(n=4).subset(
627 DimensionGraph(self.allDataIds.universe, names=["visit"])
628 ):
629 self.assertIsNotNone(dataId.region)
630 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
631 self.assertEqual(dataId.region, dataId.records["visit"].region)
632 for dataId in self.randomDataIds(n=4).subset(
633 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"])
634 ):
635 self.assertIsNotNone(dataId.region)
636 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"})
637 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region)
638 for dataId in self.randomDataIds(n=4).subset(
639 DimensionGraph(self.allDataIds.universe, names=["tract"])
640 ):
641 self.assertIsNotNone(dataId.region)
642 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
643 self.assertEqual(dataId.region, dataId.records["tract"].region)
644 for dataId in self.randomDataIds(n=4).subset(
645 DimensionGraph(self.allDataIds.universe, names=["patch"])
646 ):
647 self.assertIsNotNone(dataId.region)
648 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"})
649 self.assertEqual(dataId.region, dataId.records["patch"].region)
651 def testTimespans(self):
652 """Test that data IDs for a few known dimensions have the expected
653 timespans.
654 """
655 for dataId in self.randomDataIds(n=4).subset(
656 DimensionGraph(self.allDataIds.universe, names=["visit"])
657 ):
658 self.assertIsNotNone(dataId.timespan)
659 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"})
660 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan)
661 # Also test the case for non-temporal DataIds.
662 for dataId in self.randomDataIds(n=4).subset(
663 DimensionGraph(self.allDataIds.universe, names=["patch"])
664 ):
665 self.assertIsNone(dataId.timespan)
667 def testIterableStatusFlags(self):
668 """Test that DataCoordinateSet and DataCoordinateSequence compute
669 their hasFull and hasRecords flags correctly from their elements.
670 """
671 dataIds = self.randomDataIds(n=10)
672 split = self.splitByStateFlags(dataIds)
673 for cls in (DataCoordinateSet, DataCoordinateSequence):
674 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull())
675 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull())
676 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords())
677 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords())
678 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull())
679 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull())
680 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords())
681 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords())
682 with self.assertRaises(ValueError):
683 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True)
684 self.assertEqual(
685 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied
686 )
687 self.assertEqual(
688 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied
689 )
690 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords())
691 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords())
692 with self.assertRaises(ValueError):
693 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True)
694 if dataIds.graph.implied:
695 with self.assertRaises(ValueError):
696 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True)
698 def testSetOperations(self):
699 """Test for self-consistency across DataCoordinateSet's operations."""
700 c = self.randomDataIds(n=10).toSet()
701 a = self.randomDataIds(n=20).toSet() | c
702 b = self.randomDataIds(n=20).toSet() | c
703 # Make sure we don't have a particularly unlucky random seed, since
704 # that would make a lot of this test uninteresting.
705 self.assertNotEqual(a, b)
706 self.assertGreater(len(a), 0)
707 self.assertGreater(len(b), 0)
708 # The rest of the tests should not depend on the random seed.
709 self.assertEqual(a, a)
710 self.assertNotEqual(a, a.toSequence())
711 self.assertEqual(a, a.toSequence().toSet())
712 self.assertEqual(a, a.toSequence().toSet())
713 self.assertEqual(b, b)
714 self.assertNotEqual(b, b.toSequence())
715 self.assertEqual(b, b.toSequence().toSet())
716 self.assertEqual(a & b, a.intersection(b))
717 self.assertLessEqual(a & b, a)
718 self.assertLessEqual(a & b, b)
719 self.assertEqual(a | b, a.union(b))
720 self.assertGreaterEqual(a | b, a)
721 self.assertGreaterEqual(a | b, b)
722 self.assertEqual(a - b, a.difference(b))
723 self.assertLessEqual(a - b, a)
724 self.assertLessEqual(b - a, b)
725 self.assertEqual(a ^ b, a.symmetric_difference(b))
726 self.assertGreaterEqual(a ^ b, (a | b) - (a & b))
729if __name__ == "__main__": 729 ↛ 730line 729 didn't jump to line 730, because the condition on line 729 was never true
730 unittest.main()