Coverage for python / lsst / daf / butler / queries / overlaps.py: 21%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:17 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("OverlapsVisitor",) 

31 

32import itertools 

33from collections.abc import Hashable, Iterable, Mapping, Sequence, Set 

34from typing import Generic, Literal, TypeVar, cast 

35 

36from lsst.sphgeom import Region 

37 

38from .._exceptions import InvalidQueryError 

39from .._topology import TopologicalFamily, TopologicalRelationshipEndpoint, TopologicalSpace 

40from ..dimensions import DimensionElement, DimensionGroup 

41from . import tree 

42from .visitors import PredicateVisitFlags, SimplePredicateVisitor 

43 

44_T = TypeVar("_T", bound=Hashable) 

45 

46 

47class _NaiveDisjointSet(Generic[_T]): 

48 """A very naive (but simple) implementation of a "disjoint set" data 

49 structure for strings, with mostly O(N) performance. 

50 

51 This class should not be used in any context where the number of elements 

52 in the data structure is large. It intentionally implements a subset of 

53 the interface of `scipy.cluster.DisJointSet` so that non-naive 

54 implementation could be swapped in if desired. 

55 

56 Parameters 

57 ---------- 

58 superset : `~collections.abc.Iterable` [ `str` ] 

59 Elements to initialize the disjoint set, with each in its own 

60 single-element subset. 

61 """ 

62 

63 def __init__(self, superset: Iterable[_T]): 

64 self._subsets = [{k} for k in superset] 

65 self._subsets.sort(key=len, reverse=True) 

66 

67 def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 

68 """Merge the subsets containing the given elements. 

69 

70 Parameters 

71 ---------- 

72 a 

73 Element whose subset should be merged. 

74 b 

75 Element whose subset should be merged. 

76 

77 Returns 

78 ------- 

79 merged : `bool` 

80 `True` if a merge occurred, `False` if the elements were already in 

81 the same subset. 

82 """ 

83 for i, subset in enumerate(self._subsets): 

84 if a in subset: 

85 break 

86 else: 

87 raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") 

88 for j, subset in enumerate(self._subsets): 

89 if b in subset: 

90 break 

91 else: 

92 raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") 

93 if i == j: 

94 return False 

95 i, j = sorted((i, j)) 

96 self._subsets[i].update(self._subsets[j]) 

97 del self._subsets[j] 

98 self._subsets.sort(key=len, reverse=True) 

99 return True 

100 

101 def subsets(self) -> Sequence[Set[_T]]: 

102 """Return the current subsets, ordered from largest to smallest.""" 

103 return self._subsets 

104 

105 @property 

106 def n_subsets(self) -> int: 

107 """The number of subsets.""" 

108 return len(self._subsets) 

109 

110 

111class CalibrationTemporalEndpoint(TopologicalRelationshipEndpoint): 

112 """An implementation of the "topological relationship endpoint" interface 

113 for a calibration dataset search. 

114 

115 Parameters 

116 ---------- 

117 dataset_type_name : `str` or ``ANY_DATASET`` 

118 Name of the dataset type. 

119 

120 Notes 

121 ----- 

122 This lets validity range lookups participate in the logic that checks to 

123 see if an explicit spatial/temporal join in the WHERE expression is present 

124 and hence an automatic join is unnecessary. That logic is simple for 

125 datasets, since each "family" is a single dataset type that only has one 

126 endpoint (whereas different dimensions like tract and patch can belong to 

127 the same family). 

128 """ 

129 

130 def __init__(self, dataset_type_name: str | tree.AnyDatasetType): 

131 self.dataset_type_name: str | tree.AnyDatasetType = dataset_type_name 

132 

133 @property 

134 def name(self) -> str: 

135 return self.dataset_type_name if self.dataset_type_name is not tree.ANY_DATASET else "<calibrations>" 

136 

137 @property 

138 def topology(self) -> Mapping[TopologicalSpace, TopologicalFamily]: 

139 return {TopologicalSpace.TEMPORAL: CalibrationTemporalFamily(self.dataset_type_name)} 

140 

141 

142class CalibrationTemporalFamily(TopologicalFamily): 

143 """An implementation of the "topological relationship endpoint" interface 

144 for a calibration dataset search. 

145 

146 See `CalibrationTemporalEndpoint` for rationale. 

147 

148 Parameters 

149 ---------- 

150 dataset_type_name : `str` or ``ANY_DATASET`` 

151 Name of the dataset type. 

152 """ 

153 

154 def __init__(self, dataset_type_name: str | tree.AnyDatasetType): 

155 super().__init__( 

156 dataset_type_name if dataset_type_name is not tree.ANY_DATASET else "<calibrations>", 

157 TopologicalSpace.TEMPORAL, 

158 ) 

159 self.dataset_type_name: str | tree.AnyDatasetType = dataset_type_name 

160 

161 def choose(self, dimensions: DimensionGroup) -> CalibrationTemporalEndpoint: 

162 return CalibrationTemporalEndpoint(self.dataset_type_name) 

163 

164 def make_column_reference(self, endpoint: TopologicalRelationshipEndpoint) -> tree.DatasetFieldReference: 

165 return tree.DatasetFieldReference(dataset_type=self.dataset_type_name, field="timespan") 

166 

167 

168class OverlapsVisitor(SimplePredicateVisitor): 

169 """A helper class for dealing with spatial and temporal overlaps in a 

170 query. 

171 

172 Parameters 

173 ---------- 

174 dimensions : `DimensionGroup` 

175 Dimensions of the query. 

176 calibration_dataset_types : `~collections.abc.Set` [ `str` ] 

177 The names of dataset types that have been joined into the query via 

178 a search that includes at least one calibration collection. 

179 

180 Notes 

181 ----- 

182 This class includes logic for extracting explicit spatial and temporal 

183 joins from a WHERE-clause predicate and computing automatic joins given the 

184 dimensions of the query. It is designed to be subclassed by query driver 

185 implementations that want to rewrite the predicate at the same time. 

186 """ 

187 

188 def __init__(self, dimensions: DimensionGroup, calibration_dataset_types: Set[str | tree.AnyDatasetType]): 

189 self.dimensions = dimensions 

190 self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) 

191 temporal_families: list[TopologicalFamily] = [ 

192 CalibrationTemporalFamily(name) for name in calibration_dataset_types 

193 ] 

194 temporal_families.extend(self.dimensions.temporal) 

195 self._temporal_connections = _NaiveDisjointSet(temporal_families) 

196 

197 def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: 

198 """Process the given predicate to extract spatial and temporal 

199 overlaps. 

200 

201 Parameters 

202 ---------- 

203 predicate : `tree.Predicate` 

204 Predicate to process. 

205 join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] 

206 The dimensions of logical tables being joined into this query; 

207 these can included embedded spatial and temporal joins that can 

208 make it unnecessary to add new ones. 

209 

210 Returns 

211 ------- 

212 predicate : `tree.Predicate` 

213 A possibly-modified predicate that should replace the original. 

214 """ 

215 result = predicate.visit(self) 

216 if result is None: 

217 result = predicate 

218 for join_operand_dimensions in join_operands: 

219 self._add_join_operand_connections( 

220 join_operand_dimensions.spatial, 

221 self._spatial_connections, 

222 join_operand_dimensions, 

223 ) 

224 self._add_join_operand_connections( 

225 join_operand_dimensions.temporal, 

226 self._temporal_connections, 

227 join_operand_dimensions, 

228 ) 

229 result = result.logical_and(self._add_automatic_joins("spatial", self._spatial_connections)) 

230 result = result.logical_and(self._add_automatic_joins("temporal", self._temporal_connections)) 

231 return result 

232 

233 def visit_comparison( 

234 self, 

235 a: tree.ColumnExpression, 

236 operator: tree.ComparisonOperator, 

237 b: tree.ColumnExpression, 

238 flags: PredicateVisitFlags, 

239 ) -> tree.Predicate | None: 

240 # Docstring inherited. 

241 if operator == "overlaps": 

242 if tree.is_one_timespan_and_one_datetime(a, b) or tree.is_one_timespan_and_one_ingest_date(a, b): 

243 # Can be transformed directly without special handling here. 

244 return None 

245 elif a.column_type == "region": 

246 return self.visit_spatial_overlap(a, b, flags) 

247 elif b.column_type == "timespan": 

248 return self.visit_temporal_overlap(a, b, flags) 

249 else: 

250 raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") 

251 return None 

252 

253 def _add_join_operand_connections( 

254 self, 

255 families: Iterable[TopologicalFamily], 

256 connections: _NaiveDisjointSet[TopologicalFamily], 

257 operand_dimensions: DimensionGroup, 

258 ) -> None: 

259 """Add overlap connections implied by a table or subquery. 

260 

261 Parameters 

262 ---------- 

263 families : `~collections.abc.Iterable` [ `TpologicalFamily` ] 

264 Iterable of spatial or temporal families in this operand's 

265 dimensions. 

266 connections : `_NaiveDisjointSet` 

267 Relationships between spatial or temporal families to update. 

268 operand_dimensions : `DimensionGroup` 

269 Dimensions of of the table or subquery. 

270 

271 Notes 

272 ----- 

273 We assume each join operand to a `tree.Select` has its own 

274 complete set of spatial and temporal joins that went into generating 

275 its rows. That will naturally be true for relations originating from 

276 the butler database, like dataset searches and materializations, and if 

277 it isn't true for a data ID upload, that would represent an intentional 

278 association between non-overlapping things that we'd want to respect by 

279 *not* adding a more restrictive automatic join. 

280 """ 

281 for a_family, b_family in itertools.combinations(families, 2): 

282 a_element = a_family.choose(self.dimensions) 

283 b_element = b_family.choose(self.dimensions) 

284 if ( 

285 a_element.name in operand_dimensions.elements 

286 and b_element.name in operand_dimensions.elements 

287 ): 

288 connections.merge(a_family, b_family) 

289 

290 def _add_automatic_joins( 

291 self, 

292 kind: Literal["spatial", "temporal"], 

293 connections: _NaiveDisjointSet[TopologicalFamily], 

294 ) -> tree.Predicate: 

295 if connections.n_subsets <= 1: 

296 # All of the joins we need are already present. 

297 return tree.Predicate.from_bool(True) 

298 if connections.n_subsets > 2: 

299 raise InvalidQueryError( 

300 f"Too many disconnected sets of {kind} families for an automatic " 

301 f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." 

302 ) 

303 a_subset, b_subset = connections.subsets() 

304 if len(a_subset) > 1 or len(b_subset) > 1: 

305 raise InvalidQueryError( 

306 f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " 

307 "add is ambiguous. Add an explicit spatial or temporal join to avoid this error." 

308 ) 

309 # We have a pair of families that are not explicitly or implicitly 

310 # connected to any other families; add an automatic join between their 

311 # most fine-grained members. 

312 (a_family,) = a_subset 

313 (b_family,) = b_subset 

314 a = a_family.make_column_reference(a_family.choose(self.dimensions)) 

315 b = b_family.make_column_reference(b_family.choose(self.dimensions)) 

316 join_predicate = self.visit_comparison(a, "overlaps", b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

317 if join_predicate is None: 

318 join_predicate = tree.Predicate.compare(a, "overlaps", b) 

319 return join_predicate 

320 

321 def visit_spatial_overlap( 

322 self, 

323 a: tree.ColumnExpression, 

324 b: tree.ColumnExpression, 

325 flags: PredicateVisitFlags, 

326 ) -> tree.Predicate | None: 

327 """Dispatch a spatial overlap comparison predicate to handlers. 

328 

329 This method should rarely (if ever) need to be overridden. 

330 

331 Parameters 

332 ---------- 

333 a : `tree.ColumnExpression` 

334 First operand. 

335 b : `tree.ColumnExpression` 

336 Second operand. 

337 flags : `tree.PredicateLeafFlags` 

338 Information about where this overlap comparison appears in the 

339 larger predicate tree. 

340 

341 Returns 

342 ------- 

343 replaced : `tree.Predicate` or `None` 

344 The predicate to be inserted instead in the processed tree, or 

345 `None` if no substitution is needed. 

346 """ 

347 match a, b: 

348 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

349 element=b_element 

350 ): 

351 return self.visit_spatial_join(a_element, b_element, flags) 

352 case tree.DimensionFieldReference(element=element), region_expression: 

353 pass 

354 case region_expression, tree.DimensionFieldReference(element=element): 

355 pass 

356 case _: 

357 raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") 

358 if region := region_expression.get_literal_value(): 

359 return self.visit_spatial_constraint(element, region, flags) 

360 raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") 

361 

362 def visit_temporal_overlap( 

363 self, 

364 a: tree.ColumnExpression, 

365 b: tree.ColumnExpression, 

366 flags: PredicateVisitFlags, 

367 ) -> tree.Predicate | None: 

368 """Dispatch a temporal overlap comparison predicate to handlers. 

369 

370 This method should rarely (if ever) need to be overridden. 

371 

372 Parameters 

373 ---------- 

374 a : `tree.ColumnExpression`- 

375 First operand. 

376 b : `tree.ColumnExpression` 

377 Second operand. 

378 flags : `tree.PredicateLeafFlags` 

379 Information about where this overlap comparison appears in the 

380 larger predicate tree. 

381 

382 Returns 

383 ------- 

384 replaced : `tree.Predicate` or `None` 

385 The predicate to be inserted instead in the processed tree, or 

386 `None` if no substitution is needed. 

387 """ 

388 match a, b: 

389 case ( 

390 tree.DimensionFieldReference(element=a_element), 

391 tree.DimensionFieldReference(element=b_element), 

392 ): 

393 return self.visit_temporal_dimension_join(a_element, b_element, flags) 

394 case ( 

395 tree.DatasetFieldReference(dataset_type=a_dataset), 

396 tree.DimensionFieldReference(element=b_element), 

397 ): 

398 return self.visit_validity_range_dimension_join(a_dataset, b_element, flags) 

399 case ( 

400 tree.DimensionFieldReference(element=a_element), 

401 tree.DatasetFieldReference(dataset_type=b_dataset), 

402 ): 

403 return self.visit_validity_range_dimension_join(b_dataset, a_element, flags) 

404 case ( 

405 tree.DatasetFieldReference(dataset_type=a_dataset), 

406 tree.DatasetFieldReference(dataset_type=b_dataset), 

407 ): 

408 return self.visit_validity_range_join(a_dataset, b_dataset, flags) 

409 case _: 

410 # Other cases do not participate in automatic join logic and 

411 # do not require the predicate to be rewritten. 

412 return None 

413 

414 def visit_spatial_join( 

415 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

416 ) -> tree.Predicate | None: 

417 """Handle a spatial overlap comparison between two dimension elements. 

418 

419 The default implementation updates the set of known spatial connections 

420 (for use by `compute_automatic_spatial_joins`) and returns `None`. 

421 

422 Parameters 

423 ---------- 

424 a : `DimensionElement` 

425 One element in the join. 

426 b : `DimensionElement` 

427 The other element in the join. 

428 flags : `tree.PredicateLeafFlags` 

429 Information about where this overlap comparison appears in the 

430 larger predicate tree. 

431 

432 Returns 

433 ------- 

434 replaced : `tree.Predicate` or `None` 

435 The predicate to be inserted instead in the processed tree, or 

436 `None` if no substitution is needed. 

437 """ 

438 if a.spatial == b.spatial: 

439 raise InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.") 

440 self._spatial_connections.merge( 

441 cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) 

442 ) 

443 return None 

444 

445 def visit_spatial_constraint( 

446 self, 

447 element: DimensionElement, 

448 region: Region, 

449 flags: PredicateVisitFlags, 

450 ) -> tree.Predicate | None: 

451 """Handle a spatial overlap comparison between a dimension element and 

452 a literal region. 

453 

454 The default implementation just returns `None`. 

455 

456 Parameters 

457 ---------- 

458 element : `DimensionElement` 

459 The dimension element in the comparison. 

460 region : `lsst.sphgeom.Region` 

461 The literal region in the comparison. 

462 flags : `tree.PredicateLeafFlags` 

463 Information about where this overlap comparison appears in the 

464 larger predicate tree. 

465 

466 Returns 

467 ------- 

468 replaced : `tree.Predicate` or `None` 

469 The predicate to be inserted instead in the processed tree, or 

470 `None` if no substitution is needed. 

471 """ 

472 return None 

473 

474 def visit_temporal_dimension_join( 

475 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

476 ) -> tree.Predicate | None: 

477 """Handle a temporal overlap comparison between two dimension elements. 

478 

479 The default implementation updates the set of known temporal 

480 connections (for use by `compute_automatic_temporal_joins`) and returns 

481 `None`. 

482 

483 Parameters 

484 ---------- 

485 a : `DimensionElement` 

486 One element in the join. 

487 b : `DimensionElement` 

488 The other element in the join. 

489 flags : `tree.PredicateLeafFlags` 

490 Information about where this overlap comparison appears in the 

491 larger predicate tree. 

492 

493 Returns 

494 ------- 

495 replaced : `tree.Predicate` or `None` 

496 The predicate to be inserted instead in the processed tree, or 

497 `None` if no substitution is needed. 

498 """ 

499 if a.temporal == b.temporal: 

500 raise InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.") 

501 self._temporal_connections.merge( 

502 cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) 

503 ) 

504 return None 

505 

506 def visit_validity_range_dimension_join( 

507 self, a: str | tree.AnyDatasetType, b: DimensionElement, flags: PredicateVisitFlags 

508 ) -> tree.Predicate | None: 

509 """Handle a temporal overlap comparison between two dimension elements. 

510 

511 The default implementation updates the set of known temporal 

512 connections (for use by `compute_automatic_temporal_joins`) and returns 

513 `None`. 

514 

515 Parameters 

516 ---------- 

517 a : `str` or ``tree.AnyDatasetType`` 

518 Name of a calibration dataset type. 

519 b : `DimensionElement` 

520 The dimension element to join the dataset validity range to. 

521 flags : `tree.PredicateLeafFlags` 

522 Information about where this overlap comparison appears in the 

523 larger predicate tree. 

524 

525 Returns 

526 ------- 

527 replaced : `tree.Predicate` or `None` 

528 The predicate to be inserted instead in the processed tree, or 

529 `None` if no substitution is needed. 

530 """ 

531 self._temporal_connections.merge(CalibrationTemporalFamily(a), cast(TopologicalFamily, b.temporal)) 

532 return None 

533 

534 def visit_validity_range_join( 

535 self, a: str | tree.AnyDatasetType, b: str | tree.AnyDatasetType, flags: PredicateVisitFlags 

536 ) -> tree.Predicate | None: 

537 """Handle a temporal overlap comparison between two dimension elements. 

538 

539 The default implementation updates the set of known temporal 

540 connections (for use by `compute_automatic_temporal_joins`) and returns 

541 `None`. 

542 

543 Parameters 

544 ---------- 

545 a : `str` or ``tree.AnyDatasetType`` 

546 Name of a calibration dataset type. 

547 b : `str` or ``tree.AnyDatasetType`` 

548 Another claibration dataset type to join to. 

549 flags : `tree.PredicateLeafFlags` 

550 Information about where this overlap comparison appears in the 

551 larger predicate tree. 

552 

553 Returns 

554 ------- 

555 replaced : `tree.Predicate` or `None` 

556 The predicate to be inserted instead in the processed tree, or 

557 `None` if no substitution is needed. 

558 """ 

559 self._temporal_connections.merge(CalibrationTemporalFamily(a), CalibrationTemporalFamily(b)) 

560 return None