Coverage for tests/test_query_utilities.py: 13%
247 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 10:53 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 10:53 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for non-public Butler._query functionality that is not specific to
29any Butler or QueryDriver implementation.
30"""
32from __future__ import annotations
34import unittest
35from collections.abc import Iterable
37import astropy.time
38from lsst.daf.butler import DimensionUniverse, Timespan
39from lsst.daf.butler.dimensions import DimensionElement, DimensionGroup
40from lsst.daf.butler.queries import tree as qt
41from lsst.daf.butler.queries.expression_factory import ExpressionFactory
42from lsst.daf.butler.queries.overlaps import OverlapsVisitor, _NaiveDisjointSet
43from lsst.daf.butler.queries.visitors import PredicateVisitFlags
44from lsst.sphgeom import Mq3cPixelization, Region
47class ColumnSetTestCase(unittest.TestCase):
48 """Tests for lsst.daf.butler.queries.ColumnSet."""
50 def setUp(self) -> None:
51 self.universe = DimensionUniverse()
53 def test_basics(self) -> None:
54 columns = qt.ColumnSet(self.universe.conform(["detector"]))
55 self.assertNotEqual(columns, columns.dimensions.names) # intentionally not comparable to other sets
56 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
57 self.assertFalse(columns.dataset_fields)
58 columns.dataset_fields["bias"].add("dataset_id")
59 self.assertEqual(dict(columns.dataset_fields), {"bias": {"dataset_id"}})
60 columns.dimension_fields["detector"].add("purpose")
61 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
62 self.assertTrue(columns)
63 self.assertEqual(
64 list(columns),
65 [(k, None) for k in columns.dimensions.data_coordinate_keys]
66 + [("detector", "purpose"), ("bias", "dataset_id")],
67 )
68 self.assertEqual(str(columns), "{instrument, detector, detector:purpose, bias:dataset_id}")
69 empty = qt.ColumnSet(self.universe.empty.as_group())
70 self.assertFalse(empty)
71 self.assertFalse(columns.issubset(empty))
72 self.assertTrue(columns.issuperset(empty))
73 self.assertTrue(columns.isdisjoint(empty))
74 copy = columns.copy()
75 self.assertEqual(columns, copy)
76 self.assertTrue(columns.issubset(copy))
77 self.assertTrue(columns.issuperset(copy))
78 self.assertFalse(columns.isdisjoint(copy))
79 copy.dataset_fields["bias"].add("timespan")
80 copy.dimension_fields["detector"].add("name")
81 copy.update_dimensions(self.universe.conform(["band"]))
82 self.assertEqual(copy.dataset_fields["bias"], {"dataset_id", "timespan"})
83 self.assertEqual(columns.dataset_fields["bias"], {"dataset_id"})
84 self.assertEqual(copy.dimension_fields["detector"], {"purpose", "name"})
85 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
86 self.assertTrue(columns.issubset(copy))
87 self.assertFalse(columns.issuperset(copy))
88 self.assertFalse(columns.isdisjoint(copy))
89 columns.update(copy)
90 self.assertEqual(columns, copy)
91 self.assertTrue(columns.is_timespan("visit", "timespan"))
92 self.assertFalse(columns.is_timespan("visit", None))
93 self.assertFalse(columns.is_timespan("detector", "purpose"))
95 def test_drop_dimension_keys(self):
96 columns = qt.ColumnSet(self.universe.conform(["physical_filter"]))
97 columns.drop_implied_dimension_keys()
98 self.assertEqual(list(columns), [("instrument", None), ("physical_filter", None)])
99 undropped = qt.ColumnSet(columns.dimensions)
100 self.assertTrue(columns.issubset(undropped))
101 self.assertFalse(columns.issuperset(undropped))
102 self.assertFalse(columns.isdisjoint(undropped))
103 band_only = qt.ColumnSet(self.universe.conform(["band"]))
104 self.assertFalse(columns.issubset(band_only))
105 self.assertFalse(columns.issuperset(band_only))
106 self.assertTrue(columns.isdisjoint(band_only))
107 copy = columns.copy()
108 copy.update(band_only)
109 self.assertEqual(copy, undropped)
110 columns.restore_dimension_keys()
111 self.assertEqual(columns, undropped)
113 def test_get_column_spec(self) -> None:
114 columns = qt.ColumnSet(self.universe.conform(["detector"]))
115 columns.dimension_fields["detector"].add("purpose")
116 columns.dataset_fields["bias"].update(["dataset_id", "run", "collection", "timespan", "ingest_date"])
117 self.assertEqual(columns.get_column_spec("instrument", None).name, "instrument")
118 self.assertEqual(columns.get_column_spec("instrument", None).type, "string")
119 self.assertEqual(columns.get_column_spec("instrument", None).nullable, False)
120 self.assertEqual(columns.get_column_spec("detector", None).name, "detector")
121 self.assertEqual(columns.get_column_spec("detector", None).type, "int")
122 self.assertEqual(columns.get_column_spec("detector", None).nullable, False)
123 self.assertEqual(columns.get_column_spec("detector", "purpose").name, "detector:purpose")
124 self.assertEqual(columns.get_column_spec("detector", "purpose").type, "string")
125 self.assertEqual(columns.get_column_spec("detector", "purpose").nullable, True)
126 self.assertEqual(columns.get_column_spec("bias", "dataset_id").name, "bias:dataset_id")
127 self.assertEqual(columns.get_column_spec("bias", "dataset_id").type, "uuid")
128 self.assertEqual(columns.get_column_spec("bias", "dataset_id").nullable, False)
129 self.assertEqual(columns.get_column_spec("bias", "run").name, "bias:run")
130 self.assertEqual(columns.get_column_spec("bias", "run").type, "string")
131 self.assertEqual(columns.get_column_spec("bias", "run").nullable, False)
132 self.assertEqual(columns.get_column_spec("bias", "collection").name, "bias:collection")
133 self.assertEqual(columns.get_column_spec("bias", "collection").type, "string")
134 self.assertEqual(columns.get_column_spec("bias", "collection").nullable, False)
135 self.assertEqual(columns.get_column_spec("bias", "timespan").name, "bias:timespan")
136 self.assertEqual(columns.get_column_spec("bias", "timespan").type, "timespan")
137 self.assertEqual(columns.get_column_spec("bias", "timespan").nullable, False)
138 self.assertEqual(columns.get_column_spec("bias", "ingest_date").name, "bias:ingest_date")
139 self.assertEqual(columns.get_column_spec("bias", "ingest_date").type, "datetime")
140 self.assertEqual(columns.get_column_spec("bias", "ingest_date").nullable, True)
143class _RecordingOverlapsVisitor(OverlapsVisitor):
144 def __init__(self, dimensions: DimensionGroup):
145 super().__init__(dimensions)
146 self.spatial_constraints: list[tuple[str, PredicateVisitFlags]] = []
147 self.spatial_joins: list[tuple[str, str, PredicateVisitFlags]] = []
148 self.temporal_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = []
150 def visit_spatial_constraint(
151 self, element: DimensionElement, region: Region, flags: PredicateVisitFlags
152 ) -> qt.Predicate | None:
153 self.spatial_constraints.append((element.name, flags))
154 return super().visit_spatial_constraint(element, region, flags)
156 def visit_spatial_join(
157 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
158 ) -> qt.Predicate | None:
159 self.spatial_joins.append((a.name, b.name, flags))
160 return super().visit_spatial_join(a, b, flags)
162 def visit_temporal_dimension_join(
163 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
164 ) -> qt.Predicate | None:
165 self.temporal_dimension_joins.append((a.name, b.name, flags))
166 return super().visit_temporal_dimension_join(a, b, flags)
169class OverlapsVisitorTestCase(unittest.TestCase):
170 """Tests for lsst.daf.butler.queries.overlaps.OverlapsVisitor, which is
171 responsible for validating and inferring spatial and temporal joins and
172 constraints.
173 """
175 def setUp(self) -> None:
176 self.universe = DimensionUniverse()
178 def run_visitor(
179 self,
180 dimensions: Iterable[str],
181 predicate: qt.Predicate,
182 expected: str | None = None,
183 join_operands: Iterable[DimensionGroup] = (),
184 ) -> _RecordingOverlapsVisitor:
185 visitor = _RecordingOverlapsVisitor(self.universe.conform(dimensions))
186 if expected is None:
187 expected = str(predicate)
188 new_predicate = visitor.run(predicate, join_operands=join_operands)
189 self.assertEqual(str(new_predicate), expected)
190 return visitor
192 def test_trivial(self) -> None:
193 """Test the overlaps visitor when there is nothing spatial or temporal
194 in the query at all.
195 """
196 x = ExpressionFactory(self.universe)
197 # Trivial predicate.
198 visitor = self.run_visitor(["physical_filter"], qt.Predicate.from_bool(True))
199 self.assertFalse(visitor.spatial_joins)
200 self.assertFalse(visitor.spatial_constraints)
201 self.assertFalse(visitor.temporal_dimension_joins)
202 # Non-overlap predicate.
203 visitor = self.run_visitor(["physical_filter"], x.any(x.band == "r", x.band == "i"))
204 self.assertFalse(visitor.spatial_joins)
205 self.assertFalse(visitor.spatial_constraints)
206 self.assertFalse(visitor.temporal_dimension_joins)
208 def test_one_spatial_family(self) -> None:
209 """Test the overlaps visitor when there is one spatial family."""
210 x = ExpressionFactory(self.universe)
211 pixelization = Mq3cPixelization(10)
212 region = pixelization.quad(12058870)
213 # Trivial predicate.
214 visitor = self.run_visitor(["visit"], qt.Predicate.from_bool(True))
215 self.assertFalse(visitor.spatial_joins)
216 self.assertFalse(visitor.spatial_constraints)
217 self.assertFalse(visitor.temporal_dimension_joins)
218 # Non-overlap predicate.
219 visitor = self.run_visitor(["visit"], x.any(x.band == "r", x.visit > 2))
220 self.assertFalse(visitor.spatial_joins)
221 self.assertFalse(visitor.spatial_constraints)
222 self.assertFalse(visitor.temporal_dimension_joins)
223 # Spatial constraint predicate, in various positions relative to other
224 # non-overlap predicates.
225 visitor = self.run_visitor(["visit"], x.visit.region.overlaps(region))
226 self.assertEqual(visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags(0))])
227 visitor = self.run_visitor(["visit"], x.all(x.visit.region.overlaps(region), x.band == "r"))
228 self.assertEqual(
229 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_AND_SIBLINGS)]
230 )
231 visitor = self.run_visitor(["visit"], x.any(x.visit.region.overlaps(region), x.band == "r"))
232 self.assertEqual(
233 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_OR_SIBLINGS)]
234 )
235 visitor = self.run_visitor(
236 ["visit"],
237 x.all(
238 x.any(x.literal(region).overlaps(x.visit.region), x.band == "r"),
239 x.visit.observation_reason == "science",
240 ),
241 )
242 self.assertEqual(
243 visitor.spatial_constraints,
244 [
245 (
246 self.universe["visit"],
247 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS,
248 )
249 ],
250 )
251 visitor = self.run_visitor(
252 ["visit"],
253 x.any(
254 x.all(x.visit.region.overlaps(region), x.band == "r"),
255 x.visit.observation_reason == "science",
256 ),
257 )
258 self.assertEqual(
259 visitor.spatial_constraints,
260 [
261 (
262 self.universe["visit"],
263 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS,
264 )
265 ],
266 )
267 # A spatial join between dimensions in the same family is an error.
268 with self.assertRaises(qt.InvalidQueryError):
269 self.run_visitor(["patch", "tract"], x.patch.region.overlaps(x.tract.region))
271 def test_single_unambiguous_spatial_join(self) -> None:
272 """Test the overlaps visitor when there are two spatial families with
273 one dimension element in each, and hence exactly one join is needed.
274 """
275 x = ExpressionFactory(self.universe)
276 # Trivial predicate; an automatic join is added. Order of elements in
277 # automatic joins is lexicographical in order to be deterministic.
278 visitor = self.run_visitor(
279 ["visit", "tract"], qt.Predicate.from_bool(True), "tract.region OVERLAPS visit.region"
280 )
281 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)])
282 self.assertFalse(visitor.spatial_constraints)
283 self.assertFalse(visitor.temporal_dimension_joins)
284 # Non-overlap predicate; an automatic join is added.
285 visitor = self.run_visitor(
286 ["visit", "tract"],
287 x.all(x.band == "r", x.visit > 2),
288 "band == 'r' AND visit > 2 AND tract.region OVERLAPS visit.region",
289 )
290 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)])
291 self.assertFalse(visitor.spatial_constraints)
292 self.assertFalse(visitor.temporal_dimension_joins)
293 # The same overlap predicate that would be added automatically has been
294 # added manually.
295 visitor = self.run_visitor(
296 ["visit", "tract"],
297 x.tract.region.overlaps(x.visit.region),
298 "tract.region OVERLAPS visit.region",
299 )
300 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))])
301 self.assertFalse(visitor.spatial_constraints)
302 self.assertFalse(visitor.temporal_dimension_joins)
303 # Add the join overlap predicate in an OR expression, which is unusual
304 # but enough to block the addition of an automatic join; we assume the
305 # user knows what they're doing.
306 visitor = self.run_visitor(
307 ["visit", "tract"],
308 x.any(x.visit > 2, x.tract.region.overlaps(x.visit.region)),
309 "visit > 2 OR tract.region OVERLAPS visit.region",
310 )
311 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_OR_SIBLINGS)])
312 self.assertFalse(visitor.spatial_constraints)
313 self.assertFalse(visitor.temporal_dimension_joins)
314 # Add the join overlap predicate in a NOT expression, which is unusual
315 # but permitted in the same sense as OR expressions.
316 visitor = self.run_visitor(
317 ["visit", "tract"],
318 x.not_(x.tract.region.overlaps(x.visit.region)),
319 "NOT tract.region OVERLAPS visit.region",
320 )
321 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.INVERTED)])
322 self.assertFalse(visitor.spatial_constraints)
323 self.assertFalse(visitor.temporal_dimension_joins)
324 # Add a "join operand" whose dimensions include both spatial families.
325 # This blocks an automatic join from being created, because we assume
326 # that join operand (e.g. a materialization or dataset search) already
327 # encodes some spatial join.
328 visitor = self.run_visitor(
329 ["visit", "tract"],
330 qt.Predicate.from_bool(True),
331 "True",
332 join_operands=[self.universe.conform(["tract", "visit"])],
333 )
334 self.assertFalse(visitor.spatial_joins)
335 self.assertFalse(visitor.spatial_constraints)
336 self.assertFalse(visitor.temporal_dimension_joins)
338 def test_single_flexible_spatial_join(self) -> None:
339 """Test the overlaps visitor when there are two spatial families and
340 one has multiple dimension elements.
341 """
342 x = ExpressionFactory(self.universe)
343 # Trivial predicate; an automatic join between the fine-grained
344 # elements is added. Order of elements in automatic joins is
345 # lexicographical in order to be deterministic.
346 visitor = self.run_visitor(
347 ["visit", "detector", "patch"],
348 qt.Predicate.from_bool(True),
349 "patch.region OVERLAPS visit_detector_region.region",
350 )
351 self.assertEqual(
352 visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags.HAS_AND_SIBLINGS)]
353 )
354 self.assertFalse(visitor.spatial_constraints)
355 self.assertFalse(visitor.temporal_dimension_joins)
356 # The same overlap predicate that would be added automatically has been
357 # added manually.
358 visitor = self.run_visitor(
359 ["visit", "detector", "patch"],
360 x.patch.region.overlaps(x.visit_detector_region.region),
361 "patch.region OVERLAPS visit_detector_region.region",
362 )
363 self.assertEqual(visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags(0))])
364 self.assertFalse(visitor.spatial_constraints)
365 self.assertFalse(visitor.temporal_dimension_joins)
366 # A coarse overlap join has been added; respect it and do not add an
367 # automatic one.
368 visitor = self.run_visitor(
369 ["visit", "detector", "patch"],
370 x.tract.region.overlaps(x.visit.region),
371 "tract.region OVERLAPS visit.region",
372 )
373 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))])
374 self.assertFalse(visitor.spatial_constraints)
375 self.assertFalse(visitor.temporal_dimension_joins)
376 # Add a "join operand" whose dimensions include both spatial families.
377 # This blocks an automatic join from being created, because we assume
378 # that join operand (e.g. a materialization or dataset search) already
379 # encodes some spatial join.
380 visitor = self.run_visitor(
381 ["visit", "detector", "patch"],
382 qt.Predicate.from_bool(True),
383 "True",
384 join_operands=[self.universe.conform(["tract", "visit_detector_region"])],
385 )
386 self.assertFalse(visitor.spatial_joins)
387 self.assertFalse(visitor.spatial_constraints)
388 self.assertFalse(visitor.temporal_dimension_joins)
390 def test_multiple_spatial_joins(self) -> None:
391 """Test the overlaps visitor when there are >2 spatial families."""
392 x = ExpressionFactory(self.universe)
393 # Trivial predicate. This is an error, because we cannot generate
394 # automatic spatial joins when there are more than two families
395 with self.assertRaises(qt.InvalidQueryError):
396 self.run_visitor(["visit", "patch", "htm7"], qt.Predicate.from_bool(True))
397 # Predicate that joins one pair of families but orphans the the other;
398 # also an error.
399 with self.assertRaises(qt.InvalidQueryError):
400 self.run_visitor(["visit", "patch", "htm7"], x.visit.region.overlaps(x.htm7.region))
401 # A sufficient overlap join predicate has been added; each family is
402 # connected to at least one other.
403 visitor = self.run_visitor(
404 ["visit", "patch", "htm7"],
405 x.all(x.tract.region.overlaps(x.visit.region), x.tract.region.overlaps(x.htm7.region)),
406 "tract.region OVERLAPS visit.region AND tract.region OVERLAPS htm7.region",
407 )
408 self.assertEqual(
409 visitor.spatial_joins,
410 [
411 ("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS),
412 ("tract", "htm7", PredicateVisitFlags.HAS_AND_SIBLINGS),
413 ],
414 )
415 self.assertFalse(visitor.spatial_constraints)
416 self.assertFalse(visitor.temporal_dimension_joins)
417 # Add a "join operand" whose dimensions includes two spatial families,
418 # with a predicate that joins the third in.
419 visitor = self.run_visitor(
420 ["visit", "patch", "htm7"],
421 x.tract.region.overlaps(x.htm7.region),
422 "tract.region OVERLAPS htm7.region",
423 join_operands=[self.universe.conform(["visit", "tract"])],
424 )
425 self.assertEqual(
426 visitor.spatial_joins,
427 [
428 ("tract", "htm7", PredicateVisitFlags(0)),
429 ],
430 )
431 self.assertFalse(visitor.spatial_constraints)
432 self.assertFalse(visitor.temporal_dimension_joins)
434 def test_one_temporal_family(self) -> None:
435 """Test the overlaps visitor when there is one temporal family."""
436 x = ExpressionFactory(self.universe)
437 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
438 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
439 timespan = Timespan(begin, end)
440 # Trivial predicate.
441 visitor = self.run_visitor(["exposure"], qt.Predicate.from_bool(True))
442 self.assertFalse(visitor.spatial_joins)
443 self.assertFalse(visitor.spatial_constraints)
444 self.assertFalse(visitor.temporal_dimension_joins)
445 # Non-overlap predicate.
446 visitor = self.run_visitor(["exposure"], x.any(x.band == "r", x.exposure > 2))
447 self.assertFalse(visitor.spatial_joins)
448 self.assertFalse(visitor.spatial_constraints)
449 self.assertFalse(visitor.temporal_dimension_joins)
450 # Temporal constraint predicate.
451 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(timespan))
452 self.assertFalse(visitor.spatial_joins)
453 self.assertFalse(visitor.spatial_constraints)
454 self.assertFalse(visitor.temporal_dimension_joins)
455 # A temporal join between dimensions in the same family is an error.
456 with self.assertRaises(qt.InvalidQueryError):
457 self.run_visitor(["exposure", "visit"], x.exposure.timespan.overlaps(x.visit.timespan))
458 # Overlap join with a calibration dataset's validity ranges.
459 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(x["bias"].timespan))
460 self.assertFalse(visitor.spatial_joins)
461 self.assertFalse(visitor.spatial_constraints)
462 self.assertFalse(visitor.temporal_dimension_joins)
464 # There are no tests for temporal dimension joins, because the default
465 # dimension universe only has one spatial family, and the untested logic
466 # trivially duplicates the spatial-join logic.
469class NaiveDisjointSetTestCase(unittest.TestCase):
470 """Test the naive disjoint-set implementation that backs automatic overlap
471 join creation.
472 """
474 def test_naive_disjoint_set(self) -> None:
475 s = _NaiveDisjointSet(range(8))
476 self.assertCountEqual(s.subsets(), [{n} for n in range(8)])
477 s.merge(3, 4)
478 self.assertCountEqual(s.subsets(), [{0}, {1}, {2}, {3, 4}, {5}, {6}, {7}])
479 s.merge(2, 1)
480 self.assertCountEqual(s.subsets(), [{0}, {1, 2}, {3, 4}, {5}, {6}, {7}])
481 s.merge(1, 3)
482 self.assertCountEqual(s.subsets(), [{0}, {1, 2, 3, 4}, {5}, {6}, {7}])
485if __name__ == "__main__":
486 unittest.main()