Coverage for python / lsst / daf / butler / registry / collections / nameKey.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:55 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["NameKeyCollectionManager"] 

30 

31import logging 

32from collections.abc import Iterable, Mapping 

33from typing import TYPE_CHECKING, Any 

34 

35import sqlalchemy 

36 

37from ... import ddl 

38from ..._collection_type import CollectionType 

39from ...column_spec import COLLECTION_NAME_MAX_LENGTH 

40from ...timespan_database_representation import TimespanDatabaseRepresentation 

41from ..interfaces import ( 

42 ChainedCollectionRecord, 

43 CollectionRecord, 

44 Joinable, 

45 JoinedCollectionsTable, 

46 RunRecord, 

47 VersionTuple, 

48) 

49from ._base import ( 

50 CollectionTablesTuple, 

51 DefaultCollectionManager, 

52 makeCollectionChainTableSpec, 

53 makeRunTableSpec, 

54) 

55 

56if TYPE_CHECKING: 

57 from .._caching_context import CachingContext 

58 from ..interfaces import Database, StaticTablesContext 

59 

60 

61_KEY_FIELD_SPEC = ddl.FieldSpec( 

62 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, primaryKey=True 

63) 

64 

65 

66# This has to be updated on every schema change 

67_VERSION = VersionTuple(2, 0, 0) 

68 

69 

70_LOG = logging.getLogger(__name__) 

71 

72 

73def _makeTableSpecs( 

74 TimespanReprClass: type[TimespanDatabaseRepresentation], 

75) -> CollectionTablesTuple[ddl.TableSpec]: 

76 return CollectionTablesTuple( 

77 collection=ddl.TableSpec( 

78 fields=[ 

79 _KEY_FIELD_SPEC, 

80 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

81 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

82 ], 

83 ), 

84 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass), 

85 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String), 

86 ) 

87 

88 

89class NameKeyCollectionManager(DefaultCollectionManager[str]): 

90 """A `CollectionManager` implementation that uses collection names for 

91 primary/foreign keys and aggressively loads all collection/run records in 

92 the database into memory. 

93 

94 Most of the logic, including caching policy, is implemented in the base 

95 class, this class only adds customizations specific to this particular 

96 table schema. 

97 """ 

98 

99 @classmethod 

100 def initialize( 

101 cls, 

102 db: Database, 

103 context: StaticTablesContext, 

104 *, 

105 caching_context: CachingContext, 

106 registry_schema_version: VersionTuple | None = None, 

107 ) -> NameKeyCollectionManager: 

108 # Docstring inherited from CollectionManager. 

109 return cls( 

110 db, 

111 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

112 collectionIdName="name", 

113 caching_context=caching_context, 

114 registry_schema_version=registry_schema_version, 

115 ) 

116 

117 def clone(self, db: Database, caching_context: CachingContext) -> NameKeyCollectionManager: 

118 return NameKeyCollectionManager( 

119 db, 

120 tables=self._tables, 

121 collectionIdName=self._collectionIdName, 

122 caching_context=caching_context, 

123 registry_schema_version=self._registry_schema_version, 

124 ) 

125 

126 @classmethod 

127 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

128 # Docstring inherited from CollectionManager. 

129 return f"{prefix}_name" 

130 

131 @classmethod 

132 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

133 # Docstring inherited from CollectionManager. 

134 return f"{prefix}_name" 

135 

136 @classmethod 

137 def addCollectionForeignKey( 

138 cls, 

139 tableSpec: ddl.TableSpec, 

140 *, 

141 prefix: str = "collection", 

142 onDelete: str | None = None, 

143 constraint: bool = True, 

144 **kwargs: Any, 

145 ) -> ddl.FieldSpec: 

146 # Docstring inherited from CollectionManager. 

147 original = _KEY_FIELD_SPEC 

148 copy = ddl.FieldSpec( 

149 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

150 ) 

151 tableSpec.fields.add(copy) 

152 if constraint: 

153 tableSpec.foreignKeys.append( 

154 ddl.ForeignKeySpec( 

155 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

156 ) 

157 ) 

158 return copy 

159 

160 @classmethod 

161 def addRunForeignKey( 

162 cls, 

163 tableSpec: ddl.TableSpec, 

164 *, 

165 prefix: str = "run", 

166 onDelete: str | None = None, 

167 constraint: bool = True, 

168 **kwargs: Any, 

169 ) -> ddl.FieldSpec: 

170 # Docstring inherited from CollectionManager. 

171 original = _KEY_FIELD_SPEC 

172 copy = ddl.FieldSpec( 

173 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

174 ) 

175 tableSpec.fields.add(copy) 

176 if constraint: 

177 tableSpec.foreignKeys.append( 

178 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

179 ) 

180 return copy 

181 

182 def getParentChains(self, key: str) -> set[str]: 

183 # Docstring inherited from CollectionManager. 

184 table = self._tables.collection_chain 

185 sql = ( 

186 sqlalchemy.sql.select(table.columns["parent"]) 

187 .select_from(table) 

188 .where(table.columns["child"] == key) 

189 ) 

190 with self._db.query(sql) as sql_result: 

191 parent_names = set(sql_result.scalars().all()) 

192 return parent_names 

193 

194 def lookup_name_sql( 

195 self, sql_key: sqlalchemy.ColumnElement[str], sql_from_clause: Joinable 

196 ) -> tuple[sqlalchemy.ColumnElement[str], Joinable]: 

197 # Docstring inherited. 

198 return sql_key, sql_from_clause 

199 

200 def join_collections_sql( 

201 self, sql_key: sqlalchemy.ColumnElement[str], joinable: Joinable 

202 ) -> JoinedCollectionsTable: 

203 name_column = self._tables.collection.columns["name"] 

204 return JoinedCollectionsTable( 

205 joined_sql=joinable.join(self._tables.collection, onclause=name_column == sql_key), 

206 name_column=sql_key, 

207 type_column=self._tables.collection.columns["type"], 

208 ) 

209 

210 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[str]]: 

211 # Docstring inherited from base class. 

212 if flatten_chains: 

213 sql_rows = self._query_recursive(names, _KEY_FIELD_SPEC.dtype) 

214 

215 # There may be duplicates in the result, select unique names. 

216 unique_rows = {row[self._collectionIdName]: row for row in sql_rows} 

217 

218 records, chained_ids = self._rows_to_records(unique_rows.values()) 

219 records += self._rows_to_chains(sql_rows, chained_ids) 

220 

221 return records 

222 else: 

223 return self._fetch_by_key(names) 

224 

225 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: 

226 # Docstring inherited from base class. 

227 _LOG.debug("Fetching collection records using names %s.", collection_ids) 

228 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( 

229 self._tables.collection.join(self._tables.run, isouter=True) 

230 ) 

231 

232 # "Rename" child column to "name" as expected by _rows_to_chains() 

233 chain_sql = sqlalchemy.sql.select( 

234 self._tables.collection_chain.columns["parent"], 

235 self._tables.collection_chain.columns["position"], 

236 self._tables.collection_chain.columns["child"].label("name"), 

237 ) 

238 

239 records: list[CollectionRecord[str]] = [] 

240 # We want to keep transactions as short as possible. When we fetch 

241 # everything we want to quickly fetch things into memory and finish 

242 # transaction. When we fetch just few records we need to process result 

243 # of the first query before we can run the second one. 

244 if collection_ids is not None: 

245 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) 

246 with self._db.transaction(): 

247 with self._db.query(sql) as sql_result: 

248 sql_rows = sql_result.mappings().fetchall() 

249 

250 records, chained_ids = self._rows_to_records(sql_rows) 

251 

252 if chained_ids: 

253 # Retrieve chained collection compositions 

254 chain_sql = chain_sql.where( 

255 self._tables.collection_chain.columns["parent"].in_(chained_ids) 

256 ) 

257 with self._db.query(chain_sql) as sql_result: 

258 chain_rows = sql_result.mappings().fetchall() 

259 

260 records += self._rows_to_chains(chain_rows, chained_ids) 

261 

262 else: 

263 with self._db.transaction(): 

264 with self._db.query(sql) as sql_result: 

265 sql_rows = sql_result.mappings().fetchall() 

266 with self._db.query(chain_sql) as sql_result: 

267 chain_rows = sql_result.mappings().fetchall() 

268 

269 records, chained_ids = self._rows_to_records(sql_rows) 

270 records += self._rows_to_chains(chain_rows, chained_ids) 

271 

272 return records 

273 

274 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: 

275 """Convert rows returned from collection query to a list of records 

276 and a list chained collection names. 

277 """ 

278 records: list[CollectionRecord[str]] = [] 

279 TimespanReprClass = self._db.getTimespanRepresentation() 

280 chained_ids: list[str] = [] 

281 for row in rows: 

282 name = row["name"] 

283 type = CollectionType(row["type"]) 

284 record: CollectionRecord[str] 

285 if type is CollectionType.RUN: 

286 record = RunRecord[str]( 

287 key=name, 

288 name=name, 

289 host=row[self._tables.run.columns.host], 

290 timespan=TimespanReprClass.extract(row), 

291 ) 

292 records.append(record) 

293 elif type is CollectionType.CHAINED: 

294 # Need to delay chained collection construction until to 

295 # fetch their children names. 

296 chained_ids.append(name) 

297 else: 

298 record = CollectionRecord[str](key=name, name=name, type=type) 

299 records.append(record) 

300 

301 return records, chained_ids 

302 

303 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: 

304 """Convert rows returned from collection chain query to a list of 

305 records. 

306 """ 

307 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

308 for row in rows: 

309 if row["parent"] is not None: 

310 chains_defs[row["parent"]].append((row["position"], row["name"])) 

311 

312 records: list[CollectionRecord[str]] = [] 

313 for name, children in chains_defs.items(): 

314 children_names = [child for _, child in sorted(children)] 

315 record = ChainedCollectionRecord[str]( 

316 key=name, 

317 name=name, 

318 children=children_names, 

319 ) 

320 records.append(record) 

321 

322 return records 

323 

324 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: 

325 table = self._tables.collection 

326 return sqlalchemy.select(table.c.name.label("key"), table.c.type).where( 

327 table.c.name == collection_name 

328 ) 

329 

330 @classmethod 

331 def currentVersions(cls) -> list[VersionTuple]: 

332 # Docstring inherited from VersionedExtension. 

333 return [_VERSION]