Coverage for tests/test_query_interface.py: 10%

936 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-07 02:46 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests for the public Butler._query interface and the Pydantic models that 

29back it using a mock column-expression visitor and a mock QueryDriver 

30implementation. 

31 

32These tests are entirely independent of which kind of butler or database 

33backend we're using. 

34 

35This is a very large test file because a lot of tests make use of those mocks, 

36but they're not so generally useful that I think they're worth putting in the 

37library proper. 

38""" 

39 

40from __future__ import annotations 

41 

42import dataclasses 

43import itertools 

44import unittest 

45import uuid 

46from collections.abc import Iterable, Iterator, Mapping, Set 

47from typing import Any 

48 

49import astropy.time 

50from lsst.daf.butler import ( 

51 CollectionType, 

52 DataCoordinate, 

53 DataIdValue, 

54 DatasetRef, 

55 DatasetType, 

56 DimensionGroup, 

57 DimensionRecord, 

58 DimensionRecordSet, 

59 DimensionUniverse, 

60 InvalidQueryError, 

61 MissingDatasetTypeError, 

62 NamedValueSet, 

63 NoDefaultCollectionError, 

64 Timespan, 

65) 

66from lsst.daf.butler.queries import ( 

67 DataCoordinateQueryResults, 

68 DatasetRefQueryResults, 

69 DimensionRecordQueryResults, 

70 Query, 

71) 

72from lsst.daf.butler.queries import driver as qd 

73from lsst.daf.butler.queries import result_specs as qrs 

74from lsst.daf.butler.queries import tree as qt 

75from lsst.daf.butler.queries.expression_factory import ExpressionFactory 

76from lsst.daf.butler.queries.tree._column_expression import UnaryExpression 

77from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands 

78from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor 

79from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError 

80from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord 

81from lsst.sphgeom import DISJOINT, Mq3cPixelization 

82 

83 

84class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]): 

85 """Test visitor for column expressions. 

86 

87 This visitor evaluates column expressions using regular Python logic. 

88 

89 Parameters 

90 ---------- 

91 dimension_keys : `~collections.abc.Mapping`, optional 

92 Mapping from dimension name to the value it should be assigned by the 

93 visitor. 

94 dimension_fields : `~collections.abc.Mapping`, optional 

95 Mapping from ``(dimension element name, field)`` tuple to the value it 

96 should be assigned by the visitor. 

97 dataset_fields : `~collections.abc.Mapping`, optional 

98 Mapping from ``(dataset type name, field)`` tuple to the value it 

99 should be assigned by the visitor. 

100 query_tree_items : `~collections.abc.Set`, optional 

101 Set that should be used as the right-hand side of element-in-query 

102 predicates. 

103 """ 

104 

105 def __init__( 

106 self, 

107 dimension_keys: Mapping[str, Any] | None = None, 

108 dimension_fields: Mapping[tuple[str, str], Any] | None = None, 

109 dataset_fields: Mapping[tuple[str, str], Any] | None = None, 

110 query_tree_items: Set[Any] = frozenset(), 

111 ): 

112 self.dimension_keys = dimension_keys or {} 

113 self.dimension_fields = dimension_fields or {} 

114 self.dataset_fields = dataset_fields or {} 

115 self.query_tree_items = query_tree_items 

116 

117 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any: 

118 match expression.operator: 

119 case "+": 

120 return expression.a.visit(self) + expression.b.visit(self) 

121 case "-": 

122 return expression.a.visit(self) - expression.b.visit(self) 

123 case "*": 

124 return expression.a.visit(self) * expression.b.visit(self) 

125 case "/": 

126 match expression.column_type: 

127 case "int": 

128 return expression.a.visit(self) // expression.b.visit(self) 

129 case "float": 

130 return expression.a.visit(self) / expression.b.visit(self) 

131 case "%": 

132 return expression.a.visit(self) % expression.b.visit(self) 

133 

134 def visit_comparison( 

135 self, 

136 a: qt.ColumnExpression, 

137 operator: qt.ComparisonOperator, 

138 b: qt.ColumnExpression, 

139 flags: PredicateVisitFlags, 

140 ) -> bool: 

141 match operator: 

142 case "==": 

143 return a.visit(self) == b.visit(self) 

144 case "!=": 

145 return a.visit(self) != b.visit(self) 

146 case "<": 

147 return a.visit(self) < b.visit(self) 

148 case ">": 

149 return a.visit(self) > b.visit(self) 

150 case "<=": 

151 return a.visit(self) <= b.visit(self) 

152 case ">=": 

153 return a.visit(self) >= b.visit(self) 

154 case "overlaps": 

155 return not (a.visit(self).relate(b.visit(self)) & DISJOINT) 

156 

157 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any: 

158 return self.dataset_fields[expression.dataset_type, expression.field] 

159 

160 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any: 

161 return self.dimension_fields[expression.element.name, expression.field] 

162 

163 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any: 

164 return self.dimension_keys[expression.dimension.name] 

165 

166 def visit_in_container( 

167 self, 

168 member: qt.ColumnExpression, 

169 container: tuple[qt.ColumnExpression, ...], 

170 flags: PredicateVisitFlags, 

171 ) -> bool: 

172 return member.visit(self) in [item.visit(self) for item in container] 

173 

174 def visit_in_range( 

175 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags 

176 ) -> bool: 

177 return member.visit(self) in range(start, stop, step) 

178 

179 def visit_in_query_tree( 

180 self, 

181 member: qt.ColumnExpression, 

182 column: qt.ColumnExpression, 

183 query_tree: qt.QueryTree, 

184 flags: PredicateVisitFlags, 

185 ) -> bool: 

186 return member.visit(self) in self.query_tree_items 

187 

188 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool: 

189 return operand.visit(self) is None 

190 

191 def visit_literal(self, expression: qt.ColumnLiteral) -> Any: 

192 return expression.get_literal_value() 

193 

194 def visit_reversed(self, expression: qt.Reversed) -> Any: 

195 return _TestReversed(expression.operand.visit(self)) 

196 

197 def visit_unary_expression(self, expression: UnaryExpression) -> Any: 

198 match expression.operator: 

199 case "-": 

200 return -expression.operand.visit(self) 

201 case "begin_of": 

202 return expression.operand.visit(self).begin 

203 case "end_of": 

204 return expression.operand.visit(self).end 

205 

206 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool: 

207 return all(results) 

208 

209 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool: 

210 return not result 

211 

212 def apply_logical_or( 

213 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags 

214 ) -> bool: 

215 return any(results) 

216 

217 

218@dataclasses.dataclass 

219class _TestReversed: 

220 """Struct used by _TestVisitor" to mark an expression as reversed in sort 

221 order. 

222 """ 

223 

224 operand: Any 

225 

226 

227class _TestQueryExecution(BaseException): 

228 """Exception raised by _TestQueryDriver.execute to communicate its args 

229 back to the caller. 

230 """ 

231 

232 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None: 

233 self.result_spec = result_spec 

234 self.tree = tree 

235 self.driver = driver 

236 

237 

238class _TestQueryCount(BaseException): 

239 """Exception raised by _TestQueryDriver.count to communicate its args 

240 back to the caller. 

241 """ 

242 

243 def __init__( 

244 self, 

245 result_spec: qrs.ResultSpec, 

246 tree: qt.QueryTree, 

247 driver: _TestQueryDriver, 

248 exact: bool, 

249 discard: bool, 

250 ) -> None: 

251 self.result_spec = result_spec 

252 self.tree = tree 

253 self.driver = driver 

254 self.exact = exact 

255 self.discard = discard 

256 

257 

258class _TestQueryAny(BaseException): 

259 """Exception raised by _TestQueryDriver.any to communicate its args 

260 back to the caller. 

261 """ 

262 

263 def __init__( 

264 self, 

265 tree: qt.QueryTree, 

266 driver: _TestQueryDriver, 

267 exact: bool, 

268 execute: bool, 

269 ) -> None: 

270 self.tree = tree 

271 self.driver = driver 

272 self.exact = exact 

273 self.execute = execute 

274 

275 

276class _TestQueryExplainNoResults(BaseException): 

277 """Exception raised by _TestQueryDriver.explain_no_results to communicate 

278 its args back to the caller. 

279 """ 

280 

281 def __init__( 

282 self, 

283 tree: qt.QueryTree, 

284 driver: _TestQueryDriver, 

285 execute: bool, 

286 ) -> None: 

287 self.tree = tree 

288 self.driver = driver 

289 self.execute = execute 

290 

291 

292class _TestQueryDriver(qd.QueryDriver): 

293 """Mock implementation of `QueryDriver` that mostly raises exceptions that 

294 communicate the arguments its methods were called with. 

295 

296 Parameters 

297 ---------- 

298 default_collections : `tuple` [ `str`, ... ], optional 

299 Default collection the query or parent butler is imagined to have been 

300 constructed with. 

301 collection_info : `~collections.abc.Mapping`, optional 

302 Mapping from collection name to its record and summary, simulating the 

303 collections present in the data repository. 

304 dataset_types : `~collections.abc.Mapping`, optional 

305 Mapping from dataset type to its definition, simulating the dataset 

306 types registered in the data repository. 

307 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional 

308 A tuple of iterables of arbitrary type to use as result rows any time 

309 `execute` is called, with each nested iterable considered a separate 

310 page. The result type is not checked for consistency with the result 

311 spec. If this is not provided, `execute` will instead raise 

312 `_TestQueryExecution`, and `fetch_page` will not do anything useful. 

313 """ 

314 

315 def __init__( 

316 self, 

317 default_collections: tuple[str, ...] | None = None, 

318 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None, 

319 dataset_types: Mapping[str, DatasetType] | None = None, 

320 result_rows: tuple[Iterable[Any], ...] | None = None, 

321 ) -> None: 

322 self._universe = DimensionUniverse() 

323 # Mapping of the arguments passed to materialize, keyed by the UUID 

324 # that that each call returned. 

325 self.materializations: dict[ 

326 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]] 

327 ] = {} 

328 # Mapping of the arguments passed to upload_data_coordinates, keyed by 

329 # the UUID that that each call returned. 

330 self.data_coordinate_uploads: dict[ 

331 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]] 

332 ] = {} 

333 self._default_collections = default_collections 

334 self._collection_info = collection_info or {} 

335 self._dataset_types = dataset_types or {} 

336 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = [] 

337 self._result_rows = result_rows 

338 self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {} 

339 

340 @property 

341 def universe(self) -> DimensionUniverse: 

342 return self._universe 

343 

344 def __enter__(self) -> None: 

345 pass 

346 

347 def __exit__(self, *args: Any, **kwargs: Any) -> None: 

348 pass 

349 

350 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage: 

351 if self._result_rows is not None: 

352 iterator = iter(self._result_rows) 

353 current_rows = next(iterator, ()) 

354 return self._make_next_page(result_spec, current_rows, iterator) 

355 raise _TestQueryExecution(result_spec, tree, self) 

356 

357 def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage: 

358 if self._result_rows is not None: 

359 return self._make_next_page(result_spec, *self._result_iters.pop(key)) 

360 raise AssertionError("Test query driver not initialized for actual results.") 

361 

362 def _make_next_page( 

363 self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]] 

364 ) -> qd.ResultPage: 

365 next_rows = list(next(iterator, ())) 

366 if not next_rows: 

367 next_key = None 

368 else: 

369 next_key = uuid.uuid4() 

370 self._result_iters[next_key] = (next_rows, iterator) 

371 match result_spec: 

372 case qrs.DataCoordinateResultSpec(): 

373 return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

374 case qrs.DimensionRecordResultSpec(): 

375 return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

376 case qrs.DatasetRefResultSpec(): 

377 return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

378 case _: 

379 raise NotImplementedError("Other query types not yet supported.") 

380 

381 def materialize( 

382 self, 

383 tree: qt.QueryTree, 

384 dimensions: DimensionGroup, 

385 datasets: frozenset[str], 

386 ) -> qd.MaterializationKey: 

387 key = uuid.uuid4() 

388 self.materializations[key] = (tree, dimensions, datasets) 

389 return key 

390 

391 def upload_data_coordinates( 

392 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] 

393 ) -> qd.DataCoordinateUploadKey: 

394 key = uuid.uuid4() 

395 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows)) 

396 return key 

397 

398 def count( 

399 self, 

400 tree: qt.QueryTree, 

401 result_spec: qrs.ResultSpec, 

402 *, 

403 exact: bool, 

404 discard: bool, 

405 ) -> int: 

406 raise _TestQueryCount(result_spec, tree, self, exact, discard) 

407 

408 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: 

409 raise _TestQueryAny(tree, self, exact, execute) 

410 

411 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: 

412 raise _TestQueryExplainNoResults(tree, self, execute) 

413 

414 def get_default_collections(self) -> tuple[str, ...]: 

415 if self._default_collections is None: 

416 raise NoDefaultCollectionError() 

417 return self._default_collections 

418 

419 def get_dataset_type(self, name: str) -> DatasetType: 

420 try: 

421 return self._dataset_types[name] 

422 except KeyError: 

423 raise MissingDatasetTypeError(name) 

424 

425 

426class ColumnExpressionsTestCase(unittest.TestCase): 

427 """Tests for column expression objects in lsst.daf.butler.queries.tree.""" 

428 

429 def setUp(self) -> None: 

430 self.universe = DimensionUniverse() 

431 self.x = ExpressionFactory(self.universe) 

432 

433 def query(self, **kwargs: Any) -> Query: 

434 """Make an initial Query object with the given kwargs used to 

435 initialize the _TestQueryDriver. 

436 """ 

437 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

438 

439 def test_int_literals(self) -> None: 

440 expr = self.x.unwrap(self.x.literal(5)) 

441 self.assertEqual(expr.value, 5) 

442 self.assertEqual(expr.get_literal_value(), 5) 

443 self.assertEqual(expr.expression_type, "int") 

444 self.assertEqual(expr.column_type, "int") 

445 self.assertEqual(str(expr), "5") 

446 self.assertTrue(expr.is_literal) 

447 columns = qt.ColumnSet(self.universe.empty.as_group()) 

448 expr.gather_required_columns(columns) 

449 self.assertFalse(columns) 

450 self.assertEqual(expr.visit(_TestVisitor()), 5) 

451 

452 def test_string_literals(self) -> None: 

453 expr = self.x.unwrap(self.x.literal("five")) 

454 self.assertEqual(expr.value, "five") 

455 self.assertEqual(expr.get_literal_value(), "five") 

456 self.assertEqual(expr.expression_type, "string") 

457 self.assertEqual(expr.column_type, "string") 

458 self.assertEqual(str(expr), "'five'") 

459 self.assertTrue(expr.is_literal) 

460 columns = qt.ColumnSet(self.universe.empty.as_group()) 

461 expr.gather_required_columns(columns) 

462 self.assertFalse(columns) 

463 self.assertEqual(expr.visit(_TestVisitor()), "five") 

464 

465 def test_float_literals(self) -> None: 

466 expr = self.x.unwrap(self.x.literal(0.5)) 

467 self.assertEqual(expr.value, 0.5) 

468 self.assertEqual(expr.get_literal_value(), 0.5) 

469 self.assertEqual(expr.expression_type, "float") 

470 self.assertEqual(expr.column_type, "float") 

471 self.assertEqual(str(expr), "0.5") 

472 self.assertTrue(expr.is_literal) 

473 columns = qt.ColumnSet(self.universe.empty.as_group()) 

474 expr.gather_required_columns(columns) 

475 self.assertFalse(columns) 

476 self.assertEqual(expr.visit(_TestVisitor()), 0.5) 

477 

478 def test_hash_literals(self) -> None: 

479 expr = self.x.unwrap(self.x.literal(b"eleven")) 

480 self.assertEqual(expr.value, b"eleven") 

481 self.assertEqual(expr.get_literal_value(), b"eleven") 

482 self.assertEqual(expr.expression_type, "hash") 

483 self.assertEqual(expr.column_type, "hash") 

484 self.assertEqual(str(expr), "(bytes)") 

485 self.assertTrue(expr.is_literal) 

486 columns = qt.ColumnSet(self.universe.empty.as_group()) 

487 expr.gather_required_columns(columns) 

488 self.assertFalse(columns) 

489 self.assertEqual(expr.visit(_TestVisitor()), b"eleven") 

490 

491 def test_uuid_literals(self) -> None: 

492 value = uuid.uuid4() 

493 expr = self.x.unwrap(self.x.literal(value)) 

494 self.assertEqual(expr.value, value) 

495 self.assertEqual(expr.get_literal_value(), value) 

496 self.assertEqual(expr.expression_type, "uuid") 

497 self.assertEqual(expr.column_type, "uuid") 

498 self.assertEqual(str(expr), str(value)) 

499 self.assertTrue(expr.is_literal) 

500 columns = qt.ColumnSet(self.universe.empty.as_group()) 

501 expr.gather_required_columns(columns) 

502 self.assertFalse(columns) 

503 self.assertEqual(expr.visit(_TestVisitor()), value) 

504 

505 def test_datetime_literals(self) -> None: 

506 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

507 expr = self.x.unwrap(self.x.literal(value)) 

508 self.assertEqual(expr.value, value) 

509 self.assertEqual(expr.get_literal_value(), value) 

510 self.assertEqual(expr.expression_type, "datetime") 

511 self.assertEqual(expr.column_type, "datetime") 

512 self.assertEqual(str(expr), "2020-01-01T00:00:00") 

513 self.assertTrue(expr.is_literal) 

514 columns = qt.ColumnSet(self.universe.empty.as_group()) 

515 expr.gather_required_columns(columns) 

516 self.assertFalse(columns) 

517 self.assertEqual(expr.visit(_TestVisitor()), value) 

518 

519 def test_timespan_literals(self) -> None: 

520 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

521 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

522 value = Timespan(begin, end) 

523 expr = self.x.unwrap(self.x.literal(value)) 

524 self.assertEqual(expr.value, value) 

525 self.assertEqual(expr.get_literal_value(), value) 

526 self.assertEqual(expr.expression_type, "timespan") 

527 self.assertEqual(expr.column_type, "timespan") 

528 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)") 

529 self.assertTrue(expr.is_literal) 

530 columns = qt.ColumnSet(self.universe.empty.as_group()) 

531 expr.gather_required_columns(columns) 

532 self.assertFalse(columns) 

533 self.assertEqual(expr.visit(_TestVisitor()), value) 

534 

535 def test_region_literals(self) -> None: 

536 pixelization = Mq3cPixelization(10) 

537 value = pixelization.quad(12058870) 

538 expr = self.x.unwrap(self.x.literal(value)) 

539 self.assertEqual(expr.value, value) 

540 self.assertEqual(expr.get_literal_value(), value) 

541 self.assertEqual(expr.expression_type, "region") 

542 self.assertEqual(expr.column_type, "region") 

543 self.assertEqual(str(expr), "(region)") 

544 self.assertTrue(expr.is_literal) 

545 columns = qt.ColumnSet(self.universe.empty.as_group()) 

546 expr.gather_required_columns(columns) 

547 self.assertFalse(columns) 

548 self.assertEqual(expr.visit(_TestVisitor()), value) 

549 

550 def test_invalid_literal(self) -> None: 

551 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."): 

552 self.x.literal(5j) 

553 

554 def test_dimension_key_reference(self) -> None: 

555 expr = self.x.unwrap(self.x.detector) 

556 self.assertIsNone(expr.get_literal_value()) 

557 self.assertEqual(expr.expression_type, "dimension_key") 

558 self.assertEqual(expr.column_type, "int") 

559 self.assertEqual(str(expr), "detector") 

560 self.assertFalse(expr.is_literal) 

561 columns = qt.ColumnSet(self.universe.empty.as_group()) 

562 expr.gather_required_columns(columns) 

563 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

564 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3) 

565 

566 def test_dimension_field_reference(self) -> None: 

567 expr = self.x.unwrap(self.x.detector.purpose) 

568 self.assertIsNone(expr.get_literal_value()) 

569 self.assertEqual(expr.expression_type, "dimension_field") 

570 self.assertEqual(expr.column_type, "string") 

571 self.assertEqual(str(expr), "detector.purpose") 

572 self.assertFalse(expr.is_literal) 

573 columns = qt.ColumnSet(self.universe.empty.as_group()) 

574 expr.gather_required_columns(columns) 

575 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

576 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

577 with self.assertRaises(InvalidQueryError): 

578 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region") 

579 self.assertEqual( 

580 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science" 

581 ) 

582 

583 def test_dataset_field_reference(self) -> None: 

584 expr = self.x.unwrap(self.x["raw"].ingest_date) 

585 self.assertIsNone(expr.get_literal_value()) 

586 self.assertEqual(expr.expression_type, "dataset_field") 

587 self.assertEqual(str(expr), "raw.ingest_date") 

588 self.assertFalse(expr.is_literal) 

589 columns = qt.ColumnSet(self.universe.empty.as_group()) 

590 expr.gather_required_columns(columns) 

591 self.assertEqual(columns.dimensions, self.universe.empty.as_group()) 

592 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"}) 

593 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid") 

594 self.assertEqual( 

595 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string" 

596 ) 

597 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string") 

598 self.assertEqual( 

599 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime" 

600 ) 

601 self.assertEqual( 

602 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan" 

603 ) 

604 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

605 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value) 

606 

607 def test_unary_negation(self) -> None: 

608 expr = self.x.unwrap(-self.x.visit.exposure_time) 

609 self.assertIsNone(expr.get_literal_value()) 

610 self.assertEqual(expr.expression_type, "unary") 

611 self.assertEqual(expr.column_type, "float") 

612 self.assertEqual(str(expr), "-visit.exposure_time") 

613 self.assertFalse(expr.is_literal) 

614 columns = qt.ColumnSet(self.universe.empty.as_group()) 

615 expr.gather_required_columns(columns) 

616 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

617 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

618 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0) 

619 with self.assertRaises(InvalidQueryError): 

620 qt.UnaryExpression( 

621 operand=qt.DimensionFieldReference( 

622 element=self.universe.dimensions["detector"], field="purpose" 

623 ), 

624 operator="-", 

625 ) 

626 

627 def test_unary_timespan_begin(self) -> None: 

628 expr = self.x.unwrap(self.x.visit.timespan.begin) 

629 self.assertIsNone(expr.get_literal_value()) 

630 self.assertEqual(expr.expression_type, "unary") 

631 self.assertEqual(expr.column_type, "datetime") 

632 self.assertEqual(str(expr), "visit.timespan.begin") 

633 self.assertFalse(expr.is_literal) 

634 columns = qt.ColumnSet(self.universe.empty.as_group()) 

635 expr.gather_required_columns(columns) 

636 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

637 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

638 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

639 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

640 value = Timespan(begin, end) 

641 self.assertEqual( 

642 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin 

643 ) 

644 with self.assertRaises(InvalidQueryError): 

645 qt.UnaryExpression( 

646 operand=qt.DimensionFieldReference( 

647 element=self.universe.dimensions["detector"], field="purpose" 

648 ), 

649 operator="begin_of", 

650 ) 

651 

652 def test_unary_timespan_end(self) -> None: 

653 expr = self.x.unwrap(self.x.visit.timespan.end) 

654 self.assertIsNone(expr.get_literal_value()) 

655 self.assertEqual(expr.expression_type, "unary") 

656 self.assertEqual(expr.column_type, "datetime") 

657 self.assertEqual(str(expr), "visit.timespan.end") 

658 self.assertFalse(expr.is_literal) 

659 columns = qt.ColumnSet(self.universe.empty.as_group()) 

660 expr.gather_required_columns(columns) 

661 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

662 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

663 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

664 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

665 value = Timespan(begin, end) 

666 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end) 

667 with self.assertRaises(InvalidQueryError): 

668 qt.UnaryExpression( 

669 operand=qt.DimensionFieldReference( 

670 element=self.universe.dimensions["detector"], field="purpose" 

671 ), 

672 operator="end_of", 

673 ) 

674 

675 def test_binary_expression_float(self) -> None: 

676 for proxy, string, value in [ 

677 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0), 

678 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0), 

679 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0), 

680 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0), 

681 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0), 

682 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0), 

683 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0), 

684 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0), 

685 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0), 

686 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0), 

687 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0), 

688 # We don't need the parentheses we generate in the next one, but 

689 # they're not a problem either. 

690 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0), 

691 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0), 

692 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0), 

693 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0), 

694 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0), 

695 ]: 

696 with self.subTest(string=string): 

697 expr = self.x.unwrap(proxy) 

698 self.assertIsNone(expr.get_literal_value()) 

699 self.assertEqual(expr.expression_type, "binary") 

700 self.assertEqual(expr.column_type, "float") 

701 self.assertEqual(str(expr), string) 

702 self.assertFalse(expr.is_literal) 

703 columns = qt.ColumnSet(self.universe.empty.as_group()) 

704 expr.gather_required_columns(columns) 

705 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

706 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

707 self.assertEqual( 

708 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value 

709 ) 

710 

711 def test_binary_modulus(self) -> None: 

712 for proxy, string, value in [ 

713 (self.x.visit.id % 2, "visit % 2", 1), 

714 (52 % self.x.visit, "52 % visit", 2), 

715 ]: 

716 with self.subTest(string=string): 

717 expr = self.x.unwrap(proxy) 

718 self.assertIsNone(expr.get_literal_value()) 

719 self.assertEqual(expr.expression_type, "binary") 

720 self.assertEqual(expr.column_type, "int") 

721 self.assertEqual(str(expr), string) 

722 self.assertFalse(expr.is_literal) 

723 columns = qt.ColumnSet(self.universe.empty.as_group()) 

724 expr.gather_required_columns(columns) 

725 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

726 self.assertFalse(columns.dimension_fields["visit"]) 

727 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value) 

728 

729 def test_binary_expression_validation(self) -> None: 

730 with self.assertRaises(InvalidQueryError): 

731 # No arithmetic operators on strings (we do not interpret + as 

732 # concatenation). 

733 self.x.instrument + "suffix" 

734 with self.assertRaises(InvalidQueryError): 

735 # Mixed types are not supported, even when they both support the 

736 # operator. 

737 self.x.visit.exposure_time + self.x.detector 

738 with self.assertRaises(InvalidQueryError): 

739 # No modulus for floats. 

740 self.x.visit.exposure_time % 5.0 

741 

742 def test_reversed(self) -> None: 

743 expr = self.x.detector.desc 

744 self.assertIsNone(expr.get_literal_value()) 

745 self.assertEqual(expr.expression_type, "reversed") 

746 self.assertEqual(expr.column_type, "int") 

747 self.assertEqual(str(expr), "detector DESC") 

748 self.assertFalse(expr.is_literal) 

749 columns = qt.ColumnSet(self.universe.empty.as_group()) 

750 expr.gather_required_columns(columns) 

751 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

752 self.assertFalse(columns.dimension_fields["detector"]) 

753 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5)) 

754 

755 def test_trivial_predicate(self) -> None: 

756 """Test logical operations on trivial True/False predicates.""" 

757 yes = qt.Predicate.from_bool(True) 

758 no = qt.Predicate.from_bool(False) 

759 maybe: qt.Predicate = self.x.detector == 5 

760 for predicate in [ 

761 yes, 

762 yes.logical_or(no), 

763 no.logical_or(yes), 

764 yes.logical_and(yes), 

765 no.logical_not(), 

766 yes.logical_or(maybe), 

767 maybe.logical_or(yes), 

768 ]: 

769 self.assertEqual(predicate.column_type, "bool") 

770 self.assertEqual(str(predicate), "True") 

771 self.assertTrue(predicate.visit(_TestVisitor())) 

772 self.assertEqual(predicate.operands, ()) 

773 for predicate in [ 

774 no, 

775 yes.logical_and(no), 

776 no.logical_and(yes), 

777 no.logical_or(no), 

778 yes.logical_not(), 

779 no.logical_and(maybe), 

780 maybe.logical_and(no), 

781 ]: 

782 self.assertEqual(predicate.column_type, "bool") 

783 self.assertEqual(str(predicate), "False") 

784 self.assertFalse(predicate.visit(_TestVisitor())) 

785 self.assertEqual(predicate.operands, ((),)) 

786 for predicate in [ 

787 maybe, 

788 yes.logical_and(maybe), 

789 no.logical_or(maybe), 

790 maybe.logical_not().logical_not(), 

791 ]: 

792 self.assertEqual(predicate.column_type, "bool") 

793 self.assertEqual(str(predicate), "detector == 5") 

794 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5}))) 

795 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4}))) 

796 self.assertEqual(len(predicate.operands), 1) 

797 self.assertEqual(len(predicate.operands[0]), 1) 

798 self.assertIs(predicate.operands[0][0], maybe.operands[0][0]) 

799 

800 def test_comparison(self) -> None: 

801 predicate: qt.Predicate 

802 string: str 

803 value: bool 

804 for detector in (4, 5, 6): 

805 for predicate, string, value in [ 

806 (self.x.detector == 5, "detector == 5", detector == 5), 

807 (self.x.detector != 5, "detector != 5", detector != 5), 

808 (self.x.detector < 5, "detector < 5", detector < 5), 

809 (self.x.detector > 5, "detector > 5", detector > 5), 

810 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

811 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

812 (self.x.detector == 5, "detector == 5", detector == 5), 

813 (self.x.detector != 5, "detector != 5", detector != 5), 

814 (self.x.detector < 5, "detector < 5", detector < 5), 

815 (self.x.detector > 5, "detector > 5", detector > 5), 

816 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

817 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

818 ]: 

819 with self.subTest(string=string, detector=detector): 

820 self.assertEqual(predicate.column_type, "bool") 

821 self.assertEqual(str(predicate), string) 

822 columns = qt.ColumnSet(self.universe.empty.as_group()) 

823 predicate.gather_required_columns(columns) 

824 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

825 self.assertFalse(columns.dimension_fields["detector"]) 

826 self.assertEqual( 

827 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value 

828 ) 

829 inverted = predicate.logical_not() 

830 self.assertEqual(inverted.column_type, "bool") 

831 self.assertEqual(str(inverted), f"NOT {string}") 

832 self.assertEqual( 

833 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value 

834 ) 

835 columns = qt.ColumnSet(self.universe.empty.as_group()) 

836 inverted.gather_required_columns(columns) 

837 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

838 self.assertFalse(columns.dimension_fields["detector"]) 

839 

840 def test_overlap_comparison(self) -> None: 

841 pixelization = Mq3cPixelization(10) 

842 region1 = pixelization.quad(12058870) 

843 predicate = self.x.visit.region.overlaps(region1) 

844 self.assertEqual(predicate.column_type, "bool") 

845 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)") 

846 columns = qt.ColumnSet(self.universe.empty.as_group()) 

847 predicate.gather_required_columns(columns) 

848 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

849 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

850 region2 = pixelization.quad(12058857) 

851 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

852 inverted = predicate.logical_not() 

853 self.assertEqual(inverted.column_type, "bool") 

854 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)") 

855 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

856 columns = qt.ColumnSet(self.universe.empty.as_group()) 

857 inverted.gather_required_columns(columns) 

858 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

859 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

860 

861 def test_invalid_comparison(self) -> None: 

862 # Mixed type comparisons. 

863 with self.assertRaises(InvalidQueryError): 

864 self.x.visit > "three" 

865 with self.assertRaises(InvalidQueryError): 

866 self.x.visit > 3.0 

867 # Invalid operator for type. 

868 with self.assertRaises(InvalidQueryError): 

869 self.x["raw"].dataset_id < uuid.uuid4() 

870 

871 def test_is_null(self) -> None: 

872 predicate = self.x.visit.region.is_null 

873 self.assertEqual(predicate.column_type, "bool") 

874 self.assertEqual(str(predicate), "visit.region IS NULL") 

875 columns = qt.ColumnSet(self.universe.empty.as_group()) 

876 predicate.gather_required_columns(columns) 

877 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

878 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

879 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

880 inverted = predicate.logical_not() 

881 self.assertEqual(inverted.column_type, "bool") 

882 self.assertEqual(str(inverted), "NOT visit.region IS NULL") 

883 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

884 inverted.gather_required_columns(columns) 

885 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

886 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

887 

888 def test_in_container(self) -> None: 

889 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id]) 

890 self.assertEqual(predicate.column_type, "bool") 

891 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]") 

892 columns = qt.ColumnSet(self.universe.empty.as_group()) 

893 predicate.gather_required_columns(columns) 

894 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

895 self.assertFalse(columns.dimension_fields["visit"]) 

896 self.assertFalse(columns.dimension_fields["exposure"]) 

897 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

898 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

899 inverted = predicate.logical_not() 

900 self.assertEqual(inverted.column_type, "bool") 

901 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]") 

902 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

903 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

904 columns = qt.ColumnSet(self.universe.empty.as_group()) 

905 inverted.gather_required_columns(columns) 

906 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

907 self.assertFalse(columns.dimension_fields["visit"]) 

908 self.assertFalse(columns.dimension_fields["exposure"]) 

909 with self.assertRaises(InvalidQueryError): 

910 # Regions (and timespans) not allowed in IN expressions, since that 

911 # suggests topological logic we're not actually doing. We can't 

912 # use ExpressionFactory because it prohibits this case with typing. 

913 pixelization = Mq3cPixelization(10) 

914 region = pixelization.quad(12058870) 

915 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)]) 

916 with self.assertRaises(InvalidQueryError): 

917 # Mismatched types. 

918 self.x.visit.in_iterable([3.5, 2.1]) 

919 

920 def test_in_range(self) -> None: 

921 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2) 

922 self.assertEqual(predicate.column_type, "bool") 

923 self.assertEqual(str(predicate), "visit IN 2:8:2") 

924 columns = qt.ColumnSet(self.universe.empty.as_group()) 

925 predicate.gather_required_columns(columns) 

926 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

927 self.assertFalse(columns.dimension_fields["visit"]) 

928 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

929 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

930 inverted = predicate.logical_not() 

931 self.assertEqual(inverted.column_type, "bool") 

932 self.assertEqual(str(inverted), "NOT visit IN 2:8:2") 

933 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

934 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

935 columns = qt.ColumnSet(self.universe.empty.as_group()) 

936 inverted.gather_required_columns(columns) 

937 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

938 self.assertFalse(columns.dimension_fields["visit"]) 

939 with self.assertRaises(InvalidQueryError): 

940 # Only integer fields allowed. 

941 self.x.visit.exposure_time.in_range(2, 4) 

942 with self.assertRaises(InvalidQueryError): 

943 # Step must be positive. 

944 self.x.visit.in_range(2, 4, -1) 

945 with self.assertRaises(InvalidQueryError): 

946 # Stop must be >= start. 

947 self.x.visit.in_range(2, 0) 

948 

949 def test_in_query(self) -> None: 

950 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3) 

951 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query) 

952 self.assertEqual(predicate.column_type, "bool") 

953 self.assertEqual(str(predicate), "exposure IN (query).visit") 

954 columns = qt.ColumnSet(self.universe.empty.as_group()) 

955 predicate.gather_required_columns(columns) 

956 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

957 self.assertFalse(columns.dimension_fields["exposure"]) 

958 self.assertTrue( 

959 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

960 ) 

961 self.assertFalse( 

962 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

963 ) 

964 inverted = predicate.logical_not() 

965 self.assertEqual(inverted.column_type, "bool") 

966 self.assertEqual(str(inverted), "NOT exposure IN (query).visit") 

967 self.assertFalse( 

968 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

969 ) 

970 self.assertTrue( 

971 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

972 ) 

973 columns = qt.ColumnSet(self.universe.empty.as_group()) 

974 inverted.gather_required_columns(columns) 

975 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

976 self.assertFalse(columns.dimension_fields["exposure"]) 

977 with self.assertRaises(InvalidQueryError): 

978 # Regions (and timespans) not allowed in IN expressions, since that 

979 # suggests topological logic we're not actually doing. We can't 

980 # use ExpressionFactory because it prohibits this case with typing. 

981 qt.Predicate.in_query( 

982 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree 

983 ) 

984 with self.assertRaises(InvalidQueryError): 

985 # Mismatched types. 

986 self.x.exposure.in_query(self.x.visit.exposure_time, query) 

987 with self.assertRaises(InvalidQueryError): 

988 # Query column requires dimensions that are not in the query. 

989 self.x.exposure.in_query(self.x.patch, query) 

990 with self.assertRaises(InvalidQueryError): 

991 # Query column requires dataset type that is not in the query. 

992 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query) 

993 

994 def test_complex_predicate(self) -> None: 

995 """Test that predicates are converted to conjunctive normal form and 

996 get parentheses in the right places when stringified. 

997 """ 

998 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"}) 

999 a: qt.Predicate = self.x.visit > 5 # will evaluate to True 

1000 b: qt.Predicate = self.x.detector != 3 # will evaluate to False 

1001 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True 

1002 d: qt.Predicate = self.x.band == "g" # will evaluate to False 

1003 predicate: qt.Predicate 

1004 for predicate, string, value in [ 

1005 (a.logical_or(b), f"{a} OR {b}", True), 

1006 (a.logical_or(c), f"{a} OR {c}", True), 

1007 (b.logical_or(d), f"{b} OR {d}", False), 

1008 (a.logical_and(b), f"{a} AND {b}", False), 

1009 (a.logical_and(c), f"{a} AND {c}", True), 

1010 (b.logical_and(d), f"{b} AND {d}", False), 

1011 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True), 

1012 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False), 

1013 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True), 

1014 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False), 

1015 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True), 

1016 ( 

1017 a.logical_and(b).logical_or(c.logical_and(d)), 

1018 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})", 

1019 False, 

1020 ), 

1021 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False), 

1022 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False), 

1023 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True), 

1024 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True), 

1025 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False), 

1026 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True), 

1027 ( 

1028 self.x.not_(a.logical_or(b).logical_and(c)), 

1029 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})", 

1030 False, 

1031 ), 

1032 ( 

1033 a.logical_and(b.logical_or(d)).logical_not(), 

1034 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})", 

1035 True, 

1036 ), 

1037 ]: 

1038 with self.subTest(string=string): 

1039 self.assertEqual(str(predicate), string) 

1040 self.assertEqual(predicate.visit(visitor), value) 

1041 

1042 def test_proxy_misc(self) -> None: 

1043 """Test miscellaneous things on various ExpressionFactory proxies.""" 

1044 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region") 

1045 self.assertEqual(str(self.x.visit.instrument), "instrument") 

1046 self.assertEqual(str(self.x["raw"]), "raw") 

1047 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date") 

1048 self.assertEqual( 

1049 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)), 

1050 "visit.timespan OVERLAPS raw.timespan", 

1051 ) 

1052 self.assertGreater( 

1053 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"} 

1054 ) 

1055 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"}) 

1056 with self.assertRaises(AttributeError): 

1057 self.x["raw"].seq_num 

1058 with self.assertRaises(AttributeError): 

1059 self.x.visit.horse 

1060 

1061 

1062class QueryTestCase(unittest.TestCase): 

1063 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries.""" 

1064 

1065 def setUp(self) -> None: 

1066 self.maxDiff = None 

1067 self.universe = DimensionUniverse() 

1068 # We use ArrowTable as the storage class for all dataset types because 

1069 # it's got conversions that only require third-party packages we 

1070 # already require. 

1071 self.raw = DatasetType( 

1072 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable" 

1073 ) 

1074 self.refcat = DatasetType( 

1075 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable" 

1076 ) 

1077 self.bias = DatasetType( 

1078 "bias", 

1079 dimensions=self.universe.conform(["detector"]), 

1080 storageClass="ArrowTable", 

1081 isCalibration=True, 

1082 ) 

1083 self.default_collections: list[str] | None = ["DummyCam/defaults"] 

1084 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = { 

1085 "DummyCam/raw/all": ( 

1086 RunRecord[int](1, name="DummyCam/raw/all"), 

1087 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}), 

1088 ), 

1089 "DummyCam/calib": ( 

1090 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION), 

1091 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}), 

1092 ), 

1093 "refcats": ( 

1094 RunRecord[int](3, name="refcats"), 

1095 CollectionSummary(NamedValueSet({self.refcat}), governors={}), 

1096 ), 

1097 "DummyCam/defaults": ( 

1098 ChainedCollectionRecord[int]( 

1099 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats") 

1100 ), 

1101 CollectionSummary( 

1102 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}} 

1103 ), 

1104 ), 

1105 } 

1106 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias} 

1107 

1108 def query(self, **kwargs: Any) -> Query: 

1109 """Make an initial Query object with the given kwargs used to 

1110 initialize the _TestQueryDriver. 

1111 

1112 The given kwargs override the test-case-attribute defaults. 

1113 """ 

1114 kwargs.setdefault("default_collections", self.default_collections) 

1115 kwargs.setdefault("collection_info", self.collection_info) 

1116 kwargs.setdefault("dataset_types", self.dataset_types) 

1117 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

1118 

1119 def test_dataset_join(self) -> None: 

1120 """Test queries that have had a dataset search explicitly joined in via 

1121 Query.join_dataset_search. 

1122 

1123 Since this kind of query has a moderate amount of complexity, this is 

1124 where we get a lot of basic coverage that applies to all kinds of 

1125 queries, including: 

1126 

1127 - getting data ID and dataset results (but not iterating over them); 

1128 - the 'any' and 'explain_no_results' methods; 

1129 - adding 'where' filters (but not expanding dimensions accordingly); 

1130 - materializations. 

1131 """ 

1132 

1133 def check( 

1134 query: Query, 

1135 dimensions: DimensionGroup = self.raw.dimensions.as_group(), 

1136 ) -> None: 

1137 """Run a battery of tests on one of a set of very similar queries 

1138 constructed in different ways (see below). 

1139 """ 

1140 

1141 def check_query_tree( 

1142 tree: qt.QueryTree, 

1143 dimensions: DimensionGroup = dimensions, 

1144 ) -> None: 

1145 """Check the state of the QueryTree object that backs the Query 

1146 or a derived QueryResults object. 

1147 

1148 Parameters 

1149 ---------- 

1150 tree : `lsst.daf.butler.queries.tree.QueryTree` 

1151 Object to test. 

1152 dimensions : `DimensionGroup` 

1153 Dimensions to expect in the `QueryTree`, not necessarily 

1154 including those in the test 'raw' dataset type. 

1155 """ 

1156 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group()) 

1157 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'") 

1158 self.assertFalse(tree.materializations) 

1159 self.assertFalse(tree.data_coordinate_uploads) 

1160 self.assertEqual(tree.datasets.keys(), {"raw"}) 

1161 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) 

1162 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1163 self.assertEqual( 

1164 tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()}) 

1165 ) 

1166 

1167 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None: 

1168 """Construct a DataCoordinateQueryResults object from the query 

1169 with the given arguments and run a battery of tests on it. 

1170 

1171 Parameters 

1172 ---------- 

1173 *args 

1174 Forwarded to `Query.data_ids`. 

1175 query : `Query` 

1176 Query to start from. 

1177 dimensions : `DimensionGroup`, optional 

1178 Dimensions the result data IDs should have. 

1179 """ 

1180 with self.assertRaises(_TestQueryExecution) as cm: 

1181 list(query.data_ids(*args)) 

1182 self.assertEqual( 

1183 cm.exception.result_spec, 

1184 qrs.DataCoordinateResultSpec(dimensions=dimensions), 

1185 ) 

1186 check_query_tree(cm.exception.tree, dimensions=dimensions) 

1187 

1188 def check_dataset_results( 

1189 *args: Any, 

1190 query: Query, 

1191 find_first: bool = True, 

1192 storage_class_name: str = self.raw.storageClass_name, 

1193 ) -> None: 

1194 """Construct a DatasetRefQueryResults object from the query 

1195 with the given arguments and run a battery of tests on it. 

1196 

1197 Parameters 

1198 ---------- 

1199 *args 

1200 Forwarded to `Query.datasets`. 

1201 query : `Query` 

1202 Query to start from. 

1203 find_first : `bool`, optional 

1204 Whether to do find-first resolution on the results. 

1205 storage_class_name : `str`, optional 

1206 Expected name of the storage class for the results. 

1207 """ 

1208 with self.assertRaises(_TestQueryExecution) as cm: 

1209 list(query.datasets(*args, find_first=find_first)) 

1210 self.assertEqual( 

1211 cm.exception.result_spec, 

1212 qrs.DatasetRefResultSpec( 

1213 dataset_type_name="raw", 

1214 dimensions=self.raw.dimensions.as_group(), 

1215 storage_class_name=storage_class_name, 

1216 find_first=find_first, 

1217 ), 

1218 ) 

1219 check_query_tree(cm.exception.tree) 

1220 

1221 def check_materialization( 

1222 kwargs: Mapping[str, Any], 

1223 query: Query, 

1224 dimensions: DimensionGroup = dimensions, 

1225 has_dataset: bool = True, 

1226 ) -> None: 

1227 """Materialize the query with the given arguments and run a 

1228 battery of tests on the result. 

1229 

1230 Parameters 

1231 ---------- 

1232 kwargs 

1233 Forwarded as keyword arguments to `Query.materialize`. 

1234 query : `Query` 

1235 Query to start from. 

1236 dimensions : `DimensionGroup`, optional 

1237 Dimensions to expect in the materialization and its derived 

1238 query. 

1239 has_dataset : `bool`, optional 

1240 Whether the query backed by the materialization should 

1241 still have the test 'raw' dataset joined in. 

1242 """ 

1243 # Materialize the query and check the query tree sent to the 

1244 # driver and the one in the materialized query. 

1245 with self.assertRaises(_TestQueryExecution) as cm: 

1246 list(query.materialize(**kwargs).data_ids()) 

1247 derived_tree = cm.exception.tree 

1248 self.assertEqual(derived_tree.dimensions, dimensions) 

1249 # Predicate should be materialized away; it no longer appears 

1250 # in the derived query. 

1251 self.assertEqual(str(derived_tree.predicate), "True") 

1252 self.assertFalse(derived_tree.data_coordinate_uploads) 

1253 if has_dataset: 

1254 # Dataset search is still there, even though its existence 

1255 # constraint is included in the materialization, because we 

1256 # might need to re-join for some result columns in a 

1257 # derived query. 

1258 self.assertTrue(derived_tree.datasets.keys(), {"raw"}) 

1259 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) 

1260 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1261 else: 

1262 self.assertFalse(derived_tree.datasets) 

1263 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items() 

1264 self.assertEqual(derived_tree_materialized_dimensions, dimensions) 

1265 ( 

1266 materialized_tree, 

1267 materialized_dimensions, 

1268 materialized_datasets, 

1269 ) = cm.exception.driver.materializations[key] 

1270 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions) 

1271 if has_dataset: 

1272 self.assertEqual(materialized_datasets, {"raw"}) 

1273 else: 

1274 self.assertFalse(materialized_datasets) 

1275 check_query_tree(materialized_tree) 

1276 

1277 # Actual logic for the check() function begins here. 

1278 

1279 self.assertEqual(query.constraint_dataset_types, {"raw"}) 

1280 self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group()) 

1281 

1282 # Adding a constraint on a field for this dataset type should work 

1283 # (this constraint will be present in all downstream tests). 

1284 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all") 

1285 with self.assertRaises(InvalidQueryError): 

1286 # Adding constraint on a different dataset should not work. 

1287 query.where(query.expression_factory["refcat"].run == "refcats") 

1288 

1289 # Data IDs, with dimensions defaulted. 

1290 check_data_id_results(query=query) 

1291 # Dimensions for data IDs the same as defaults. 

1292 check_data_id_results(["exposure", "detector"], query=query) 

1293 # Dimensions are a subset of the query dimensions. 

1294 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"])) 

1295 # Dimensions are a superset of the query dimensions. 

1296 check_data_id_results( 

1297 ["exposure", "detector", "visit"], 

1298 query=query, 

1299 dimensions=self.universe.conform(["exposure", "detector", "visit"]), 

1300 ) 

1301 # Dimensions are neither a superset nor a subset of the query 

1302 # dimensions. 

1303 check_data_id_results( 

1304 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"]) 

1305 ) 

1306 # Dimensions are empty. 

1307 check_data_id_results([], query=query, dimensions=self.universe.conform([])) 

1308 

1309 # Get DatasetRef results, with various arguments and defaulting. 

1310 check_dataset_results("raw", query=query) 

1311 check_dataset_results("raw", query=query, find_first=True) 

1312 check_dataset_results("raw", ["DummyCam/defaults"], query=query) 

1313 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True) 

1314 check_dataset_results(self.raw, query=query) 

1315 check_dataset_results(self.raw, query=query, find_first=True) 

1316 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query) 

1317 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True) 

1318 

1319 # Changing collections at this stage is not allowed. 

1320 with self.assertRaises(InvalidQueryError): 

1321 query.datasets("raw", collections=["DummyCam/calib"]) 

1322 

1323 # Changing storage classes is allowed, if they're compatible. 

1324 check_dataset_results( 

1325 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy" 

1326 ) 

1327 with self.assertRaises(DatasetTypeError): 

1328 # Can't use overrideStorageClass, because it'll raise 

1329 # before the code we want to test can. 

1330 query.datasets(DatasetType("raw", self.raw.dimensions, "int")) 

1331 

1332 # Check the 'any' and 'explain_no_results' methods on Query itself. 

1333 for execute, exact in itertools.permutations([False, True], 2): 

1334 with self.assertRaises(_TestQueryAny) as cm: 

1335 query.any(execute=execute, exact=exact) 

1336 self.assertEqual(cm.exception.execute, execute) 

1337 self.assertEqual(cm.exception.exact, exact) 

1338 check_query_tree(cm.exception.tree, dimensions) 

1339 with self.assertRaises(_TestQueryExplainNoResults): 

1340 query.explain_no_results() 

1341 check_query_tree(cm.exception.tree, dimensions) 

1342 

1343 # Materialize the query with defaults. 

1344 check_materialization({}, query=query) 

1345 # Materialize the query with args that match defaults. 

1346 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query) 

1347 # Materialize the query with a superset of the original dimensions. 

1348 check_materialization( 

1349 {"dimensions": ["exposure", "detector", "visit"]}, 

1350 query=query, 

1351 dimensions=self.universe.conform(["exposure", "visit", "detector"]), 

1352 ) 

1353 # Materialize the query with no datasets. 

1354 check_materialization( 

1355 {"dimensions": ["exposure", "detector"], "datasets": frozenset()}, 

1356 query=query, 

1357 has_dataset=False, 

1358 ) 

1359 # Materialize the query with no datasets and a subset of the 

1360 # dimensions. 

1361 check_materialization( 

1362 {"dimensions": ["exposure"], "datasets": frozenset()}, 

1363 query=query, 

1364 has_dataset=False, 

1365 dimensions=self.universe.conform(["exposure"]), 

1366 ) 

1367 # Materializing the query with a dataset that is not in the query 

1368 # is an error. 

1369 with self.assertRaises(InvalidQueryError): 

1370 query.materialize(datasets={"refcat"}) 

1371 # Materializing the query with dimensions that are not a superset 

1372 # of any materialized dataset dimensions is an error. 

1373 with self.assertRaises(InvalidQueryError): 

1374 query.materialize(dimensions=["exposure"], datasets={"raw"}) 

1375 

1376 # Actual logic for test_dataset_joins starts here. 

1377 

1378 # Default collections and existing dataset type name. 

1379 check(self.query().join_dataset_search("raw")) 

1380 # Default collections and existing DatasetType instance. 

1381 check(self.query().join_dataset_search(self.raw)) 

1382 # Manual collections and existing dataset type. 

1383 check( 

1384 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"]) 

1385 ) 

1386 check( 

1387 self.query(default_collections=None).join_dataset_search( 

1388 self.raw, collections=["DummyCam/defaults"] 

1389 ) 

1390 ) 

1391 with self.assertRaises(MissingDatasetTypeError): 

1392 # Dataset type does not exist. 

1393 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"]) 

1394 with self.assertRaises(DatasetTypeError): 

1395 # Dataset type object with bad dimensions passed. 

1396 self.query().join_dataset_search( 

1397 DatasetType( 

1398 "raw", 

1399 dimensions={"detector", "visit"}, 

1400 storageClass=self.raw.storageClass_name, 

1401 universe=self.universe, 

1402 ) 

1403 ) 

1404 with self.assertRaises(TypeError): 

1405 # Bad type for dataset type argument. 

1406 self.query().join_dataset_search(3) 

1407 with self.assertRaises(InvalidQueryError): 

1408 # Cannot pass storage class override to join_dataset_search, 

1409 # because we cannot use it there. 

1410 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy")) 

1411 

1412 def test_dimension_record_results(self) -> None: 

1413 """Test queries that return dimension records. 

1414 

1415 This includes tests for: 

1416 

1417 - joining against uploaded data coordinates; 

1418 - counting result rows; 

1419 - expanding dimensions as needed for 'where' conditions; 

1420 - order_by and limit. 

1421 

1422 It does not include the iteration methods of 

1423 DimensionRecordQueryResults, since those require a different mock 

1424 driver setup (see test_dimension_record_iteration). 

1425 """ 

1426 # Set up the base query-results object to test. 

1427 query = self.query() 

1428 x = query.expression_factory 

1429 self.assertFalse(query.constraint_dimensions) 

1430 query = query.where(x.skymap == "m") 

1431 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"])) 

1432 upload_rows = [ 

1433 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe), 

1434 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe), 

1435 ] 

1436 raw_rows = frozenset([data_id.required_values for data_id in upload_rows]) 

1437 query = query.join_data_coordinates(upload_rows) 

1438 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"])) 

1439 results = query.dimension_records("patch") 

1440 results = results.where(x.tract == 4) 

1441 

1442 # Define a closure to run tests on variants of the base query. 

1443 def check( 

1444 results: DimensionRecordQueryResults, 

1445 order_by: Any = (), 

1446 limit: int | None = None, 

1447 ) -> list[str]: 

1448 results = results.order_by(*order_by).limit(limit) 

1449 self.assertEqual(results.element.name, "patch") 

1450 with self.assertRaises(_TestQueryExecution) as cm: 

1451 list(results) 

1452 tree = cm.exception.tree 

1453 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1454 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"])) 

1455 self.assertFalse(tree.materializations) 

1456 self.assertFalse(tree.datasets) 

1457 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items() 

1458 self.assertEqual(upload_dimensions, self.universe.conform(["visit"])) 

1459 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows)) 

1460 result_spec = cm.exception.result_spec 

1461 self.assertEqual(result_spec.result_type, "dimension_record") 

1462 self.assertEqual(result_spec.element, self.universe["patch"]) 

1463 self.assertEqual(result_spec.limit, limit) 

1464 for exact, discard in itertools.permutations([False, True], r=2): 

1465 with self.assertRaises(_TestQueryCount) as cm: 

1466 results.count(exact=exact, discard=discard) 

1467 self.assertEqual(cm.exception.result_spec, result_spec) 

1468 self.assertEqual(cm.exception.exact, exact) 

1469 self.assertEqual(cm.exception.discard, discard) 

1470 return [str(term) for term in result_spec.order_by] 

1471 

1472 # Run the closure's tests on variants of the base query. 

1473 self.assertEqual(check(results), []) 

1474 self.assertEqual(check(results, limit=2), []) 

1475 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1476 self.assertEqual( 

1477 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1478 ["patch.cell_x", "patch.cell_y DESC"], 

1479 ) 

1480 with self.assertRaises(InvalidQueryError): 

1481 # Cannot upload empty list of data IDs. 

1482 query.join_data_coordinates([]) 

1483 with self.assertRaises(InvalidQueryError): 

1484 # Cannot upload heterogeneous list of data IDs. 

1485 query.join_data_coordinates( 

1486 [ 

1487 DataCoordinate.make_empty(self.universe), 

1488 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe), 

1489 ] 

1490 ) 

1491 

1492 def test_dimension_record_iteration(self) -> None: 

1493 """Tests for DimensionRecordQueryResult iteration.""" 

1494 

1495 def make_record(n: int) -> DimensionRecord: 

1496 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n) 

1497 

1498 result_rows = ( 

1499 [make_record(n) for n in range(3)], 

1500 [make_record(n) for n in range(3, 6)], 

1501 [make_record(10)], 

1502 ) 

1503 results = self.query(result_rows=result_rows).dimension_records("patch") 

1504 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1505 self.assertEqual( 

1506 list(results.iter_set_pages()), 

1507 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows], 

1508 ) 

1509 self.assertEqual( 

1510 [table.column("id").to_pylist() for table in results.iter_table_pages()], 

1511 [list(range(3)), list(range(3, 6)), [10]], 

1512 ) 

1513 

1514 def test_data_coordinate_results(self) -> None: 

1515 """Test queries that return data coordinates. 

1516 

1517 This includes tests for: 

1518 

1519 - counting result rows; 

1520 - expanding dimensions as needed for 'where' conditions; 

1521 - order_by and limit. 

1522 

1523 It does not include the iteration methods of 

1524 DataCoordinateQueryResults, since those require a different mock 

1525 driver setup (see test_data_coordinate_iteration). More tests for 

1526 different inputs to DataCoordinateQueryResults construction are in 

1527 test_dataset_join. 

1528 """ 

1529 # Set up the base query-results object to test. 

1530 query = self.query() 

1531 x = query.expression_factory 

1532 self.assertFalse(query.constraint_dimensions) 

1533 query = query.where(x.skymap == "m") 

1534 results = query.data_ids(["patch", "band"]) 

1535 results = results.where(x.tract == 4) 

1536 

1537 # Define a closure to run tests on variants of the base query. 

1538 def check( 

1539 results: DataCoordinateQueryResults, 

1540 order_by: Any = (), 

1541 limit: int | None = None, 

1542 include_dimension_records: bool = False, 

1543 ) -> list[str]: 

1544 results = results.order_by(*order_by).limit(limit) 

1545 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"])) 

1546 with self.assertRaises(_TestQueryExecution) as cm: 

1547 list(results) 

1548 tree = cm.exception.tree 

1549 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1550 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"])) 

1551 self.assertFalse(tree.materializations) 

1552 self.assertFalse(tree.datasets) 

1553 self.assertFalse(tree.data_coordinate_uploads) 

1554 result_spec = cm.exception.result_spec 

1555 self.assertEqual(result_spec.result_type, "data_coordinate") 

1556 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"])) 

1557 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1558 self.assertEqual(result_spec.limit, limit) 

1559 self.assertIsNone(result_spec.find_first_dataset) 

1560 for exact, discard in itertools.permutations([False, True], r=2): 

1561 with self.assertRaises(_TestQueryCount) as cm: 

1562 results.count(exact=exact, discard=discard) 

1563 self.assertEqual(cm.exception.result_spec, result_spec) 

1564 self.assertEqual(cm.exception.exact, exact) 

1565 self.assertEqual(cm.exception.discard, discard) 

1566 return [str(term) for term in result_spec.order_by] 

1567 

1568 # Run the closure's tests on variants of the base query. 

1569 self.assertEqual(check(results), []) 

1570 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), []) 

1571 self.assertEqual( 

1572 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1573 [], 

1574 ) 

1575 self.assertEqual(check(results, limit=2), []) 

1576 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1577 self.assertEqual( 

1578 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1579 ["patch.cell_x", "patch.cell_y DESC"], 

1580 ) 

1581 self.assertEqual( 

1582 check(results, order_by=["patch.cell_x", "-cell_y"]), 

1583 ["patch.cell_x", "patch.cell_y DESC"], 

1584 ) 

1585 

1586 def test_data_coordinate_iteration(self) -> None: 

1587 """Tests for DataCoordinateQueryResult iteration.""" 

1588 

1589 def make_data_id(n: int) -> DimensionRecord: 

1590 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe) 

1591 

1592 result_rows = ( 

1593 [make_data_id(n) for n in range(3)], 

1594 [make_data_id(n) for n in range(3, 6)], 

1595 [make_data_id(10)], 

1596 ) 

1597 results = self.query(result_rows=result_rows).data_ids(["patch"]) 

1598 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1599 

1600 def test_dataset_results(self) -> None: 

1601 """Test queries that return dataset refs. 

1602 

1603 This includes tests for: 

1604 

1605 - counting result rows; 

1606 - expanding dimensions as needed for 'where' conditions; 

1607 - different ways of passing a data ID to 'where' methods; 

1608 - order_by and limit. 

1609 

1610 It does not include the iteration methods of the DatasetRefQueryResults 

1611 classes, since those require a different mock driver setup (see 

1612 test_dataset_iteration). More tests for different inputs to 

1613 DatasetRefQueryResults construction are in test_dataset_join. 

1614 """ 

1615 # Set up a few equivalent base query-results object to test. 

1616 query = self.query() 

1617 x = query.expression_factory 

1618 self.assertFalse(query.constraint_dimensions) 

1619 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4) 

1620 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where( 

1621 {"instrument": "DummyCam", "visit": 4} 

1622 ) 

1623 results3 = query.datasets("raw").where( 

1624 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe) 

1625 ) 

1626 

1627 # Define a closure to check a DatasetRefQueryResults instance. 

1628 def check( 

1629 results: DatasetRefQueryResults, 

1630 order_by: Any = (), 

1631 limit: int | None = None, 

1632 include_dimension_records: bool = False, 

1633 ) -> list[str]: 

1634 results = results.order_by(*order_by).limit(limit) 

1635 with self.assertRaises(_TestQueryExecution) as cm: 

1636 list(results) 

1637 tree = cm.exception.tree 

1638 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4") 

1639 self.assertEqual( 

1640 tree.dimensions, 

1641 self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()), 

1642 ) 

1643 self.assertFalse(tree.materializations) 

1644 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name}) 

1645 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",)) 

1646 self.assertEqual( 

1647 tree.datasets[results.dataset_type.name].dimensions, 

1648 results.dataset_type.dimensions.as_group(), 

1649 ) 

1650 self.assertFalse(tree.data_coordinate_uploads) 

1651 result_spec = cm.exception.result_spec 

1652 self.assertEqual(result_spec.result_type, "dataset_ref") 

1653 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1654 self.assertEqual(result_spec.limit, limit) 

1655 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name) 

1656 for exact, discard in itertools.permutations([False, True], r=2): 

1657 with self.assertRaises(_TestQueryCount) as cm: 

1658 results.count(exact=exact, discard=discard) 

1659 self.assertEqual(cm.exception.result_spec, result_spec) 

1660 self.assertEqual(cm.exception.exact, exact) 

1661 self.assertEqual(cm.exception.discard, discard) 

1662 with self.assertRaises(_TestQueryExecution) as cm: 

1663 list(results.data_ids) 

1664 self.assertEqual( 

1665 cm.exception.result_spec, 

1666 qrs.DataCoordinateResultSpec( 

1667 dimensions=results.dataset_type.dimensions.as_group(), 

1668 include_dimension_records=include_dimension_records, 

1669 ), 

1670 ) 

1671 self.assertIs(cm.exception.tree, tree) 

1672 return [str(term) for term in result_spec.order_by] 

1673 

1674 # Run the closure's tests on variants of the base query. 

1675 self.assertEqual(check(results1), []) 

1676 self.assertEqual(check(results2), []) 

1677 self.assertEqual(check(results3), []) 

1678 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), []) 

1679 self.assertEqual( 

1680 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1681 [], 

1682 ) 

1683 self.assertEqual(check(results1, limit=2), []) 

1684 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"]) 

1685 self.assertEqual(check(results1, order_by=["detector"]), ["detector"]) 

1686 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"]) 

1687 

1688 def test_dataset_iteration(self) -> None: 

1689 """Tests for SingleTypeDatasetQueryResult iteration.""" 

1690 

1691 def make_ref(n: int) -> DimensionRecord: 

1692 return DatasetRef( 

1693 self.raw, 

1694 DataCoordinate.standardize( 

1695 instrument="DummyCam", exposure=4, detector=n, universe=self.universe 

1696 ), 

1697 run="DummyCam/raw/all", 

1698 id=uuid.uuid4(), 

1699 ) 

1700 

1701 result_rows = ( 

1702 [make_ref(n) for n in range(3)], 

1703 [make_ref(n) for n in range(3, 6)], 

1704 [make_ref(10)], 

1705 ) 

1706 results = self.query(result_rows=result_rows).datasets("raw") 

1707 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1708 

1709 def test_identifiers(self) -> None: 

1710 """Test edge-cases of identifiers in order_by expressions.""" 

1711 

1712 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]: 

1713 with self.assertRaises(_TestQueryExecution) as cm: 

1714 list(results) 

1715 return [str(term) for term in cm.exception.result_spec.order_by] 

1716 

1717 self.assertEqual( 

1718 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")), 

1719 ["day_obs.timespan.begin DESC"], 

1720 ) 

1721 self.assertEqual( 

1722 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")), 

1723 ["day_obs.timespan.end"], 

1724 ) 

1725 self.assertEqual( 

1726 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")), 

1727 ["visit.timespan.begin DESC"], 

1728 ) 

1729 self.assertEqual( 

1730 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")), 

1731 ["visit.timespan.end"], 

1732 ) 

1733 self.assertEqual( 

1734 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")), 

1735 ["visit.science_program"], 

1736 ) 

1737 self.assertEqual( 

1738 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")), 

1739 ["visit"], 

1740 ) 

1741 self.assertEqual( 

1742 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")), 

1743 ["physical_filter"], 

1744 ) 

1745 with self.assertRaises(TypeError): 

1746 self.query().data_ids(["visit"]).order_by(3) 

1747 with self.assertRaises(InvalidQueryError): 

1748 self.query().data_ids(["visit"]).order_by("visit.region") 

1749 with self.assertRaisesRegex(InvalidQueryError, "Ambiguous"): 

1750 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin") 

1751 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1752 self.query().data_ids(["visit", "exposure"]).order_by("blarg") 

1753 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1754 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse") 

1755 with self.assertRaisesRegex(InvalidQueryError, "Unrecognized"): 

1756 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey") 

1757 with self.assertRaisesRegex(InvalidQueryError, "not valid for datasets"): 

1758 self.query().datasets("raw").order_by("raw.seq_num") 

1759 

1760 def test_invalid_models(self) -> None: 

1761 """Test invalid models and combinations of models that cannot be 

1762 constructed via the public Query and *QueryResults interfaces. 

1763 """ 

1764 x = ExpressionFactory(self.universe) 

1765 with self.assertRaises(InvalidQueryError): 

1766 # QueryTree dimensions do not cover dataset dimensions. 

1767 qt.QueryTree( 

1768 dimensions=self.universe.conform(["visit"]), 

1769 datasets={ 

1770 "raw": qt.DatasetSearch( 

1771 collections=("DummyCam/raw/all",), 

1772 dimensions=self.raw.dimensions.as_group(), 

1773 ) 

1774 }, 

1775 ) 

1776 with self.assertRaises(InvalidQueryError): 

1777 # QueryTree dimensions do no cover predicate dimensions. 

1778 qt.QueryTree( 

1779 dimensions=self.universe.conform(["visit"]), 

1780 predicate=(x.detector > 5), 

1781 ) 

1782 with self.assertRaises(InvalidQueryError): 

1783 # Predicate references a dataset not in the QueryTree. 

1784 qt.QueryTree( 

1785 dimensions=self.universe.conform(["exposure", "detector"]), 

1786 predicate=(x["raw"].collection == "bird"), 

1787 ) 

1788 with self.assertRaises(InvalidQueryError): 

1789 # ResultSpec's dimensions are not a subset of the query tree's. 

1790 DimensionRecordQueryResults( 

1791 _TestQueryDriver(), 

1792 qt.QueryTree(dimensions=self.universe.conform(["tract"])), 

1793 qrs.DimensionRecordResultSpec(element=self.universe["detector"]), 

1794 ) 

1795 with self.assertRaises(InvalidQueryError): 

1796 # ResultSpec's datasets are not a subset of the query tree's. 

1797 DatasetRefQueryResults( 

1798 _TestQueryDriver(), 

1799 qt.QueryTree(dimensions=self.raw.dimensions.as_group()), 

1800 qrs.DatasetRefResultSpec( 

1801 dataset_type_name="raw", 

1802 dimensions=self.raw.dimensions.as_group(), 

1803 storage_class_name=self.raw.storageClass_name, 

1804 find_first=True, 

1805 ), 

1806 ) 

1807 with self.assertRaises(InvalidQueryError): 

1808 # ResultSpec's order_by expression is not related to the dimensions 

1809 # we're returning. 

1810 x = ExpressionFactory(self.universe) 

1811 DimensionRecordQueryResults( 

1812 _TestQueryDriver(), 

1813 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1814 qrs.DimensionRecordResultSpec( 

1815 element=self.universe["detector"], order_by=(x.unwrap(x.visit),) 

1816 ), 

1817 ) 

1818 with self.assertRaises(InvalidQueryError): 

1819 # ResultSpec's order_by expression is not related to the datasets 

1820 # we're returning. 

1821 x = ExpressionFactory(self.universe) 

1822 DimensionRecordQueryResults( 

1823 _TestQueryDriver(), 

1824 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1825 qrs.DimensionRecordResultSpec( 

1826 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),) 

1827 ), 

1828 ) 

1829 

1830 def test_general_result_spec(self) -> None: 

1831 """Tests for GeneralResultSpec. 

1832 

1833 Unlike the other ResultSpec objects, we don't have a *QueryResults 

1834 class for GeneralResultSpec yet, so we can't use the higher-level 

1835 interfaces to test it like we can the others. 

1836 """ 

1837 a = qrs.GeneralResultSpec( 

1838 dimensions=self.universe.conform(["detector"]), 

1839 dimension_fields={"detector": {"purpose"}}, 

1840 dataset_fields={}, 

1841 find_first=False, 

1842 ) 

1843 self.assertEqual(a.find_first_dataset, None) 

1844 a_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1845 a_columns.dimension_fields["detector"].add("purpose") 

1846 self.assertEqual(a.get_result_columns(), a_columns) 

1847 b = qrs.GeneralResultSpec( 

1848 dimensions=self.universe.conform(["detector"]), 

1849 dimension_fields={}, 

1850 dataset_fields={"bias": {"timespan", "dataset_id"}}, 

1851 find_first=True, 

1852 ) 

1853 self.assertEqual(b.find_first_dataset, "bias") 

1854 b_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1855 b_columns.dataset_fields["bias"].add("timespan") 

1856 b_columns.dataset_fields["bias"].add("dataset_id") 

1857 self.assertEqual(b.get_result_columns(), b_columns) 

1858 with self.assertRaises(InvalidQueryError): 

1859 # More than one dataset type with find_first 

1860 qrs.GeneralResultSpec( 

1861 dimensions=self.universe.conform(["detector", "exposure"]), 

1862 dimension_fields={}, 

1863 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}}, 

1864 find_first=True, 

1865 ) 

1866 with self.assertRaises(InvalidQueryError): 

1867 # Out-of-bounds dimension fields. 

1868 qrs.GeneralResultSpec( 

1869 dimensions=self.universe.conform(["detector"]), 

1870 dimension_fields={"visit": {"name"}}, 

1871 dataset_fields={}, 

1872 find_first=False, 

1873 ) 

1874 with self.assertRaises(InvalidQueryError): 

1875 # No fields for dimension element. 

1876 qrs.GeneralResultSpec( 

1877 dimensions=self.universe.conform(["detector"]), 

1878 dimension_fields={"detector": set()}, 

1879 dataset_fields={}, 

1880 find_first=True, 

1881 ) 

1882 with self.assertRaises(InvalidQueryError): 

1883 # No fields for dataset. 

1884 qrs.GeneralResultSpec( 

1885 dimensions=self.universe.conform(["detector"]), 

1886 dimension_fields={}, 

1887 dataset_fields={"bias": set()}, 

1888 find_first=True, 

1889 ) 

1890 

1891 

1892if __name__ == "__main__": 

1893 unittest.main()