Coverage for tests/test_query_interface.py: 10%
936 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:46 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for the public Butler._query interface and the Pydantic models that
29back it using a mock column-expression visitor and a mock QueryDriver
30implementation.
32These tests are entirely independent of which kind of butler or database
33backend we're using.
35This is a very large test file because a lot of tests make use of those mocks,
36but they're not so generally useful that I think they're worth putting in the
37library proper.
38"""
40from __future__ import annotations
42import dataclasses
43import itertools
44import unittest
45import uuid
46from collections.abc import Iterable, Iterator, Mapping, Set
47from typing import Any
49import astropy.time
50from lsst.daf.butler import (
51 CollectionType,
52 DataCoordinate,
53 DataIdValue,
54 DatasetRef,
55 DatasetType,
56 DimensionGroup,
57 DimensionRecord,
58 DimensionRecordSet,
59 DimensionUniverse,
60 InvalidQueryError,
61 MissingDatasetTypeError,
62 NamedValueSet,
63 NoDefaultCollectionError,
64 Timespan,
65)
66from lsst.daf.butler.queries import (
67 DataCoordinateQueryResults,
68 DatasetRefQueryResults,
69 DimensionRecordQueryResults,
70 Query,
71)
72from lsst.daf.butler.queries import driver as qd
73from lsst.daf.butler.queries import result_specs as qrs
74from lsst.daf.butler.queries import tree as qt
75from lsst.daf.butler.queries.expression_factory import ExpressionFactory
76from lsst.daf.butler.queries.tree._column_expression import UnaryExpression
77from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands
78from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
79from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError
80from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord
81from lsst.sphgeom import DISJOINT, Mq3cPixelization
84class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]):
85 """Test visitor for column expressions.
87 This visitor evaluates column expressions using regular Python logic.
89 Parameters
90 ----------
91 dimension_keys : `~collections.abc.Mapping`, optional
92 Mapping from dimension name to the value it should be assigned by the
93 visitor.
94 dimension_fields : `~collections.abc.Mapping`, optional
95 Mapping from ``(dimension element name, field)`` tuple to the value it
96 should be assigned by the visitor.
97 dataset_fields : `~collections.abc.Mapping`, optional
98 Mapping from ``(dataset type name, field)`` tuple to the value it
99 should be assigned by the visitor.
100 query_tree_items : `~collections.abc.Set`, optional
101 Set that should be used as the right-hand side of element-in-query
102 predicates.
103 """
105 def __init__(
106 self,
107 dimension_keys: Mapping[str, Any] | None = None,
108 dimension_fields: Mapping[tuple[str, str], Any] | None = None,
109 dataset_fields: Mapping[tuple[str, str], Any] | None = None,
110 query_tree_items: Set[Any] = frozenset(),
111 ):
112 self.dimension_keys = dimension_keys or {}
113 self.dimension_fields = dimension_fields or {}
114 self.dataset_fields = dataset_fields or {}
115 self.query_tree_items = query_tree_items
117 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any:
118 match expression.operator:
119 case "+":
120 return expression.a.visit(self) + expression.b.visit(self)
121 case "-":
122 return expression.a.visit(self) - expression.b.visit(self)
123 case "*":
124 return expression.a.visit(self) * expression.b.visit(self)
125 case "/":
126 match expression.column_type:
127 case "int":
128 return expression.a.visit(self) // expression.b.visit(self)
129 case "float":
130 return expression.a.visit(self) / expression.b.visit(self)
131 case "%":
132 return expression.a.visit(self) % expression.b.visit(self)
134 def visit_comparison(
135 self,
136 a: qt.ColumnExpression,
137 operator: qt.ComparisonOperator,
138 b: qt.ColumnExpression,
139 flags: PredicateVisitFlags,
140 ) -> bool:
141 match operator:
142 case "==":
143 return a.visit(self) == b.visit(self)
144 case "!=":
145 return a.visit(self) != b.visit(self)
146 case "<":
147 return a.visit(self) < b.visit(self)
148 case ">":
149 return a.visit(self) > b.visit(self)
150 case "<=":
151 return a.visit(self) <= b.visit(self)
152 case ">=":
153 return a.visit(self) >= b.visit(self)
154 case "overlaps":
155 return not (a.visit(self).relate(b.visit(self)) & DISJOINT)
157 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any:
158 return self.dataset_fields[expression.dataset_type, expression.field]
160 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any:
161 return self.dimension_fields[expression.element.name, expression.field]
163 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any:
164 return self.dimension_keys[expression.dimension.name]
166 def visit_in_container(
167 self,
168 member: qt.ColumnExpression,
169 container: tuple[qt.ColumnExpression, ...],
170 flags: PredicateVisitFlags,
171 ) -> bool:
172 return member.visit(self) in [item.visit(self) for item in container]
174 def visit_in_range(
175 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
176 ) -> bool:
177 return member.visit(self) in range(start, stop, step)
179 def visit_in_query_tree(
180 self,
181 member: qt.ColumnExpression,
182 column: qt.ColumnExpression,
183 query_tree: qt.QueryTree,
184 flags: PredicateVisitFlags,
185 ) -> bool:
186 return member.visit(self) in self.query_tree_items
188 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool:
189 return operand.visit(self) is None
191 def visit_literal(self, expression: qt.ColumnLiteral) -> Any:
192 return expression.get_literal_value()
194 def visit_reversed(self, expression: qt.Reversed) -> Any:
195 return _TestReversed(expression.operand.visit(self))
197 def visit_unary_expression(self, expression: UnaryExpression) -> Any:
198 match expression.operator:
199 case "-":
200 return -expression.operand.visit(self)
201 case "begin_of":
202 return expression.operand.visit(self).begin
203 case "end_of":
204 return expression.operand.visit(self).end
206 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool:
207 return all(results)
209 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool:
210 return not result
212 def apply_logical_or(
213 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags
214 ) -> bool:
215 return any(results)
218@dataclasses.dataclass
219class _TestReversed:
220 """Struct used by _TestVisitor" to mark an expression as reversed in sort
221 order.
222 """
224 operand: Any
227class _TestQueryExecution(BaseException):
228 """Exception raised by _TestQueryDriver.execute to communicate its args
229 back to the caller.
230 """
232 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None:
233 self.result_spec = result_spec
234 self.tree = tree
235 self.driver = driver
238class _TestQueryCount(BaseException):
239 """Exception raised by _TestQueryDriver.count to communicate its args
240 back to the caller.
241 """
243 def __init__(
244 self,
245 result_spec: qrs.ResultSpec,
246 tree: qt.QueryTree,
247 driver: _TestQueryDriver,
248 exact: bool,
249 discard: bool,
250 ) -> None:
251 self.result_spec = result_spec
252 self.tree = tree
253 self.driver = driver
254 self.exact = exact
255 self.discard = discard
258class _TestQueryAny(BaseException):
259 """Exception raised by _TestQueryDriver.any to communicate its args
260 back to the caller.
261 """
263 def __init__(
264 self,
265 tree: qt.QueryTree,
266 driver: _TestQueryDriver,
267 exact: bool,
268 execute: bool,
269 ) -> None:
270 self.tree = tree
271 self.driver = driver
272 self.exact = exact
273 self.execute = execute
276class _TestQueryExplainNoResults(BaseException):
277 """Exception raised by _TestQueryDriver.explain_no_results to communicate
278 its args back to the caller.
279 """
281 def __init__(
282 self,
283 tree: qt.QueryTree,
284 driver: _TestQueryDriver,
285 execute: bool,
286 ) -> None:
287 self.tree = tree
288 self.driver = driver
289 self.execute = execute
292class _TestQueryDriver(qd.QueryDriver):
293 """Mock implementation of `QueryDriver` that mostly raises exceptions that
294 communicate the arguments its methods were called with.
296 Parameters
297 ----------
298 default_collections : `tuple` [ `str`, ... ], optional
299 Default collection the query or parent butler is imagined to have been
300 constructed with.
301 collection_info : `~collections.abc.Mapping`, optional
302 Mapping from collection name to its record and summary, simulating the
303 collections present in the data repository.
304 dataset_types : `~collections.abc.Mapping`, optional
305 Mapping from dataset type to its definition, simulating the dataset
306 types registered in the data repository.
307 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional
308 A tuple of iterables of arbitrary type to use as result rows any time
309 `execute` is called, with each nested iterable considered a separate
310 page. The result type is not checked for consistency with the result
311 spec. If this is not provided, `execute` will instead raise
312 `_TestQueryExecution`, and `fetch_page` will not do anything useful.
313 """
315 def __init__(
316 self,
317 default_collections: tuple[str, ...] | None = None,
318 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None,
319 dataset_types: Mapping[str, DatasetType] | None = None,
320 result_rows: tuple[Iterable[Any], ...] | None = None,
321 ) -> None:
322 self._universe = DimensionUniverse()
323 # Mapping of the arguments passed to materialize, keyed by the UUID
324 # that that each call returned.
325 self.materializations: dict[
326 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]]
327 ] = {}
328 # Mapping of the arguments passed to upload_data_coordinates, keyed by
329 # the UUID that that each call returned.
330 self.data_coordinate_uploads: dict[
331 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]]
332 ] = {}
333 self._default_collections = default_collections
334 self._collection_info = collection_info or {}
335 self._dataset_types = dataset_types or {}
336 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = []
337 self._result_rows = result_rows
338 self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {}
340 @property
341 def universe(self) -> DimensionUniverse:
342 return self._universe
344 def __enter__(self) -> None:
345 pass
347 def __exit__(self, *args: Any, **kwargs: Any) -> None:
348 pass
350 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage:
351 if self._result_rows is not None:
352 iterator = iter(self._result_rows)
353 current_rows = next(iterator, ())
354 return self._make_next_page(result_spec, current_rows, iterator)
355 raise _TestQueryExecution(result_spec, tree, self)
357 def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage:
358 if self._result_rows is not None:
359 return self._make_next_page(result_spec, *self._result_iters.pop(key))
360 raise AssertionError("Test query driver not initialized for actual results.")
362 def _make_next_page(
363 self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]]
364 ) -> qd.ResultPage:
365 next_rows = list(next(iterator, ()))
366 if not next_rows:
367 next_key = None
368 else:
369 next_key = uuid.uuid4()
370 self._result_iters[next_key] = (next_rows, iterator)
371 match result_spec:
372 case qrs.DataCoordinateResultSpec():
373 return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
374 case qrs.DimensionRecordResultSpec():
375 return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
376 case qrs.DatasetRefResultSpec():
377 return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows)
378 case _:
379 raise NotImplementedError("Other query types not yet supported.")
381 def materialize(
382 self,
383 tree: qt.QueryTree,
384 dimensions: DimensionGroup,
385 datasets: frozenset[str],
386 ) -> qd.MaterializationKey:
387 key = uuid.uuid4()
388 self.materializations[key] = (tree, dimensions, datasets)
389 return key
391 def upload_data_coordinates(
392 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
393 ) -> qd.DataCoordinateUploadKey:
394 key = uuid.uuid4()
395 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows))
396 return key
398 def count(
399 self,
400 tree: qt.QueryTree,
401 result_spec: qrs.ResultSpec,
402 *,
403 exact: bool,
404 discard: bool,
405 ) -> int:
406 raise _TestQueryCount(result_spec, tree, self, exact, discard)
408 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
409 raise _TestQueryAny(tree, self, exact, execute)
411 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]:
412 raise _TestQueryExplainNoResults(tree, self, execute)
414 def get_default_collections(self) -> tuple[str, ...]:
415 if self._default_collections is None:
416 raise NoDefaultCollectionError()
417 return self._default_collections
419 def get_dataset_type(self, name: str) -> DatasetType:
420 try:
421 return self._dataset_types[name]
422 except KeyError:
423 raise MissingDatasetTypeError(name)
426class ColumnExpressionsTestCase(unittest.TestCase):
427 """Tests for column expression objects in lsst.daf.butler.queries.tree."""
429 def setUp(self) -> None:
430 self.universe = DimensionUniverse()
431 self.x = ExpressionFactory(self.universe)
433 def query(self, **kwargs: Any) -> Query:
434 """Make an initial Query object with the given kwargs used to
435 initialize the _TestQueryDriver.
436 """
437 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
439 def test_int_literals(self) -> None:
440 expr = self.x.unwrap(self.x.literal(5))
441 self.assertEqual(expr.value, 5)
442 self.assertEqual(expr.get_literal_value(), 5)
443 self.assertEqual(expr.expression_type, "int")
444 self.assertEqual(expr.column_type, "int")
445 self.assertEqual(str(expr), "5")
446 self.assertTrue(expr.is_literal)
447 columns = qt.ColumnSet(self.universe.empty.as_group())
448 expr.gather_required_columns(columns)
449 self.assertFalse(columns)
450 self.assertEqual(expr.visit(_TestVisitor()), 5)
452 def test_string_literals(self) -> None:
453 expr = self.x.unwrap(self.x.literal("five"))
454 self.assertEqual(expr.value, "five")
455 self.assertEqual(expr.get_literal_value(), "five")
456 self.assertEqual(expr.expression_type, "string")
457 self.assertEqual(expr.column_type, "string")
458 self.assertEqual(str(expr), "'five'")
459 self.assertTrue(expr.is_literal)
460 columns = qt.ColumnSet(self.universe.empty.as_group())
461 expr.gather_required_columns(columns)
462 self.assertFalse(columns)
463 self.assertEqual(expr.visit(_TestVisitor()), "five")
465 def test_float_literals(self) -> None:
466 expr = self.x.unwrap(self.x.literal(0.5))
467 self.assertEqual(expr.value, 0.5)
468 self.assertEqual(expr.get_literal_value(), 0.5)
469 self.assertEqual(expr.expression_type, "float")
470 self.assertEqual(expr.column_type, "float")
471 self.assertEqual(str(expr), "0.5")
472 self.assertTrue(expr.is_literal)
473 columns = qt.ColumnSet(self.universe.empty.as_group())
474 expr.gather_required_columns(columns)
475 self.assertFalse(columns)
476 self.assertEqual(expr.visit(_TestVisitor()), 0.5)
478 def test_hash_literals(self) -> None:
479 expr = self.x.unwrap(self.x.literal(b"eleven"))
480 self.assertEqual(expr.value, b"eleven")
481 self.assertEqual(expr.get_literal_value(), b"eleven")
482 self.assertEqual(expr.expression_type, "hash")
483 self.assertEqual(expr.column_type, "hash")
484 self.assertEqual(str(expr), "(bytes)")
485 self.assertTrue(expr.is_literal)
486 columns = qt.ColumnSet(self.universe.empty.as_group())
487 expr.gather_required_columns(columns)
488 self.assertFalse(columns)
489 self.assertEqual(expr.visit(_TestVisitor()), b"eleven")
491 def test_uuid_literals(self) -> None:
492 value = uuid.uuid4()
493 expr = self.x.unwrap(self.x.literal(value))
494 self.assertEqual(expr.value, value)
495 self.assertEqual(expr.get_literal_value(), value)
496 self.assertEqual(expr.expression_type, "uuid")
497 self.assertEqual(expr.column_type, "uuid")
498 self.assertEqual(str(expr), str(value))
499 self.assertTrue(expr.is_literal)
500 columns = qt.ColumnSet(self.universe.empty.as_group())
501 expr.gather_required_columns(columns)
502 self.assertFalse(columns)
503 self.assertEqual(expr.visit(_TestVisitor()), value)
505 def test_datetime_literals(self) -> None:
506 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
507 expr = self.x.unwrap(self.x.literal(value))
508 self.assertEqual(expr.value, value)
509 self.assertEqual(expr.get_literal_value(), value)
510 self.assertEqual(expr.expression_type, "datetime")
511 self.assertEqual(expr.column_type, "datetime")
512 self.assertEqual(str(expr), "2020-01-01T00:00:00")
513 self.assertTrue(expr.is_literal)
514 columns = qt.ColumnSet(self.universe.empty.as_group())
515 expr.gather_required_columns(columns)
516 self.assertFalse(columns)
517 self.assertEqual(expr.visit(_TestVisitor()), value)
519 def test_timespan_literals(self) -> None:
520 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
521 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
522 value = Timespan(begin, end)
523 expr = self.x.unwrap(self.x.literal(value))
524 self.assertEqual(expr.value, value)
525 self.assertEqual(expr.get_literal_value(), value)
526 self.assertEqual(expr.expression_type, "timespan")
527 self.assertEqual(expr.column_type, "timespan")
528 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)")
529 self.assertTrue(expr.is_literal)
530 columns = qt.ColumnSet(self.universe.empty.as_group())
531 expr.gather_required_columns(columns)
532 self.assertFalse(columns)
533 self.assertEqual(expr.visit(_TestVisitor()), value)
535 def test_region_literals(self) -> None:
536 pixelization = Mq3cPixelization(10)
537 value = pixelization.quad(12058870)
538 expr = self.x.unwrap(self.x.literal(value))
539 self.assertEqual(expr.value, value)
540 self.assertEqual(expr.get_literal_value(), value)
541 self.assertEqual(expr.expression_type, "region")
542 self.assertEqual(expr.column_type, "region")
543 self.assertEqual(str(expr), "(region)")
544 self.assertTrue(expr.is_literal)
545 columns = qt.ColumnSet(self.universe.empty.as_group())
546 expr.gather_required_columns(columns)
547 self.assertFalse(columns)
548 self.assertEqual(expr.visit(_TestVisitor()), value)
550 def test_invalid_literal(self) -> None:
551 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."):
552 self.x.literal(5j)
554 def test_dimension_key_reference(self) -> None:
555 expr = self.x.unwrap(self.x.detector)
556 self.assertIsNone(expr.get_literal_value())
557 self.assertEqual(expr.expression_type, "dimension_key")
558 self.assertEqual(expr.column_type, "int")
559 self.assertEqual(str(expr), "detector")
560 self.assertFalse(expr.is_literal)
561 columns = qt.ColumnSet(self.universe.empty.as_group())
562 expr.gather_required_columns(columns)
563 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
564 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3)
566 def test_dimension_field_reference(self) -> None:
567 expr = self.x.unwrap(self.x.detector.purpose)
568 self.assertIsNone(expr.get_literal_value())
569 self.assertEqual(expr.expression_type, "dimension_field")
570 self.assertEqual(expr.column_type, "string")
571 self.assertEqual(str(expr), "detector.purpose")
572 self.assertFalse(expr.is_literal)
573 columns = qt.ColumnSet(self.universe.empty.as_group())
574 expr.gather_required_columns(columns)
575 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
576 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
577 with self.assertRaises(InvalidQueryError):
578 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region")
579 self.assertEqual(
580 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science"
581 )
583 def test_dataset_field_reference(self) -> None:
584 expr = self.x.unwrap(self.x["raw"].ingest_date)
585 self.assertIsNone(expr.get_literal_value())
586 self.assertEqual(expr.expression_type, "dataset_field")
587 self.assertEqual(str(expr), "raw.ingest_date")
588 self.assertFalse(expr.is_literal)
589 columns = qt.ColumnSet(self.universe.empty.as_group())
590 expr.gather_required_columns(columns)
591 self.assertEqual(columns.dimensions, self.universe.empty.as_group())
592 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"})
593 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid")
594 self.assertEqual(
595 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string"
596 )
597 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string")
598 self.assertEqual(
599 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime"
600 )
601 self.assertEqual(
602 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan"
603 )
604 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
605 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value)
607 def test_unary_negation(self) -> None:
608 expr = self.x.unwrap(-self.x.visit.exposure_time)
609 self.assertIsNone(expr.get_literal_value())
610 self.assertEqual(expr.expression_type, "unary")
611 self.assertEqual(expr.column_type, "float")
612 self.assertEqual(str(expr), "-visit.exposure_time")
613 self.assertFalse(expr.is_literal)
614 columns = qt.ColumnSet(self.universe.empty.as_group())
615 expr.gather_required_columns(columns)
616 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
617 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
618 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0)
619 with self.assertRaises(InvalidQueryError):
620 qt.UnaryExpression(
621 operand=qt.DimensionFieldReference(
622 element=self.universe.dimensions["detector"], field="purpose"
623 ),
624 operator="-",
625 )
627 def test_unary_timespan_begin(self) -> None:
628 expr = self.x.unwrap(self.x.visit.timespan.begin)
629 self.assertIsNone(expr.get_literal_value())
630 self.assertEqual(expr.expression_type, "unary")
631 self.assertEqual(expr.column_type, "datetime")
632 self.assertEqual(str(expr), "visit.timespan.begin")
633 self.assertFalse(expr.is_literal)
634 columns = qt.ColumnSet(self.universe.empty.as_group())
635 expr.gather_required_columns(columns)
636 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
637 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
638 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
639 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
640 value = Timespan(begin, end)
641 self.assertEqual(
642 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin
643 )
644 with self.assertRaises(InvalidQueryError):
645 qt.UnaryExpression(
646 operand=qt.DimensionFieldReference(
647 element=self.universe.dimensions["detector"], field="purpose"
648 ),
649 operator="begin_of",
650 )
652 def test_unary_timespan_end(self) -> None:
653 expr = self.x.unwrap(self.x.visit.timespan.end)
654 self.assertIsNone(expr.get_literal_value())
655 self.assertEqual(expr.expression_type, "unary")
656 self.assertEqual(expr.column_type, "datetime")
657 self.assertEqual(str(expr), "visit.timespan.end")
658 self.assertFalse(expr.is_literal)
659 columns = qt.ColumnSet(self.universe.empty.as_group())
660 expr.gather_required_columns(columns)
661 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
662 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
663 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
664 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
665 value = Timespan(begin, end)
666 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end)
667 with self.assertRaises(InvalidQueryError):
668 qt.UnaryExpression(
669 operand=qt.DimensionFieldReference(
670 element=self.universe.dimensions["detector"], field="purpose"
671 ),
672 operator="end_of",
673 )
675 def test_binary_expression_float(self) -> None:
676 for proxy, string, value in [
677 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0),
678 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0),
679 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0),
680 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0),
681 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0),
682 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0),
683 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0),
684 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0),
685 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0),
686 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0),
687 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0),
688 # We don't need the parentheses we generate in the next one, but
689 # they're not a problem either.
690 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0),
691 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0),
692 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0),
693 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0),
694 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0),
695 ]:
696 with self.subTest(string=string):
697 expr = self.x.unwrap(proxy)
698 self.assertIsNone(expr.get_literal_value())
699 self.assertEqual(expr.expression_type, "binary")
700 self.assertEqual(expr.column_type, "float")
701 self.assertEqual(str(expr), string)
702 self.assertFalse(expr.is_literal)
703 columns = qt.ColumnSet(self.universe.empty.as_group())
704 expr.gather_required_columns(columns)
705 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
706 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
707 self.assertEqual(
708 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value
709 )
711 def test_binary_modulus(self) -> None:
712 for proxy, string, value in [
713 (self.x.visit.id % 2, "visit % 2", 1),
714 (52 % self.x.visit, "52 % visit", 2),
715 ]:
716 with self.subTest(string=string):
717 expr = self.x.unwrap(proxy)
718 self.assertIsNone(expr.get_literal_value())
719 self.assertEqual(expr.expression_type, "binary")
720 self.assertEqual(expr.column_type, "int")
721 self.assertEqual(str(expr), string)
722 self.assertFalse(expr.is_literal)
723 columns = qt.ColumnSet(self.universe.empty.as_group())
724 expr.gather_required_columns(columns)
725 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
726 self.assertFalse(columns.dimension_fields["visit"])
727 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value)
729 def test_binary_expression_validation(self) -> None:
730 with self.assertRaises(InvalidQueryError):
731 # No arithmetic operators on strings (we do not interpret + as
732 # concatenation).
733 self.x.instrument + "suffix"
734 with self.assertRaises(InvalidQueryError):
735 # Mixed types are not supported, even when they both support the
736 # operator.
737 self.x.visit.exposure_time + self.x.detector
738 with self.assertRaises(InvalidQueryError):
739 # No modulus for floats.
740 self.x.visit.exposure_time % 5.0
742 def test_reversed(self) -> None:
743 expr = self.x.detector.desc
744 self.assertIsNone(expr.get_literal_value())
745 self.assertEqual(expr.expression_type, "reversed")
746 self.assertEqual(expr.column_type, "int")
747 self.assertEqual(str(expr), "detector DESC")
748 self.assertFalse(expr.is_literal)
749 columns = qt.ColumnSet(self.universe.empty.as_group())
750 expr.gather_required_columns(columns)
751 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
752 self.assertFalse(columns.dimension_fields["detector"])
753 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5))
755 def test_trivial_predicate(self) -> None:
756 """Test logical operations on trivial True/False predicates."""
757 yes = qt.Predicate.from_bool(True)
758 no = qt.Predicate.from_bool(False)
759 maybe: qt.Predicate = self.x.detector == 5
760 for predicate in [
761 yes,
762 yes.logical_or(no),
763 no.logical_or(yes),
764 yes.logical_and(yes),
765 no.logical_not(),
766 yes.logical_or(maybe),
767 maybe.logical_or(yes),
768 ]:
769 self.assertEqual(predicate.column_type, "bool")
770 self.assertEqual(str(predicate), "True")
771 self.assertTrue(predicate.visit(_TestVisitor()))
772 self.assertEqual(predicate.operands, ())
773 for predicate in [
774 no,
775 yes.logical_and(no),
776 no.logical_and(yes),
777 no.logical_or(no),
778 yes.logical_not(),
779 no.logical_and(maybe),
780 maybe.logical_and(no),
781 ]:
782 self.assertEqual(predicate.column_type, "bool")
783 self.assertEqual(str(predicate), "False")
784 self.assertFalse(predicate.visit(_TestVisitor()))
785 self.assertEqual(predicate.operands, ((),))
786 for predicate in [
787 maybe,
788 yes.logical_and(maybe),
789 no.logical_or(maybe),
790 maybe.logical_not().logical_not(),
791 ]:
792 self.assertEqual(predicate.column_type, "bool")
793 self.assertEqual(str(predicate), "detector == 5")
794 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5})))
795 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4})))
796 self.assertEqual(len(predicate.operands), 1)
797 self.assertEqual(len(predicate.operands[0]), 1)
798 self.assertIs(predicate.operands[0][0], maybe.operands[0][0])
800 def test_comparison(self) -> None:
801 predicate: qt.Predicate
802 string: str
803 value: bool
804 for detector in (4, 5, 6):
805 for predicate, string, value in [
806 (self.x.detector == 5, "detector == 5", detector == 5),
807 (self.x.detector != 5, "detector != 5", detector != 5),
808 (self.x.detector < 5, "detector < 5", detector < 5),
809 (self.x.detector > 5, "detector > 5", detector > 5),
810 (self.x.detector <= 5, "detector <= 5", detector <= 5),
811 (self.x.detector >= 5, "detector >= 5", detector >= 5),
812 (self.x.detector == 5, "detector == 5", detector == 5),
813 (self.x.detector != 5, "detector != 5", detector != 5),
814 (self.x.detector < 5, "detector < 5", detector < 5),
815 (self.x.detector > 5, "detector > 5", detector > 5),
816 (self.x.detector <= 5, "detector <= 5", detector <= 5),
817 (self.x.detector >= 5, "detector >= 5", detector >= 5),
818 ]:
819 with self.subTest(string=string, detector=detector):
820 self.assertEqual(predicate.column_type, "bool")
821 self.assertEqual(str(predicate), string)
822 columns = qt.ColumnSet(self.universe.empty.as_group())
823 predicate.gather_required_columns(columns)
824 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
825 self.assertFalse(columns.dimension_fields["detector"])
826 self.assertEqual(
827 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value
828 )
829 inverted = predicate.logical_not()
830 self.assertEqual(inverted.column_type, "bool")
831 self.assertEqual(str(inverted), f"NOT {string}")
832 self.assertEqual(
833 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value
834 )
835 columns = qt.ColumnSet(self.universe.empty.as_group())
836 inverted.gather_required_columns(columns)
837 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
838 self.assertFalse(columns.dimension_fields["detector"])
840 def test_overlap_comparison(self) -> None:
841 pixelization = Mq3cPixelization(10)
842 region1 = pixelization.quad(12058870)
843 predicate = self.x.visit.region.overlaps(region1)
844 self.assertEqual(predicate.column_type, "bool")
845 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)")
846 columns = qt.ColumnSet(self.universe.empty.as_group())
847 predicate.gather_required_columns(columns)
848 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
849 self.assertEqual(columns.dimension_fields["visit"], {"region"})
850 region2 = pixelization.quad(12058857)
851 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
852 inverted = predicate.logical_not()
853 self.assertEqual(inverted.column_type, "bool")
854 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)")
855 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
856 columns = qt.ColumnSet(self.universe.empty.as_group())
857 inverted.gather_required_columns(columns)
858 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
859 self.assertEqual(columns.dimension_fields["visit"], {"region"})
861 def test_invalid_comparison(self) -> None:
862 # Mixed type comparisons.
863 with self.assertRaises(InvalidQueryError):
864 self.x.visit > "three"
865 with self.assertRaises(InvalidQueryError):
866 self.x.visit > 3.0
867 # Invalid operator for type.
868 with self.assertRaises(InvalidQueryError):
869 self.x["raw"].dataset_id < uuid.uuid4()
871 def test_is_null(self) -> None:
872 predicate = self.x.visit.region.is_null
873 self.assertEqual(predicate.column_type, "bool")
874 self.assertEqual(str(predicate), "visit.region IS NULL")
875 columns = qt.ColumnSet(self.universe.empty.as_group())
876 predicate.gather_required_columns(columns)
877 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
878 self.assertEqual(columns.dimension_fields["visit"], {"region"})
879 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
880 inverted = predicate.logical_not()
881 self.assertEqual(inverted.column_type, "bool")
882 self.assertEqual(str(inverted), "NOT visit.region IS NULL")
883 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
884 inverted.gather_required_columns(columns)
885 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
886 self.assertEqual(columns.dimension_fields["visit"], {"region"})
888 def test_in_container(self) -> None:
889 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id])
890 self.assertEqual(predicate.column_type, "bool")
891 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]")
892 columns = qt.ColumnSet(self.universe.empty.as_group())
893 predicate.gather_required_columns(columns)
894 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
895 self.assertFalse(columns.dimension_fields["visit"])
896 self.assertFalse(columns.dimension_fields["exposure"])
897 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
898 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
899 inverted = predicate.logical_not()
900 self.assertEqual(inverted.column_type, "bool")
901 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]")
902 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
903 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
904 columns = qt.ColumnSet(self.universe.empty.as_group())
905 inverted.gather_required_columns(columns)
906 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
907 self.assertFalse(columns.dimension_fields["visit"])
908 self.assertFalse(columns.dimension_fields["exposure"])
909 with self.assertRaises(InvalidQueryError):
910 # Regions (and timespans) not allowed in IN expressions, since that
911 # suggests topological logic we're not actually doing. We can't
912 # use ExpressionFactory because it prohibits this case with typing.
913 pixelization = Mq3cPixelization(10)
914 region = pixelization.quad(12058870)
915 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)])
916 with self.assertRaises(InvalidQueryError):
917 # Mismatched types.
918 self.x.visit.in_iterable([3.5, 2.1])
920 def test_in_range(self) -> None:
921 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2)
922 self.assertEqual(predicate.column_type, "bool")
923 self.assertEqual(str(predicate), "visit IN 2:8:2")
924 columns = qt.ColumnSet(self.universe.empty.as_group())
925 predicate.gather_required_columns(columns)
926 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
927 self.assertFalse(columns.dimension_fields["visit"])
928 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2})))
929 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8})))
930 inverted = predicate.logical_not()
931 self.assertEqual(inverted.column_type, "bool")
932 self.assertEqual(str(inverted), "NOT visit IN 2:8:2")
933 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2})))
934 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8})))
935 columns = qt.ColumnSet(self.universe.empty.as_group())
936 inverted.gather_required_columns(columns)
937 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
938 self.assertFalse(columns.dimension_fields["visit"])
939 with self.assertRaises(InvalidQueryError):
940 # Only integer fields allowed.
941 self.x.visit.exposure_time.in_range(2, 4)
942 with self.assertRaises(InvalidQueryError):
943 # Step must be positive.
944 self.x.visit.in_range(2, 4, -1)
945 with self.assertRaises(InvalidQueryError):
946 # Stop must be >= start.
947 self.x.visit.in_range(2, 0)
949 def test_in_query(self) -> None:
950 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3)
951 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query)
952 self.assertEqual(predicate.column_type, "bool")
953 self.assertEqual(str(predicate), "exposure IN (query).visit")
954 columns = qt.ColumnSet(self.universe.empty.as_group())
955 predicate.gather_required_columns(columns)
956 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
957 self.assertFalse(columns.dimension_fields["exposure"])
958 self.assertTrue(
959 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
960 )
961 self.assertFalse(
962 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
963 )
964 inverted = predicate.logical_not()
965 self.assertEqual(inverted.column_type, "bool")
966 self.assertEqual(str(inverted), "NOT exposure IN (query).visit")
967 self.assertFalse(
968 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
969 )
970 self.assertTrue(
971 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
972 )
973 columns = qt.ColumnSet(self.universe.empty.as_group())
974 inverted.gather_required_columns(columns)
975 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
976 self.assertFalse(columns.dimension_fields["exposure"])
977 with self.assertRaises(InvalidQueryError):
978 # Regions (and timespans) not allowed in IN expressions, since that
979 # suggests topological logic we're not actually doing. We can't
980 # use ExpressionFactory because it prohibits this case with typing.
981 qt.Predicate.in_query(
982 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree
983 )
984 with self.assertRaises(InvalidQueryError):
985 # Mismatched types.
986 self.x.exposure.in_query(self.x.visit.exposure_time, query)
987 with self.assertRaises(InvalidQueryError):
988 # Query column requires dimensions that are not in the query.
989 self.x.exposure.in_query(self.x.patch, query)
990 with self.assertRaises(InvalidQueryError):
991 # Query column requires dataset type that is not in the query.
992 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query)
994 def test_complex_predicate(self) -> None:
995 """Test that predicates are converted to conjunctive normal form and
996 get parentheses in the right places when stringified.
997 """
998 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"})
999 a: qt.Predicate = self.x.visit > 5 # will evaluate to True
1000 b: qt.Predicate = self.x.detector != 3 # will evaluate to False
1001 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True
1002 d: qt.Predicate = self.x.band == "g" # will evaluate to False
1003 predicate: qt.Predicate
1004 for predicate, string, value in [
1005 (a.logical_or(b), f"{a} OR {b}", True),
1006 (a.logical_or(c), f"{a} OR {c}", True),
1007 (b.logical_or(d), f"{b} OR {d}", False),
1008 (a.logical_and(b), f"{a} AND {b}", False),
1009 (a.logical_and(c), f"{a} AND {c}", True),
1010 (b.logical_and(d), f"{b} AND {d}", False),
1011 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True),
1012 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False),
1013 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True),
1014 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False),
1015 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True),
1016 (
1017 a.logical_and(b).logical_or(c.logical_and(d)),
1018 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})",
1019 False,
1020 ),
1021 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False),
1022 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False),
1023 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True),
1024 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True),
1025 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False),
1026 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True),
1027 (
1028 self.x.not_(a.logical_or(b).logical_and(c)),
1029 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})",
1030 False,
1031 ),
1032 (
1033 a.logical_and(b.logical_or(d)).logical_not(),
1034 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})",
1035 True,
1036 ),
1037 ]:
1038 with self.subTest(string=string):
1039 self.assertEqual(str(predicate), string)
1040 self.assertEqual(predicate.visit(visitor), value)
1042 def test_proxy_misc(self) -> None:
1043 """Test miscellaneous things on various ExpressionFactory proxies."""
1044 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region")
1045 self.assertEqual(str(self.x.visit.instrument), "instrument")
1046 self.assertEqual(str(self.x["raw"]), "raw")
1047 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date")
1048 self.assertEqual(
1049 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)),
1050 "visit.timespan OVERLAPS raw.timespan",
1051 )
1052 self.assertGreater(
1053 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"}
1054 )
1055 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"})
1056 with self.assertRaises(AttributeError):
1057 self.x["raw"].seq_num
1058 with self.assertRaises(AttributeError):
1059 self.x.visit.horse
1062class QueryTestCase(unittest.TestCase):
1063 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries."""
1065 def setUp(self) -> None:
1066 self.maxDiff = None
1067 self.universe = DimensionUniverse()
1068 # We use ArrowTable as the storage class for all dataset types because
1069 # it's got conversions that only require third-party packages we
1070 # already require.
1071 self.raw = DatasetType(
1072 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable"
1073 )
1074 self.refcat = DatasetType(
1075 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable"
1076 )
1077 self.bias = DatasetType(
1078 "bias",
1079 dimensions=self.universe.conform(["detector"]),
1080 storageClass="ArrowTable",
1081 isCalibration=True,
1082 )
1083 self.default_collections: list[str] | None = ["DummyCam/defaults"]
1084 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = {
1085 "DummyCam/raw/all": (
1086 RunRecord[int](1, name="DummyCam/raw/all"),
1087 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}),
1088 ),
1089 "DummyCam/calib": (
1090 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION),
1091 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}),
1092 ),
1093 "refcats": (
1094 RunRecord[int](3, name="refcats"),
1095 CollectionSummary(NamedValueSet({self.refcat}), governors={}),
1096 ),
1097 "DummyCam/defaults": (
1098 ChainedCollectionRecord[int](
1099 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats")
1100 ),
1101 CollectionSummary(
1102 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}}
1103 ),
1104 ),
1105 }
1106 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias}
1108 def query(self, **kwargs: Any) -> Query:
1109 """Make an initial Query object with the given kwargs used to
1110 initialize the _TestQueryDriver.
1112 The given kwargs override the test-case-attribute defaults.
1113 """
1114 kwargs.setdefault("default_collections", self.default_collections)
1115 kwargs.setdefault("collection_info", self.collection_info)
1116 kwargs.setdefault("dataset_types", self.dataset_types)
1117 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
1119 def test_dataset_join(self) -> None:
1120 """Test queries that have had a dataset search explicitly joined in via
1121 Query.join_dataset_search.
1123 Since this kind of query has a moderate amount of complexity, this is
1124 where we get a lot of basic coverage that applies to all kinds of
1125 queries, including:
1127 - getting data ID and dataset results (but not iterating over them);
1128 - the 'any' and 'explain_no_results' methods;
1129 - adding 'where' filters (but not expanding dimensions accordingly);
1130 - materializations.
1131 """
1133 def check(
1134 query: Query,
1135 dimensions: DimensionGroup = self.raw.dimensions.as_group(),
1136 ) -> None:
1137 """Run a battery of tests on one of a set of very similar queries
1138 constructed in different ways (see below).
1139 """
1141 def check_query_tree(
1142 tree: qt.QueryTree,
1143 dimensions: DimensionGroup = dimensions,
1144 ) -> None:
1145 """Check the state of the QueryTree object that backs the Query
1146 or a derived QueryResults object.
1148 Parameters
1149 ----------
1150 tree : `lsst.daf.butler.queries.tree.QueryTree`
1151 Object to test.
1152 dimensions : `DimensionGroup`
1153 Dimensions to expect in the `QueryTree`, not necessarily
1154 including those in the test 'raw' dataset type.
1155 """
1156 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group())
1157 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'")
1158 self.assertFalse(tree.materializations)
1159 self.assertFalse(tree.data_coordinate_uploads)
1160 self.assertEqual(tree.datasets.keys(), {"raw"})
1161 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1162 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",))
1163 self.assertEqual(
1164 tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()})
1165 )
1167 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None:
1168 """Construct a DataCoordinateQueryResults object from the query
1169 with the given arguments and run a battery of tests on it.
1171 Parameters
1172 ----------
1173 *args
1174 Forwarded to `Query.data_ids`.
1175 query : `Query`
1176 Query to start from.
1177 dimensions : `DimensionGroup`, optional
1178 Dimensions the result data IDs should have.
1179 """
1180 with self.assertRaises(_TestQueryExecution) as cm:
1181 list(query.data_ids(*args))
1182 self.assertEqual(
1183 cm.exception.result_spec,
1184 qrs.DataCoordinateResultSpec(dimensions=dimensions),
1185 )
1186 check_query_tree(cm.exception.tree, dimensions=dimensions)
1188 def check_dataset_results(
1189 *args: Any,
1190 query: Query,
1191 find_first: bool = True,
1192 storage_class_name: str = self.raw.storageClass_name,
1193 ) -> None:
1194 """Construct a DatasetRefQueryResults object from the query
1195 with the given arguments and run a battery of tests on it.
1197 Parameters
1198 ----------
1199 *args
1200 Forwarded to `Query.datasets`.
1201 query : `Query`
1202 Query to start from.
1203 find_first : `bool`, optional
1204 Whether to do find-first resolution on the results.
1205 storage_class_name : `str`, optional
1206 Expected name of the storage class for the results.
1207 """
1208 with self.assertRaises(_TestQueryExecution) as cm:
1209 list(query.datasets(*args, find_first=find_first))
1210 self.assertEqual(
1211 cm.exception.result_spec,
1212 qrs.DatasetRefResultSpec(
1213 dataset_type_name="raw",
1214 dimensions=self.raw.dimensions.as_group(),
1215 storage_class_name=storage_class_name,
1216 find_first=find_first,
1217 ),
1218 )
1219 check_query_tree(cm.exception.tree)
1221 def check_materialization(
1222 kwargs: Mapping[str, Any],
1223 query: Query,
1224 dimensions: DimensionGroup = dimensions,
1225 has_dataset: bool = True,
1226 ) -> None:
1227 """Materialize the query with the given arguments and run a
1228 battery of tests on the result.
1230 Parameters
1231 ----------
1232 kwargs
1233 Forwarded as keyword arguments to `Query.materialize`.
1234 query : `Query`
1235 Query to start from.
1236 dimensions : `DimensionGroup`, optional
1237 Dimensions to expect in the materialization and its derived
1238 query.
1239 has_dataset : `bool`, optional
1240 Whether the query backed by the materialization should
1241 still have the test 'raw' dataset joined in.
1242 """
1243 # Materialize the query and check the query tree sent to the
1244 # driver and the one in the materialized query.
1245 with self.assertRaises(_TestQueryExecution) as cm:
1246 list(query.materialize(**kwargs).data_ids())
1247 derived_tree = cm.exception.tree
1248 self.assertEqual(derived_tree.dimensions, dimensions)
1249 # Predicate should be materialized away; it no longer appears
1250 # in the derived query.
1251 self.assertEqual(str(derived_tree.predicate), "True")
1252 self.assertFalse(derived_tree.data_coordinate_uploads)
1253 if has_dataset:
1254 # Dataset search is still there, even though its existence
1255 # constraint is included in the materialization, because we
1256 # might need to re-join for some result columns in a
1257 # derived query.
1258 self.assertTrue(derived_tree.datasets.keys(), {"raw"})
1259 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group())
1260 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",))
1261 else:
1262 self.assertFalse(derived_tree.datasets)
1263 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items()
1264 self.assertEqual(derived_tree_materialized_dimensions, dimensions)
1265 (
1266 materialized_tree,
1267 materialized_dimensions,
1268 materialized_datasets,
1269 ) = cm.exception.driver.materializations[key]
1270 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions)
1271 if has_dataset:
1272 self.assertEqual(materialized_datasets, {"raw"})
1273 else:
1274 self.assertFalse(materialized_datasets)
1275 check_query_tree(materialized_tree)
1277 # Actual logic for the check() function begins here.
1279 self.assertEqual(query.constraint_dataset_types, {"raw"})
1280 self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group())
1282 # Adding a constraint on a field for this dataset type should work
1283 # (this constraint will be present in all downstream tests).
1284 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all")
1285 with self.assertRaises(InvalidQueryError):
1286 # Adding constraint on a different dataset should not work.
1287 query.where(query.expression_factory["refcat"].run == "refcats")
1289 # Data IDs, with dimensions defaulted.
1290 check_data_id_results(query=query)
1291 # Dimensions for data IDs the same as defaults.
1292 check_data_id_results(["exposure", "detector"], query=query)
1293 # Dimensions are a subset of the query dimensions.
1294 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"]))
1295 # Dimensions are a superset of the query dimensions.
1296 check_data_id_results(
1297 ["exposure", "detector", "visit"],
1298 query=query,
1299 dimensions=self.universe.conform(["exposure", "detector", "visit"]),
1300 )
1301 # Dimensions are neither a superset nor a subset of the query
1302 # dimensions.
1303 check_data_id_results(
1304 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"])
1305 )
1306 # Dimensions are empty.
1307 check_data_id_results([], query=query, dimensions=self.universe.conform([]))
1309 # Get DatasetRef results, with various arguments and defaulting.
1310 check_dataset_results("raw", query=query)
1311 check_dataset_results("raw", query=query, find_first=True)
1312 check_dataset_results("raw", ["DummyCam/defaults"], query=query)
1313 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True)
1314 check_dataset_results(self.raw, query=query)
1315 check_dataset_results(self.raw, query=query, find_first=True)
1316 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query)
1317 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True)
1319 # Changing collections at this stage is not allowed.
1320 with self.assertRaises(InvalidQueryError):
1321 query.datasets("raw", collections=["DummyCam/calib"])
1323 # Changing storage classes is allowed, if they're compatible.
1324 check_dataset_results(
1325 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy"
1326 )
1327 with self.assertRaises(DatasetTypeError):
1328 # Can't use overrideStorageClass, because it'll raise
1329 # before the code we want to test can.
1330 query.datasets(DatasetType("raw", self.raw.dimensions, "int"))
1332 # Check the 'any' and 'explain_no_results' methods on Query itself.
1333 for execute, exact in itertools.permutations([False, True], 2):
1334 with self.assertRaises(_TestQueryAny) as cm:
1335 query.any(execute=execute, exact=exact)
1336 self.assertEqual(cm.exception.execute, execute)
1337 self.assertEqual(cm.exception.exact, exact)
1338 check_query_tree(cm.exception.tree, dimensions)
1339 with self.assertRaises(_TestQueryExplainNoResults):
1340 query.explain_no_results()
1341 check_query_tree(cm.exception.tree, dimensions)
1343 # Materialize the query with defaults.
1344 check_materialization({}, query=query)
1345 # Materialize the query with args that match defaults.
1346 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query)
1347 # Materialize the query with a superset of the original dimensions.
1348 check_materialization(
1349 {"dimensions": ["exposure", "detector", "visit"]},
1350 query=query,
1351 dimensions=self.universe.conform(["exposure", "visit", "detector"]),
1352 )
1353 # Materialize the query with no datasets.
1354 check_materialization(
1355 {"dimensions": ["exposure", "detector"], "datasets": frozenset()},
1356 query=query,
1357 has_dataset=False,
1358 )
1359 # Materialize the query with no datasets and a subset of the
1360 # dimensions.
1361 check_materialization(
1362 {"dimensions": ["exposure"], "datasets": frozenset()},
1363 query=query,
1364 has_dataset=False,
1365 dimensions=self.universe.conform(["exposure"]),
1366 )
1367 # Materializing the query with a dataset that is not in the query
1368 # is an error.
1369 with self.assertRaises(InvalidQueryError):
1370 query.materialize(datasets={"refcat"})
1371 # Materializing the query with dimensions that are not a superset
1372 # of any materialized dataset dimensions is an error.
1373 with self.assertRaises(InvalidQueryError):
1374 query.materialize(dimensions=["exposure"], datasets={"raw"})
1376 # Actual logic for test_dataset_joins starts here.
1378 # Default collections and existing dataset type name.
1379 check(self.query().join_dataset_search("raw"))
1380 # Default collections and existing DatasetType instance.
1381 check(self.query().join_dataset_search(self.raw))
1382 # Manual collections and existing dataset type.
1383 check(
1384 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"])
1385 )
1386 check(
1387 self.query(default_collections=None).join_dataset_search(
1388 self.raw, collections=["DummyCam/defaults"]
1389 )
1390 )
1391 with self.assertRaises(MissingDatasetTypeError):
1392 # Dataset type does not exist.
1393 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"])
1394 with self.assertRaises(DatasetTypeError):
1395 # Dataset type object with bad dimensions passed.
1396 self.query().join_dataset_search(
1397 DatasetType(
1398 "raw",
1399 dimensions={"detector", "visit"},
1400 storageClass=self.raw.storageClass_name,
1401 universe=self.universe,
1402 )
1403 )
1404 with self.assertRaises(TypeError):
1405 # Bad type for dataset type argument.
1406 self.query().join_dataset_search(3)
1407 with self.assertRaises(InvalidQueryError):
1408 # Cannot pass storage class override to join_dataset_search,
1409 # because we cannot use it there.
1410 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy"))
1412 def test_dimension_record_results(self) -> None:
1413 """Test queries that return dimension records.
1415 This includes tests for:
1417 - joining against uploaded data coordinates;
1418 - counting result rows;
1419 - expanding dimensions as needed for 'where' conditions;
1420 - order_by and limit.
1422 It does not include the iteration methods of
1423 DimensionRecordQueryResults, since those require a different mock
1424 driver setup (see test_dimension_record_iteration).
1425 """
1426 # Set up the base query-results object to test.
1427 query = self.query()
1428 x = query.expression_factory
1429 self.assertFalse(query.constraint_dimensions)
1430 query = query.where(x.skymap == "m")
1431 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"]))
1432 upload_rows = [
1433 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe),
1434 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe),
1435 ]
1436 raw_rows = frozenset([data_id.required_values for data_id in upload_rows])
1437 query = query.join_data_coordinates(upload_rows)
1438 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"]))
1439 results = query.dimension_records("patch")
1440 results = results.where(x.tract == 4)
1442 # Define a closure to run tests on variants of the base query.
1443 def check(
1444 results: DimensionRecordQueryResults,
1445 order_by: Any = (),
1446 limit: int | None = None,
1447 ) -> list[str]:
1448 results = results.order_by(*order_by).limit(limit)
1449 self.assertEqual(results.element.name, "patch")
1450 with self.assertRaises(_TestQueryExecution) as cm:
1451 list(results)
1452 tree = cm.exception.tree
1453 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1454 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"]))
1455 self.assertFalse(tree.materializations)
1456 self.assertFalse(tree.datasets)
1457 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items()
1458 self.assertEqual(upload_dimensions, self.universe.conform(["visit"]))
1459 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows))
1460 result_spec = cm.exception.result_spec
1461 self.assertEqual(result_spec.result_type, "dimension_record")
1462 self.assertEqual(result_spec.element, self.universe["patch"])
1463 self.assertEqual(result_spec.limit, limit)
1464 for exact, discard in itertools.permutations([False, True], r=2):
1465 with self.assertRaises(_TestQueryCount) as cm:
1466 results.count(exact=exact, discard=discard)
1467 self.assertEqual(cm.exception.result_spec, result_spec)
1468 self.assertEqual(cm.exception.exact, exact)
1469 self.assertEqual(cm.exception.discard, discard)
1470 return [str(term) for term in result_spec.order_by]
1472 # Run the closure's tests on variants of the base query.
1473 self.assertEqual(check(results), [])
1474 self.assertEqual(check(results, limit=2), [])
1475 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1476 self.assertEqual(
1477 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1478 ["patch.cell_x", "patch.cell_y DESC"],
1479 )
1480 with self.assertRaises(InvalidQueryError):
1481 # Cannot upload empty list of data IDs.
1482 query.join_data_coordinates([])
1483 with self.assertRaises(InvalidQueryError):
1484 # Cannot upload heterogeneous list of data IDs.
1485 query.join_data_coordinates(
1486 [
1487 DataCoordinate.make_empty(self.universe),
1488 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe),
1489 ]
1490 )
1492 def test_dimension_record_iteration(self) -> None:
1493 """Tests for DimensionRecordQueryResult iteration."""
1495 def make_record(n: int) -> DimensionRecord:
1496 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n)
1498 result_rows = (
1499 [make_record(n) for n in range(3)],
1500 [make_record(n) for n in range(3, 6)],
1501 [make_record(10)],
1502 )
1503 results = self.query(result_rows=result_rows).dimension_records("patch")
1504 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1505 self.assertEqual(
1506 list(results.iter_set_pages()),
1507 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows],
1508 )
1509 self.assertEqual(
1510 [table.column("id").to_pylist() for table in results.iter_table_pages()],
1511 [list(range(3)), list(range(3, 6)), [10]],
1512 )
1514 def test_data_coordinate_results(self) -> None:
1515 """Test queries that return data coordinates.
1517 This includes tests for:
1519 - counting result rows;
1520 - expanding dimensions as needed for 'where' conditions;
1521 - order_by and limit.
1523 It does not include the iteration methods of
1524 DataCoordinateQueryResults, since those require a different mock
1525 driver setup (see test_data_coordinate_iteration). More tests for
1526 different inputs to DataCoordinateQueryResults construction are in
1527 test_dataset_join.
1528 """
1529 # Set up the base query-results object to test.
1530 query = self.query()
1531 x = query.expression_factory
1532 self.assertFalse(query.constraint_dimensions)
1533 query = query.where(x.skymap == "m")
1534 results = query.data_ids(["patch", "band"])
1535 results = results.where(x.tract == 4)
1537 # Define a closure to run tests on variants of the base query.
1538 def check(
1539 results: DataCoordinateQueryResults,
1540 order_by: Any = (),
1541 limit: int | None = None,
1542 include_dimension_records: bool = False,
1543 ) -> list[str]:
1544 results = results.order_by(*order_by).limit(limit)
1545 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"]))
1546 with self.assertRaises(_TestQueryExecution) as cm:
1547 list(results)
1548 tree = cm.exception.tree
1549 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1550 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"]))
1551 self.assertFalse(tree.materializations)
1552 self.assertFalse(tree.datasets)
1553 self.assertFalse(tree.data_coordinate_uploads)
1554 result_spec = cm.exception.result_spec
1555 self.assertEqual(result_spec.result_type, "data_coordinate")
1556 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"]))
1557 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1558 self.assertEqual(result_spec.limit, limit)
1559 self.assertIsNone(result_spec.find_first_dataset)
1560 for exact, discard in itertools.permutations([False, True], r=2):
1561 with self.assertRaises(_TestQueryCount) as cm:
1562 results.count(exact=exact, discard=discard)
1563 self.assertEqual(cm.exception.result_spec, result_spec)
1564 self.assertEqual(cm.exception.exact, exact)
1565 self.assertEqual(cm.exception.discard, discard)
1566 return [str(term) for term in result_spec.order_by]
1568 # Run the closure's tests on variants of the base query.
1569 self.assertEqual(check(results), [])
1570 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), [])
1571 self.assertEqual(
1572 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1573 [],
1574 )
1575 self.assertEqual(check(results, limit=2), [])
1576 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1577 self.assertEqual(
1578 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1579 ["patch.cell_x", "patch.cell_y DESC"],
1580 )
1581 self.assertEqual(
1582 check(results, order_by=["patch.cell_x", "-cell_y"]),
1583 ["patch.cell_x", "patch.cell_y DESC"],
1584 )
1586 def test_data_coordinate_iteration(self) -> None:
1587 """Tests for DataCoordinateQueryResult iteration."""
1589 def make_data_id(n: int) -> DimensionRecord:
1590 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe)
1592 result_rows = (
1593 [make_data_id(n) for n in range(3)],
1594 [make_data_id(n) for n in range(3, 6)],
1595 [make_data_id(10)],
1596 )
1597 results = self.query(result_rows=result_rows).data_ids(["patch"])
1598 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1600 def test_dataset_results(self) -> None:
1601 """Test queries that return dataset refs.
1603 This includes tests for:
1605 - counting result rows;
1606 - expanding dimensions as needed for 'where' conditions;
1607 - different ways of passing a data ID to 'where' methods;
1608 - order_by and limit.
1610 It does not include the iteration methods of the DatasetRefQueryResults
1611 classes, since those require a different mock driver setup (see
1612 test_dataset_iteration). More tests for different inputs to
1613 DatasetRefQueryResults construction are in test_dataset_join.
1614 """
1615 # Set up a few equivalent base query-results object to test.
1616 query = self.query()
1617 x = query.expression_factory
1618 self.assertFalse(query.constraint_dimensions)
1619 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4)
1620 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where(
1621 {"instrument": "DummyCam", "visit": 4}
1622 )
1623 results3 = query.datasets("raw").where(
1624 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe)
1625 )
1627 # Define a closure to check a DatasetRefQueryResults instance.
1628 def check(
1629 results: DatasetRefQueryResults,
1630 order_by: Any = (),
1631 limit: int | None = None,
1632 include_dimension_records: bool = False,
1633 ) -> list[str]:
1634 results = results.order_by(*order_by).limit(limit)
1635 with self.assertRaises(_TestQueryExecution) as cm:
1636 list(results)
1637 tree = cm.exception.tree
1638 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4")
1639 self.assertEqual(
1640 tree.dimensions,
1641 self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()),
1642 )
1643 self.assertFalse(tree.materializations)
1644 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name})
1645 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",))
1646 self.assertEqual(
1647 tree.datasets[results.dataset_type.name].dimensions,
1648 results.dataset_type.dimensions.as_group(),
1649 )
1650 self.assertFalse(tree.data_coordinate_uploads)
1651 result_spec = cm.exception.result_spec
1652 self.assertEqual(result_spec.result_type, "dataset_ref")
1653 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1654 self.assertEqual(result_spec.limit, limit)
1655 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name)
1656 for exact, discard in itertools.permutations([False, True], r=2):
1657 with self.assertRaises(_TestQueryCount) as cm:
1658 results.count(exact=exact, discard=discard)
1659 self.assertEqual(cm.exception.result_spec, result_spec)
1660 self.assertEqual(cm.exception.exact, exact)
1661 self.assertEqual(cm.exception.discard, discard)
1662 with self.assertRaises(_TestQueryExecution) as cm:
1663 list(results.data_ids)
1664 self.assertEqual(
1665 cm.exception.result_spec,
1666 qrs.DataCoordinateResultSpec(
1667 dimensions=results.dataset_type.dimensions.as_group(),
1668 include_dimension_records=include_dimension_records,
1669 ),
1670 )
1671 self.assertIs(cm.exception.tree, tree)
1672 return [str(term) for term in result_spec.order_by]
1674 # Run the closure's tests on variants of the base query.
1675 self.assertEqual(check(results1), [])
1676 self.assertEqual(check(results2), [])
1677 self.assertEqual(check(results3), [])
1678 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), [])
1679 self.assertEqual(
1680 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1681 [],
1682 )
1683 self.assertEqual(check(results1, limit=2), [])
1684 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"])
1685 self.assertEqual(check(results1, order_by=["detector"]), ["detector"])
1686 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"])
1688 def test_dataset_iteration(self) -> None:
1689 """Tests for SingleTypeDatasetQueryResult iteration."""
1691 def make_ref(n: int) -> DimensionRecord:
1692 return DatasetRef(
1693 self.raw,
1694 DataCoordinate.standardize(
1695 instrument="DummyCam", exposure=4, detector=n, universe=self.universe
1696 ),
1697 run="DummyCam/raw/all",
1698 id=uuid.uuid4(),
1699 )
1701 result_rows = (
1702 [make_ref(n) for n in range(3)],
1703 [make_ref(n) for n in range(3, 6)],
1704 [make_ref(10)],
1705 )
1706 results = self.query(result_rows=result_rows).datasets("raw")
1707 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1709 def test_identifiers(self) -> None:
1710 """Test edge-cases of identifiers in order_by expressions."""
1712 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]:
1713 with self.assertRaises(_TestQueryExecution) as cm:
1714 list(results)
1715 return [str(term) for term in cm.exception.result_spec.order_by]
1717 self.assertEqual(
1718 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")),
1719 ["day_obs.timespan.begin DESC"],
1720 )
1721 self.assertEqual(
1722 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")),
1723 ["day_obs.timespan.end"],
1724 )
1725 self.assertEqual(
1726 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")),
1727 ["visit.timespan.begin DESC"],
1728 )
1729 self.assertEqual(
1730 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")),
1731 ["visit.timespan.end"],
1732 )
1733 self.assertEqual(
1734 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")),
1735 ["visit.science_program"],
1736 )
1737 self.assertEqual(
1738 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")),
1739 ["visit"],
1740 )
1741 self.assertEqual(
1742 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")),
1743 ["physical_filter"],
1744 )
1745 with self.assertRaises(TypeError):
1746 self.query().data_ids(["visit"]).order_by(3)
1747 with self.assertRaises(InvalidQueryError):
1748 self.query().data_ids(["visit"]).order_by("visit.region")
1749 with self.assertRaisesRegex(InvalidQueryError, "Ambiguous"):
1750 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin")
1751 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1752 self.query().data_ids(["visit", "exposure"]).order_by("blarg")
1753 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1754 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse")
1755 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1756 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey")
1757 with self.assertRaisesRegex(InvalidQueryError, "not valid for datasets"):
1758 self.query().datasets("raw").order_by("raw.seq_num")
1760 def test_invalid_models(self) -> None:
1761 """Test invalid models and combinations of models that cannot be
1762 constructed via the public Query and *QueryResults interfaces.
1763 """
1764 x = ExpressionFactory(self.universe)
1765 with self.assertRaises(InvalidQueryError):
1766 # QueryTree dimensions do not cover dataset dimensions.
1767 qt.QueryTree(
1768 dimensions=self.universe.conform(["visit"]),
1769 datasets={
1770 "raw": qt.DatasetSearch(
1771 collections=("DummyCam/raw/all",),
1772 dimensions=self.raw.dimensions.as_group(),
1773 )
1774 },
1775 )
1776 with self.assertRaises(InvalidQueryError):
1777 # QueryTree dimensions do no cover predicate dimensions.
1778 qt.QueryTree(
1779 dimensions=self.universe.conform(["visit"]),
1780 predicate=(x.detector > 5),
1781 )
1782 with self.assertRaises(InvalidQueryError):
1783 # Predicate references a dataset not in the QueryTree.
1784 qt.QueryTree(
1785 dimensions=self.universe.conform(["exposure", "detector"]),
1786 predicate=(x["raw"].collection == "bird"),
1787 )
1788 with self.assertRaises(InvalidQueryError):
1789 # ResultSpec's dimensions are not a subset of the query tree's.
1790 DimensionRecordQueryResults(
1791 _TestQueryDriver(),
1792 qt.QueryTree(dimensions=self.universe.conform(["tract"])),
1793 qrs.DimensionRecordResultSpec(element=self.universe["detector"]),
1794 )
1795 with self.assertRaises(InvalidQueryError):
1796 # ResultSpec's datasets are not a subset of the query tree's.
1797 DatasetRefQueryResults(
1798 _TestQueryDriver(),
1799 qt.QueryTree(dimensions=self.raw.dimensions.as_group()),
1800 qrs.DatasetRefResultSpec(
1801 dataset_type_name="raw",
1802 dimensions=self.raw.dimensions.as_group(),
1803 storage_class_name=self.raw.storageClass_name,
1804 find_first=True,
1805 ),
1806 )
1807 with self.assertRaises(InvalidQueryError):
1808 # ResultSpec's order_by expression is not related to the dimensions
1809 # we're returning.
1810 x = ExpressionFactory(self.universe)
1811 DimensionRecordQueryResults(
1812 _TestQueryDriver(),
1813 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1814 qrs.DimensionRecordResultSpec(
1815 element=self.universe["detector"], order_by=(x.unwrap(x.visit),)
1816 ),
1817 )
1818 with self.assertRaises(InvalidQueryError):
1819 # ResultSpec's order_by expression is not related to the datasets
1820 # we're returning.
1821 x = ExpressionFactory(self.universe)
1822 DimensionRecordQueryResults(
1823 _TestQueryDriver(),
1824 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1825 qrs.DimensionRecordResultSpec(
1826 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),)
1827 ),
1828 )
1830 def test_general_result_spec(self) -> None:
1831 """Tests for GeneralResultSpec.
1833 Unlike the other ResultSpec objects, we don't have a *QueryResults
1834 class for GeneralResultSpec yet, so we can't use the higher-level
1835 interfaces to test it like we can the others.
1836 """
1837 a = qrs.GeneralResultSpec(
1838 dimensions=self.universe.conform(["detector"]),
1839 dimension_fields={"detector": {"purpose"}},
1840 dataset_fields={},
1841 find_first=False,
1842 )
1843 self.assertEqual(a.find_first_dataset, None)
1844 a_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1845 a_columns.dimension_fields["detector"].add("purpose")
1846 self.assertEqual(a.get_result_columns(), a_columns)
1847 b = qrs.GeneralResultSpec(
1848 dimensions=self.universe.conform(["detector"]),
1849 dimension_fields={},
1850 dataset_fields={"bias": {"timespan", "dataset_id"}},
1851 find_first=True,
1852 )
1853 self.assertEqual(b.find_first_dataset, "bias")
1854 b_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1855 b_columns.dataset_fields["bias"].add("timespan")
1856 b_columns.dataset_fields["bias"].add("dataset_id")
1857 self.assertEqual(b.get_result_columns(), b_columns)
1858 with self.assertRaises(InvalidQueryError):
1859 # More than one dataset type with find_first
1860 qrs.GeneralResultSpec(
1861 dimensions=self.universe.conform(["detector", "exposure"]),
1862 dimension_fields={},
1863 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}},
1864 find_first=True,
1865 )
1866 with self.assertRaises(InvalidQueryError):
1867 # Out-of-bounds dimension fields.
1868 qrs.GeneralResultSpec(
1869 dimensions=self.universe.conform(["detector"]),
1870 dimension_fields={"visit": {"name"}},
1871 dataset_fields={},
1872 find_first=False,
1873 )
1874 with self.assertRaises(InvalidQueryError):
1875 # No fields for dimension element.
1876 qrs.GeneralResultSpec(
1877 dimensions=self.universe.conform(["detector"]),
1878 dimension_fields={"detector": set()},
1879 dataset_fields={},
1880 find_first=True,
1881 )
1882 with self.assertRaises(InvalidQueryError):
1883 # No fields for dataset.
1884 qrs.GeneralResultSpec(
1885 dimensions=self.universe.conform(["detector"]),
1886 dimension_fields={},
1887 dataset_fields={"bias": set()},
1888 find_first=True,
1889 )
1892if __name__ == "__main__":
1893 unittest.main()