Coverage for python/lsst/daf/butler/queries/overlaps.py: 19%

114 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-26 02:48 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("OverlapsVisitor",) 

31 

32import itertools 

33from collections.abc import Hashable, Iterable, Sequence, Set 

34from typing import Generic, Literal, TypeVar, cast 

35 

36from lsst.sphgeom import Region 

37 

38from .._topology import TopologicalFamily 

39from ..dimensions import DimensionElement, DimensionGroup 

40from . import tree 

41from .visitors import PredicateVisitFlags, SimplePredicateVisitor 

42 

43_T = TypeVar("_T", bound=Hashable) 

44 

45 

46class _NaiveDisjointSet(Generic[_T]): 

47 """A very naive (but simple) implementation of a "disjoint set" data 

48 structure for strings, with mostly O(N) performance. 

49 

50 This class should not be used in any context where the number of elements 

51 in the data structure is large. It intentionally implements a subset of 

52 the interface of `scipy.cluster.DisJointSet` so that non-naive 

53 implementation could be swapped in if desired. 

54 

55 Parameters 

56 ---------- 

57 superset : `~collections.abc.Iterable` [ `str` ] 

58 Elements to initialize the disjoint set, with each in its own 

59 single-element subset. 

60 """ 

61 

62 def __init__(self, superset: Iterable[_T]): 

63 self._subsets = [{k} for k in superset] 

64 self._subsets.sort(key=len, reverse=True) 

65 

66 def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 

67 """Merge the subsets containing the given elements. 

68 

69 Parameters 

70 ---------- 

71 a : 

72 Element whose subset should be merged. 

73 b : 

74 Element whose subset should be merged. 

75 

76 Returns 

77 ------- 

78 merged : `bool` 

79 `True` if a merge occurred, `False` if the elements were already in 

80 the same subset. 

81 """ 

82 for i, subset in enumerate(self._subsets): 

83 if a in subset: 

84 break 

85 else: 

86 raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") 

87 for j, subset in enumerate(self._subsets): 

88 if b in subset: 

89 break 

90 else: 

91 raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") 

92 if i == j: 

93 return False 

94 i, j = sorted((i, j)) 

95 self._subsets[i].update(self._subsets[j]) 

96 del self._subsets[j] 

97 self._subsets.sort(key=len, reverse=True) 

98 return True 

99 

100 def subsets(self) -> Sequence[Set[_T]]: 

101 """Return the current subsets, ordered from largest to smallest.""" 

102 return self._subsets 

103 

104 @property 

105 def n_subsets(self) -> int: 

106 """The number of subsets.""" 

107 return len(self._subsets) 

108 

109 

110class OverlapsVisitor(SimplePredicateVisitor): 

111 """A helper class for dealing with spatial and temporal overlaps in a 

112 query. 

113 

114 Parameters 

115 ---------- 

116 dimensions : `DimensionGroup` 

117 Dimensions of the query. 

118 

119 Notes 

120 ----- 

121 This class includes logic for extracting explicit spatial and temporal 

122 joins from a WHERE-clause predicate and computing automatic joins given the 

123 dimensions of the query. It is designed to be subclassed by query driver 

124 implementations that want to rewrite the predicate at the same time. 

125 """ 

126 

127 def __init__(self, dimensions: DimensionGroup): 

128 self.dimensions = dimensions 

129 self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) 

130 self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) 

131 

132 def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: 

133 """Process the given predicate to extract spatial and temporal 

134 overlaps. 

135 

136 Parameters 

137 ---------- 

138 predicate : `tree.Predicate` 

139 Predicate to process. 

140 join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] 

141 The dimensions of logical tables being joined into this query; 

142 these can included embedded spatial and temporal joins that can 

143 make it unnecessary to add new ones. 

144 

145 Returns 

146 ------- 

147 predicate : `tree.Predicate` 

148 A possibly-modified predicate that should replace the original. 

149 """ 

150 result = predicate.visit(self) 

151 if result is None: 

152 result = predicate 

153 for join_operand_dimensions in join_operands: 

154 self.add_join_operand_connections(join_operand_dimensions) 

155 for a, b in self.compute_automatic_spatial_joins(): 

156 join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

157 if join_predicate is None: 

158 join_predicate = tree.Predicate.compare( 

159 tree.DimensionFieldReference.model_construct(element=a, field="region"), 

160 "overlaps", 

161 tree.DimensionFieldReference.model_construct(element=b, field="region"), 

162 ) 

163 result = result.logical_and(join_predicate) 

164 for a, b in self.compute_automatic_temporal_joins(): 

165 join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

166 if join_predicate is None: 

167 join_predicate = tree.Predicate.compare( 

168 tree.DimensionFieldReference.model_construct(element=a, field="timespan"), 

169 "overlaps", 

170 tree.DimensionFieldReference.model_construct(element=b, field="timespan"), 

171 ) 

172 result = result.logical_and(join_predicate) 

173 return result 

174 

175 def visit_comparison( 

176 self, 

177 a: tree.ColumnExpression, 

178 operator: tree.ComparisonOperator, 

179 b: tree.ColumnExpression, 

180 flags: PredicateVisitFlags, 

181 ) -> tree.Predicate | None: 

182 # Docstring inherited. 

183 if operator == "overlaps": 

184 if a.column_type == "region": 

185 return self.visit_spatial_overlap(a, b, flags) 

186 elif b.column_type == "timespan": 

187 return self.visit_temporal_overlap(a, b, flags) 

188 else: 

189 raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") 

190 return None 

191 

192 def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: 

193 """Add overlap connections implied by a table or subquery. 

194 

195 Parameters 

196 ---------- 

197 operand_dimensions : `DimensionGroup` 

198 Dimensions of of the table or subquery. 

199 

200 Notes 

201 ----- 

202 We assume each join operand to a `tree.Select` has its own 

203 complete set of spatial and temporal joins that went into generating 

204 its rows. That will naturally be true for relations originating from 

205 the butler database, like dataset searches and materializations, and if 

206 it isn't true for a data ID upload, that would represent an intentional 

207 association between non-overlapping things that we'd want to respect by 

208 *not* adding a more restrictive automatic join. 

209 """ 

210 for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): 

211 self._spatial_connections.merge(a_family, b_family) 

212 for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): 

213 self._temporal_connections.merge(a_family, b_family) 

214 

215 def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

216 """Return pairs of dimension elements that should be spatially joined. 

217 

218 Returns 

219 ------- 

220 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

221 Automatic joins. 

222 

223 Notes 

224 ----- 

225 This takes into account explicit joins extracted by `run` and implicit 

226 joins added by `add_join_operand_connections`, and only returns 

227 additional joins if there is an unambiguous way to spatially connect 

228 any dimensions that are not already spatially connected. Automatic 

229 joins are always the most fine-grained join between sets of dimensions 

230 (i.e. ``visit_detector_region`` and ``patch`` instead of ``visit`` and 

231 ``tract``), but explicitly adding a coarser join between sets of 

232 elements will prevent the fine-grained join from being added. 

233 """ 

234 return self._compute_automatic_joins("spatial", self._spatial_connections) 

235 

236 def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

237 """Return pairs of dimension elements that should be spatially joined. 

238 

239 Returns 

240 ------- 

241 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

242 Automatic joins. 

243 

244 Notes 

245 ----- 

246 See `compute_automatic_spatial_joins` for information on how automatic 

247 joins are determined. Joins to dataset validity ranges are never 

248 automatic. 

249 """ 

250 return self._compute_automatic_joins("temporal", self._temporal_connections) 

251 

252 def _compute_automatic_joins( 

253 self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] 

254 ) -> list[tuple[DimensionElement, DimensionElement]]: 

255 if connections.n_subsets <= 1: 

256 # All of the joins we need are already present. 

257 return [] 

258 if connections.n_subsets > 2: 

259 raise tree.InvalidQueryError( 

260 f"Too many disconnected sets of {kind} families for an automatic " 

261 f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." 

262 ) 

263 a_subset, b_subset = connections.subsets() 

264 if len(a_subset) > 1 or len(b_subset) > 1: 

265 raise tree.InvalidQueryError( 

266 f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " 

267 "add is ambiguous. Add an explicit spatial join to avoid this error." 

268 ) 

269 # We have a pair of families that are not explicitly or implicitly 

270 # connected to any other families; add an automatic join between their 

271 # most fine-grained members. 

272 (a_family,) = a_subset 

273 (b_family,) = b_subset 

274 return [ 

275 ( 

276 cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), 

277 cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), 

278 ) 

279 ] 

280 

281 def visit_spatial_overlap( 

282 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

283 ) -> tree.Predicate | None: 

284 """Dispatch a spatial overlap comparison predicate to handlers. 

285 

286 This method should rarely (if ever) need to be overridden. 

287 

288 Parameters 

289 ---------- 

290 a : `tree.ColumnExpression` 

291 First operand. 

292 b : `tree.ColumnExpression` 

293 Second operand. 

294 flags : `tree.PredicateLeafFlags` 

295 Information about where this overlap comparison appears in the 

296 larger predicate tree. 

297 

298 Returns 

299 ------- 

300 replaced : `tree.Predicate` or `None` 

301 The predicate to be inserted instead in the processed tree, or 

302 `None` if no substitution is needed. 

303 """ 

304 match a, b: 

305 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

306 element=b_element 

307 ): 

308 return self.visit_spatial_join(a_element, b_element, flags) 

309 case tree.DimensionFieldReference(element=element), region_expression: 

310 pass 

311 case region_expression, tree.DimensionFieldReference(element=element): 

312 pass 

313 case _: 

314 raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") 

315 if region := region_expression.get_literal_value(): 

316 return self.visit_spatial_constraint(element, region, flags) 

317 raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") 

318 

319 def visit_temporal_overlap( 

320 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

321 ) -> tree.Predicate | None: 

322 """Dispatch a temporal overlap comparison predicate to handlers. 

323 

324 This method should rarely (if ever) need to be overridden. 

325 

326 Parameters 

327 ---------- 

328 a : `tree.ColumnExpression`- 

329 First operand. 

330 b : `tree.ColumnExpression` 

331 Second operand. 

332 flags : `tree.PredicateLeafFlags` 

333 Information about where this overlap comparison appears in the 

334 larger predicate tree. 

335 

336 Returns 

337 ------- 

338 replaced : `tree.Predicate` or `None` 

339 The predicate to be inserted instead in the processed tree, or 

340 `None` if no substitution is needed. 

341 """ 

342 match a, b: 

343 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

344 element=b_element 

345 ): 

346 return self.visit_temporal_dimension_join(a_element, b_element, flags) 

347 case _: 

348 # We don't bother differentiating any other kind of temporal 

349 # comparison, because in all foreseeable database schemas we 

350 # wouldn't have to do anything special with them, since they 

351 # don't participate in automatic join calculations and they 

352 # should be straightforwardly convertible to SQL. 

353 return None 

354 

355 def visit_spatial_join( 

356 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

357 ) -> tree.Predicate | None: 

358 """Handle a spatial overlap comparison between two dimension elements. 

359 

360 The default implementation updates the set of known spatial connections 

361 (for use by `compute_automatic_spatial_joins`) and returns `None`. 

362 

363 Parameters 

364 ---------- 

365 a : `DimensionElement` 

366 One element in the join. 

367 b : `DimensionElement` 

368 The other element in the join. 

369 flags : `tree.PredicateLeafFlags` 

370 Information about where this overlap comparison appears in the 

371 larger predicate tree. 

372 

373 Returns 

374 ------- 

375 replaced : `tree.Predicate` or `None` 

376 The predicate to be inserted instead in the processed tree, or 

377 `None` if no substitution is needed. 

378 """ 

379 if a.spatial == b.spatial: 

380 raise tree.InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.") 

381 self._spatial_connections.merge( 

382 cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) 

383 ) 

384 return None 

385 

386 def visit_spatial_constraint( 

387 self, 

388 element: DimensionElement, 

389 region: Region, 

390 flags: PredicateVisitFlags, 

391 ) -> tree.Predicate | None: 

392 """Handle a spatial overlap comparison between a dimension element and 

393 a literal region. 

394 

395 The default implementation just returns `None`. 

396 

397 Parameters 

398 ---------- 

399 element : `DimensionElement` 

400 The dimension element in the comparison. 

401 region : `lsst.sphgeom.Region` 

402 The literal region in the comparison. 

403 flags : `tree.PredicateLeafFlags` 

404 Information about where this overlap comparison appears in the 

405 larger predicate tree. 

406 

407 Returns 

408 ------- 

409 replaced : `tree.Predicate` or `None` 

410 The predicate to be inserted instead in the processed tree, or 

411 `None` if no substitution is needed. 

412 """ 

413 return None 

414 

415 def visit_temporal_dimension_join( 

416 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

417 ) -> tree.Predicate | None: 

418 """Handle a temporal overlap comparison between two dimension elements. 

419 

420 The default implementation updates the set of known temporal 

421 connections (for use by `compute_automatic_temporal_joins`) and returns 

422 `None`. 

423 

424 Parameters 

425 ---------- 

426 a : `DimensionElement` 

427 One element in the join. 

428 b : `DimensionElement` 

429 The other element in the join. 

430 flags : `tree.PredicateLeafFlags` 

431 Information about where this overlap comparison appears in the 

432 larger predicate tree. 

433 

434 Returns 

435 ------- 

436 replaced : `tree.Predicate` or `None` 

437 The predicate to be inserted instead in the processed tree, or 

438 `None` if no substitution is needed. 

439 """ 

440 if a.temporal == b.temporal: 

441 raise tree.InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.") 

442 self._temporal_connections.merge( 

443 cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) 

444 ) 

445 return None