Coverage for tests / test_query_interface.py: 11%
954 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:36 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:36 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests for the public Butler.query interface and the Pydantic models that
29back it using a mock column-expression visitor and a mock QueryDriver
30implementation.
32These tests are entirely independent of which kind of butler or database
33backend we're using.
35This is a very large test file because a lot of tests make use of those mocks,
36but they're not so generally useful that I think they're worth putting in the
37library proper.
38"""
40from __future__ import annotations
42import dataclasses
43import itertools
44import unittest
45import uuid
46from collections.abc import Iterable, Iterator, Mapping, Set
47from typing import Any
49import astropy.table
50import astropy.time
51import numpy as np
53from lsst.daf.butler import (
54 CollectionType,
55 DataCoordinate,
56 DataIdValue,
57 DatasetRef,
58 DatasetType,
59 DimensionGroup,
60 DimensionRecord,
61 DimensionRecordSet,
62 DimensionUniverse,
63 InvalidQueryError,
64 MissingDatasetTypeError,
65 NamedValueSet,
66 NoDefaultCollectionError,
67 Timespan,
68)
69from lsst.daf.butler.queries import (
70 DataCoordinateQueryResults,
71 DatasetRefQueryResults,
72 DimensionRecordQueryResults,
73 Query,
74)
75from lsst.daf.butler.queries import driver as qd
76from lsst.daf.butler.queries import result_specs as qrs
77from lsst.daf.butler.queries import tree as qt
78from lsst.daf.butler.queries.expression_factory import ExpressionFactory
79from lsst.daf.butler.queries.tree._column_expression import UnaryExpression
80from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands
81from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
82from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError
83from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord
84from lsst.sphgeom import DISJOINT, Mq3cPixelization
87class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]):
88 """Test visitor for column expressions.
90 This visitor evaluates column expressions using regular Python logic.
92 Parameters
93 ----------
94 dimension_keys : `~collections.abc.Mapping`, optional
95 Mapping from dimension name to the value it should be assigned by the
96 visitor.
97 dimension_fields : `~collections.abc.Mapping`, optional
98 Mapping from ``(dimension element name, field)`` tuple to the value it
99 should be assigned by the visitor.
100 dataset_fields : `~collections.abc.Mapping`, optional
101 Mapping from ``(dataset type name, field)`` tuple to the value it
102 should be assigned by the visitor.
103 query_tree_items : `~collections.abc.Set`, optional
104 Set that should be used as the right-hand side of element-in-query
105 predicates.
106 """
108 def __init__(
109 self,
110 dimension_keys: Mapping[str, Any] | None = None,
111 dimension_fields: Mapping[tuple[str, str], Any] | None = None,
112 dataset_fields: Mapping[tuple[str, str], Any] | None = None,
113 query_tree_items: Set[Any] = frozenset(),
114 ):
115 self.dimension_keys = dimension_keys or {}
116 self.dimension_fields = dimension_fields or {}
117 self.dataset_fields = dataset_fields or {}
118 self.query_tree_items = query_tree_items
120 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any:
121 match expression.operator:
122 case "+":
123 return expression.a.visit(self) + expression.b.visit(self)
124 case "-":
125 return expression.a.visit(self) - expression.b.visit(self)
126 case "*":
127 return expression.a.visit(self) * expression.b.visit(self)
128 case "/":
129 match expression.column_type:
130 case "int":
131 return expression.a.visit(self) // expression.b.visit(self)
132 case "float":
133 return expression.a.visit(self) / expression.b.visit(self)
134 case "%":
135 return expression.a.visit(self) % expression.b.visit(self)
137 def visit_comparison(
138 self,
139 a: qt.ColumnExpression,
140 operator: qt.ComparisonOperator,
141 b: qt.ColumnExpression,
142 flags: PredicateVisitFlags,
143 ) -> bool:
144 match operator:
145 case "==":
146 return a.visit(self) == b.visit(self)
147 case "!=":
148 return a.visit(self) != b.visit(self)
149 case "<":
150 return a.visit(self) < b.visit(self)
151 case ">":
152 return a.visit(self) > b.visit(self)
153 case "<=":
154 return a.visit(self) <= b.visit(self)
155 case ">=":
156 return a.visit(self) >= b.visit(self)
157 case "overlaps":
158 return not (a.visit(self).relate(b.visit(self)) & DISJOINT)
160 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any:
161 return self.dataset_fields[expression.dataset_type, expression.field]
163 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any:
164 return self.dimension_fields[expression.element.name, expression.field]
166 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any:
167 return self.dimension_keys[expression.dimension.name]
169 def visit_in_container(
170 self,
171 member: qt.ColumnExpression,
172 container: tuple[qt.ColumnExpression, ...],
173 flags: PredicateVisitFlags,
174 ) -> bool:
175 return member.visit(self) in [item.visit(self) for item in container]
177 def visit_in_range(
178 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
179 ) -> bool:
180 return member.visit(self) in range(start, stop, step)
182 def visit_in_query_tree(
183 self,
184 member: qt.ColumnExpression,
185 column: qt.ColumnExpression,
186 query_tree: qt.QueryTree,
187 flags: PredicateVisitFlags,
188 ) -> bool:
189 return member.visit(self) in self.query_tree_items
191 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool:
192 return operand.visit(self) is None
194 def visit_literal(self, expression: qt.ColumnLiteral) -> Any:
195 return expression.get_literal_value()
197 def visit_reversed(self, expression: qt.Reversed) -> Any:
198 return _TestReversed(expression.operand.visit(self))
200 def visit_unary_expression(self, expression: UnaryExpression) -> Any:
201 match expression.operator:
202 case "-":
203 return -expression.operand.visit(self)
204 case "begin_of":
205 return expression.operand.visit(self).begin
206 case "end_of":
207 return expression.operand.visit(self).end
209 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool:
210 return all(results)
212 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool:
213 return not result
215 def apply_logical_or(
216 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags
217 ) -> bool:
218 return any(results)
221@dataclasses.dataclass
222class _TestReversed:
223 """Struct used by _TestVisitor" to mark an expression as reversed in sort
224 order.
225 """
227 operand: Any
230class _TestQueryExecution(BaseException):
231 """Exception raised by _TestQueryDriver.execute to communicate its args
232 back to the caller.
233 """
235 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None:
236 self.result_spec = result_spec
237 self.tree = tree
238 self.driver = driver
241class _TestQueryCount(BaseException):
242 """Exception raised by _TestQueryDriver.count to communicate its args
243 back to the caller.
244 """
246 def __init__(
247 self,
248 result_spec: qrs.ResultSpec,
249 tree: qt.QueryTree,
250 driver: _TestQueryDriver,
251 exact: bool,
252 discard: bool,
253 ) -> None:
254 self.result_spec = result_spec
255 self.tree = tree
256 self.driver = driver
257 self.exact = exact
258 self.discard = discard
261class _TestQueryAny(BaseException):
262 """Exception raised by _TestQueryDriver.any to communicate its args
263 back to the caller.
264 """
266 def __init__(
267 self,
268 tree: qt.QueryTree,
269 driver: _TestQueryDriver,
270 exact: bool,
271 execute: bool,
272 ) -> None:
273 self.tree = tree
274 self.driver = driver
275 self.exact = exact
276 self.execute = execute
279class _TestQueryExplainNoResults(BaseException):
280 """Exception raised by _TestQueryDriver.explain_no_results to communicate
281 its args back to the caller.
282 """
284 def __init__(
285 self,
286 tree: qt.QueryTree,
287 driver: _TestQueryDriver,
288 execute: bool,
289 ) -> None:
290 self.tree = tree
291 self.driver = driver
292 self.execute = execute
295class _TestQueryDriver(qd.QueryDriver):
296 """Mock implementation of `QueryDriver` that mostly raises exceptions that
297 communicate the arguments its methods were called with.
299 Parameters
300 ----------
301 default_collections : `tuple` [ `str`, ... ], optional
302 Default collection the query or parent butler is imagined to have been
303 constructed with.
304 collection_info : `~collections.abc.Mapping`, optional
305 Mapping from collection name to its record and summary, simulating the
306 collections present in the data repository.
307 dataset_types : `~collections.abc.Mapping`, optional
308 Mapping from dataset type to its definition, simulating the dataset
309 types registered in the data repository.
310 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional
311 A tuple of iterables of arbitrary type to use as result rows any time
312 `execute` is called, with each nested iterable considered a separate
313 page. The result type is not checked for consistency with the result
314 spec. If this is not provided, `execute` will instead raise
315 `_TestQueryExecution`, and `fetch_page` will not do anything useful.
316 """
318 def __init__(
319 self,
320 default_collections: tuple[str, ...] | None = None,
321 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None,
322 dataset_types: Mapping[str, DatasetType] | None = None,
323 result_rows: tuple[Iterable[Any], ...] | None = None,
324 ) -> None:
325 self._universe = DimensionUniverse()
326 # Mapping of the arguments passed to materialize, keyed by the UUID
327 # that that each call returned.
328 self.materializations: dict[
329 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]]
330 ] = {}
331 # Mapping of the arguments passed to upload_data_coordinates, keyed by
332 # the UUID that that each call returned.
333 self.data_coordinate_uploads: dict[
334 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]]
335 ] = {}
336 self._default_collections = default_collections
337 self._collection_info = collection_info or {}
338 self._dataset_types = dataset_types or {}
339 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = []
340 self._result_rows = result_rows
342 @property
343 def universe(self) -> DimensionUniverse:
344 return self._universe
346 def __enter__(self) -> None:
347 pass
349 def __exit__(self, *args: Any, **kwargs: Any) -> None:
350 pass
352 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> Iterator[qd.ResultPage]:
353 if self._result_rows is None:
354 raise _TestQueryExecution(result_spec, tree, self)
356 for rows in self._result_rows:
357 yield self._make_next_page(result_spec, rows)
359 def _make_next_page(self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any]) -> qd.ResultPage:
360 match result_spec:
361 case qrs.DataCoordinateResultSpec():
362 return qd.DataCoordinateResultPage(spec=result_spec, rows=current_rows)
363 case qrs.DimensionRecordResultSpec():
364 return qd.DimensionRecordResultPage(spec=result_spec, rows=current_rows)
365 case qrs.DatasetRefResultSpec():
366 return qd.DatasetRefResultPage(spec=result_spec, rows=current_rows)
367 case _:
368 raise NotImplementedError("Other query types not yet supported.")
370 def materialize(
371 self,
372 tree: qt.QueryTree,
373 dimensions: DimensionGroup,
374 datasets: frozenset[str],
375 allow_duplicate_overlaps: bool = False,
376 ) -> qd.MaterializationKey:
377 key = uuid.uuid4()
378 self.materializations[key] = (tree, dimensions, datasets)
379 return key
381 def upload_data_coordinates(
382 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
383 ) -> qd.DataCoordinateUploadKey:
384 key = uuid.uuid4()
385 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows))
386 return key
388 def count(
389 self,
390 tree: qt.QueryTree,
391 result_spec: qrs.ResultSpec,
392 *,
393 exact: bool,
394 discard: bool,
395 ) -> int:
396 raise _TestQueryCount(result_spec, tree, self, exact, discard)
398 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
399 raise _TestQueryAny(tree, self, exact, execute)
401 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]:
402 raise _TestQueryExplainNoResults(tree, self, execute)
404 def get_default_collections(self) -> tuple[str, ...]:
405 if self._default_collections is None:
406 raise NoDefaultCollectionError()
407 return self._default_collections
409 def get_dataset_type(self, name: str) -> DatasetType:
410 try:
411 return self._dataset_types[name]
412 except KeyError:
413 raise MissingDatasetTypeError(name)
416class ColumnExpressionsTestCase(unittest.TestCase):
417 """Tests for column expression objects in lsst.daf.butler.queries.tree."""
419 def setUp(self) -> None:
420 self.universe = DimensionUniverse()
421 self.x = ExpressionFactory(self.universe)
423 def query(self, **kwargs: Any) -> Query:
424 """Make an initial Query object with the given kwargs used to
425 initialize the _TestQueryDriver.
426 """
427 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
429 def test_int_literals(self) -> None:
430 expr = self.x.unwrap(self.x.literal(5))
431 self.assertEqual(expr.value, 5)
432 self.assertEqual(expr.get_literal_value(), 5)
433 self.assertEqual(expr.expression_type, "int")
434 self.assertEqual(expr.column_type, "int")
435 self.assertEqual(str(expr), "5")
436 self.assertTrue(expr.is_literal)
437 columns = qt.ColumnSet(self.universe.empty)
438 expr.gather_required_columns(columns)
439 self.assertFalse(columns)
440 self.assertEqual(expr.visit(_TestVisitor()), 5)
442 def test_string_literals(self) -> None:
443 expr = self.x.unwrap(self.x.literal("five"))
444 self.assertEqual(expr.value, "five")
445 self.assertEqual(expr.get_literal_value(), "five")
446 self.assertEqual(expr.expression_type, "string")
447 self.assertEqual(expr.column_type, "string")
448 self.assertEqual(str(expr), "'five'")
449 self.assertTrue(expr.is_literal)
450 columns = qt.ColumnSet(self.universe.empty)
451 expr.gather_required_columns(columns)
452 self.assertFalse(columns)
453 self.assertEqual(expr.visit(_TestVisitor()), "five")
455 def test_float_literals(self) -> None:
456 expr = self.x.unwrap(self.x.literal(0.5))
457 self.assertEqual(expr.value, 0.5)
458 self.assertEqual(expr.get_literal_value(), 0.5)
459 self.assertEqual(expr.expression_type, "float")
460 self.assertEqual(expr.column_type, "float")
461 self.assertEqual(str(expr), "0.5")
462 self.assertTrue(expr.is_literal)
463 columns = qt.ColumnSet(self.universe.empty)
464 expr.gather_required_columns(columns)
465 self.assertFalse(columns)
466 self.assertEqual(expr.visit(_TestVisitor()), 0.5)
468 def test_hash_literals(self) -> None:
469 expr = self.x.unwrap(self.x.literal(b"eleven"))
470 self.assertEqual(expr.value, b"eleven")
471 self.assertEqual(expr.get_literal_value(), b"eleven")
472 self.assertEqual(expr.expression_type, "hash")
473 self.assertEqual(expr.column_type, "hash")
474 self.assertEqual(str(expr), "(bytes)")
475 self.assertTrue(expr.is_literal)
476 columns = qt.ColumnSet(self.universe.empty)
477 expr.gather_required_columns(columns)
478 self.assertFalse(columns)
479 self.assertEqual(expr.visit(_TestVisitor()), b"eleven")
481 def test_uuid_literals(self) -> None:
482 value = uuid.uuid4()
483 expr = self.x.unwrap(self.x.literal(value))
484 self.assertEqual(expr.value, value)
485 self.assertEqual(expr.get_literal_value(), value)
486 self.assertEqual(expr.expression_type, "uuid")
487 self.assertEqual(expr.column_type, "uuid")
488 self.assertEqual(str(expr), str(value))
489 self.assertTrue(expr.is_literal)
490 columns = qt.ColumnSet(self.universe.empty)
491 expr.gather_required_columns(columns)
492 self.assertFalse(columns)
493 self.assertEqual(expr.visit(_TestVisitor()), value)
495 def test_datetime_literals(self) -> None:
496 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
497 expr = self.x.unwrap(self.x.literal(value))
498 self.assertEqual(expr.value, value)
499 self.assertEqual(expr.get_literal_value(), value)
500 self.assertEqual(expr.expression_type, "datetime")
501 self.assertEqual(expr.column_type, "datetime")
502 self.assertEqual(str(expr), "2020-01-01T00:00:00")
503 self.assertTrue(expr.is_literal)
504 columns = qt.ColumnSet(self.universe.empty)
505 expr.gather_required_columns(columns)
506 self.assertFalse(columns)
507 self.assertEqual(expr.visit(_TestVisitor()), value)
509 def test_timespan_literals(self) -> None:
510 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
511 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
512 value = Timespan(begin, end)
513 expr = self.x.unwrap(self.x.literal(value))
514 self.assertEqual(expr.value, value)
515 self.assertEqual(expr.get_literal_value(), value)
516 self.assertEqual(expr.expression_type, "timespan")
517 self.assertEqual(expr.column_type, "timespan")
518 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)")
519 self.assertTrue(expr.is_literal)
520 columns = qt.ColumnSet(self.universe.empty)
521 expr.gather_required_columns(columns)
522 self.assertFalse(columns)
523 self.assertEqual(expr.visit(_TestVisitor()), value)
525 def test_region_literals(self) -> None:
526 pixelization = Mq3cPixelization(10)
527 value = pixelization.quad(12058870)
528 expr = self.x.unwrap(self.x.literal(value))
529 self.assertEqual(expr.value, value)
530 self.assertEqual(expr.get_literal_value(), value)
531 self.assertEqual(expr.expression_type, "region")
532 self.assertEqual(expr.column_type, "region")
533 self.assertEqual(str(expr), "(region)")
534 self.assertTrue(expr.is_literal)
535 columns = qt.ColumnSet(self.universe.empty)
536 expr.gather_required_columns(columns)
537 self.assertFalse(columns)
538 self.assertEqual(expr.visit(_TestVisitor()), value)
540 def test_invalid_literal(self) -> None:
541 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."):
542 self.x.literal(5j)
544 def test_dimension_key_reference(self) -> None:
545 expr = self.x.unwrap(self.x.detector)
546 self.assertIsNone(expr.get_literal_value())
547 self.assertEqual(expr.expression_type, "dimension_key")
548 self.assertEqual(expr.column_type, "int")
549 self.assertEqual(str(expr), "detector")
550 self.assertFalse(expr.is_literal)
551 columns = qt.ColumnSet(self.universe.empty)
552 expr.gather_required_columns(columns)
553 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
554 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3)
556 def test_dimension_field_reference(self) -> None:
557 expr = self.x.unwrap(self.x.detector.purpose)
558 self.assertIsNone(expr.get_literal_value())
559 self.assertEqual(expr.expression_type, "dimension_field")
560 self.assertEqual(expr.column_type, "string")
561 self.assertEqual(str(expr), "detector.purpose")
562 self.assertFalse(expr.is_literal)
563 columns = qt.ColumnSet(self.universe.empty)
564 expr.gather_required_columns(columns)
565 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
566 self.assertEqual(columns.dimension_fields["detector"], {"purpose"})
567 with self.assertRaises(InvalidQueryError):
568 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region")
569 self.assertEqual(
570 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science"
571 )
573 def test_dataset_field_reference(self) -> None:
574 expr = self.x.unwrap(self.x["raw"].ingest_date)
575 self.assertIsNone(expr.get_literal_value())
576 self.assertEqual(expr.expression_type, "dataset_field")
577 self.assertEqual(str(expr), "raw.ingest_date")
578 self.assertFalse(expr.is_literal)
579 columns = qt.ColumnSet(self.universe.empty)
580 expr.gather_required_columns(columns)
581 self.assertEqual(columns.dimensions, self.universe.empty)
582 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"})
583 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid")
584 self.assertEqual(
585 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string"
586 )
587 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string")
588 self.assertEqual(
589 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "ingest_date"
590 )
591 self.assertEqual(
592 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan"
593 )
594 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
595 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value)
597 def test_unary_negation(self) -> None:
598 expr = self.x.unwrap(-self.x.visit.exposure_time)
599 self.assertIsNone(expr.get_literal_value())
600 self.assertEqual(expr.expression_type, "unary")
601 self.assertEqual(expr.column_type, "float")
602 self.assertEqual(str(expr), "-visit.exposure_time")
603 self.assertFalse(expr.is_literal)
604 columns = qt.ColumnSet(self.universe.empty)
605 expr.gather_required_columns(columns)
606 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
607 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
608 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0)
609 with self.assertRaises(InvalidQueryError):
610 qt.UnaryExpression(
611 operand=qt.DimensionFieldReference(
612 element=self.universe.dimensions["detector"], field="purpose"
613 ),
614 operator="-",
615 )
617 def test_unary_timespan_begin(self) -> None:
618 expr = self.x.unwrap(self.x.visit.timespan.begin)
619 self.assertIsNone(expr.get_literal_value())
620 self.assertEqual(expr.expression_type, "unary")
621 self.assertEqual(expr.column_type, "datetime")
622 self.assertEqual(str(expr), "visit.timespan.begin")
623 self.assertFalse(expr.is_literal)
624 columns = qt.ColumnSet(self.universe.empty)
625 expr.gather_required_columns(columns)
626 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
627 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
628 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
629 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
630 value = Timespan(begin, end)
631 self.assertEqual(
632 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin
633 )
634 with self.assertRaises(InvalidQueryError):
635 qt.UnaryExpression(
636 operand=qt.DimensionFieldReference(
637 element=self.universe.dimensions["detector"], field="purpose"
638 ),
639 operator="begin_of",
640 )
642 def test_unary_timespan_end(self) -> None:
643 expr = self.x.unwrap(self.x.visit.timespan.end)
644 self.assertIsNone(expr.get_literal_value())
645 self.assertEqual(expr.expression_type, "unary")
646 self.assertEqual(expr.column_type, "datetime")
647 self.assertEqual(str(expr), "visit.timespan.end")
648 self.assertFalse(expr.is_literal)
649 columns = qt.ColumnSet(self.universe.empty)
650 expr.gather_required_columns(columns)
651 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
652 self.assertEqual(columns.dimension_fields["visit"], {"timespan"})
653 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
654 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai")
655 value = Timespan(begin, end)
656 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end)
657 with self.assertRaises(InvalidQueryError):
658 qt.UnaryExpression(
659 operand=qt.DimensionFieldReference(
660 element=self.universe.dimensions["detector"], field="purpose"
661 ),
662 operator="end_of",
663 )
665 def test_binary_expression_float(self) -> None:
666 for proxy, string, value in [
667 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0),
668 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0),
669 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0),
670 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0),
671 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0),
672 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0),
673 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0),
674 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0),
675 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0),
676 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0),
677 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0),
678 # We don't need the parentheses we generate in the next one, but
679 # they're not a problem either.
680 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0),
681 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0),
682 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0),
683 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0),
684 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0),
685 ]:
686 with self.subTest(string=string):
687 expr = self.x.unwrap(proxy)
688 self.assertIsNone(expr.get_literal_value())
689 self.assertEqual(expr.expression_type, "binary")
690 self.assertEqual(expr.column_type, "float")
691 self.assertEqual(str(expr), string)
692 self.assertFalse(expr.is_literal)
693 columns = qt.ColumnSet(self.universe.empty)
694 expr.gather_required_columns(columns)
695 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
696 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"})
697 self.assertEqual(
698 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value
699 )
701 def test_binary_modulus(self) -> None:
702 for proxy, string, value in [
703 (self.x.visit.id % 2, "visit % 2", 1),
704 (52 % self.x.visit, "52 % visit", 2),
705 ]:
706 with self.subTest(string=string):
707 expr = self.x.unwrap(proxy)
708 self.assertIsNone(expr.get_literal_value())
709 self.assertEqual(expr.expression_type, "binary")
710 self.assertEqual(expr.column_type, "int")
711 self.assertEqual(str(expr), string)
712 self.assertFalse(expr.is_literal)
713 columns = qt.ColumnSet(self.universe.empty)
714 expr.gather_required_columns(columns)
715 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
716 self.assertFalse(columns.dimension_fields["visit"])
717 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value)
719 def test_binary_expression_validation(self) -> None:
720 with self.assertRaises(InvalidQueryError):
721 # No arithmetic operators on strings (we do not interpret + as
722 # concatenation).
723 self.x.instrument + "suffix"
724 with self.assertRaises(InvalidQueryError):
725 # Mixed types are not supported, even when they both support the
726 # operator.
727 self.x.visit.exposure_time + self.x.detector
728 with self.assertRaises(InvalidQueryError):
729 # No modulus for floats.
730 self.x.visit.exposure_time % 5.0
732 def test_reversed(self) -> None:
733 expr = self.x.detector.desc
734 self.assertIsNone(expr.get_literal_value())
735 self.assertEqual(expr.expression_type, "reversed")
736 self.assertEqual(expr.column_type, "int")
737 self.assertEqual(str(expr), "detector DESC")
738 self.assertFalse(expr.is_literal)
739 columns = qt.ColumnSet(self.universe.empty)
740 expr.gather_required_columns(columns)
741 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
742 self.assertFalse(columns.dimension_fields["detector"])
743 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5))
745 def test_trivial_predicate(self) -> None:
746 """Test logical operations on trivial True/False predicates."""
747 yes = qt.Predicate.from_bool(True)
748 no = qt.Predicate.from_bool(False)
749 maybe: qt.Predicate = self.x.detector == 5
750 for predicate in [
751 yes,
752 yes.logical_or(no),
753 no.logical_or(yes),
754 yes.logical_and(yes),
755 no.logical_not(),
756 yes.logical_or(maybe),
757 maybe.logical_or(yes),
758 ]:
759 self.assertEqual(predicate.column_type, "bool")
760 self.assertEqual(str(predicate), "True")
761 self.assertTrue(predicate.visit(_TestVisitor()))
762 self.assertEqual(predicate.operands, ())
763 for predicate in [
764 no,
765 yes.logical_and(no),
766 no.logical_and(yes),
767 no.logical_or(no),
768 yes.logical_not(),
769 no.logical_and(maybe),
770 maybe.logical_and(no),
771 ]:
772 self.assertEqual(predicate.column_type, "bool")
773 self.assertEqual(str(predicate), "False")
774 self.assertFalse(predicate.visit(_TestVisitor()))
775 self.assertEqual(predicate.operands, ((),))
776 for predicate in [
777 maybe,
778 yes.logical_and(maybe),
779 no.logical_or(maybe),
780 maybe.logical_not().logical_not(),
781 ]:
782 self.assertEqual(predicate.column_type, "bool")
783 self.assertEqual(str(predicate), "detector == 5")
784 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5})))
785 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4})))
786 self.assertEqual(len(predicate.operands), 1)
787 self.assertEqual(len(predicate.operands[0]), 1)
788 self.assertIs(predicate.operands[0][0], maybe.operands[0][0])
790 def test_comparison(self) -> None:
791 predicate: qt.Predicate
792 string: str
793 value: bool
794 for detector in (4, 5, 6):
795 for predicate, string, value in [
796 (self.x.detector == 5, "detector == 5", detector == 5),
797 (self.x.detector != 5, "detector != 5", detector != 5),
798 (self.x.detector < 5, "detector < 5", detector < 5),
799 (self.x.detector > 5, "detector > 5", detector > 5),
800 (self.x.detector <= 5, "detector <= 5", detector <= 5),
801 (self.x.detector >= 5, "detector >= 5", detector >= 5),
802 (self.x.detector == 5, "detector == 5", detector == 5),
803 (self.x.detector != 5, "detector != 5", detector != 5),
804 (self.x.detector < 5, "detector < 5", detector < 5),
805 (self.x.detector > 5, "detector > 5", detector > 5),
806 (self.x.detector <= 5, "detector <= 5", detector <= 5),
807 (self.x.detector >= 5, "detector >= 5", detector >= 5),
808 ]:
809 with self.subTest(string=string, detector=detector):
810 self.assertEqual(predicate.column_type, "bool")
811 self.assertEqual(str(predicate), string)
812 columns = qt.ColumnSet(self.universe.empty)
813 predicate.gather_required_columns(columns)
814 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
815 self.assertFalse(columns.dimension_fields["detector"])
816 self.assertEqual(
817 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value
818 )
819 inverted = predicate.logical_not()
820 self.assertEqual(inverted.column_type, "bool")
821 self.assertEqual(str(inverted), f"NOT {string}")
822 self.assertEqual(
823 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value
824 )
825 columns = qt.ColumnSet(self.universe.empty)
826 inverted.gather_required_columns(columns)
827 self.assertEqual(columns.dimensions, self.universe.conform(["detector"]))
828 self.assertFalse(columns.dimension_fields["detector"])
830 def test_overlap_comparison(self) -> None:
831 pixelization = Mq3cPixelization(10)
832 region1 = pixelization.quad(12058870)
833 predicate = self.x.visit.region.overlaps(region1)
834 self.assertEqual(predicate.column_type, "bool")
835 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)")
836 columns = qt.ColumnSet(self.universe.empty)
837 predicate.gather_required_columns(columns)
838 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
839 self.assertEqual(columns.dimension_fields["visit"], {"region"})
840 region2 = pixelization.quad(12058857)
841 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
842 inverted = predicate.logical_not()
843 self.assertEqual(inverted.column_type, "bool")
844 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)")
845 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2})))
846 columns = qt.ColumnSet(self.universe.empty)
847 inverted.gather_required_columns(columns)
848 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
849 self.assertEqual(columns.dimension_fields["visit"], {"region"})
851 def test_invalid_comparison(self) -> None:
852 # Mixed type comparisons.
853 with self.assertRaises(InvalidQueryError):
854 self.x.visit > "three"
855 # Invalid operator for type.
856 with self.assertRaises(InvalidQueryError):
857 self.x["raw"].dataset_id < uuid.uuid4()
859 def test_is_null(self) -> None:
860 predicate = self.x.visit.region.is_null
861 self.assertEqual(predicate.column_type, "bool")
862 self.assertEqual(str(predicate), "visit.region IS NULL")
863 columns = qt.ColumnSet(self.universe.empty)
864 predicate.gather_required_columns(columns)
865 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
866 self.assertEqual(columns.dimension_fields["visit"], {"region"})
867 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
868 inverted = predicate.logical_not()
869 self.assertEqual(inverted.column_type, "bool")
870 self.assertEqual(str(inverted), "NOT visit.region IS NULL")
871 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None})))
872 inverted.gather_required_columns(columns)
873 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
874 self.assertEqual(columns.dimension_fields["visit"], {"region"})
876 def test_in_container(self) -> None:
877 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id])
878 self.assertEqual(predicate.column_type, "bool")
879 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]")
880 columns = qt.ColumnSet(self.universe.empty)
881 predicate.gather_required_columns(columns)
882 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
883 self.assertFalse(columns.dimension_fields["visit"])
884 self.assertFalse(columns.dimension_fields["exposure"])
885 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
886 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
887 inverted = predicate.logical_not()
888 self.assertEqual(inverted.column_type, "bool")
889 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]")
890 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2})))
891 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5})))
892 columns = qt.ColumnSet(self.universe.empty)
893 inverted.gather_required_columns(columns)
894 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"]))
895 self.assertFalse(columns.dimension_fields["visit"])
896 self.assertFalse(columns.dimension_fields["exposure"])
897 with self.assertRaises(InvalidQueryError):
898 # Regions (and timespans) not allowed in IN expressions, since that
899 # suggests topological logic we're not actually doing. We can't
900 # use ExpressionFactory because it prohibits this case with typing.
901 pixelization = Mq3cPixelization(10)
902 region = pixelization.quad(12058870)
903 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)])
904 with self.assertRaises(InvalidQueryError):
905 # Mismatched types.
906 self.x.visit.in_iterable([3.5, 2.1])
908 def test_in_range(self) -> None:
909 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2)
910 self.assertEqual(predicate.column_type, "bool")
911 self.assertEqual(str(predicate), "visit IN 2:8:2")
912 columns = qt.ColumnSet(self.universe.empty)
913 predicate.gather_required_columns(columns)
914 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
915 self.assertFalse(columns.dimension_fields["visit"])
916 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2})))
917 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8})))
918 inverted = predicate.logical_not()
919 self.assertEqual(inverted.column_type, "bool")
920 self.assertEqual(str(inverted), "NOT visit IN 2:8:2")
921 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2})))
922 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8})))
923 columns = qt.ColumnSet(self.universe.empty)
924 inverted.gather_required_columns(columns)
925 self.assertEqual(columns.dimensions, self.universe.conform(["visit"]))
926 self.assertFalse(columns.dimension_fields["visit"])
927 with self.assertRaises(InvalidQueryError):
928 # Only integer fields allowed.
929 self.x.visit.exposure_time.in_range(2, 4)
930 with self.assertRaises(InvalidQueryError):
931 # Step must be positive.
932 self.x.visit.in_range(2, 4, -1)
933 with self.assertRaises(InvalidQueryError):
934 # Stop must be >= start.
935 self.x.visit.in_range(2, 0)
937 def test_in_query(self) -> None:
938 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3)
939 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query)
940 self.assertEqual(predicate.column_type, "bool")
941 self.assertEqual(str(predicate), "exposure IN (query).visit")
942 columns = qt.ColumnSet(self.universe.empty)
943 predicate.gather_required_columns(columns)
944 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
945 self.assertFalse(columns.dimension_fields["exposure"])
946 self.assertTrue(
947 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
948 )
949 self.assertFalse(
950 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
951 )
952 inverted = predicate.logical_not()
953 self.assertEqual(inverted.column_type, "bool")
954 self.assertEqual(str(inverted), "NOT exposure IN (query).visit")
955 self.assertFalse(
956 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3}))
957 )
958 self.assertTrue(
959 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3}))
960 )
961 columns = qt.ColumnSet(self.universe.empty)
962 inverted.gather_required_columns(columns)
963 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"]))
964 self.assertFalse(columns.dimension_fields["exposure"])
965 with self.assertRaises(InvalidQueryError):
966 # Regions (and timespans) not allowed in IN expressions, since that
967 # suggests topological logic we're not actually doing. We can't
968 # use ExpressionFactory because it prohibits this case with typing.
969 qt.Predicate.in_query(
970 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree
971 )
972 with self.assertRaises(InvalidQueryError):
973 # Mismatched types.
974 self.x.exposure.in_query(self.x.visit.exposure_time, query)
975 with self.assertRaises(InvalidQueryError):
976 # Query column requires dimensions that are not in the query.
977 self.x.exposure.in_query(self.x.patch, query)
978 with self.assertRaises(InvalidQueryError):
979 # Query column requires dataset type that is not in the query.
980 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query)
982 def test_complex_predicate(self) -> None:
983 """Test that predicates are converted to conjunctive normal form and
984 get parentheses in the right places when stringified.
985 """
986 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"})
987 a: qt.Predicate = self.x.visit > 5 # will evaluate to True
988 b: qt.Predicate = self.x.detector != 3 # will evaluate to False
989 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True
990 d: qt.Predicate = self.x.band == "g" # will evaluate to False
991 predicate: qt.Predicate
992 for predicate, string, value in [
993 (a.logical_or(b), f"{a} OR {b}", True),
994 (a.logical_or(c), f"{a} OR {c}", True),
995 (b.logical_or(d), f"{b} OR {d}", False),
996 (a.logical_and(b), f"{a} AND {b}", False),
997 (a.logical_and(c), f"{a} AND {c}", True),
998 (b.logical_and(d), f"{b} AND {d}", False),
999 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True),
1000 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False),
1001 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True),
1002 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False),
1003 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True),
1004 (
1005 a.logical_and(b).logical_or(c.logical_and(d)),
1006 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})",
1007 False,
1008 ),
1009 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False),
1010 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False),
1011 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True),
1012 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True),
1013 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False),
1014 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True),
1015 (
1016 self.x.not_(a.logical_or(b).logical_and(c)),
1017 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})",
1018 False,
1019 ),
1020 (
1021 a.logical_and(b.logical_or(d)).logical_not(),
1022 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})",
1023 True,
1024 ),
1025 ]:
1026 with self.subTest(string=string):
1027 self.assertEqual(str(predicate), string)
1028 self.assertEqual(predicate.visit(visitor), value)
1030 def test_proxy_misc(self) -> None:
1031 """Test miscellaneous things on various ExpressionFactory proxies."""
1032 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region")
1033 self.assertEqual(str(self.x.visit.instrument), "instrument")
1034 self.assertEqual(str(self.x["raw"]), "raw")
1035 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date")
1036 self.assertEqual(
1037 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)),
1038 "visit.timespan OVERLAPS raw.timespan",
1039 )
1040 self.assertGreater(
1041 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"}
1042 )
1043 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"})
1044 with self.assertRaises(AttributeError):
1045 self.x["raw"].seq_num
1046 with self.assertRaises(AttributeError):
1047 self.x.visit.horse
1050class QueryTestCase(unittest.TestCase):
1051 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries."""
1053 def setUp(self) -> None:
1054 self.maxDiff = None
1055 self.universe = DimensionUniverse()
1056 # We use ArrowTable as the storage class for all dataset types because
1057 # it's got conversions that only require third-party packages we
1058 # already require.
1059 self.raw = DatasetType(
1060 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable"
1061 )
1062 self.refcat = DatasetType(
1063 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable"
1064 )
1065 self.bias = DatasetType(
1066 "bias",
1067 dimensions=self.universe.conform(["detector"]),
1068 storageClass="ArrowTable",
1069 isCalibration=True,
1070 )
1071 self.default_collections: list[str] | None = ["DummyCam/defaults"]
1072 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = {
1073 "DummyCam/raw/all": (
1074 RunRecord[int](1, name="DummyCam/raw/all"),
1075 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}),
1076 ),
1077 "DummyCam/calib": (
1078 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION),
1079 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}),
1080 ),
1081 "refcats": (
1082 RunRecord[int](3, name="refcats"),
1083 CollectionSummary(NamedValueSet({self.refcat}), governors={}),
1084 ),
1085 "DummyCam/defaults": (
1086 ChainedCollectionRecord[int](
1087 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats")
1088 ),
1089 CollectionSummary(
1090 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}}
1091 ),
1092 ),
1093 }
1094 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias}
1096 def query(self, **kwargs: Any) -> Query:
1097 """Make an initial Query object with the given kwargs used to
1098 initialize the _TestQueryDriver.
1100 The given kwargs override the test-case-attribute defaults.
1101 """
1102 kwargs.setdefault("default_collections", self.default_collections)
1103 kwargs.setdefault("collection_info", self.collection_info)
1104 kwargs.setdefault("dataset_types", self.dataset_types)
1105 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe))
1107 def test_dataset_join(self) -> None:
1108 """Test queries that have had a dataset search explicitly joined in via
1109 Query.join_dataset_search.
1111 Since this kind of query has a moderate amount of complexity, this is
1112 where we get a lot of basic coverage that applies to all kinds of
1113 queries, including:
1115 - getting data ID and dataset results (but not iterating over them);
1116 - the 'any' and 'explain_no_results' methods;
1117 - adding 'where' filters (but not expanding dimensions accordingly);
1118 - materializations.
1119 """
1121 def check(
1122 query: Query,
1123 dimensions: DimensionGroup = self.raw.dimensions,
1124 ) -> None:
1125 """Run a battery of tests on one of a set of very similar queries
1126 constructed in different ways (see below).
1127 """
1129 def check_query_tree(
1130 tree: qt.QueryTree,
1131 dimensions: DimensionGroup = dimensions,
1132 ) -> None:
1133 """Check the state of the QueryTree object that backs the Query
1134 or a derived QueryResults object.
1136 Parameters
1137 ----------
1138 tree : `lsst.daf.butler.queries.tree.QueryTree`
1139 Object to test.
1140 dimensions : `DimensionGroup`
1141 Dimensions to expect in the `QueryTree`, not necessarily
1142 including those in the test 'raw' dataset type.
1143 """
1144 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions)
1145 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'")
1146 self.assertFalse(tree.materializations)
1147 self.assertFalse(tree.data_coordinate_uploads)
1148 self.assertEqual(tree.datasets.keys(), {"raw"})
1149 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions)
1150 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",))
1151 self.assertEqual(tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions}))
1153 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None:
1154 """Construct a DataCoordinateQueryResults object from the query
1155 with the given arguments and run a battery of tests on it.
1157 Parameters
1158 ----------
1159 *args
1160 Forwarded to `Query.data_ids`.
1161 query : `Query`
1162 Query to start from.
1163 dimensions : `DimensionGroup`, optional
1164 Dimensions the result data IDs should have.
1165 """
1166 with self.assertRaises(_TestQueryExecution) as cm:
1167 list(query.data_ids(*args))
1168 self.assertEqual(
1169 cm.exception.result_spec,
1170 qrs.DataCoordinateResultSpec(dimensions=dimensions),
1171 )
1172 check_query_tree(cm.exception.tree, dimensions=dimensions)
1174 def check_dataset_results(
1175 *args: Any,
1176 query: Query,
1177 find_first: bool = True,
1178 storage_class_name: str = self.raw.storageClass_name,
1179 ) -> None:
1180 """Construct a DatasetRefQueryResults object from the query
1181 with the given arguments and run a battery of tests on it.
1183 Parameters
1184 ----------
1185 *args
1186 Forwarded to `Query.datasets`.
1187 query : `Query`
1188 Query to start from.
1189 find_first : `bool`, optional
1190 Whether to do find-first resolution on the results.
1191 storage_class_name : `str`, optional
1192 Expected name of the storage class for the results.
1193 """
1194 with self.assertRaises(_TestQueryExecution) as cm:
1195 list(query.datasets(*args, find_first=find_first))
1196 self.assertEqual(
1197 cm.exception.result_spec,
1198 qrs.DatasetRefResultSpec(
1199 dataset_type_name="raw",
1200 dimensions=self.raw.dimensions,
1201 storage_class_name=storage_class_name,
1202 find_first=find_first,
1203 ),
1204 )
1205 check_query_tree(cm.exception.tree)
1207 def check_materialization(
1208 kwargs: Mapping[str, Any],
1209 query: Query,
1210 dimensions: DimensionGroup = dimensions,
1211 has_dataset: bool = True,
1212 ) -> None:
1213 """Materialize the query with the given arguments and run a
1214 battery of tests on the result.
1216 Parameters
1217 ----------
1218 kwargs
1219 Forwarded as keyword arguments to `Query.materialize`.
1220 query : `Query`
1221 Query to start from.
1222 dimensions : `DimensionGroup`, optional
1223 Dimensions to expect in the materialization and its derived
1224 query.
1225 has_dataset : `bool`, optional
1226 Whether the query backed by the materialization should
1227 still have the test 'raw' dataset joined in.
1228 """
1229 # Materialize the query and check the query tree sent to the
1230 # driver and the one in the materialized query.
1231 with self.assertRaises(_TestQueryExecution) as cm:
1232 list(query.materialize(**kwargs).data_ids())
1233 derived_tree = cm.exception.tree
1234 self.assertEqual(derived_tree.dimensions, dimensions)
1235 # Predicate should be materialized away; it no longer appears
1236 # in the derived query.
1237 self.assertEqual(str(derived_tree.predicate), "True")
1238 self.assertFalse(derived_tree.data_coordinate_uploads)
1239 if has_dataset:
1240 # Dataset search is still there, even though its existence
1241 # constraint is included in the materialization, because we
1242 # might need to re-join for some result columns in a
1243 # derived query.
1244 self.assertTrue(derived_tree.datasets.keys(), {"raw"})
1245 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions)
1246 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",))
1247 else:
1248 self.assertFalse(derived_tree.datasets)
1249 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items()
1250 self.assertEqual(derived_tree_materialized_dimensions, dimensions)
1251 (
1252 materialized_tree,
1253 materialized_dimensions,
1254 materialized_datasets,
1255 ) = cm.exception.driver.materializations[key]
1256 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions)
1257 if has_dataset:
1258 self.assertEqual(materialized_datasets, {"raw"})
1259 else:
1260 self.assertFalse(materialized_datasets)
1261 check_query_tree(materialized_tree)
1263 # Actual logic for the check() function begins here.
1265 self.assertEqual(query.constraint_dataset_types, {"raw"})
1266 self.assertEqual(query.constraint_dimensions, self.raw.dimensions)
1268 # Adding a constraint on a field for this dataset type should work
1269 # (this constraint will be present in all downstream tests).
1270 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all")
1271 with self.assertRaises(InvalidQueryError):
1272 # Adding constraint on a different dataset should not work.
1273 query.where(query.expression_factory["refcat"].run == "refcats")
1275 # Data IDs, with dimensions defaulted.
1276 check_data_id_results(query=query)
1277 # Dimensions for data IDs the same as defaults.
1278 check_data_id_results(["exposure", "detector"], query=query)
1279 # Dimensions are a subset of the query dimensions.
1280 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"]))
1281 # Dimensions are a superset of the query dimensions.
1282 check_data_id_results(
1283 ["exposure", "detector", "visit"],
1284 query=query,
1285 dimensions=self.universe.conform(["exposure", "detector", "visit"]),
1286 )
1287 # Dimensions are neither a superset nor a subset of the query
1288 # dimensions.
1289 check_data_id_results(
1290 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"])
1291 )
1292 # Dimensions are empty.
1293 check_data_id_results([], query=query, dimensions=self.universe.conform([]))
1295 # Get DatasetRef results, with various arguments and defaulting.
1296 check_dataset_results("raw", query=query)
1297 check_dataset_results("raw", query=query, find_first=True)
1298 check_dataset_results("raw", ["DummyCam/defaults"], query=query)
1299 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True)
1300 check_dataset_results(self.raw, query=query)
1301 check_dataset_results(self.raw, query=query, find_first=True)
1302 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query)
1303 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True)
1305 # Changing collections at this stage is not allowed.
1306 with self.assertRaises(InvalidQueryError):
1307 query.datasets("raw", collections=["DummyCam/calib"])
1309 # Changing storage classes is allowed, if they're compatible.
1310 check_dataset_results(
1311 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy"
1312 )
1313 with self.assertRaises(DatasetTypeError):
1314 # Can't use overrideStorageClass, because it'll raise
1315 # before the code we want to test can.
1316 query.datasets(DatasetType("raw", self.raw.dimensions, "int"))
1318 # Check the 'any' and 'explain_no_results' methods on Query itself.
1319 for execute, exact in itertools.permutations([False, True], 2):
1320 with self.assertRaises(_TestQueryAny) as cm:
1321 query.any(execute=execute, exact=exact)
1322 self.assertEqual(cm.exception.execute, execute)
1323 self.assertEqual(cm.exception.exact, exact)
1324 check_query_tree(cm.exception.tree, dimensions)
1325 with self.assertRaises(_TestQueryExplainNoResults):
1326 query.explain_no_results()
1327 check_query_tree(cm.exception.tree, dimensions)
1329 # Materialize the query with defaults.
1330 check_materialization({}, query=query)
1331 # Materialize the query with args that match defaults.
1332 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query)
1333 # Materialize the query with a superset of the original dimensions.
1334 check_materialization(
1335 {"dimensions": ["exposure", "detector", "visit"]},
1336 query=query,
1337 dimensions=self.universe.conform(["exposure", "visit", "detector"]),
1338 )
1339 # Materialize the query with no datasets.
1340 check_materialization(
1341 {"dimensions": ["exposure", "detector"], "datasets": frozenset()},
1342 query=query,
1343 has_dataset=False,
1344 )
1345 # Materialize the query with no datasets and a subset of the
1346 # dimensions.
1347 check_materialization(
1348 {"dimensions": ["exposure"], "datasets": frozenset()},
1349 query=query,
1350 has_dataset=False,
1351 dimensions=self.universe.conform(["exposure"]),
1352 )
1353 # Materializing the query with a dataset that is not in the query
1354 # is an error.
1355 with self.assertRaises(InvalidQueryError):
1356 query.materialize(datasets={"refcat"})
1357 # Materializing the query with dimensions that are not a superset
1358 # of any materialized dataset dimensions is an error.
1359 with self.assertRaises(InvalidQueryError):
1360 query.materialize(dimensions=["exposure"], datasets={"raw"})
1362 # Actual logic for test_dataset_joins starts here.
1364 # Default collections and existing dataset type name.
1365 check(self.query().join_dataset_search("raw"))
1366 # Default collections and existing DatasetType instance.
1367 check(self.query().join_dataset_search(self.raw))
1368 # Manual collections and existing dataset type.
1369 check(
1370 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"])
1371 )
1372 check(
1373 self.query(default_collections=None).join_dataset_search(
1374 self.raw, collections=["DummyCam/defaults"]
1375 )
1376 )
1377 with self.assertRaises(MissingDatasetTypeError):
1378 # Dataset type does not exist.
1379 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"])
1380 with self.assertRaises(DatasetTypeError):
1381 # Dataset type object with bad dimensions passed.
1382 self.query().join_dataset_search(
1383 DatasetType(
1384 "raw",
1385 dimensions={"detector", "visit"},
1386 storageClass=self.raw.storageClass_name,
1387 universe=self.universe,
1388 )
1389 )
1390 with self.assertRaises(TypeError):
1391 # Bad type for dataset type argument.
1392 self.query().join_dataset_search(3)
1393 with self.assertRaises(InvalidQueryError):
1394 # Cannot pass storage class override to join_dataset_search,
1395 # because we cannot use it there.
1396 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy"))
1398 def test_dimension_record_results(self) -> None:
1399 """Test queries that return dimension records.
1401 This includes tests for:
1403 - joining against uploaded data coordinates;
1404 - counting result rows;
1405 - expanding dimensions as needed for 'where' conditions;
1406 - order_by and limit.
1408 It does not include the iteration methods of
1409 DimensionRecordQueryResults, since those require a different mock
1410 driver setup (see test_dimension_record_iteration).
1411 """
1412 # Set up the base query-results object to test.
1413 query = self.query()
1414 x = query.expression_factory
1415 self.assertFalse(query.constraint_dimensions)
1416 query = query.where(x.skymap == "m")
1417 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"]))
1418 upload_rows = [
1419 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe),
1420 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe),
1421 ]
1422 raw_rows = frozenset([data_id.required_values for data_id in upload_rows])
1423 query = query.join_data_coordinates(upload_rows)
1424 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"]))
1425 results = query.dimension_records("patch")
1426 results = results.where(x.tract == 4)
1428 # Define a closure to run tests on variants of the base query.
1429 def check(
1430 results: DimensionRecordQueryResults,
1431 order_by: Any = (),
1432 limit: int | None = None,
1433 ) -> list[str]:
1434 results = results.order_by(*order_by).limit(limit)
1435 self.assertEqual(results.element.name, "patch")
1436 with self.assertRaises(_TestQueryExecution) as cm:
1437 list(results)
1438 tree = cm.exception.tree
1439 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1440 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"]))
1441 self.assertFalse(tree.materializations)
1442 self.assertFalse(tree.datasets)
1443 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items()
1444 self.assertEqual(upload_dimensions, self.universe.conform(["visit"]))
1445 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows))
1446 result_spec = cm.exception.result_spec
1447 self.assertEqual(result_spec.result_type, "dimension_record")
1448 self.assertEqual(result_spec.element, self.universe["patch"])
1449 self.assertEqual(result_spec.limit, limit)
1450 for exact, discard in itertools.permutations([False, True], r=2):
1451 with self.assertRaises(_TestQueryCount) as cm:
1452 results.count(exact=exact, discard=discard)
1453 self.assertEqual(cm.exception.result_spec, result_spec)
1454 self.assertEqual(cm.exception.exact, exact)
1455 self.assertEqual(cm.exception.discard, discard)
1456 return [str(term) for term in result_spec.order_by]
1458 # Run the closure's tests on variants of the base query.
1459 self.assertEqual(check(results), [])
1460 self.assertEqual(check(results, limit=2), [])
1461 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1462 self.assertEqual(
1463 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1464 ["patch.cell_x", "patch.cell_y DESC"],
1465 )
1466 with self.assertRaises(InvalidQueryError):
1467 # Cannot upload empty list of data IDs.
1468 query.join_data_coordinates([])
1469 with self.assertRaises(InvalidQueryError):
1470 # Cannot upload heterogeneous list of data IDs.
1471 query.join_data_coordinates(
1472 [
1473 DataCoordinate.make_empty(self.universe),
1474 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe),
1475 ]
1476 )
1478 def test_join_data_coordinate_table(self) -> None:
1479 """Test joining a data coordinates via a table."""
1480 query = self.query()
1481 x = query.expression_factory
1482 self.assertFalse(query.constraint_dimensions)
1483 query = query.where(x.skymap == "m")
1484 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"]))
1485 with self.assertRaises(InvalidQueryError):
1486 # Cannot join a table with no rows.
1487 query.join_data_coordinate_table(
1488 astropy.table.Table({"instrument": np.array([]), "visit": np.array([])})
1489 )
1490 with self.assertRaises(InvalidQueryError):
1491 # Cannot join a table with missing required dimensions.
1492 query.join_data_coordinate_table(astropy.table.Table({"visit": np.array([1, 3, 5])}))
1493 upload_table = astropy.table.Table({"tract": np.array([2, 5, 7], dtype=np.int64)})
1494 raw_rows = frozenset([("m", 2), ("m", 5), ("m", 7)])
1495 query = query.join_data_coordinate_table(upload_table)
1496 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "tract"]))
1497 results = query.dimension_records("patch")
1498 with self.assertRaises(_TestQueryExecution) as cm:
1499 list(results)
1500 tree = cm.exception.tree
1501 driver = cm.exception.driver
1502 self.assertEqual(str(tree.predicate), "skymap == 'm'")
1503 self.assertEqual(tree.dimensions, self.universe.conform(["patch"]))
1504 self.assertFalse(tree.materializations)
1505 self.assertFalse(tree.datasets)
1506 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items()
1507 self.assertEqual(upload_dimensions, self.universe.conform(["tract"]))
1508 self.assertEqual(driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows))
1509 for row in driver.data_coordinate_uploads[key][1]:
1510 for value in row:
1511 # Make sure numpy scalar types have been cast to builtins.
1512 self.assertIsInstance(value, (int, str))
1514 def test_dimension_record_iteration(self) -> None:
1515 """Tests for DimensionRecordQueryResult iteration."""
1517 def make_record(n: int) -> DimensionRecord:
1518 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n)
1520 result_rows = (
1521 [make_record(n) for n in range(3)],
1522 [make_record(n) for n in range(3, 6)],
1523 [make_record(10)],
1524 )
1525 results = self.query(result_rows=result_rows).dimension_records("patch")
1526 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1527 self.assertEqual(
1528 list(results.iter_set_pages()),
1529 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows],
1530 )
1531 self.assertEqual(
1532 [table.column("id").to_pylist() for table in results.iter_table_pages()],
1533 [list(range(3)), list(range(3, 6)), [10]],
1534 )
1536 def test_data_coordinate_results(self) -> None:
1537 """Test queries that return data coordinates.
1539 This includes tests for:
1541 - counting result rows;
1542 - expanding dimensions as needed for 'where' conditions;
1543 - order_by and limit.
1545 It does not include the iteration methods of
1546 DataCoordinateQueryResults, since those require a different mock
1547 driver setup (see test_data_coordinate_iteration). More tests for
1548 different inputs to DataCoordinateQueryResults construction are in
1549 test_dataset_join.
1550 """
1551 # Set up the base query-results object to test.
1552 query = self.query()
1553 x = query.expression_factory
1554 self.assertFalse(query.constraint_dimensions)
1555 query = query.where(x.skymap == "m")
1556 results = query.data_ids(["patch", "band"])
1557 results = results.where(x.tract == 4)
1559 # Define a closure to run tests on variants of the base query.
1560 def check(
1561 results: DataCoordinateQueryResults,
1562 order_by: Any = (),
1563 limit: int | None = None,
1564 include_dimension_records: bool = False,
1565 ) -> list[str]:
1566 results = results.order_by(*order_by).limit(limit)
1567 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"]))
1568 with self.assertRaises(_TestQueryExecution) as cm:
1569 list(results)
1570 tree = cm.exception.tree
1571 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4")
1572 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"]))
1573 self.assertFalse(tree.materializations)
1574 self.assertFalse(tree.datasets)
1575 self.assertFalse(tree.data_coordinate_uploads)
1576 result_spec = cm.exception.result_spec
1577 self.assertEqual(result_spec.result_type, "data_coordinate")
1578 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"]))
1579 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1580 self.assertEqual(result_spec.limit, limit)
1581 self.assertIsNone(result_spec.find_first_dataset)
1582 for exact, discard in itertools.permutations([False, True], r=2):
1583 with self.assertRaises(_TestQueryCount) as cm:
1584 results.count(exact=exact, discard=discard)
1585 self.assertEqual(cm.exception.result_spec, result_spec)
1586 self.assertEqual(cm.exception.exact, exact)
1587 self.assertEqual(cm.exception.discard, discard)
1588 return [str(term) for term in result_spec.order_by]
1590 # Run the closure's tests on variants of the base query.
1591 self.assertEqual(check(results), [])
1592 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), [])
1593 self.assertEqual(
1594 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1595 [],
1596 )
1597 self.assertEqual(check(results, limit=2), [])
1598 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"])
1599 self.assertEqual(
1600 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]),
1601 ["patch.cell_x", "patch.cell_y DESC"],
1602 )
1603 self.assertEqual(
1604 check(results, order_by=["patch.cell_x", "-cell_y"]),
1605 ["patch.cell_x", "patch.cell_y DESC"],
1606 )
1608 def test_data_coordinate_iteration(self) -> None:
1609 """Tests for DataCoordinateQueryResult iteration."""
1611 def make_data_id(n: int) -> DimensionRecord:
1612 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe)
1614 result_rows = (
1615 [make_data_id(n) for n in range(3)],
1616 [make_data_id(n) for n in range(3, 6)],
1617 [make_data_id(10)],
1618 )
1619 results = self.query(result_rows=result_rows).data_ids(["patch"])
1620 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1622 def test_dataset_results(self) -> None:
1623 """Test queries that return dataset refs.
1625 This includes tests for:
1627 - counting result rows;
1628 - expanding dimensions as needed for 'where' conditions;
1629 - different ways of passing a data ID to 'where' methods;
1630 - order_by and limit.
1632 It does not include the iteration methods of the DatasetRefQueryResults
1633 classes, since those require a different mock driver setup (see
1634 test_dataset_iteration). More tests for different inputs to
1635 DatasetRefQueryResults construction are in test_dataset_join.
1636 """
1637 # Set up a few equivalent base query-results object to test.
1638 query = self.query()
1639 x = query.expression_factory
1640 self.assertFalse(query.constraint_dimensions)
1641 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4)
1642 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where(
1643 {"instrument": "DummyCam", "visit": 4}
1644 )
1645 results3 = query.datasets("raw").where(
1646 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe)
1647 )
1649 # Define a closure to check a DatasetRefQueryResults instance.
1650 def check(
1651 results: DatasetRefQueryResults,
1652 order_by: Any = (),
1653 limit: int | None = None,
1654 include_dimension_records: bool = False,
1655 ) -> list[str]:
1656 results = results.order_by(*order_by).limit(limit)
1657 with self.assertRaises(_TestQueryExecution) as cm:
1658 list(results)
1659 tree = cm.exception.tree
1660 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4")
1661 self.assertEqual(
1662 tree.dimensions,
1663 self.universe.conform(["visit"]).union(results.dataset_type.dimensions),
1664 )
1665 self.assertFalse(tree.materializations)
1666 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name})
1667 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",))
1668 self.assertEqual(
1669 tree.datasets[results.dataset_type.name].dimensions,
1670 results.dataset_type.dimensions,
1671 )
1672 self.assertFalse(tree.data_coordinate_uploads)
1673 result_spec = cm.exception.result_spec
1674 self.assertEqual(result_spec.result_type, "dataset_ref")
1675 self.assertEqual(result_spec.include_dimension_records, include_dimension_records)
1676 self.assertEqual(result_spec.limit, limit)
1677 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name)
1678 for exact, discard in itertools.permutations([False, True], r=2):
1679 with self.assertRaises(_TestQueryCount) as cm:
1680 results.count(exact=exact, discard=discard)
1681 self.assertEqual(cm.exception.result_spec, result_spec)
1682 self.assertEqual(cm.exception.exact, exact)
1683 self.assertEqual(cm.exception.discard, discard)
1684 with self.assertRaises(_TestQueryExecution) as cm:
1685 list(results.data_ids)
1686 self.assertEqual(
1687 cm.exception.result_spec,
1688 qrs.DataCoordinateResultSpec(
1689 dimensions=results.dataset_type.dimensions,
1690 include_dimension_records=include_dimension_records,
1691 ),
1692 )
1693 self.assertIs(cm.exception.tree, tree)
1694 return [str(term) for term in result_spec.order_by]
1696 # Run the closure's tests on variants of the base query.
1697 self.assertEqual(check(results1), [])
1698 self.assertEqual(check(results2), [])
1699 self.assertEqual(check(results3), [])
1700 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), [])
1701 self.assertEqual(
1702 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True),
1703 [],
1704 )
1705 self.assertEqual(check(results1, limit=2), [])
1706 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"])
1707 self.assertEqual(check(results1, order_by=["detector"]), ["detector"])
1708 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"])
1710 def test_dataset_iteration(self) -> None:
1711 """Tests for SingleTypeDatasetQueryResult iteration."""
1713 def make_ref(n: int) -> DimensionRecord:
1714 return DatasetRef(
1715 self.raw,
1716 DataCoordinate.standardize(
1717 instrument="DummyCam", exposure=4, detector=n, universe=self.universe
1718 ),
1719 run="DummyCam/raw/all",
1720 id=uuid.uuid4(),
1721 )
1723 result_rows = (
1724 [make_ref(n) for n in range(3)],
1725 [make_ref(n) for n in range(3, 6)],
1726 [make_ref(10)],
1727 )
1728 results = self.query(result_rows=result_rows).datasets("raw")
1729 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows)))
1731 def test_identifiers(self) -> None:
1732 """Test edge-cases of identifiers in order_by expressions."""
1734 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]:
1735 with self.assertRaises(_TestQueryExecution) as cm:
1736 list(results)
1737 return [str(term) for term in cm.exception.result_spec.order_by]
1739 self.assertEqual(
1740 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")),
1741 ["day_obs.timespan.begin DESC"],
1742 )
1743 self.assertEqual(
1744 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")),
1745 ["day_obs.timespan.end"],
1746 )
1747 self.assertEqual(
1748 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")),
1749 ["visit.timespan.begin DESC"],
1750 )
1751 self.assertEqual(
1752 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")),
1753 ["visit.timespan.end"],
1754 )
1755 self.assertEqual(
1756 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")),
1757 ["visit.science_program"],
1758 )
1759 self.assertEqual(
1760 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")),
1761 ["visit"],
1762 )
1763 self.assertEqual(
1764 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")),
1765 ["physical_filter"],
1766 )
1767 with self.assertRaises(TypeError):
1768 self.query().data_ids(["visit"]).order_by(3)
1769 with self.assertRaises(InvalidQueryError):
1770 self.query().data_ids(["visit"]).order_by("visit.region")
1771 with self.assertRaisesRegex(InvalidQueryError, "Ambiguous"):
1772 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin")
1773 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1774 self.query().data_ids(["visit", "exposure"]).order_by("blarg")
1775 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1776 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse")
1777 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"):
1778 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey")
1779 with self.assertRaisesRegex(InvalidQueryError, "not valid for datasets"):
1780 self.query().datasets("raw").order_by("raw.seq_num")
1782 def test_invalid_models(self) -> None:
1783 """Test invalid models and combinations of models that cannot be
1784 constructed via the public Query and *QueryResults interfaces.
1785 """
1786 x = ExpressionFactory(self.universe)
1787 with self.assertRaises(InvalidQueryError):
1788 # QueryTree dimensions do not cover dataset dimensions.
1789 qt.QueryTree(
1790 dimensions=self.universe.conform(["visit"]),
1791 datasets={
1792 "raw": qt.DatasetSearch(
1793 collections=("DummyCam/raw/all",),
1794 dimensions=self.raw.dimensions,
1795 )
1796 },
1797 )
1798 with self.assertRaises(InvalidQueryError):
1799 # QueryTree dimensions do no cover predicate dimensions.
1800 qt.QueryTree(
1801 dimensions=self.universe.conform(["visit"]),
1802 predicate=(x.detector > 5),
1803 )
1804 with self.assertRaises(InvalidQueryError):
1805 # Predicate references a dataset not in the QueryTree.
1806 qt.QueryTree(
1807 dimensions=self.universe.conform(["exposure", "detector"]),
1808 predicate=(x["raw"].collection == "bird"),
1809 )
1810 with self.assertRaises(InvalidQueryError):
1811 # ResultSpec's dimensions are not a subset of the query tree's.
1812 DimensionRecordQueryResults(
1813 _TestQueryDriver(),
1814 qt.QueryTree(dimensions=self.universe.conform(["tract"])),
1815 qrs.DimensionRecordResultSpec(element=self.universe["detector"]),
1816 )
1817 with self.assertRaises(InvalidQueryError):
1818 # ResultSpec's datasets are not a subset of the query tree's.
1819 DatasetRefQueryResults(
1820 _TestQueryDriver(),
1821 qt.QueryTree(dimensions=self.raw.dimensions),
1822 qrs.DatasetRefResultSpec(
1823 dataset_type_name="raw",
1824 dimensions=self.raw.dimensions,
1825 storage_class_name=self.raw.storageClass_name,
1826 find_first=True,
1827 ),
1828 )
1829 with self.assertRaises(InvalidQueryError):
1830 # ResultSpec's order_by expression is not related to the dimensions
1831 # we're returning.
1832 x = ExpressionFactory(self.universe)
1833 DimensionRecordQueryResults(
1834 _TestQueryDriver(),
1835 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1836 qrs.DimensionRecordResultSpec(
1837 element=self.universe["detector"], order_by=(x.unwrap(x.visit),)
1838 ),
1839 )
1840 with self.assertRaises(InvalidQueryError):
1841 # ResultSpec's order_by expression is not related to the datasets
1842 # we're returning.
1843 x = ExpressionFactory(self.universe)
1844 DimensionRecordQueryResults(
1845 _TestQueryDriver(),
1846 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])),
1847 qrs.DimensionRecordResultSpec(
1848 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),)
1849 ),
1850 )
1852 def test_general_result_spec(self) -> None:
1853 """Tests for GeneralResultSpec.
1855 Unlike the other ResultSpec objects, we don't have a *QueryResults
1856 class for GeneralResultSpec yet, so we can't use the higher-level
1857 interfaces to test it like we can the others.
1858 """
1859 a = qrs.GeneralResultSpec(
1860 dimensions=self.universe.conform(["detector"]),
1861 dimension_fields={"detector": {"purpose"}},
1862 dataset_fields={},
1863 find_first=False,
1864 )
1865 self.assertEqual(a.find_first_dataset, None)
1866 a_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1867 a_columns.dimension_fields["detector"].add("purpose")
1868 self.assertEqual(a.get_result_columns(), a_columns)
1869 b = qrs.GeneralResultSpec(
1870 dimensions=self.universe.conform(["detector"]),
1871 dimension_fields={},
1872 dataset_fields={"bias": {"timespan", "dataset_id"}},
1873 find_first=True,
1874 )
1875 self.assertEqual(b.find_first_dataset, "bias")
1876 b_columns = qt.ColumnSet(self.universe.conform(["detector"]))
1877 b_columns.dataset_fields["bias"].add("timespan")
1878 b_columns.dataset_fields["bias"].add("dataset_id")
1879 self.assertEqual(b.get_result_columns(), b_columns)
1880 with self.assertRaises(InvalidQueryError):
1881 # More than one dataset type with find_first
1882 qrs.GeneralResultSpec(
1883 dimensions=self.universe.conform(["detector", "exposure"]),
1884 dimension_fields={},
1885 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}},
1886 find_first=True,
1887 )
1888 with self.assertRaises(InvalidQueryError):
1889 # Out-of-bounds dimension fields.
1890 qrs.GeneralResultSpec(
1891 dimensions=self.universe.conform(["detector"]),
1892 dimension_fields={"visit": {"name"}},
1893 dataset_fields={},
1894 find_first=False,
1895 )
1896 with self.assertRaises(InvalidQueryError):
1897 # No fields for dimension element.
1898 qrs.GeneralResultSpec(
1899 dimensions=self.universe.conform(["detector"]),
1900 dimension_fields={"detector": set()},
1901 dataset_fields={},
1902 find_first=True,
1903 )
1904 with self.assertRaises(InvalidQueryError):
1905 # No fields for dataset.
1906 qrs.GeneralResultSpec(
1907 dimensions=self.universe.conform(["detector"]),
1908 dimension_fields={},
1909 dataset_fields={"bias": set()},
1910 find_first=True,
1911 )
1914if __name__ == "__main__":
1915 unittest.main()