Coverage for python/lsst/daf/butler/queries/overlaps.py: 20%
115 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 02:47 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 02:47 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("OverlapsVisitor",)
32import itertools
33from collections.abc import Hashable, Iterable, Sequence, Set
34from typing import Generic, Literal, TypeVar, cast
36from lsst.sphgeom import Region
38from .._exceptions import InvalidQueryError
39from .._topology import TopologicalFamily
40from ..dimensions import DimensionElement, DimensionGroup
41from . import tree
42from .visitors import PredicateVisitFlags, SimplePredicateVisitor
44_T = TypeVar("_T", bound=Hashable)
47class _NaiveDisjointSet(Generic[_T]):
48 """A very naive (but simple) implementation of a "disjoint set" data
49 structure for strings, with mostly O(N) performance.
51 This class should not be used in any context where the number of elements
52 in the data structure is large. It intentionally implements a subset of
53 the interface of `scipy.cluster.DisJointSet` so that non-naive
54 implementation could be swapped in if desired.
56 Parameters
57 ----------
58 superset : `~collections.abc.Iterable` [ `str` ]
59 Elements to initialize the disjoint set, with each in its own
60 single-element subset.
61 """
63 def __init__(self, superset: Iterable[_T]):
64 self._subsets = [{k} for k in superset]
65 self._subsets.sort(key=len, reverse=True)
67 def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04
68 """Merge the subsets containing the given elements.
70 Parameters
71 ----------
72 a :
73 Element whose subset should be merged.
74 b :
75 Element whose subset should be merged.
77 Returns
78 -------
79 merged : `bool`
80 `True` if a merge occurred, `False` if the elements were already in
81 the same subset.
82 """
83 for i, subset in enumerate(self._subsets):
84 if a in subset:
85 break
86 else:
87 raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.")
88 for j, subset in enumerate(self._subsets):
89 if b in subset:
90 break
91 else:
92 raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.")
93 if i == j:
94 return False
95 i, j = sorted((i, j))
96 self._subsets[i].update(self._subsets[j])
97 del self._subsets[j]
98 self._subsets.sort(key=len, reverse=True)
99 return True
101 def subsets(self) -> Sequence[Set[_T]]:
102 """Return the current subsets, ordered from largest to smallest."""
103 return self._subsets
105 @property
106 def n_subsets(self) -> int:
107 """The number of subsets."""
108 return len(self._subsets)
111class OverlapsVisitor(SimplePredicateVisitor):
112 """A helper class for dealing with spatial and temporal overlaps in a
113 query.
115 Parameters
116 ----------
117 dimensions : `DimensionGroup`
118 Dimensions of the query.
120 Notes
121 -----
122 This class includes logic for extracting explicit spatial and temporal
123 joins from a WHERE-clause predicate and computing automatic joins given the
124 dimensions of the query. It is designed to be subclassed by query driver
125 implementations that want to rewrite the predicate at the same time.
126 """
128 def __init__(self, dimensions: DimensionGroup):
129 self.dimensions = dimensions
130 self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial)
131 self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal)
133 def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate:
134 """Process the given predicate to extract spatial and temporal
135 overlaps.
137 Parameters
138 ----------
139 predicate : `tree.Predicate`
140 Predicate to process.
141 join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ]
142 The dimensions of logical tables being joined into this query;
143 these can included embedded spatial and temporal joins that can
144 make it unnecessary to add new ones.
146 Returns
147 -------
148 predicate : `tree.Predicate`
149 A possibly-modified predicate that should replace the original.
150 """
151 result = predicate.visit(self)
152 if result is None:
153 result = predicate
154 for join_operand_dimensions in join_operands:
155 self.add_join_operand_connections(join_operand_dimensions)
156 for a, b in self.compute_automatic_spatial_joins():
157 join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS)
158 if join_predicate is None:
159 join_predicate = tree.Predicate.compare(
160 tree.DimensionFieldReference.model_construct(element=a, field="region"),
161 "overlaps",
162 tree.DimensionFieldReference.model_construct(element=b, field="region"),
163 )
164 result = result.logical_and(join_predicate)
165 for a, b in self.compute_automatic_temporal_joins():
166 join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS)
167 if join_predicate is None:
168 join_predicate = tree.Predicate.compare(
169 tree.DimensionFieldReference.model_construct(element=a, field="timespan"),
170 "overlaps",
171 tree.DimensionFieldReference.model_construct(element=b, field="timespan"),
172 )
173 result = result.logical_and(join_predicate)
174 return result
176 def visit_comparison(
177 self,
178 a: tree.ColumnExpression,
179 operator: tree.ComparisonOperator,
180 b: tree.ColumnExpression,
181 flags: PredicateVisitFlags,
182 ) -> tree.Predicate | None:
183 # Docstring inherited.
184 if operator == "overlaps":
185 if a.column_type == "region":
186 return self.visit_spatial_overlap(a, b, flags)
187 elif b.column_type == "timespan":
188 return self.visit_temporal_overlap(a, b, flags)
189 else:
190 raise AssertionError(f"Unexpected column type {a.column_type} for overlap.")
191 return None
193 def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None:
194 """Add overlap connections implied by a table or subquery.
196 Parameters
197 ----------
198 operand_dimensions : `DimensionGroup`
199 Dimensions of of the table or subquery.
201 Notes
202 -----
203 We assume each join operand to a `tree.Select` has its own
204 complete set of spatial and temporal joins that went into generating
205 its rows. That will naturally be true for relations originating from
206 the butler database, like dataset searches and materializations, and if
207 it isn't true for a data ID upload, that would represent an intentional
208 association between non-overlapping things that we'd want to respect by
209 *not* adding a more restrictive automatic join.
210 """
211 for a_family, b_family in itertools.pairwise(operand_dimensions.spatial):
212 self._spatial_connections.merge(a_family, b_family)
213 for a_family, b_family in itertools.pairwise(operand_dimensions.temporal):
214 self._temporal_connections.merge(a_family, b_family)
216 def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]:
217 """Return pairs of dimension elements that should be spatially joined.
219 Returns
220 -------
221 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ]
222 Automatic joins.
224 Notes
225 -----
226 This takes into account explicit joins extracted by `run` and implicit
227 joins added by `add_join_operand_connections`, and only returns
228 additional joins if there is an unambiguous way to spatially connect
229 any dimensions that are not already spatially connected. Automatic
230 joins are always the most fine-grained join between sets of dimensions
231 (i.e. ``visit_detector_region`` and ``patch`` instead of ``visit`` and
232 ``tract``), but explicitly adding a coarser join between sets of
233 elements will prevent the fine-grained join from being added.
234 """
235 return self._compute_automatic_joins("spatial", self._spatial_connections)
237 def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]:
238 """Return pairs of dimension elements that should be spatially joined.
240 Returns
241 -------
242 joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ]
243 Automatic joins.
245 Notes
246 -----
247 See `compute_automatic_spatial_joins` for information on how automatic
248 joins are determined. Joins to dataset validity ranges are never
249 automatic.
250 """
251 return self._compute_automatic_joins("temporal", self._temporal_connections)
253 def _compute_automatic_joins(
254 self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily]
255 ) -> list[tuple[DimensionElement, DimensionElement]]:
256 if connections.n_subsets <= 1:
257 # All of the joins we need are already present.
258 return []
259 if connections.n_subsets > 2:
260 raise InvalidQueryError(
261 f"Too many disconnected sets of {kind} families for an automatic "
262 f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error."
263 )
264 a_subset, b_subset = connections.subsets()
265 if len(a_subset) > 1 or len(b_subset) > 1:
266 raise InvalidQueryError(
267 f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to "
268 "add is ambiguous. Add an explicit spatial join to avoid this error."
269 )
270 # We have a pair of families that are not explicitly or implicitly
271 # connected to any other families; add an automatic join between their
272 # most fine-grained members.
273 (a_family,) = a_subset
274 (b_family,) = b_subset
275 return [
276 (
277 cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)),
278 cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)),
279 )
280 ]
282 def visit_spatial_overlap(
283 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags
284 ) -> tree.Predicate | None:
285 """Dispatch a spatial overlap comparison predicate to handlers.
287 This method should rarely (if ever) need to be overridden.
289 Parameters
290 ----------
291 a : `tree.ColumnExpression`
292 First operand.
293 b : `tree.ColumnExpression`
294 Second operand.
295 flags : `tree.PredicateLeafFlags`
296 Information about where this overlap comparison appears in the
297 larger predicate tree.
299 Returns
300 -------
301 replaced : `tree.Predicate` or `None`
302 The predicate to be inserted instead in the processed tree, or
303 `None` if no substitution is needed.
304 """
305 match a, b:
306 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference(
307 element=b_element
308 ):
309 return self.visit_spatial_join(a_element, b_element, flags)
310 case tree.DimensionFieldReference(element=element), region_expression:
311 pass
312 case region_expression, tree.DimensionFieldReference(element=element):
313 pass
314 case _:
315 raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.")
316 if region := region_expression.get_literal_value():
317 return self.visit_spatial_constraint(element, region, flags)
318 raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.")
320 def visit_temporal_overlap(
321 self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags
322 ) -> tree.Predicate | None:
323 """Dispatch a temporal overlap comparison predicate to handlers.
325 This method should rarely (if ever) need to be overridden.
327 Parameters
328 ----------
329 a : `tree.ColumnExpression`-
330 First operand.
331 b : `tree.ColumnExpression`
332 Second operand.
333 flags : `tree.PredicateLeafFlags`
334 Information about where this overlap comparison appears in the
335 larger predicate tree.
337 Returns
338 -------
339 replaced : `tree.Predicate` or `None`
340 The predicate to be inserted instead in the processed tree, or
341 `None` if no substitution is needed.
342 """
343 match a, b:
344 case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference(
345 element=b_element
346 ):
347 return self.visit_temporal_dimension_join(a_element, b_element, flags)
348 case _:
349 # We don't bother differentiating any other kind of temporal
350 # comparison, because in all foreseeable database schemas we
351 # wouldn't have to do anything special with them, since they
352 # don't participate in automatic join calculations and they
353 # should be straightforwardly convertible to SQL.
354 return None
356 def visit_spatial_join(
357 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
358 ) -> tree.Predicate | None:
359 """Handle a spatial overlap comparison between two dimension elements.
361 The default implementation updates the set of known spatial connections
362 (for use by `compute_automatic_spatial_joins`) and returns `None`.
364 Parameters
365 ----------
366 a : `DimensionElement`
367 One element in the join.
368 b : `DimensionElement`
369 The other element in the join.
370 flags : `tree.PredicateLeafFlags`
371 Information about where this overlap comparison appears in the
372 larger predicate tree.
374 Returns
375 -------
376 replaced : `tree.Predicate` or `None`
377 The predicate to be inserted instead in the processed tree, or
378 `None` if no substitution is needed.
379 """
380 if a.spatial == b.spatial:
381 raise InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.")
382 self._spatial_connections.merge(
383 cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial)
384 )
385 return None
387 def visit_spatial_constraint(
388 self,
389 element: DimensionElement,
390 region: Region,
391 flags: PredicateVisitFlags,
392 ) -> tree.Predicate | None:
393 """Handle a spatial overlap comparison between a dimension element and
394 a literal region.
396 The default implementation just returns `None`.
398 Parameters
399 ----------
400 element : `DimensionElement`
401 The dimension element in the comparison.
402 region : `lsst.sphgeom.Region`
403 The literal region in the comparison.
404 flags : `tree.PredicateLeafFlags`
405 Information about where this overlap comparison appears in the
406 larger predicate tree.
408 Returns
409 -------
410 replaced : `tree.Predicate` or `None`
411 The predicate to be inserted instead in the processed tree, or
412 `None` if no substitution is needed.
413 """
414 return None
416 def visit_temporal_dimension_join(
417 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
418 ) -> tree.Predicate | None:
419 """Handle a temporal overlap comparison between two dimension elements.
421 The default implementation updates the set of known temporal
422 connections (for use by `compute_automatic_temporal_joins`) and returns
423 `None`.
425 Parameters
426 ----------
427 a : `DimensionElement`
428 One element in the join.
429 b : `DimensionElement`
430 The other element in the join.
431 flags : `tree.PredicateLeafFlags`
432 Information about where this overlap comparison appears in the
433 larger predicate tree.
435 Returns
436 -------
437 replaced : `tree.Predicate` or `None`
438 The predicate to be inserted instead in the processed tree, or
439 `None` if no substitution is needed.
440 """
441 if a.temporal == b.temporal:
442 raise InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.")
443 self._temporal_connections.merge(
444 cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal)
445 )
446 return None