Coverage for tests / test_query_utilities.py: 13%
262 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:43 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for non-public Butler._query functionality that is not specific to
29any Butler or QueryDriver implementation.
30"""
32from __future__ import annotations
34import unittest
35from collections.abc import Iterable, Set
37import astropy.time
39from lsst.daf.butler import DimensionUniverse, InvalidQueryError, Timespan
40from lsst.daf.butler.dimensions import DimensionElement, DimensionGroup
41from lsst.daf.butler.queries import tree as qt
42from lsst.daf.butler.queries.expression_factory import ExpressionFactory
43from lsst.daf.butler.queries.overlaps import OverlapsVisitor, _NaiveDisjointSet
44from lsst.daf.butler.queries.visitors import PredicateVisitFlags
45from lsst.sphgeom import Mq3cPixelization, Region
48class ColumnSetTestCase(unittest.TestCase):
49 """Tests for lsst.daf.butler.queries.ColumnSet."""
51 def setUp(self) -> None:
52 self.universe = DimensionUniverse()
54 def test_basics(self) -> None:
55 columns = qt.ColumnSet(self.universe.conform(["detector"]))
56 self.assertNotEqual(columns, columns.dimensions.names) # intentionally not comparable to other sets
57 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
58 self.assertFalse(columns.dataset_fields)
59 columns.dataset_fields["bias"].add("dataset_id")
60 self.assertEqual(dict(columns.dataset_fields), {"bias": {"dataset_id"}})
61 columns.dimension_fields["detector"].add("purpose")
62 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
63 self.assertTrue(columns)
64 self.assertEqual(
65 list(columns),
66 [(k, None) for k in columns.dimensions.data_coordinate_keys]
67 + [("detector", "purpose"), ("bias", "dataset_id")],
68 )
69 self.assertEqual(str(columns), "{instrument, detector, detector:purpose, bias:dataset_id}")
70 empty = qt.ColumnSet(self.universe.empty)
71 self.assertFalse(empty)
72 self.assertFalse(columns.issubset(empty))
73 self.assertTrue(columns.issuperset(empty))
74 self.assertTrue(columns.isdisjoint(empty))
75 copy = columns.copy()
76 self.assertEqual(columns, copy)
77 self.assertTrue(columns.issubset(copy))
78 self.assertTrue(columns.issuperset(copy))
79 self.assertFalse(columns.isdisjoint(copy))
80 copy.dataset_fields["bias"].add("timespan")
81 copy.dimension_fields["detector"].add("name")
82 copy.update_dimensions(self.universe.conform(["band"]))
83 self.assertEqual(copy.dataset_fields["bias"], {"dataset_id", "timespan"})
84 self.assertEqual(columns.dataset_fields["bias"], {"dataset_id"})
85 self.assertEqual(copy.dimension_fields["detector"], {"purpose", "name"})
86 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
87 self.assertTrue(columns.issubset(copy))
88 self.assertFalse(columns.issuperset(copy))
89 self.assertFalse(columns.isdisjoint(copy))
90 columns.update(copy)
91 self.assertEqual(columns, copy)
92 self.assertTrue(columns.is_timespan("visit", "timespan"))
93 self.assertFalse(columns.is_timespan("visit", None))
94 self.assertFalse(columns.is_timespan("detector", "purpose"))
96 def test_drop_dimension_keys(self):
97 columns = qt.ColumnSet(self.universe.conform(["physical_filter"]))
98 columns.drop_implied_dimension_keys()
99 self.assertEqual(list(columns), [("instrument", None), ("physical_filter", None)])
100 undropped = qt.ColumnSet(columns.dimensions)
101 self.assertTrue(columns.issubset(undropped))
102 self.assertFalse(columns.issuperset(undropped))
103 self.assertFalse(columns.isdisjoint(undropped))
104 band_only = qt.ColumnSet(self.universe.conform(["band"]))
105 self.assertFalse(columns.issubset(band_only))
106 self.assertFalse(columns.issuperset(band_only))
107 self.assertTrue(columns.isdisjoint(band_only))
108 copy = columns.copy()
109 copy.update(band_only)
110 self.assertEqual(copy, undropped)
111 columns.restore_dimension_keys()
112 self.assertEqual(columns, undropped)
114 def test_get_column_spec(self) -> None:
115 columns = qt.ColumnSet(self.universe.conform(["detector"]))
116 columns.dimension_fields["detector"].add("purpose")
117 columns.dataset_fields["bias"].update(["dataset_id", "run", "collection", "timespan", "ingest_date"])
118 self.assertEqual(columns.get_column_spec("instrument", None).name, "instrument")
119 self.assertEqual(columns.get_column_spec("instrument", None).type, "string")
120 self.assertEqual(columns.get_column_spec("instrument", None).nullable, False)
121 self.assertEqual(columns.get_column_spec("detector", None).name, "detector")
122 self.assertEqual(columns.get_column_spec("detector", None).type, "int")
123 self.assertEqual(columns.get_column_spec("detector", None).nullable, False)
124 self.assertEqual(columns.get_column_spec("detector", "purpose").name, "detector:purpose")
125 self.assertEqual(columns.get_column_spec("detector", "purpose").type, "string")
126 self.assertEqual(columns.get_column_spec("detector", "purpose").nullable, True)
127 self.assertEqual(columns.get_column_spec("bias", "dataset_id").name, "bias:dataset_id")
128 self.assertEqual(columns.get_column_spec("bias", "dataset_id").type, "uuid")
129 self.assertEqual(columns.get_column_spec("bias", "dataset_id").nullable, False)
130 self.assertEqual(columns.get_column_spec("bias", "run").name, "bias:run")
131 self.assertEqual(columns.get_column_spec("bias", "run").type, "string")
132 self.assertEqual(columns.get_column_spec("bias", "run").nullable, False)
133 self.assertEqual(columns.get_column_spec("bias", "collection").name, "bias:collection")
134 self.assertEqual(columns.get_column_spec("bias", "collection").type, "string")
135 self.assertEqual(columns.get_column_spec("bias", "collection").nullable, False)
136 self.assertEqual(columns.get_column_spec("bias", "timespan").name, "bias:timespan")
137 self.assertEqual(columns.get_column_spec("bias", "timespan").type, "timespan")
138 self.assertEqual(columns.get_column_spec("bias", "timespan").nullable, True)
139 self.assertEqual(columns.get_column_spec("bias", "ingest_date").name, "bias:ingest_date")
140 self.assertEqual(columns.get_column_spec("bias", "ingest_date").type, "datetime")
141 self.assertEqual(columns.get_column_spec("bias", "ingest_date").nullable, True)
144class _RecordingOverlapsVisitor(OverlapsVisitor):
145 def __init__(self, dimensions: DimensionGroup, calibration_dataset_types: Set[str] = frozenset()):
146 super().__init__(dimensions, calibration_dataset_types)
147 self.spatial_constraints: list[tuple[str, PredicateVisitFlags]] = []
148 self.spatial_joins: list[tuple[str, str, PredicateVisitFlags]] = []
149 self.temporal_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = []
150 self.validity_range_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = []
151 self.validity_range_joins: list[tuple[str, str, PredicateVisitFlags]] = []
153 def visit_spatial_constraint(
154 self, element: DimensionElement, region: Region, flags: PredicateVisitFlags
155 ) -> qt.Predicate | None:
156 self.spatial_constraints.append((element.name, flags))
157 return super().visit_spatial_constraint(element, region, flags)
159 def visit_spatial_join(
160 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
161 ) -> qt.Predicate | None:
162 self.spatial_joins.append((a.name, b.name, flags))
163 return super().visit_spatial_join(a, b, flags)
165 def visit_temporal_dimension_join(
166 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
167 ) -> qt.Predicate | None:
168 self.temporal_dimension_joins.append((a.name, b.name, flags))
169 return super().visit_temporal_dimension_join(a, b, flags)
171 def visit_validity_range_dimension_join(
172 self, a: str, b: DimensionElement, flags: PredicateVisitFlags
173 ) -> qt.Predicate | None:
174 self.validity_range_dimension_joins.append((a, b.name, flags))
175 return super().visit_validity_range_dimension_join(a, b, flags)
177 def visit_validity_range_join(self, a: str, b: str, flags: PredicateVisitFlags) -> qt.Predicate | None:
178 self.validity_range_joins.append((a, b, flags))
179 return super().visit_validity_range_join(a, b, flags)
182class OverlapsVisitorTestCase(unittest.TestCase):
183 """Tests for lsst.daf.butler.queries.overlaps.OverlapsVisitor, which is
184 responsible for validating and inferring spatial and temporal joins and
185 constraints.
186 """
188 def setUp(self) -> None:
189 self.universe = DimensionUniverse()
191 def run_visitor(
192 self,
193 dimensions: Iterable[str],
194 predicate: qt.Predicate,
195 expected: str | None = None,
196 join_operands: Iterable[DimensionGroup] = (),
197 calibration_dataset_types: Set[str] = frozenset(),
198 ) -> _RecordingOverlapsVisitor:
199 visitor = _RecordingOverlapsVisitor(self.universe.conform(dimensions), calibration_dataset_types)
200 if expected is None:
201 expected = str(predicate)
202 new_predicate = visitor.run(predicate, join_operands=join_operands)
203 self.assertEqual(str(new_predicate), expected)
204 return visitor
206 def test_trivial(self) -> None:
207 """Test the overlaps visitor when there is nothing spatial or temporal
208 in the query at all.
209 """
210 x = ExpressionFactory(self.universe)
211 # Trivial predicate.
212 visitor = self.run_visitor(["physical_filter"], qt.Predicate.from_bool(True))
213 self.assertFalse(visitor.spatial_joins)
214 self.assertFalse(visitor.spatial_constraints)
215 self.assertFalse(visitor.temporal_dimension_joins)
216 # Non-overlap predicate.
217 visitor = self.run_visitor(["physical_filter"], x.any(x.band == "r", x.band == "i"))
218 self.assertFalse(visitor.spatial_joins)
219 self.assertFalse(visitor.spatial_constraints)
220 self.assertFalse(visitor.temporal_dimension_joins)
222 def test_one_spatial_family(self) -> None:
223 """Test the overlaps visitor when there is one spatial family."""
224 x = ExpressionFactory(self.universe)
225 pixelization = Mq3cPixelization(10)
226 region = pixelization.quad(12058870)
227 # Trivial predicate.
228 visitor = self.run_visitor(["visit"], qt.Predicate.from_bool(True))
229 self.assertFalse(visitor.spatial_joins)
230 self.assertFalse(visitor.spatial_constraints)
231 self.assertFalse(visitor.temporal_dimension_joins)
232 # Non-overlap predicate.
233 visitor = self.run_visitor(["visit"], x.any(x.band == "r", x.visit > 2))
234 self.assertFalse(visitor.spatial_joins)
235 self.assertFalse(visitor.spatial_constraints)
236 self.assertFalse(visitor.temporal_dimension_joins)
237 # Spatial constraint predicate, in various positions relative to other
238 # non-overlap predicates.
239 visitor = self.run_visitor(["visit"], x.visit.region.overlaps(region))
240 self.assertEqual(visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags(0))])
241 visitor = self.run_visitor(["visit"], x.all(x.visit.region.overlaps(region), x.band == "r"))
242 self.assertEqual(
243 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_AND_SIBLINGS)]
244 )
245 visitor = self.run_visitor(["visit"], x.any(x.visit.region.overlaps(region), x.band == "r"))
246 self.assertEqual(
247 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_OR_SIBLINGS)]
248 )
249 visitor = self.run_visitor(
250 ["visit"],
251 x.all(
252 x.any(x.literal(region).overlaps(x.visit.region), x.band == "r"),
253 x.visit.observation_reason == "science",
254 ),
255 )
256 self.assertEqual(
257 visitor.spatial_constraints,
258 [
259 (
260 self.universe["visit"],
261 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS,
262 )
263 ],
264 )
265 visitor = self.run_visitor(
266 ["visit"],
267 x.any(
268 x.all(x.visit.region.overlaps(region), x.band == "r"),
269 x.visit.observation_reason == "science",
270 ),
271 )
272 self.assertEqual(
273 visitor.spatial_constraints,
274 [
275 (
276 self.universe["visit"],
277 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS,
278 )
279 ],
280 )
281 # A spatial join between dimensions in the same family is an error.
282 with self.assertRaises(InvalidQueryError):
283 self.run_visitor(["patch", "tract"], x.patch.region.overlaps(x.tract.region))
285 def test_single_unambiguous_spatial_join(self) -> None:
286 """Test the overlaps visitor when there are two spatial families with
287 one dimension element in each, and hence exactly one join is needed.
288 """
289 x = ExpressionFactory(self.universe)
290 # Trivial predicate; an automatic join is added. Order of elements in
291 # automatic joins is lexicographical in order to be deterministic.
292 visitor = self.run_visitor(
293 ["visit", "tract"], qt.Predicate.from_bool(True), "tract.region OVERLAPS visit.region"
294 )
295 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)])
296 self.assertFalse(visitor.spatial_constraints)
297 self.assertFalse(visitor.temporal_dimension_joins)
298 # Non-overlap predicate; an automatic join is added.
299 visitor = self.run_visitor(
300 ["visit", "tract"],
301 x.all(x.band == "r", x.visit > 2),
302 "band == 'r' AND visit > 2 AND tract.region OVERLAPS visit.region",
303 )
304 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)])
305 self.assertFalse(visitor.spatial_constraints)
306 self.assertFalse(visitor.temporal_dimension_joins)
307 # The same overlap predicate that would be added automatically has been
308 # added manually.
309 visitor = self.run_visitor(
310 ["visit", "tract"],
311 x.tract.region.overlaps(x.visit.region),
312 "tract.region OVERLAPS visit.region",
313 )
314 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))])
315 self.assertFalse(visitor.spatial_constraints)
316 self.assertFalse(visitor.temporal_dimension_joins)
317 # Add the join overlap predicate in an OR expression, which is unusual
318 # but enough to block the addition of an automatic join; we assume the
319 # user knows what they're doing.
320 visitor = self.run_visitor(
321 ["visit", "tract"],
322 x.any(x.visit > 2, x.tract.region.overlaps(x.visit.region)),
323 "visit > 2 OR tract.region OVERLAPS visit.region",
324 )
325 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_OR_SIBLINGS)])
326 self.assertFalse(visitor.spatial_constraints)
327 self.assertFalse(visitor.temporal_dimension_joins)
328 # Add the join overlap predicate in a NOT expression, which is unusual
329 # but permitted in the same sense as OR expressions.
330 visitor = self.run_visitor(
331 ["visit", "tract"],
332 x.not_(x.tract.region.overlaps(x.visit.region)),
333 "NOT tract.region OVERLAPS visit.region",
334 )
335 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.INVERTED)])
336 self.assertFalse(visitor.spatial_constraints)
337 self.assertFalse(visitor.temporal_dimension_joins)
338 # Add a "join operand" whose dimensions include both spatial families.
339 # This blocks an automatic join from being created, because we assume
340 # that join operand (e.g. a materialization or dataset search) already
341 # encodes some spatial join.
342 visitor = self.run_visitor(
343 ["visit", "tract"],
344 qt.Predicate.from_bool(True),
345 "True",
346 join_operands=[self.universe.conform(["tract", "visit"])],
347 )
348 self.assertFalse(visitor.spatial_joins)
349 self.assertFalse(visitor.spatial_constraints)
350 self.assertFalse(visitor.temporal_dimension_joins)
352 def test_single_flexible_spatial_join(self) -> None:
353 """Test the overlaps visitor when there are two spatial families and
354 one has multiple dimension elements.
355 """
356 x = ExpressionFactory(self.universe)
357 # Trivial predicate; an automatic join between the fine-grained
358 # elements is added. Order of elements in automatic joins is
359 # lexicographical in order to be deterministic.
360 visitor = self.run_visitor(
361 ["visit", "detector", "patch"],
362 qt.Predicate.from_bool(True),
363 "patch.region OVERLAPS visit_detector_region.region",
364 )
365 self.assertEqual(
366 visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags.HAS_AND_SIBLINGS)]
367 )
368 self.assertFalse(visitor.spatial_constraints)
369 self.assertFalse(visitor.temporal_dimension_joins)
370 # The same overlap predicate that would be added automatically has been
371 # added manually.
372 visitor = self.run_visitor(
373 ["visit", "detector", "patch"],
374 x.patch.region.overlaps(x.visit_detector_region.region),
375 "patch.region OVERLAPS visit_detector_region.region",
376 )
377 self.assertEqual(visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags(0))])
378 self.assertFalse(visitor.spatial_constraints)
379 self.assertFalse(visitor.temporal_dimension_joins)
380 # A coarse overlap join has been added; respect it and do not add an
381 # automatic one.
382 visitor = self.run_visitor(
383 ["visit", "detector", "patch"],
384 x.tract.region.overlaps(x.visit.region),
385 "tract.region OVERLAPS visit.region",
386 )
387 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))])
388 self.assertFalse(visitor.spatial_constraints)
389 self.assertFalse(visitor.temporal_dimension_joins)
390 # Add a "join operand" whose dimensions include both spatial families
391 # with the most fine-grained dimensions in the query.
392 # This blocks an automatic join from being created, because we assume
393 # that join operand (e.g. a materialization or dataset search) already
394 # encodes some spatial join.
395 visitor = self.run_visitor(
396 ["visit", "detector", "patch"],
397 qt.Predicate.from_bool(True),
398 "True",
399 join_operands=[self.universe.conform(["patch", "visit_detector_region"])],
400 )
401 self.assertFalse(visitor.spatial_joins)
402 self.assertFalse(visitor.spatial_constraints)
403 self.assertFalse(visitor.temporal_dimension_joins)
405 def test_multiple_spatial_joins(self) -> None:
406 """Test the overlaps visitor when there are >2 spatial families."""
407 x = ExpressionFactory(self.universe)
408 # Trivial predicate. This is an error, because we cannot generate
409 # automatic spatial joins when there are more than two families
410 with self.assertRaises(InvalidQueryError):
411 self.run_visitor(["visit", "patch", "htm7"], qt.Predicate.from_bool(True))
412 # Predicate that joins one pair of families but orphans the the other;
413 # also an error.
414 with self.assertRaises(InvalidQueryError):
415 self.run_visitor(["visit", "patch", "htm7"], x.visit.region.overlaps(x.htm7.region))
416 # A sufficient overlap join predicate has been added; each family is
417 # connected to at least one other.
418 visitor = self.run_visitor(
419 ["visit", "patch", "htm7"],
420 x.all(x.tract.region.overlaps(x.visit.region), x.tract.region.overlaps(x.htm7.region)),
421 "tract.region OVERLAPS visit.region AND tract.region OVERLAPS htm7.region",
422 )
423 self.assertEqual(
424 visitor.spatial_joins,
425 [
426 ("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS),
427 ("tract", "htm7", PredicateVisitFlags.HAS_AND_SIBLINGS),
428 ],
429 )
430 self.assertFalse(visitor.spatial_constraints)
431 self.assertFalse(visitor.temporal_dimension_joins)
432 # Add a "join operand" whose dimensions includes two spatial families,
433 # with the most fine-grained dimensions in the query, and a predicate
434 # that joins the third in.
435 visitor = self.run_visitor(
436 ["visit", "patch", "htm7"],
437 x.tract.region.overlaps(x.htm7.region),
438 "tract.region OVERLAPS htm7.region",
439 join_operands=[self.universe.conform(["visit", "patch"])],
440 )
441 self.assertEqual(
442 visitor.spatial_joins,
443 [
444 ("tract", "htm7", PredicateVisitFlags(0)),
445 ],
446 )
447 self.assertFalse(visitor.spatial_constraints)
448 self.assertFalse(visitor.temporal_dimension_joins)
450 def test_one_temporal_family(self) -> None:
451 """Test the overlaps visitor when there is one temporal family."""
452 x = ExpressionFactory(self.universe)
453 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
454 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
455 timespan = Timespan(begin, end)
456 # Trivial predicate.
457 visitor = self.run_visitor(["exposure"], qt.Predicate.from_bool(True))
458 self.assertFalse(visitor.spatial_joins)
459 self.assertFalse(visitor.spatial_constraints)
460 self.assertFalse(visitor.temporal_dimension_joins)
461 # Non-overlap predicate.
462 visitor = self.run_visitor(["exposure"], x.any(x.band == "r", x.exposure > 2))
463 self.assertFalse(visitor.spatial_joins)
464 self.assertFalse(visitor.spatial_constraints)
465 self.assertFalse(visitor.temporal_dimension_joins)
466 # Temporal constraint predicate.
467 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(timespan))
468 self.assertFalse(visitor.spatial_joins)
469 self.assertFalse(visitor.spatial_constraints)
470 self.assertFalse(visitor.temporal_dimension_joins)
471 # A temporal join between dimensions in the same family is an error.
472 with self.assertRaises(InvalidQueryError):
473 self.run_visitor(["exposure", "visit"], x.exposure.timespan.overlaps(x.visit.timespan))
474 # Overlap join with a calibration dataset's validity ranges.
475 visitor = self.run_visitor(
476 ["exposure"], x.exposure.timespan.overlaps(x["bias"].timespan), calibration_dataset_types={"bias"}
477 )
478 self.assertFalse(visitor.spatial_joins)
479 self.assertFalse(visitor.spatial_constraints)
480 self.assertFalse(visitor.temporal_dimension_joins)
481 self.assertEqual(
482 visitor.validity_range_dimension_joins, [("bias", "exposure", PredicateVisitFlags(0))]
483 )
484 self.assertFalse(visitor.validity_range_joins)
485 # Overlap join between two calibration dataset validity ranges.
486 # (It's not clear this kind of query is ever useful in practice, but
487 # there's a good consistency argument for what it ought to do).
488 visitor = self.run_visitor(
489 [], x["flat"].timespan.overlaps(x["bias"].timespan), calibration_dataset_types={"bias", "flat"}
490 )
491 self.assertFalse(visitor.spatial_joins)
492 self.assertFalse(visitor.spatial_constraints)
493 self.assertFalse(visitor.temporal_dimension_joins)
494 self.assertFalse(visitor.validity_range_dimension_joins)
495 self.assertEqual(visitor.validity_range_joins, [("flat", "bias", PredicateVisitFlags(0))])
497 # There are no tests for temporal dimension joins, because the default
498 # dimension universe only has one spatial family, and the untested logic
499 # trivially duplicates the spatial-join logic.
502class NaiveDisjointSetTestCase(unittest.TestCase):
503 """Test the naive disjoint-set implementation that backs automatic overlap
504 join creation.
505 """
507 def test_naive_disjoint_set(self) -> None:
508 s = _NaiveDisjointSet(range(8))
509 self.assertCountEqual(s.subsets(), [{n} for n in range(8)])
510 s.merge(3, 4)
511 self.assertCountEqual(s.subsets(), [{0}, {1}, {2}, {3, 4}, {5}, {6}, {7}])
512 s.merge(2, 1)
513 self.assertCountEqual(s.subsets(), [{0}, {1, 2}, {3, 4}, {5}, {6}, {7}])
514 s.merge(1, 3)
515 self.assertCountEqual(s.subsets(), [{0}, {1, 2, 3, 4}, {5}, {6}, {7}])
518if __name__ == "__main__":
519 unittest.main()