Coverage for python/lsst/daf/butler/queries/overlaps.py: 19%

120 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-05 11:36 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("OverlapsVisitor",) 

31 

32import itertools 

33from collections.abc import Hashable, Iterable, Sequence, Set 

34from typing import Generic, Literal, TypeVar, cast 

35 

36from lsst.sphgeom import Region 

37 

38from .._topology import TopologicalFamily 

39from ..dimensions import DimensionElement, DimensionGroup 

40from . import tree 

41from .visitors import PredicateVisitFlags, SimplePredicateVisitor 

42 

43_T = TypeVar("_T", bound=Hashable) 

44 

45 

46class _NaiveDisjointSet(Generic[_T]): 

47 """A very naive (but simple) implementation of a "disjoint set" data 

48 structure for strings, with mostly O(N) performance. 

49 

50 This class should not be used in any context where the number of elements 

51 in the data structure is large. It intentionally implements a subset of 

52 the interface of `scipy.cluster.DisJointSet` so that non-naive 

53 implementation could be swapped in if desired. 

54 

55 Parameters 

56 ---------- 

57 superset : `~collections.abc.Iterable` [ `str` ] 

58 Elements to initialize the disjoint set, with each in its own 

59 single-element subset. 

60 """ 

61 

62 def __init__(self, superset: Iterable[_T]): 

63 self._subsets = [{k} for k in superset] 

64 self._subsets.sort(key=len, reverse=True) 

65 

66 def add(self, k: _T) -> bool: # numpydoc ignore=PR04 

67 """Add a new element as its own single-element subset unless it is 

68 already present. 

69 

70 Parameters 

71 ---------- 

72 k 

73 Value to add. 

74 

75 Returns 

76 ------- 

77 added : `bool`: 

78 `True` if the value was actually added, `False` if it was already 

79 present. 

80 """ 

81 for subset in self._subsets: 

82 if k in subset: 

83 return False 

84 self._subsets.append({k}) 

85 return True 

86 

87 def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 

88 """Merge the subsets containing the given elements. 

89 

90 Parameters 

91 ---------- 

92 a : 

93 Element whose subset should be merged. 

94 b : 

95 Element whose subset should be merged. 

96 

97 Returns 

98 ------- 

99 merged : `bool` 

100 `True` if a merge occurred, `False` if the elements were already in 

101 the same subset. 

102 """ 

103 for i, subset in enumerate(self._subsets): 

104 if a in subset: 

105 break 

106 else: 

107 raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") 

108 for j, subset in enumerate(self._subsets): 

109 if b in subset: 

110 break 

111 else: 

112 raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") 

113 if i == j: 

114 return False 

115 i, j = sorted((i, j)) 

116 self._subsets[i].update(self._subsets[j]) 

117 del self._subsets[j] 

118 self._subsets.sort(key=len, reverse=True) 

119 return True 

120 

121 def subsets(self) -> Sequence[Set[_T]]: 

122 """Return the current subsets, ordered from largest to smallest.""" 

123 return self._subsets 

124 

125 @property 

126 def n_subsets(self) -> int: 

127 """The number of subsets.""" 

128 return len(self._subsets) 

129 

130 

131class OverlapsVisitor(SimplePredicateVisitor): 

132 """A helper class for dealing with spatial and temporal overlaps in a 

133 query. 

134 

135 Parameters 

136 ---------- 

137 dimensions : `DimensionGroup` 

138 Dimensions of the query. 

139 

140 Notes 

141 ----- 

142 This class includes logic for extracting explicit spatial and temporal 

143 joins from a WHERE-clause predicate and computing automatic joins given the 

144 dimensions of the query. It is designed to be subclassed by query driver 

145 implementations that want to rewrite the predicate at the same time. 

146 """ 

147 

148 def __init__(self, dimensions: DimensionGroup): 

149 self.dimensions = dimensions 

150 self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) 

151 self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) 

152 

153 def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: 

154 """Process the given predicate to extract spatial and temporal 

155 overlaps. 

156 

157 Parameters 

158 ---------- 

159 predicate : `tree.Predicate` 

160 Predicate to process. 

161 join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] 

162 The dimensions of logical tables being joined into this query; 

163 these can included embedded spatial and temporal joins that can 

164 make it unnecessary to add new ones. 

165 

166 Returns 

167 ------- 

168 predicate : `tree.Predicate` 

169 A possibly-modified predicate that should replace the original. 

170 """ 

171 result = predicate.visit(self) 

172 if result is None: 

173 result = predicate 

174 for join_operand_dimensions in join_operands: 

175 self.add_join_operand_connections(join_operand_dimensions) 

176 for a, b in self.compute_automatic_spatial_joins(): 

177 join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

178 if join_predicate is None: 

179 join_predicate = tree.Predicate.compare( 

180 tree.DimensionFieldReference.model_construct(element=a, field="region"), 

181 "overlaps", 

182 tree.DimensionFieldReference.model_construct(element=b, field="region"), 

183 ) 

184 result = result.logical_and(join_predicate) 

185 for a, b in self.compute_automatic_temporal_joins(): 

186 join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

187 if join_predicate is None: 

188 join_predicate = tree.Predicate.compare( 

189 tree.DimensionFieldReference.model_construct(element=a, field="timespan"), 

190 "overlaps", 

191 tree.DimensionFieldReference.model_construct(element=b, field="timespan"), 

192 ) 

193 result = result.logical_and(join_predicate) 

194 return result 

195 

196 def visit_comparison( 

197 self, 

198 a: tree.ColumnExpression, 

199 operator: tree.ComparisonOperator, 

200 b: tree.ColumnExpression, 

201 flags: PredicateVisitFlags, 

202 ) -> tree.Predicate | None: 

203 # Docstring inherited. 

204 if operator == "overlaps": 

205 if a.column_type == "region": 

206 return self.visit_spatial_overlap(a, b, flags) 

207 elif b.column_type == "timespan": 

208 return self.visit_temporal_overlap(a, b, flags) 

209 else: 

210 raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") 

211 return None 

212 

213 def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: 

214 """Add overlap connections implied by a table or subquery. 

215 

216 Parameters 

217 ---------- 

218 operand_dimensions : `DimensionGroup` 

219 Dimensions of of the table or subquery. 

220 

221 Notes 

222 ----- 

223 We assume each join operand to a `tree.Select` has its own 

224 complete set of spatial and temporal joins that went into generating 

225 its rows. That will naturally be true for relations originating from 

226 the butler database, like dataset searches and materializations, and if 

227 it isn't true for a data ID upload, that would represent an intentional 

228 association between non-overlapping things that we'd want to respect by 

229 *not* adding a more restrictive automatic join. 

230 """ 

231 for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): 

232 self._spatial_connections.merge(a_family, b_family) 

233 for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): 

234 self._temporal_connections.merge(a_family, b_family) 

235 

236 def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

237 """Return pairs of dimension elements that should be spatially joined. 

238 

239 Returns 

240 ------- 

241 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

242 Automatic joins. 

243 

244 Notes 

245 ----- 

246 This takes into account explicit joins extracted by `run` and implicit 

247 joins added by `add_join_operand_connections`, and only returns 

248 additional joins if there is an unambiguous way to spatially connect 

249 any dimensions that are not already spatially connected. Automatic 

250 joins are always the most fine-grained join between sets of dimensions 

251 (i.e. ``visit_detector_region`` and ``patch`` instead of ``visit`` and 

252 ``tract``), but explicitly adding a coarser join between sets of 

253 elements will prevent the fine-grained join from being added. 

254 """ 

255 return self._compute_automatic_joins("spatial", self._spatial_connections) 

256 

257 def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

258 """Return pairs of dimension elements that should be spatially joined. 

259 

260 Returns 

261 ------- 

262 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

263 Automatic joins. 

264 

265 Notes 

266 ----- 

267 See `compute_automatic_spatial_joins` for information on how automatic 

268 joins are determined. Joins to dataset validity ranges are never 

269 automatic. 

270 """ 

271 return self._compute_automatic_joins("temporal", self._temporal_connections) 

272 

273 def _compute_automatic_joins( 

274 self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] 

275 ) -> list[tuple[DimensionElement, DimensionElement]]: 

276 if connections.n_subsets <= 1: 

277 # All of the joins we need are already present. 

278 return [] 

279 if connections.n_subsets > 2: 

280 raise tree.InvalidQueryError( 

281 f"Too many disconnected sets of {kind} families for an automatic " 

282 f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." 

283 ) 

284 a_subset, b_subset = connections.subsets() 

285 if len(a_subset) > 1 or len(b_subset) > 1: 

286 raise tree.InvalidQueryError( 

287 f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " 

288 "add is ambiguous. Add an explicit spatial join to avoid this error." 

289 ) 

290 # We have a pair of families that are not explicitly or implicitly 

291 # connected to any other families; add an automatic join between their 

292 # most fine-grained members. 

293 (a_family,) = a_subset 

294 (b_family,) = b_subset 

295 return [ 

296 ( 

297 cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), 

298 cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), 

299 ) 

300 ] 

301 

302 def visit_spatial_overlap( 

303 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

304 ) -> tree.Predicate | None: 

305 """Dispatch a spatial overlap comparison predicate to handlers. 

306 

307 This method should rarely (if ever) need to be overridden. 

308 

309 Parameters 

310 ---------- 

311 a : `tree.ColumnExpression` 

312 First operand. 

313 b : `tree.ColumnExpression` 

314 Second operand. 

315 flags : `tree.PredicateLeafFlags` 

316 Information about where this overlap comparison appears in the 

317 larger predicate tree. 

318 

319 Returns 

320 ------- 

321 replaced : `tree.Predicate` or `None` 

322 The predicate to be inserted instead in the processed tree, or 

323 `None` if no substitution is needed. 

324 """ 

325 match a, b: 

326 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

327 element=b_element 

328 ): 

329 return self.visit_spatial_join(a_element, b_element, flags) 

330 case tree.DimensionFieldReference(element=element), region_expression: 

331 pass 

332 case region_expression, tree.DimensionFieldReference(element=element): 

333 pass 

334 case _: 

335 raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") 

336 if region := region_expression.get_literal_value(): 

337 return self.visit_spatial_constraint(element, region, flags) 

338 raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") 

339 

340 def visit_temporal_overlap( 

341 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

342 ) -> tree.Predicate | None: 

343 """Dispatch a temporal overlap comparison predicate to handlers. 

344 

345 This method should rarely (if ever) need to be overridden. 

346 

347 Parameters 

348 ---------- 

349 a : `tree.ColumnExpression`- 

350 First operand. 

351 b : `tree.ColumnExpression` 

352 Second operand. 

353 flags : `tree.PredicateLeafFlags` 

354 Information about where this overlap comparison appears in the 

355 larger predicate tree. 

356 

357 Returns 

358 ------- 

359 replaced : `tree.Predicate` or `None` 

360 The predicate to be inserted instead in the processed tree, or 

361 `None` if no substitution is needed. 

362 """ 

363 match a, b: 

364 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

365 element=b_element 

366 ): 

367 return self.visit_temporal_dimension_join(a_element, b_element, flags) 

368 case _: 

369 # We don't bother differentiating any other kind of temporal 

370 # comparison, because in all foreseeable database schemas we 

371 # wouldn't have to do anything special with them, since they 

372 # don't participate in automatic join calculations and they 

373 # should be straightforwardly convertible to SQL. 

374 return None 

375 

376 def visit_spatial_join( 

377 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

378 ) -> tree.Predicate | None: 

379 """Handle a spatial overlap comparison between two dimension elements. 

380 

381 The default implementation updates the set of known spatial connections 

382 (for use by `compute_automatic_spatial_joins`) and returns `None`. 

383 

384 Parameters 

385 ---------- 

386 a : `DimensionElement` 

387 One element in the join. 

388 b : `DimensionElement` 

389 The other element in the join. 

390 flags : `tree.PredicateLeafFlags` 

391 Information about where this overlap comparison appears in the 

392 larger predicate tree. 

393 

394 Returns 

395 ------- 

396 replaced : `tree.Predicate` or `None` 

397 The predicate to be inserted instead in the processed tree, or 

398 `None` if no substitution is needed. 

399 """ 

400 if a.spatial == b.spatial: 

401 raise tree.InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.") 

402 self._spatial_connections.merge( 

403 cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) 

404 ) 

405 return None 

406 

407 def visit_spatial_constraint( 

408 self, 

409 element: DimensionElement, 

410 region: Region, 

411 flags: PredicateVisitFlags, 

412 ) -> tree.Predicate | None: 

413 """Handle a spatial overlap comparison between a dimension element and 

414 a literal region. 

415 

416 The default implementation just returns `None`. 

417 

418 Parameters 

419 ---------- 

420 element : `DimensionElement` 

421 The dimension element in the comparison. 

422 region : `lsst.sphgeom.Region` 

423 The literal region in the comparison. 

424 flags : `tree.PredicateLeafFlags` 

425 Information about where this overlap comparison appears in the 

426 larger predicate tree. 

427 

428 Returns 

429 ------- 

430 replaced : `tree.Predicate` or `None` 

431 The predicate to be inserted instead in the processed tree, or 

432 `None` if no substitution is needed. 

433 """ 

434 return None 

435 

436 def visit_temporal_dimension_join( 

437 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

438 ) -> tree.Predicate | None: 

439 """Handle a temporal overlap comparison between two dimension elements. 

440 

441 The default implementation updates the set of known temporal 

442 connections (for use by `compute_automatic_temporal_joins`) and returns 

443 `None`. 

444 

445 Parameters 

446 ---------- 

447 a : `DimensionElement` 

448 One element in the join. 

449 b : `DimensionElement` 

450 The other element in the join. 

451 flags : `tree.PredicateLeafFlags` 

452 Information about where this overlap comparison appears in the 

453 larger predicate tree. 

454 

455 Returns 

456 ------- 

457 replaced : `tree.Predicate` or `None` 

458 The predicate to be inserted instead in the processed tree, or 

459 `None` if no substitution is needed. 

460 """ 

461 if a.temporal == b.temporal: 

462 raise tree.InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.") 

463 self._temporal_connections.merge( 

464 cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) 

465 ) 

466 return None