Coverage for tests/test_query_interface.py: 9%
982 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-05 11:36 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-05 11:36 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for the public Butler._query interface and the Pydantic models that
29back it using a mock column-expression visitor and a mock QueryDriver
30implementation.
32These tests are entirely independent of which kind of butler or database
33backend we're using.
35This is a very large test file because a lot of tests make use of those mocks,
36but they're not so generally useful that I think they're worth putting in the
37library proper.
38"""
40from __future__ import annotations
42import dataclasses
43import itertools
44import unittest
45import uuid
46from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
47from typing import Any, cast
49import astropy.time
50from lsst.daf.butler import (
51 CollectionType,
52 DataCoordinate,
53 DataIdValue,
54 DatasetRef,
55 DatasetType,
56 DimensionGroup,
57 DimensionRecord,
58 DimensionRecordSet,
59 DimensionUniverse,
60 MissingDatasetTypeError,
61 NamedValueSet,
62 NoDefaultCollectionError,
63 Timespan,
64)
65from lsst.daf.butler.queries import (
66 DataCoordinateQueryResults,
67 DatasetQueryResults,
68 DimensionRecordQueryResults,
69 Query,
70 SingleTypeDatasetQueryResults,
71)
72from lsst.daf.butler.queries import driver as qd
73from lsst.daf.butler.queries import result_specs as qrs
74from lsst.daf.butler.queries import tree as qt
75from lsst.daf.butler.queries.expression_factory import ExpressionFactory
76from lsst.daf.butler.queries.tree._column_expression import UnaryExpression
77from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands
78from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
79from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError
80from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord
81from lsst.sphgeom import DISJOINT, Mq3cPixelization
84class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]):
85 """Test visitor for column expressions.
87 This visitor evaluates column expressions using regular Python logic.
89 Parameters
90 ----------
91 dimension_keys : `~collections.abc.Mapping`, optional
92 Mapping from dimension name to the value it should be assigned by the
93 visitor.
94 dimension_fields : `~collections.abc.Mapping`, optional
95 Mapping from ``(dimension element name, field)`` tuple to the value it
96 should be assigned by the visitor.
97 dataset_fields : `~collections.abc.Mapping`, optional
98 Mapping from ``(dataset type name, field)`` tuple to the value it
99 should be assigned by the visitor.
100 query_tree_items : `~collections.abc.Set`, optional
101 Set that should be used as the right-hand side of element-in-query
102 predicates.
103 """
105 def __init__(
106 self,
107 dimension_keys: Mapping[str, Any] | None = None,
108 dimension_fields: Mapping[tuple[str, str], Any] | None = None,
109 dataset_fields: Mapping[tuple[str, str], Any] | None = None,
110 query_tree_items: Set[Any] = frozenset(),
111 ):
112 self.dimension_keys = dimension_keys or {}
113 self.dimension_fields = dimension_fields or {}
114 self.dataset_fields = dataset_fields or {}
115 self.query_tree_items = query_tree_items
117 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any:
118 match expression.operator:
119 case "+":
120 return expression.a.visit(self) + expression.b.visit(self)
121 case "-":
122 return expression.a.visit(self) - expression.b.visit(self)
123 case "*":
124 return expression.a.visit(self) * expression.b.visit(self)
125 case "/":
126 match expression.column_type:
127 case "int":
128 return expression.a.visit(self) // expression.b.visit(self)
129 case "float":
130 return expression.a.visit(self) / expression.b.visit(self)
131 case "%":
132 return expression.a.visit(self) % expression.b.visit(self)
134 def visit_comparison(
135 self,
136 a: qt.ColumnExpression,
137 operator: qt.ComparisonOperator,
138 b: qt.ColumnExpression,
139 flags: PredicateVisitFlags,
140 ) -> bool:
141 match operator:
142 case "==":
143 return a.visit(self) == b.visit(self)
144 case "!=":
145 return a.visit(self) != b.visit(self)
146 case "<":
147 return a.visit(self) < b.visit(self)
148 case ">":
149 return a.visit(self) > b.visit(self)
150 case "<=":
151 return a.visit(self) <= b.visit(self)
152 case ">=":
153 return a.visit(self) >= b.visit(self)
154 case "overlaps":
155 return not (a.visit(self).relate(b.visit(self)) & DISJOINT)
157 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any:
158 return self.dataset_fields[expression.dataset_type, expression.field]
160 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any:
161 return self.dimension_fields[expression.element.name, expression.field]
163 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any:
164 return self.dimension_keys[expression.dimension.name]
166 def visit_in_container(
167 self,
168 member: qt.ColumnExpression,
169 container: tuple[qt.ColumnExpression, ...],
170 flags: PredicateVisitFlags,
171 ) -> bool:
172 return member.visit(self) in [item.visit(self) for item in container]
174 def visit_in_range(
175 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
176 ) -> bool:
177 return member.visit(self) in range(start, stop, step)
179 def visit_in_query_tree(
180 self,
181 member: qt.ColumnExpression,
182 column: qt.ColumnExpression,
183 query_tree: qt.QueryTree,
184 flags: PredicateVisitFlags,
185 ) -> bool:
186 return member.visit(self) in self.query_tree_items
188 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool:
189 return operand.visit(self) is None
191 def visit_literal(self, expression: qt.ColumnLiteral) -> Any:
192 return expression.get_literal_value()
194 def visit_reversed(self, expression: qt.Reversed) -> Any:
195 return _TestReversed(expression.operand.visit(self))
197 def visit_unary_expression(self, expression: UnaryExpression) -> Any:
198 match expression.operator:
199 case "-":
200 return -expression.operand.visit(self)
201 case "begin_of":
202 return expression.operand.visit(self).begin
203 case "end_of":
204 return expression.operand.visit(self).end
206 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool:
207 return all(results)
209 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool:
210 return not result
212 def apply_logical_or(
213 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags
214 ) -> bool:
215 return any(results)
218@dataclasses.dataclass
219class _TestReversed:
220 """Struct used by _TestVisitor" to mark an expression as reversed in sort
221 order.
222 """
224 operand: Any
227class _TestQueryExecution(BaseException):
228 """Exception raised by _TestQueryDriver.execute to communicate its args
229 back to the caller.
230 """
232 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None:
233 self.result_spec = result_spec
234 self.tree = tree
235 self.driver = driver
238class _TestQueryCount(BaseException):
239 """Exception raised by _TestQueryDriver.count to communicate its args
240 back to the caller.
241 """
243 def __init__(
244 self,
245 result_spec: qrs.ResultSpec,
246 tree: qt.QueryTree,
247 driver: _TestQueryDriver,
248 exact: bool,
249 discard: bool,
250 ) -> None:
251 self.result_spec = result_spec
252 self.tree = tree
253 self.driver = driver
254 self.exact = exact
255 self.discard = discard
258class _TestQueryAny(BaseException):
259 """Exception raised by _TestQueryDriver.any to communicate its args
260 back to the caller.
261 """
263 def __init__(
264 self,
265 tree: qt.QueryTree,
266 driver: _TestQueryDriver,
267 exact: bool,
268 execute: bool,
269 ) -> None:
270 self.tree = tree
271 self.driver = driver
272 self.exact = exact
273 self.execute = execute
276class _TestQueryExplainNoResults(BaseException):
277 """Exception raised by _TestQueryDriver.explain_no_results to communicate
278 its args back to the caller.
279 """
281 def __init__(
282 self,
283 tree: qt.QueryTree,
284 driver: _TestQueryDriver,
285 execute: bool,
286 ) -> None:
287 self.tree = tree
288 self.driver = driver
289 self.execute = execute
292class _TestQueryDriver(qd.QueryDriver):
293 """Mock implementation of `QueryDriver` that mostly raises exceptions that
294 communicate the arguments its methods were called with.
296 Parameters
297 ----------
298 default_collections : `tuple` [ `str`, ... ], optional
299 Default collection the query or parent butler is imagined to have been
300 constructed with.
301 collection_info : `~collections.abc.Mapping`, optional
302 Mapping from collection name to its record and summary, simulating the
303 collections present in the data repository.
304 dataset_types : `~collections.abc.Mapping`, optional
305 Mapping from dataset type to its definition, simulating the dataset
306 types registered in the data repository.
307 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional
308 A tuple of iterables of arbitrary type to use as result rows any time
309 `execute` is called, with each nested iterable considered a separate
310 page. The result type is not checked for consistency with the result
311 spec. If this is not provided, `execute` will instead raise
312 `_TestQueryExecution`, and `fetch_page` will not do anything useful.
313 """
315 def __init__(
316 self,
317 default_collections: tuple[str, ...] | None = None,
318 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None,
319 dataset_types: Mapping[str, DatasetType] | None = None,
320 result_rows: tuple[Iterable[Any], ...] | None = None,
321 ) -> None:
322 self._universe = DimensionUniverse()
323 # Mapping of the arguments passed to materialize, keyed by the UUID
324 # that that each call returned.
325 self.materializations: dict[
326 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]]
327 ] = {}
328 # Mapping of the arguments passed to upload_data_coordinates, keyed by
329 # the UUID that that each call returned.
330 self.data_coordinate_uploads: dict[
331 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]]
332 ] = {}
333 self._default_collections = default_collections
334 self._collection_info = collection_info or {}
335 self._dataset_types = dataset_types or {}
336 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = []
337 self._result_rows = result_rows
338 self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {}
340 @property
341 def universe(self) -> DimensionUniverse:
342 return self._universe
344 def __enter__(self) -> None:
345 pass
347 def __exit__(self, *args: Any, **kwargs: Any) -> None:
348 pass
350 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage:
351 if self._result_rows is not None:
352 iterator = iter(self._result_rows)
353 current_rows = next(iterator, ())
354 return self._make_next_page(result_spec, current_rows, iterator)
355 raise _TestQueryExecution(result_spec, tree, self)
357 def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage:
358 if self._result_rows is not None:
359 return self._make_next_page(result_spec, *self._result_iters.pop(key))
360 raise AssertionError("Test query driver not initialized for actual results.")
362 def _make_next_page(
363 self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]]
364 ) -> qd.ResultPage:
365 next_rows = list(next(iterator, ()))
366 if not next_rows:
367 next_key = None
368 else:
369 next_key = uuid.uuid4()
370 self._result_iters[next_key] = (next_rows, iterator)
371 match result_spec:
372 case qrs.DataCoordinateResultSpec():
373 return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
374 case qrs.DimensionRecordResultSpec():
375 return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
376 case qrs.DatasetRefResultSpec():
377 return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
378 case _:
379 raise NotImplementedError("Other query types not yet supported.")
381 def materialize(
382 self,
383 tree: qt.QueryTree,
384 dimensions: DimensionGroup,
385 datasets: frozenset[str],
386 ) -> qd.MaterializationKey:
387 key = uuid.uuid4()
388 self.materializations[key] = (tree, dimensions, datasets)
389 return key
391 def upload_data_coordinates(
392 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
393 ) -> qd.DataCoordinateUploadKey:
394 key = uuid.uuid4()
395 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows))
396 return key
398 def count(
399 self,
400 tree: qt.QueryTree,
401 result_spec: qrs.ResultSpec,
402 *,
403 exact: bool,
404 discard: bool,
405 ) -> int:
406 raise _TestQueryCount(result_spec, tree, self, exact, discard)
408 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
409 raise _TestQueryAny(tree, self, exact, execute)
411 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]:
412 raise _TestQueryExplainNoResults(tree, self, execute)
414 def get_default_collections(self) -> tuple[str, ...]:
415 if self._default_collections is None:
416 raise NoDefaultCollectionError()
417 return self._default_collections
419 def resolve_collection_path(
420 self, collections: Sequence[str], _done: set[str] | None = None
421 ) -> list[tuple[CollectionRecord, CollectionSummary]]:
422 if _done is None:
423 _done = set()
424 result: list[tuple[CollectionRecord, CollectionSummary]] = []
425 for name in collections:
426 if name in _done:
427 continue
428 _done.add(name)
429 record, summary = self._collection_info[name]
430 if record.type is CollectionType.CHAINED:
431 result.extend(
432 self.resolve_collection_path(cast(ChainedCollectionRecord, record).children, _done=_done)
433 )
434 else:
435 result.append((record, summary))
436 return result
438 def get_dataset_type(self, name: str) -> DatasetType:
439 try:
440 return self._dataset_types[name]
441 except KeyError:
442 raise MissingDatasetTypeError(name)
445class ColumnExpressionsTestCase(unittest.TestCase):
446 """Tests for column expression objects in lsst.daf.butler.queries.tree."""
448 def setUp(self) -> None:
449 self.universe = DimensionUniverse()
450 self.x = ExpressionFactory(self.universe)
452 def query(self, **kwargs: Any) -> Query:
453 """Make an initial Query object with the given kwargs used to
454 initialize the _TestQueryDriver.
455 """
456 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
458 def test_int_literals(self) -> None:
459 expr = self.x.unwrap(self.x.literal(5))
460 self.assertEqual(expr.value, 5)
461 self.assertEqual(expr.get_literal_value(), 5)
462 self.assertEqual(expr.expression_type, "int")
463 self.assertEqual(expr.column_type, "int")
464 self.assertEqual(str(expr), "5")
465 self.assertTrue(expr.is_literal)
466 columns = qt.ColumnSet(self.universe.empty.as_group())
467 expr.gather_required_columns(columns)
468 self.assertFalse(columns)
469 self.assertEqual(expr.visit(_TestVisitor()), 5)
471 def test_string_literals(self) -> None:
472 expr = self.x.unwrap(self.x.literal("five"))
473 self.assertEqual(expr.value, "five")
474 self.assertEqual(expr.get_literal_value(), "five")
475 self.assertEqual(expr.expression_type, "string")
476 self.assertEqual(expr.column_type, "string")
477 self.assertEqual(str(expr), "'five'")
478 self.assertTrue(expr.is_literal)
479 columns = qt.ColumnSet(self.universe.empty.as_group())
480 expr.gather_required_columns(columns)
481 self.assertFalse(columns)
482 self.assertEqual(expr.visit(_TestVisitor()), "five")
484 def test_float_literals(self) -> None:
485 expr = self.x.unwrap(self.x.literal(0.5))
486 self.assertEqual(expr.value, 0.5)
487 self.assertEqual(expr.get_literal_value(), 0.5)
488 self.assertEqual(expr.expression_type, "float")
489 self.assertEqual(expr.column_type, "float")
490 self.assertEqual(str(expr), "0.5")
491 self.assertTrue(expr.is_literal)
492 columns = qt.ColumnSet(self.universe.empty.as_group())
493 expr.gather_required_columns(columns)
494 self.assertFalse(columns)
495 self.assertEqual(expr.visit(_TestVisitor()), 0.5)
497 def test_hash_literals(self) -> None:
498 expr = self.x.unwrap(self.x.literal(b"eleven"))
499 self.assertEqual(expr.value, b"eleven")
500 self.assertEqual(expr.get_literal_value(), b"eleven")
501 self.assertEqual(expr.expression_type, "hash")
502 self.assertEqual(expr.column_type, "hash")
503 self.assertEqual(str(expr), "(bytes)")
504 self.assertTrue(expr.is_literal)
505 columns = qt.ColumnSet(self.universe.empty.as_group())
506 expr.gather_required_columns(columns)
507 self.assertFalse(columns)
508 self.assertEqual(expr.visit(_TestVisitor()), b"eleven")
510 def test_uuid_literals(self) -> None:
511 value = uuid.uuid4()
512 expr = self.x.unwrap(self.x.literal(value))
513 self.assertEqual(expr.value, value)
514 self.assertEqual(expr.get_literal_value(), value)
515 self.assertEqual(expr.expression_type, "uuid")
516 self.assertEqual(expr.column_type, "uuid")
517 self.assertEqual(str(expr), str(value))
518 self.assertTrue(expr.is_literal)
519 columns = qt.ColumnSet(self.universe.empty.as_group())
520 expr.gather_required_columns(columns)
521 self.assertFalse(columns)
522 self.assertEqual(expr.visit(_TestVisitor()), value)
524 def test_datetime_literals(self) -> None:
525 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
526 expr = self.x.unwrap(self.x.literal(value))
527 self.assertEqual(expr.value, value)
528 self.assertEqual(expr.get_literal_value(), value)
529 self.assertEqual(expr.expression_type, "datetime")
530 self.assertEqual(expr.column_type, "datetime")
531 self.assertEqual(str(expr), "2020-01-01T00:00:00")
532 self.assertTrue(expr.is_literal)
533 columns = qt.ColumnSet(self.universe.empty.as_group())
534 expr.gather_required_columns(columns)
535 self.assertFalse(columns)
536 self.assertEqual(expr.visit(_TestVisitor()), value)
538 def test_timespan_literals(self) -> None:
539 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
540 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
541 value = Timespan(begin, end)
542 expr = self.x.unwrap(self.x.literal(value))
543 self.assertEqual(expr.value, value)
544 self.assertEqual(expr.get_literal_value(), value)
545 self.assertEqual(expr.expression_type, "timespan")
546 self.assertEqual(expr.column_type, "timespan")
547 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)")
548 self.assertTrue(expr.is_literal)
549 columns = qt.ColumnSet(self.universe.empty.as_group())
550 expr.gather_required_columns(columns)
551 self.assertFalse(columns)
552 self.assertEqual(expr.visit(_TestVisitor()), value)
554 def test_region_literals(self) -> None:
555 pixelization = Mq3cPixelization(10)
556 value = pixelization.quad(12058870)
557 expr = self.x.unwrap(self.x.literal(value))
558 self.assertEqual(expr.value, value)
559 self.assertEqual(expr.get_literal_value(), value)
560 self.assertEqual(expr.expression_type, "region")
561 self.assertEqual(expr.column_type, "region")
562 self.assertEqual(str(expr), "(region)")
563 self.assertTrue(expr.is_literal)
564 columns = qt.ColumnSet(self.universe.empty.as_group())
565 expr.gather_required_columns(columns)
566 self.assertFalse(columns)
567 self.assertEqual(expr.visit(_TestVisitor()), value)
569 def test_invalid_literal(self) -> None:
570 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."):
571 self.x.literal(5j)
573 def test_dimension_key_reference(self) -> None:
574 expr = self.x.unwrap(self.x.detector)
575 self.assertIsNone(expr.get_literal_value())
576 self.assertEqual(expr.expression_type, "dimension_key")
577 self.assertEqual(expr.column_type, "int")
578 self.assertEqual(str(expr), "detector")
579 self.assertFalse(expr.is_literal)
580 columns = qt.ColumnSet(self.universe.empty.as_group())
581 expr.gather_required_columns(columns)
582 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
583 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3)
585 def test_dimension_field_reference(self) -> None:
586 expr = self.x.unwrap(self.x.detector.purpose)
587 self.assertIsNone(expr.get_literal_value())
588 self.assertEqual(expr.expression_type, "dimension_field")
589 self.assertEqual(expr.column_type, "string")
590 self.assertEqual(str(expr), "detector.purpose")
591 self.assertFalse(expr.is_literal)
592 columns = qt.ColumnSet(self.universe.empty.as_group())
593 expr.gather_required_columns(columns)
594 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
595 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
596 with self.assertRaises(qt.InvalidQueryError):
597 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region")
598 self.assertEqual(
599 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science"
600 )
602 def test_dataset_field_reference(self) -> None:
603 expr = self.x.unwrap(self.x["raw"].ingest_date)
604 self.assertIsNone(expr.get_literal_value())
605 self.assertEqual(expr.expression_type, "dataset_field")
606 self.assertEqual(str(expr), "raw.ingest_date")
607 self.assertFalse(expr.is_literal)
608 columns = qt.ColumnSet(self.universe.empty.as_group())
609 expr.gather_required_columns(columns)
610 self.assertEqual(columns.dimensions, self.universe.empty.as_group())
611 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"})
612 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid")
613 self.assertEqual(
614 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string"
615 )
616 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string")
617 self.assertEqual(
618 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime"
619 )
620 self.assertEqual(
621 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan"
622 )
623 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
624 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value)
626 def test_unary_negation(self) -> None:
627 expr = self.x.unwrap(-self.x.visit.exposure_time)
628 self.assertIsNone(expr.get_literal_value())
629 self.assertEqual(expr.expression_type, "unary")
630 self.assertEqual(expr.column_type, "float")
631 self.assertEqual(str(expr), "-visit.exposure_time")
632 self.assertFalse(expr.is_literal)
633 columns = qt.ColumnSet(self.universe.empty.as_group())
634 expr.gather_required_columns(columns)
635 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
636 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
637 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0)
638 with self.assertRaises(qt.InvalidQueryError):
639 qt.UnaryExpression(
640 operand=qt.DimensionFieldReference(
641 element=self.universe.dimensions["detector"], field="purpose"
642 ),
643 operator="-",
644 )
646 def test_unary_timespan_begin(self) -> None:
647 expr = self.x.unwrap(self.x.visit.timespan.begin)
648 self.assertIsNone(expr.get_literal_value())
649 self.assertEqual(expr.expression_type, "unary")
650 self.assertEqual(expr.column_type, "datetime")
651 self.assertEqual(str(expr), "visit.timespan.begin")
652 self.assertFalse(expr.is_literal)
653 columns = qt.ColumnSet(self.universe.empty.as_group())
654 expr.gather_required_columns(columns)
655 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
656 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
657 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
658 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
659 value = Timespan(begin, end)
660 self.assertEqual(
661 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin
662 )
663 with self.assertRaises(qt.InvalidQueryError):
664 qt.UnaryExpression(
665 operand=qt.DimensionFieldReference(
666 element=self.universe.dimensions["detector"], field="purpose"
667 ),
668 operator="begin_of",
669 )
671 def test_unary_timespan_end(self) -> None:
672 expr = self.x.unwrap(self.x.visit.timespan.end)
673 self.assertIsNone(expr.get_literal_value())
674 self.assertEqual(expr.expression_type, "unary")
675 self.assertEqual(expr.column_type, "datetime")
676 self.assertEqual(str(expr), "visit.timespan.end")
677 self.assertFalse(expr.is_literal)
678 columns = qt.ColumnSet(self.universe.empty.as_group())
679 expr.gather_required_columns(columns)
680 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
681 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
682 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
683 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
684 value = Timespan(begin, end)
685 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end)
686 with self.assertRaises(qt.InvalidQueryError):
687 qt.UnaryExpression(
688 operand=qt.DimensionFieldReference(
689 element=self.universe.dimensions["detector"], field="purpose"
690 ),
691 operator="end_of",
692 )
694 def test_binary_expression_float(self) -> None:
695 for proxy, string, value in [
696 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0),
697 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0),
698 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0),
699 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0),
700 (15.0 + -self.x.visit.exposure_time, "15.0 + -visit.exposure_time", -15.0),
701 (10.0 - -self.x.visit.exposure_time, "10.0 - -visit.exposure_time", 40.0),
702 (6.0 * -self.x.visit.exposure_time, "6.0 * -visit.exposure_time", -180.0),
703 (30.0 / -self.x.visit.exposure_time, "30.0 / -visit.exposure_time", -1.0),
704 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0),
705 ((self.x.visit.exposure_time + 15.0) + 45.0, "visit.exposure_time + 15.0 + 45.0", 90.0),
706 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0),
707 # We don't need the parentheses we generate in the next one, but
708 # they're not a problem either.
709 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0),
710 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * (-visit.exposure_time - 15.0)", -270.0),
711 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + -visit.exposure_time - 15.0", 15.0),
712 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / (-visit.exposure_time - 15.0)", -2.0),
713 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - (-visit.exposure_time - 15.0)", 105.0),
714 ]:
715 with self.subTest(string=string):
716 expr = self.x.unwrap(proxy)
717 self.assertIsNone(expr.get_literal_value())
718 self.assertEqual(expr.expression_type, "binary")
719 self.assertEqual(expr.column_type, "float")
720 self.assertEqual(str(expr), string)
721 self.assertFalse(expr.is_literal)
722 columns = qt.ColumnSet(self.universe.empty.as_group())
723 expr.gather_required_columns(columns)
724 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
725 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
726 self.assertEqual(
727 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value
728 )
730 def test_binary_modulus(self) -> None:
731 for proxy, string, value in [
732 (self.x.visit.id % 2, "visit % 2", 1),
733 (52 % self.x.visit, "52 % visit", 2),
734 ]:
735 with self.subTest(string=string):
736 expr = self.x.unwrap(proxy)
737 self.assertIsNone(expr.get_literal_value())
738 self.assertEqual(expr.expression_type, "binary")
739 self.assertEqual(expr.column_type, "int")
740 self.assertEqual(str(expr), string)
741 self.assertFalse(expr.is_literal)
742 columns = qt.ColumnSet(self.universe.empty.as_group())
743 expr.gather_required_columns(columns)
744 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
745 self.assertFalse(columns.dimension_fields["visit"])
746 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value)
748 def test_binary_expression_validation(self) -> None:
749 with self.assertRaises(qt.InvalidQueryError):
750 # No arithmetic operators on strings (we do not interpret + as
751 # concatenation).
752 self.x.instrument + "suffix"
753 with self.assertRaises(qt.InvalidQueryError):
754 # Mixed types are not supported, even when they both support the
755 # operator.
756 self.x.visit.exposure_time + self.x.detector
757 with self.assertRaises(qt.InvalidQueryError):
758 # No modulus for floats.
759 self.x.visit.exposure_time % 5.0
761 def test_reversed(self) -> None:
762 expr = self.x.detector.desc
763 self.assertIsNone(expr.get_literal_value())
764 self.assertEqual(expr.expression_type, "reversed")
765 self.assertEqual(expr.column_type, "int")
766 self.assertEqual(str(expr), "detector DESC")
767 self.assertFalse(expr.is_literal)
768 columns = qt.ColumnSet(self.universe.empty.as_group())
769 expr.gather_required_columns(columns)
770 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
771 self.assertFalse(columns.dimension_fields["detector"])
772 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5))
774 def test_trivial_predicate(self) -> None:
775 """Test logical operations on trivial True/False predicates."""
776 yes = qt.Predicate.from_bool(True)
777 no = qt.Predicate.from_bool(False)
778 maybe: qt.Predicate = self.x.detector == 5
779 for predicate in [
780 yes,
781 yes.logical_or(no),
782 no.logical_or(yes),
783 yes.logical_and(yes),
784 no.logical_not(),
785 yes.logical_or(maybe),
786 maybe.logical_or(yes),
787 ]:
788 self.assertEqual(predicate.column_type, "bool")
789 self.assertEqual(str(predicate), "True")
790 self.assertTrue(predicate.visit(_TestVisitor()))
791 self.assertEqual(predicate.operands, ())
792 for predicate in [
793 no,
794 yes.logical_and(no),
795 no.logical_and(yes),
796 no.logical_or(no),
797 yes.logical_not(),
798 no.logical_and(maybe),
799 maybe.logical_and(no),
800 ]:
801 self.assertEqual(predicate.column_type, "bool")
802 self.assertEqual(str(predicate), "False")
803 self.assertFalse(predicate.visit(_TestVisitor()))
804 self.assertEqual(predicate.operands, ((),))
805 for predicate in [
806 maybe,
807 yes.logical_and(maybe),
808 no.logical_or(maybe),
809 maybe.logical_not().logical_not(),
810 ]:
811 self.assertEqual(predicate.column_type, "bool")
812 self.assertEqual(str(predicate), "detector == 5")
813 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5})))
814 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4})))
815 self.assertEqual(len(predicate.operands), 1)
816 self.assertEqual(len(predicate.operands[0]), 1)
817 self.assertIs(predicate.operands[0][0], maybe.operands[0][0])
819 def test_comparison(self) -> None:
820 predicate: qt.Predicate
821 string: str
822 value: bool
823 for detector in (4, 5, 6):
824 for predicate, string, value in [
825 (self.x.detector == 5, "detector == 5", detector == 5),
826 (self.x.detector != 5, "detector != 5", detector != 5),
827 (self.x.detector < 5, "detector < 5", detector < 5),
828 (self.x.detector > 5, "detector > 5", detector > 5),
829 (self.x.detector <= 5, "detector <= 5", detector <= 5),
830 (self.x.detector >= 5, "detector >= 5", detector >= 5),
831 (self.x.detector == 5, "detector == 5", detector == 5),
832 (self.x.detector != 5, "detector != 5", detector != 5),
833 (self.x.detector < 5, "detector < 5", detector < 5),
834 (self.x.detector > 5, "detector > 5", detector > 5),
835 (self.x.detector <= 5, "detector <= 5", detector <= 5),
836 (self.x.detector >= 5, "detector >= 5", detector >= 5),
837 ]:
838 with self.subTest(string=string, detector=detector):
839 self.assertEqual(predicate.column_type, "bool")
840 self.assertEqual(str(predicate), string)
841 columns = qt.ColumnSet(self.universe.empty.as_group())
842 predicate.gather_required_columns(columns)
843 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
844 self.assertFalse(columns.dimension_fields["detector"])
845 self.assertEqual(
846 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value
847 )
848 inverted = predicate.logical_not()
849 self.assertEqual(inverted.column_type, "bool")
850 self.assertEqual(str(inverted), f"NOT {string}")
851 self.assertEqual(
852 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value
853 )
854 columns = qt.ColumnSet(self.universe.empty.as_group())
855 inverted.gather_required_columns(columns)
856 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
857 self.assertFalse(columns.dimension_fields["detector"])
859 def test_overlap_comparison(self) -> None:
860 pixelization = Mq3cPixelization(10)
861 region1 = pixelization.quad(12058870)
862 predicate = self.x.visit.region.overlaps(region1)
863 self.assertEqual(predicate.column_type, "bool")
864 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)")
865 columns = qt.ColumnSet(self.universe.empty.as_group())
866 predicate.gather_required_columns(columns)
867 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
868 self.assertEqual(columns.dimension_fields["visit"], {"region"})
869 region2 = pixelization.quad(12058857)
870 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
871 inverted = predicate.logical_not()
872 self.assertEqual(inverted.column_type, "bool")
873 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)")
874 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
875 columns = qt.ColumnSet(self.universe.empty.as_group())
876 inverted.gather_required_columns(columns)
877 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
878 self.assertEqual(columns.dimension_fields["visit"], {"region"})
880 def test_invalid_comparison(self) -> None:
881 # Mixed type comparisons.
882 with self.assertRaises(qt.InvalidQueryError):
883 self.x.visit > "three"
884 with self.assertRaises(qt.InvalidQueryError):
885 self.x.visit > 3.0
886 # Invalid operator for type.
887 with self.assertRaises(qt.InvalidQueryError):
888 self.x["raw"].dataset_id < uuid.uuid4()
890 def test_is_null(self) -> None:
891 predicate = self.x.visit.region.is_null
892 self.assertEqual(predicate.column_type, "bool")
893 self.assertEqual(str(predicate), "visit.region IS NULL")
894 columns = qt.ColumnSet(self.universe.empty.as_group())
895 predicate.gather_required_columns(columns)
896 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
897 self.assertEqual(columns.dimension_fields["visit"], {"region"})
898 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
899 inverted = predicate.logical_not()
900 self.assertEqual(inverted.column_type, "bool")
901 self.assertEqual(str(inverted), "NOT visit.region IS NULL")
902 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
903 inverted.gather_required_columns(columns)
904 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
905 self.assertEqual(columns.dimension_fields["visit"], {"region"})
907 def test_in_container(self) -> None:
908 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id])
909 self.assertEqual(predicate.column_type, "bool")
910 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]")
911 columns = qt.ColumnSet(self.universe.empty.as_group())
912 predicate.gather_required_columns(columns)
913 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
914 self.assertFalse(columns.dimension_fields["visit"])
915 self.assertFalse(columns.dimension_fields["exposure"])
916 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
917 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
918 inverted = predicate.logical_not()
919 self.assertEqual(inverted.column_type, "bool")
920 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]")
921 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
922 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
923 columns = qt.ColumnSet(self.universe.empty.as_group())
924 inverted.gather_required_columns(columns)
925 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
926 self.assertFalse(columns.dimension_fields["visit"])
927 self.assertFalse(columns.dimension_fields["exposure"])
928 with self.assertRaises(qt.InvalidQueryError):
929 # Regions (and timespans) not allowed in IN expressions, since that
930 # suggests topological logic we're not actually doing. We can't
931 # use ExpressionFactory because it prohibits this case with typing.
932 pixelization = Mq3cPixelization(10)
933 region = pixelization.quad(12058870)
934 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)])
935 with self.assertRaises(qt.InvalidQueryError):
936 # Mismatched types.
937 self.x.visit.in_iterable([3.5, 2.1])
939 def test_in_range(self) -> None:
940 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2)
941 self.assertEqual(predicate.column_type, "bool")
942 self.assertEqual(str(predicate), "visit IN 2:8:2")
943 columns = qt.ColumnSet(self.universe.empty.as_group())
944 predicate.gather_required_columns(columns)
945 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
946 self.assertFalse(columns.dimension_fields["visit"])
947 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2})))
948 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8})))
949 inverted = predicate.logical_not()
950 self.assertEqual(inverted.column_type, "bool")
951 self.assertEqual(str(inverted), "NOT visit IN 2:8:2")
952 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2})))
953 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8})))
954 columns = qt.ColumnSet(self.universe.empty.as_group())
955 inverted.gather_required_columns(columns)
956 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
957 self.assertFalse(columns.dimension_fields["visit"])
958 with self.assertRaises(qt.InvalidQueryError):
959 # Only integer fields allowed.
960 self.x.visit.exposure_time.in_range(2, 4)
961 with self.assertRaises(qt.InvalidQueryError):
962 # Step must be positive.
963 self.x.visit.in_range(2, 4, -1)
964 with self.assertRaises(qt.InvalidQueryError):
965 # Stop must be >= start.
966 self.x.visit.in_range(2, 0)
968 def test_in_query(self) -> None:
969 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3)
970 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query)
971 self.assertEqual(predicate.column_type, "bool")
972 self.assertEqual(str(predicate), "exposure IN (query).visit")
973 columns = qt.ColumnSet(self.universe.empty.as_group())
974 predicate.gather_required_columns(columns)
975 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
976 self.assertFalse(columns.dimension_fields["exposure"])
977 self.assertTrue(
978 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
979 )
980 self.assertFalse(
981 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
982 )
983 inverted = predicate.logical_not()
984 self.assertEqual(inverted.column_type, "bool")
985 self.assertEqual(str(inverted), "NOT exposure IN (query).visit")
986 self.assertFalse(
987 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
988 )
989 self.assertTrue(
990 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
991 )
992 columns = qt.ColumnSet(self.universe.empty.as_group())
993 inverted.gather_required_columns(columns)
994 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
995 self.assertFalse(columns.dimension_fields["exposure"])
996 with self.assertRaises(qt.InvalidQueryError):
997 # Regions (and timespans) not allowed in IN expressions, since that
998 # suggests topological logic we're not actually doing. We can't
999 # use ExpressionFactory because it prohibits this case with typing.
1000 qt.Predicate.in_query(
1001 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree
1002 )
1003 with self.assertRaises(qt.InvalidQueryError):
1004 # Mismatched types.
1005 self.x.exposure.in_query(self.x.visit.exposure_time, query)
1006 with self.assertRaises(qt.InvalidQueryError):
1007 # Query column requires dimensions that are not in the query.
1008 self.x.exposure.in_query(self.x.patch, query)
1009 with self.assertRaises(qt.InvalidQueryError):
1010 # Query column requires dataset type that is not in the query.
1011 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query)
1013 def test_complex_predicate(self) -> None:
1014 """Test that predicates are converted to conjunctive normal form and
1015 get parentheses in the right places when stringified.
1016 """
1017 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"})
1018 a: qt.Predicate = self.x.visit > 5 # will evaluate to True
1019 b: qt.Predicate = self.x.detector != 3 # will evaluate to False
1020 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True
1021 d: qt.Predicate = self.x.band == "g" # will evaluate to False
1022 predicate: qt.Predicate
1023 for predicate, string, value in [
1024 (a.logical_or(b), f"{a} OR {b}", True),
1025 (a.logical_or(c), f"{a} OR {c}", True),
1026 (b.logical_or(d), f"{b} OR {d}", False),
1027 (a.logical_and(b), f"{a} AND {b}", False),
1028 (a.logical_and(c), f"{a} AND {c}", True),
1029 (b.logical_and(d), f"{b} AND {d}", False),
1030 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True),
1031 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False),
1032 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True),
1033 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False),
1034 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True),
1035 (
1036 a.logical_and(b).logical_or(c.logical_and(d)),
1037 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})",
1038 False,
1039 ),
1040 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False),
1041 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False),
1042 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True),
1043 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True),
1044 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False),
1045 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True),
1046 (
1047 self.x.not_(a.logical_or(b).logical_and(c)),
1048 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})",
1049 False,
1050 ),
1051 (
1052 a.logical_and(b.logical_or(d)).logical_not(),
1053 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})",
1054 True,
1055 ),
1056 ]:
1057 with self.subTest(string=string):
1058 self.assertEqual(str(predicate), string)
1059 self.assertEqual(predicate.visit(visitor), value)
1061 def test_proxy_misc(self) -> None:
1062 """Test miscellaneous things on various ExpressionFactory proxies."""
1063 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region")
1064 self.assertEqual(str(self.x.visit.instrument), "instrument")
1065 self.assertEqual(str(self.x["raw"]), "raw")
1066 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date")
1067 self.assertEqual(
1068 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)),
1069 "visit.timespan OVERLAPS raw.timespan",
1070 )
1071 self.assertGreater(
1072 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"}
1073 )
1074 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"})
1075 with self.assertRaises(AttributeError):
1076 self.x["raw"].seq_num
1077 with self.assertRaises(AttributeError):
1078 self.x.visit.horse
1081class QueryTestCase(unittest.TestCase):
1082 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries."""
1084 def setUp(self) -> None:
1085 self.maxDiff = None
1086 self.universe = DimensionUniverse()
1087 # We use ArrowTable as the storage class for all dataset types because
1088 # it's got conversions that only require third-party packages we
1089 # already require.
1090 self.raw = DatasetType(
1091 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable"
1092 )
1093 self.refcat = DatasetType(
1094 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable"
1095 )
1096 self.bias = DatasetType(
1097 "bias",
1098 dimensions=self.universe.conform(["detector"]),
1099 storageClass="ArrowTable",
1100 isCalibration=True,
1101 )
1102 self.default_collections: list[str] | None = ["DummyCam/defaults"]
1103 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = {
1104 "DummyCam/raw/all": (
1105 RunRecord[int](1, name="DummyCam/raw/all"),
1106 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}),
1107 ),
1108 "DummyCam/calib": (
1109 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION),
1110 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}),
1111 ),
1112 "refcats": (
1113 RunRecord[int](3, name="refcats"),
1114 CollectionSummary(NamedValueSet({self.refcat}), governors={}),
1115 ),
1116 "DummyCam/defaults": (
1117 ChainedCollectionRecord[int](
1118 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats")
1119 ),
1120 CollectionSummary(
1121 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}}
1122 ),
1123 ),
1124 }
1125 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias}
1127 def query(self, **kwargs: Any) -> Query:
1128 """Make an initial Query object with the given kwargs used to
1129 initialize the _TestQueryDriver.
1131 The given kwargs override the test-case-attribute defaults.
1132 """
1133 kwargs.setdefault("default_collections", self.default_collections)
1134 kwargs.setdefault("collection_info", self.collection_info)
1135 kwargs.setdefault("dataset_types", self.dataset_types)
1136 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
1138 def test_dataset_join(self) -> None:
1139 """Test queries that have had a dataset search explicitly joined in via
1140 Query.join_dataset_search.
1142 Since this kind of query has a moderate amount of complexity, this is
1143 where we get a lot of basic coverage that applies to all kinds of
1144 queries, including:
1146 - getting data ID and dataset results (but not iterating over them);
1147 - the 'any' and 'explain_no_results' methods;
1148 - adding 'where' filters (but not expanding dimensions accordingly);
1149 - materializations.
1150 """
1152 def check(
1153 query: Query,
1154 dimensions: DimensionGroup = self.raw.dimensions.as_group(),
1155 has_storage_class: bool = True,
1156 dataset_type_registered: bool = True,
1157 ) -> None:
1158 """Run a battery of tests on one of a set of very similar queries
1159 constructed in different ways (see below).
1160 """
1162 def check_query_tree(
1163 tree: qt.QueryTree,
1164 dimensions: DimensionGroup = dimensions,
1165 storage_class_name: str | None = self.raw.storageClass_name if has_storage_class else None,
1166 ) -> None:
1167 """Check the state of the QueryTree object that backs the Query
1168 or a derived QueryResults object.
1170 Parameters
1171 ----------
1172 tree : `lsst.daf.butler.queries.tree.QueryTree`
1173 Object to test.
1174 dimensions : `DimensionGroup`
1175 Dimensions to expect in the `QueryTree`, not necessarily
1176 including those in the test 'raw' dataset type.
1177 storage_class_name : `bool`, optional
1178 The storage class name the query is expected to have for
1179 the test 'raw' dataset type.
1180 """
1181 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group())
1182 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'")
1183 self.assertFalse(tree.materializations)
1184 self.assertFalse(tree.data_coordinate_uploads)
1185 self.assertEqual(tree.datasets.keys(), {"raw"})
1186 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1187 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",))
1188 self.assertEqual(tree.datasets["raw"].storage_class_name, storage_class_name)
1189 self.assertEqual(
1190 tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()})
1191 )
1193 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None:
1194 """Construct a DataCoordinateQueryResults object from the query
1195 with the given arguments and run a battery of tests on it.
1197 Parameters
1198 ----------
1199 *args
1200 Forwarded to `Query.data_ids`.
1201 query : `Query`
1202 Query to start from.
1203 dimensions : `DimensionGroup`, optional
1204 Dimensions the result data IDs should have.
1205 """
1206 with self.assertRaises(_TestQueryExecution) as cm:
1207 list(query.data_ids(*args))
1208 self.assertEqual(
1209 cm.exception.result_spec,
1210 qrs.DataCoordinateResultSpec(dimensions=dimensions),
1211 )
1212 check_query_tree(cm.exception.tree, dimensions=dimensions)
1214 def check_dataset_results(
1215 *args: Any,
1216 query: Query,
1217 find_first: bool = True,
1218 storage_class_name: str = self.raw.storageClass_name,
1219 ) -> None:
1220 """Construct a DatasetQueryResults object from the query with
1221 the given arguments and run a battery of tests on it.
1223 Parameters
1224 ----------
1225 *args
1226 Forwarded to `Query.datasets`.
1227 query : `Query`
1228 Query to start from.
1229 find_first : `bool`, optional
1230 Whether to do find-first resolution on the results.
1231 storage_class_name : `str`, optional
1232 Expected name of the storage class for the results.
1233 """
1234 with self.assertRaises(_TestQueryExecution) as cm:
1235 list(query.datasets(*args, find_first=find_first))
1236 self.assertEqual(
1237 cm.exception.result_spec,
1238 qrs.DatasetRefResultSpec(
1239 dataset_type_name="raw",
1240 dimensions=self.raw.dimensions.as_group(),
1241 storage_class_name=storage_class_name,
1242 find_first=find_first,
1243 ),
1244 )
1245 check_query_tree(cm.exception.tree, storage_class_name=storage_class_name)
1247 def check_materialization(
1248 kwargs: Mapping[str, Any],
1249 query: Query,
1250 dimensions: DimensionGroup = dimensions,
1251 has_dataset: bool = True,
1252 ) -> None:
1253 """Materialize the query with the given arguments and run a
1254 battery of tests on the result.
1256 Parameters
1257 ----------
1258 kwargs
1259 Forwarded as keyword arguments to `Query.materialize`.
1260 query : `Query`
1261 Query to start from.
1262 dimensions : `DimensionGroup`, optional
1263 Dimensions to expect in the materialization and its derived
1264 query.
1265 has_dataset : `bool`, optional
1266 Whether the query backed by the materialization should
1267 still have the test 'raw' dataset joined in.
1268 """
1269 # Materialize the query and check the query tree sent to the
1270 # driver and the one in the materialized query.
1271 with self.assertRaises(_TestQueryExecution) as cm:
1272 list(query.materialize(**kwargs).data_ids())
1273 derived_tree = cm.exception.tree
1274 self.assertEqual(derived_tree.dimensions, dimensions)
1275 # Predicate should be materialized away; it no longer appears
1276 # in the derived query.
1277 self.assertEqual(str(derived_tree.predicate), "True")
1278 self.assertFalse(derived_tree.data_coordinate_uploads)
1279 if has_dataset:
1280 # Dataset search is still there, even though its existence
1281 # constraint is included in the materialization, because we
1282 # might need to re-join for some result columns in a
1283 # derived query.
1284 self.assertTrue(derived_tree.datasets.keys(), {"raw"})
1285 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1286 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",))
1287 else:
1288 self.assertFalse(derived_tree.datasets)
1289 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items()
1290 self.assertEqual(derived_tree_materialized_dimensions, dimensions)
1291 (
1292 materialized_tree,
1293 materialized_dimensions,
1294 materialized_datasets,
1295 ) = cm.exception.driver.materializations[key]
1296 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions)
1297 if has_dataset:
1298 self.assertEqual(materialized_datasets, {"raw"})
1299 else:
1300 self.assertFalse(materialized_datasets)
1301 check_query_tree(materialized_tree)
1303 # Actual logic for the check() function begins here.
1305 self.assertEqual(query.constraint_dataset_types, {"raw"})
1306 self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group())
1308 # Adding a constraint on a field for this dataset type should work
1309 # (this constraint will be present in all downstream tests).
1310 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all")
1311 with self.assertRaises(qt.InvalidQueryError):
1312 # Adding constraint on a different dataset should not work.
1313 query.where(query.expression_factory["refcat"].run == "refcats")
1315 # Data IDs, with dimensions defaulted.
1316 check_data_id_results(query=query)
1317 # Dimensions for data IDs the same as defaults.
1318 check_data_id_results(["exposure", "detector"], query=query)
1319 # Dimensions are a subset of the query dimensions.
1320 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"]))
1321 # Dimensions are a superset of the query dimensions.
1322 check_data_id_results(
1323 ["exposure", "detector", "visit"],
1324 query=query,
1325 dimensions=self.universe.conform(["exposure", "detector", "visit"]),
1326 )
1327 # Dimensions are neither a superset nor a subset of the query
1328 # dimensions.
1329 check_data_id_results(
1330 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"])
1331 )
1332 # Dimensions are empty.
1333 check_data_id_results([], query=query, dimensions=self.universe.conform([]))
1335 # Get DatasetRef results, with various arguments and defaulting.
1336 if has_storage_class:
1337 check_dataset_results("raw", query=query)
1338 check_dataset_results("raw", query=query, find_first=True)
1339 check_dataset_results("raw", ["DummyCam/defaults"], query=query)
1340 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True)
1341 else:
1342 with self.assertRaises(MissingDatasetTypeError):
1343 query.datasets("raw")
1344 with self.assertRaises(MissingDatasetTypeError):
1345 query.datasets("raw", find_first=True)
1346 with self.assertRaises(MissingDatasetTypeError):
1347 query.datasets("raw", ["DummyCam/defaults"])
1348 with self.assertRaises(MissingDatasetTypeError):
1349 query.datasets("raw", ["DummyCam/defaults"], find_first=True)
1350 check_dataset_results(self.raw, query=query)
1351 check_dataset_results(self.raw, query=query, find_first=True)
1352 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query)
1353 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True)
1355 # Changing collections at this stage is not allowed.
1356 with self.assertRaises(qt.InvalidQueryError):
1357 query.datasets("raw", collections=["DummyCam/calib"])
1359 # Changing storage classes is allowed, if they're compatible.
1360 check_dataset_results(
1361 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy"
1362 )
1363 if dataset_type_registered:
1364 with self.assertRaises(DatasetTypeError):
1365 # Can't use overrideStorageClass, because it'll raise
1366 # before the code we want to test can.
1367 query.datasets(DatasetType("raw", self.raw.dimensions, "int"))
1369 # Check the 'any' and 'explain_no_results' methods on Query itself.
1370 for execute, exact in itertools.permutations([False, True], 2):
1371 with self.assertRaises(_TestQueryAny) as cm:
1372 query.any(execute=execute, exact=exact)
1373 self.assertEqual(cm.exception.execute, execute)
1374 self.assertEqual(cm.exception.exact, exact)
1375 check_query_tree(cm.exception.tree, dimensions)
1376 with self.assertRaises(_TestQueryExplainNoResults):
1377 query.explain_no_results()
1378 check_query_tree(cm.exception.tree, dimensions)
1380 # Materialize the query with defaults.
1381 check_materialization({}, query=query)
1382 # Materialize the query with args that match defaults.
1383 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query)
1384 # Materialize the query with a superset of the original dimensions.
1385 check_materialization(
1386 {"dimensions": ["exposure", "detector", "visit"]},
1387 query=query,
1388 dimensions=self.universe.conform(["exposure", "visit", "detector"]),
1389 )
1390 # Materialize the query with no datasets.
1391 check_materialization(
1392 {"dimensions": ["exposure", "detector"], "datasets": frozenset()},
1393 query=query,
1394 has_dataset=False,
1395 )
1396 # Materialize the query with no datasets and a subset of the
1397 # dimensions.
1398 check_materialization(
1399 {"dimensions": ["exposure"], "datasets": frozenset()},
1400 query=query,
1401 has_dataset=False,
1402 dimensions=self.universe.conform(["exposure"]),
1403 )
1404 # Materializing the query with a dataset that is not in the query
1405 # is an error.
1406 with self.assertRaises(qt.InvalidQueryError):
1407 query.materialize(datasets={"refcat"})
1408 # Materializing the query with dimensions that are not a superset
1409 # of any materialized dataset dimensions is an error.
1410 with self.assertRaises(qt.InvalidQueryError):
1411 query.materialize(dimensions=["exposure"], datasets={"raw"})
1413 # Actual logic for test_dataset_joins starts here.
1415 # Default collections and existing dataset type name.
1416 check(self.query().join_dataset_search("raw"))
1417 # Default collections and existing DatasetType instance.
1418 check(self.query().join_dataset_search(self.raw))
1419 # Manual collections and existing dataset type.
1420 check(
1421 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"])
1422 )
1423 check(
1424 self.query(default_collections=None).join_dataset_search(
1425 self.raw, collections=["DummyCam/defaults"]
1426 )
1427 )
1428 # Dataset type does not exist, but dimensions provided. This will
1429 # prohibit getting results without providing the dataset type
1430 # later.
1431 check(
1432 self.query(dataset_types={}).join_dataset_search(
1433 "raw", dimensions=self.universe.conform(["detector", "exposure"])
1434 ),
1435 has_storage_class=False,
1436 dataset_type_registered=False,
1437 )
1438 # Dataset type does not exist, but a full dataset type was
1439 # provided up front.
1440 check(self.query(dataset_types={}).join_dataset_search(self.raw), dataset_type_registered=False)
1442 with self.assertRaises(MissingDatasetTypeError):
1443 # Dataset type does not exist and no dimensions passed.
1444 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"])
1445 with self.assertRaises(DatasetTypeError):
1446 # Dataset type does exist and bad dimensions passed.
1447 self.query().join_dataset_search(
1448 "raw", collections=["DummyCam/raw/all"], dimensions=self.universe.conform(["visit"])
1449 )
1450 with self.assertRaises(TypeError):
1451 # Dataset type object and dimensions were passed (illegal even if
1452 # they agree)
1453 self.query().join_dataset_search(
1454 self.raw,
1455 dimensions=self.raw.dimensions.as_group(),
1456 )
1457 with self.assertRaises(TypeError):
1458 # Bad type for dataset type argument.
1459 self.query().join_dataset_search(3)
1460 with self.assertRaises(DatasetTypeError):
1461 # Changing dimensions is an error.
1462 self.query(dataset_types={}).join_dataset_search(
1463 "raw", dimensions=self.universe.conform(["patch"])
1464 ).datasets(self.raw)
1466 def test_dimension_record_results(self) -> None:
1467 """Test queries that return dimension records.
1469 This includes tests for:
1471 - joining against uploaded data coordinates;
1472 - counting result rows;
1473 - expanding dimensions as needed for 'where' conditions;
1474 - order_by, limit, and offset.
1476 It does not include the iteration methods of
1477 DimensionRecordQueryResults, since those require a different mock
1478 driver setup (see test_dimension_record_iteration).
1479 """
1480 # Set up the base query-results object to test.
1481 query = self.query()
1482 x = query.expression_factory
1483 self.assertFalse(query.constraint_dimensions)
1484 query = query.where(x.skymap == "m")
1485 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"]))
1486 upload_rows = [
1487 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe),
1488 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe),
1489 ]
1490 raw_rows = frozenset([data_id.required_values for data_id in upload_rows])
1491 query = query.join_data_coordinates(upload_rows)
1492 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"]))
1493 results = query.dimension_records("patch")
1494 results = results.where(x.tract == 4)
1496 # Define a closure to run tests on variants of the base query.
1497 def check(
1498 results: DimensionRecordQueryResults,
1499 order_by: Any = (),
1500 limit: int | None = None,
1501 offset: int = 0,
1502 ) -> list[str]:
1503 results = results.order_by(*order_by).limit(limit, offset=offset)
1504 self.assertEqual(results.element.name, "patch")
1505 with self.assertRaises(_TestQueryExecution) as cm:
1506 list(results)
1507 tree = cm.exception.tree
1508 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1509 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"]))
1510 self.assertFalse(tree.materializations)
1511 self.assertFalse(tree.datasets)
1512 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items()
1513 self.assertEqual(upload_dimensions, self.universe.conform(["visit"]))
1514 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows))
1515 result_spec = cm.exception.result_spec
1516 self.assertEqual(result_spec.result_type, "dimension_record")
1517 self.assertEqual(result_spec.element, self.universe["patch"])
1518 self.assertEqual(result_spec.limit, limit)
1519 self.assertEqual(result_spec.offset, offset)
1520 for exact, discard in itertools.permutations([False, True], r=2):
1521 with self.assertRaises(_TestQueryCount) as cm:
1522 results.count(exact=exact, discard=discard)
1523 self.assertEqual(cm.exception.result_spec, result_spec)
1524 self.assertEqual(cm.exception.exact, exact)
1525 self.assertEqual(cm.exception.discard, discard)
1526 return [str(term) for term in result_spec.order_by]
1528 # Run the closure's tests on variants of the base query.
1529 self.assertEqual(check(results), [])
1530 self.assertEqual(check(results, limit=2), [])
1531 self.assertEqual(check(results, offset=1), [])
1532 self.assertEqual(check(results, limit=3, offset=3), [])
1533 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1534 self.assertEqual(
1535 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc], offset=2),
1536 ["patch.cell_x", "patch.cell_y DESC"],
1537 )
1538 with self.assertRaises(qt.InvalidQueryError):
1539 # Cannot upload empty list of data IDs.
1540 query.join_data_coordinates([])
1541 with self.assertRaises(qt.InvalidQueryError):
1542 # Cannot upload heterogeneous list of data IDs.
1543 query.join_data_coordinates(
1544 [
1545 DataCoordinate.make_empty(self.universe),
1546 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe),
1547 ]
1548 )
1550 def test_dimension_record_iteration(self) -> None:
1551 """Tests for DimensionRecordQueryResult iteration."""
1553 def make_record(n: int) -> DimensionRecord:
1554 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n)
1556 result_rows = (
1557 [make_record(n) for n in range(3)],
1558 [make_record(n) for n in range(3, 6)],
1559 [make_record(10)],
1560 )
1561 results = self.query(result_rows=result_rows).dimension_records("patch")
1562 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1563 self.assertEqual(
1564 list(results.iter_set_pages()),
1565 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows],
1566 )
1567 self.assertEqual(
1568 [table.column("id").to_pylist() for table in results.iter_table_pages()],
1569 [list(range(3)), list(range(3, 6)), [10]],
1570 )
1572 def test_data_coordinate_results(self) -> None:
1573 """Test queries that return data coordinates.
1575 This includes tests for:
1577 - counting result rows;
1578 - expanding dimensions as needed for 'where' conditions;
1579 - order_by, limit, and offset.
1581 It does not include the iteration methods of
1582 DataCoordinateQueryResults, since those require a different mock
1583 driver setup (see test_data_coordinate_iteration). More tests for
1584 different inputs to DataCoordinateQueryResults construction are in
1585 test_dataset_join.
1586 """
1587 # Set up the base query-results object to test.
1588 query = self.query()
1589 x = query.expression_factory
1590 self.assertFalse(query.constraint_dimensions)
1591 query = query.where(x.skymap == "m")
1592 results = query.data_ids(["patch", "band"])
1593 results = results.where(x.tract == 4)
1595 # Define a closure to run tests on variants of the base query.
1596 def check(
1597 results: DataCoordinateQueryResults,
1598 order_by: Any = (),
1599 limit: int | None = None,
1600 offset: int = 0,
1601 include_dimension_records: bool = False,
1602 ) -> list[str]:
1603 results = results.order_by(*order_by).limit(limit, offset=offset)
1604 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"]))
1605 with self.assertRaises(_TestQueryExecution) as cm:
1606 list(results)
1607 tree = cm.exception.tree
1608 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1609 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"]))
1610 self.assertFalse(tree.materializations)
1611 self.assertFalse(tree.datasets)
1612 self.assertFalse(tree.data_coordinate_uploads)
1613 result_spec = cm.exception.result_spec
1614 self.assertEqual(result_spec.result_type, "data_coordinate")
1615 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"]))
1616 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1617 self.assertEqual(result_spec.limit, limit)
1618 self.assertEqual(result_spec.offset, offset)
1619 self.assertIsNone(result_spec.find_first_dataset)
1620 for exact, discard in itertools.permutations([False, True], r=2):
1621 with self.assertRaises(_TestQueryCount) as cm:
1622 results.count(exact=exact, discard=discard)
1623 self.assertEqual(cm.exception.result_spec, result_spec)
1624 self.assertEqual(cm.exception.exact, exact)
1625 self.assertEqual(cm.exception.discard, discard)
1626 return [str(term) for term in result_spec.order_by]
1628 # Run the closure's tests on variants of the base query.
1629 self.assertEqual(check(results), [])
1630 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), [])
1631 self.assertEqual(
1632 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1633 [],
1634 )
1635 self.assertEqual(check(results, limit=2), [])
1636 self.assertEqual(check(results, offset=1), [])
1637 self.assertEqual(check(results, limit=3, offset=3), [])
1638 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1639 self.assertEqual(
1640 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc], offset=2),
1641 ["patch.cell_x", "patch.cell_y DESC"],
1642 )
1643 self.assertEqual(
1644 check(results, order_by=["patch.cell_x", "-cell_y"], offset=2),
1645 ["patch.cell_x", "patch.cell_y DESC"],
1646 )
1648 def test_data_coordinate_iteration(self) -> None:
1649 """Tests for DataCoordinateQueryResult iteration."""
1651 def make_data_id(n: int) -> DimensionRecord:
1652 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe)
1654 result_rows = (
1655 [make_data_id(n) for n in range(3)],
1656 [make_data_id(n) for n in range(3, 6)],
1657 [make_data_id(10)],
1658 )
1659 results = self.query(result_rows=result_rows).data_ids(["patch"])
1660 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1662 def test_dataset_results(self) -> None:
1663 """Test queries that return dataset refs.
1665 This includes tests for:
1667 - counting result rows;
1668 - expanding dimensions as needed for 'where' conditions;
1669 - chained results for multiple dataset types;
1670 - different ways of passing a data ID to 'where' methods;
1671 - order_by, limit, and offset.
1673 It does not include the iteration methods of the DatasetQueryResults
1674 classes, since those require a different mock driver setup (see
1675 test_dataset_iteration). More tests for different inputs to
1676 SingleTypeDatasetQueryResults construction are in test_dataset_join.
1677 """
1678 # Set up a few equivalent base query-results object to test.
1679 query = self.query()
1680 x = query.expression_factory
1681 self.assertFalse(query.constraint_dimensions)
1682 results1 = query.datasets(...).where(x.instrument == "DummyCam", visit=4)
1683 results2 = query.datasets(..., collections=["DummyCam/defaults"]).where(
1684 {"instrument": "DummyCam", "visit": 4}
1685 )
1686 results3 = query.datasets(["raw", "bias", "refcat"]).where(
1687 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe)
1688 )
1690 # Define a closure to handle single-dataset-type results.
1691 def check_single_type(
1692 results: SingleTypeDatasetQueryResults,
1693 order_by: Any = (),
1694 limit: int | None = None,
1695 offset: int = 0,
1696 include_dimension_records: bool = False,
1697 ) -> list[str]:
1698 results = results.order_by(*order_by).limit(limit, offset=offset)
1699 self.assertIs(list(results.by_dataset_type())[0], results)
1700 with self.assertRaises(_TestQueryExecution) as cm:
1701 list(results)
1702 tree = cm.exception.tree
1703 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4")
1704 self.assertEqual(
1705 tree.dimensions,
1706 self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()),
1707 )
1708 self.assertFalse(tree.materializations)
1709 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name})
1710 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",))
1711 self.assertEqual(
1712 tree.datasets[results.dataset_type.name].dimensions,
1713 results.dataset_type.dimensions.as_group(),
1714 )
1715 self.assertEqual(
1716 tree.datasets[results.dataset_type.name].storage_class_name,
1717 results.dataset_type.storageClass_name,
1718 )
1719 self.assertFalse(tree.data_coordinate_uploads)
1720 result_spec = cm.exception.result_spec
1721 self.assertEqual(result_spec.result_type, "dataset_ref")
1722 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1723 self.assertEqual(result_spec.limit, limit)
1724 self.assertEqual(result_spec.offset, offset)
1725 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name)
1726 for exact, discard in itertools.permutations([False, True], r=2):
1727 with self.assertRaises(_TestQueryCount) as cm:
1728 results.count(exact=exact, discard=discard)
1729 self.assertEqual(cm.exception.result_spec, result_spec)
1730 self.assertEqual(cm.exception.exact, exact)
1731 self.assertEqual(cm.exception.discard, discard)
1732 with self.assertRaises(_TestQueryExecution) as cm:
1733 list(results.data_ids)
1734 self.assertEqual(
1735 cm.exception.result_spec,
1736 qrs.DataCoordinateResultSpec(
1737 dimensions=results.dataset_type.dimensions.as_group(),
1738 include_dimension_records=include_dimension_records,
1739 ),
1740 )
1741 self.assertIs(cm.exception.tree, tree)
1742 return [str(term) for term in result_spec.order_by]
1744 # Define a closure to run tests on variants of the base query, which
1745 # is a chain of multiple dataset types.
1746 def check_chained(
1747 results: DatasetQueryResults,
1748 order_by: tuple[Any, Any, Any] = ((), (), ()),
1749 limit: int | None = None,
1750 offset: int = 0,
1751 include_dimension_records: bool = False,
1752 ) -> list[list[str]]:
1753 self.assertEqual(results.has_dimension_records, include_dimension_records)
1754 types_seen: list[str] = []
1755 order_by_strings: list[list[str]] = []
1756 for single_type_results, single_type_order_by in zip(results.by_dataset_type(), order_by):
1757 order_by_strings.append(
1758 check_single_type(
1759 single_type_results,
1760 order_by=single_type_order_by,
1761 limit=limit,
1762 offset=offset,
1763 include_dimension_records=include_dimension_records,
1764 )
1765 )
1766 types_seen.append(single_type_results.dataset_type.name)
1767 self.assertEqual(types_seen, sorted(["raw", "bias", "refcat"]))
1768 return order_by_strings
1770 # Run the closure's tests on variants of the base query.
1771 self.assertEqual(check_chained(results1), [[], [], []])
1772 self.assertEqual(check_chained(results2), [[], [], []])
1773 self.assertEqual(check_chained(results3), [[], [], []])
1774 self.assertEqual(
1775 check_chained(results1.with_dimension_records(), include_dimension_records=True), [[], [], []]
1776 )
1777 self.assertEqual(
1778 check_chained(
1779 results1.with_dimension_records().with_dimension_records(), include_dimension_records=True
1780 ),
1781 [[], [], []],
1782 )
1783 self.assertEqual(check_chained(results1, limit=2), [[], [], []])
1784 self.assertEqual(check_chained(results1, offset=1), [[], [], []])
1785 self.assertEqual(check_chained(results1, limit=3, offset=3), [[], [], []])
1786 self.assertEqual(
1787 check_chained(
1788 results1,
1789 order_by=[
1790 ["bias.timespan.begin"],
1791 ["ingest_date"],
1792 ["htm7"],
1793 ],
1794 ),
1795 [["bias.timespan.begin"], ["raw.ingest_date"], ["htm7"]],
1796 )
1798 def test_dataset_iteration(self) -> None:
1799 """Tests for SingleTypeDatasetQueryResult iteration."""
1801 def make_ref(n: int) -> DimensionRecord:
1802 return DatasetRef(
1803 self.raw,
1804 DataCoordinate.standardize(
1805 instrument="DummyCam", exposure=4, detector=n, universe=self.universe
1806 ),
1807 run="DummyCam/raw/all",
1808 id=uuid.uuid4(),
1809 )
1811 result_rows = (
1812 [make_ref(n) for n in range(3)],
1813 [make_ref(n) for n in range(3, 6)],
1814 [make_ref(10)],
1815 )
1816 results = self.query(result_rows=result_rows).datasets("raw")
1817 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1819 def test_identifiers(self) -> None:
1820 """Test edge-cases of identifiers in order_by expressions."""
1822 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]:
1823 with self.assertRaises(_TestQueryExecution) as cm:
1824 list(results)
1825 return [str(term) for term in cm.exception.result_spec.order_by]
1827 self.assertEqual(
1828 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")),
1829 ["day_obs.timespan.begin DESC"],
1830 )
1831 self.assertEqual(
1832 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")),
1833 ["day_obs.timespan.end"],
1834 )
1835 self.assertEqual(
1836 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")),
1837 ["visit.timespan.begin DESC"],
1838 )
1839 self.assertEqual(
1840 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")),
1841 ["visit.timespan.end"],
1842 )
1843 self.assertEqual(
1844 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")),
1845 ["visit.science_program"],
1846 )
1847 self.assertEqual(
1848 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")),
1849 ["visit"],
1850 )
1851 self.assertEqual(
1852 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")),
1853 ["physical_filter"],
1854 )
1855 with self.assertRaises(TypeError):
1856 self.query().data_ids(["visit"]).order_by(3)
1857 with self.assertRaises(qt.InvalidQueryError):
1858 self.query().data_ids(["visit"]).order_by("visit.region")
1859 with self.assertRaisesRegex(qt.InvalidQueryError, "Ambiguous"):
1860 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin")
1861 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1862 self.query().data_ids(["visit", "exposure"]).order_by("blarg")
1863 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1864 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse")
1865 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1866 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey")
1867 with self.assertRaisesRegex(qt.InvalidQueryError, "not valid for datasets"):
1868 self.query().datasets("raw").order_by("raw.seq_num")
1870 def test_invalid_models(self) -> None:
1871 """Test invalid models and combinations of models that cannot be
1872 constructed via the public Query and *QueryResults interfaces.
1873 """
1874 x = ExpressionFactory(self.universe)
1875 with self.assertRaises(qt.InvalidQueryError):
1876 # QueryTree dimensions do not cover dataset dimensions.
1877 qt.QueryTree(
1878 dimensions=self.universe.conform(["visit"]),
1879 datasets={
1880 "raw": qt.DatasetSearch(
1881 collections=("DummyCam/raw/all",),
1882 dimensions=self.raw.dimensions.as_group(),
1883 storage_class_name=None,
1884 )
1885 },
1886 )
1887 with self.assertRaises(qt.InvalidQueryError):
1888 # QueryTree dimensions do no cover predicate dimensions.
1889 qt.QueryTree(
1890 dimensions=self.universe.conform(["visit"]),
1891 predicate=(x.detector > 5),
1892 )
1893 with self.assertRaises(qt.InvalidQueryError):
1894 # Predicate references a dataset not in the QueryTree.
1895 qt.QueryTree(
1896 dimensions=self.universe.conform(["exposure", "detector"]),
1897 predicate=(x["raw"].collection == "bird"),
1898 )
1899 with self.assertRaises(qt.InvalidQueryError):
1900 # ResultSpec's dimensions are not a subset of the query tree's.
1901 DimensionRecordQueryResults(
1902 _TestQueryDriver(),
1903 qt.QueryTree(dimensions=self.universe.conform(["tract"])),
1904 qrs.DimensionRecordResultSpec(element=self.universe["detector"]),
1905 )
1906 with self.assertRaises(qt.InvalidQueryError):
1907 # ResultSpec's datasets are not a subset of the query tree's.
1908 SingleTypeDatasetQueryResults(
1909 _TestQueryDriver(),
1910 qt.QueryTree(dimensions=self.raw.dimensions.as_group()),
1911 qrs.DatasetRefResultSpec(
1912 dataset_type_name="raw",
1913 dimensions=self.raw.dimensions.as_group(),
1914 storage_class_name=self.raw.storageClass_name,
1915 find_first=True,
1916 ),
1917 )
1918 with self.assertRaises(qt.InvalidQueryError):
1919 # ResultSpec's order_by expression is not related to the dimensions
1920 # we're returning.
1921 x = ExpressionFactory(self.universe)
1922 DimensionRecordQueryResults(
1923 _TestQueryDriver(),
1924 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1925 qrs.DimensionRecordResultSpec(
1926 element=self.universe["detector"], order_by=(x.unwrap(x.visit),)
1927 ),
1928 )
1929 with self.assertRaises(qt.InvalidQueryError):
1930 # ResultSpec's order_by expression is not related to the datasets
1931 # we're returning.
1932 x = ExpressionFactory(self.universe)
1933 DimensionRecordQueryResults(
1934 _TestQueryDriver(),
1935 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1936 qrs.DimensionRecordResultSpec(
1937 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),)
1938 ),
1939 )
1941 def test_general_result_spec(self) -> None:
1942 """Tests for GeneralResultSpec.
1944 Unlike the other ResultSpec objects, we don't have a *QueryResults
1945 class for GeneralResultSpec yet, so we can't use the higher-level
1946 interfaces to test it like we can the others.
1947 """
1948 a = qrs.GeneralResultSpec(
1949 dimensions=self.universe.conform(["detector"]),
1950 dimension_fields={"detector": {"purpose"}},
1951 dataset_fields={},
1952 find_first=False,
1953 )
1954 self.assertEqual(a.find_first_dataset, None)
1955 a_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1956 a_columns.dimension_fields["detector"].add("purpose")
1957 self.assertEqual(a.get_result_columns(), a_columns)
1958 b = qrs.GeneralResultSpec(
1959 dimensions=self.universe.conform(["detector"]),
1960 dimension_fields={},
1961 dataset_fields={"bias": {"timespan", "dataset_id"}},
1962 find_first=True,
1963 )
1964 self.assertEqual(b.find_first_dataset, "bias")
1965 b_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1966 b_columns.dataset_fields["bias"].add("timespan")
1967 b_columns.dataset_fields["bias"].add("dataset_id")
1968 self.assertEqual(b.get_result_columns(), b_columns)
1969 with self.assertRaises(qt.InvalidQueryError):
1970 # More than one dataset type with find_first
1971 qrs.GeneralResultSpec(
1972 dimensions=self.universe.conform(["detector", "exposure"]),
1973 dimension_fields={},
1974 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}},
1975 find_first=True,
1976 )
1977 with self.assertRaises(qt.InvalidQueryError):
1978 # Out-of-bounds dimension fields.
1979 qrs.GeneralResultSpec(
1980 dimensions=self.universe.conform(["detector"]),
1981 dimension_fields={"visit": {"name"}},
1982 dataset_fields={},
1983 find_first=False,
1984 )
1985 with self.assertRaises(qt.InvalidQueryError):
1986 # No fields for dimension element.
1987 qrs.GeneralResultSpec(
1988 dimensions=self.universe.conform(["detector"]),
1989 dimension_fields={"detector": set()},
1990 dataset_fields={},
1991 find_first=True,
1992 )
1993 with self.assertRaises(qt.InvalidQueryError):
1994 # No fields for dataset.
1995 qrs.GeneralResultSpec(
1996 dimensions=self.universe.conform(["detector"]),
1997 dimension_fields={},
1998 dataset_fields={"bias": set()},
1999 find_first=True,
2000 )
2003if __name__ == "__main__":
2004 unittest.main()