Coverage for python/lsst/daf/butler/queries/overlaps.py: 20%

115 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:51 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("OverlapsVisitor",) 

31 

32import itertools 

33from collections.abc import Hashable, Iterable, Sequence, Set 

34from typing import Generic, Literal, TypeVar, cast 

35 

36from lsst.sphgeom import Region 

37 

38from .._exceptions import InvalidQueryError 

39from .._topology import TopologicalFamily 

40from ..dimensions import DimensionElement, DimensionGroup 

41from . import tree 

42from .visitors import PredicateVisitFlags, SimplePredicateVisitor 

43 

44_T = TypeVar("_T", bound=Hashable) 

45 

46 

47class _NaiveDisjointSet(Generic[_T]): 

48 """A very naive (but simple) implementation of a "disjoint set" data 

49 structure for strings, with mostly O(N) performance. 

50 

51 This class should not be used in any context where the number of elements 

52 in the data structure is large. It intentionally implements a subset of 

53 the interface of `scipy.cluster.DisJointSet` so that non-naive 

54 implementation could be swapped in if desired. 

55 

56 Parameters 

57 ---------- 

58 superset : `~collections.abc.Iterable` [ `str` ] 

59 Elements to initialize the disjoint set, with each in its own 

60 single-element subset. 

61 """ 

62 

63 def __init__(self, superset: Iterable[_T]): 

64 self._subsets = [{k} for k in superset] 

65 self._subsets.sort(key=len, reverse=True) 

66 

67 def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 

68 """Merge the subsets containing the given elements. 

69 

70 Parameters 

71 ---------- 

72 a : 

73 Element whose subset should be merged. 

74 b : 

75 Element whose subset should be merged. 

76 

77 Returns 

78 ------- 

79 merged : `bool` 

80 `True` if a merge occurred, `False` if the elements were already in 

81 the same subset. 

82 """ 

83 for i, subset in enumerate(self._subsets): 

84 if a in subset: 

85 break 

86 else: 

87 raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") 

88 for j, subset in enumerate(self._subsets): 

89 if b in subset: 

90 break 

91 else: 

92 raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") 

93 if i == j: 

94 return False 

95 i, j = sorted((i, j)) 

96 self._subsets[i].update(self._subsets[j]) 

97 del self._subsets[j] 

98 self._subsets.sort(key=len, reverse=True) 

99 return True 

100 

101 def subsets(self) -> Sequence[Set[_T]]: 

102 """Return the current subsets, ordered from largest to smallest.""" 

103 return self._subsets 

104 

105 @property 

106 def n_subsets(self) -> int: 

107 """The number of subsets.""" 

108 return len(self._subsets) 

109 

110 

111class OverlapsVisitor(SimplePredicateVisitor): 

112 """A helper class for dealing with spatial and temporal overlaps in a 

113 query. 

114 

115 Parameters 

116 ---------- 

117 dimensions : `DimensionGroup` 

118 Dimensions of the query. 

119 

120 Notes 

121 ----- 

122 This class includes logic for extracting explicit spatial and temporal 

123 joins from a WHERE-clause predicate and computing automatic joins given the 

124 dimensions of the query. It is designed to be subclassed by query driver 

125 implementations that want to rewrite the predicate at the same time. 

126 """ 

127 

128 def __init__(self, dimensions: DimensionGroup): 

129 self.dimensions = dimensions 

130 self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) 

131 self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) 

132 

133 def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: 

134 """Process the given predicate to extract spatial and temporal 

135 overlaps. 

136 

137 Parameters 

138 ---------- 

139 predicate : `tree.Predicate` 

140 Predicate to process. 

141 join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] 

142 The dimensions of logical tables being joined into this query; 

143 these can included embedded spatial and temporal joins that can 

144 make it unnecessary to add new ones. 

145 

146 Returns 

147 ------- 

148 predicate : `tree.Predicate` 

149 A possibly-modified predicate that should replace the original. 

150 """ 

151 result = predicate.visit(self) 

152 if result is None: 

153 result = predicate 

154 for join_operand_dimensions in join_operands: 

155 self.add_join_operand_connections(join_operand_dimensions) 

156 for a, b in self.compute_automatic_spatial_joins(): 

157 join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

158 if join_predicate is None: 

159 join_predicate = tree.Predicate.compare( 

160 tree.DimensionFieldReference.model_construct(element=a, field="region"), 

161 "overlaps", 

162 tree.DimensionFieldReference.model_construct(element=b, field="region"), 

163 ) 

164 result = result.logical_and(join_predicate) 

165 for a, b in self.compute_automatic_temporal_joins(): 

166 join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) 

167 if join_predicate is None: 

168 join_predicate = tree.Predicate.compare( 

169 tree.DimensionFieldReference.model_construct(element=a, field="timespan"), 

170 "overlaps", 

171 tree.DimensionFieldReference.model_construct(element=b, field="timespan"), 

172 ) 

173 result = result.logical_and(join_predicate) 

174 return result 

175 

176 def visit_comparison( 

177 self, 

178 a: tree.ColumnExpression, 

179 operator: tree.ComparisonOperator, 

180 b: tree.ColumnExpression, 

181 flags: PredicateVisitFlags, 

182 ) -> tree.Predicate | None: 

183 # Docstring inherited. 

184 if operator == "overlaps": 

185 if a.column_type == "region": 

186 return self.visit_spatial_overlap(a, b, flags) 

187 elif b.column_type == "timespan": 

188 return self.visit_temporal_overlap(a, b, flags) 

189 else: 

190 raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") 

191 return None 

192 

193 def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: 

194 """Add overlap connections implied by a table or subquery. 

195 

196 Parameters 

197 ---------- 

198 operand_dimensions : `DimensionGroup` 

199 Dimensions of of the table or subquery. 

200 

201 Notes 

202 ----- 

203 We assume each join operand to a `tree.Select` has its own 

204 complete set of spatial and temporal joins that went into generating 

205 its rows. That will naturally be true for relations originating from 

206 the butler database, like dataset searches and materializations, and if 

207 it isn't true for a data ID upload, that would represent an intentional 

208 association between non-overlapping things that we'd want to respect by 

209 *not* adding a more restrictive automatic join. 

210 """ 

211 for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): 

212 self._spatial_connections.merge(a_family, b_family) 

213 for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): 

214 self._temporal_connections.merge(a_family, b_family) 

215 

216 def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

217 """Return pairs of dimension elements that should be spatially joined. 

218 

219 Returns 

220 ------- 

221 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

222 Automatic joins. 

223 

224 Notes 

225 ----- 

226 This takes into account explicit joins extracted by `run` and implicit 

227 joins added by `add_join_operand_connections`, and only returns 

228 additional joins if there is an unambiguous way to spatially connect 

229 any dimensions that are not already spatially connected. Automatic 

230 joins are always the most fine-grained join between sets of dimensions 

231 (i.e. ``visit_detector_region`` and ``patch`` instead of ``visit`` and 

232 ``tract``), but explicitly adding a coarser join between sets of 

233 elements will prevent the fine-grained join from being added. 

234 """ 

235 return self._compute_automatic_joins("spatial", self._spatial_connections) 

236 

237 def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: 

238 """Return pairs of dimension elements that should be spatially joined. 

239 

240 Returns 

241 ------- 

242 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] 

243 Automatic joins. 

244 

245 Notes 

246 ----- 

247 See `compute_automatic_spatial_joins` for information on how automatic 

248 joins are determined. Joins to dataset validity ranges are never 

249 automatic. 

250 """ 

251 return self._compute_automatic_joins("temporal", self._temporal_connections) 

252 

253 def _compute_automatic_joins( 

254 self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] 

255 ) -> list[tuple[DimensionElement, DimensionElement]]: 

256 if connections.n_subsets <= 1: 

257 # All of the joins we need are already present. 

258 return [] 

259 if connections.n_subsets > 2: 

260 raise InvalidQueryError( 

261 f"Too many disconnected sets of {kind} families for an automatic " 

262 f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." 

263 ) 

264 a_subset, b_subset = connections.subsets() 

265 if len(a_subset) > 1 or len(b_subset) > 1: 

266 raise InvalidQueryError( 

267 f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " 

268 "add is ambiguous. Add an explicit spatial join to avoid this error." 

269 ) 

270 # We have a pair of families that are not explicitly or implicitly 

271 # connected to any other families; add an automatic join between their 

272 # most fine-grained members. 

273 (a_family,) = a_subset 

274 (b_family,) = b_subset 

275 return [ 

276 ( 

277 cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), 

278 cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), 

279 ) 

280 ] 

281 

282 def visit_spatial_overlap( 

283 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

284 ) -> tree.Predicate | None: 

285 """Dispatch a spatial overlap comparison predicate to handlers. 

286 

287 This method should rarely (if ever) need to be overridden. 

288 

289 Parameters 

290 ---------- 

291 a : `tree.ColumnExpression` 

292 First operand. 

293 b : `tree.ColumnExpression` 

294 Second operand. 

295 flags : `tree.PredicateLeafFlags` 

296 Information about where this overlap comparison appears in the 

297 larger predicate tree. 

298 

299 Returns 

300 ------- 

301 replaced : `tree.Predicate` or `None` 

302 The predicate to be inserted instead in the processed tree, or 

303 `None` if no substitution is needed. 

304 """ 

305 match a, b: 

306 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

307 element=b_element 

308 ): 

309 return self.visit_spatial_join(a_element, b_element, flags) 

310 case tree.DimensionFieldReference(element=element), region_expression: 

311 pass 

312 case region_expression, tree.DimensionFieldReference(element=element): 

313 pass 

314 case _: 

315 raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") 

316 if region := region_expression.get_literal_value(): 

317 return self.visit_spatial_constraint(element, region, flags) 

318 raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") 

319 

320 def visit_temporal_overlap( 

321 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags 

322 ) -> tree.Predicate | None: 

323 """Dispatch a temporal overlap comparison predicate to handlers. 

324 

325 This method should rarely (if ever) need to be overridden. 

326 

327 Parameters 

328 ---------- 

329 a : `tree.ColumnExpression`- 

330 First operand. 

331 b : `tree.ColumnExpression` 

332 Second operand. 

333 flags : `tree.PredicateLeafFlags` 

334 Information about where this overlap comparison appears in the 

335 larger predicate tree. 

336 

337 Returns 

338 ------- 

339 replaced : `tree.Predicate` or `None` 

340 The predicate to be inserted instead in the processed tree, or 

341 `None` if no substitution is needed. 

342 """ 

343 match a, b: 

344 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( 

345 element=b_element 

346 ): 

347 return self.visit_temporal_dimension_join(a_element, b_element, flags) 

348 case _: 

349 # We don't bother differentiating any other kind of temporal 

350 # comparison, because in all foreseeable database schemas we 

351 # wouldn't have to do anything special with them, since they 

352 # don't participate in automatic join calculations and they 

353 # should be straightforwardly convertible to SQL. 

354 return None 

355 

356 def visit_spatial_join( 

357 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

358 ) -> tree.Predicate | None: 

359 """Handle a spatial overlap comparison between two dimension elements. 

360 

361 The default implementation updates the set of known spatial connections 

362 (for use by `compute_automatic_spatial_joins`) and returns `None`. 

363 

364 Parameters 

365 ---------- 

366 a : `DimensionElement` 

367 One element in the join. 

368 b : `DimensionElement` 

369 The other element in the join. 

370 flags : `tree.PredicateLeafFlags` 

371 Information about where this overlap comparison appears in the 

372 larger predicate tree. 

373 

374 Returns 

375 ------- 

376 replaced : `tree.Predicate` or `None` 

377 The predicate to be inserted instead in the processed tree, or 

378 `None` if no substitution is needed. 

379 """ 

380 if a.spatial == b.spatial: 

381 raise InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.") 

382 self._spatial_connections.merge( 

383 cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) 

384 ) 

385 return None 

386 

387 def visit_spatial_constraint( 

388 self, 

389 element: DimensionElement, 

390 region: Region, 

391 flags: PredicateVisitFlags, 

392 ) -> tree.Predicate | None: 

393 """Handle a spatial overlap comparison between a dimension element and 

394 a literal region. 

395 

396 The default implementation just returns `None`. 

397 

398 Parameters 

399 ---------- 

400 element : `DimensionElement` 

401 The dimension element in the comparison. 

402 region : `lsst.sphgeom.Region` 

403 The literal region in the comparison. 

404 flags : `tree.PredicateLeafFlags` 

405 Information about where this overlap comparison appears in the 

406 larger predicate tree. 

407 

408 Returns 

409 ------- 

410 replaced : `tree.Predicate` or `None` 

411 The predicate to be inserted instead in the processed tree, or 

412 `None` if no substitution is needed. 

413 """ 

414 return None 

415 

416 def visit_temporal_dimension_join( 

417 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

418 ) -> tree.Predicate | None: 

419 """Handle a temporal overlap comparison between two dimension elements. 

420 

421 The default implementation updates the set of known temporal 

422 connections (for use by `compute_automatic_temporal_joins`) and returns 

423 `None`. 

424 

425 Parameters 

426 ---------- 

427 a : `DimensionElement` 

428 One element in the join. 

429 b : `DimensionElement` 

430 The other element in the join. 

431 flags : `tree.PredicateLeafFlags` 

432 Information about where this overlap comparison appears in the 

433 larger predicate tree. 

434 

435 Returns 

436 ------- 

437 replaced : `tree.Predicate` or `None` 

438 The predicate to be inserted instead in the processed tree, or 

439 `None` if no substitution is needed. 

440 """ 

441 if a.temporal == b.temporal: 

442 raise InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.") 

443 self._temporal_connections.merge( 

444 cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) 

445 ) 

446 return None