Coverage for tests/test_query_interface.py: 10%
936 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 10:14 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 10:14 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for the public Butler._query interface and the Pydantic models that
29back it using a mock column-expression visitor and a mock QueryDriver
30implementation.
32These tests are entirely independent of which kind of butler or database
33backend we're using.
35This is a very large test file because a lot of tests make use of those mocks,
36but they're not so generally useful that I think they're worth putting in the
37library proper.
38"""
40from __future__ import annotations
42import dataclasses
43import itertools
44import unittest
45import uuid
46from collections.abc import Iterable, Iterator, Mapping, Set
47from typing import Any
49import astropy.time
50from lsst.daf.butler import (
51 CollectionType,
52 DataCoordinate,
53 DataIdValue,
54 DatasetRef,
55 DatasetType,
56 DimensionGroup,
57 DimensionRecord,
58 DimensionRecordSet,
59 DimensionUniverse,
60 MissingDatasetTypeError,
61 NamedValueSet,
62 NoDefaultCollectionError,
63 Timespan,
64)
65from lsst.daf.butler.queries import (
66 DataCoordinateQueryResults,
67 DatasetRefQueryResults,
68 DimensionRecordQueryResults,
69 Query,
70)
71from lsst.daf.butler.queries import driver as qd
72from lsst.daf.butler.queries import result_specs as qrs
73from lsst.daf.butler.queries import tree as qt
74from lsst.daf.butler.queries.expression_factory import ExpressionFactory
75from lsst.daf.butler.queries.tree._column_expression import UnaryExpression
76from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands
77from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
78from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError
79from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord
80from lsst.sphgeom import DISJOINT, Mq3cPixelization
83class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]):
84 """Test visitor for column expressions.
86 This visitor evaluates column expressions using regular Python logic.
88 Parameters
89 ----------
90 dimension_keys : `~collections.abc.Mapping`, optional
91 Mapping from dimension name to the value it should be assigned by the
92 visitor.
93 dimension_fields : `~collections.abc.Mapping`, optional
94 Mapping from ``(dimension element name, field)`` tuple to the value it
95 should be assigned by the visitor.
96 dataset_fields : `~collections.abc.Mapping`, optional
97 Mapping from ``(dataset type name, field)`` tuple to the value it
98 should be assigned by the visitor.
99 query_tree_items : `~collections.abc.Set`, optional
100 Set that should be used as the right-hand side of element-in-query
101 predicates.
102 """
104 def __init__(
105 self,
106 dimension_keys: Mapping[str, Any] | None = None,
107 dimension_fields: Mapping[tuple[str, str], Any] | None = None,
108 dataset_fields: Mapping[tuple[str, str], Any] | None = None,
109 query_tree_items: Set[Any] = frozenset(),
110 ):
111 self.dimension_keys = dimension_keys or {}
112 self.dimension_fields = dimension_fields or {}
113 self.dataset_fields = dataset_fields or {}
114 self.query_tree_items = query_tree_items
116 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any:
117 match expression.operator:
118 case "+":
119 return expression.a.visit(self) + expression.b.visit(self)
120 case "-":
121 return expression.a.visit(self) - expression.b.visit(self)
122 case "*":
123 return expression.a.visit(self) * expression.b.visit(self)
124 case "/":
125 match expression.column_type:
126 case "int":
127 return expression.a.visit(self) // expression.b.visit(self)
128 case "float":
129 return expression.a.visit(self) / expression.b.visit(self)
130 case "%":
131 return expression.a.visit(self) % expression.b.visit(self)
133 def visit_comparison(
134 self,
135 a: qt.ColumnExpression,
136 operator: qt.ComparisonOperator,
137 b: qt.ColumnExpression,
138 flags: PredicateVisitFlags,
139 ) -> bool:
140 match operator:
141 case "==":
142 return a.visit(self) == b.visit(self)
143 case "!=":
144 return a.visit(self) != b.visit(self)
145 case "<":
146 return a.visit(self) < b.visit(self)
147 case ">":
148 return a.visit(self) > b.visit(self)
149 case "<=":
150 return a.visit(self) <= b.visit(self)
151 case ">=":
152 return a.visit(self) >= b.visit(self)
153 case "overlaps":
154 return not (a.visit(self).relate(b.visit(self)) & DISJOINT)
156 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any:
157 return self.dataset_fields[expression.dataset_type, expression.field]
159 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any:
160 return self.dimension_fields[expression.element.name, expression.field]
162 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any:
163 return self.dimension_keys[expression.dimension.name]
165 def visit_in_container(
166 self,
167 member: qt.ColumnExpression,
168 container: tuple[qt.ColumnExpression, ...],
169 flags: PredicateVisitFlags,
170 ) -> bool:
171 return member.visit(self) in [item.visit(self) for item in container]
173 def visit_in_range(
174 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
175 ) -> bool:
176 return member.visit(self) in range(start, stop, step)
178 def visit_in_query_tree(
179 self,
180 member: qt.ColumnExpression,
181 column: qt.ColumnExpression,
182 query_tree: qt.QueryTree,
183 flags: PredicateVisitFlags,
184 ) -> bool:
185 return member.visit(self) in self.query_tree_items
187 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool:
188 return operand.visit(self) is None
190 def visit_literal(self, expression: qt.ColumnLiteral) -> Any:
191 return expression.get_literal_value()
193 def visit_reversed(self, expression: qt.Reversed) -> Any:
194 return _TestReversed(expression.operand.visit(self))
196 def visit_unary_expression(self, expression: UnaryExpression) -> Any:
197 match expression.operator:
198 case "-":
199 return -expression.operand.visit(self)
200 case "begin_of":
201 return expression.operand.visit(self).begin
202 case "end_of":
203 return expression.operand.visit(self).end
205 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool:
206 return all(results)
208 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool:
209 return not result
211 def apply_logical_or(
212 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags
213 ) -> bool:
214 return any(results)
217@dataclasses.dataclass
218class _TestReversed:
219 """Struct used by _TestVisitor" to mark an expression as reversed in sort
220 order.
221 """
223 operand: Any
226class _TestQueryExecution(BaseException):
227 """Exception raised by _TestQueryDriver.execute to communicate its args
228 back to the caller.
229 """
231 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None:
232 self.result_spec = result_spec
233 self.tree = tree
234 self.driver = driver
237class _TestQueryCount(BaseException):
238 """Exception raised by _TestQueryDriver.count to communicate its args
239 back to the caller.
240 """
242 def __init__(
243 self,
244 result_spec: qrs.ResultSpec,
245 tree: qt.QueryTree,
246 driver: _TestQueryDriver,
247 exact: bool,
248 discard: bool,
249 ) -> None:
250 self.result_spec = result_spec
251 self.tree = tree
252 self.driver = driver
253 self.exact = exact
254 self.discard = discard
257class _TestQueryAny(BaseException):
258 """Exception raised by _TestQueryDriver.any to communicate its args
259 back to the caller.
260 """
262 def __init__(
263 self,
264 tree: qt.QueryTree,
265 driver: _TestQueryDriver,
266 exact: bool,
267 execute: bool,
268 ) -> None:
269 self.tree = tree
270 self.driver = driver
271 self.exact = exact
272 self.execute = execute
275class _TestQueryExplainNoResults(BaseException):
276 """Exception raised by _TestQueryDriver.explain_no_results to communicate
277 its args back to the caller.
278 """
280 def __init__(
281 self,
282 tree: qt.QueryTree,
283 driver: _TestQueryDriver,
284 execute: bool,
285 ) -> None:
286 self.tree = tree
287 self.driver = driver
288 self.execute = execute
291class _TestQueryDriver(qd.QueryDriver):
292 """Mock implementation of `QueryDriver` that mostly raises exceptions that
293 communicate the arguments its methods were called with.
295 Parameters
296 ----------
297 default_collections : `tuple` [ `str`, ... ], optional
298 Default collection the query or parent butler is imagined to have been
299 constructed with.
300 collection_info : `~collections.abc.Mapping`, optional
301 Mapping from collection name to its record and summary, simulating the
302 collections present in the data repository.
303 dataset_types : `~collections.abc.Mapping`, optional
304 Mapping from dataset type to its definition, simulating the dataset
305 types registered in the data repository.
306 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional
307 A tuple of iterables of arbitrary type to use as result rows any time
308 `execute` is called, with each nested iterable considered a separate
309 page. The result type is not checked for consistency with the result
310 spec. If this is not provided, `execute` will instead raise
311 `_TestQueryExecution`, and `fetch_page` will not do anything useful.
312 """
314 def __init__(
315 self,
316 default_collections: tuple[str, ...] | None = None,
317 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None,
318 dataset_types: Mapping[str, DatasetType] | None = None,
319 result_rows: tuple[Iterable[Any], ...] | None = None,
320 ) -> None:
321 self._universe = DimensionUniverse()
322 # Mapping of the arguments passed to materialize, keyed by the UUID
323 # that that each call returned.
324 self.materializations: dict[
325 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]]
326 ] = {}
327 # Mapping of the arguments passed to upload_data_coordinates, keyed by
328 # the UUID that that each call returned.
329 self.data_coordinate_uploads: dict[
330 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]]
331 ] = {}
332 self._default_collections = default_collections
333 self._collection_info = collection_info or {}
334 self._dataset_types = dataset_types or {}
335 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = []
336 self._result_rows = result_rows
337 self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {}
339 @property
340 def universe(self) -> DimensionUniverse:
341 return self._universe
343 def __enter__(self) -> None:
344 pass
346 def __exit__(self, *args: Any, **kwargs: Any) -> None:
347 pass
349 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage:
350 if self._result_rows is not None:
351 iterator = iter(self._result_rows)
352 current_rows = next(iterator, ())
353 return self._make_next_page(result_spec, current_rows, iterator)
354 raise _TestQueryExecution(result_spec, tree, self)
356 def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage:
357 if self._result_rows is not None:
358 return self._make_next_page(result_spec, *self._result_iters.pop(key))
359 raise AssertionError("Test query driver not initialized for actual results.")
361 def _make_next_page(
362 self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]]
363 ) -> qd.ResultPage:
364 next_rows = list(next(iterator, ()))
365 if not next_rows:
366 next_key = None
367 else:
368 next_key = uuid.uuid4()
369 self._result_iters[next_key] = (next_rows, iterator)
370 match result_spec:
371 case qrs.DataCoordinateResultSpec():
372 return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
373 case qrs.DimensionRecordResultSpec():
374 return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
375 case qrs.DatasetRefResultSpec():
376 return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
377 case _:
378 raise NotImplementedError("Other query types not yet supported.")
380 def materialize(
381 self,
382 tree: qt.QueryTree,
383 dimensions: DimensionGroup,
384 datasets: frozenset[str],
385 ) -> qd.MaterializationKey:
386 key = uuid.uuid4()
387 self.materializations[key] = (tree, dimensions, datasets)
388 return key
390 def upload_data_coordinates(
391 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
392 ) -> qd.DataCoordinateUploadKey:
393 key = uuid.uuid4()
394 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows))
395 return key
397 def count(
398 self,
399 tree: qt.QueryTree,
400 result_spec: qrs.ResultSpec,
401 *,
402 exact: bool,
403 discard: bool,
404 ) -> int:
405 raise _TestQueryCount(result_spec, tree, self, exact, discard)
407 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
408 raise _TestQueryAny(tree, self, exact, execute)
410 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]:
411 raise _TestQueryExplainNoResults(tree, self, execute)
413 def get_default_collections(self) -> tuple[str, ...]:
414 if self._default_collections is None:
415 raise NoDefaultCollectionError()
416 return self._default_collections
418 def get_dataset_type(self, name: str) -> DatasetType:
419 try:
420 return self._dataset_types[name]
421 except KeyError:
422 raise MissingDatasetTypeError(name)
425class ColumnExpressionsTestCase(unittest.TestCase):
426 """Tests for column expression objects in lsst.daf.butler.queries.tree."""
428 def setUp(self) -> None:
429 self.universe = DimensionUniverse()
430 self.x = ExpressionFactory(self.universe)
432 def query(self, **kwargs: Any) -> Query:
433 """Make an initial Query object with the given kwargs used to
434 initialize the _TestQueryDriver.
435 """
436 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
438 def test_int_literals(self) -> None:
439 expr = self.x.unwrap(self.x.literal(5))
440 self.assertEqual(expr.value, 5)
441 self.assertEqual(expr.get_literal_value(), 5)
442 self.assertEqual(expr.expression_type, "int")
443 self.assertEqual(expr.column_type, "int")
444 self.assertEqual(str(expr), "5")
445 self.assertTrue(expr.is_literal)
446 columns = qt.ColumnSet(self.universe.empty.as_group())
447 expr.gather_required_columns(columns)
448 self.assertFalse(columns)
449 self.assertEqual(expr.visit(_TestVisitor()), 5)
451 def test_string_literals(self) -> None:
452 expr = self.x.unwrap(self.x.literal("five"))
453 self.assertEqual(expr.value, "five")
454 self.assertEqual(expr.get_literal_value(), "five")
455 self.assertEqual(expr.expression_type, "string")
456 self.assertEqual(expr.column_type, "string")
457 self.assertEqual(str(expr), "'five'")
458 self.assertTrue(expr.is_literal)
459 columns = qt.ColumnSet(self.universe.empty.as_group())
460 expr.gather_required_columns(columns)
461 self.assertFalse(columns)
462 self.assertEqual(expr.visit(_TestVisitor()), "five")
464 def test_float_literals(self) -> None:
465 expr = self.x.unwrap(self.x.literal(0.5))
466 self.assertEqual(expr.value, 0.5)
467 self.assertEqual(expr.get_literal_value(), 0.5)
468 self.assertEqual(expr.expression_type, "float")
469 self.assertEqual(expr.column_type, "float")
470 self.assertEqual(str(expr), "0.5")
471 self.assertTrue(expr.is_literal)
472 columns = qt.ColumnSet(self.universe.empty.as_group())
473 expr.gather_required_columns(columns)
474 self.assertFalse(columns)
475 self.assertEqual(expr.visit(_TestVisitor()), 0.5)
477 def test_hash_literals(self) -> None:
478 expr = self.x.unwrap(self.x.literal(b"eleven"))
479 self.assertEqual(expr.value, b"eleven")
480 self.assertEqual(expr.get_literal_value(), b"eleven")
481 self.assertEqual(expr.expression_type, "hash")
482 self.assertEqual(expr.column_type, "hash")
483 self.assertEqual(str(expr), "(bytes)")
484 self.assertTrue(expr.is_literal)
485 columns = qt.ColumnSet(self.universe.empty.as_group())
486 expr.gather_required_columns(columns)
487 self.assertFalse(columns)
488 self.assertEqual(expr.visit(_TestVisitor()), b"eleven")
490 def test_uuid_literals(self) -> None:
491 value = uuid.uuid4()
492 expr = self.x.unwrap(self.x.literal(value))
493 self.assertEqual(expr.value, value)
494 self.assertEqual(expr.get_literal_value(), value)
495 self.assertEqual(expr.expression_type, "uuid")
496 self.assertEqual(expr.column_type, "uuid")
497 self.assertEqual(str(expr), str(value))
498 self.assertTrue(expr.is_literal)
499 columns = qt.ColumnSet(self.universe.empty.as_group())
500 expr.gather_required_columns(columns)
501 self.assertFalse(columns)
502 self.assertEqual(expr.visit(_TestVisitor()), value)
504 def test_datetime_literals(self) -> None:
505 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
506 expr = self.x.unwrap(self.x.literal(value))
507 self.assertEqual(expr.value, value)
508 self.assertEqual(expr.get_literal_value(), value)
509 self.assertEqual(expr.expression_type, "datetime")
510 self.assertEqual(expr.column_type, "datetime")
511 self.assertEqual(str(expr), "2020-01-01T00:00:00")
512 self.assertTrue(expr.is_literal)
513 columns = qt.ColumnSet(self.universe.empty.as_group())
514 expr.gather_required_columns(columns)
515 self.assertFalse(columns)
516 self.assertEqual(expr.visit(_TestVisitor()), value)
518 def test_timespan_literals(self) -> None:
519 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
520 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
521 value = Timespan(begin, end)
522 expr = self.x.unwrap(self.x.literal(value))
523 self.assertEqual(expr.value, value)
524 self.assertEqual(expr.get_literal_value(), value)
525 self.assertEqual(expr.expression_type, "timespan")
526 self.assertEqual(expr.column_type, "timespan")
527 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)")
528 self.assertTrue(expr.is_literal)
529 columns = qt.ColumnSet(self.universe.empty.as_group())
530 expr.gather_required_columns(columns)
531 self.assertFalse(columns)
532 self.assertEqual(expr.visit(_TestVisitor()), value)
534 def test_region_literals(self) -> None:
535 pixelization = Mq3cPixelization(10)
536 value = pixelization.quad(12058870)
537 expr = self.x.unwrap(self.x.literal(value))
538 self.assertEqual(expr.value, value)
539 self.assertEqual(expr.get_literal_value(), value)
540 self.assertEqual(expr.expression_type, "region")
541 self.assertEqual(expr.column_type, "region")
542 self.assertEqual(str(expr), "(region)")
543 self.assertTrue(expr.is_literal)
544 columns = qt.ColumnSet(self.universe.empty.as_group())
545 expr.gather_required_columns(columns)
546 self.assertFalse(columns)
547 self.assertEqual(expr.visit(_TestVisitor()), value)
549 def test_invalid_literal(self) -> None:
550 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."):
551 self.x.literal(5j)
553 def test_dimension_key_reference(self) -> None:
554 expr = self.x.unwrap(self.x.detector)
555 self.assertIsNone(expr.get_literal_value())
556 self.assertEqual(expr.expression_type, "dimension_key")
557 self.assertEqual(expr.column_type, "int")
558 self.assertEqual(str(expr), "detector")
559 self.assertFalse(expr.is_literal)
560 columns = qt.ColumnSet(self.universe.empty.as_group())
561 expr.gather_required_columns(columns)
562 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
563 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3)
565 def test_dimension_field_reference(self) -> None:
566 expr = self.x.unwrap(self.x.detector.purpose)
567 self.assertIsNone(expr.get_literal_value())
568 self.assertEqual(expr.expression_type, "dimension_field")
569 self.assertEqual(expr.column_type, "string")
570 self.assertEqual(str(expr), "detector.purpose")
571 self.assertFalse(expr.is_literal)
572 columns = qt.ColumnSet(self.universe.empty.as_group())
573 expr.gather_required_columns(columns)
574 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
575 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
576 with self.assertRaises(qt.InvalidQueryError):
577 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region")
578 self.assertEqual(
579 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science"
580 )
582 def test_dataset_field_reference(self) -> None:
583 expr = self.x.unwrap(self.x["raw"].ingest_date)
584 self.assertIsNone(expr.get_literal_value())
585 self.assertEqual(expr.expression_type, "dataset_field")
586 self.assertEqual(str(expr), "raw.ingest_date")
587 self.assertFalse(expr.is_literal)
588 columns = qt.ColumnSet(self.universe.empty.as_group())
589 expr.gather_required_columns(columns)
590 self.assertEqual(columns.dimensions, self.universe.empty.as_group())
591 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"})
592 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid")
593 self.assertEqual(
594 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string"
595 )
596 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string")
597 self.assertEqual(
598 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime"
599 )
600 self.assertEqual(
601 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan"
602 )
603 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
604 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value)
606 def test_unary_negation(self) -> None:
607 expr = self.x.unwrap(-self.x.visit.exposure_time)
608 self.assertIsNone(expr.get_literal_value())
609 self.assertEqual(expr.expression_type, "unary")
610 self.assertEqual(expr.column_type, "float")
611 self.assertEqual(str(expr), "-visit.exposure_time")
612 self.assertFalse(expr.is_literal)
613 columns = qt.ColumnSet(self.universe.empty.as_group())
614 expr.gather_required_columns(columns)
615 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
616 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
617 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0)
618 with self.assertRaises(qt.InvalidQueryError):
619 qt.UnaryExpression(
620 operand=qt.DimensionFieldReference(
621 element=self.universe.dimensions["detector"], field="purpose"
622 ),
623 operator="-",
624 )
626 def test_unary_timespan_begin(self) -> None:
627 expr = self.x.unwrap(self.x.visit.timespan.begin)
628 self.assertIsNone(expr.get_literal_value())
629 self.assertEqual(expr.expression_type, "unary")
630 self.assertEqual(expr.column_type, "datetime")
631 self.assertEqual(str(expr), "visit.timespan.begin")
632 self.assertFalse(expr.is_literal)
633 columns = qt.ColumnSet(self.universe.empty.as_group())
634 expr.gather_required_columns(columns)
635 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
636 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
637 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
638 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
639 value = Timespan(begin, end)
640 self.assertEqual(
641 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin
642 )
643 with self.assertRaises(qt.InvalidQueryError):
644 qt.UnaryExpression(
645 operand=qt.DimensionFieldReference(
646 element=self.universe.dimensions["detector"], field="purpose"
647 ),
648 operator="begin_of",
649 )
651 def test_unary_timespan_end(self) -> None:
652 expr = self.x.unwrap(self.x.visit.timespan.end)
653 self.assertIsNone(expr.get_literal_value())
654 self.assertEqual(expr.expression_type, "unary")
655 self.assertEqual(expr.column_type, "datetime")
656 self.assertEqual(str(expr), "visit.timespan.end")
657 self.assertFalse(expr.is_literal)
658 columns = qt.ColumnSet(self.universe.empty.as_group())
659 expr.gather_required_columns(columns)
660 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
661 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
662 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
663 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
664 value = Timespan(begin, end)
665 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end)
666 with self.assertRaises(qt.InvalidQueryError):
667 qt.UnaryExpression(
668 operand=qt.DimensionFieldReference(
669 element=self.universe.dimensions["detector"], field="purpose"
670 ),
671 operator="end_of",
672 )
674 def test_binary_expression_float(self) -> None:
675 for proxy, string, value in [
676 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0),
677 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0),
678 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0),
679 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0),
680 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0),
681 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0),
682 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0),
683 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0),
684 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0),
685 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0),
686 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0),
687 # We don't need the parentheses we generate in the next one, but
688 # they're not a problem either.
689 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0),
690 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0),
691 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0),
692 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0),
693 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0),
694 ]:
695 with self.subTest(string=string):
696 expr = self.x.unwrap(proxy)
697 self.assertIsNone(expr.get_literal_value())
698 self.assertEqual(expr.expression_type, "binary")
699 self.assertEqual(expr.column_type, "float")
700 self.assertEqual(str(expr), string)
701 self.assertFalse(expr.is_literal)
702 columns = qt.ColumnSet(self.universe.empty.as_group())
703 expr.gather_required_columns(columns)
704 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
705 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
706 self.assertEqual(
707 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value
708 )
710 def test_binary_modulus(self) -> None:
711 for proxy, string, value in [
712 (self.x.visit.id % 2, "visit % 2", 1),
713 (52 % self.x.visit, "52 % visit", 2),
714 ]:
715 with self.subTest(string=string):
716 expr = self.x.unwrap(proxy)
717 self.assertIsNone(expr.get_literal_value())
718 self.assertEqual(expr.expression_type, "binary")
719 self.assertEqual(expr.column_type, "int")
720 self.assertEqual(str(expr), string)
721 self.assertFalse(expr.is_literal)
722 columns = qt.ColumnSet(self.universe.empty.as_group())
723 expr.gather_required_columns(columns)
724 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
725 self.assertFalse(columns.dimension_fields["visit"])
726 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value)
728 def test_binary_expression_validation(self) -> None:
729 with self.assertRaises(qt.InvalidQueryError):
730 # No arithmetic operators on strings (we do not interpret + as
731 # concatenation).
732 self.x.instrument + "suffix"
733 with self.assertRaises(qt.InvalidQueryError):
734 # Mixed types are not supported, even when they both support the
735 # operator.
736 self.x.visit.exposure_time + self.x.detector
737 with self.assertRaises(qt.InvalidQueryError):
738 # No modulus for floats.
739 self.x.visit.exposure_time % 5.0
741 def test_reversed(self) -> None:
742 expr = self.x.detector.desc
743 self.assertIsNone(expr.get_literal_value())
744 self.assertEqual(expr.expression_type, "reversed")
745 self.assertEqual(expr.column_type, "int")
746 self.assertEqual(str(expr), "detector DESC")
747 self.assertFalse(expr.is_literal)
748 columns = qt.ColumnSet(self.universe.empty.as_group())
749 expr.gather_required_columns(columns)
750 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
751 self.assertFalse(columns.dimension_fields["detector"])
752 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5))
754 def test_trivial_predicate(self) -> None:
755 """Test logical operations on trivial True/False predicates."""
756 yes = qt.Predicate.from_bool(True)
757 no = qt.Predicate.from_bool(False)
758 maybe: qt.Predicate = self.x.detector == 5
759 for predicate in [
760 yes,
761 yes.logical_or(no),
762 no.logical_or(yes),
763 yes.logical_and(yes),
764 no.logical_not(),
765 yes.logical_or(maybe),
766 maybe.logical_or(yes),
767 ]:
768 self.assertEqual(predicate.column_type, "bool")
769 self.assertEqual(str(predicate), "True")
770 self.assertTrue(predicate.visit(_TestVisitor()))
771 self.assertEqual(predicate.operands, ())
772 for predicate in [
773 no,
774 yes.logical_and(no),
775 no.logical_and(yes),
776 no.logical_or(no),
777 yes.logical_not(),
778 no.logical_and(maybe),
779 maybe.logical_and(no),
780 ]:
781 self.assertEqual(predicate.column_type, "bool")
782 self.assertEqual(str(predicate), "False")
783 self.assertFalse(predicate.visit(_TestVisitor()))
784 self.assertEqual(predicate.operands, ((),))
785 for predicate in [
786 maybe,
787 yes.logical_and(maybe),
788 no.logical_or(maybe),
789 maybe.logical_not().logical_not(),
790 ]:
791 self.assertEqual(predicate.column_type, "bool")
792 self.assertEqual(str(predicate), "detector == 5")
793 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5})))
794 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4})))
795 self.assertEqual(len(predicate.operands), 1)
796 self.assertEqual(len(predicate.operands[0]), 1)
797 self.assertIs(predicate.operands[0][0], maybe.operands[0][0])
799 def test_comparison(self) -> None:
800 predicate: qt.Predicate
801 string: str
802 value: bool
803 for detector in (4, 5, 6):
804 for predicate, string, value in [
805 (self.x.detector == 5, "detector == 5", detector == 5),
806 (self.x.detector != 5, "detector != 5", detector != 5),
807 (self.x.detector < 5, "detector < 5", detector < 5),
808 (self.x.detector > 5, "detector > 5", detector > 5),
809 (self.x.detector <= 5, "detector <= 5", detector <= 5),
810 (self.x.detector >= 5, "detector >= 5", detector >= 5),
811 (self.x.detector == 5, "detector == 5", detector == 5),
812 (self.x.detector != 5, "detector != 5", detector != 5),
813 (self.x.detector < 5, "detector < 5", detector < 5),
814 (self.x.detector > 5, "detector > 5", detector > 5),
815 (self.x.detector <= 5, "detector <= 5", detector <= 5),
816 (self.x.detector >= 5, "detector >= 5", detector >= 5),
817 ]:
818 with self.subTest(string=string, detector=detector):
819 self.assertEqual(predicate.column_type, "bool")
820 self.assertEqual(str(predicate), string)
821 columns = qt.ColumnSet(self.universe.empty.as_group())
822 predicate.gather_required_columns(columns)
823 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
824 self.assertFalse(columns.dimension_fields["detector"])
825 self.assertEqual(
826 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value
827 )
828 inverted = predicate.logical_not()
829 self.assertEqual(inverted.column_type, "bool")
830 self.assertEqual(str(inverted), f"NOT {string}")
831 self.assertEqual(
832 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value
833 )
834 columns = qt.ColumnSet(self.universe.empty.as_group())
835 inverted.gather_required_columns(columns)
836 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
837 self.assertFalse(columns.dimension_fields["detector"])
839 def test_overlap_comparison(self) -> None:
840 pixelization = Mq3cPixelization(10)
841 region1 = pixelization.quad(12058870)
842 predicate = self.x.visit.region.overlaps(region1)
843 self.assertEqual(predicate.column_type, "bool")
844 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)")
845 columns = qt.ColumnSet(self.universe.empty.as_group())
846 predicate.gather_required_columns(columns)
847 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
848 self.assertEqual(columns.dimension_fields["visit"], {"region"})
849 region2 = pixelization.quad(12058857)
850 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
851 inverted = predicate.logical_not()
852 self.assertEqual(inverted.column_type, "bool")
853 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)")
854 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
855 columns = qt.ColumnSet(self.universe.empty.as_group())
856 inverted.gather_required_columns(columns)
857 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
858 self.assertEqual(columns.dimension_fields["visit"], {"region"})
860 def test_invalid_comparison(self) -> None:
861 # Mixed type comparisons.
862 with self.assertRaises(qt.InvalidQueryError):
863 self.x.visit > "three"
864 with self.assertRaises(qt.InvalidQueryError):
865 self.x.visit > 3.0
866 # Invalid operator for type.
867 with self.assertRaises(qt.InvalidQueryError):
868 self.x["raw"].dataset_id < uuid.uuid4()
870 def test_is_null(self) -> None:
871 predicate = self.x.visit.region.is_null
872 self.assertEqual(predicate.column_type, "bool")
873 self.assertEqual(str(predicate), "visit.region IS NULL")
874 columns = qt.ColumnSet(self.universe.empty.as_group())
875 predicate.gather_required_columns(columns)
876 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
877 self.assertEqual(columns.dimension_fields["visit"], {"region"})
878 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
879 inverted = predicate.logical_not()
880 self.assertEqual(inverted.column_type, "bool")
881 self.assertEqual(str(inverted), "NOT visit.region IS NULL")
882 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
883 inverted.gather_required_columns(columns)
884 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
885 self.assertEqual(columns.dimension_fields["visit"], {"region"})
887 def test_in_container(self) -> None:
888 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id])
889 self.assertEqual(predicate.column_type, "bool")
890 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]")
891 columns = qt.ColumnSet(self.universe.empty.as_group())
892 predicate.gather_required_columns(columns)
893 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
894 self.assertFalse(columns.dimension_fields["visit"])
895 self.assertFalse(columns.dimension_fields["exposure"])
896 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
897 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
898 inverted = predicate.logical_not()
899 self.assertEqual(inverted.column_type, "bool")
900 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]")
901 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
902 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
903 columns = qt.ColumnSet(self.universe.empty.as_group())
904 inverted.gather_required_columns(columns)
905 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
906 self.assertFalse(columns.dimension_fields["visit"])
907 self.assertFalse(columns.dimension_fields["exposure"])
908 with self.assertRaises(qt.InvalidQueryError):
909 # Regions (and timespans) not allowed in IN expressions, since that
910 # suggests topological logic we're not actually doing. We can't
911 # use ExpressionFactory because it prohibits this case with typing.
912 pixelization = Mq3cPixelization(10)
913 region = pixelization.quad(12058870)
914 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)])
915 with self.assertRaises(qt.InvalidQueryError):
916 # Mismatched types.
917 self.x.visit.in_iterable([3.5, 2.1])
919 def test_in_range(self) -> None:
920 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2)
921 self.assertEqual(predicate.column_type, "bool")
922 self.assertEqual(str(predicate), "visit IN 2:8:2")
923 columns = qt.ColumnSet(self.universe.empty.as_group())
924 predicate.gather_required_columns(columns)
925 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
926 self.assertFalse(columns.dimension_fields["visit"])
927 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2})))
928 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8})))
929 inverted = predicate.logical_not()
930 self.assertEqual(inverted.column_type, "bool")
931 self.assertEqual(str(inverted), "NOT visit IN 2:8:2")
932 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2})))
933 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8})))
934 columns = qt.ColumnSet(self.universe.empty.as_group())
935 inverted.gather_required_columns(columns)
936 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
937 self.assertFalse(columns.dimension_fields["visit"])
938 with self.assertRaises(qt.InvalidQueryError):
939 # Only integer fields allowed.
940 self.x.visit.exposure_time.in_range(2, 4)
941 with self.assertRaises(qt.InvalidQueryError):
942 # Step must be positive.
943 self.x.visit.in_range(2, 4, -1)
944 with self.assertRaises(qt.InvalidQueryError):
945 # Stop must be >= start.
946 self.x.visit.in_range(2, 0)
948 def test_in_query(self) -> None:
949 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3)
950 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query)
951 self.assertEqual(predicate.column_type, "bool")
952 self.assertEqual(str(predicate), "exposure IN (query).visit")
953 columns = qt.ColumnSet(self.universe.empty.as_group())
954 predicate.gather_required_columns(columns)
955 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
956 self.assertFalse(columns.dimension_fields["exposure"])
957 self.assertTrue(
958 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
959 )
960 self.assertFalse(
961 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
962 )
963 inverted = predicate.logical_not()
964 self.assertEqual(inverted.column_type, "bool")
965 self.assertEqual(str(inverted), "NOT exposure IN (query).visit")
966 self.assertFalse(
967 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
968 )
969 self.assertTrue(
970 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
971 )
972 columns = qt.ColumnSet(self.universe.empty.as_group())
973 inverted.gather_required_columns(columns)
974 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
975 self.assertFalse(columns.dimension_fields["exposure"])
976 with self.assertRaises(qt.InvalidQueryError):
977 # Regions (and timespans) not allowed in IN expressions, since that
978 # suggests topological logic we're not actually doing. We can't
979 # use ExpressionFactory because it prohibits this case with typing.
980 qt.Predicate.in_query(
981 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree
982 )
983 with self.assertRaises(qt.InvalidQueryError):
984 # Mismatched types.
985 self.x.exposure.in_query(self.x.visit.exposure_time, query)
986 with self.assertRaises(qt.InvalidQueryError):
987 # Query column requires dimensions that are not in the query.
988 self.x.exposure.in_query(self.x.patch, query)
989 with self.assertRaises(qt.InvalidQueryError):
990 # Query column requires dataset type that is not in the query.
991 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query)
993 def test_complex_predicate(self) -> None:
994 """Test that predicates are converted to conjunctive normal form and
995 get parentheses in the right places when stringified.
996 """
997 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"})
998 a: qt.Predicate = self.x.visit > 5 # will evaluate to True
999 b: qt.Predicate = self.x.detector != 3 # will evaluate to False
1000 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True
1001 d: qt.Predicate = self.x.band == "g" # will evaluate to False
1002 predicate: qt.Predicate
1003 for predicate, string, value in [
1004 (a.logical_or(b), f"{a} OR {b}", True),
1005 (a.logical_or(c), f"{a} OR {c}", True),
1006 (b.logical_or(d), f"{b} OR {d}", False),
1007 (a.logical_and(b), f"{a} AND {b}", False),
1008 (a.logical_and(c), f"{a} AND {c}", True),
1009 (b.logical_and(d), f"{b} AND {d}", False),
1010 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True),
1011 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False),
1012 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True),
1013 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False),
1014 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True),
1015 (
1016 a.logical_and(b).logical_or(c.logical_and(d)),
1017 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})",
1018 False,
1019 ),
1020 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False),
1021 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False),
1022 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True),
1023 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True),
1024 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False),
1025 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True),
1026 (
1027 self.x.not_(a.logical_or(b).logical_and(c)),
1028 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})",
1029 False,
1030 ),
1031 (
1032 a.logical_and(b.logical_or(d)).logical_not(),
1033 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})",
1034 True,
1035 ),
1036 ]:
1037 with self.subTest(string=string):
1038 self.assertEqual(str(predicate), string)
1039 self.assertEqual(predicate.visit(visitor), value)
1041 def test_proxy_misc(self) -> None:
1042 """Test miscellaneous things on various ExpressionFactory proxies."""
1043 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region")
1044 self.assertEqual(str(self.x.visit.instrument), "instrument")
1045 self.assertEqual(str(self.x["raw"]), "raw")
1046 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date")
1047 self.assertEqual(
1048 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)),
1049 "visit.timespan OVERLAPS raw.timespan",
1050 )
1051 self.assertGreater(
1052 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"}
1053 )
1054 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"})
1055 with self.assertRaises(AttributeError):
1056 self.x["raw"].seq_num
1057 with self.assertRaises(AttributeError):
1058 self.x.visit.horse
1061class QueryTestCase(unittest.TestCase):
1062 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries."""
1064 def setUp(self) -> None:
1065 self.maxDiff = None
1066 self.universe = DimensionUniverse()
1067 # We use ArrowTable as the storage class for all dataset types because
1068 # it's got conversions that only require third-party packages we
1069 # already require.
1070 self.raw = DatasetType(
1071 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable"
1072 )
1073 self.refcat = DatasetType(
1074 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable"
1075 )
1076 self.bias = DatasetType(
1077 "bias",
1078 dimensions=self.universe.conform(["detector"]),
1079 storageClass="ArrowTable",
1080 isCalibration=True,
1081 )
1082 self.default_collections: list[str] | None = ["DummyCam/defaults"]
1083 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = {
1084 "DummyCam/raw/all": (
1085 RunRecord[int](1, name="DummyCam/raw/all"),
1086 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}),
1087 ),
1088 "DummyCam/calib": (
1089 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION),
1090 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}),
1091 ),
1092 "refcats": (
1093 RunRecord[int](3, name="refcats"),
1094 CollectionSummary(NamedValueSet({self.refcat}), governors={}),
1095 ),
1096 "DummyCam/defaults": (
1097 ChainedCollectionRecord[int](
1098 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats")
1099 ),
1100 CollectionSummary(
1101 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}}
1102 ),
1103 ),
1104 }
1105 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias}
1107 def query(self, **kwargs: Any) -> Query:
1108 """Make an initial Query object with the given kwargs used to
1109 initialize the _TestQueryDriver.
1111 The given kwargs override the test-case-attribute defaults.
1112 """
1113 kwargs.setdefault("default_collections", self.default_collections)
1114 kwargs.setdefault("collection_info", self.collection_info)
1115 kwargs.setdefault("dataset_types", self.dataset_types)
1116 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
1118 def test_dataset_join(self) -> None:
1119 """Test queries that have had a dataset search explicitly joined in via
1120 Query.join_dataset_search.
1122 Since this kind of query has a moderate amount of complexity, this is
1123 where we get a lot of basic coverage that applies to all kinds of
1124 queries, including:
1126 - getting data ID and dataset results (but not iterating over them);
1127 - the 'any' and 'explain_no_results' methods;
1128 - adding 'where' filters (but not expanding dimensions accordingly);
1129 - materializations.
1130 """
1132 def check(
1133 query: Query,
1134 dimensions: DimensionGroup = self.raw.dimensions.as_group(),
1135 ) -> None:
1136 """Run a battery of tests on one of a set of very similar queries
1137 constructed in different ways (see below).
1138 """
1140 def check_query_tree(
1141 tree: qt.QueryTree,
1142 dimensions: DimensionGroup = dimensions,
1143 ) -> None:
1144 """Check the state of the QueryTree object that backs the Query
1145 or a derived QueryResults object.
1147 Parameters
1148 ----------
1149 tree : `lsst.daf.butler.queries.tree.QueryTree`
1150 Object to test.
1151 dimensions : `DimensionGroup`
1152 Dimensions to expect in the `QueryTree`, not necessarily
1153 including those in the test 'raw' dataset type.
1154 """
1155 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group())
1156 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'")
1157 self.assertFalse(tree.materializations)
1158 self.assertFalse(tree.data_coordinate_uploads)
1159 self.assertEqual(tree.datasets.keys(), {"raw"})
1160 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1161 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",))
1162 self.assertEqual(
1163 tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()})
1164 )
1166 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None:
1167 """Construct a DataCoordinateQueryResults object from the query
1168 with the given arguments and run a battery of tests on it.
1170 Parameters
1171 ----------
1172 *args
1173 Forwarded to `Query.data_ids`.
1174 query : `Query`
1175 Query to start from.
1176 dimensions : `DimensionGroup`, optional
1177 Dimensions the result data IDs should have.
1178 """
1179 with self.assertRaises(_TestQueryExecution) as cm:
1180 list(query.data_ids(*args))
1181 self.assertEqual(
1182 cm.exception.result_spec,
1183 qrs.DataCoordinateResultSpec(dimensions=dimensions),
1184 )
1185 check_query_tree(cm.exception.tree, dimensions=dimensions)
1187 def check_dataset_results(
1188 *args: Any,
1189 query: Query,
1190 find_first: bool = True,
1191 storage_class_name: str = self.raw.storageClass_name,
1192 ) -> None:
1193 """Construct a DatasetRefQueryResults object from the query
1194 with the given arguments and run a battery of tests on it.
1196 Parameters
1197 ----------
1198 *args
1199 Forwarded to `Query.datasets`.
1200 query : `Query`
1201 Query to start from.
1202 find_first : `bool`, optional
1203 Whether to do find-first resolution on the results.
1204 storage_class_name : `str`, optional
1205 Expected name of the storage class for the results.
1206 """
1207 with self.assertRaises(_TestQueryExecution) as cm:
1208 list(query.datasets(*args, find_first=find_first))
1209 self.assertEqual(
1210 cm.exception.result_spec,
1211 qrs.DatasetRefResultSpec(
1212 dataset_type_name="raw",
1213 dimensions=self.raw.dimensions.as_group(),
1214 storage_class_name=storage_class_name,
1215 find_first=find_first,
1216 ),
1217 )
1218 check_query_tree(cm.exception.tree)
1220 def check_materialization(
1221 kwargs: Mapping[str, Any],
1222 query: Query,
1223 dimensions: DimensionGroup = dimensions,
1224 has_dataset: bool = True,
1225 ) -> None:
1226 """Materialize the query with the given arguments and run a
1227 battery of tests on the result.
1229 Parameters
1230 ----------
1231 kwargs
1232 Forwarded as keyword arguments to `Query.materialize`.
1233 query : `Query`
1234 Query to start from.
1235 dimensions : `DimensionGroup`, optional
1236 Dimensions to expect in the materialization and its derived
1237 query.
1238 has_dataset : `bool`, optional
1239 Whether the query backed by the materialization should
1240 still have the test 'raw' dataset joined in.
1241 """
1242 # Materialize the query and check the query tree sent to the
1243 # driver and the one in the materialized query.
1244 with self.assertRaises(_TestQueryExecution) as cm:
1245 list(query.materialize(**kwargs).data_ids())
1246 derived_tree = cm.exception.tree
1247 self.assertEqual(derived_tree.dimensions, dimensions)
1248 # Predicate should be materialized away; it no longer appears
1249 # in the derived query.
1250 self.assertEqual(str(derived_tree.predicate), "True")
1251 self.assertFalse(derived_tree.data_coordinate_uploads)
1252 if has_dataset:
1253 # Dataset search is still there, even though its existence
1254 # constraint is included in the materialization, because we
1255 # might need to re-join for some result columns in a
1256 # derived query.
1257 self.assertTrue(derived_tree.datasets.keys(), {"raw"})
1258 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1259 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",))
1260 else:
1261 self.assertFalse(derived_tree.datasets)
1262 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items()
1263 self.assertEqual(derived_tree_materialized_dimensions, dimensions)
1264 (
1265 materialized_tree,
1266 materialized_dimensions,
1267 materialized_datasets,
1268 ) = cm.exception.driver.materializations[key]
1269 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions)
1270 if has_dataset:
1271 self.assertEqual(materialized_datasets, {"raw"})
1272 else:
1273 self.assertFalse(materialized_datasets)
1274 check_query_tree(materialized_tree)
1276 # Actual logic for the check() function begins here.
1278 self.assertEqual(query.constraint_dataset_types, {"raw"})
1279 self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group())
1281 # Adding a constraint on a field for this dataset type should work
1282 # (this constraint will be present in all downstream tests).
1283 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all")
1284 with self.assertRaises(qt.InvalidQueryError):
1285 # Adding constraint on a different dataset should not work.
1286 query.where(query.expression_factory["refcat"].run == "refcats")
1288 # Data IDs, with dimensions defaulted.
1289 check_data_id_results(query=query)
1290 # Dimensions for data IDs the same as defaults.
1291 check_data_id_results(["exposure", "detector"], query=query)
1292 # Dimensions are a subset of the query dimensions.
1293 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"]))
1294 # Dimensions are a superset of the query dimensions.
1295 check_data_id_results(
1296 ["exposure", "detector", "visit"],
1297 query=query,
1298 dimensions=self.universe.conform(["exposure", "detector", "visit"]),
1299 )
1300 # Dimensions are neither a superset nor a subset of the query
1301 # dimensions.
1302 check_data_id_results(
1303 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"])
1304 )
1305 # Dimensions are empty.
1306 check_data_id_results([], query=query, dimensions=self.universe.conform([]))
1308 # Get DatasetRef results, with various arguments and defaulting.
1309 check_dataset_results("raw", query=query)
1310 check_dataset_results("raw", query=query, find_first=True)
1311 check_dataset_results("raw", ["DummyCam/defaults"], query=query)
1312 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True)
1313 check_dataset_results(self.raw, query=query)
1314 check_dataset_results(self.raw, query=query, find_first=True)
1315 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query)
1316 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True)
1318 # Changing collections at this stage is not allowed.
1319 with self.assertRaises(qt.InvalidQueryError):
1320 query.datasets("raw", collections=["DummyCam/calib"])
1322 # Changing storage classes is allowed, if they're compatible.
1323 check_dataset_results(
1324 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy"
1325 )
1326 with self.assertRaises(DatasetTypeError):
1327 # Can't use overrideStorageClass, because it'll raise
1328 # before the code we want to test can.
1329 query.datasets(DatasetType("raw", self.raw.dimensions, "int"))
1331 # Check the 'any' and 'explain_no_results' methods on Query itself.
1332 for execute, exact in itertools.permutations([False, True], 2):
1333 with self.assertRaises(_TestQueryAny) as cm:
1334 query.any(execute=execute, exact=exact)
1335 self.assertEqual(cm.exception.execute, execute)
1336 self.assertEqual(cm.exception.exact, exact)
1337 check_query_tree(cm.exception.tree, dimensions)
1338 with self.assertRaises(_TestQueryExplainNoResults):
1339 query.explain_no_results()
1340 check_query_tree(cm.exception.tree, dimensions)
1342 # Materialize the query with defaults.
1343 check_materialization({}, query=query)
1344 # Materialize the query with args that match defaults.
1345 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query)
1346 # Materialize the query with a superset of the original dimensions.
1347 check_materialization(
1348 {"dimensions": ["exposure", "detector", "visit"]},
1349 query=query,
1350 dimensions=self.universe.conform(["exposure", "visit", "detector"]),
1351 )
1352 # Materialize the query with no datasets.
1353 check_materialization(
1354 {"dimensions": ["exposure", "detector"], "datasets": frozenset()},
1355 query=query,
1356 has_dataset=False,
1357 )
1358 # Materialize the query with no datasets and a subset of the
1359 # dimensions.
1360 check_materialization(
1361 {"dimensions": ["exposure"], "datasets": frozenset()},
1362 query=query,
1363 has_dataset=False,
1364 dimensions=self.universe.conform(["exposure"]),
1365 )
1366 # Materializing the query with a dataset that is not in the query
1367 # is an error.
1368 with self.assertRaises(qt.InvalidQueryError):
1369 query.materialize(datasets={"refcat"})
1370 # Materializing the query with dimensions that are not a superset
1371 # of any materialized dataset dimensions is an error.
1372 with self.assertRaises(qt.InvalidQueryError):
1373 query.materialize(dimensions=["exposure"], datasets={"raw"})
1375 # Actual logic for test_dataset_joins starts here.
1377 # Default collections and existing dataset type name.
1378 check(self.query().join_dataset_search("raw"))
1379 # Default collections and existing DatasetType instance.
1380 check(self.query().join_dataset_search(self.raw))
1381 # Manual collections and existing dataset type.
1382 check(
1383 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"])
1384 )
1385 check(
1386 self.query(default_collections=None).join_dataset_search(
1387 self.raw, collections=["DummyCam/defaults"]
1388 )
1389 )
1390 with self.assertRaises(MissingDatasetTypeError):
1391 # Dataset type does not exist.
1392 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"])
1393 with self.assertRaises(DatasetTypeError):
1394 # Dataset type object with bad dimensions passed.
1395 self.query().join_dataset_search(
1396 DatasetType(
1397 "raw",
1398 dimensions={"detector", "visit"},
1399 storageClass=self.raw.storageClass_name,
1400 universe=self.universe,
1401 )
1402 )
1403 with self.assertRaises(TypeError):
1404 # Bad type for dataset type argument.
1405 self.query().join_dataset_search(3)
1406 with self.assertRaises(qt.InvalidQueryError):
1407 # Cannot pass storage class override to join_dataset_search,
1408 # because we cannot use it there.
1409 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy"))
1411 def test_dimension_record_results(self) -> None:
1412 """Test queries that return dimension records.
1414 This includes tests for:
1416 - joining against uploaded data coordinates;
1417 - counting result rows;
1418 - expanding dimensions as needed for 'where' conditions;
1419 - order_by and limit.
1421 It does not include the iteration methods of
1422 DimensionRecordQueryResults, since those require a different mock
1423 driver setup (see test_dimension_record_iteration).
1424 """
1425 # Set up the base query-results object to test.
1426 query = self.query()
1427 x = query.expression_factory
1428 self.assertFalse(query.constraint_dimensions)
1429 query = query.where(x.skymap == "m")
1430 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"]))
1431 upload_rows = [
1432 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe),
1433 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe),
1434 ]
1435 raw_rows = frozenset([data_id.required_values for data_id in upload_rows])
1436 query = query.join_data_coordinates(upload_rows)
1437 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"]))
1438 results = query.dimension_records("patch")
1439 results = results.where(x.tract == 4)
1441 # Define a closure to run tests on variants of the base query.
1442 def check(
1443 results: DimensionRecordQueryResults,
1444 order_by: Any = (),
1445 limit: int | None = None,
1446 ) -> list[str]:
1447 results = results.order_by(*order_by).limit(limit)
1448 self.assertEqual(results.element.name, "patch")
1449 with self.assertRaises(_TestQueryExecution) as cm:
1450 list(results)
1451 tree = cm.exception.tree
1452 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1453 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"]))
1454 self.assertFalse(tree.materializations)
1455 self.assertFalse(tree.datasets)
1456 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items()
1457 self.assertEqual(upload_dimensions, self.universe.conform(["visit"]))
1458 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows))
1459 result_spec = cm.exception.result_spec
1460 self.assertEqual(result_spec.result_type, "dimension_record")
1461 self.assertEqual(result_spec.element, self.universe["patch"])
1462 self.assertEqual(result_spec.limit, limit)
1463 for exact, discard in itertools.permutations([False, True], r=2):
1464 with self.assertRaises(_TestQueryCount) as cm:
1465 results.count(exact=exact, discard=discard)
1466 self.assertEqual(cm.exception.result_spec, result_spec)
1467 self.assertEqual(cm.exception.exact, exact)
1468 self.assertEqual(cm.exception.discard, discard)
1469 return [str(term) for term in result_spec.order_by]
1471 # Run the closure's tests on variants of the base query.
1472 self.assertEqual(check(results), [])
1473 self.assertEqual(check(results, limit=2), [])
1474 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1475 self.assertEqual(
1476 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1477 ["patch.cell_x", "patch.cell_y DESC"],
1478 )
1479 with self.assertRaises(qt.InvalidQueryError):
1480 # Cannot upload empty list of data IDs.
1481 query.join_data_coordinates([])
1482 with self.assertRaises(qt.InvalidQueryError):
1483 # Cannot upload heterogeneous list of data IDs.
1484 query.join_data_coordinates(
1485 [
1486 DataCoordinate.make_empty(self.universe),
1487 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe),
1488 ]
1489 )
1491 def test_dimension_record_iteration(self) -> None:
1492 """Tests for DimensionRecordQueryResult iteration."""
1494 def make_record(n: int) -> DimensionRecord:
1495 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n)
1497 result_rows = (
1498 [make_record(n) for n in range(3)],
1499 [make_record(n) for n in range(3, 6)],
1500 [make_record(10)],
1501 )
1502 results = self.query(result_rows=result_rows).dimension_records("patch")
1503 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1504 self.assertEqual(
1505 list(results.iter_set_pages()),
1506 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows],
1507 )
1508 self.assertEqual(
1509 [table.column("id").to_pylist() for table in results.iter_table_pages()],
1510 [list(range(3)), list(range(3, 6)), [10]],
1511 )
1513 def test_data_coordinate_results(self) -> None:
1514 """Test queries that return data coordinates.
1516 This includes tests for:
1518 - counting result rows;
1519 - expanding dimensions as needed for 'where' conditions;
1520 - order_by and limit.
1522 It does not include the iteration methods of
1523 DataCoordinateQueryResults, since those require a different mock
1524 driver setup (see test_data_coordinate_iteration). More tests for
1525 different inputs to DataCoordinateQueryResults construction are in
1526 test_dataset_join.
1527 """
1528 # Set up the base query-results object to test.
1529 query = self.query()
1530 x = query.expression_factory
1531 self.assertFalse(query.constraint_dimensions)
1532 query = query.where(x.skymap == "m")
1533 results = query.data_ids(["patch", "band"])
1534 results = results.where(x.tract == 4)
1536 # Define a closure to run tests on variants of the base query.
1537 def check(
1538 results: DataCoordinateQueryResults,
1539 order_by: Any = (),
1540 limit: int | None = None,
1541 include_dimension_records: bool = False,
1542 ) -> list[str]:
1543 results = results.order_by(*order_by).limit(limit)
1544 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"]))
1545 with self.assertRaises(_TestQueryExecution) as cm:
1546 list(results)
1547 tree = cm.exception.tree
1548 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1549 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"]))
1550 self.assertFalse(tree.materializations)
1551 self.assertFalse(tree.datasets)
1552 self.assertFalse(tree.data_coordinate_uploads)
1553 result_spec = cm.exception.result_spec
1554 self.assertEqual(result_spec.result_type, "data_coordinate")
1555 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"]))
1556 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1557 self.assertEqual(result_spec.limit, limit)
1558 self.assertIsNone(result_spec.find_first_dataset)
1559 for exact, discard in itertools.permutations([False, True], r=2):
1560 with self.assertRaises(_TestQueryCount) as cm:
1561 results.count(exact=exact, discard=discard)
1562 self.assertEqual(cm.exception.result_spec, result_spec)
1563 self.assertEqual(cm.exception.exact, exact)
1564 self.assertEqual(cm.exception.discard, discard)
1565 return [str(term) for term in result_spec.order_by]
1567 # Run the closure's tests on variants of the base query.
1568 self.assertEqual(check(results), [])
1569 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), [])
1570 self.assertEqual(
1571 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1572 [],
1573 )
1574 self.assertEqual(check(results, limit=2), [])
1575 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1576 self.assertEqual(
1577 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1578 ["patch.cell_x", "patch.cell_y DESC"],
1579 )
1580 self.assertEqual(
1581 check(results, order_by=["patch.cell_x", "-cell_y"]),
1582 ["patch.cell_x", "patch.cell_y DESC"],
1583 )
1585 def test_data_coordinate_iteration(self) -> None:
1586 """Tests for DataCoordinateQueryResult iteration."""
1588 def make_data_id(n: int) -> DimensionRecord:
1589 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe)
1591 result_rows = (
1592 [make_data_id(n) for n in range(3)],
1593 [make_data_id(n) for n in range(3, 6)],
1594 [make_data_id(10)],
1595 )
1596 results = self.query(result_rows=result_rows).data_ids(["patch"])
1597 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1599 def test_dataset_results(self) -> None:
1600 """Test queries that return dataset refs.
1602 This includes tests for:
1604 - counting result rows;
1605 - expanding dimensions as needed for 'where' conditions;
1606 - different ways of passing a data ID to 'where' methods;
1607 - order_by and limit.
1609 It does not include the iteration methods of the DatasetRefQueryResults
1610 classes, since those require a different mock driver setup (see
1611 test_dataset_iteration). More tests for different inputs to
1612 DatasetRefQueryResults construction are in test_dataset_join.
1613 """
1614 # Set up a few equivalent base query-results object to test.
1615 query = self.query()
1616 x = query.expression_factory
1617 self.assertFalse(query.constraint_dimensions)
1618 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4)
1619 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where(
1620 {"instrument": "DummyCam", "visit": 4}
1621 )
1622 results3 = query.datasets("raw").where(
1623 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe)
1624 )
1626 # Define a closure to check a DatasetRefQueryResults instance.
1627 def check(
1628 results: DatasetRefQueryResults,
1629 order_by: Any = (),
1630 limit: int | None = None,
1631 include_dimension_records: bool = False,
1632 ) -> list[str]:
1633 results = results.order_by(*order_by).limit(limit)
1634 with self.assertRaises(_TestQueryExecution) as cm:
1635 list(results)
1636 tree = cm.exception.tree
1637 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4")
1638 self.assertEqual(
1639 tree.dimensions,
1640 self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()),
1641 )
1642 self.assertFalse(tree.materializations)
1643 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name})
1644 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",))
1645 self.assertEqual(
1646 tree.datasets[results.dataset_type.name].dimensions,
1647 results.dataset_type.dimensions.as_group(),
1648 )
1649 self.assertFalse(tree.data_coordinate_uploads)
1650 result_spec = cm.exception.result_spec
1651 self.assertEqual(result_spec.result_type, "dataset_ref")
1652 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1653 self.assertEqual(result_spec.limit, limit)
1654 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name)
1655 for exact, discard in itertools.permutations([False, True], r=2):
1656 with self.assertRaises(_TestQueryCount) as cm:
1657 results.count(exact=exact, discard=discard)
1658 self.assertEqual(cm.exception.result_spec, result_spec)
1659 self.assertEqual(cm.exception.exact, exact)
1660 self.assertEqual(cm.exception.discard, discard)
1661 with self.assertRaises(_TestQueryExecution) as cm:
1662 list(results.data_ids)
1663 self.assertEqual(
1664 cm.exception.result_spec,
1665 qrs.DataCoordinateResultSpec(
1666 dimensions=results.dataset_type.dimensions.as_group(),
1667 include_dimension_records=include_dimension_records,
1668 ),
1669 )
1670 self.assertIs(cm.exception.tree, tree)
1671 return [str(term) for term in result_spec.order_by]
1673 # Run the closure's tests on variants of the base query.
1674 self.assertEqual(check(results1), [])
1675 self.assertEqual(check(results2), [])
1676 self.assertEqual(check(results3), [])
1677 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), [])
1678 self.assertEqual(
1679 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1680 [],
1681 )
1682 self.assertEqual(check(results1, limit=2), [])
1683 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"])
1684 self.assertEqual(check(results1, order_by=["detector"]), ["detector"])
1685 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"])
1687 def test_dataset_iteration(self) -> None:
1688 """Tests for SingleTypeDatasetQueryResult iteration."""
1690 def make_ref(n: int) -> DimensionRecord:
1691 return DatasetRef(
1692 self.raw,
1693 DataCoordinate.standardize(
1694 instrument="DummyCam", exposure=4, detector=n, universe=self.universe
1695 ),
1696 run="DummyCam/raw/all",
1697 id=uuid.uuid4(),
1698 )
1700 result_rows = (
1701 [make_ref(n) for n in range(3)],
1702 [make_ref(n) for n in range(3, 6)],
1703 [make_ref(10)],
1704 )
1705 results = self.query(result_rows=result_rows).datasets("raw")
1706 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1708 def test_identifiers(self) -> None:
1709 """Test edge-cases of identifiers in order_by expressions."""
1711 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]:
1712 with self.assertRaises(_TestQueryExecution) as cm:
1713 list(results)
1714 return [str(term) for term in cm.exception.result_spec.order_by]
1716 self.assertEqual(
1717 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")),
1718 ["day_obs.timespan.begin DESC"],
1719 )
1720 self.assertEqual(
1721 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")),
1722 ["day_obs.timespan.end"],
1723 )
1724 self.assertEqual(
1725 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")),
1726 ["visit.timespan.begin DESC"],
1727 )
1728 self.assertEqual(
1729 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")),
1730 ["visit.timespan.end"],
1731 )
1732 self.assertEqual(
1733 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")),
1734 ["visit.science_program"],
1735 )
1736 self.assertEqual(
1737 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")),
1738 ["visit"],
1739 )
1740 self.assertEqual(
1741 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")),
1742 ["physical_filter"],
1743 )
1744 with self.assertRaises(TypeError):
1745 self.query().data_ids(["visit"]).order_by(3)
1746 with self.assertRaises(qt.InvalidQueryError):
1747 self.query().data_ids(["visit"]).order_by("visit.region")
1748 with self.assertRaisesRegex(qt.InvalidQueryError, "Ambiguous"):
1749 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin")
1750 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1751 self.query().data_ids(["visit", "exposure"]).order_by("blarg")
1752 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1753 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse")
1754 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"):
1755 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey")
1756 with self.assertRaisesRegex(qt.InvalidQueryError, "not valid for datasets"):
1757 self.query().datasets("raw").order_by("raw.seq_num")
1759 def test_invalid_models(self) -> None:
1760 """Test invalid models and combinations of models that cannot be
1761 constructed via the public Query and *QueryResults interfaces.
1762 """
1763 x = ExpressionFactory(self.universe)
1764 with self.assertRaises(qt.InvalidQueryError):
1765 # QueryTree dimensions do not cover dataset dimensions.
1766 qt.QueryTree(
1767 dimensions=self.universe.conform(["visit"]),
1768 datasets={
1769 "raw": qt.DatasetSearch(
1770 collections=("DummyCam/raw/all",),
1771 dimensions=self.raw.dimensions.as_group(),
1772 )
1773 },
1774 )
1775 with self.assertRaises(qt.InvalidQueryError):
1776 # QueryTree dimensions do no cover predicate dimensions.
1777 qt.QueryTree(
1778 dimensions=self.universe.conform(["visit"]),
1779 predicate=(x.detector > 5),
1780 )
1781 with self.assertRaises(qt.InvalidQueryError):
1782 # Predicate references a dataset not in the QueryTree.
1783 qt.QueryTree(
1784 dimensions=self.universe.conform(["exposure", "detector"]),
1785 predicate=(x["raw"].collection == "bird"),
1786 )
1787 with self.assertRaises(qt.InvalidQueryError):
1788 # ResultSpec's dimensions are not a subset of the query tree's.
1789 DimensionRecordQueryResults(
1790 _TestQueryDriver(),
1791 qt.QueryTree(dimensions=self.universe.conform(["tract"])),
1792 qrs.DimensionRecordResultSpec(element=self.universe["detector"]),
1793 )
1794 with self.assertRaises(qt.InvalidQueryError):
1795 # ResultSpec's datasets are not a subset of the query tree's.
1796 DatasetRefQueryResults(
1797 _TestQueryDriver(),
1798 qt.QueryTree(dimensions=self.raw.dimensions.as_group()),
1799 qrs.DatasetRefResultSpec(
1800 dataset_type_name="raw",
1801 dimensions=self.raw.dimensions.as_group(),
1802 storage_class_name=self.raw.storageClass_name,
1803 find_first=True,
1804 ),
1805 )
1806 with self.assertRaises(qt.InvalidQueryError):
1807 # ResultSpec's order_by expression is not related to the dimensions
1808 # we're returning.
1809 x = ExpressionFactory(self.universe)
1810 DimensionRecordQueryResults(
1811 _TestQueryDriver(),
1812 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1813 qrs.DimensionRecordResultSpec(
1814 element=self.universe["detector"], order_by=(x.unwrap(x.visit),)
1815 ),
1816 )
1817 with self.assertRaises(qt.InvalidQueryError):
1818 # ResultSpec's order_by expression is not related to the datasets
1819 # we're returning.
1820 x = ExpressionFactory(self.universe)
1821 DimensionRecordQueryResults(
1822 _TestQueryDriver(),
1823 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1824 qrs.DimensionRecordResultSpec(
1825 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),)
1826 ),
1827 )
1829 def test_general_result_spec(self) -> None:
1830 """Tests for GeneralResultSpec.
1832 Unlike the other ResultSpec objects, we don't have a *QueryResults
1833 class for GeneralResultSpec yet, so we can't use the higher-level
1834 interfaces to test it like we can the others.
1835 """
1836 a = qrs.GeneralResultSpec(
1837 dimensions=self.universe.conform(["detector"]),
1838 dimension_fields={"detector": {"purpose"}},
1839 dataset_fields={},
1840 find_first=False,
1841 )
1842 self.assertEqual(a.find_first_dataset, None)
1843 a_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1844 a_columns.dimension_fields["detector"].add("purpose")
1845 self.assertEqual(a.get_result_columns(), a_columns)
1846 b = qrs.GeneralResultSpec(
1847 dimensions=self.universe.conform(["detector"]),
1848 dimension_fields={},
1849 dataset_fields={"bias": {"timespan", "dataset_id"}},
1850 find_first=True,
1851 )
1852 self.assertEqual(b.find_first_dataset, "bias")
1853 b_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1854 b_columns.dataset_fields["bias"].add("timespan")
1855 b_columns.dataset_fields["bias"].add("dataset_id")
1856 self.assertEqual(b.get_result_columns(), b_columns)
1857 with self.assertRaises(qt.InvalidQueryError):
1858 # More than one dataset type with find_first
1859 qrs.GeneralResultSpec(
1860 dimensions=self.universe.conform(["detector", "exposure"]),
1861 dimension_fields={},
1862 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}},
1863 find_first=True,
1864 )
1865 with self.assertRaises(qt.InvalidQueryError):
1866 # Out-of-bounds dimension fields.
1867 qrs.GeneralResultSpec(
1868 dimensions=self.universe.conform(["detector"]),
1869 dimension_fields={"visit": {"name"}},
1870 dataset_fields={},
1871 find_first=False,
1872 )
1873 with self.assertRaises(qt.InvalidQueryError):
1874 # No fields for dimension element.
1875 qrs.GeneralResultSpec(
1876 dimensions=self.universe.conform(["detector"]),
1877 dimension_fields={"detector": set()},
1878 dataset_fields={},
1879 find_first=True,
1880 )
1881 with self.assertRaises(qt.InvalidQueryError):
1882 # No fields for dataset.
1883 qrs.GeneralResultSpec(
1884 dimensions=self.universe.conform(["detector"]),
1885 dimension_fields={},
1886 dataset_fields={"bias": set()},
1887 find_first=True,
1888 )
1891if __name__ == "__main__":
1892 unittest.main()