Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%

110 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-16 10:43 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["SynthIntKeyCollectionManager"] 

32 

33import logging 

34from collections.abc import Iterable, Mapping 

35from typing import TYPE_CHECKING, Any 

36 

37import sqlalchemy 

38 

39from ...timespan_database_representation import TimespanDatabaseRepresentation 

40from .._collection_type import CollectionType 

41from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

42from ._base import ( 

43 CollectionTablesTuple, 

44 DefaultCollectionManager, 

45 makeCollectionChainTableSpec, 

46 makeRunTableSpec, 

47) 

48 

49if TYPE_CHECKING: 

50 from .._caching_context import CachingContext 

51 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext 

52 

53 

54_KEY_FIELD_SPEC = ddl.FieldSpec( 

55 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True 

56) 

57 

58 

59# This has to be updated on every schema change 

60_VERSION = VersionTuple(2, 0, 0) 

61 

62_LOG = logging.getLogger(__name__) 

63 

64 

65def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

66 return CollectionTablesTuple( 

67 collection=ddl.TableSpec( 

68 fields=[ 

69 _KEY_FIELD_SPEC, 

70 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, nullable=False), 

71 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

72 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

73 ], 

74 unique=[("name",)], 

75 ), 

76 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass), 

77 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger), 

78 ) 

79 

80 

81class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): 

82 """A `CollectionManager` implementation that uses synthetic primary key 

83 (auto-incremented integer) for collections table. 

84 """ 

85 

86 @classmethod 

87 def initialize( 

88 cls, 

89 db: Database, 

90 context: StaticTablesContext, 

91 *, 

92 dimensions: DimensionRecordStorageManager, 

93 caching_context: CachingContext, 

94 registry_schema_version: VersionTuple | None = None, 

95 ) -> SynthIntKeyCollectionManager: 

96 # Docstring inherited from CollectionManager. 

97 return cls( 

98 db, 

99 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

100 collectionIdName="collection_id", 

101 dimensions=dimensions, 

102 caching_context=caching_context, 

103 registry_schema_version=registry_schema_version, 

104 ) 

105 

106 @classmethod 

107 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

108 # Docstring inherited from CollectionManager. 

109 return f"{prefix}_id" 

110 

111 @classmethod 

112 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

113 # Docstring inherited from CollectionManager. 

114 return f"{prefix}_id" 

115 

116 @classmethod 

117 def addCollectionForeignKey( 

118 cls, 

119 tableSpec: ddl.TableSpec, 

120 *, 

121 prefix: str = "collection", 

122 onDelete: str | None = None, 

123 constraint: bool = True, 

124 **kwargs: Any, 

125 ) -> ddl.FieldSpec: 

126 # Docstring inherited from CollectionManager. 

127 original = _KEY_FIELD_SPEC 

128 copy = ddl.FieldSpec( 

129 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

130 ) 

131 tableSpec.fields.add(copy) 

132 if constraint: 

133 tableSpec.foreignKeys.append( 

134 ddl.ForeignKeySpec( 

135 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

136 ) 

137 ) 

138 return copy 

139 

140 @classmethod 

141 def addRunForeignKey( 

142 cls, 

143 tableSpec: ddl.TableSpec, 

144 *, 

145 prefix: str = "run", 

146 onDelete: str | None = None, 

147 constraint: bool = True, 

148 **kwargs: Any, 

149 ) -> ddl.FieldSpec: 

150 # Docstring inherited from CollectionManager. 

151 original = _KEY_FIELD_SPEC 

152 copy = ddl.FieldSpec( 

153 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

154 ) 

155 tableSpec.fields.add(copy) 

156 if constraint: 156 ↛ 160line 156 didn't jump to line 160, because the condition on line 156 was never false

157 tableSpec.foreignKeys.append( 

158 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

159 ) 

160 return copy 

161 

162 def getParentChains(self, key: int) -> set[str]: 

163 # Docstring inherited from CollectionManager. 

164 chain = self._tables.collection_chain 

165 collection = self._tables.collection 

166 sql = ( 

167 sqlalchemy.sql.select(collection.columns["name"]) 

168 .select_from(collection) 

169 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) 

170 .where(chain.columns["child"] == key) 

171 ) 

172 with self._db.query(sql) as sql_result: 

173 parent_names = set(sql_result.scalars().all()) 

174 return parent_names 

175 

176 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: 

177 # Docstring inherited from base class. 

178 _LOG.debug("Fetching collection records using names %s.", names) 

179 return self._fetch("name", names) 

180 

181 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: 

182 # Docstring inherited from base class. 

183 _LOG.debug("Fetching collection records using IDs %s.", collection_ids) 

184 return self._fetch(self._collectionIdName, collection_ids) 

185 

186 def _fetch( 

187 self, column_name: str, collections: Iterable[int | str] | None 

188 ) -> list[CollectionRecord[int]]: 

189 collection_chain = self._tables.collection_chain 

190 collection = self._tables.collection 

191 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( 

192 collection.join(self._tables.run, isouter=True) 

193 ) 

194 

195 chain_sql = ( 

196 sqlalchemy.sql.select( 

197 collection_chain.columns["parent"], 

198 collection_chain.columns["position"], 

199 collection.columns["name"].label("child_name"), 

200 ) 

201 .select_from(collection_chain) 

202 .join( 

203 collection, 

204 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], 

205 ) 

206 ) 

207 

208 records: list[CollectionRecord[int]] = [] 

209 # We want to keep transactions as short as possible. When we fetch 

210 # everything we want to quickly fetch things into memory and finish 

211 # transaction. When we fetch just few records we need to process first 

212 # query before wi can run second one, 

213 if collections is not None: 

214 sql = sql.where(collection.columns[column_name].in_(collections)) 

215 with self._db.transaction(): 

216 with self._db.query(sql) as sql_result: 

217 sql_rows = sql_result.mappings().fetchall() 

218 

219 records, chained_ids = self._rows_to_records(sql_rows) 

220 

221 if chained_ids: 

222 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) 

223 

224 with self._db.query(chain_sql) as sql_result: 

225 chain_rows = sql_result.mappings().fetchall() 

226 

227 records += self._rows_to_chains(chain_rows, chained_ids) 

228 

229 else: 

230 with self._db.transaction(): 

231 with self._db.query(sql) as sql_result: 

232 sql_rows = sql_result.mappings().fetchall() 

233 with self._db.query(chain_sql) as sql_result: 

234 chain_rows = sql_result.mappings().fetchall() 

235 

236 records, chained_ids = self._rows_to_records(sql_rows) 

237 records += self._rows_to_chains(chain_rows, chained_ids) 

238 

239 return records 

240 

241 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: 

242 """Convert rows returned from collection query to a list of records 

243 and a dict chained collection names. 

244 """ 

245 records: list[CollectionRecord[int]] = [] 

246 chained_ids: dict[int, str] = {} 

247 TimespanReprClass = self._db.getTimespanRepresentation() 

248 for row in rows: 

249 key: int = row[self._collectionIdName] 

250 name: str = row[self._tables.collection.columns.name] 

251 type = CollectionType(row["type"]) 

252 record: CollectionRecord[int] 

253 if type is CollectionType.RUN: 

254 record = RunRecord[int]( 

255 key=key, 

256 name=name, 

257 host=row[self._tables.run.columns.host], 

258 timespan=TimespanReprClass.extract(row), 

259 ) 

260 records.append(record) 

261 elif type is CollectionType.CHAINED: 

262 # Need to delay chained collection construction until to 

263 # fetch their children names. 

264 chained_ids[key] = name 

265 else: 

266 record = CollectionRecord[int](key=key, name=name, type=type) 

267 records.append(record) 

268 return records, chained_ids 

269 

270 def _rows_to_chains( 

271 self, rows: Iterable[Mapping], chained_ids: dict[int, str] 

272 ) -> list[CollectionRecord[int]]: 

273 """Convert rows returned from collection chain query to a list of 

274 records. 

275 """ 

276 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

277 for row in rows: 

278 chains_defs[row["parent"]].append((row["position"], row["child_name"])) 

279 

280 records: list[CollectionRecord[int]] = [] 

281 for key, children in chains_defs.items(): 

282 name = chained_ids[key] 

283 children_names = [child for _, child in sorted(children)] 

284 record = ChainedCollectionRecord[int]( 

285 key=key, 

286 name=name, 

287 children=children_names, 

288 ) 

289 records.append(record) 

290 

291 return records 

292 

293 @classmethod 

294 def currentVersions(cls) -> list[VersionTuple]: 

295 # Docstring inherited from VersionedExtension. 

296 return [_VERSION]