Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 29%

85 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 09:58 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ("ButlerSqlEngine",) 

32 

33import dataclasses 

34from collections.abc import Iterable, Set 

35from typing import Any, cast 

36 

37import astropy.time 

38import sqlalchemy 

39from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql 

40 

41from ..._column_tags import is_timespan_column 

42from ..._column_type_info import ColumnTypeInfo, LogicalColumn 

43from ..._timespan import Timespan 

44from ...name_shrinker import NameShrinker 

45from ...timespan_database_representation import TimespanDatabaseRepresentation 

46from .find_first_dataset import FindFirstDataset 

47 

48 

49@dataclasses.dataclass(repr=False, eq=False, kw_only=True) 

50class ButlerSqlEngine(sql.Engine[LogicalColumn]): 

51 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan 

52 and `FindFirstDataset` operation support. 

53 """ 

54 

55 column_types: ColumnTypeInfo 

56 """Struct containing information about column types that depend on registry 

57 configuration. 

58 """ 

59 

60 name_shrinker: NameShrinker 

61 """Object used to shrink the SQL identifiers that represent a ColumnTag to 

62 fit within the database's limit (`NameShrinker`). 

63 """ 

64 

65 def __str__(self) -> str: 

66 return self.name 

67 

68 def __repr__(self) -> str: 

69 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}" 

70 

71 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select: 

72 # Docstring inherited. 

73 # This override exists to add support for the custom FindFirstDataset 

74 # operation. 

75 match operation: 

76 case FindFirstDataset(): 

77 if target.has_sort and not target.has_slice: 

78 # Existing target is sorted, but not sliced. We want to 

79 # move that sort outside (i.e. after) the FindFirstDataset, 

80 # since otherwise the FindFirstDataset would put the Sort 

81 # into a CTE where it will do nothing. 

82 inner = target.reapply_skip(sort=Sort()) 

83 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort) 

84 else: 

85 # Apply the FindFirstDataset directly to the existing 

86 # target, which we've already asserted starts with a 

87 # Select. That existing Select will be used for the CTE 

88 # that starts the FindFirstDataset implementation (see 

89 # to_payload override). 

90 return sql.Select.apply_skip(operation._finish_apply(target)) 

91 case _: 

92 return super()._append_unary_to_select(operation, target) 

93 

94 def get_identifier(self, tag: ColumnTag) -> str: 

95 # Docstring inherited. 

96 return self.name_shrinker.shrink(super().get_identifier(tag)) 

97 

98 def extract_mapping( 

99 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection 

100 ) -> dict[ColumnTag, LogicalColumn]: 

101 # Docstring inherited. 

102 # This override exists to add support for Timespan columns. 

103 result: dict[ColumnTag, LogicalColumn] = {} 

104 for tag in tags: 

105 if is_timespan_column(tag): 

106 result[tag] = self.column_types.timespan_cls.from_columns( 

107 sql_columns, name=self.get_identifier(tag) 

108 ) 

109 else: 

110 result[tag] = sql_columns[self.get_identifier(tag)] 

111 return result 

112 

113 def select_items( 

114 self, 

115 items: Iterable[tuple[ColumnTag, LogicalColumn]], 

116 sql_from: sqlalchemy.sql.FromClause, 

117 *extra: sqlalchemy.sql.ColumnElement, 

118 ) -> sqlalchemy.sql.Select: 

119 # Docstring inherited. 

120 # This override exists to add support for Timespan columns. 

121 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

122 for tag, logical_column in items: 

123 if is_timespan_column(tag): 

124 select_columns.extend( 

125 cast(TimespanDatabaseRepresentation, logical_column).flatten( 

126 name=self.get_identifier(tag) 

127 ) 

128 ) 

129 else: 

130 select_columns.append( 

131 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag)) 

132 ) 

133 select_columns.extend(extra) 

134 self.handle_empty_columns(select_columns) 

135 return sqlalchemy.sql.select(*select_columns).select_from(sql_from) 

136 

137 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select: 

138 # Docstring inherited. 

139 # This override exists to add support for Timespan columns. 

140 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

141 for tag in tags: 

142 if is_timespan_column(tag): 

143 select_columns.extend( 

144 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag)) 

145 ) 

146 else: 

147 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag))) 

148 self.handle_empty_columns(select_columns) 

149 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False)) 

150 

151 def convert_column_literal(self, value: Any) -> LogicalColumn: 

152 # Docstring inherited. 

153 # This override exists to add support for Timespan columns. 

154 if isinstance(value, Timespan): 

155 return self.column_types.timespan_cls.fromLiteral(value) 

156 elif isinstance(value, astropy.time.Time): 

157 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai) 

158 else: 

159 return super().convert_column_literal(value) 

160 

161 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]: 

162 # Docstring inherited. 

163 # This override exists to add support for the custom FindFirstDataset 

164 # operation. 

165 match relation: 

166 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target): 

167 # We build a subquery of the form below to search the 

168 # collections in order. 

169 # 

170 # WITH {dst}_search AS ( 

171 # {target} 

172 # ... 

173 # ) 

174 # SELECT 

175 # {dst}_window.*, 

176 # FROM ( 

177 # SELECT 

178 # {dst}_search.*, 

179 # ROW_NUMBER() OVER ( 

180 # PARTITION BY {dst_search}.{operation.dimensions} 

181 # ORDER BY {operation.rank} 

182 # ) AS rownum 

183 # ) {dst}_window 

184 # WHERE 

185 # {dst}_window.rownum = 1; 

186 # 

187 # We'll start with the Common Table Expression (CTE) at the 

188 # top, which we mostly get from the target relation. 

189 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search") 

190 # Now we fill out the SELECT from the CTE, and the subquery it 

191 # contains (at the same time, since they have the same columns, 

192 # aside from the special 'rownum' window-function column). 

193 search_columns = self.extract_mapping(target.columns, search.columns) 

194 partition_by = [ 

195 _assert_column_is_directly_usable_by_sqlalchemy(search_columns[tag]) 

196 for tag in operation.dimensions 

197 ] 

198 row_number = sqlalchemy.sql.func.row_number() 

199 rank_column = _assert_column_is_directly_usable_by_sqlalchemy(search_columns[operation.rank]) 

200 if partition_by: 

201 rownum_column = row_number.over(partition_by=partition_by, order_by=rank_column) 

202 else: 

203 rownum_column = row_number.over(order_by=rank_column) 

204 window = self.select_items( 

205 search_columns.items(), search, rownum_column.label("rownum") 

206 ).subquery(f"{operation.rank.dataset_type}_window") 

207 return sql.Payload( 

208 from_clause=window, 

209 columns_available=self.extract_mapping(target.columns, window.columns), 

210 where=[window.columns["rownum"] == 1], 

211 ) 

212 case _: 

213 return super().to_payload(relation) 

214 

215 

216def _assert_column_is_directly_usable_by_sqlalchemy(column: LogicalColumn) -> sqlalchemy.sql.ColumnElement: 

217 """Narrow a `LogicalColumn` to a SqlAlchemy ColumnElement, to satisfy the 

218 typechecker in cases where no timespans are expected. 

219 """ 

220 if isinstance(column, TimespanDatabaseRepresentation): 

221 raise TypeError("Timespans not expected here.") 

222 

223 return column