Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 27%

76 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 08:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("ButlerSqlEngine",) 

30 

31import dataclasses 

32from collections.abc import Iterable, Set 

33from typing import Any, cast 

34 

35import astropy.time 

36import sqlalchemy 

37from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql 

38 

39from ...core import ( 

40 ColumnTypeInfo, 

41 LogicalColumn, 

42 Timespan, 

43 TimespanDatabaseRepresentation, 

44 ddl, 

45 is_timespan_column, 

46) 

47from ..nameShrinker import NameShrinker 

48from .find_first_dataset import FindFirstDataset 

49 

50 

51@dataclasses.dataclass(repr=False, eq=False, kw_only=True) 

52class ButlerSqlEngine(sql.Engine[LogicalColumn]): 

53 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan 

54 and `FindFirstDataset` operation support. 

55 """ 

56 

57 column_types: ColumnTypeInfo 

58 """Struct containing information about column types that depend on registry 

59 configuration. 

60 """ 

61 

62 name_shrinker: NameShrinker 

63 """Object used to shrink the SQL identifiers that represent a ColumnTag to 

64 fit within the database's limit (`NameShrinker`). 

65 """ 

66 

67 def __str__(self) -> str: 

68 return self.name 

69 

70 def __repr__(self) -> str: 

71 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}" 

72 

73 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select: 

74 # Docstring inherited. 

75 # This override exists to add support for the custom FindFirstDataset 

76 # operation. 

77 match operation: 

78 case FindFirstDataset(): 

79 if target.has_sort and not target.has_slice: 

80 # Existing target is sorted, but not sliced. We want to 

81 # move that sort outside (i.e. after) the FindFirstDataset, 

82 # since otherwise the FindFirstDataset would put the Sort 

83 # into a CTE where it will do nothing. 

84 inner = target.reapply_skip(sort=Sort()) 

85 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort) 

86 else: 

87 # Apply the FindFirstDataset directly to the existing 

88 # target, which we've already asserted starts with a 

89 # Select. That existing Select will be used for the CTE 

90 # that starts the FindFirstDataset implementation (see 

91 # to_payload override). 

92 return sql.Select.apply_skip(operation._finish_apply(target)) 

93 case _: 

94 return super()._append_unary_to_select(operation, target) 

95 

96 def get_identifier(self, tag: ColumnTag) -> str: 

97 # Docstring inherited. 

98 return self.name_shrinker.shrink(super().get_identifier(tag)) 

99 

100 def extract_mapping( 

101 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection 

102 ) -> dict[ColumnTag, LogicalColumn]: 

103 # Docstring inherited. 

104 # This override exists to add support for Timespan columns. 

105 result: dict[ColumnTag, LogicalColumn] = {} 

106 for tag in tags: 

107 if is_timespan_column(tag): 

108 result[tag] = self.column_types.timespan_cls.from_columns( 

109 sql_columns, name=self.get_identifier(tag) 

110 ) 

111 else: 

112 result[tag] = sql_columns[self.get_identifier(tag)] 

113 return result 

114 

115 def select_items( 

116 self, 

117 items: Iterable[tuple[ColumnTag, LogicalColumn]], 

118 sql_from: sqlalchemy.sql.FromClause, 

119 *extra: sqlalchemy.sql.ColumnElement, 

120 ) -> sqlalchemy.sql.Select: 

121 # Docstring inherited. 

122 # This override exists to add support for Timespan columns. 

123 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

124 for tag, logical_column in items: 

125 if is_timespan_column(tag): 

126 select_columns.extend( 

127 cast(TimespanDatabaseRepresentation, logical_column).flatten( 

128 name=self.get_identifier(tag) 

129 ) 

130 ) 

131 else: 

132 select_columns.append( 

133 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag)) 

134 ) 

135 select_columns.extend(extra) 

136 self.handle_empty_columns(select_columns) 

137 return sqlalchemy.sql.select(*select_columns).select_from(sql_from) 

138 

139 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select: 

140 # Docstring inherited. 

141 # This override exists to add support for Timespan columns. 

142 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

143 for tag in tags: 

144 if is_timespan_column(tag): 

145 select_columns.extend( 

146 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag)) 

147 ) 

148 else: 

149 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag))) 

150 self.handle_empty_columns(select_columns) 

151 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False)) 

152 

153 def convert_column_literal(self, value: Any) -> LogicalColumn: 

154 # Docstring inherited. 

155 # This override exists to add support for Timespan columns. 

156 if isinstance(value, Timespan): 

157 return self.column_types.timespan_cls.fromLiteral(value) 

158 elif isinstance(value, astropy.time.Time): 

159 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai) 

160 else: 

161 return super().convert_column_literal(value) 

162 

163 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]: 

164 # Docstring inherited. 

165 # This override exists to add support for the custom FindFirstDataset 

166 # operation. 

167 match relation: 

168 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target): 

169 # We build a subquery of the form below to search the 

170 # collections in order. 

171 # 

172 # WITH {dst}_search AS ( 

173 # {target} 

174 # ... 

175 # ) 

176 # SELECT 

177 # {dst}_window.*, 

178 # FROM ( 

179 # SELECT 

180 # {dst}_search.*, 

181 # ROW_NUMBER() OVER ( 

182 # PARTITION BY {dst_search}.{operation.dimensions} 

183 # ORDER BY {operation.rank} 

184 # ) AS rownum 

185 # ) {dst}_window 

186 # WHERE 

187 # {dst}_window.rownum = 1; 

188 # 

189 # We'll start with the Common Table Expression (CTE) at the 

190 # top, which we mostly get from the target relation. 

191 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search") 

192 # Now we fill out the SELECT from the CTE, and the subquery it 

193 # contains (at the same time, since they have the same columns, 

194 # aside from the special 'rownum' window-function column). 

195 search_columns = self.extract_mapping(target.columns, search.columns) 

196 partition_by = [search_columns[tag] for tag in operation.dimensions] 

197 rownum_column = sqlalchemy.sql.func.row_number() 

198 if partition_by: 

199 rownum_column = rownum_column.over( 

200 partition_by=partition_by, order_by=search_columns[operation.rank] 

201 ) 

202 else: 

203 rownum_column = rownum_column.over(order_by=search_columns[operation.rank]) 

204 window = self.select_items( 

205 search_columns.items(), search, rownum_column.label("rownum") 

206 ).subquery(f"{operation.rank.dataset_type}_window") 

207 return sql.Payload( 

208 from_clause=window, 

209 columns_available=self.extract_mapping(target.columns, window.columns), 

210 where=[window.columns["rownum"] == 1], 

211 ) 

212 case _: 

213 return super().to_payload(relation)