Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%

118 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 10:13 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["SynthIntKeyCollectionManager"] 

32 

33import logging 

34from collections.abc import Iterable, Mapping 

35from typing import TYPE_CHECKING, Any 

36 

37import sqlalchemy 

38 

39from ...column_spec import COLLECTION_NAME_MAX_LENGTH 

40from ...timespan_database_representation import TimespanDatabaseRepresentation 

41from .._collection_type import CollectionType 

42from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

43from ._base import ( 

44 CollectionTablesTuple, 

45 DefaultCollectionManager, 

46 makeCollectionChainTableSpec, 

47 makeRunTableSpec, 

48) 

49 

50if TYPE_CHECKING: 

51 from .._caching_context import CachingContext 

52 from ..interfaces import Database, StaticTablesContext 

53 

54 

55_KEY_FIELD_SPEC = ddl.FieldSpec( 

56 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True 

57) 

58 

59 

60# This has to be updated on every schema change 

61_VERSION = VersionTuple(2, 0, 0) 

62 

63_LOG = logging.getLogger(__name__) 

64 

65 

66def _makeTableSpecs( 

67 TimespanReprClass: type[TimespanDatabaseRepresentation], 

68) -> CollectionTablesTuple[ddl.TableSpec]: 

69 return CollectionTablesTuple( 

70 collection=ddl.TableSpec( 

71 fields=[ 

72 _KEY_FIELD_SPEC, 

73 ddl.FieldSpec( 

74 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, nullable=False 

75 ), 

76 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

77 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

78 ], 

79 unique=[("name",)], 

80 ), 

81 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass), 

82 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger), 

83 ) 

84 

85 

86class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): 

87 """A `CollectionManager` implementation that uses synthetic primary key 

88 (auto-incremented integer) for collections table. 

89 """ 

90 

91 @classmethod 

92 def initialize( 

93 cls, 

94 db: Database, 

95 context: StaticTablesContext, 

96 *, 

97 caching_context: CachingContext, 

98 registry_schema_version: VersionTuple | None = None, 

99 ) -> SynthIntKeyCollectionManager: 

100 # Docstring inherited from CollectionManager. 

101 return cls( 

102 db, 

103 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

104 collectionIdName="collection_id", 

105 caching_context=caching_context, 

106 registry_schema_version=registry_schema_version, 

107 ) 

108 

109 def clone(self, db: Database, caching_context: CachingContext) -> SynthIntKeyCollectionManager: 

110 return SynthIntKeyCollectionManager( 

111 db, 

112 tables=self._tables, 

113 collectionIdName=self._collectionIdName, 

114 caching_context=caching_context, 

115 registry_schema_version=self._registry_schema_version, 

116 ) 

117 

118 @classmethod 

119 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

120 # Docstring inherited from CollectionManager. 

121 return f"{prefix}_id" 

122 

123 @classmethod 

124 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

125 # Docstring inherited from CollectionManager. 

126 return f"{prefix}_id" 

127 

128 @classmethod 

129 def addCollectionForeignKey( 

130 cls, 

131 tableSpec: ddl.TableSpec, 

132 *, 

133 prefix: str = "collection", 

134 onDelete: str | None = None, 

135 constraint: bool = True, 

136 **kwargs: Any, 

137 ) -> ddl.FieldSpec: 

138 # Docstring inherited from CollectionManager. 

139 original = _KEY_FIELD_SPEC 

140 copy = ddl.FieldSpec( 

141 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

142 ) 

143 tableSpec.fields.add(copy) 

144 if constraint: 

145 tableSpec.foreignKeys.append( 

146 ddl.ForeignKeySpec( 

147 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

148 ) 

149 ) 

150 return copy 

151 

152 @classmethod 

153 def addRunForeignKey( 

154 cls, 

155 tableSpec: ddl.TableSpec, 

156 *, 

157 prefix: str = "run", 

158 onDelete: str | None = None, 

159 constraint: bool = True, 

160 **kwargs: Any, 

161 ) -> ddl.FieldSpec: 

162 # Docstring inherited from CollectionManager. 

163 original = _KEY_FIELD_SPEC 

164 copy = ddl.FieldSpec( 

165 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs 

166 ) 

167 tableSpec.fields.add(copy) 

168 if constraint: 168 ↛ 172line 168 didn't jump to line 172, because the condition on line 168 was never false

169 tableSpec.foreignKeys.append( 

170 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

171 ) 

172 return copy 

173 

174 def getParentChains(self, key: int) -> set[str]: 

175 # Docstring inherited from CollectionManager. 

176 chain = self._tables.collection_chain 

177 collection = self._tables.collection 

178 sql = ( 

179 sqlalchemy.sql.select(collection.columns["name"]) 

180 .select_from(collection) 

181 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) 

182 .where(chain.columns["child"] == key) 

183 ) 

184 with self._db.query(sql) as sql_result: 

185 parent_names = set(sql_result.scalars().all()) 

186 return parent_names 

187 

188 def lookup_name_sql( 

189 self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: sqlalchemy.FromClause 

190 ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: 

191 # Docstring inherited. 

192 return ( 

193 self._tables.collection.c.name, 

194 sql_from_clause.join( 

195 self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key 

196 ), 

197 ) 

198 

199 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: 

200 # Docstring inherited from base class. 

201 _LOG.debug("Fetching collection records using names %s.", names) 

202 return self._fetch("name", names) 

203 

204 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: 

205 # Docstring inherited from base class. 

206 _LOG.debug("Fetching collection records using IDs %s.", collection_ids) 

207 return self._fetch(self._collectionIdName, collection_ids) 

208 

209 def _fetch( 

210 self, column_name: str, collections: Iterable[int | str] | None 

211 ) -> list[CollectionRecord[int]]: 

212 collection_chain = self._tables.collection_chain 

213 collection = self._tables.collection 

214 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( 

215 collection.join(self._tables.run, isouter=True) 

216 ) 

217 

218 chain_sql = ( 

219 sqlalchemy.sql.select( 

220 collection_chain.columns["parent"], 

221 collection_chain.columns["position"], 

222 collection.columns["name"].label("child_name"), 

223 ) 

224 .select_from(collection_chain) 

225 .join( 

226 collection, 

227 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], 

228 ) 

229 ) 

230 

231 records: list[CollectionRecord[int]] = [] 

232 # We want to keep transactions as short as possible. When we fetch 

233 # everything we want to quickly fetch things into memory and finish 

234 # transaction. When we fetch just few records we need to process first 

235 # query before wi can run second one, 

236 if collections is not None: 

237 sql = sql.where(collection.columns[column_name].in_(collections)) 

238 with self._db.transaction(): 

239 with self._db.query(sql) as sql_result: 

240 sql_rows = sql_result.mappings().fetchall() 

241 

242 records, chained_ids = self._rows_to_records(sql_rows) 

243 

244 if chained_ids: 

245 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) 

246 

247 with self._db.query(chain_sql) as sql_result: 

248 chain_rows = sql_result.mappings().fetchall() 

249 

250 records += self._rows_to_chains(chain_rows, chained_ids) 

251 

252 else: 

253 with self._db.transaction(): 

254 with self._db.query(sql) as sql_result: 

255 sql_rows = sql_result.mappings().fetchall() 

256 with self._db.query(chain_sql) as sql_result: 

257 chain_rows = sql_result.mappings().fetchall() 

258 

259 records, chained_ids = self._rows_to_records(sql_rows) 

260 records += self._rows_to_chains(chain_rows, chained_ids) 

261 

262 return records 

263 

264 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: 

265 """Convert rows returned from collection query to a list of records 

266 and a dict chained collection names. 

267 """ 

268 records: list[CollectionRecord[int]] = [] 

269 chained_ids: dict[int, str] = {} 

270 TimespanReprClass = self._db.getTimespanRepresentation() 

271 for row in rows: 

272 key: int = row[self._collectionIdName] 

273 name: str = row[self._tables.collection.columns.name] 

274 type = CollectionType(row["type"]) 

275 record: CollectionRecord[int] 

276 if type is CollectionType.RUN: 

277 record = RunRecord[int]( 

278 key=key, 

279 name=name, 

280 host=row[self._tables.run.columns.host], 

281 timespan=TimespanReprClass.extract(row), 

282 ) 

283 records.append(record) 

284 elif type is CollectionType.CHAINED: 

285 # Need to delay chained collection construction until to 

286 # fetch their children names. 

287 chained_ids[key] = name 

288 else: 

289 record = CollectionRecord[int](key=key, name=name, type=type) 

290 records.append(record) 

291 return records, chained_ids 

292 

293 def _rows_to_chains( 

294 self, rows: Iterable[Mapping], chained_ids: dict[int, str] 

295 ) -> list[CollectionRecord[int]]: 

296 """Convert rows returned from collection chain query to a list of 

297 records. 

298 """ 

299 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

300 for row in rows: 

301 chains_defs[row["parent"]].append((row["position"], row["child_name"])) 

302 

303 records: list[CollectionRecord[int]] = [] 

304 for key, children in chains_defs.items(): 

305 name = chained_ids[key] 

306 children_names = [child for _, child in sorted(children)] 

307 record = ChainedCollectionRecord[int]( 

308 key=key, 

309 name=name, 

310 children=children_names, 

311 ) 

312 records.append(record) 

313 

314 return records 

315 

316 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: 

317 table = self._tables.collection 

318 return sqlalchemy.select(table.c.collection_id.label("key"), table.c.type).where( 

319 table.c.name == collection_name 

320 ) 

321 

322 @classmethod 

323 def currentVersions(cls) -> list[VersionTuple]: 

324 # Docstring inherited from VersionedExtension. 

325 return [_VERSION]