Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 29%

79 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-01 11:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ("ButlerSqlEngine",) 

32 

33import dataclasses 

34from collections.abc import Iterable, Set 

35from typing import Any, cast 

36 

37import astropy.time 

38import sqlalchemy 

39from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql 

40 

41from ..._column_tags import is_timespan_column 

42from ..._column_type_info import ColumnTypeInfo, LogicalColumn 

43from ..._timespan import Timespan, TimespanDatabaseRepresentation 

44from ..nameShrinker import NameShrinker 

45from .find_first_dataset import FindFirstDataset 

46 

47 

48@dataclasses.dataclass(repr=False, eq=False, kw_only=True) 

49class ButlerSqlEngine(sql.Engine[LogicalColumn]): 

50 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan 

51 and `FindFirstDataset` operation support. 

52 """ 

53 

54 column_types: ColumnTypeInfo 

55 """Struct containing information about column types that depend on registry 

56 configuration. 

57 """ 

58 

59 name_shrinker: NameShrinker 

60 """Object used to shrink the SQL identifiers that represent a ColumnTag to 

61 fit within the database's limit (`NameShrinker`). 

62 """ 

63 

64 def __str__(self) -> str: 

65 return self.name 

66 

67 def __repr__(self) -> str: 

68 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}" 

69 

70 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select: 

71 # Docstring inherited. 

72 # This override exists to add support for the custom FindFirstDataset 

73 # operation. 

74 match operation: 

75 case FindFirstDataset(): 

76 if target.has_sort and not target.has_slice: 

77 # Existing target is sorted, but not sliced. We want to 

78 # move that sort outside (i.e. after) the FindFirstDataset, 

79 # since otherwise the FindFirstDataset would put the Sort 

80 # into a CTE where it will do nothing. 

81 inner = target.reapply_skip(sort=Sort()) 

82 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort) 

83 else: 

84 # Apply the FindFirstDataset directly to the existing 

85 # target, which we've already asserted starts with a 

86 # Select. That existing Select will be used for the CTE 

87 # that starts the FindFirstDataset implementation (see 

88 # to_payload override). 

89 return sql.Select.apply_skip(operation._finish_apply(target)) 

90 case _: 

91 return super()._append_unary_to_select(operation, target) 

92 

93 def get_identifier(self, tag: ColumnTag) -> str: 

94 # Docstring inherited. 

95 return self.name_shrinker.shrink(super().get_identifier(tag)) 

96 

97 def extract_mapping( 

98 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection 

99 ) -> dict[ColumnTag, LogicalColumn]: 

100 # Docstring inherited. 

101 # This override exists to add support for Timespan columns. 

102 result: dict[ColumnTag, LogicalColumn] = {} 

103 for tag in tags: 

104 if is_timespan_column(tag): 

105 result[tag] = self.column_types.timespan_cls.from_columns( 

106 sql_columns, name=self.get_identifier(tag) 

107 ) 

108 else: 

109 result[tag] = sql_columns[self.get_identifier(tag)] 

110 return result 

111 

112 def select_items( 

113 self, 

114 items: Iterable[tuple[ColumnTag, LogicalColumn]], 

115 sql_from: sqlalchemy.sql.FromClause, 

116 *extra: sqlalchemy.sql.ColumnElement, 

117 ) -> sqlalchemy.sql.Select: 

118 # Docstring inherited. 

119 # This override exists to add support for Timespan columns. 

120 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

121 for tag, logical_column in items: 

122 if is_timespan_column(tag): 

123 select_columns.extend( 

124 cast(TimespanDatabaseRepresentation, logical_column).flatten( 

125 name=self.get_identifier(tag) 

126 ) 

127 ) 

128 else: 

129 select_columns.append( 

130 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag)) 

131 ) 

132 select_columns.extend(extra) 

133 self.handle_empty_columns(select_columns) 

134 return sqlalchemy.sql.select(*select_columns).select_from(sql_from) 

135 

136 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select: 

137 # Docstring inherited. 

138 # This override exists to add support for Timespan columns. 

139 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

140 for tag in tags: 

141 if is_timespan_column(tag): 

142 select_columns.extend( 

143 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag)) 

144 ) 

145 else: 

146 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag))) 

147 self.handle_empty_columns(select_columns) 

148 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False)) 

149 

150 def convert_column_literal(self, value: Any) -> LogicalColumn: 

151 # Docstring inherited. 

152 # This override exists to add support for Timespan columns. 

153 if isinstance(value, Timespan): 

154 return self.column_types.timespan_cls.fromLiteral(value) 

155 elif isinstance(value, astropy.time.Time): 

156 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai) 

157 else: 

158 return super().convert_column_literal(value) 

159 

160 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]: 

161 # Docstring inherited. 

162 # This override exists to add support for the custom FindFirstDataset 

163 # operation. 

164 match relation: 

165 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target): 

166 # We build a subquery of the form below to search the 

167 # collections in order. 

168 # 

169 # WITH {dst}_search AS ( 

170 # {target} 

171 # ... 

172 # ) 

173 # SELECT 

174 # {dst}_window.*, 

175 # FROM ( 

176 # SELECT 

177 # {dst}_search.*, 

178 # ROW_NUMBER() OVER ( 

179 # PARTITION BY {dst_search}.{operation.dimensions} 

180 # ORDER BY {operation.rank} 

181 # ) AS rownum 

182 # ) {dst}_window 

183 # WHERE 

184 # {dst}_window.rownum = 1; 

185 # 

186 # We'll start with the Common Table Expression (CTE) at the 

187 # top, which we mostly get from the target relation. 

188 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search") 

189 # Now we fill out the SELECT from the CTE, and the subquery it 

190 # contains (at the same time, since they have the same columns, 

191 # aside from the special 'rownum' window-function column). 

192 search_columns = self.extract_mapping(target.columns, search.columns) 

193 partition_by = [search_columns[tag] for tag in operation.dimensions] 

194 rownum_column = sqlalchemy.sql.func.row_number() 

195 if partition_by: 

196 rownum_column = rownum_column.over( 

197 partition_by=partition_by, order_by=search_columns[operation.rank] 

198 ) 

199 else: 

200 rownum_column = rownum_column.over(order_by=search_columns[operation.rank]) 

201 window = self.select_items( 

202 search_columns.items(), search, rownum_column.label("rownum") 

203 ).subquery(f"{operation.rank.dataset_type}_window") 

204 return sql.Payload( 

205 from_clause=window, 

206 columns_available=self.extract_mapping(target.columns, window.columns), 

207 where=[window.columns["rownum"] == 1], 

208 ) 

209 case _: 

210 return super().to_payload(relation)