Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 27%

76 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-17 02:31 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ButlerSqlEngine",) 

24 

25import dataclasses 

26from collections.abc import Iterable, Set 

27from typing import Any, cast 

28 

29import astropy.time 

30import sqlalchemy 

31from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql 

32 

33from ...core import ( 

34 ColumnTypeInfo, 

35 LogicalColumn, 

36 Timespan, 

37 TimespanDatabaseRepresentation, 

38 ddl, 

39 is_timespan_column, 

40) 

41from ..nameShrinker import NameShrinker 

42from .find_first_dataset import FindFirstDataset 

43 

44 

45@dataclasses.dataclass(repr=False, eq=False, kw_only=True) 

46class ButlerSqlEngine(sql.Engine[LogicalColumn]): 

47 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan 

48 and `FindFirstDataset` operation support. 

49 """ 

50 

51 column_types: ColumnTypeInfo 

52 """Struct containing information about column types that depend on registry 

53 configuration. 

54 """ 

55 

56 name_shrinker: NameShrinker 

57 """Object used to shrink the SQL identifiers that represent a ColumnTag to 

58 fit within the database's limit (`NameShrinker`). 

59 """ 

60 

61 def __str__(self) -> str: 

62 return self.name 

63 

64 def __repr__(self) -> str: 

65 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}" 

66 

67 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select: 

68 # Docstring inherited. 

69 # This override exists to add support for the custom FindFirstDataset 

70 # operation. 

71 match operation: 

72 case FindFirstDataset(): 

73 if target.has_sort and not target.has_slice: 

74 # Existing target is sorted, but not sliced. We want to 

75 # move that sort outside (i.e. after) the FindFirstDataset, 

76 # since otherwise the FindFirstDataset would put the Sort 

77 # into a CTE where it will do nothing. 

78 inner = target.reapply_skip(sort=Sort()) 

79 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort) 

80 else: 

81 # Apply the FindFirstDataset directly to the existing 

82 # target, which we've already asserted starts with a 

83 # Select. That existing Select will be used for the CTE 

84 # that starts the FindFirstDataset implementation (see 

85 # to_payload override). 

86 return sql.Select.apply_skip(operation._finish_apply(target)) 

87 case _: 

88 return super()._append_unary_to_select(operation, target) 

89 

90 def get_identifier(self, tag: ColumnTag) -> str: 

91 # Docstring inherited. 

92 return self.name_shrinker.shrink(super().get_identifier(tag)) 

93 

94 def extract_mapping( 

95 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection 

96 ) -> dict[ColumnTag, LogicalColumn]: 

97 # Docstring inherited. 

98 # This override exists to add support for Timespan columns. 

99 result: dict[ColumnTag, LogicalColumn] = {} 

100 for tag in tags: 

101 if is_timespan_column(tag): 

102 result[tag] = self.column_types.timespan_cls.from_columns( 

103 sql_columns, name=self.get_identifier(tag) 

104 ) 

105 else: 

106 result[tag] = sql_columns[self.get_identifier(tag)] 

107 return result 

108 

109 def select_items( 

110 self, 

111 items: Iterable[tuple[ColumnTag, LogicalColumn]], 

112 sql_from: sqlalchemy.sql.FromClause, 

113 *extra: sqlalchemy.sql.ColumnElement, 

114 ) -> sqlalchemy.sql.Select: 

115 # Docstring inherited. 

116 # This override exists to add support for Timespan columns. 

117 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

118 for tag, logical_column in items: 

119 if is_timespan_column(tag): 

120 select_columns.extend( 

121 cast(TimespanDatabaseRepresentation, logical_column).flatten( 

122 name=self.get_identifier(tag) 

123 ) 

124 ) 

125 else: 

126 select_columns.append( 

127 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag)) 

128 ) 

129 select_columns.extend(extra) 

130 self.handle_empty_columns(select_columns) 

131 return sqlalchemy.sql.select(*select_columns).select_from(sql_from) 

132 

133 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select: 

134 # Docstring inherited. 

135 # This override exists to add support for Timespan columns. 

136 select_columns: list[sqlalchemy.sql.ColumnElement] = [] 

137 for tag in tags: 

138 if is_timespan_column(tag): 

139 select_columns.extend( 

140 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag)) 

141 ) 

142 else: 

143 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag))) 

144 self.handle_empty_columns(select_columns) 

145 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False)) 

146 

147 def convert_column_literal(self, value: Any) -> LogicalColumn: 

148 # Docstring inherited. 

149 # This override exists to add support for Timespan columns. 

150 if isinstance(value, Timespan): 

151 return self.column_types.timespan_cls.fromLiteral(value) 

152 elif isinstance(value, astropy.time.Time): 

153 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai) 

154 else: 

155 return super().convert_column_literal(value) 

156 

157 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]: 

158 # Docstring inherited. 

159 # This override exists to add support for the custom FindFirstDataset 

160 # operation. 

161 match relation: 

162 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target): 

163 # We build a subquery of the form below to search the 

164 # collections in order. 

165 # 

166 # WITH {dst}_search AS ( 

167 # {target} 

168 # ... 

169 # ) 

170 # SELECT 

171 # {dst}_window.*, 

172 # FROM ( 

173 # SELECT 

174 # {dst}_search.*, 

175 # ROW_NUMBER() OVER ( 

176 # PARTITION BY {dst_search}.{operation.dimensions} 

177 # ORDER BY {operation.rank} 

178 # ) AS rownum 

179 # ) {dst}_window 

180 # WHERE 

181 # {dst}_window.rownum = 1; 

182 # 

183 # We'll start with the Common Table Expression (CTE) at the 

184 # top, which we mostly get from the target relation. 

185 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search") 

186 # Now we fill out the SELECT from the CTE, and the subquery it 

187 # contains (at the same time, since they have the same columns, 

188 # aside from the special 'rownum' window-function column). 

189 search_columns = self.extract_mapping(target.columns, search.columns) 

190 partition_by = [search_columns[tag] for tag in operation.dimensions] 

191 rownum_column = sqlalchemy.sql.func.row_number() 

192 if partition_by: 

193 rownum_column = rownum_column.over( 

194 partition_by=partition_by, order_by=search_columns[operation.rank] 

195 ) 

196 else: 

197 rownum_column = rownum_column.over(order_by=search_columns[operation.rank]) 

198 window = self.select_items( 

199 search_columns.items(), search, rownum_column.label("rownum") 

200 ).subquery(f"{operation.rank.dataset_type}_window") 

201 return sql.Payload( 

202 from_clause=window, 

203 columns_available=self.extract_mapping(target.columns, window.columns), 

204 where=[window.columns["rownum"] == 1], 

205 ) 

206 case _: 

207 return super().to_payload(relation)