Coverage for python/lsst/daf/butler/direct_query_driver/_postprocessing.py: 30%
67 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 10:24 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 10:24 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("Postprocessing",)
32from collections.abc import Iterable, Iterator
33from typing import TYPE_CHECKING, ClassVar
35import sqlalchemy
36from lsst.sphgeom import DISJOINT, Region
38from ..queries import ValidityRangeMatchError
39from ..queries import tree as qt
41if TYPE_CHECKING:
42 from ..dimensions import DimensionElement
45class Postprocessing:
46 """A helper object that filters and checks SQL-query result rows to perform
47 operations we can't [fully] perform in the SQL query.
49 Notes
50 -----
51 Postprocessing objects are initialized with no parameters to do nothing
52 when applied; they are modified as needed in place as the query is built.
54 Postprocessing objects evaluate to `True` in a boolean context only when
55 they might perform actual row filtering. They may still perform checks
56 when they evaluate to `False`.
57 """
59 def __init__(self) -> None:
60 self.spatial_join_filtering = []
61 self.spatial_where_filtering = []
62 self.check_validity_match_count: bool = False
63 self._limit: int | None = None
65 VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT"
66 """The field name used for the special result column that holds the number
67 of matching find-first calibration datasets for each data ID.
69 When present, the value of this column must be one for all rows.
70 """
72 spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]]
73 """Pairs of dimension elements whose regions must overlap; rows with
74 any non-overlap pair will be filtered out.
75 """
77 spatial_where_filtering: list[tuple[DimensionElement, Region]]
78 """Dimension elements and regions that must overlap; rows with any
79 non-overlap pair will be filtered out.
80 """
82 check_validity_match_count: bool
83 """If `True`, result rows will include a special column that counts the
84 number of matching datasets in each collection for each data ID, and
85 postprocessing should check that the value of this column is one for
86 every row (and raise `.queries.ValidityRangeMatchError` if it is not).
87 """
89 @property
90 def limit(self) -> int | None:
91 """The maximum number of rows to return, or `None` for no limit.
93 This is only set when other postprocess filtering makes it impossible
94 to apply directly in SQL.
95 """
96 return self._limit
98 @limit.setter
99 def limit(self, value: int | None) -> None:
100 if value and not self:
101 raise RuntimeError(
102 "Postprocessing should only implement 'limit' if it needs to do spatial filtering."
103 )
104 self._limit = value
106 def __bool__(self) -> bool:
107 return bool(self.spatial_join_filtering or self.spatial_where_filtering)
109 def gather_columns_required(self, columns: qt.ColumnSet) -> None:
110 """Add all columns required to perform postprocessing to the given
111 column set.
113 Parameters
114 ----------
115 columns : `.queries.tree.ColumnSet`
116 Column set to modify in place.
117 """
118 for element in self.iter_region_dimension_elements():
119 columns.update_dimensions(element.minimal_group)
120 columns.dimension_fields[element.name].add("region")
122 def iter_region_dimension_elements(self) -> Iterator[DimensionElement]:
123 """Iterate over the dimension elements whose regions are needed for
124 postprocessing.
126 Returns
127 -------
128 elements : `~collections.abc.Iterator` [ `DimensionElement` ]
129 Iterator over dimension element objects.
130 """
131 for a, b in self.spatial_join_filtering:
132 yield a
133 yield b
134 for element, _ in self.spatial_where_filtering:
135 yield element
137 def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]:
138 """Iterate over the columns needed for postprocessing that are not in
139 the given `.queries.tree.ColumnSet`.
141 Parameters
142 ----------
143 columns : `.queries.tree.ColumnSet`
144 Columns that should not be returned by this method. These are
145 typically the columns included in a query even in the absence of
146 postprocessing.
148 Returns
149 -------
150 elements : `~collections.abc.Iterator` [ `DimensionElement` ]
151 Iterator over dimension element objects.
152 """
153 done: set[DimensionElement] = set()
154 for element in self.iter_region_dimension_elements():
155 if element not in done:
156 if "region" not in columns.dimension_fields.get(element.name, frozenset()):
157 yield element
158 done.add(element)
160 def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]:
161 """Apply the postprocessing to an iterable of SQL result rows.
163 Parameters
164 ----------
165 rows : `~collections.abc.Iterable` [ `sqlalchemy.Row` ]
166 Rows to process.
168 Returns
169 -------
170 processed : `~collections.abc.Iterable` [ `sqlalchemy.Row` ]
171 Rows that pass the postprocessing filters and checks.
173 Notes
174 -----
175 This method decreases `limit` in place if it is not `None`, such that
176 the same `Postprocessing` instance can be applied to each page in a
177 sequence of result pages. This means a single `Postprocessing` object
178 can only be used for a single SQL query, and should be discarded when
179 iteration over the results of that query is complete.
180 """
181 if not (self or self.check_validity_match_count):
182 yield from rows
183 if self._limit == 0:
184 return
185 joins = [
186 (
187 qt.ColumnSet.get_qualified_name(a.name, "region"),
188 qt.ColumnSet.get_qualified_name(b.name, "region"),
189 )
190 for a, b in self.spatial_join_filtering
191 ]
192 where = [
193 (qt.ColumnSet.get_qualified_name(element.name, "region"), region)
194 for element, region in self.spatial_where_filtering
195 ]
197 for row in rows:
198 m = row._mapping
199 if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any(
200 m[field].relate(region) & DISJOINT for field, region in where
201 ):
202 continue
203 if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1:
204 raise ValidityRangeMatchError(
205 "Ambiguous calibration validity range match. This usually means a temporal join or "
206 "'where' needs to be added, but it could also mean that multiple validity ranges "
207 "overlap a single output data ID."
208 )
209 yield row
210 if self._limit is not None:
211 self._limit -= 1
212 if self._limit == 0:
213 return