Coverage for tests / test_query_interface.py: 11%

954 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:37 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests for the public Butler.query interface and the Pydantic models that 

29back it using a mock column-expression visitor and a mock QueryDriver 

30implementation. 

31 

32These tests are entirely independent of which kind of butler or database 

33backend we're using. 

34 

35This is a very large test file because a lot of tests make use of those mocks, 

36but they're not so generally useful that I think they're worth putting in the 

37library proper. 

38""" 

39 

40from __future__ import annotations 

41 

42import dataclasses 

43import itertools 

44import unittest 

45import uuid 

46from collections.abc import Iterable, Iterator, Mapping, Set 

47from typing import Any 

48 

49import astropy.table 

50import astropy.time 

51import numpy as np 

52 

53from lsst.daf.butler import ( 

54 CollectionType, 

55 DataCoordinate, 

56 DataIdValue, 

57 DatasetRef, 

58 DatasetType, 

59 DimensionGroup, 

60 DimensionRecord, 

61 DimensionRecordSet, 

62 DimensionUniverse, 

63 InvalidQueryError, 

64 MissingDatasetTypeError, 

65 NamedValueSet, 

66 NoDefaultCollectionError, 

67 Timespan, 

68) 

69from lsst.daf.butler.queries import ( 

70 DataCoordinateQueryResults, 

71 DatasetRefQueryResults, 

72 DimensionRecordQueryResults, 

73 Query, 

74) 

75from lsst.daf.butler.queries import driver as qd 

76from lsst.daf.butler.queries import result_specs as qrs 

77from lsst.daf.butler.queries import tree as qt 

78from lsst.daf.butler.queries.expression_factory import ExpressionFactory 

79from lsst.daf.butler.queries.tree._column_expression import UnaryExpression 

80from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands 

81from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor 

82from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError 

83from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord 

84from lsst.sphgeom import DISJOINT, Mq3cPixelization 

85 

86 

87class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]): 

88 """Test visitor for column expressions. 

89 

90 This visitor evaluates column expressions using regular Python logic. 

91 

92 Parameters 

93 ---------- 

94 dimension_keys : `~collections.abc.Mapping`, optional 

95 Mapping from dimension name to the value it should be assigned by the 

96 visitor. 

97 dimension_fields : `~collections.abc.Mapping`, optional 

98 Mapping from ``(dimension element name, field)`` tuple to the value it 

99 should be assigned by the visitor. 

100 dataset_fields : `~collections.abc.Mapping`, optional 

101 Mapping from ``(dataset type name, field)`` tuple to the value it 

102 should be assigned by the visitor. 

103 query_tree_items : `~collections.abc.Set`, optional 

104 Set that should be used as the right-hand side of element-in-query 

105 predicates. 

106 """ 

107 

108 def __init__( 

109 self, 

110 dimension_keys: Mapping[str, Any] | None = None, 

111 dimension_fields: Mapping[tuple[str, str], Any] | None = None, 

112 dataset_fields: Mapping[tuple[str, str], Any] | None = None, 

113 query_tree_items: Set[Any] = frozenset(), 

114 ): 

115 self.dimension_keys = dimension_keys or {} 

116 self.dimension_fields = dimension_fields or {} 

117 self.dataset_fields = dataset_fields or {} 

118 self.query_tree_items = query_tree_items 

119 

120 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any: 

121 match expression.operator: 

122 case "+": 

123 return expression.a.visit(self) + expression.b.visit(self) 

124 case "-": 

125 return expression.a.visit(self) - expression.b.visit(self) 

126 case "*": 

127 return expression.a.visit(self) * expression.b.visit(self) 

128 case "/": 

129 match expression.column_type: 

130 case "int": 

131 return expression.a.visit(self) // expression.b.visit(self) 

132 case "float": 

133 return expression.a.visit(self) / expression.b.visit(self) 

134 case "%": 

135 return expression.a.visit(self) % expression.b.visit(self) 

136 

137 def visit_comparison( 

138 self, 

139 a: qt.ColumnExpression, 

140 operator: qt.ComparisonOperator, 

141 b: qt.ColumnExpression, 

142 flags: PredicateVisitFlags, 

143 ) -> bool: 

144 match operator: 

145 case "==": 

146 return a.visit(self) == b.visit(self) 

147 case "!=": 

148 return a.visit(self) != b.visit(self) 

149 case "<": 

150 return a.visit(self) < b.visit(self) 

151 case ">": 

152 return a.visit(self) > b.visit(self) 

153 case "<=": 

154 return a.visit(self) <= b.visit(self) 

155 case ">=": 

156 return a.visit(self) >= b.visit(self) 

157 case "overlaps": 

158 return not (a.visit(self).relate(b.visit(self)) & DISJOINT) 

159 

160 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any: 

161 return self.dataset_fields[expression.dataset_type, expression.field] 

162 

163 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any: 

164 return self.dimension_fields[expression.element.name, expression.field] 

165 

166 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any: 

167 return self.dimension_keys[expression.dimension.name] 

168 

169 def visit_in_container( 

170 self, 

171 member: qt.ColumnExpression, 

172 container: tuple[qt.ColumnExpression, ...], 

173 flags: PredicateVisitFlags, 

174 ) -> bool: 

175 return member.visit(self) in [item.visit(self) for item in container] 

176 

177 def visit_in_range( 

178 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags 

179 ) -> bool: 

180 return member.visit(self) in range(start, stop, step) 

181 

182 def visit_in_query_tree( 

183 self, 

184 member: qt.ColumnExpression, 

185 column: qt.ColumnExpression, 

186 query_tree: qt.QueryTree, 

187 flags: PredicateVisitFlags, 

188 ) -> bool: 

189 return member.visit(self) in self.query_tree_items 

190 

191 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool: 

192 return operand.visit(self) is None 

193 

194 def visit_literal(self, expression: qt.ColumnLiteral) -> Any: 

195 return expression.get_literal_value() 

196 

197 def visit_reversed(self, expression: qt.Reversed) -> Any: 

198 return _TestReversed(expression.operand.visit(self)) 

199 

200 def visit_unary_expression(self, expression: UnaryExpression) -> Any: 

201 match expression.operator: 

202 case "-": 

203 return -expression.operand.visit(self) 

204 case "begin_of": 

205 return expression.operand.visit(self).begin 

206 case "end_of": 

207 return expression.operand.visit(self).end 

208 

209 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool: 

210 return all(results) 

211 

212 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool: 

213 return not result 

214 

215 def apply_logical_or( 

216 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags 

217 ) -> bool: 

218 return any(results) 

219 

220 

221@dataclasses.dataclass 

222class _TestReversed: 

223 """Struct used by _TestVisitor" to mark an expression as reversed in sort 

224 order. 

225 """ 

226 

227 operand: Any 

228 

229 

230class _TestQueryExecution(BaseException): 

231 """Exception raised by _TestQueryDriver.execute to communicate its args 

232 back to the caller. 

233 """ 

234 

235 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None: 

236 self.result_spec = result_spec 

237 self.tree = tree 

238 self.driver = driver 

239 

240 

241class _TestQueryCount(BaseException): 

242 """Exception raised by _TestQueryDriver.count to communicate its args 

243 back to the caller. 

244 """ 

245 

246 def __init__( 

247 self, 

248 result_spec: qrs.ResultSpec, 

249 tree: qt.QueryTree, 

250 driver: _TestQueryDriver, 

251 exact: bool, 

252 discard: bool, 

253 ) -> None: 

254 self.result_spec = result_spec 

255 self.tree = tree 

256 self.driver = driver 

257 self.exact = exact 

258 self.discard = discard 

259 

260 

261class _TestQueryAny(BaseException): 

262 """Exception raised by _TestQueryDriver.any to communicate its args 

263 back to the caller. 

264 """ 

265 

266 def __init__( 

267 self, 

268 tree: qt.QueryTree, 

269 driver: _TestQueryDriver, 

270 exact: bool, 

271 execute: bool, 

272 ) -> None: 

273 self.tree = tree 

274 self.driver = driver 

275 self.exact = exact 

276 self.execute = execute 

277 

278 

279class _TestQueryExplainNoResults(BaseException): 

280 """Exception raised by _TestQueryDriver.explain_no_results to communicate 

281 its args back to the caller. 

282 """ 

283 

284 def __init__( 

285 self, 

286 tree: qt.QueryTree, 

287 driver: _TestQueryDriver, 

288 execute: bool, 

289 ) -> None: 

290 self.tree = tree 

291 self.driver = driver 

292 self.execute = execute 

293 

294 

295class _TestQueryDriver(qd.QueryDriver): 

296 """Mock implementation of `QueryDriver` that mostly raises exceptions that 

297 communicate the arguments its methods were called with. 

298 

299 Parameters 

300 ---------- 

301 default_collections : `tuple` [ `str`, ... ], optional 

302 Default collection the query or parent butler is imagined to have been 

303 constructed with. 

304 collection_info : `~collections.abc.Mapping`, optional 

305 Mapping from collection name to its record and summary, simulating the 

306 collections present in the data repository. 

307 dataset_types : `~collections.abc.Mapping`, optional 

308 Mapping from dataset type to its definition, simulating the dataset 

309 types registered in the data repository. 

310 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional 

311 A tuple of iterables of arbitrary type to use as result rows any time 

312 `execute` is called, with each nested iterable considered a separate 

313 page. The result type is not checked for consistency with the result 

314 spec. If this is not provided, `execute` will instead raise 

315 `_TestQueryExecution`, and `fetch_page` will not do anything useful. 

316 """ 

317 

318 def __init__( 

319 self, 

320 default_collections: tuple[str, ...] | None = None, 

321 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None, 

322 dataset_types: Mapping[str, DatasetType] | None = None, 

323 result_rows: tuple[Iterable[Any], ...] | None = None, 

324 ) -> None: 

325 self._universe = DimensionUniverse() 

326 # Mapping of the arguments passed to materialize, keyed by the UUID 

327 # that that each call returned. 

328 self.materializations: dict[ 

329 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]] 

330 ] = {} 

331 # Mapping of the arguments passed to upload_data_coordinates, keyed by 

332 # the UUID that that each call returned. 

333 self.data_coordinate_uploads: dict[ 

334 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]] 

335 ] = {} 

336 self._default_collections = default_collections 

337 self._collection_info = collection_info or {} 

338 self._dataset_types = dataset_types or {} 

339 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = [] 

340 self._result_rows = result_rows 

341 

342 @property 

343 def universe(self) -> DimensionUniverse: 

344 return self._universe 

345 

346 def __enter__(self) -> None: 

347 pass 

348 

349 def __exit__(self, *args: Any, **kwargs: Any) -> None: 

350 pass 

351 

352 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> Iterator[qd.ResultPage]: 

353 if self._result_rows is None: 

354 raise _TestQueryExecution(result_spec, tree, self) 

355 

356 for rows in self._result_rows: 

357 yield self._make_next_page(result_spec, rows) 

358 

359 def _make_next_page(self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any]) -> qd.ResultPage: 

360 match result_spec: 

361 case qrs.DataCoordinateResultSpec(): 

362 return qd.DataCoordinateResultPage(spec=result_spec, rows=current_rows) 

363 case qrs.DimensionRecordResultSpec(): 

364 return qd.DimensionRecordResultPage(spec=result_spec, rows=current_rows) 

365 case qrs.DatasetRefResultSpec(): 

366 return qd.DatasetRefResultPage(spec=result_spec, rows=current_rows) 

367 case _: 

368 raise NotImplementedError("Other query types not yet supported.") 

369 

370 def materialize( 

371 self, 

372 tree: qt.QueryTree, 

373 dimensions: DimensionGroup, 

374 datasets: frozenset[str], 

375 allow_duplicate_overlaps: bool = False, 

376 ) -> qd.MaterializationKey: 

377 key = uuid.uuid4() 

378 self.materializations[key] = (tree, dimensions, datasets) 

379 return key 

380 

381 def upload_data_coordinates( 

382 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] 

383 ) -> qd.DataCoordinateUploadKey: 

384 key = uuid.uuid4() 

385 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows)) 

386 return key 

387 

388 def count( 

389 self, 

390 tree: qt.QueryTree, 

391 result_spec: qrs.ResultSpec, 

392 *, 

393 exact: bool, 

394 discard: bool, 

395 ) -> int: 

396 raise _TestQueryCount(result_spec, tree, self, exact, discard) 

397 

398 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: 

399 raise _TestQueryAny(tree, self, exact, execute) 

400 

401 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: 

402 raise _TestQueryExplainNoResults(tree, self, execute) 

403 

404 def get_default_collections(self) -> tuple[str, ...]: 

405 if self._default_collections is None: 

406 raise NoDefaultCollectionError() 

407 return self._default_collections 

408 

409 def get_dataset_type(self, name: str) -> DatasetType: 

410 try: 

411 return self._dataset_types[name] 

412 except KeyError: 

413 raise MissingDatasetTypeError(name) 

414 

415 

416class ColumnExpressionsTestCase(unittest.TestCase): 

417 """Tests for column expression objects in lsst.daf.butler.queries.tree.""" 

418 

419 def setUp(self) -> None: 

420 self.universe = DimensionUniverse() 

421 self.x = ExpressionFactory(self.universe) 

422 

423 def query(self, **kwargs: Any) -> Query: 

424 """Make an initial Query object with the given kwargs used to 

425 initialize the _TestQueryDriver. 

426 """ 

427 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

428 

429 def test_int_literals(self) -> None: 

430 expr = self.x.unwrap(self.x.literal(5)) 

431 self.assertEqual(expr.value, 5) 

432 self.assertEqual(expr.get_literal_value(), 5) 

433 self.assertEqual(expr.expression_type, "int") 

434 self.assertEqual(expr.column_type, "int") 

435 self.assertEqual(str(expr), "5") 

436 self.assertTrue(expr.is_literal) 

437 columns = qt.ColumnSet(self.universe.empty) 

438 expr.gather_required_columns(columns) 

439 self.assertFalse(columns) 

440 self.assertEqual(expr.visit(_TestVisitor()), 5) 

441 

442 def test_string_literals(self) -> None: 

443 expr = self.x.unwrap(self.x.literal("five")) 

444 self.assertEqual(expr.value, "five") 

445 self.assertEqual(expr.get_literal_value(), "five") 

446 self.assertEqual(expr.expression_type, "string") 

447 self.assertEqual(expr.column_type, "string") 

448 self.assertEqual(str(expr), "'five'") 

449 self.assertTrue(expr.is_literal) 

450 columns = qt.ColumnSet(self.universe.empty) 

451 expr.gather_required_columns(columns) 

452 self.assertFalse(columns) 

453 self.assertEqual(expr.visit(_TestVisitor()), "five") 

454 

455 def test_float_literals(self) -> None: 

456 expr = self.x.unwrap(self.x.literal(0.5)) 

457 self.assertEqual(expr.value, 0.5) 

458 self.assertEqual(expr.get_literal_value(), 0.5) 

459 self.assertEqual(expr.expression_type, "float") 

460 self.assertEqual(expr.column_type, "float") 

461 self.assertEqual(str(expr), "0.5") 

462 self.assertTrue(expr.is_literal) 

463 columns = qt.ColumnSet(self.universe.empty) 

464 expr.gather_required_columns(columns) 

465 self.assertFalse(columns) 

466 self.assertEqual(expr.visit(_TestVisitor()), 0.5) 

467 

468 def test_hash_literals(self) -> None: 

469 expr = self.x.unwrap(self.x.literal(b"eleven")) 

470 self.assertEqual(expr.value, b"eleven") 

471 self.assertEqual(expr.get_literal_value(), b"eleven") 

472 self.assertEqual(expr.expression_type, "hash") 

473 self.assertEqual(expr.column_type, "hash") 

474 self.assertEqual(str(expr), "(bytes)") 

475 self.assertTrue(expr.is_literal) 

476 columns = qt.ColumnSet(self.universe.empty) 

477 expr.gather_required_columns(columns) 

478 self.assertFalse(columns) 

479 self.assertEqual(expr.visit(_TestVisitor()), b"eleven") 

480 

481 def test_uuid_literals(self) -> None: 

482 value = uuid.uuid4() 

483 expr = self.x.unwrap(self.x.literal(value)) 

484 self.assertEqual(expr.value, value) 

485 self.assertEqual(expr.get_literal_value(), value) 

486 self.assertEqual(expr.expression_type, "uuid") 

487 self.assertEqual(expr.column_type, "uuid") 

488 self.assertEqual(str(expr), str(value)) 

489 self.assertTrue(expr.is_literal) 

490 columns = qt.ColumnSet(self.universe.empty) 

491 expr.gather_required_columns(columns) 

492 self.assertFalse(columns) 

493 self.assertEqual(expr.visit(_TestVisitor()), value) 

494 

495 def test_datetime_literals(self) -> None: 

496 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

497 expr = self.x.unwrap(self.x.literal(value)) 

498 self.assertEqual(expr.value, value) 

499 self.assertEqual(expr.get_literal_value(), value) 

500 self.assertEqual(expr.expression_type, "datetime") 

501 self.assertEqual(expr.column_type, "datetime") 

502 self.assertEqual(str(expr), "2020-01-01T00:00:00") 

503 self.assertTrue(expr.is_literal) 

504 columns = qt.ColumnSet(self.universe.empty) 

505 expr.gather_required_columns(columns) 

506 self.assertFalse(columns) 

507 self.assertEqual(expr.visit(_TestVisitor()), value) 

508 

509 def test_timespan_literals(self) -> None: 

510 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

511 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

512 value = Timespan(begin, end) 

513 expr = self.x.unwrap(self.x.literal(value)) 

514 self.assertEqual(expr.value, value) 

515 self.assertEqual(expr.get_literal_value(), value) 

516 self.assertEqual(expr.expression_type, "timespan") 

517 self.assertEqual(expr.column_type, "timespan") 

518 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)") 

519 self.assertTrue(expr.is_literal) 

520 columns = qt.ColumnSet(self.universe.empty) 

521 expr.gather_required_columns(columns) 

522 self.assertFalse(columns) 

523 self.assertEqual(expr.visit(_TestVisitor()), value) 

524 

525 def test_region_literals(self) -> None: 

526 pixelization = Mq3cPixelization(10) 

527 value = pixelization.quad(12058870) 

528 expr = self.x.unwrap(self.x.literal(value)) 

529 self.assertEqual(expr.value, value) 

530 self.assertEqual(expr.get_literal_value(), value) 

531 self.assertEqual(expr.expression_type, "region") 

532 self.assertEqual(expr.column_type, "region") 

533 self.assertEqual(str(expr), "(region)") 

534 self.assertTrue(expr.is_literal) 

535 columns = qt.ColumnSet(self.universe.empty) 

536 expr.gather_required_columns(columns) 

537 self.assertFalse(columns) 

538 self.assertEqual(expr.visit(_TestVisitor()), value) 

539 

540 def test_invalid_literal(self) -> None: 

541 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."): 

542 self.x.literal(5j) 

543 

544 def test_dimension_key_reference(self) -> None: 

545 expr = self.x.unwrap(self.x.detector) 

546 self.assertIsNone(expr.get_literal_value()) 

547 self.assertEqual(expr.expression_type, "dimension_key") 

548 self.assertEqual(expr.column_type, "int") 

549 self.assertEqual(str(expr), "detector") 

550 self.assertFalse(expr.is_literal) 

551 columns = qt.ColumnSet(self.universe.empty) 

552 expr.gather_required_columns(columns) 

553 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

554 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3) 

555 

556 def test_dimension_field_reference(self) -> None: 

557 expr = self.x.unwrap(self.x.detector.purpose) 

558 self.assertIsNone(expr.get_literal_value()) 

559 self.assertEqual(expr.expression_type, "dimension_field") 

560 self.assertEqual(expr.column_type, "string") 

561 self.assertEqual(str(expr), "detector.purpose") 

562 self.assertFalse(expr.is_literal) 

563 columns = qt.ColumnSet(self.universe.empty) 

564 expr.gather_required_columns(columns) 

565 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

566 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

567 with self.assertRaises(InvalidQueryError): 

568 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region") 

569 self.assertEqual( 

570 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science" 

571 ) 

572 

573 def test_dataset_field_reference(self) -> None: 

574 expr = self.x.unwrap(self.x["raw"].ingest_date) 

575 self.assertIsNone(expr.get_literal_value()) 

576 self.assertEqual(expr.expression_type, "dataset_field") 

577 self.assertEqual(str(expr), "raw.ingest_date") 

578 self.assertFalse(expr.is_literal) 

579 columns = qt.ColumnSet(self.universe.empty) 

580 expr.gather_required_columns(columns) 

581 self.assertEqual(columns.dimensions, self.universe.empty) 

582 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"}) 

583 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid") 

584 self.assertEqual( 

585 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string" 

586 ) 

587 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string") 

588 self.assertEqual( 

589 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "ingest_date" 

590 ) 

591 self.assertEqual( 

592 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan" 

593 ) 

594 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

595 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value) 

596 

597 def test_unary_negation(self) -> None: 

598 expr = self.x.unwrap(-self.x.visit.exposure_time) 

599 self.assertIsNone(expr.get_literal_value()) 

600 self.assertEqual(expr.expression_type, "unary") 

601 self.assertEqual(expr.column_type, "float") 

602 self.assertEqual(str(expr), "-visit.exposure_time") 

603 self.assertFalse(expr.is_literal) 

604 columns = qt.ColumnSet(self.universe.empty) 

605 expr.gather_required_columns(columns) 

606 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

607 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

608 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0) 

609 with self.assertRaises(InvalidQueryError): 

610 qt.UnaryExpression( 

611 operand=qt.DimensionFieldReference( 

612 element=self.universe.dimensions["detector"], field="purpose" 

613 ), 

614 operator="-", 

615 ) 

616 

617 def test_unary_timespan_begin(self) -> None: 

618 expr = self.x.unwrap(self.x.visit.timespan.begin) 

619 self.assertIsNone(expr.get_literal_value()) 

620 self.assertEqual(expr.expression_type, "unary") 

621 self.assertEqual(expr.column_type, "datetime") 

622 self.assertEqual(str(expr), "visit.timespan.begin") 

623 self.assertFalse(expr.is_literal) 

624 columns = qt.ColumnSet(self.universe.empty) 

625 expr.gather_required_columns(columns) 

626 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

627 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

628 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

629 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

630 value = Timespan(begin, end) 

631 self.assertEqual( 

632 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin 

633 ) 

634 with self.assertRaises(InvalidQueryError): 

635 qt.UnaryExpression( 

636 operand=qt.DimensionFieldReference( 

637 element=self.universe.dimensions["detector"], field="purpose" 

638 ), 

639 operator="begin_of", 

640 ) 

641 

642 def test_unary_timespan_end(self) -> None: 

643 expr = self.x.unwrap(self.x.visit.timespan.end) 

644 self.assertIsNone(expr.get_literal_value()) 

645 self.assertEqual(expr.expression_type, "unary") 

646 self.assertEqual(expr.column_type, "datetime") 

647 self.assertEqual(str(expr), "visit.timespan.end") 

648 self.assertFalse(expr.is_literal) 

649 columns = qt.ColumnSet(self.universe.empty) 

650 expr.gather_required_columns(columns) 

651 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

652 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

653 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

654 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

655 value = Timespan(begin, end) 

656 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end) 

657 with self.assertRaises(InvalidQueryError): 

658 qt.UnaryExpression( 

659 operand=qt.DimensionFieldReference( 

660 element=self.universe.dimensions["detector"], field="purpose" 

661 ), 

662 operator="end_of", 

663 ) 

664 

665 def test_binary_expression_float(self) -> None: 

666 for proxy, string, value in [ 

667 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0), 

668 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0), 

669 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0), 

670 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0), 

671 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0), 

672 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0), 

673 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0), 

674 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0), 

675 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0), 

676 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0), 

677 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0), 

678 # We don't need the parentheses we generate in the next one, but 

679 # they're not a problem either. 

680 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0), 

681 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0), 

682 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0), 

683 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0), 

684 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0), 

685 ]: 

686 with self.subTest(string=string): 

687 expr = self.x.unwrap(proxy) 

688 self.assertIsNone(expr.get_literal_value()) 

689 self.assertEqual(expr.expression_type, "binary") 

690 self.assertEqual(expr.column_type, "float") 

691 self.assertEqual(str(expr), string) 

692 self.assertFalse(expr.is_literal) 

693 columns = qt.ColumnSet(self.universe.empty) 

694 expr.gather_required_columns(columns) 

695 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

696 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

697 self.assertEqual( 

698 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value 

699 ) 

700 

701 def test_binary_modulus(self) -> None: 

702 for proxy, string, value in [ 

703 (self.x.visit.id % 2, "visit % 2", 1), 

704 (52 % self.x.visit, "52 % visit", 2), 

705 ]: 

706 with self.subTest(string=string): 

707 expr = self.x.unwrap(proxy) 

708 self.assertIsNone(expr.get_literal_value()) 

709 self.assertEqual(expr.expression_type, "binary") 

710 self.assertEqual(expr.column_type, "int") 

711 self.assertEqual(str(expr), string) 

712 self.assertFalse(expr.is_literal) 

713 columns = qt.ColumnSet(self.universe.empty) 

714 expr.gather_required_columns(columns) 

715 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

716 self.assertFalse(columns.dimension_fields["visit"]) 

717 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value) 

718 

719 def test_binary_expression_validation(self) -> None: 

720 with self.assertRaises(InvalidQueryError): 

721 # No arithmetic operators on strings (we do not interpret + as 

722 # concatenation). 

723 self.x.instrument + "suffix" 

724 with self.assertRaises(InvalidQueryError): 

725 # Mixed types are not supported, even when they both support the 

726 # operator. 

727 self.x.visit.exposure_time + self.x.detector 

728 with self.assertRaises(InvalidQueryError): 

729 # No modulus for floats. 

730 self.x.visit.exposure_time % 5.0 

731 

732 def test_reversed(self) -> None: 

733 expr = self.x.detector.desc 

734 self.assertIsNone(expr.get_literal_value()) 

735 self.assertEqual(expr.expression_type, "reversed") 

736 self.assertEqual(expr.column_type, "int") 

737 self.assertEqual(str(expr), "detector DESC") 

738 self.assertFalse(expr.is_literal) 

739 columns = qt.ColumnSet(self.universe.empty) 

740 expr.gather_required_columns(columns) 

741 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

742 self.assertFalse(columns.dimension_fields["detector"]) 

743 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5)) 

744 

745 def test_trivial_predicate(self) -> None: 

746 """Test logical operations on trivial True/False predicates.""" 

747 yes = qt.Predicate.from_bool(True) 

748 no = qt.Predicate.from_bool(False) 

749 maybe: qt.Predicate = self.x.detector == 5 

750 for predicate in [ 

751 yes, 

752 yes.logical_or(no), 

753 no.logical_or(yes), 

754 yes.logical_and(yes), 

755 no.logical_not(), 

756 yes.logical_or(maybe), 

757 maybe.logical_or(yes), 

758 ]: 

759 self.assertEqual(predicate.column_type, "bool") 

760 self.assertEqual(str(predicate), "True") 

761 self.assertTrue(predicate.visit(_TestVisitor())) 

762 self.assertEqual(predicate.operands, ()) 

763 for predicate in [ 

764 no, 

765 yes.logical_and(no), 

766 no.logical_and(yes), 

767 no.logical_or(no), 

768 yes.logical_not(), 

769 no.logical_and(maybe), 

770 maybe.logical_and(no), 

771 ]: 

772 self.assertEqual(predicate.column_type, "bool") 

773 self.assertEqual(str(predicate), "False") 

774 self.assertFalse(predicate.visit(_TestVisitor())) 

775 self.assertEqual(predicate.operands, ((),)) 

776 for predicate in [ 

777 maybe, 

778 yes.logical_and(maybe), 

779 no.logical_or(maybe), 

780 maybe.logical_not().logical_not(), 

781 ]: 

782 self.assertEqual(predicate.column_type, "bool") 

783 self.assertEqual(str(predicate), "detector == 5") 

784 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5}))) 

785 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4}))) 

786 self.assertEqual(len(predicate.operands), 1) 

787 self.assertEqual(len(predicate.operands[0]), 1) 

788 self.assertIs(predicate.operands[0][0], maybe.operands[0][0]) 

789 

790 def test_comparison(self) -> None: 

791 predicate: qt.Predicate 

792 string: str 

793 value: bool 

794 for detector in (4, 5, 6): 

795 for predicate, string, value in [ 

796 (self.x.detector == 5, "detector == 5", detector == 5), 

797 (self.x.detector != 5, "detector != 5", detector != 5), 

798 (self.x.detector < 5, "detector < 5", detector < 5), 

799 (self.x.detector > 5, "detector > 5", detector > 5), 

800 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

801 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

802 (self.x.detector == 5, "detector == 5", detector == 5), 

803 (self.x.detector != 5, "detector != 5", detector != 5), 

804 (self.x.detector < 5, "detector < 5", detector < 5), 

805 (self.x.detector > 5, "detector > 5", detector > 5), 

806 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

807 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

808 ]: 

809 with self.subTest(string=string, detector=detector): 

810 self.assertEqual(predicate.column_type, "bool") 

811 self.assertEqual(str(predicate), string) 

812 columns = qt.ColumnSet(self.universe.empty) 

813 predicate.gather_required_columns(columns) 

814 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

815 self.assertFalse(columns.dimension_fields["detector"]) 

816 self.assertEqual( 

817 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value 

818 ) 

819 inverted = predicate.logical_not() 

820 self.assertEqual(inverted.column_type, "bool") 

821 self.assertEqual(str(inverted), f"NOT {string}") 

822 self.assertEqual( 

823 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value 

824 ) 

825 columns = qt.ColumnSet(self.universe.empty) 

826 inverted.gather_required_columns(columns) 

827 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

828 self.assertFalse(columns.dimension_fields["detector"]) 

829 

830 def test_overlap_comparison(self) -> None: 

831 pixelization = Mq3cPixelization(10) 

832 region1 = pixelization.quad(12058870) 

833 predicate = self.x.visit.region.overlaps(region1) 

834 self.assertEqual(predicate.column_type, "bool") 

835 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)") 

836 columns = qt.ColumnSet(self.universe.empty) 

837 predicate.gather_required_columns(columns) 

838 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

839 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

840 region2 = pixelization.quad(12058857) 

841 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

842 inverted = predicate.logical_not() 

843 self.assertEqual(inverted.column_type, "bool") 

844 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)") 

845 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

846 columns = qt.ColumnSet(self.universe.empty) 

847 inverted.gather_required_columns(columns) 

848 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

849 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

850 

851 def test_invalid_comparison(self) -> None: 

852 # Mixed type comparisons. 

853 with self.assertRaises(InvalidQueryError): 

854 self.x.visit > "three" 

855 # Invalid operator for type. 

856 with self.assertRaises(InvalidQueryError): 

857 self.x["raw"].dataset_id < uuid.uuid4() 

858 

859 def test_is_null(self) -> None: 

860 predicate = self.x.visit.region.is_null 

861 self.assertEqual(predicate.column_type, "bool") 

862 self.assertEqual(str(predicate), "visit.region IS NULL") 

863 columns = qt.ColumnSet(self.universe.empty) 

864 predicate.gather_required_columns(columns) 

865 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

866 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

867 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

868 inverted = predicate.logical_not() 

869 self.assertEqual(inverted.column_type, "bool") 

870 self.assertEqual(str(inverted), "NOT visit.region IS NULL") 

871 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

872 inverted.gather_required_columns(columns) 

873 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

874 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

875 

876 def test_in_container(self) -> None: 

877 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id]) 

878 self.assertEqual(predicate.column_type, "bool") 

879 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]") 

880 columns = qt.ColumnSet(self.universe.empty) 

881 predicate.gather_required_columns(columns) 

882 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

883 self.assertFalse(columns.dimension_fields["visit"]) 

884 self.assertFalse(columns.dimension_fields["exposure"]) 

885 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

886 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

887 inverted = predicate.logical_not() 

888 self.assertEqual(inverted.column_type, "bool") 

889 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]") 

890 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

891 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

892 columns = qt.ColumnSet(self.universe.empty) 

893 inverted.gather_required_columns(columns) 

894 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

895 self.assertFalse(columns.dimension_fields["visit"]) 

896 self.assertFalse(columns.dimension_fields["exposure"]) 

897 with self.assertRaises(InvalidQueryError): 

898 # Regions (and timespans) not allowed in IN expressions, since that 

899 # suggests topological logic we're not actually doing. We can't 

900 # use ExpressionFactory because it prohibits this case with typing. 

901 pixelization = Mq3cPixelization(10) 

902 region = pixelization.quad(12058870) 

903 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)]) 

904 with self.assertRaises(InvalidQueryError): 

905 # Mismatched types. 

906 self.x.visit.in_iterable([3.5, 2.1]) 

907 

908 def test_in_range(self) -> None: 

909 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2) 

910 self.assertEqual(predicate.column_type, "bool") 

911 self.assertEqual(str(predicate), "visit IN 2:8:2") 

912 columns = qt.ColumnSet(self.universe.empty) 

913 predicate.gather_required_columns(columns) 

914 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

915 self.assertFalse(columns.dimension_fields["visit"]) 

916 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

917 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

918 inverted = predicate.logical_not() 

919 self.assertEqual(inverted.column_type, "bool") 

920 self.assertEqual(str(inverted), "NOT visit IN 2:8:2") 

921 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

922 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

923 columns = qt.ColumnSet(self.universe.empty) 

924 inverted.gather_required_columns(columns) 

925 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

926 self.assertFalse(columns.dimension_fields["visit"]) 

927 with self.assertRaises(InvalidQueryError): 

928 # Only integer fields allowed. 

929 self.x.visit.exposure_time.in_range(2, 4) 

930 with self.assertRaises(InvalidQueryError): 

931 # Step must be positive. 

932 self.x.visit.in_range(2, 4, -1) 

933 with self.assertRaises(InvalidQueryError): 

934 # Stop must be >= start. 

935 self.x.visit.in_range(2, 0) 

936 

937 def test_in_query(self) -> None: 

938 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3) 

939 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query) 

940 self.assertEqual(predicate.column_type, "bool") 

941 self.assertEqual(str(predicate), "exposure IN (query).visit") 

942 columns = qt.ColumnSet(self.universe.empty) 

943 predicate.gather_required_columns(columns) 

944 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

945 self.assertFalse(columns.dimension_fields["exposure"]) 

946 self.assertTrue( 

947 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

948 ) 

949 self.assertFalse( 

950 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

951 ) 

952 inverted = predicate.logical_not() 

953 self.assertEqual(inverted.column_type, "bool") 

954 self.assertEqual(str(inverted), "NOT exposure IN (query).visit") 

955 self.assertFalse( 

956 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

957 ) 

958 self.assertTrue( 

959 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

960 ) 

961 columns = qt.ColumnSet(self.universe.empty) 

962 inverted.gather_required_columns(columns) 

963 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

964 self.assertFalse(columns.dimension_fields["exposure"]) 

965 with self.assertRaises(InvalidQueryError): 

966 # Regions (and timespans) not allowed in IN expressions, since that 

967 # suggests topological logic we're not actually doing. We can't 

968 # use ExpressionFactory because it prohibits this case with typing. 

969 qt.Predicate.in_query( 

970 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree 

971 ) 

972 with self.assertRaises(InvalidQueryError): 

973 # Mismatched types. 

974 self.x.exposure.in_query(self.x.visit.exposure_time, query) 

975 with self.assertRaises(InvalidQueryError): 

976 # Query column requires dimensions that are not in the query. 

977 self.x.exposure.in_query(self.x.patch, query) 

978 with self.assertRaises(InvalidQueryError): 

979 # Query column requires dataset type that is not in the query. 

980 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query) 

981 

982 def test_complex_predicate(self) -> None: 

983 """Test that predicates are converted to conjunctive normal form and 

984 get parentheses in the right places when stringified. 

985 """ 

986 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"}) 

987 a: qt.Predicate = self.x.visit > 5 # will evaluate to True 

988 b: qt.Predicate = self.x.detector != 3 # will evaluate to False 

989 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True 

990 d: qt.Predicate = self.x.band == "g" # will evaluate to False 

991 predicate: qt.Predicate 

992 for predicate, string, value in [ 

993 (a.logical_or(b), f"{a} OR {b}", True), 

994 (a.logical_or(c), f"{a} OR {c}", True), 

995 (b.logical_or(d), f"{b} OR {d}", False), 

996 (a.logical_and(b), f"{a} AND {b}", False), 

997 (a.logical_and(c), f"{a} AND {c}", True), 

998 (b.logical_and(d), f"{b} AND {d}", False), 

999 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True), 

1000 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False), 

1001 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True), 

1002 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False), 

1003 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True), 

1004 ( 

1005 a.logical_and(b).logical_or(c.logical_and(d)), 

1006 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})", 

1007 False, 

1008 ), 

1009 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False), 

1010 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False), 

1011 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True), 

1012 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True), 

1013 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False), 

1014 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True), 

1015 ( 

1016 self.x.not_(a.logical_or(b).logical_and(c)), 

1017 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})", 

1018 False, 

1019 ), 

1020 ( 

1021 a.logical_and(b.logical_or(d)).logical_not(), 

1022 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})", 

1023 True, 

1024 ), 

1025 ]: 

1026 with self.subTest(string=string): 

1027 self.assertEqual(str(predicate), string) 

1028 self.assertEqual(predicate.visit(visitor), value) 

1029 

1030 def test_proxy_misc(self) -> None: 

1031 """Test miscellaneous things on various ExpressionFactory proxies.""" 

1032 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region") 

1033 self.assertEqual(str(self.x.visit.instrument), "instrument") 

1034 self.assertEqual(str(self.x["raw"]), "raw") 

1035 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date") 

1036 self.assertEqual( 

1037 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)), 

1038 "visit.timespan OVERLAPS raw.timespan", 

1039 ) 

1040 self.assertGreater( 

1041 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"} 

1042 ) 

1043 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"}) 

1044 with self.assertRaises(AttributeError): 

1045 self.x["raw"].seq_num 

1046 with self.assertRaises(AttributeError): 

1047 self.x.visit.horse 

1048 

1049 

1050class QueryTestCase(unittest.TestCase): 

1051 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries.""" 

1052 

1053 def setUp(self) -> None: 

1054 self.maxDiff = None 

1055 self.universe = DimensionUniverse() 

1056 # We use ArrowTable as the storage class for all dataset types because 

1057 # it's got conversions that only require third-party packages we 

1058 # already require. 

1059 self.raw = DatasetType( 

1060 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable" 

1061 ) 

1062 self.refcat = DatasetType( 

1063 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable" 

1064 ) 

1065 self.bias = DatasetType( 

1066 "bias", 

1067 dimensions=self.universe.conform(["detector"]), 

1068 storageClass="ArrowTable", 

1069 isCalibration=True, 

1070 ) 

1071 self.default_collections: list[str] | None = ["DummyCam/defaults"] 

1072 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = { 

1073 "DummyCam/raw/all": ( 

1074 RunRecord[int](1, name="DummyCam/raw/all"), 

1075 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}), 

1076 ), 

1077 "DummyCam/calib": ( 

1078 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION), 

1079 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}), 

1080 ), 

1081 "refcats": ( 

1082 RunRecord[int](3, name="refcats"), 

1083 CollectionSummary(NamedValueSet({self.refcat}), governors={}), 

1084 ), 

1085 "DummyCam/defaults": ( 

1086 ChainedCollectionRecord[int]( 

1087 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats") 

1088 ), 

1089 CollectionSummary( 

1090 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}} 

1091 ), 

1092 ), 

1093 } 

1094 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias} 

1095 

1096 def query(self, **kwargs: Any) -> Query: 

1097 """Make an initial Query object with the given kwargs used to 

1098 initialize the _TestQueryDriver. 

1099 

1100 The given kwargs override the test-case-attribute defaults. 

1101 """ 

1102 kwargs.setdefault("default_collections", self.default_collections) 

1103 kwargs.setdefault("collection_info", self.collection_info) 

1104 kwargs.setdefault("dataset_types", self.dataset_types) 

1105 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

1106 

1107 def test_dataset_join(self) -> None: 

1108 """Test queries that have had a dataset search explicitly joined in via 

1109 Query.join_dataset_search. 

1110 

1111 Since this kind of query has a moderate amount of complexity, this is 

1112 where we get a lot of basic coverage that applies to all kinds of 

1113 queries, including: 

1114 

1115 - getting data ID and dataset results (but not iterating over them); 

1116 - the 'any' and 'explain_no_results' methods; 

1117 - adding 'where' filters (but not expanding dimensions accordingly); 

1118 - materializations. 

1119 """ 

1120 

1121 def check( 

1122 query: Query, 

1123 dimensions: DimensionGroup = self.raw.dimensions, 

1124 ) -> None: 

1125 """Run a battery of tests on one of a set of very similar queries 

1126 constructed in different ways (see below). 

1127 """ 

1128 

1129 def check_query_tree( 

1130 tree: qt.QueryTree, 

1131 dimensions: DimensionGroup = dimensions, 

1132 ) -> None: 

1133 """Check the state of the QueryTree object that backs the Query 

1134 or a derived QueryResults object. 

1135 

1136 Parameters 

1137 ---------- 

1138 tree : `lsst.daf.butler.queries.tree.QueryTree` 

1139 Object to test. 

1140 dimensions : `DimensionGroup` 

1141 Dimensions to expect in the `QueryTree`, not necessarily 

1142 including those in the test 'raw' dataset type. 

1143 """ 

1144 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions) 

1145 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'") 

1146 self.assertFalse(tree.materializations) 

1147 self.assertFalse(tree.data_coordinate_uploads) 

1148 self.assertEqual(tree.datasets.keys(), {"raw"}) 

1149 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions) 

1150 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1151 self.assertEqual(tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions})) 

1152 

1153 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None: 

1154 """Construct a DataCoordinateQueryResults object from the query 

1155 with the given arguments and run a battery of tests on it. 

1156 

1157 Parameters 

1158 ---------- 

1159 *args 

1160 Forwarded to `Query.data_ids`. 

1161 query : `Query` 

1162 Query to start from. 

1163 dimensions : `DimensionGroup`, optional 

1164 Dimensions the result data IDs should have. 

1165 """ 

1166 with self.assertRaises(_TestQueryExecution) as cm: 

1167 list(query.data_ids(*args)) 

1168 self.assertEqual( 

1169 cm.exception.result_spec, 

1170 qrs.DataCoordinateResultSpec(dimensions=dimensions), 

1171 ) 

1172 check_query_tree(cm.exception.tree, dimensions=dimensions) 

1173 

1174 def check_dataset_results( 

1175 *args: Any, 

1176 query: Query, 

1177 find_first: bool = True, 

1178 storage_class_name: str = self.raw.storageClass_name, 

1179 ) -> None: 

1180 """Construct a DatasetRefQueryResults object from the query 

1181 with the given arguments and run a battery of tests on it. 

1182 

1183 Parameters 

1184 ---------- 

1185 *args 

1186 Forwarded to `Query.datasets`. 

1187 query : `Query` 

1188 Query to start from. 

1189 find_first : `bool`, optional 

1190 Whether to do find-first resolution on the results. 

1191 storage_class_name : `str`, optional 

1192 Expected name of the storage class for the results. 

1193 """ 

1194 with self.assertRaises(_TestQueryExecution) as cm: 

1195 list(query.datasets(*args, find_first=find_first)) 

1196 self.assertEqual( 

1197 cm.exception.result_spec, 

1198 qrs.DatasetRefResultSpec( 

1199 dataset_type_name="raw", 

1200 dimensions=self.raw.dimensions, 

1201 storage_class_name=storage_class_name, 

1202 find_first=find_first, 

1203 ), 

1204 ) 

1205 check_query_tree(cm.exception.tree) 

1206 

1207 def check_materialization( 

1208 kwargs: Mapping[str, Any], 

1209 query: Query, 

1210 dimensions: DimensionGroup = dimensions, 

1211 has_dataset: bool = True, 

1212 ) -> None: 

1213 """Materialize the query with the given arguments and run a 

1214 battery of tests on the result. 

1215 

1216 Parameters 

1217 ---------- 

1218 kwargs 

1219 Forwarded as keyword arguments to `Query.materialize`. 

1220 query : `Query` 

1221 Query to start from. 

1222 dimensions : `DimensionGroup`, optional 

1223 Dimensions to expect in the materialization and its derived 

1224 query. 

1225 has_dataset : `bool`, optional 

1226 Whether the query backed by the materialization should 

1227 still have the test 'raw' dataset joined in. 

1228 """ 

1229 # Materialize the query and check the query tree sent to the 

1230 # driver and the one in the materialized query. 

1231 with self.assertRaises(_TestQueryExecution) as cm: 

1232 list(query.materialize(**kwargs).data_ids()) 

1233 derived_tree = cm.exception.tree 

1234 self.assertEqual(derived_tree.dimensions, dimensions) 

1235 # Predicate should be materialized away; it no longer appears 

1236 # in the derived query. 

1237 self.assertEqual(str(derived_tree.predicate), "True") 

1238 self.assertFalse(derived_tree.data_coordinate_uploads) 

1239 if has_dataset: 

1240 # Dataset search is still there, even though its existence 

1241 # constraint is included in the materialization, because we 

1242 # might need to re-join for some result columns in a 

1243 # derived query. 

1244 self.assertTrue(derived_tree.datasets.keys(), {"raw"}) 

1245 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions) 

1246 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1247 else: 

1248 self.assertFalse(derived_tree.datasets) 

1249 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items() 

1250 self.assertEqual(derived_tree_materialized_dimensions, dimensions) 

1251 ( 

1252 materialized_tree, 

1253 materialized_dimensions, 

1254 materialized_datasets, 

1255 ) = cm.exception.driver.materializations[key] 

1256 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions) 

1257 if has_dataset: 

1258 self.assertEqual(materialized_datasets, {"raw"}) 

1259 else: 

1260 self.assertFalse(materialized_datasets) 

1261 check_query_tree(materialized_tree) 

1262 

1263 # Actual logic for the check() function begins here. 

1264 

1265 self.assertEqual(query.constraint_dataset_types, {"raw"}) 

1266 self.assertEqual(query.constraint_dimensions, self.raw.dimensions) 

1267 

1268 # Adding a constraint on a field for this dataset type should work 

1269 # (this constraint will be present in all downstream tests). 

1270 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all") 

1271 with self.assertRaises(InvalidQueryError): 

1272 # Adding constraint on a different dataset should not work. 

1273 query.where(query.expression_factory["refcat"].run == "refcats") 

1274 

1275 # Data IDs, with dimensions defaulted. 

1276 check_data_id_results(query=query) 

1277 # Dimensions for data IDs the same as defaults. 

1278 check_data_id_results(["exposure", "detector"], query=query) 

1279 # Dimensions are a subset of the query dimensions. 

1280 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"])) 

1281 # Dimensions are a superset of the query dimensions. 

1282 check_data_id_results( 

1283 ["exposure", "detector", "visit"], 

1284 query=query, 

1285 dimensions=self.universe.conform(["exposure", "detector", "visit"]), 

1286 ) 

1287 # Dimensions are neither a superset nor a subset of the query 

1288 # dimensions. 

1289 check_data_id_results( 

1290 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"]) 

1291 ) 

1292 # Dimensions are empty. 

1293 check_data_id_results([], query=query, dimensions=self.universe.conform([])) 

1294 

1295 # Get DatasetRef results, with various arguments and defaulting. 

1296 check_dataset_results("raw", query=query) 

1297 check_dataset_results("raw", query=query, find_first=True) 

1298 check_dataset_results("raw", ["DummyCam/defaults"], query=query) 

1299 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True) 

1300 check_dataset_results(self.raw, query=query) 

1301 check_dataset_results(self.raw, query=query, find_first=True) 

1302 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query) 

1303 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True) 

1304 

1305 # Changing collections at this stage is not allowed. 

1306 with self.assertRaises(InvalidQueryError): 

1307 query.datasets("raw", collections=["DummyCam/calib"]) 

1308 

1309 # Changing storage classes is allowed, if they're compatible. 

1310 check_dataset_results( 

1311 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy" 

1312 ) 

1313 with self.assertRaises(DatasetTypeError): 

1314 # Can't use overrideStorageClass, because it'll raise 

1315 # before the code we want to test can. 

1316 query.datasets(DatasetType("raw", self.raw.dimensions, "int")) 

1317 

1318 # Check the 'any' and 'explain_no_results' methods on Query itself. 

1319 for execute, exact in itertools.permutations([False, True], 2): 

1320 with self.assertRaises(_TestQueryAny) as cm: 

1321 query.any(execute=execute, exact=exact) 

1322 self.assertEqual(cm.exception.execute, execute) 

1323 self.assertEqual(cm.exception.exact, exact) 

1324 check_query_tree(cm.exception.tree, dimensions) 

1325 with self.assertRaises(_TestQueryExplainNoResults): 

1326 query.explain_no_results() 

1327 check_query_tree(cm.exception.tree, dimensions) 

1328 

1329 # Materialize the query with defaults. 

1330 check_materialization({}, query=query) 

1331 # Materialize the query with args that match defaults. 

1332 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query) 

1333 # Materialize the query with a superset of the original dimensions. 

1334 check_materialization( 

1335 {"dimensions": ["exposure", "detector", "visit"]}, 

1336 query=query, 

1337 dimensions=self.universe.conform(["exposure", "visit", "detector"]), 

1338 ) 

1339 # Materialize the query with no datasets. 

1340 check_materialization( 

1341 {"dimensions": ["exposure", "detector"], "datasets": frozenset()}, 

1342 query=query, 

1343 has_dataset=False, 

1344 ) 

1345 # Materialize the query with no datasets and a subset of the 

1346 # dimensions. 

1347 check_materialization( 

1348 {"dimensions": ["exposure"], "datasets": frozenset()}, 

1349 query=query, 

1350 has_dataset=False, 

1351 dimensions=self.universe.conform(["exposure"]), 

1352 ) 

1353 # Materializing the query with a dataset that is not in the query 

1354 # is an error. 

1355 with self.assertRaises(InvalidQueryError): 

1356 query.materialize(datasets={"refcat"}) 

1357 # Materializing the query with dimensions that are not a superset 

1358 # of any materialized dataset dimensions is an error. 

1359 with self.assertRaises(InvalidQueryError): 

1360 query.materialize(dimensions=["exposure"], datasets={"raw"}) 

1361 

1362 # Actual logic for test_dataset_joins starts here. 

1363 

1364 # Default collections and existing dataset type name. 

1365 check(self.query().join_dataset_search("raw")) 

1366 # Default collections and existing DatasetType instance. 

1367 check(self.query().join_dataset_search(self.raw)) 

1368 # Manual collections and existing dataset type. 

1369 check( 

1370 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"]) 

1371 ) 

1372 check( 

1373 self.query(default_collections=None).join_dataset_search( 

1374 self.raw, collections=["DummyCam/defaults"] 

1375 ) 

1376 ) 

1377 with self.assertRaises(MissingDatasetTypeError): 

1378 # Dataset type does not exist. 

1379 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"]) 

1380 with self.assertRaises(DatasetTypeError): 

1381 # Dataset type object with bad dimensions passed. 

1382 self.query().join_dataset_search( 

1383 DatasetType( 

1384 "raw", 

1385 dimensions={"detector", "visit"}, 

1386 storageClass=self.raw.storageClass_name, 

1387 universe=self.universe, 

1388 ) 

1389 ) 

1390 with self.assertRaises(TypeError): 

1391 # Bad type for dataset type argument. 

1392 self.query().join_dataset_search(3) 

1393 with self.assertRaises(InvalidQueryError): 

1394 # Cannot pass storage class override to join_dataset_search, 

1395 # because we cannot use it there. 

1396 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy")) 

1397 

1398 def test_dimension_record_results(self) -> None: 

1399 """Test queries that return dimension records. 

1400 

1401 This includes tests for: 

1402 

1403 - joining against uploaded data coordinates; 

1404 - counting result rows; 

1405 - expanding dimensions as needed for 'where' conditions; 

1406 - order_by and limit. 

1407 

1408 It does not include the iteration methods of 

1409 DimensionRecordQueryResults, since those require a different mock 

1410 driver setup (see test_dimension_record_iteration). 

1411 """ 

1412 # Set up the base query-results object to test. 

1413 query = self.query() 

1414 x = query.expression_factory 

1415 self.assertFalse(query.constraint_dimensions) 

1416 query = query.where(x.skymap == "m") 

1417 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"])) 

1418 upload_rows = [ 

1419 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe), 

1420 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe), 

1421 ] 

1422 raw_rows = frozenset([data_id.required_values for data_id in upload_rows]) 

1423 query = query.join_data_coordinates(upload_rows) 

1424 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"])) 

1425 results = query.dimension_records("patch") 

1426 results = results.where(x.tract == 4) 

1427 

1428 # Define a closure to run tests on variants of the base query. 

1429 def check( 

1430 results: DimensionRecordQueryResults, 

1431 order_by: Any = (), 

1432 limit: int | None = None, 

1433 ) -> list[str]: 

1434 results = results.order_by(*order_by).limit(limit) 

1435 self.assertEqual(results.element.name, "patch") 

1436 with self.assertRaises(_TestQueryExecution) as cm: 

1437 list(results) 

1438 tree = cm.exception.tree 

1439 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1440 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"])) 

1441 self.assertFalse(tree.materializations) 

1442 self.assertFalse(tree.datasets) 

1443 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items() 

1444 self.assertEqual(upload_dimensions, self.universe.conform(["visit"])) 

1445 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows)) 

1446 result_spec = cm.exception.result_spec 

1447 self.assertEqual(result_spec.result_type, "dimension_record") 

1448 self.assertEqual(result_spec.element, self.universe["patch"]) 

1449 self.assertEqual(result_spec.limit, limit) 

1450 for exact, discard in itertools.permutations([False, True], r=2): 

1451 with self.assertRaises(_TestQueryCount) as cm: 

1452 results.count(exact=exact, discard=discard) 

1453 self.assertEqual(cm.exception.result_spec, result_spec) 

1454 self.assertEqual(cm.exception.exact, exact) 

1455 self.assertEqual(cm.exception.discard, discard) 

1456 return [str(term) for term in result_spec.order_by] 

1457 

1458 # Run the closure's tests on variants of the base query. 

1459 self.assertEqual(check(results), []) 

1460 self.assertEqual(check(results, limit=2), []) 

1461 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1462 self.assertEqual( 

1463 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1464 ["patch.cell_x", "patch.cell_y DESC"], 

1465 ) 

1466 with self.assertRaises(InvalidQueryError): 

1467 # Cannot upload empty list of data IDs. 

1468 query.join_data_coordinates([]) 

1469 with self.assertRaises(InvalidQueryError): 

1470 # Cannot upload heterogeneous list of data IDs. 

1471 query.join_data_coordinates( 

1472 [ 

1473 DataCoordinate.make_empty(self.universe), 

1474 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe), 

1475 ] 

1476 ) 

1477 

1478 def test_join_data_coordinate_table(self) -> None: 

1479 """Test joining a data coordinates via a table.""" 

1480 query = self.query() 

1481 x = query.expression_factory 

1482 self.assertFalse(query.constraint_dimensions) 

1483 query = query.where(x.skymap == "m") 

1484 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"])) 

1485 with self.assertRaises(InvalidQueryError): 

1486 # Cannot join a table with no rows. 

1487 query.join_data_coordinate_table( 

1488 astropy.table.Table({"instrument": np.array([]), "visit": np.array([])}) 

1489 ) 

1490 with self.assertRaises(InvalidQueryError): 

1491 # Cannot join a table with missing required dimensions. 

1492 query.join_data_coordinate_table(astropy.table.Table({"visit": np.array([1, 3, 5])})) 

1493 upload_table = astropy.table.Table({"tract": np.array([2, 5, 7], dtype=np.int64)}) 

1494 raw_rows = frozenset([("m", 2), ("m", 5), ("m", 7)]) 

1495 query = query.join_data_coordinate_table(upload_table) 

1496 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "tract"])) 

1497 results = query.dimension_records("patch") 

1498 with self.assertRaises(_TestQueryExecution) as cm: 

1499 list(results) 

1500 tree = cm.exception.tree 

1501 driver = cm.exception.driver 

1502 self.assertEqual(str(tree.predicate), "skymap == 'm'") 

1503 self.assertEqual(tree.dimensions, self.universe.conform(["patch"])) 

1504 self.assertFalse(tree.materializations) 

1505 self.assertFalse(tree.datasets) 

1506 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items() 

1507 self.assertEqual(upload_dimensions, self.universe.conform(["tract"])) 

1508 self.assertEqual(driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows)) 

1509 for row in driver.data_coordinate_uploads[key][1]: 

1510 for value in row: 

1511 # Make sure numpy scalar types have been cast to builtins. 

1512 self.assertIsInstance(value, (int, str)) 

1513 

1514 def test_dimension_record_iteration(self) -> None: 

1515 """Tests for DimensionRecordQueryResult iteration.""" 

1516 

1517 def make_record(n: int) -> DimensionRecord: 

1518 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n) 

1519 

1520 result_rows = ( 

1521 [make_record(n) for n in range(3)], 

1522 [make_record(n) for n in range(3, 6)], 

1523 [make_record(10)], 

1524 ) 

1525 results = self.query(result_rows=result_rows).dimension_records("patch") 

1526 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1527 self.assertEqual( 

1528 list(results.iter_set_pages()), 

1529 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows], 

1530 ) 

1531 self.assertEqual( 

1532 [table.column("id").to_pylist() for table in results.iter_table_pages()], 

1533 [list(range(3)), list(range(3, 6)), [10]], 

1534 ) 

1535 

1536 def test_data_coordinate_results(self) -> None: 

1537 """Test queries that return data coordinates. 

1538 

1539 This includes tests for: 

1540 

1541 - counting result rows; 

1542 - expanding dimensions as needed for 'where' conditions; 

1543 - order_by and limit. 

1544 

1545 It does not include the iteration methods of 

1546 DataCoordinateQueryResults, since those require a different mock 

1547 driver setup (see test_data_coordinate_iteration). More tests for 

1548 different inputs to DataCoordinateQueryResults construction are in 

1549 test_dataset_join. 

1550 """ 

1551 # Set up the base query-results object to test. 

1552 query = self.query() 

1553 x = query.expression_factory 

1554 self.assertFalse(query.constraint_dimensions) 

1555 query = query.where(x.skymap == "m") 

1556 results = query.data_ids(["patch", "band"]) 

1557 results = results.where(x.tract == 4) 

1558 

1559 # Define a closure to run tests on variants of the base query. 

1560 def check( 

1561 results: DataCoordinateQueryResults, 

1562 order_by: Any = (), 

1563 limit: int | None = None, 

1564 include_dimension_records: bool = False, 

1565 ) -> list[str]: 

1566 results = results.order_by(*order_by).limit(limit) 

1567 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"])) 

1568 with self.assertRaises(_TestQueryExecution) as cm: 

1569 list(results) 

1570 tree = cm.exception.tree 

1571 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1572 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"])) 

1573 self.assertFalse(tree.materializations) 

1574 self.assertFalse(tree.datasets) 

1575 self.assertFalse(tree.data_coordinate_uploads) 

1576 result_spec = cm.exception.result_spec 

1577 self.assertEqual(result_spec.result_type, "data_coordinate") 

1578 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"])) 

1579 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1580 self.assertEqual(result_spec.limit, limit) 

1581 self.assertIsNone(result_spec.find_first_dataset) 

1582 for exact, discard in itertools.permutations([False, True], r=2): 

1583 with self.assertRaises(_TestQueryCount) as cm: 

1584 results.count(exact=exact, discard=discard) 

1585 self.assertEqual(cm.exception.result_spec, result_spec) 

1586 self.assertEqual(cm.exception.exact, exact) 

1587 self.assertEqual(cm.exception.discard, discard) 

1588 return [str(term) for term in result_spec.order_by] 

1589 

1590 # Run the closure's tests on variants of the base query. 

1591 self.assertEqual(check(results), []) 

1592 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), []) 

1593 self.assertEqual( 

1594 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1595 [], 

1596 ) 

1597 self.assertEqual(check(results, limit=2), []) 

1598 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1599 self.assertEqual( 

1600 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1601 ["patch.cell_x", "patch.cell_y DESC"], 

1602 ) 

1603 self.assertEqual( 

1604 check(results, order_by=["patch.cell_x", "-cell_y"]), 

1605 ["patch.cell_x", "patch.cell_y DESC"], 

1606 ) 

1607 

1608 def test_data_coordinate_iteration(self) -> None: 

1609 """Tests for DataCoordinateQueryResult iteration.""" 

1610 

1611 def make_data_id(n: int) -> DimensionRecord: 

1612 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe) 

1613 

1614 result_rows = ( 

1615 [make_data_id(n) for n in range(3)], 

1616 [make_data_id(n) for n in range(3, 6)], 

1617 [make_data_id(10)], 

1618 ) 

1619 results = self.query(result_rows=result_rows).data_ids(["patch"]) 

1620 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1621 

1622 def test_dataset_results(self) -> None: 

1623 """Test queries that return dataset refs. 

1624 

1625 This includes tests for: 

1626 

1627 - counting result rows; 

1628 - expanding dimensions as needed for 'where' conditions; 

1629 - different ways of passing a data ID to 'where' methods; 

1630 - order_by and limit. 

1631 

1632 It does not include the iteration methods of the DatasetRefQueryResults 

1633 classes, since those require a different mock driver setup (see 

1634 test_dataset_iteration). More tests for different inputs to 

1635 DatasetRefQueryResults construction are in test_dataset_join. 

1636 """ 

1637 # Set up a few equivalent base query-results object to test. 

1638 query = self.query() 

1639 x = query.expression_factory 

1640 self.assertFalse(query.constraint_dimensions) 

1641 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4) 

1642 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where( 

1643 {"instrument": "DummyCam", "visit": 4} 

1644 ) 

1645 results3 = query.datasets("raw").where( 

1646 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe) 

1647 ) 

1648 

1649 # Define a closure to check a DatasetRefQueryResults instance. 

1650 def check( 

1651 results: DatasetRefQueryResults, 

1652 order_by: Any = (), 

1653 limit: int | None = None, 

1654 include_dimension_records: bool = False, 

1655 ) -> list[str]: 

1656 results = results.order_by(*order_by).limit(limit) 

1657 with self.assertRaises(_TestQueryExecution) as cm: 

1658 list(results) 

1659 tree = cm.exception.tree 

1660 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4") 

1661 self.assertEqual( 

1662 tree.dimensions, 

1663 self.universe.conform(["visit"]).union(results.dataset_type.dimensions), 

1664 ) 

1665 self.assertFalse(tree.materializations) 

1666 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name}) 

1667 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",)) 

1668 self.assertEqual( 

1669 tree.datasets[results.dataset_type.name].dimensions, 

1670 results.dataset_type.dimensions, 

1671 ) 

1672 self.assertFalse(tree.data_coordinate_uploads) 

1673 result_spec = cm.exception.result_spec 

1674 self.assertEqual(result_spec.result_type, "dataset_ref") 

1675 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1676 self.assertEqual(result_spec.limit, limit) 

1677 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name) 

1678 for exact, discard in itertools.permutations([False, True], r=2): 

1679 with self.assertRaises(_TestQueryCount) as cm: 

1680 results.count(exact=exact, discard=discard) 

1681 self.assertEqual(cm.exception.result_spec, result_spec) 

1682 self.assertEqual(cm.exception.exact, exact) 

1683 self.assertEqual(cm.exception.discard, discard) 

1684 with self.assertRaises(_TestQueryExecution) as cm: 

1685 list(results.data_ids) 

1686 self.assertEqual( 

1687 cm.exception.result_spec, 

1688 qrs.DataCoordinateResultSpec( 

1689 dimensions=results.dataset_type.dimensions, 

1690 include_dimension_records=include_dimension_records, 

1691 ), 

1692 ) 

1693 self.assertIs(cm.exception.tree, tree) 

1694 return [str(term) for term in result_spec.order_by] 

1695 

1696 # Run the closure's tests on variants of the base query. 

1697 self.assertEqual(check(results1), []) 

1698 self.assertEqual(check(results2), []) 

1699 self.assertEqual(check(results3), []) 

1700 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), []) 

1701 self.assertEqual( 

1702 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1703 [], 

1704 ) 

1705 self.assertEqual(check(results1, limit=2), []) 

1706 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"]) 

1707 self.assertEqual(check(results1, order_by=["detector"]), ["detector"]) 

1708 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"]) 

1709 

1710 def test_dataset_iteration(self) -> None: 

1711 """Tests for SingleTypeDatasetQueryResult iteration.""" 

1712 

1713 def make_ref(n: int) -> DimensionRecord: 

1714 return DatasetRef( 

1715 self.raw, 

1716 DataCoordinate.standardize( 

1717 instrument="DummyCam", exposure=4, detector=n, universe=self.universe 

1718 ), 

1719 run="DummyCam/raw/all", 

1720 id=uuid.uuid4(), 

1721 ) 

1722 

1723 result_rows = ( 

1724 [make_ref(n) for n in range(3)], 

1725 [make_ref(n) for n in range(3, 6)], 

1726 [make_ref(10)], 

1727 ) 

1728 results = self.query(result_rows=result_rows).datasets("raw") 

1729 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1730 

1731 def test_identifiers(self) -> None: 

1732 """Test edge-cases of identifiers in order_by expressions.""" 

1733 

1734 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]: 

1735 with self.assertRaises(_TestQueryExecution) as cm: 

1736 list(results) 

1737 return [str(term) for term in cm.exception.result_spec.order_by] 

1738 

1739 self.assertEqual( 

1740 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")), 

1741 ["day_obs.timespan.begin DESC"], 

1742 ) 

1743 self.assertEqual( 

1744 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")), 

1745 ["day_obs.timespan.end"], 

1746 ) 

1747 self.assertEqual( 

1748 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")), 

1749 ["visit.timespan.begin DESC"], 

1750 ) 

1751 self.assertEqual( 

1752 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")), 

1753 ["visit.timespan.end"], 

1754 ) 

1755 self.assertEqual( 

1756 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")), 

1757 ["visit.science_program"], 

1758 ) 

1759 self.assertEqual( 

1760 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")), 

1761 ["visit"], 

1762 ) 

1763 self.assertEqual( 

1764 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")), 

1765 ["physical_filter"], 

1766 ) 

1767 with self.assertRaises(TypeError): 

1768 self.query().data_ids(["visit"]).order_by(3) 

1769 with self.assertRaises(InvalidQueryError): 

1770 self.query().data_ids(["visit"]).order_by("visit.region") 

1771 with self.assertRaisesRegex(InvalidQueryError, "Ambiguous"): 

1772 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin") 

1773 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1774 self.query().data_ids(["visit", "exposure"]).order_by("blarg") 

1775 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1776 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse") 

1777 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1778 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey") 

1779 with self.assertRaisesRegex(InvalidQueryError, "not valid for datasets"): 

1780 self.query().datasets("raw").order_by("raw.seq_num") 

1781 

1782 def test_invalid_models(self) -> None: 

1783 """Test invalid models and combinations of models that cannot be 

1784 constructed via the public Query and *QueryResults interfaces. 

1785 """ 

1786 x = ExpressionFactory(self.universe) 

1787 with self.assertRaises(InvalidQueryError): 

1788 # QueryTree dimensions do not cover dataset dimensions. 

1789 qt.QueryTree( 

1790 dimensions=self.universe.conform(["visit"]), 

1791 datasets={ 

1792 "raw": qt.DatasetSearch( 

1793 collections=("DummyCam/raw/all",), 

1794 dimensions=self.raw.dimensions, 

1795 ) 

1796 }, 

1797 ) 

1798 with self.assertRaises(InvalidQueryError): 

1799 # QueryTree dimensions do no cover predicate dimensions. 

1800 qt.QueryTree( 

1801 dimensions=self.universe.conform(["visit"]), 

1802 predicate=(x.detector > 5), 

1803 ) 

1804 with self.assertRaises(InvalidQueryError): 

1805 # Predicate references a dataset not in the QueryTree. 

1806 qt.QueryTree( 

1807 dimensions=self.universe.conform(["exposure", "detector"]), 

1808 predicate=(x["raw"].collection == "bird"), 

1809 ) 

1810 with self.assertRaises(InvalidQueryError): 

1811 # ResultSpec's dimensions are not a subset of the query tree's. 

1812 DimensionRecordQueryResults( 

1813 _TestQueryDriver(), 

1814 qt.QueryTree(dimensions=self.universe.conform(["tract"])), 

1815 qrs.DimensionRecordResultSpec(element=self.universe["detector"]), 

1816 ) 

1817 with self.assertRaises(InvalidQueryError): 

1818 # ResultSpec's datasets are not a subset of the query tree's. 

1819 DatasetRefQueryResults( 

1820 _TestQueryDriver(), 

1821 qt.QueryTree(dimensions=self.raw.dimensions), 

1822 qrs.DatasetRefResultSpec( 

1823 dataset_type_name="raw", 

1824 dimensions=self.raw.dimensions, 

1825 storage_class_name=self.raw.storageClass_name, 

1826 find_first=True, 

1827 ), 

1828 ) 

1829 with self.assertRaises(InvalidQueryError): 

1830 # ResultSpec's order_by expression is not related to the dimensions 

1831 # we're returning. 

1832 x = ExpressionFactory(self.universe) 

1833 DimensionRecordQueryResults( 

1834 _TestQueryDriver(), 

1835 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1836 qrs.DimensionRecordResultSpec( 

1837 element=self.universe["detector"], order_by=(x.unwrap(x.visit),) 

1838 ), 

1839 ) 

1840 with self.assertRaises(InvalidQueryError): 

1841 # ResultSpec's order_by expression is not related to the datasets 

1842 # we're returning. 

1843 x = ExpressionFactory(self.universe) 

1844 DimensionRecordQueryResults( 

1845 _TestQueryDriver(), 

1846 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1847 qrs.DimensionRecordResultSpec( 

1848 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),) 

1849 ), 

1850 ) 

1851 

1852 def test_general_result_spec(self) -> None: 

1853 """Tests for GeneralResultSpec. 

1854 

1855 Unlike the other ResultSpec objects, we don't have a *QueryResults 

1856 class for GeneralResultSpec yet, so we can't use the higher-level 

1857 interfaces to test it like we can the others. 

1858 """ 

1859 a = qrs.GeneralResultSpec( 

1860 dimensions=self.universe.conform(["detector"]), 

1861 dimension_fields={"detector": {"purpose"}}, 

1862 dataset_fields={}, 

1863 find_first=False, 

1864 ) 

1865 self.assertEqual(a.find_first_dataset, None) 

1866 a_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1867 a_columns.dimension_fields["detector"].add("purpose") 

1868 self.assertEqual(a.get_result_columns(), a_columns) 

1869 b = qrs.GeneralResultSpec( 

1870 dimensions=self.universe.conform(["detector"]), 

1871 dimension_fields={}, 

1872 dataset_fields={"bias": {"timespan", "dataset_id"}}, 

1873 find_first=True, 

1874 ) 

1875 self.assertEqual(b.find_first_dataset, "bias") 

1876 b_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1877 b_columns.dataset_fields["bias"].add("timespan") 

1878 b_columns.dataset_fields["bias"].add("dataset_id") 

1879 self.assertEqual(b.get_result_columns(), b_columns) 

1880 with self.assertRaises(InvalidQueryError): 

1881 # More than one dataset type with find_first 

1882 qrs.GeneralResultSpec( 

1883 dimensions=self.universe.conform(["detector", "exposure"]), 

1884 dimension_fields={}, 

1885 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}}, 

1886 find_first=True, 

1887 ) 

1888 with self.assertRaises(InvalidQueryError): 

1889 # Out-of-bounds dimension fields. 

1890 qrs.GeneralResultSpec( 

1891 dimensions=self.universe.conform(["detector"]), 

1892 dimension_fields={"visit": {"name"}}, 

1893 dataset_fields={}, 

1894 find_first=False, 

1895 ) 

1896 with self.assertRaises(InvalidQueryError): 

1897 # No fields for dimension element. 

1898 qrs.GeneralResultSpec( 

1899 dimensions=self.universe.conform(["detector"]), 

1900 dimension_fields={"detector": set()}, 

1901 dataset_fields={}, 

1902 find_first=True, 

1903 ) 

1904 with self.assertRaises(InvalidQueryError): 

1905 # No fields for dataset. 

1906 qrs.GeneralResultSpec( 

1907 dimensions=self.universe.conform(["detector"]), 

1908 dimension_fields={}, 

1909 dataset_fields={"bias": set()}, 

1910 find_first=True, 

1911 ) 

1912 

1913 

1914if __name__ == "__main__": 

1915 unittest.main()