Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%

106 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-05 11:05 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["SynthIntKeyCollectionManager"] 

32 

33from collections.abc import Iterable, Mapping 

34from typing import TYPE_CHECKING, Any 

35 

36import sqlalchemy 

37 

38from ..._timespan import TimespanDatabaseRepresentation 

39from .._collection_type import CollectionType 

40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

41from ._base import ( 

42 CollectionTablesTuple, 

43 DefaultCollectionManager, 

44 makeCollectionChainTableSpec, 

45 makeRunTableSpec, 

46) 

47 

48if TYPE_CHECKING: 

49 from .._caching_context import CachingContext 

50 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext 

51 

52 

53_KEY_FIELD_SPEC = ddl.FieldSpec( 

54 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True 

55) 

56 

57 

58# This has to be updated on every schema change 

59_VERSION = VersionTuple(2, 0, 0) 

60 

61 

62def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

63 return CollectionTablesTuple( 

64 collection=ddl.TableSpec( 

65 fields=[ 

66 _KEY_FIELD_SPEC, 

67 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, nullable=False), 

68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

70 ], 

71 unique=[("name",)], 

72 ), 

73 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass), 

74 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger), 

75 ) 

76 

77 

78class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): 

79 """A `CollectionManager` implementation that uses synthetic primary key 

80 (auto-incremented integer) for collections table. 

81 """ 

82 

83 @classmethod 

84 def initialize( 

85 cls, 

86 db: Database, 

87 context: StaticTablesContext, 

88 *, 

89 dimensions: DimensionRecordStorageManager, 

90 caching_context: CachingContext, 

91 registry_schema_version: VersionTuple | None = None, 

92 ) -> SynthIntKeyCollectionManager: 

93 # Docstring inherited from CollectionManager. 

94 return cls( 

95 db, 

96 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

97 collectionIdName="collection_id", 

98 dimensions=dimensions, 

99 caching_context=caching_context, 

100 registry_schema_version=registry_schema_version, 

101 ) 

102 

103 @classmethod 

104 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

105 # Docstring inherited from CollectionManager. 

106 return f"{prefix}_id" 

107 

108 @classmethod 

109 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

110 # Docstring inherited from CollectionManager. 

111 return f"{prefix}_id" 

112 

113 @classmethod 

114 def addCollectionForeignKey( 

115 cls, 

116 tableSpec: ddl.TableSpec, 

117 *, 

118 prefix: str = "collection", 

119 onDelete: str | None = None, 

120 constraint: bool = True, 

121 **kwargs: Any, 

122 ) -> ddl.FieldSpec: 

123 # Docstring inherited from CollectionManager. 

124 original = _KEY_FIELD_SPEC 

125 copy = ddl.FieldSpec( 

126 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

127 ) 

128 tableSpec.fields.add(copy) 

129 if constraint: 

130 tableSpec.foreignKeys.append( 

131 ddl.ForeignKeySpec( 

132 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

133 ) 

134 ) 

135 return copy 

136 

137 @classmethod 

138 def addRunForeignKey( 

139 cls, 

140 tableSpec: ddl.TableSpec, 

141 *, 

142 prefix: str = "run", 

143 onDelete: str | None = None, 

144 constraint: bool = True, 

145 **kwargs: Any, 

146 ) -> ddl.FieldSpec: 

147 # Docstring inherited from CollectionManager. 

148 original = _KEY_FIELD_SPEC 

149 copy = ddl.FieldSpec( 

150 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

151 ) 

152 tableSpec.fields.add(copy) 

153 if constraint: 153 ↛ 157line 153 didn't jump to line 157, because the condition on line 153 was never false

154 tableSpec.foreignKeys.append( 

155 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

156 ) 

157 return copy 

158 

159 def getParentChains(self, key: int) -> set[str]: 

160 # Docstring inherited from CollectionManager. 

161 chain = self._tables.collection_chain 

162 collection = self._tables.collection 

163 sql = ( 

164 sqlalchemy.sql.select(collection.columns["name"]) 

165 .select_from(collection) 

166 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) 

167 .where(chain.columns["child"] == key) 

168 ) 

169 with self._db.query(sql) as sql_result: 

170 parent_names = set(sql_result.scalars().all()) 

171 return parent_names 

172 

173 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: 

174 # Docstring inherited from base class. 

175 return self._fetch("name", names) 

176 

177 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: 

178 # Docstring inherited from base class. 

179 return self._fetch(self._collectionIdName, collection_ids) 

180 

181 def _fetch( 

182 self, column_name: str, collections: Iterable[int | str] | None 

183 ) -> list[CollectionRecord[int]]: 

184 collection_chain = self._tables.collection_chain 

185 collection = self._tables.collection 

186 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( 

187 collection.join(self._tables.run, isouter=True) 

188 ) 

189 

190 chain_sql = ( 

191 sqlalchemy.sql.select( 

192 collection_chain.columns["parent"], 

193 collection_chain.columns["position"], 

194 collection.columns["name"].label("child_name"), 

195 ) 

196 .select_from(collection_chain) 

197 .join( 

198 collection, 

199 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], 

200 ) 

201 ) 

202 

203 records: list[CollectionRecord[int]] = [] 

204 # We want to keep transactions as short as possible. When we fetch 

205 # everything we want to quickly fetch things into memory and finish 

206 # transaction. When we fetch just few records we need to process first 

207 # query before wi can run second one, 

208 if collections is not None: 

209 sql = sql.where(collection.columns[column_name].in_(collections)) 

210 with self._db.transaction(): 

211 with self._db.query(sql) as sql_result: 

212 sql_rows = sql_result.mappings().fetchall() 

213 

214 records, chained_ids = self._rows_to_records(sql_rows) 

215 

216 if chained_ids: 

217 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) 

218 

219 with self._db.query(chain_sql) as sql_result: 

220 chain_rows = sql_result.mappings().fetchall() 

221 

222 records += self._rows_to_chains(chain_rows, chained_ids) 

223 

224 else: 

225 with self._db.transaction(): 

226 with self._db.query(sql) as sql_result: 

227 sql_rows = sql_result.mappings().fetchall() 

228 with self._db.query(chain_sql) as sql_result: 

229 chain_rows = sql_result.mappings().fetchall() 

230 

231 records, chained_ids = self._rows_to_records(sql_rows) 

232 records += self._rows_to_chains(chain_rows, chained_ids) 

233 

234 return records 

235 

236 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: 

237 """Convert rows returned from collection query to a list of records 

238 and a dict chained collection names. 

239 """ 

240 records: list[CollectionRecord[int]] = [] 

241 chained_ids: dict[int, str] = {} 

242 TimespanReprClass = self._db.getTimespanRepresentation() 

243 for row in rows: 

244 key: int = row[self._collectionIdName] 

245 name: str = row[self._tables.collection.columns.name] 

246 type = CollectionType(row["type"]) 

247 record: CollectionRecord[int] 

248 if type is CollectionType.RUN: 

249 record = RunRecord[int]( 

250 key=key, 

251 name=name, 

252 host=row[self._tables.run.columns.host], 

253 timespan=TimespanReprClass.extract(row), 

254 ) 

255 records.append(record) 

256 elif type is CollectionType.CHAINED: 

257 # Need to delay chained collection construction until to 

258 # fetch their children names. 

259 chained_ids[key] = name 

260 else: 

261 record = CollectionRecord[int](key=key, name=name, type=type) 

262 records.append(record) 

263 return records, chained_ids 

264 

265 def _rows_to_chains( 

266 self, rows: Iterable[Mapping], chained_ids: dict[int, str] 

267 ) -> list[CollectionRecord[int]]: 

268 """Convert rows returned from collection chain query to a list of 

269 records. 

270 """ 

271 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

272 for row in rows: 

273 chains_defs[row["parent"]].append((row["position"], row["child_name"])) 

274 

275 records: list[CollectionRecord[int]] = [] 

276 for key, children in chains_defs.items(): 

277 name = chained_ids[key] 

278 children_names = [child for _, child in sorted(children)] 

279 record = ChainedCollectionRecord[int]( 

280 key=key, 

281 name=name, 

282 children=children_names, 

283 ) 

284 records.append(record) 

285 

286 return records 

287 

288 @classmethod 

289 def currentVersions(cls) -> list[VersionTuple]: 

290 # Docstring inherited from VersionedExtension. 

291 return [_VERSION]