Coverage for python / lsst / daf / butler / direct_query_driver / _postprocessing.py: 29%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:36 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("Postprocessing",) 

31 

32from collections.abc import Iterable, Iterator 

33from typing import TYPE_CHECKING, ClassVar 

34 

35import sqlalchemy 

36 

37from lsst.sphgeom import Region 

38 

39from .._exceptions import CalibrationLookupError 

40from ..queries import tree as qt 

41 

42if TYPE_CHECKING: 

43 from ..dimensions import DimensionElement 

44 

45 

46class Postprocessing: 

47 """A helper object that filters and checks SQL-query result rows to perform 

48 operations we can't [fully] perform in the SQL query. 

49 

50 Notes 

51 ----- 

52 Postprocessing objects are initialized with no parameters to do nothing 

53 when applied; they are modified as needed in place as the query is built. 

54 

55 Postprocessing objects evaluate to `True` in a boolean context only when 

56 they might perform actual row filtering. They may still perform checks 

57 when they evaluate to `False`. 

58 """ 

59 

60 def __init__(self) -> None: 

61 self.spatial_join_filtering = [] 

62 self.spatial_where_filtering = [] 

63 self.spatial_expression_filtering = [] 

64 self.check_validity_match_count: bool = False 

65 self._limit: int | None = None 

66 

67 VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT" 

68 """The field name used for the special result column that holds the number 

69 of matching find-first calibration datasets for each data ID. 

70 

71 When present, the value of this column must be one for all rows. 

72 """ 

73 

74 spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] 

75 """Pairs of dimension elements whose regions must overlap; rows with 

76 any non-overlap pair will be filtered out. 

77 """ 

78 

79 spatial_where_filtering: list[tuple[DimensionElement, Region]] 

80 """Dimension elements and regions that must overlap; rows with any 

81 non-overlap pair will be filtered out. 

82 """ 

83 

84 spatial_expression_filtering: list[str] 

85 """The names of calculated columns that can be parsed by 

86 `lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that 

87 indicates whether regions definitely overlap. 

88 """ 

89 

90 check_validity_match_count: bool 

91 """If `True`, result rows will include a special column that counts the 

92 number of matching datasets in each collection for each data ID, and 

93 postprocessing should check that the value of this column is one for 

94 every row (and raise `CalibrationLookupError` if it is not). 

95 """ 

96 

97 @property 

98 def limit(self) -> int | None: 

99 """The maximum number of rows to return, or `None` for no limit. 

100 

101 This is only set when other postprocess filtering makes it impossible 

102 to apply directly in SQL. 

103 """ 

104 return self._limit 

105 

106 @limit.setter 

107 def limit(self, value: int | None) -> None: 

108 if value and not self: 

109 raise RuntimeError( 

110 "Postprocessing should only implement 'limit' if it needs to do spatial filtering." 

111 ) 

112 self._limit = value 

113 

114 def __bool__(self) -> bool: 

115 return bool( 

116 self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering 

117 ) 

118 

119 def gather_columns_required(self, columns: qt.ColumnSet) -> None: 

120 """Add all columns required to perform postprocessing to the given 

121 column set. 

122 

123 Parameters 

124 ---------- 

125 columns : `.queries.tree.ColumnSet` 

126 Column set to modify in place. 

127 """ 

128 for element in self.iter_region_dimension_elements(): 

129 columns.update_dimensions(element.minimal_group) 

130 columns.dimension_fields[element.name].add("region") 

131 

132 def iter_region_dimension_elements(self) -> Iterator[DimensionElement]: 

133 """Iterate over the dimension elements whose regions are needed for 

134 postprocessing. 

135 

136 Returns 

137 ------- 

138 elements : `~collections.abc.Iterator` [ `DimensionElement` ] 

139 Iterator over dimension element objects. 

140 """ 

141 for a, b in self.spatial_join_filtering: 

142 yield a 

143 yield b 

144 for element, _ in self.spatial_where_filtering: 

145 yield element 

146 

147 def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]: 

148 """Iterate over the columns needed for postprocessing that are not in 

149 the given `.queries.tree.ColumnSet`. 

150 

151 Parameters 

152 ---------- 

153 columns : `.queries.tree.ColumnSet` 

154 Columns that should not be returned by this method. These are 

155 typically the columns included in a query even in the absence of 

156 postprocessing. 

157 

158 Returns 

159 ------- 

160 elements : `~collections.abc.Iterator` [ `DimensionElement` ] 

161 Iterator over dimension element objects. 

162 """ 

163 done: set[DimensionElement] = set() 

164 for element in self.iter_region_dimension_elements(): 

165 if element not in done: 

166 if "region" not in columns.dimension_fields.get(element.name, frozenset()): 

167 yield element 

168 done.add(element) 

169 

170 def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: 

171 """Apply the postprocessing to an iterable of SQL result rows. 

172 

173 Parameters 

174 ---------- 

175 rows : `~collections.abc.Iterable` [ `sqlalchemy.Row` ] 

176 Rows to process. 

177 

178 Returns 

179 ------- 

180 processed : `~collections.abc.Iterable` [ `sqlalchemy.Row` ] 

181 Rows that pass the postprocessing filters and checks. 

182 

183 Notes 

184 ----- 

185 This method decreases `limit` in place if it is not `None`, such that 

186 the same `Postprocessing` instance can be applied to each page in a 

187 sequence of result pages. This means a single `Postprocessing` object 

188 can only be used for a single SQL query, and should be discarded when 

189 iteration over the results of that query is complete. 

190 """ 

191 if not (self or self.check_validity_match_count): 

192 yield from rows 

193 return 

194 if self._limit == 0: 

195 return 

196 joins = [ 

197 ( 

198 qt.ColumnSet.get_qualified_name(a.name, "region"), 

199 qt.ColumnSet.get_qualified_name(b.name, "region"), 

200 ) 

201 for a, b in self.spatial_join_filtering 

202 ] 

203 where = [ 

204 (qt.ColumnSet.get_qualified_name(element.name, "region"), region) 

205 for element, region in self.spatial_where_filtering 

206 ] 

207 

208 for row in rows: 

209 m = row._mapping 

210 # Skip rows where at least one couple of regions do not overlap. 

211 if ( 

212 any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering) 

213 or any(m[a].overlaps(m[b]) is False for a, b in joins) 

214 or any(m[field].overlaps(region) is False for field, region in where) 

215 ): 

216 continue 

217 if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: 

218 raise CalibrationLookupError( 

219 "Ambiguous calibration validity range match. This usually means a temporal join or " 

220 "'where' needs to be added, but it could also mean that multiple validity ranges " 

221 "overlap a single output data ID." 

222 ) 

223 yield row 

224 if self._limit is not None: 

225 self._limit -= 1 

226 if self._limit == 0: 

227 return