Coverage for tests/test_query_interface.py: 10%

936 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-04 02:55 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests for the public Butler._query interface and the Pydantic models that 

29back it using a mock column-expression visitor and a mock QueryDriver 

30implementation. 

31 

32These tests are entirely independent of which kind of butler or database 

33backend we're using. 

34 

35This is a very large test file because a lot of tests make use of those mocks, 

36but they're not so generally useful that I think they're worth putting in the 

37library proper. 

38""" 

39 

40from __future__ import annotations 

41 

42import dataclasses 

43import itertools 

44import unittest 

45import uuid 

46from collections.abc import Iterable, Iterator, Mapping, Set 

47from typing import Any 

48 

49import astropy.time 

50from lsst.daf.butler import ( 

51 CollectionType, 

52 DataCoordinate, 

53 DataIdValue, 

54 DatasetRef, 

55 DatasetType, 

56 DimensionGroup, 

57 DimensionRecord, 

58 DimensionRecordSet, 

59 DimensionUniverse, 

60 MissingDatasetTypeError, 

61 NamedValueSet, 

62 NoDefaultCollectionError, 

63 Timespan, 

64) 

65from lsst.daf.butler.queries import ( 

66 DataCoordinateQueryResults, 

67 DatasetRefQueryResults, 

68 DimensionRecordQueryResults, 

69 Query, 

70) 

71from lsst.daf.butler.queries import driver as qd 

72from lsst.daf.butler.queries import result_specs as qrs 

73from lsst.daf.butler.queries import tree as qt 

74from lsst.daf.butler.queries.expression_factory import ExpressionFactory 

75from lsst.daf.butler.queries.tree._column_expression import UnaryExpression 

76from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands 

77from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor 

78from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError 

79from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord 

80from lsst.sphgeom import DISJOINT, Mq3cPixelization 

81 

82 

83class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]): 

84 """Test visitor for column expressions. 

85 

86 This visitor evaluates column expressions using regular Python logic. 

87 

88 Parameters 

89 ---------- 

90 dimension_keys : `~collections.abc.Mapping`, optional 

91 Mapping from dimension name to the value it should be assigned by the 

92 visitor. 

93 dimension_fields : `~collections.abc.Mapping`, optional 

94 Mapping from ``(dimension element name, field)`` tuple to the value it 

95 should be assigned by the visitor. 

96 dataset_fields : `~collections.abc.Mapping`, optional 

97 Mapping from ``(dataset type name, field)`` tuple to the value it 

98 should be assigned by the visitor. 

99 query_tree_items : `~collections.abc.Set`, optional 

100 Set that should be used as the right-hand side of element-in-query 

101 predicates. 

102 """ 

103 

104 def __init__( 

105 self, 

106 dimension_keys: Mapping[str, Any] | None = None, 

107 dimension_fields: Mapping[tuple[str, str], Any] | None = None, 

108 dataset_fields: Mapping[tuple[str, str], Any] | None = None, 

109 query_tree_items: Set[Any] = frozenset(), 

110 ): 

111 self.dimension_keys = dimension_keys or {} 

112 self.dimension_fields = dimension_fields or {} 

113 self.dataset_fields = dataset_fields or {} 

114 self.query_tree_items = query_tree_items 

115 

116 def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any: 

117 match expression.operator: 

118 case "+": 

119 return expression.a.visit(self) + expression.b.visit(self) 

120 case "-": 

121 return expression.a.visit(self) - expression.b.visit(self) 

122 case "*": 

123 return expression.a.visit(self) * expression.b.visit(self) 

124 case "/": 

125 match expression.column_type: 

126 case "int": 

127 return expression.a.visit(self) // expression.b.visit(self) 

128 case "float": 

129 return expression.a.visit(self) / expression.b.visit(self) 

130 case "%": 

131 return expression.a.visit(self) % expression.b.visit(self) 

132 

133 def visit_comparison( 

134 self, 

135 a: qt.ColumnExpression, 

136 operator: qt.ComparisonOperator, 

137 b: qt.ColumnExpression, 

138 flags: PredicateVisitFlags, 

139 ) -> bool: 

140 match operator: 

141 case "==": 

142 return a.visit(self) == b.visit(self) 

143 case "!=": 

144 return a.visit(self) != b.visit(self) 

145 case "<": 

146 return a.visit(self) < b.visit(self) 

147 case ">": 

148 return a.visit(self) > b.visit(self) 

149 case "<=": 

150 return a.visit(self) <= b.visit(self) 

151 case ">=": 

152 return a.visit(self) >= b.visit(self) 

153 case "overlaps": 

154 return not (a.visit(self).relate(b.visit(self)) & DISJOINT) 

155 

156 def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any: 

157 return self.dataset_fields[expression.dataset_type, expression.field] 

158 

159 def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any: 

160 return self.dimension_fields[expression.element.name, expression.field] 

161 

162 def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any: 

163 return self.dimension_keys[expression.dimension.name] 

164 

165 def visit_in_container( 

166 self, 

167 member: qt.ColumnExpression, 

168 container: tuple[qt.ColumnExpression, ...], 

169 flags: PredicateVisitFlags, 

170 ) -> bool: 

171 return member.visit(self) in [item.visit(self) for item in container] 

172 

173 def visit_in_range( 

174 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags 

175 ) -> bool: 

176 return member.visit(self) in range(start, stop, step) 

177 

178 def visit_in_query_tree( 

179 self, 

180 member: qt.ColumnExpression, 

181 column: qt.ColumnExpression, 

182 query_tree: qt.QueryTree, 

183 flags: PredicateVisitFlags, 

184 ) -> bool: 

185 return member.visit(self) in self.query_tree_items 

186 

187 def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool: 

188 return operand.visit(self) is None 

189 

190 def visit_literal(self, expression: qt.ColumnLiteral) -> Any: 

191 return expression.get_literal_value() 

192 

193 def visit_reversed(self, expression: qt.Reversed) -> Any: 

194 return _TestReversed(expression.operand.visit(self)) 

195 

196 def visit_unary_expression(self, expression: UnaryExpression) -> Any: 

197 match expression.operator: 

198 case "-": 

199 return -expression.operand.visit(self) 

200 case "begin_of": 

201 return expression.operand.visit(self).begin 

202 case "end_of": 

203 return expression.operand.visit(self).end 

204 

205 def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool: 

206 return all(results) 

207 

208 def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool: 

209 return not result 

210 

211 def apply_logical_or( 

212 self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags 

213 ) -> bool: 

214 return any(results) 

215 

216 

217@dataclasses.dataclass 

218class _TestReversed: 

219 """Struct used by _TestVisitor" to mark an expression as reversed in sort 

220 order. 

221 """ 

222 

223 operand: Any 

224 

225 

226class _TestQueryExecution(BaseException): 

227 """Exception raised by _TestQueryDriver.execute to communicate its args 

228 back to the caller. 

229 """ 

230 

231 def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None: 

232 self.result_spec = result_spec 

233 self.tree = tree 

234 self.driver = driver 

235 

236 

237class _TestQueryCount(BaseException): 

238 """Exception raised by _TestQueryDriver.count to communicate its args 

239 back to the caller. 

240 """ 

241 

242 def __init__( 

243 self, 

244 result_spec: qrs.ResultSpec, 

245 tree: qt.QueryTree, 

246 driver: _TestQueryDriver, 

247 exact: bool, 

248 discard: bool, 

249 ) -> None: 

250 self.result_spec = result_spec 

251 self.tree = tree 

252 self.driver = driver 

253 self.exact = exact 

254 self.discard = discard 

255 

256 

257class _TestQueryAny(BaseException): 

258 """Exception raised by _TestQueryDriver.any to communicate its args 

259 back to the caller. 

260 """ 

261 

262 def __init__( 

263 self, 

264 tree: qt.QueryTree, 

265 driver: _TestQueryDriver, 

266 exact: bool, 

267 execute: bool, 

268 ) -> None: 

269 self.tree = tree 

270 self.driver = driver 

271 self.exact = exact 

272 self.execute = execute 

273 

274 

275class _TestQueryExplainNoResults(BaseException): 

276 """Exception raised by _TestQueryDriver.explain_no_results to communicate 

277 its args back to the caller. 

278 """ 

279 

280 def __init__( 

281 self, 

282 tree: qt.QueryTree, 

283 driver: _TestQueryDriver, 

284 execute: bool, 

285 ) -> None: 

286 self.tree = tree 

287 self.driver = driver 

288 self.execute = execute 

289 

290 

291class _TestQueryDriver(qd.QueryDriver): 

292 """Mock implementation of `QueryDriver` that mostly raises exceptions that 

293 communicate the arguments its methods were called with. 

294 

295 Parameters 

296 ---------- 

297 default_collections : `tuple` [ `str`, ... ], optional 

298 Default collection the query or parent butler is imagined to have been 

299 constructed with. 

300 collection_info : `~collections.abc.Mapping`, optional 

301 Mapping from collection name to its record and summary, simulating the 

302 collections present in the data repository. 

303 dataset_types : `~collections.abc.Mapping`, optional 

304 Mapping from dataset type to its definition, simulating the dataset 

305 types registered in the data repository. 

306 result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional 

307 A tuple of iterables of arbitrary type to use as result rows any time 

308 `execute` is called, with each nested iterable considered a separate 

309 page. The result type is not checked for consistency with the result 

310 spec. If this is not provided, `execute` will instead raise 

311 `_TestQueryExecution`, and `fetch_page` will not do anything useful. 

312 """ 

313 

314 def __init__( 

315 self, 

316 default_collections: tuple[str, ...] | None = None, 

317 collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None, 

318 dataset_types: Mapping[str, DatasetType] | None = None, 

319 result_rows: tuple[Iterable[Any], ...] | None = None, 

320 ) -> None: 

321 self._universe = DimensionUniverse() 

322 # Mapping of the arguments passed to materialize, keyed by the UUID 

323 # that that each call returned. 

324 self.materializations: dict[ 

325 qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]] 

326 ] = {} 

327 # Mapping of the arguments passed to upload_data_coordinates, keyed by 

328 # the UUID that that each call returned. 

329 self.data_coordinate_uploads: dict[ 

330 qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]] 

331 ] = {} 

332 self._default_collections = default_collections 

333 self._collection_info = collection_info or {} 

334 self._dataset_types = dataset_types or {} 

335 self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = [] 

336 self._result_rows = result_rows 

337 self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {} 

338 

339 @property 

340 def universe(self) -> DimensionUniverse: 

341 return self._universe 

342 

343 def __enter__(self) -> None: 

344 pass 

345 

346 def __exit__(self, *args: Any, **kwargs: Any) -> None: 

347 pass 

348 

349 def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage: 

350 if self._result_rows is not None: 

351 iterator = iter(self._result_rows) 

352 current_rows = next(iterator, ()) 

353 return self._make_next_page(result_spec, current_rows, iterator) 

354 raise _TestQueryExecution(result_spec, tree, self) 

355 

356 def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage: 

357 if self._result_rows is not None: 

358 return self._make_next_page(result_spec, *self._result_iters.pop(key)) 

359 raise AssertionError("Test query driver not initialized for actual results.") 

360 

361 def _make_next_page( 

362 self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]] 

363 ) -> qd.ResultPage: 

364 next_rows = list(next(iterator, ())) 

365 if not next_rows: 

366 next_key = None 

367 else: 

368 next_key = uuid.uuid4() 

369 self._result_iters[next_key] = (next_rows, iterator) 

370 match result_spec: 

371 case qrs.DataCoordinateResultSpec(): 

372 return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

373 case qrs.DimensionRecordResultSpec(): 

374 return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

375 case qrs.DatasetRefResultSpec(): 

376 return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows) 

377 case _: 

378 raise NotImplementedError("Other query types not yet supported.") 

379 

380 def materialize( 

381 self, 

382 tree: qt.QueryTree, 

383 dimensions: DimensionGroup, 

384 datasets: frozenset[str], 

385 ) -> qd.MaterializationKey: 

386 key = uuid.uuid4() 

387 self.materializations[key] = (tree, dimensions, datasets) 

388 return key 

389 

390 def upload_data_coordinates( 

391 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] 

392 ) -> qd.DataCoordinateUploadKey: 

393 key = uuid.uuid4() 

394 self.data_coordinate_uploads[key] = (dimensions, frozenset(rows)) 

395 return key 

396 

397 def count( 

398 self, 

399 tree: qt.QueryTree, 

400 result_spec: qrs.ResultSpec, 

401 *, 

402 exact: bool, 

403 discard: bool, 

404 ) -> int: 

405 raise _TestQueryCount(result_spec, tree, self, exact, discard) 

406 

407 def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: 

408 raise _TestQueryAny(tree, self, exact, execute) 

409 

410 def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: 

411 raise _TestQueryExplainNoResults(tree, self, execute) 

412 

413 def get_default_collections(self) -> tuple[str, ...]: 

414 if self._default_collections is None: 

415 raise NoDefaultCollectionError() 

416 return self._default_collections 

417 

418 def get_dataset_type(self, name: str) -> DatasetType: 

419 try: 

420 return self._dataset_types[name] 

421 except KeyError: 

422 raise MissingDatasetTypeError(name) 

423 

424 

425class ColumnExpressionsTestCase(unittest.TestCase): 

426 """Tests for column expression objects in lsst.daf.butler.queries.tree.""" 

427 

428 def setUp(self) -> None: 

429 self.universe = DimensionUniverse() 

430 self.x = ExpressionFactory(self.universe) 

431 

432 def query(self, **kwargs: Any) -> Query: 

433 """Make an initial Query object with the given kwargs used to 

434 initialize the _TestQueryDriver. 

435 """ 

436 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

437 

438 def test_int_literals(self) -> None: 

439 expr = self.x.unwrap(self.x.literal(5)) 

440 self.assertEqual(expr.value, 5) 

441 self.assertEqual(expr.get_literal_value(), 5) 

442 self.assertEqual(expr.expression_type, "int") 

443 self.assertEqual(expr.column_type, "int") 

444 self.assertEqual(str(expr), "5") 

445 self.assertTrue(expr.is_literal) 

446 columns = qt.ColumnSet(self.universe.empty.as_group()) 

447 expr.gather_required_columns(columns) 

448 self.assertFalse(columns) 

449 self.assertEqual(expr.visit(_TestVisitor()), 5) 

450 

451 def test_string_literals(self) -> None: 

452 expr = self.x.unwrap(self.x.literal("five")) 

453 self.assertEqual(expr.value, "five") 

454 self.assertEqual(expr.get_literal_value(), "five") 

455 self.assertEqual(expr.expression_type, "string") 

456 self.assertEqual(expr.column_type, "string") 

457 self.assertEqual(str(expr), "'five'") 

458 self.assertTrue(expr.is_literal) 

459 columns = qt.ColumnSet(self.universe.empty.as_group()) 

460 expr.gather_required_columns(columns) 

461 self.assertFalse(columns) 

462 self.assertEqual(expr.visit(_TestVisitor()), "five") 

463 

464 def test_float_literals(self) -> None: 

465 expr = self.x.unwrap(self.x.literal(0.5)) 

466 self.assertEqual(expr.value, 0.5) 

467 self.assertEqual(expr.get_literal_value(), 0.5) 

468 self.assertEqual(expr.expression_type, "float") 

469 self.assertEqual(expr.column_type, "float") 

470 self.assertEqual(str(expr), "0.5") 

471 self.assertTrue(expr.is_literal) 

472 columns = qt.ColumnSet(self.universe.empty.as_group()) 

473 expr.gather_required_columns(columns) 

474 self.assertFalse(columns) 

475 self.assertEqual(expr.visit(_TestVisitor()), 0.5) 

476 

477 def test_hash_literals(self) -> None: 

478 expr = self.x.unwrap(self.x.literal(b"eleven")) 

479 self.assertEqual(expr.value, b"eleven") 

480 self.assertEqual(expr.get_literal_value(), b"eleven") 

481 self.assertEqual(expr.expression_type, "hash") 

482 self.assertEqual(expr.column_type, "hash") 

483 self.assertEqual(str(expr), "(bytes)") 

484 self.assertTrue(expr.is_literal) 

485 columns = qt.ColumnSet(self.universe.empty.as_group()) 

486 expr.gather_required_columns(columns) 

487 self.assertFalse(columns) 

488 self.assertEqual(expr.visit(_TestVisitor()), b"eleven") 

489 

490 def test_uuid_literals(self) -> None: 

491 value = uuid.uuid4() 

492 expr = self.x.unwrap(self.x.literal(value)) 

493 self.assertEqual(expr.value, value) 

494 self.assertEqual(expr.get_literal_value(), value) 

495 self.assertEqual(expr.expression_type, "uuid") 

496 self.assertEqual(expr.column_type, "uuid") 

497 self.assertEqual(str(expr), str(value)) 

498 self.assertTrue(expr.is_literal) 

499 columns = qt.ColumnSet(self.universe.empty.as_group()) 

500 expr.gather_required_columns(columns) 

501 self.assertFalse(columns) 

502 self.assertEqual(expr.visit(_TestVisitor()), value) 

503 

504 def test_datetime_literals(self) -> None: 

505 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

506 expr = self.x.unwrap(self.x.literal(value)) 

507 self.assertEqual(expr.value, value) 

508 self.assertEqual(expr.get_literal_value(), value) 

509 self.assertEqual(expr.expression_type, "datetime") 

510 self.assertEqual(expr.column_type, "datetime") 

511 self.assertEqual(str(expr), "2020-01-01T00:00:00") 

512 self.assertTrue(expr.is_literal) 

513 columns = qt.ColumnSet(self.universe.empty.as_group()) 

514 expr.gather_required_columns(columns) 

515 self.assertFalse(columns) 

516 self.assertEqual(expr.visit(_TestVisitor()), value) 

517 

518 def test_timespan_literals(self) -> None: 

519 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

520 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

521 value = Timespan(begin, end) 

522 expr = self.x.unwrap(self.x.literal(value)) 

523 self.assertEqual(expr.value, value) 

524 self.assertEqual(expr.get_literal_value(), value) 

525 self.assertEqual(expr.expression_type, "timespan") 

526 self.assertEqual(expr.column_type, "timespan") 

527 self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)") 

528 self.assertTrue(expr.is_literal) 

529 columns = qt.ColumnSet(self.universe.empty.as_group()) 

530 expr.gather_required_columns(columns) 

531 self.assertFalse(columns) 

532 self.assertEqual(expr.visit(_TestVisitor()), value) 

533 

534 def test_region_literals(self) -> None: 

535 pixelization = Mq3cPixelization(10) 

536 value = pixelization.quad(12058870) 

537 expr = self.x.unwrap(self.x.literal(value)) 

538 self.assertEqual(expr.value, value) 

539 self.assertEqual(expr.get_literal_value(), value) 

540 self.assertEqual(expr.expression_type, "region") 

541 self.assertEqual(expr.column_type, "region") 

542 self.assertEqual(str(expr), "(region)") 

543 self.assertTrue(expr.is_literal) 

544 columns = qt.ColumnSet(self.universe.empty.as_group()) 

545 expr.gather_required_columns(columns) 

546 self.assertFalse(columns) 

547 self.assertEqual(expr.visit(_TestVisitor()), value) 

548 

549 def test_invalid_literal(self) -> None: 

550 with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."): 

551 self.x.literal(5j) 

552 

553 def test_dimension_key_reference(self) -> None: 

554 expr = self.x.unwrap(self.x.detector) 

555 self.assertIsNone(expr.get_literal_value()) 

556 self.assertEqual(expr.expression_type, "dimension_key") 

557 self.assertEqual(expr.column_type, "int") 

558 self.assertEqual(str(expr), "detector") 

559 self.assertFalse(expr.is_literal) 

560 columns = qt.ColumnSet(self.universe.empty.as_group()) 

561 expr.gather_required_columns(columns) 

562 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

563 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3) 

564 

565 def test_dimension_field_reference(self) -> None: 

566 expr = self.x.unwrap(self.x.detector.purpose) 

567 self.assertIsNone(expr.get_literal_value()) 

568 self.assertEqual(expr.expression_type, "dimension_field") 

569 self.assertEqual(expr.column_type, "string") 

570 self.assertEqual(str(expr), "detector.purpose") 

571 self.assertFalse(expr.is_literal) 

572 columns = qt.ColumnSet(self.universe.empty.as_group()) 

573 expr.gather_required_columns(columns) 

574 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

575 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

576 with self.assertRaises(qt.InvalidQueryError): 

577 qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region") 

578 self.assertEqual( 

579 expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science" 

580 ) 

581 

582 def test_dataset_field_reference(self) -> None: 

583 expr = self.x.unwrap(self.x["raw"].ingest_date) 

584 self.assertIsNone(expr.get_literal_value()) 

585 self.assertEqual(expr.expression_type, "dataset_field") 

586 self.assertEqual(str(expr), "raw.ingest_date") 

587 self.assertFalse(expr.is_literal) 

588 columns = qt.ColumnSet(self.universe.empty.as_group()) 

589 expr.gather_required_columns(columns) 

590 self.assertEqual(columns.dimensions, self.universe.empty.as_group()) 

591 self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"}) 

592 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid") 

593 self.assertEqual( 

594 qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string" 

595 ) 

596 self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string") 

597 self.assertEqual( 

598 qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime" 

599 ) 

600 self.assertEqual( 

601 qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan" 

602 ) 

603 value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

604 self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value) 

605 

606 def test_unary_negation(self) -> None: 

607 expr = self.x.unwrap(-self.x.visit.exposure_time) 

608 self.assertIsNone(expr.get_literal_value()) 

609 self.assertEqual(expr.expression_type, "unary") 

610 self.assertEqual(expr.column_type, "float") 

611 self.assertEqual(str(expr), "-visit.exposure_time") 

612 self.assertFalse(expr.is_literal) 

613 columns = qt.ColumnSet(self.universe.empty.as_group()) 

614 expr.gather_required_columns(columns) 

615 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

616 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

617 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0) 

618 with self.assertRaises(qt.InvalidQueryError): 

619 qt.UnaryExpression( 

620 operand=qt.DimensionFieldReference( 

621 element=self.universe.dimensions["detector"], field="purpose" 

622 ), 

623 operator="-", 

624 ) 

625 

626 def test_unary_timespan_begin(self) -> None: 

627 expr = self.x.unwrap(self.x.visit.timespan.begin) 

628 self.assertIsNone(expr.get_literal_value()) 

629 self.assertEqual(expr.expression_type, "unary") 

630 self.assertEqual(expr.column_type, "datetime") 

631 self.assertEqual(str(expr), "visit.timespan.begin") 

632 self.assertFalse(expr.is_literal) 

633 columns = qt.ColumnSet(self.universe.empty.as_group()) 

634 expr.gather_required_columns(columns) 

635 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

636 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

637 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

638 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

639 value = Timespan(begin, end) 

640 self.assertEqual( 

641 expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin 

642 ) 

643 with self.assertRaises(qt.InvalidQueryError): 

644 qt.UnaryExpression( 

645 operand=qt.DimensionFieldReference( 

646 element=self.universe.dimensions["detector"], field="purpose" 

647 ), 

648 operator="begin_of", 

649 ) 

650 

651 def test_unary_timespan_end(self) -> None: 

652 expr = self.x.unwrap(self.x.visit.timespan.end) 

653 self.assertIsNone(expr.get_literal_value()) 

654 self.assertEqual(expr.expression_type, "unary") 

655 self.assertEqual(expr.column_type, "datetime") 

656 self.assertEqual(str(expr), "visit.timespan.end") 

657 self.assertFalse(expr.is_literal) 

658 columns = qt.ColumnSet(self.universe.empty.as_group()) 

659 expr.gather_required_columns(columns) 

660 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

661 self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) 

662 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

663 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

664 value = Timespan(begin, end) 

665 self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end) 

666 with self.assertRaises(qt.InvalidQueryError): 

667 qt.UnaryExpression( 

668 operand=qt.DimensionFieldReference( 

669 element=self.universe.dimensions["detector"], field="purpose" 

670 ), 

671 operator="end_of", 

672 ) 

673 

674 def test_binary_expression_float(self) -> None: 

675 for proxy, string, value in [ 

676 (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0), 

677 (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0), 

678 (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0), 

679 (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0), 

680 (15.0 + -self.x.visit.exposure_time, "15.0 + (-visit.exposure_time)", -15.0), 

681 (10.0 - -self.x.visit.exposure_time, "10.0 - (-visit.exposure_time)", 40.0), 

682 (6.0 * -self.x.visit.exposure_time, "6.0 * (-visit.exposure_time)", -180.0), 

683 (30.0 / -self.x.visit.exposure_time, "30.0 / (-visit.exposure_time)", -1.0), 

684 ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0), 

685 ((self.x.visit.exposure_time + 15.0) + 45.0, "(visit.exposure_time + 15.0) + 45.0", 90.0), 

686 ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0), 

687 # We don't need the parentheses we generate in the next one, but 

688 # they're not a problem either. 

689 ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0), 

690 (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * ((-visit.exposure_time) - 15.0)", -270.0), 

691 (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + ((-visit.exposure_time) - 15.0)", 15.0), 

692 (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / ((-visit.exposure_time) - 15.0)", -2.0), 

693 (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - ((-visit.exposure_time) - 15.0)", 105.0), 

694 ]: 

695 with self.subTest(string=string): 

696 expr = self.x.unwrap(proxy) 

697 self.assertIsNone(expr.get_literal_value()) 

698 self.assertEqual(expr.expression_type, "binary") 

699 self.assertEqual(expr.column_type, "float") 

700 self.assertEqual(str(expr), string) 

701 self.assertFalse(expr.is_literal) 

702 columns = qt.ColumnSet(self.universe.empty.as_group()) 

703 expr.gather_required_columns(columns) 

704 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

705 self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) 

706 self.assertEqual( 

707 expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value 

708 ) 

709 

710 def test_binary_modulus(self) -> None: 

711 for proxy, string, value in [ 

712 (self.x.visit.id % 2, "visit % 2", 1), 

713 (52 % self.x.visit, "52 % visit", 2), 

714 ]: 

715 with self.subTest(string=string): 

716 expr = self.x.unwrap(proxy) 

717 self.assertIsNone(expr.get_literal_value()) 

718 self.assertEqual(expr.expression_type, "binary") 

719 self.assertEqual(expr.column_type, "int") 

720 self.assertEqual(str(expr), string) 

721 self.assertFalse(expr.is_literal) 

722 columns = qt.ColumnSet(self.universe.empty.as_group()) 

723 expr.gather_required_columns(columns) 

724 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

725 self.assertFalse(columns.dimension_fields["visit"]) 

726 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value) 

727 

728 def test_binary_expression_validation(self) -> None: 

729 with self.assertRaises(qt.InvalidQueryError): 

730 # No arithmetic operators on strings (we do not interpret + as 

731 # concatenation). 

732 self.x.instrument + "suffix" 

733 with self.assertRaises(qt.InvalidQueryError): 

734 # Mixed types are not supported, even when they both support the 

735 # operator. 

736 self.x.visit.exposure_time + self.x.detector 

737 with self.assertRaises(qt.InvalidQueryError): 

738 # No modulus for floats. 

739 self.x.visit.exposure_time % 5.0 

740 

741 def test_reversed(self) -> None: 

742 expr = self.x.detector.desc 

743 self.assertIsNone(expr.get_literal_value()) 

744 self.assertEqual(expr.expression_type, "reversed") 

745 self.assertEqual(expr.column_type, "int") 

746 self.assertEqual(str(expr), "detector DESC") 

747 self.assertFalse(expr.is_literal) 

748 columns = qt.ColumnSet(self.universe.empty.as_group()) 

749 expr.gather_required_columns(columns) 

750 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

751 self.assertFalse(columns.dimension_fields["detector"]) 

752 self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5)) 

753 

754 def test_trivial_predicate(self) -> None: 

755 """Test logical operations on trivial True/False predicates.""" 

756 yes = qt.Predicate.from_bool(True) 

757 no = qt.Predicate.from_bool(False) 

758 maybe: qt.Predicate = self.x.detector == 5 

759 for predicate in [ 

760 yes, 

761 yes.logical_or(no), 

762 no.logical_or(yes), 

763 yes.logical_and(yes), 

764 no.logical_not(), 

765 yes.logical_or(maybe), 

766 maybe.logical_or(yes), 

767 ]: 

768 self.assertEqual(predicate.column_type, "bool") 

769 self.assertEqual(str(predicate), "True") 

770 self.assertTrue(predicate.visit(_TestVisitor())) 

771 self.assertEqual(predicate.operands, ()) 

772 for predicate in [ 

773 no, 

774 yes.logical_and(no), 

775 no.logical_and(yes), 

776 no.logical_or(no), 

777 yes.logical_not(), 

778 no.logical_and(maybe), 

779 maybe.logical_and(no), 

780 ]: 

781 self.assertEqual(predicate.column_type, "bool") 

782 self.assertEqual(str(predicate), "False") 

783 self.assertFalse(predicate.visit(_TestVisitor())) 

784 self.assertEqual(predicate.operands, ((),)) 

785 for predicate in [ 

786 maybe, 

787 yes.logical_and(maybe), 

788 no.logical_or(maybe), 

789 maybe.logical_not().logical_not(), 

790 ]: 

791 self.assertEqual(predicate.column_type, "bool") 

792 self.assertEqual(str(predicate), "detector == 5") 

793 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5}))) 

794 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4}))) 

795 self.assertEqual(len(predicate.operands), 1) 

796 self.assertEqual(len(predicate.operands[0]), 1) 

797 self.assertIs(predicate.operands[0][0], maybe.operands[0][0]) 

798 

799 def test_comparison(self) -> None: 

800 predicate: qt.Predicate 

801 string: str 

802 value: bool 

803 for detector in (4, 5, 6): 

804 for predicate, string, value in [ 

805 (self.x.detector == 5, "detector == 5", detector == 5), 

806 (self.x.detector != 5, "detector != 5", detector != 5), 

807 (self.x.detector < 5, "detector < 5", detector < 5), 

808 (self.x.detector > 5, "detector > 5", detector > 5), 

809 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

810 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

811 (self.x.detector == 5, "detector == 5", detector == 5), 

812 (self.x.detector != 5, "detector != 5", detector != 5), 

813 (self.x.detector < 5, "detector < 5", detector < 5), 

814 (self.x.detector > 5, "detector > 5", detector > 5), 

815 (self.x.detector <= 5, "detector <= 5", detector <= 5), 

816 (self.x.detector >= 5, "detector >= 5", detector >= 5), 

817 ]: 

818 with self.subTest(string=string, detector=detector): 

819 self.assertEqual(predicate.column_type, "bool") 

820 self.assertEqual(str(predicate), string) 

821 columns = qt.ColumnSet(self.universe.empty.as_group()) 

822 predicate.gather_required_columns(columns) 

823 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

824 self.assertFalse(columns.dimension_fields["detector"]) 

825 self.assertEqual( 

826 predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value 

827 ) 

828 inverted = predicate.logical_not() 

829 self.assertEqual(inverted.column_type, "bool") 

830 self.assertEqual(str(inverted), f"NOT {string}") 

831 self.assertEqual( 

832 inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value 

833 ) 

834 columns = qt.ColumnSet(self.universe.empty.as_group()) 

835 inverted.gather_required_columns(columns) 

836 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

837 self.assertFalse(columns.dimension_fields["detector"]) 

838 

839 def test_overlap_comparison(self) -> None: 

840 pixelization = Mq3cPixelization(10) 

841 region1 = pixelization.quad(12058870) 

842 predicate = self.x.visit.region.overlaps(region1) 

843 self.assertEqual(predicate.column_type, "bool") 

844 self.assertEqual(str(predicate), "visit.region OVERLAPS (region)") 

845 columns = qt.ColumnSet(self.universe.empty.as_group()) 

846 predicate.gather_required_columns(columns) 

847 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

848 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

849 region2 = pixelization.quad(12058857) 

850 self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

851 inverted = predicate.logical_not() 

852 self.assertEqual(inverted.column_type, "bool") 

853 self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)") 

854 self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) 

855 columns = qt.ColumnSet(self.universe.empty.as_group()) 

856 inverted.gather_required_columns(columns) 

857 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

858 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

859 

860 def test_invalid_comparison(self) -> None: 

861 # Mixed type comparisons. 

862 with self.assertRaises(qt.InvalidQueryError): 

863 self.x.visit > "three" 

864 with self.assertRaises(qt.InvalidQueryError): 

865 self.x.visit > 3.0 

866 # Invalid operator for type. 

867 with self.assertRaises(qt.InvalidQueryError): 

868 self.x["raw"].dataset_id < uuid.uuid4() 

869 

870 def test_is_null(self) -> None: 

871 predicate = self.x.visit.region.is_null 

872 self.assertEqual(predicate.column_type, "bool") 

873 self.assertEqual(str(predicate), "visit.region IS NULL") 

874 columns = qt.ColumnSet(self.universe.empty.as_group()) 

875 predicate.gather_required_columns(columns) 

876 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

877 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

878 self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

879 inverted = predicate.logical_not() 

880 self.assertEqual(inverted.column_type, "bool") 

881 self.assertEqual(str(inverted), "NOT visit.region IS NULL") 

882 self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) 

883 inverted.gather_required_columns(columns) 

884 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

885 self.assertEqual(columns.dimension_fields["visit"], {"region"}) 

886 

887 def test_in_container(self) -> None: 

888 predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id]) 

889 self.assertEqual(predicate.column_type, "bool") 

890 self.assertEqual(str(predicate), "visit IN [3, 4, exposure]") 

891 columns = qt.ColumnSet(self.universe.empty.as_group()) 

892 predicate.gather_required_columns(columns) 

893 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

894 self.assertFalse(columns.dimension_fields["visit"]) 

895 self.assertFalse(columns.dimension_fields["exposure"]) 

896 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

897 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

898 inverted = predicate.logical_not() 

899 self.assertEqual(inverted.column_type, "bool") 

900 self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]") 

901 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) 

902 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) 

903 columns = qt.ColumnSet(self.universe.empty.as_group()) 

904 inverted.gather_required_columns(columns) 

905 self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) 

906 self.assertFalse(columns.dimension_fields["visit"]) 

907 self.assertFalse(columns.dimension_fields["exposure"]) 

908 with self.assertRaises(qt.InvalidQueryError): 

909 # Regions (and timespans) not allowed in IN expressions, since that 

910 # suggests topological logic we're not actually doing. We can't 

911 # use ExpressionFactory because it prohibits this case with typing. 

912 pixelization = Mq3cPixelization(10) 

913 region = pixelization.quad(12058870) 

914 qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)]) 

915 with self.assertRaises(qt.InvalidQueryError): 

916 # Mismatched types. 

917 self.x.visit.in_iterable([3.5, 2.1]) 

918 

919 def test_in_range(self) -> None: 

920 predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2) 

921 self.assertEqual(predicate.column_type, "bool") 

922 self.assertEqual(str(predicate), "visit IN 2:8:2") 

923 columns = qt.ColumnSet(self.universe.empty.as_group()) 

924 predicate.gather_required_columns(columns) 

925 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

926 self.assertFalse(columns.dimension_fields["visit"]) 

927 self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

928 self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

929 inverted = predicate.logical_not() 

930 self.assertEqual(inverted.column_type, "bool") 

931 self.assertEqual(str(inverted), "NOT visit IN 2:8:2") 

932 self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2}))) 

933 self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8}))) 

934 columns = qt.ColumnSet(self.universe.empty.as_group()) 

935 inverted.gather_required_columns(columns) 

936 self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) 

937 self.assertFalse(columns.dimension_fields["visit"]) 

938 with self.assertRaises(qt.InvalidQueryError): 

939 # Only integer fields allowed. 

940 self.x.visit.exposure_time.in_range(2, 4) 

941 with self.assertRaises(qt.InvalidQueryError): 

942 # Step must be positive. 

943 self.x.visit.in_range(2, 4, -1) 

944 with self.assertRaises(qt.InvalidQueryError): 

945 # Stop must be >= start. 

946 self.x.visit.in_range(2, 0) 

947 

948 def test_in_query(self) -> None: 

949 query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3) 

950 predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query) 

951 self.assertEqual(predicate.column_type, "bool") 

952 self.assertEqual(str(predicate), "exposure IN (query).visit") 

953 columns = qt.ColumnSet(self.universe.empty.as_group()) 

954 predicate.gather_required_columns(columns) 

955 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

956 self.assertFalse(columns.dimension_fields["exposure"]) 

957 self.assertTrue( 

958 predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

959 ) 

960 self.assertFalse( 

961 predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

962 ) 

963 inverted = predicate.logical_not() 

964 self.assertEqual(inverted.column_type, "bool") 

965 self.assertEqual(str(inverted), "NOT exposure IN (query).visit") 

966 self.assertFalse( 

967 inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) 

968 ) 

969 self.assertTrue( 

970 inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) 

971 ) 

972 columns = qt.ColumnSet(self.universe.empty.as_group()) 

973 inverted.gather_required_columns(columns) 

974 self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) 

975 self.assertFalse(columns.dimension_fields["exposure"]) 

976 with self.assertRaises(qt.InvalidQueryError): 

977 # Regions (and timespans) not allowed in IN expressions, since that 

978 # suggests topological logic we're not actually doing. We can't 

979 # use ExpressionFactory because it prohibits this case with typing. 

980 qt.Predicate.in_query( 

981 self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree 

982 ) 

983 with self.assertRaises(qt.InvalidQueryError): 

984 # Mismatched types. 

985 self.x.exposure.in_query(self.x.visit.exposure_time, query) 

986 with self.assertRaises(qt.InvalidQueryError): 

987 # Query column requires dimensions that are not in the query. 

988 self.x.exposure.in_query(self.x.patch, query) 

989 with self.assertRaises(qt.InvalidQueryError): 

990 # Query column requires dataset type that is not in the query. 

991 self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query) 

992 

993 def test_complex_predicate(self) -> None: 

994 """Test that predicates are converted to conjunctive normal form and 

995 get parentheses in the right places when stringified. 

996 """ 

997 visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"}) 

998 a: qt.Predicate = self.x.visit > 5 # will evaluate to True 

999 b: qt.Predicate = self.x.detector != 3 # will evaluate to False 

1000 c: qt.Predicate = self.x.instrument == "i" # will evaluate to True 

1001 d: qt.Predicate = self.x.band == "g" # will evaluate to False 

1002 predicate: qt.Predicate 

1003 for predicate, string, value in [ 

1004 (a.logical_or(b), f"{a} OR {b}", True), 

1005 (a.logical_or(c), f"{a} OR {c}", True), 

1006 (b.logical_or(d), f"{b} OR {d}", False), 

1007 (a.logical_and(b), f"{a} AND {b}", False), 

1008 (a.logical_and(c), f"{a} AND {c}", True), 

1009 (b.logical_and(d), f"{b} AND {d}", False), 

1010 (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True), 

1011 (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False), 

1012 (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True), 

1013 (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False), 

1014 (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True), 

1015 ( 

1016 a.logical_and(b).logical_or(c.logical_and(d)), 

1017 f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})", 

1018 False, 

1019 ), 

1020 (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False), 

1021 (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False), 

1022 (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True), 

1023 (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True), 

1024 (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False), 

1025 (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True), 

1026 ( 

1027 self.x.not_(a.logical_or(b).logical_and(c)), 

1028 f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})", 

1029 False, 

1030 ), 

1031 ( 

1032 a.logical_and(b.logical_or(d)).logical_not(), 

1033 f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})", 

1034 True, 

1035 ), 

1036 ]: 

1037 with self.subTest(string=string): 

1038 self.assertEqual(str(predicate), string) 

1039 self.assertEqual(predicate.visit(visitor), value) 

1040 

1041 def test_proxy_misc(self) -> None: 

1042 """Test miscellaneous things on various ExpressionFactory proxies.""" 

1043 self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region") 

1044 self.assertEqual(str(self.x.visit.instrument), "instrument") 

1045 self.assertEqual(str(self.x["raw"]), "raw") 

1046 self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date") 

1047 self.assertEqual( 

1048 str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)), 

1049 "visit.timespan OVERLAPS raw.timespan", 

1050 ) 

1051 self.assertGreater( 

1052 set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"} 

1053 ) 

1054 self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"}) 

1055 with self.assertRaises(AttributeError): 

1056 self.x["raw"].seq_num 

1057 with self.assertRaises(AttributeError): 

1058 self.x.visit.horse 

1059 

1060 

1061class QueryTestCase(unittest.TestCase): 

1062 """Tests for Query and *QueryResults objects in lsst.daf.butler.queries.""" 

1063 

1064 def setUp(self) -> None: 

1065 self.maxDiff = None 

1066 self.universe = DimensionUniverse() 

1067 # We use ArrowTable as the storage class for all dataset types because 

1068 # it's got conversions that only require third-party packages we 

1069 # already require. 

1070 self.raw = DatasetType( 

1071 "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable" 

1072 ) 

1073 self.refcat = DatasetType( 

1074 "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable" 

1075 ) 

1076 self.bias = DatasetType( 

1077 "bias", 

1078 dimensions=self.universe.conform(["detector"]), 

1079 storageClass="ArrowTable", 

1080 isCalibration=True, 

1081 ) 

1082 self.default_collections: list[str] | None = ["DummyCam/defaults"] 

1083 self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = { 

1084 "DummyCam/raw/all": ( 

1085 RunRecord[int](1, name="DummyCam/raw/all"), 

1086 CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}), 

1087 ), 

1088 "DummyCam/calib": ( 

1089 CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION), 

1090 CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}), 

1091 ), 

1092 "refcats": ( 

1093 RunRecord[int](3, name="refcats"), 

1094 CollectionSummary(NamedValueSet({self.refcat}), governors={}), 

1095 ), 

1096 "DummyCam/defaults": ( 

1097 ChainedCollectionRecord[int]( 

1098 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats") 

1099 ), 

1100 CollectionSummary( 

1101 NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}} 

1102 ), 

1103 ), 

1104 } 

1105 self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias} 

1106 

1107 def query(self, **kwargs: Any) -> Query: 

1108 """Make an initial Query object with the given kwargs used to 

1109 initialize the _TestQueryDriver. 

1110 

1111 The given kwargs override the test-case-attribute defaults. 

1112 """ 

1113 kwargs.setdefault("default_collections", self.default_collections) 

1114 kwargs.setdefault("collection_info", self.collection_info) 

1115 kwargs.setdefault("dataset_types", self.dataset_types) 

1116 return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) 

1117 

1118 def test_dataset_join(self) -> None: 

1119 """Test queries that have had a dataset search explicitly joined in via 

1120 Query.join_dataset_search. 

1121 

1122 Since this kind of query has a moderate amount of complexity, this is 

1123 where we get a lot of basic coverage that applies to all kinds of 

1124 queries, including: 

1125 

1126 - getting data ID and dataset results (but not iterating over them); 

1127 - the 'any' and 'explain_no_results' methods; 

1128 - adding 'where' filters (but not expanding dimensions accordingly); 

1129 - materializations. 

1130 """ 

1131 

1132 def check( 

1133 query: Query, 

1134 dimensions: DimensionGroup = self.raw.dimensions.as_group(), 

1135 ) -> None: 

1136 """Run a battery of tests on one of a set of very similar queries 

1137 constructed in different ways (see below). 

1138 """ 

1139 

1140 def check_query_tree( 

1141 tree: qt.QueryTree, 

1142 dimensions: DimensionGroup = dimensions, 

1143 ) -> None: 

1144 """Check the state of the QueryTree object that backs the Query 

1145 or a derived QueryResults object. 

1146 

1147 Parameters 

1148 ---------- 

1149 tree : `lsst.daf.butler.queries.tree.QueryTree` 

1150 Object to test. 

1151 dimensions : `DimensionGroup` 

1152 Dimensions to expect in the `QueryTree`, not necessarily 

1153 including those in the test 'raw' dataset type. 

1154 """ 

1155 self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group()) 

1156 self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'") 

1157 self.assertFalse(tree.materializations) 

1158 self.assertFalse(tree.data_coordinate_uploads) 

1159 self.assertEqual(tree.datasets.keys(), {"raw"}) 

1160 self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) 

1161 self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1162 self.assertEqual( 

1163 tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()}) 

1164 ) 

1165 

1166 def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None: 

1167 """Construct a DataCoordinateQueryResults object from the query 

1168 with the given arguments and run a battery of tests on it. 

1169 

1170 Parameters 

1171 ---------- 

1172 *args 

1173 Forwarded to `Query.data_ids`. 

1174 query : `Query` 

1175 Query to start from. 

1176 dimensions : `DimensionGroup`, optional 

1177 Dimensions the result data IDs should have. 

1178 """ 

1179 with self.assertRaises(_TestQueryExecution) as cm: 

1180 list(query.data_ids(*args)) 

1181 self.assertEqual( 

1182 cm.exception.result_spec, 

1183 qrs.DataCoordinateResultSpec(dimensions=dimensions), 

1184 ) 

1185 check_query_tree(cm.exception.tree, dimensions=dimensions) 

1186 

1187 def check_dataset_results( 

1188 *args: Any, 

1189 query: Query, 

1190 find_first: bool = True, 

1191 storage_class_name: str = self.raw.storageClass_name, 

1192 ) -> None: 

1193 """Construct a DatasetRefQueryResults object from the query 

1194 with the given arguments and run a battery of tests on it. 

1195 

1196 Parameters 

1197 ---------- 

1198 *args 

1199 Forwarded to `Query.datasets`. 

1200 query : `Query` 

1201 Query to start from. 

1202 find_first : `bool`, optional 

1203 Whether to do find-first resolution on the results. 

1204 storage_class_name : `str`, optional 

1205 Expected name of the storage class for the results. 

1206 """ 

1207 with self.assertRaises(_TestQueryExecution) as cm: 

1208 list(query.datasets(*args, find_first=find_first)) 

1209 self.assertEqual( 

1210 cm.exception.result_spec, 

1211 qrs.DatasetRefResultSpec( 

1212 dataset_type_name="raw", 

1213 dimensions=self.raw.dimensions.as_group(), 

1214 storage_class_name=storage_class_name, 

1215 find_first=find_first, 

1216 ), 

1217 ) 

1218 check_query_tree(cm.exception.tree) 

1219 

1220 def check_materialization( 

1221 kwargs: Mapping[str, Any], 

1222 query: Query, 

1223 dimensions: DimensionGroup = dimensions, 

1224 has_dataset: bool = True, 

1225 ) -> None: 

1226 """Materialize the query with the given arguments and run a 

1227 battery of tests on the result. 

1228 

1229 Parameters 

1230 ---------- 

1231 kwargs 

1232 Forwarded as keyword arguments to `Query.materialize`. 

1233 query : `Query` 

1234 Query to start from. 

1235 dimensions : `DimensionGroup`, optional 

1236 Dimensions to expect in the materialization and its derived 

1237 query. 

1238 has_dataset : `bool`, optional 

1239 Whether the query backed by the materialization should 

1240 still have the test 'raw' dataset joined in. 

1241 """ 

1242 # Materialize the query and check the query tree sent to the 

1243 # driver and the one in the materialized query. 

1244 with self.assertRaises(_TestQueryExecution) as cm: 

1245 list(query.materialize(**kwargs).data_ids()) 

1246 derived_tree = cm.exception.tree 

1247 self.assertEqual(derived_tree.dimensions, dimensions) 

1248 # Predicate should be materialized away; it no longer appears 

1249 # in the derived query. 

1250 self.assertEqual(str(derived_tree.predicate), "True") 

1251 self.assertFalse(derived_tree.data_coordinate_uploads) 

1252 if has_dataset: 

1253 # Dataset search is still there, even though its existence 

1254 # constraint is included in the materialization, because we 

1255 # might need to re-join for some result columns in a 

1256 # derived query. 

1257 self.assertTrue(derived_tree.datasets.keys(), {"raw"}) 

1258 self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) 

1259 self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",)) 

1260 else: 

1261 self.assertFalse(derived_tree.datasets) 

1262 ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items() 

1263 self.assertEqual(derived_tree_materialized_dimensions, dimensions) 

1264 ( 

1265 materialized_tree, 

1266 materialized_dimensions, 

1267 materialized_datasets, 

1268 ) = cm.exception.driver.materializations[key] 

1269 self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions) 

1270 if has_dataset: 

1271 self.assertEqual(materialized_datasets, {"raw"}) 

1272 else: 

1273 self.assertFalse(materialized_datasets) 

1274 check_query_tree(materialized_tree) 

1275 

1276 # Actual logic for the check() function begins here. 

1277 

1278 self.assertEqual(query.constraint_dataset_types, {"raw"}) 

1279 self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group()) 

1280 

1281 # Adding a constraint on a field for this dataset type should work 

1282 # (this constraint will be present in all downstream tests). 

1283 query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all") 

1284 with self.assertRaises(qt.InvalidQueryError): 

1285 # Adding constraint on a different dataset should not work. 

1286 query.where(query.expression_factory["refcat"].run == "refcats") 

1287 

1288 # Data IDs, with dimensions defaulted. 

1289 check_data_id_results(query=query) 

1290 # Dimensions for data IDs the same as defaults. 

1291 check_data_id_results(["exposure", "detector"], query=query) 

1292 # Dimensions are a subset of the query dimensions. 

1293 check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"])) 

1294 # Dimensions are a superset of the query dimensions. 

1295 check_data_id_results( 

1296 ["exposure", "detector", "visit"], 

1297 query=query, 

1298 dimensions=self.universe.conform(["exposure", "detector", "visit"]), 

1299 ) 

1300 # Dimensions are neither a superset nor a subset of the query 

1301 # dimensions. 

1302 check_data_id_results( 

1303 ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"]) 

1304 ) 

1305 # Dimensions are empty. 

1306 check_data_id_results([], query=query, dimensions=self.universe.conform([])) 

1307 

1308 # Get DatasetRef results, with various arguments and defaulting. 

1309 check_dataset_results("raw", query=query) 

1310 check_dataset_results("raw", query=query, find_first=True) 

1311 check_dataset_results("raw", ["DummyCam/defaults"], query=query) 

1312 check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True) 

1313 check_dataset_results(self.raw, query=query) 

1314 check_dataset_results(self.raw, query=query, find_first=True) 

1315 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query) 

1316 check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True) 

1317 

1318 # Changing collections at this stage is not allowed. 

1319 with self.assertRaises(qt.InvalidQueryError): 

1320 query.datasets("raw", collections=["DummyCam/calib"]) 

1321 

1322 # Changing storage classes is allowed, if they're compatible. 

1323 check_dataset_results( 

1324 self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy" 

1325 ) 

1326 with self.assertRaises(DatasetTypeError): 

1327 # Can't use overrideStorageClass, because it'll raise 

1328 # before the code we want to test can. 

1329 query.datasets(DatasetType("raw", self.raw.dimensions, "int")) 

1330 

1331 # Check the 'any' and 'explain_no_results' methods on Query itself. 

1332 for execute, exact in itertools.permutations([False, True], 2): 

1333 with self.assertRaises(_TestQueryAny) as cm: 

1334 query.any(execute=execute, exact=exact) 

1335 self.assertEqual(cm.exception.execute, execute) 

1336 self.assertEqual(cm.exception.exact, exact) 

1337 check_query_tree(cm.exception.tree, dimensions) 

1338 with self.assertRaises(_TestQueryExplainNoResults): 

1339 query.explain_no_results() 

1340 check_query_tree(cm.exception.tree, dimensions) 

1341 

1342 # Materialize the query with defaults. 

1343 check_materialization({}, query=query) 

1344 # Materialize the query with args that match defaults. 

1345 check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query) 

1346 # Materialize the query with a superset of the original dimensions. 

1347 check_materialization( 

1348 {"dimensions": ["exposure", "detector", "visit"]}, 

1349 query=query, 

1350 dimensions=self.universe.conform(["exposure", "visit", "detector"]), 

1351 ) 

1352 # Materialize the query with no datasets. 

1353 check_materialization( 

1354 {"dimensions": ["exposure", "detector"], "datasets": frozenset()}, 

1355 query=query, 

1356 has_dataset=False, 

1357 ) 

1358 # Materialize the query with no datasets and a subset of the 

1359 # dimensions. 

1360 check_materialization( 

1361 {"dimensions": ["exposure"], "datasets": frozenset()}, 

1362 query=query, 

1363 has_dataset=False, 

1364 dimensions=self.universe.conform(["exposure"]), 

1365 ) 

1366 # Materializing the query with a dataset that is not in the query 

1367 # is an error. 

1368 with self.assertRaises(qt.InvalidQueryError): 

1369 query.materialize(datasets={"refcat"}) 

1370 # Materializing the query with dimensions that are not a superset 

1371 # of any materialized dataset dimensions is an error. 

1372 with self.assertRaises(qt.InvalidQueryError): 

1373 query.materialize(dimensions=["exposure"], datasets={"raw"}) 

1374 

1375 # Actual logic for test_dataset_joins starts here. 

1376 

1377 # Default collections and existing dataset type name. 

1378 check(self.query().join_dataset_search("raw")) 

1379 # Default collections and existing DatasetType instance. 

1380 check(self.query().join_dataset_search(self.raw)) 

1381 # Manual collections and existing dataset type. 

1382 check( 

1383 self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"]) 

1384 ) 

1385 check( 

1386 self.query(default_collections=None).join_dataset_search( 

1387 self.raw, collections=["DummyCam/defaults"] 

1388 ) 

1389 ) 

1390 with self.assertRaises(MissingDatasetTypeError): 

1391 # Dataset type does not exist. 

1392 self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"]) 

1393 with self.assertRaises(DatasetTypeError): 

1394 # Dataset type object with bad dimensions passed. 

1395 self.query().join_dataset_search( 

1396 DatasetType( 

1397 "raw", 

1398 dimensions={"detector", "visit"}, 

1399 storageClass=self.raw.storageClass_name, 

1400 universe=self.universe, 

1401 ) 

1402 ) 

1403 with self.assertRaises(TypeError): 

1404 # Bad type for dataset type argument. 

1405 self.query().join_dataset_search(3) 

1406 with self.assertRaises(qt.InvalidQueryError): 

1407 # Cannot pass storage class override to join_dataset_search, 

1408 # because we cannot use it there. 

1409 self.query().join_dataset_search(self.raw.overrideStorageClass("ArrowAstropy")) 

1410 

1411 def test_dimension_record_results(self) -> None: 

1412 """Test queries that return dimension records. 

1413 

1414 This includes tests for: 

1415 

1416 - joining against uploaded data coordinates; 

1417 - counting result rows; 

1418 - expanding dimensions as needed for 'where' conditions; 

1419 - order_by and limit. 

1420 

1421 It does not include the iteration methods of 

1422 DimensionRecordQueryResults, since those require a different mock 

1423 driver setup (see test_dimension_record_iteration). 

1424 """ 

1425 # Set up the base query-results object to test. 

1426 query = self.query() 

1427 x = query.expression_factory 

1428 self.assertFalse(query.constraint_dimensions) 

1429 query = query.where(x.skymap == "m") 

1430 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"])) 

1431 upload_rows = [ 

1432 DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe), 

1433 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe), 

1434 ] 

1435 raw_rows = frozenset([data_id.required_values for data_id in upload_rows]) 

1436 query = query.join_data_coordinates(upload_rows) 

1437 self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"])) 

1438 results = query.dimension_records("patch") 

1439 results = results.where(x.tract == 4) 

1440 

1441 # Define a closure to run tests on variants of the base query. 

1442 def check( 

1443 results: DimensionRecordQueryResults, 

1444 order_by: Any = (), 

1445 limit: int | None = None, 

1446 ) -> list[str]: 

1447 results = results.order_by(*order_by).limit(limit) 

1448 self.assertEqual(results.element.name, "patch") 

1449 with self.assertRaises(_TestQueryExecution) as cm: 

1450 list(results) 

1451 tree = cm.exception.tree 

1452 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1453 self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"])) 

1454 self.assertFalse(tree.materializations) 

1455 self.assertFalse(tree.datasets) 

1456 ((key, upload_dimensions),) = tree.data_coordinate_uploads.items() 

1457 self.assertEqual(upload_dimensions, self.universe.conform(["visit"])) 

1458 self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows)) 

1459 result_spec = cm.exception.result_spec 

1460 self.assertEqual(result_spec.result_type, "dimension_record") 

1461 self.assertEqual(result_spec.element, self.universe["patch"]) 

1462 self.assertEqual(result_spec.limit, limit) 

1463 for exact, discard in itertools.permutations([False, True], r=2): 

1464 with self.assertRaises(_TestQueryCount) as cm: 

1465 results.count(exact=exact, discard=discard) 

1466 self.assertEqual(cm.exception.result_spec, result_spec) 

1467 self.assertEqual(cm.exception.exact, exact) 

1468 self.assertEqual(cm.exception.discard, discard) 

1469 return [str(term) for term in result_spec.order_by] 

1470 

1471 # Run the closure's tests on variants of the base query. 

1472 self.assertEqual(check(results), []) 

1473 self.assertEqual(check(results, limit=2), []) 

1474 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1475 self.assertEqual( 

1476 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1477 ["patch.cell_x", "patch.cell_y DESC"], 

1478 ) 

1479 with self.assertRaises(qt.InvalidQueryError): 

1480 # Cannot upload empty list of data IDs. 

1481 query.join_data_coordinates([]) 

1482 with self.assertRaises(qt.InvalidQueryError): 

1483 # Cannot upload heterogeneous list of data IDs. 

1484 query.join_data_coordinates( 

1485 [ 

1486 DataCoordinate.make_empty(self.universe), 

1487 DataCoordinate.standardize(instrument="DummyCam", universe=self.universe), 

1488 ] 

1489 ) 

1490 

1491 def test_dimension_record_iteration(self) -> None: 

1492 """Tests for DimensionRecordQueryResult iteration.""" 

1493 

1494 def make_record(n: int) -> DimensionRecord: 

1495 return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n) 

1496 

1497 result_rows = ( 

1498 [make_record(n) for n in range(3)], 

1499 [make_record(n) for n in range(3, 6)], 

1500 [make_record(10)], 

1501 ) 

1502 results = self.query(result_rows=result_rows).dimension_records("patch") 

1503 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1504 self.assertEqual( 

1505 list(results.iter_set_pages()), 

1506 [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows], 

1507 ) 

1508 self.assertEqual( 

1509 [table.column("id").to_pylist() for table in results.iter_table_pages()], 

1510 [list(range(3)), list(range(3, 6)), [10]], 

1511 ) 

1512 

1513 def test_data_coordinate_results(self) -> None: 

1514 """Test queries that return data coordinates. 

1515 

1516 This includes tests for: 

1517 

1518 - counting result rows; 

1519 - expanding dimensions as needed for 'where' conditions; 

1520 - order_by and limit. 

1521 

1522 It does not include the iteration methods of 

1523 DataCoordinateQueryResults, since those require a different mock 

1524 driver setup (see test_data_coordinate_iteration). More tests for 

1525 different inputs to DataCoordinateQueryResults construction are in 

1526 test_dataset_join. 

1527 """ 

1528 # Set up the base query-results object to test. 

1529 query = self.query() 

1530 x = query.expression_factory 

1531 self.assertFalse(query.constraint_dimensions) 

1532 query = query.where(x.skymap == "m") 

1533 results = query.data_ids(["patch", "band"]) 

1534 results = results.where(x.tract == 4) 

1535 

1536 # Define a closure to run tests on variants of the base query. 

1537 def check( 

1538 results: DataCoordinateQueryResults, 

1539 order_by: Any = (), 

1540 limit: int | None = None, 

1541 include_dimension_records: bool = False, 

1542 ) -> list[str]: 

1543 results = results.order_by(*order_by).limit(limit) 

1544 self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"])) 

1545 with self.assertRaises(_TestQueryExecution) as cm: 

1546 list(results) 

1547 tree = cm.exception.tree 

1548 self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") 

1549 self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"])) 

1550 self.assertFalse(tree.materializations) 

1551 self.assertFalse(tree.datasets) 

1552 self.assertFalse(tree.data_coordinate_uploads) 

1553 result_spec = cm.exception.result_spec 

1554 self.assertEqual(result_spec.result_type, "data_coordinate") 

1555 self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"])) 

1556 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1557 self.assertEqual(result_spec.limit, limit) 

1558 self.assertIsNone(result_spec.find_first_dataset) 

1559 for exact, discard in itertools.permutations([False, True], r=2): 

1560 with self.assertRaises(_TestQueryCount) as cm: 

1561 results.count(exact=exact, discard=discard) 

1562 self.assertEqual(cm.exception.result_spec, result_spec) 

1563 self.assertEqual(cm.exception.exact, exact) 

1564 self.assertEqual(cm.exception.discard, discard) 

1565 return [str(term) for term in result_spec.order_by] 

1566 

1567 # Run the closure's tests on variants of the base query. 

1568 self.assertEqual(check(results), []) 

1569 self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), []) 

1570 self.assertEqual( 

1571 check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1572 [], 

1573 ) 

1574 self.assertEqual(check(results, limit=2), []) 

1575 self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) 

1576 self.assertEqual( 

1577 check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc]), 

1578 ["patch.cell_x", "patch.cell_y DESC"], 

1579 ) 

1580 self.assertEqual( 

1581 check(results, order_by=["patch.cell_x", "-cell_y"]), 

1582 ["patch.cell_x", "patch.cell_y DESC"], 

1583 ) 

1584 

1585 def test_data_coordinate_iteration(self) -> None: 

1586 """Tests for DataCoordinateQueryResult iteration.""" 

1587 

1588 def make_data_id(n: int) -> DimensionRecord: 

1589 return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe) 

1590 

1591 result_rows = ( 

1592 [make_data_id(n) for n in range(3)], 

1593 [make_data_id(n) for n in range(3, 6)], 

1594 [make_data_id(10)], 

1595 ) 

1596 results = self.query(result_rows=result_rows).data_ids(["patch"]) 

1597 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1598 

1599 def test_dataset_results(self) -> None: 

1600 """Test queries that return dataset refs. 

1601 

1602 This includes tests for: 

1603 

1604 - counting result rows; 

1605 - expanding dimensions as needed for 'where' conditions; 

1606 - different ways of passing a data ID to 'where' methods; 

1607 - order_by and limit. 

1608 

1609 It does not include the iteration methods of the DatasetRefQueryResults 

1610 classes, since those require a different mock driver setup (see 

1611 test_dataset_iteration). More tests for different inputs to 

1612 DatasetRefQueryResults construction are in test_dataset_join. 

1613 """ 

1614 # Set up a few equivalent base query-results object to test. 

1615 query = self.query() 

1616 x = query.expression_factory 

1617 self.assertFalse(query.constraint_dimensions) 

1618 results1 = query.datasets("raw").where(x.instrument == "DummyCam", visit=4) 

1619 results2 = query.datasets("raw", collections=["DummyCam/defaults"]).where( 

1620 {"instrument": "DummyCam", "visit": 4} 

1621 ) 

1622 results3 = query.datasets("raw").where( 

1623 DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe) 

1624 ) 

1625 

1626 # Define a closure to check a DatasetRefQueryResults instance. 

1627 def check( 

1628 results: DatasetRefQueryResults, 

1629 order_by: Any = (), 

1630 limit: int | None = None, 

1631 include_dimension_records: bool = False, 

1632 ) -> list[str]: 

1633 results = results.order_by(*order_by).limit(limit) 

1634 with self.assertRaises(_TestQueryExecution) as cm: 

1635 list(results) 

1636 tree = cm.exception.tree 

1637 self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4") 

1638 self.assertEqual( 

1639 tree.dimensions, 

1640 self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()), 

1641 ) 

1642 self.assertFalse(tree.materializations) 

1643 self.assertEqual(tree.datasets.keys(), {results.dataset_type.name}) 

1644 self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",)) 

1645 self.assertEqual( 

1646 tree.datasets[results.dataset_type.name].dimensions, 

1647 results.dataset_type.dimensions.as_group(), 

1648 ) 

1649 self.assertFalse(tree.data_coordinate_uploads) 

1650 result_spec = cm.exception.result_spec 

1651 self.assertEqual(result_spec.result_type, "dataset_ref") 

1652 self.assertEqual(result_spec.include_dimension_records, include_dimension_records) 

1653 self.assertEqual(result_spec.limit, limit) 

1654 self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name) 

1655 for exact, discard in itertools.permutations([False, True], r=2): 

1656 with self.assertRaises(_TestQueryCount) as cm: 

1657 results.count(exact=exact, discard=discard) 

1658 self.assertEqual(cm.exception.result_spec, result_spec) 

1659 self.assertEqual(cm.exception.exact, exact) 

1660 self.assertEqual(cm.exception.discard, discard) 

1661 with self.assertRaises(_TestQueryExecution) as cm: 

1662 list(results.data_ids) 

1663 self.assertEqual( 

1664 cm.exception.result_spec, 

1665 qrs.DataCoordinateResultSpec( 

1666 dimensions=results.dataset_type.dimensions.as_group(), 

1667 include_dimension_records=include_dimension_records, 

1668 ), 

1669 ) 

1670 self.assertIs(cm.exception.tree, tree) 

1671 return [str(term) for term in result_spec.order_by] 

1672 

1673 # Run the closure's tests on variants of the base query. 

1674 self.assertEqual(check(results1), []) 

1675 self.assertEqual(check(results2), []) 

1676 self.assertEqual(check(results3), []) 

1677 self.assertEqual(check(results1.with_dimension_records(), include_dimension_records=True), []) 

1678 self.assertEqual( 

1679 check(results1.with_dimension_records().with_dimension_records(), include_dimension_records=True), 

1680 [], 

1681 ) 

1682 self.assertEqual(check(results1, limit=2), []) 

1683 self.assertEqual(check(results1, order_by=["raw.timespan.begin"]), ["raw.timespan.begin"]) 

1684 self.assertEqual(check(results1, order_by=["detector"]), ["detector"]) 

1685 self.assertEqual(check(results1, order_by=["ingest_date"]), ["raw.ingest_date"]) 

1686 

1687 def test_dataset_iteration(self) -> None: 

1688 """Tests for SingleTypeDatasetQueryResult iteration.""" 

1689 

1690 def make_ref(n: int) -> DimensionRecord: 

1691 return DatasetRef( 

1692 self.raw, 

1693 DataCoordinate.standardize( 

1694 instrument="DummyCam", exposure=4, detector=n, universe=self.universe 

1695 ), 

1696 run="DummyCam/raw/all", 

1697 id=uuid.uuid4(), 

1698 ) 

1699 

1700 result_rows = ( 

1701 [make_ref(n) for n in range(3)], 

1702 [make_ref(n) for n in range(3, 6)], 

1703 [make_ref(10)], 

1704 ) 

1705 results = self.query(result_rows=result_rows).datasets("raw") 

1706 self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) 

1707 

1708 def test_identifiers(self) -> None: 

1709 """Test edge-cases of identifiers in order_by expressions.""" 

1710 

1711 def extract_order_by(results: DataCoordinateQueryResults) -> list[str]: 

1712 with self.assertRaises(_TestQueryExecution) as cm: 

1713 list(results) 

1714 return [str(term) for term in cm.exception.result_spec.order_by] 

1715 

1716 self.assertEqual( 

1717 extract_order_by(self.query().data_ids(["day_obs"]).order_by("-timespan.begin")), 

1718 ["day_obs.timespan.begin DESC"], 

1719 ) 

1720 self.assertEqual( 

1721 extract_order_by(self.query().data_ids(["day_obs"]).order_by("timespan.end")), 

1722 ["day_obs.timespan.end"], 

1723 ) 

1724 self.assertEqual( 

1725 extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")), 

1726 ["visit.timespan.begin DESC"], 

1727 ) 

1728 self.assertEqual( 

1729 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")), 

1730 ["visit.timespan.end"], 

1731 ) 

1732 self.assertEqual( 

1733 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")), 

1734 ["visit.science_program"], 

1735 ) 

1736 self.assertEqual( 

1737 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")), 

1738 ["visit"], 

1739 ) 

1740 self.assertEqual( 

1741 extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")), 

1742 ["physical_filter"], 

1743 ) 

1744 with self.assertRaises(TypeError): 

1745 self.query().data_ids(["visit"]).order_by(3) 

1746 with self.assertRaises(qt.InvalidQueryError): 

1747 self.query().data_ids(["visit"]).order_by("visit.region") 

1748 with self.assertRaisesRegex(qt.InvalidQueryError, "Ambiguous"): 

1749 self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin") 

1750 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): 

1751 self.query().data_ids(["visit", "exposure"]).order_by("blarg") 

1752 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): 

1753 self.query().data_ids(["visit", "exposure"]).order_by("visit.horse") 

1754 with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): 

1755 self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey") 

1756 with self.assertRaisesRegex(qt.InvalidQueryError, "not valid for datasets"): 

1757 self.query().datasets("raw").order_by("raw.seq_num") 

1758 

1759 def test_invalid_models(self) -> None: 

1760 """Test invalid models and combinations of models that cannot be 

1761 constructed via the public Query and *QueryResults interfaces. 

1762 """ 

1763 x = ExpressionFactory(self.universe) 

1764 with self.assertRaises(qt.InvalidQueryError): 

1765 # QueryTree dimensions do not cover dataset dimensions. 

1766 qt.QueryTree( 

1767 dimensions=self.universe.conform(["visit"]), 

1768 datasets={ 

1769 "raw": qt.DatasetSearch( 

1770 collections=("DummyCam/raw/all",), 

1771 dimensions=self.raw.dimensions.as_group(), 

1772 ) 

1773 }, 

1774 ) 

1775 with self.assertRaises(qt.InvalidQueryError): 

1776 # QueryTree dimensions do no cover predicate dimensions. 

1777 qt.QueryTree( 

1778 dimensions=self.universe.conform(["visit"]), 

1779 predicate=(x.detector > 5), 

1780 ) 

1781 with self.assertRaises(qt.InvalidQueryError): 

1782 # Predicate references a dataset not in the QueryTree. 

1783 qt.QueryTree( 

1784 dimensions=self.universe.conform(["exposure", "detector"]), 

1785 predicate=(x["raw"].collection == "bird"), 

1786 ) 

1787 with self.assertRaises(qt.InvalidQueryError): 

1788 # ResultSpec's dimensions are not a subset of the query tree's. 

1789 DimensionRecordQueryResults( 

1790 _TestQueryDriver(), 

1791 qt.QueryTree(dimensions=self.universe.conform(["tract"])), 

1792 qrs.DimensionRecordResultSpec(element=self.universe["detector"]), 

1793 ) 

1794 with self.assertRaises(qt.InvalidQueryError): 

1795 # ResultSpec's datasets are not a subset of the query tree's. 

1796 DatasetRefQueryResults( 

1797 _TestQueryDriver(), 

1798 qt.QueryTree(dimensions=self.raw.dimensions.as_group()), 

1799 qrs.DatasetRefResultSpec( 

1800 dataset_type_name="raw", 

1801 dimensions=self.raw.dimensions.as_group(), 

1802 storage_class_name=self.raw.storageClass_name, 

1803 find_first=True, 

1804 ), 

1805 ) 

1806 with self.assertRaises(qt.InvalidQueryError): 

1807 # ResultSpec's order_by expression is not related to the dimensions 

1808 # we're returning. 

1809 x = ExpressionFactory(self.universe) 

1810 DimensionRecordQueryResults( 

1811 _TestQueryDriver(), 

1812 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1813 qrs.DimensionRecordResultSpec( 

1814 element=self.universe["detector"], order_by=(x.unwrap(x.visit),) 

1815 ), 

1816 ) 

1817 with self.assertRaises(qt.InvalidQueryError): 

1818 # ResultSpec's order_by expression is not related to the datasets 

1819 # we're returning. 

1820 x = ExpressionFactory(self.universe) 

1821 DimensionRecordQueryResults( 

1822 _TestQueryDriver(), 

1823 qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), 

1824 qrs.DimensionRecordResultSpec( 

1825 element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),) 

1826 ), 

1827 ) 

1828 

1829 def test_general_result_spec(self) -> None: 

1830 """Tests for GeneralResultSpec. 

1831 

1832 Unlike the other ResultSpec objects, we don't have a *QueryResults 

1833 class for GeneralResultSpec yet, so we can't use the higher-level 

1834 interfaces to test it like we can the others. 

1835 """ 

1836 a = qrs.GeneralResultSpec( 

1837 dimensions=self.universe.conform(["detector"]), 

1838 dimension_fields={"detector": {"purpose"}}, 

1839 dataset_fields={}, 

1840 find_first=False, 

1841 ) 

1842 self.assertEqual(a.find_first_dataset, None) 

1843 a_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1844 a_columns.dimension_fields["detector"].add("purpose") 

1845 self.assertEqual(a.get_result_columns(), a_columns) 

1846 b = qrs.GeneralResultSpec( 

1847 dimensions=self.universe.conform(["detector"]), 

1848 dimension_fields={}, 

1849 dataset_fields={"bias": {"timespan", "dataset_id"}}, 

1850 find_first=True, 

1851 ) 

1852 self.assertEqual(b.find_first_dataset, "bias") 

1853 b_columns = qt.ColumnSet(self.universe.conform(["detector"])) 

1854 b_columns.dataset_fields["bias"].add("timespan") 

1855 b_columns.dataset_fields["bias"].add("dataset_id") 

1856 self.assertEqual(b.get_result_columns(), b_columns) 

1857 with self.assertRaises(qt.InvalidQueryError): 

1858 # More than one dataset type with find_first 

1859 qrs.GeneralResultSpec( 

1860 dimensions=self.universe.conform(["detector", "exposure"]), 

1861 dimension_fields={}, 

1862 dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}}, 

1863 find_first=True, 

1864 ) 

1865 with self.assertRaises(qt.InvalidQueryError): 

1866 # Out-of-bounds dimension fields. 

1867 qrs.GeneralResultSpec( 

1868 dimensions=self.universe.conform(["detector"]), 

1869 dimension_fields={"visit": {"name"}}, 

1870 dataset_fields={}, 

1871 find_first=False, 

1872 ) 

1873 with self.assertRaises(qt.InvalidQueryError): 

1874 # No fields for dimension element. 

1875 qrs.GeneralResultSpec( 

1876 dimensions=self.universe.conform(["detector"]), 

1877 dimension_fields={"detector": set()}, 

1878 dataset_fields={}, 

1879 find_first=True, 

1880 ) 

1881 with self.assertRaises(qt.InvalidQueryError): 

1882 # No fields for dataset. 

1883 qrs.GeneralResultSpec( 

1884 dimensions=self.universe.conform(["detector"]), 

1885 dimension_fields={}, 

1886 dataset_fields={"bias": set()}, 

1887 find_first=True, 

1888 ) 

1889 

1890 

1891if __name__ == "__main__": 

1892 unittest.main()