Coverage for python/lsst/daf/butler/direct_query_driver/_postprocessing.py: 30%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 10:14 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("Postprocessing",) 

31 

32from collections.abc import Iterable, Iterator 

33from typing import TYPE_CHECKING, ClassVar 

34 

35import sqlalchemy 

36from lsst.sphgeom import DISJOINT, Region 

37 

38from ..queries import ValidityRangeMatchError 

39from ..queries import tree as qt 

40 

41if TYPE_CHECKING: 

42 from ..dimensions import DimensionElement 

43 

44 

45class Postprocessing: 

46 """A helper object that filters and checks SQL-query result rows to perform 

47 operations we can't [fully] perform in the SQL query. 

48 

49 Notes 

50 ----- 

51 Postprocessing objects are initialized with no parameters to do nothing 

52 when applied; they are modified as needed in place as the query is built. 

53 

54 Postprocessing objects evaluate to `True` in a boolean context only when 

55 they might perform actual row filtering. They may still perform checks 

56 when they evaluate to `False`. 

57 """ 

58 

59 def __init__(self) -> None: 

60 self.spatial_join_filtering = [] 

61 self.spatial_where_filtering = [] 

62 self.check_validity_match_count: bool = False 

63 self._limit: int | None = None 

64 

65 VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT" 

66 """The field name used for the special result column that holds the number 

67 of matching find-first calibration datasets for each data ID. 

68 

69 When present, the value of this column must be one for all rows. 

70 """ 

71 

72 spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] 

73 """Pairs of dimension elements whose regions must overlap; rows with 

74 any non-overlap pair will be filtered out. 

75 """ 

76 

77 spatial_where_filtering: list[tuple[DimensionElement, Region]] 

78 """Dimension elements and regions that must overlap; rows with any 

79 non-overlap pair will be filtered out. 

80 """ 

81 

82 check_validity_match_count: bool 

83 """If `True`, result rows will include a special column that counts the 

84 number of matching datasets in each collection for each data ID, and 

85 postprocessing should check that the value of this column is one for 

86 every row (and raise `.queries.ValidityRangeMatchError` if it is not). 

87 """ 

88 

89 @property 

90 def limit(self) -> int | None: 

91 """The maximum number of rows to return, or `None` for no limit. 

92 

93 This is only set when other postprocess filtering makes it impossible 

94 to apply directly in SQL. 

95 """ 

96 return self._limit 

97 

98 @limit.setter 

99 def limit(self, value: int | None) -> None: 

100 if value and not self: 

101 raise RuntimeError( 

102 "Postprocessing should only implement 'limit' if it needs to do spatial filtering." 

103 ) 

104 self._limit = value 

105 

106 def __bool__(self) -> bool: 

107 return bool(self.spatial_join_filtering or self.spatial_where_filtering) 

108 

109 def gather_columns_required(self, columns: qt.ColumnSet) -> None: 

110 """Add all columns required to perform postprocessing to the given 

111 column set. 

112 

113 Parameters 

114 ---------- 

115 columns : `.queries.tree.ColumnSet` 

116 Column set to modify in place. 

117 """ 

118 for element in self.iter_region_dimension_elements(): 

119 columns.update_dimensions(element.minimal_group) 

120 columns.dimension_fields[element.name].add("region") 

121 

122 def iter_region_dimension_elements(self) -> Iterator[DimensionElement]: 

123 """Iterate over the dimension elements whose regions are needed for 

124 postprocessing. 

125 

126 Returns 

127 ------- 

128 elements : `~collections.abc.Iterator` [ `DimensionElement` ] 

129 Iterator over dimension element objects. 

130 """ 

131 for a, b in self.spatial_join_filtering: 

132 yield a 

133 yield b 

134 for element, _ in self.spatial_where_filtering: 

135 yield element 

136 

137 def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]: 

138 """Iterate over the columns needed for postprocessing that are not in 

139 the given `.queries.tree.ColumnSet`. 

140 

141 Parameters 

142 ---------- 

143 columns : `.queries.tree.ColumnSet` 

144 Columns that should not be returned by this method. These are 

145 typically the columns included in a query even in the absence of 

146 postprocessing. 

147 

148 Returns 

149 ------- 

150 elements : `~collections.abc.Iterator` [ `DimensionElement` ] 

151 Iterator over dimension element objects. 

152 """ 

153 done: set[DimensionElement] = set() 

154 for element in self.iter_region_dimension_elements(): 

155 if element not in done: 

156 if "region" not in columns.dimension_fields.get(element.name, frozenset()): 

157 yield element 

158 done.add(element) 

159 

160 def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: 

161 """Apply the postprocessing to an iterable of SQL result rows. 

162 

163 Parameters 

164 ---------- 

165 rows : `~collections.abc.Iterable` [ `sqlalchemy.Row` ] 

166 Rows to process. 

167 

168 Returns 

169 ------- 

170 processed : `~collections.abc.Iterable` [ `sqlalchemy.Row` ] 

171 Rows that pass the postprocessing filters and checks. 

172 

173 Notes 

174 ----- 

175 This method decreases `limit` in place if it is not `None`, such that 

176 the same `Postprocessing` instance can be applied to each page in a 

177 sequence of result pages. This means a single `Postprocessing` object 

178 can only be used for a single SQL query, and should be discarded when 

179 iteration over the results of that query is complete. 

180 """ 

181 if not (self or self.check_validity_match_count): 

182 yield from rows 

183 if self._limit == 0: 

184 return 

185 joins = [ 

186 ( 

187 qt.ColumnSet.get_qualified_name(a.name, "region"), 

188 qt.ColumnSet.get_qualified_name(b.name, "region"), 

189 ) 

190 for a, b in self.spatial_join_filtering 

191 ] 

192 where = [ 

193 (qt.ColumnSet.get_qualified_name(element.name, "region"), region) 

194 for element, region in self.spatial_where_filtering 

195 ] 

196 

197 for row in rows: 

198 m = row._mapping 

199 if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any( 

200 m[field].relate(region) & DISJOINT for field, region in where 

201 ): 

202 continue 

203 if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: 

204 raise ValidityRangeMatchError( 

205 "Ambiguous calibration validity range match. This usually means a temporal join or " 

206 "'where' needs to be added, but it could also mean that multiple validity ranges " 

207 "overlap a single output data ID." 

208 ) 

209 yield row 

210 if self._limit is not None: 

211 self._limit -= 1 

212 if self._limit == 0: 

213 return