Coverage for python / lsst / daf / butler / registry / collections / synthIntKey.py: 0%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:49 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["SynthIntKeyCollectionManager"] 

32 

33import logging 

34from collections.abc import Iterable, Mapping 

35from typing import TYPE_CHECKING, Any 

36 

37import sqlalchemy 

38 

39from ..._collection_type import CollectionType 

40from ...column_spec import COLLECTION_NAME_MAX_LENGTH 

41from ...timespan_database_representation import TimespanDatabaseRepresentation 

42from ..interfaces import ( 

43 ChainedCollectionRecord, 

44 CollectionRecord, 

45 Joinable, 

46 JoinedCollectionsTable, 

47 RunRecord, 

48 VersionTuple, 

49) 

50from ._base import ( 

51 CollectionTablesTuple, 

52 DefaultCollectionManager, 

53 makeCollectionChainTableSpec, 

54 makeRunTableSpec, 

55) 

56 

57if TYPE_CHECKING: 

58 from .._caching_context import CachingContext 

59 from ..interfaces import Database, StaticTablesContext 

60 

61 

62_KEY_FIELD_SPEC = ddl.FieldSpec( 

63 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True 

64) 

65 

66 

67# This has to be updated on every schema change 

68_VERSION = VersionTuple(2, 0, 0) 

69 

70_LOG = logging.getLogger(__name__) 

71 

72 

73def _makeTableSpecs( 

74 TimespanReprClass: type[TimespanDatabaseRepresentation], 

75) -> CollectionTablesTuple[ddl.TableSpec]: 

76 return CollectionTablesTuple( 

77 collection=ddl.TableSpec( 

78 fields=[ 

79 _KEY_FIELD_SPEC, 

80 ddl.FieldSpec( 

81 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, nullable=False 

82 ), 

83 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

84 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

85 ], 

86 unique=[("name",)], 

87 ), 

88 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass), 

89 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger), 

90 ) 

91 

92 

93class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): 

94 """A `CollectionManager` implementation that uses synthetic primary key 

95 (auto-incremented integer) for collections table. 

96 """ 

97 

98 @classmethod 

99 def initialize( 

100 cls, 

101 db: Database, 

102 context: StaticTablesContext, 

103 *, 

104 caching_context: CachingContext, 

105 registry_schema_version: VersionTuple | None = None, 

106 ) -> SynthIntKeyCollectionManager: 

107 # Docstring inherited from CollectionManager. 

108 return cls( 

109 db, 

110 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

111 collectionIdName="collection_id", 

112 caching_context=caching_context, 

113 registry_schema_version=registry_schema_version, 

114 ) 

115 

116 def clone(self, db: Database, caching_context: CachingContext) -> SynthIntKeyCollectionManager: 

117 return SynthIntKeyCollectionManager( 

118 db, 

119 tables=self._tables, 

120 collectionIdName=self._collectionIdName, 

121 caching_context=caching_context, 

122 registry_schema_version=self._registry_schema_version, 

123 ) 

124 

125 @classmethod 

126 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

127 # Docstring inherited from CollectionManager. 

128 return f"{prefix}_id" 

129 

130 @classmethod 

131 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

132 # Docstring inherited from CollectionManager. 

133 return f"{prefix}_id" 

134 

135 @classmethod 

136 def addCollectionForeignKey( 

137 cls, 

138 tableSpec: ddl.TableSpec, 

139 *, 

140 prefix: str = "collection", 

141 onDelete: str | None = None, 

142 constraint: bool = True, 

143 **kwargs: Any, 

144 ) -> ddl.FieldSpec: 

145 # Docstring inherited from CollectionManager. 

146 original = _KEY_FIELD_SPEC 

147 copy = ddl.FieldSpec( 

148 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

149 ) 

150 tableSpec.fields.add(copy) 

151 if constraint: 

152 tableSpec.foreignKeys.append( 

153 ddl.ForeignKeySpec( 

154 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

155 ) 

156 ) 

157 return copy 

158 

159 @classmethod 

160 def addRunForeignKey( 

161 cls, 

162 tableSpec: ddl.TableSpec, 

163 *, 

164 prefix: str = "run", 

165 onDelete: str | None = None, 

166 constraint: bool = True, 

167 **kwargs: Any, 

168 ) -> ddl.FieldSpec: 

169 # Docstring inherited from CollectionManager. 

170 original = _KEY_FIELD_SPEC 

171 copy = ddl.FieldSpec( 

172 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

173 ) 

174 tableSpec.fields.add(copy) 

175 if constraint: 

176 tableSpec.foreignKeys.append( 

177 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

178 ) 

179 return copy 

180 

181 def getParentChains(self, key: int) -> set[str]: 

182 # Docstring inherited from CollectionManager. 

183 chain = self._tables.collection_chain 

184 collection = self._tables.collection 

185 sql = ( 

186 sqlalchemy.sql.select(collection.columns["name"]) 

187 .select_from(collection) 

188 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) 

189 .where(chain.columns["child"] == key) 

190 ) 

191 with self._db.query(sql) as sql_result: 

192 parent_names = set(sql_result.scalars().all()) 

193 return parent_names 

194 

195 def lookup_name_sql( 

196 self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: Joinable 

197 ) -> tuple[sqlalchemy.ColumnElement[str], Joinable]: 

198 # Docstring inherited. 

199 joined = self.join_collections_sql(sql_key, sql_from_clause) 

200 return (joined.name_column, joined.joined_sql) 

201 

202 def join_collections_sql( 

203 self, sql_key: sqlalchemy.ColumnElement[int], joinable: Joinable 

204 ) -> JoinedCollectionsTable: 

205 return JoinedCollectionsTable( 

206 joined_sql=joinable.join( 

207 self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key 

208 ), 

209 name_column=self._tables.collection.columns["name"], 

210 type_column=self._tables.collection.columns["type"], 

211 ) 

212 

213 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[int]]: 

214 # Docstring inherited from base class. 

215 _LOG.debug("Fetching collection records using names %s.", names) 

216 if flatten_chains: 

217 sql_rows = self._query_recursive(names, _KEY_FIELD_SPEC.dtype) 

218 

219 # There may be duplicates in the result, select unique IDs. 

220 unique_rows = {row[self._collectionIdName]: row for row in sql_rows} 

221 

222 records, chained_ids = self._rows_to_records(unique_rows.values()) 

223 records += self._rows_to_chains(sql_rows, chained_ids) 

224 

225 return records 

226 else: 

227 return self._fetch("name", names) 

228 

229 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: 

230 # Docstring inherited from base class. 

231 _LOG.debug("Fetching collection records using IDs %s.", collection_ids) 

232 return self._fetch(self._collectionIdName, collection_ids) 

233 

234 def _fetch( 

235 self, column_name: str, collections: Iterable[int | str] | None 

236 ) -> list[CollectionRecord[int]]: 

237 collection_chain = self._tables.collection_chain 

238 collection = self._tables.collection 

239 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( 

240 collection.join(self._tables.run, isouter=True) 

241 ) 

242 

243 chain_sql = ( 

244 sqlalchemy.sql.select( 

245 collection_chain.columns["parent"], 

246 collection_chain.columns["position"], 

247 collection.columns["name"], 

248 ) 

249 .select_from(collection_chain) 

250 .join( 

251 collection, 

252 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], 

253 ) 

254 ) 

255 

256 records: list[CollectionRecord[int]] = [] 

257 # We want to keep transactions as short as possible. When we fetch 

258 # everything we want to quickly fetch things into memory and finish 

259 # transaction. When we fetch just few records we need to process first 

260 # query before wi can run second one, 

261 if collections is not None: 

262 sql = sql.where(collection.columns[column_name].in_(collections)) 

263 with self._db.transaction(): 

264 with self._db.query(sql) as sql_result: 

265 sql_rows = sql_result.mappings().fetchall() 

266 

267 records, chained_ids = self._rows_to_records(sql_rows) 

268 

269 if chained_ids: 

270 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) 

271 

272 with self._db.query(chain_sql) as sql_result: 

273 chain_rows = sql_result.mappings().fetchall() 

274 

275 records += self._rows_to_chains(chain_rows, chained_ids) 

276 

277 else: 

278 with self._db.transaction(): 

279 with self._db.query(sql) as sql_result: 

280 sql_rows = sql_result.mappings().fetchall() 

281 with self._db.query(chain_sql) as sql_result: 

282 chain_rows = sql_result.mappings().fetchall() 

283 

284 records, chained_ids = self._rows_to_records(sql_rows) 

285 records += self._rows_to_chains(chain_rows, chained_ids) 

286 

287 return records 

288 

289 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: 

290 """Convert rows returned from collection query to a list of records 

291 and a dict chained collection names. 

292 """ 

293 records: list[CollectionRecord[int]] = [] 

294 chained_ids: dict[int, str] = {} 

295 TimespanReprClass = self._db.getTimespanRepresentation() 

296 for row in rows: 

297 key: int = row[self._collectionIdName] 

298 name: str = row["name"] 

299 type = CollectionType(row["type"]) 

300 record: CollectionRecord[int] 

301 if type is CollectionType.RUN: 

302 record = RunRecord[int]( 

303 key=key, 

304 name=name, 

305 host=row[self._tables.run.columns.host], 

306 timespan=TimespanReprClass.extract(row), 

307 ) 

308 records.append(record) 

309 elif type is CollectionType.CHAINED: 

310 # Need to delay chained collection construction until to 

311 # fetch their children names. 

312 chained_ids[key] = name 

313 else: 

314 record = CollectionRecord[int](key=key, name=name, type=type) 

315 records.append(record) 

316 return records, chained_ids 

317 

318 def _rows_to_chains( 

319 self, rows: Iterable[Mapping], chained_ids: dict[int, str] 

320 ) -> list[CollectionRecord[int]]: 

321 """Convert rows returned from collection chain query to a list of 

322 records. 

323 """ 

324 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

325 for row in rows: 

326 if row["parent"] is not None: 

327 chains_defs[row["parent"]].append((row["position"], row["name"])) 

328 

329 records: list[CollectionRecord[int]] = [] 

330 for key, children in chains_defs.items(): 

331 name = chained_ids[key] 

332 children_names = [child for _, child in sorted(children)] 

333 record = ChainedCollectionRecord[int]( 

334 key=key, 

335 name=name, 

336 children=children_names, 

337 ) 

338 records.append(record) 

339 

340 return records 

341 

342 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: 

343 table = self._tables.collection 

344 return sqlalchemy.select(table.c.collection_id.label("key"), table.c.type).where( 

345 table.c.name == collection_name 

346 ) 

347 

348 @classmethod 

349 def currentVersions(cls) -> list[VersionTuple]: 

350 # Docstring inherited from VersionedExtension. 

351 return [_VERSION]