Coverage for python/lsst/daf/butler/registry/collections/_base.py: 96%

176 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-16 10:43 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33import itertools 

34from abc import abstractmethod 

35from collections import namedtuple 

36from collections.abc import Iterable, Iterator, Set 

37from typing import TYPE_CHECKING, Any, TypeVar, cast 

38 

39import sqlalchemy 

40 

41from ...timespan_database_representation import TimespanDatabaseRepresentation 

42from .._collection_type import CollectionType 

43from .._exceptions import MissingCollectionError 

44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

45from ..wildcards import CollectionWildcard 

46 

47if TYPE_CHECKING: 

48 from .._caching_context import CachingContext 

49 from ..interfaces import Database, DimensionRecordStorageManager 

50 

51 

52def _makeCollectionForeignKey( 

53 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

54) -> ddl.ForeignKeySpec: 

55 """Define foreign key specification that refers to collections table. 

56 

57 Parameters 

58 ---------- 

59 sourceColumnName : `str` 

60 Name of the column in the referring table. 

61 collectionIdName : `str` 

62 Name of the column in collections table that identifies it (PK). 

63 **kwargs 

64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

65 

66 Returns 

67 ------- 

68 spec : `ddl.ForeignKeySpec` 

69 Foreign key specification. 

70 

71 Notes 

72 ----- 

73 This method assumes fixed name ("collection") of a collections table. 

74 There is also a general assumption that collection primary key consists 

75 of a single column. 

76 """ 

77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

78 

79 

80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

81 

82 

83def makeRunTableSpec( 

84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

85) -> ddl.TableSpec: 

86 """Define specification for "run" table. 

87 

88 Parameters 

89 ---------- 

90 collectionIdName : `str` 

91 Name of the column in collections table that identifies it (PK). 

92 collectionIdType : `type` 

93 Type of the PK column in the collections table, one of the 

94 `sqlalchemy` types. 

95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

97 timespans are stored in this database. 

98 

99 Returns 

100 ------- 

101 spec : `ddl.TableSpec` 

102 Specification for run table. 

103 

104 Notes 

105 ----- 

106 Assumption here and in the code below is that the name of the identifying 

107 column is the same in both collections and run tables. The names of 

108 non-identifying columns containing run metadata are fixed. 

109 """ 

110 result = ddl.TableSpec( 

111 fields=[ 

112 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

113 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

114 ], 

115 foreignKeys=[ 

116 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

117 ], 

118 ) 

119 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

120 result.fields.add(fieldSpec) 

121 return result 

122 

123 

124def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

125 """Define specification for "collection_chain" table. 

126 

127 Parameters 

128 ---------- 

129 collectionIdName : `str` 

130 Name of the column in collections table that identifies it (PK). 

131 collectionIdType : `type` 

132 Type of the PK column in the collections table, one of the 

133 `sqlalchemy` types. 

134 

135 Returns 

136 ------- 

137 spec : `ddl.TableSpec` 

138 Specification for collection chain table. 

139 

140 Notes 

141 ----- 

142 Collection chain is simply an ordered one-to-many relation between 

143 collections. The names of the columns in the table are fixed and 

144 also hardcoded in the code below. 

145 """ 

146 return ddl.TableSpec( 

147 fields=[ 

148 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

149 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

150 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

151 ], 

152 foreignKeys=[ 

153 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

154 _makeCollectionForeignKey("child", collectionIdName), 

155 ], 

156 ) 

157 

158 

159K = TypeVar("K") 

160 

161 

162class DefaultCollectionManager(CollectionManager[K]): 

163 """Default `CollectionManager` implementation. 

164 

165 This implementation uses record classes defined in this module and is 

166 based on the same assumptions about schema outlined in the record classes. 

167 

168 Parameters 

169 ---------- 

170 db : `Database` 

171 Interface to the underlying database engine and namespace. 

172 tables : `CollectionTablesTuple` 

173 Named tuple of SQLAlchemy table objects. 

174 collectionIdName : `str` 

175 Name of the column in collections table that identifies it (PK). 

176 dimensions : `DimensionRecordStorageManager` 

177 Manager object for the dimensions in this `Registry`. 

178 caching_context : `CachingContext` 

179 Caching context to use. 

180 registry_schema_version : `VersionTuple` or `None`, optional 

181 The version of the registry schema. 

182 

183 Notes 

184 ----- 

185 Implementation uses "aggressive" pre-fetching and caching of the records 

186 in memory. Memory cache is synchronized from database when `refresh` 

187 method is called. 

188 """ 

189 

190 def __init__( 

191 self, 

192 db: Database, 

193 tables: CollectionTablesTuple, 

194 collectionIdName: str, 

195 *, 

196 dimensions: DimensionRecordStorageManager, 

197 caching_context: CachingContext, 

198 registry_schema_version: VersionTuple | None = None, 

199 ): 

200 super().__init__(registry_schema_version=registry_schema_version) 

201 self._db = db 

202 self._tables = tables 

203 self._collectionIdName = collectionIdName 

204 self._dimensions = dimensions 

205 self._caching_context = caching_context 

206 

207 def refresh(self) -> None: 

208 # Docstring inherited from CollectionManager. 

209 if self._caching_context.collection_records is not None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true

210 self._caching_context.collection_records.clear() 

211 

212 def _fetch_all(self) -> list[CollectionRecord[K]]: 

213 """Retrieve all records into cache if not done so yet.""" 

214 if self._caching_context.collection_records is not None: 

215 if self._caching_context.collection_records.full: 

216 return list(self._caching_context.collection_records.records()) 

217 records = self._fetch_by_key(None) 

218 if self._caching_context.collection_records is not None: 

219 self._caching_context.collection_records.set(records, full=True) 

220 return records 

221 

222 def register( 

223 self, name: str, type: CollectionType, doc: str | None = None 

224 ) -> tuple[CollectionRecord[K], bool]: 

225 # Docstring inherited from CollectionManager. 

226 registered = False 

227 record = self._getByName(name) 

228 if record is None: 

229 row, inserted_or_updated = self._db.sync( 

230 self._tables.collection, 

231 keys={"name": name}, 

232 compared={"type": int(type)}, 

233 extra={"doc": doc}, 

234 returning=[self._collectionIdName], 

235 ) 

236 assert isinstance(inserted_or_updated, bool) 

237 registered = inserted_or_updated 

238 assert row is not None 

239 collection_id = cast(K, row[self._collectionIdName]) 

240 if type is CollectionType.RUN: 

241 TimespanReprClass = self._db.getTimespanRepresentation() 

242 row, _ = self._db.sync( 

243 self._tables.run, 

244 keys={self._collectionIdName: collection_id}, 

245 returning=("host",) + TimespanReprClass.getFieldNames(), 

246 ) 

247 assert row is not None 

248 record = RunRecord[K]( 

249 key=collection_id, 

250 name=name, 

251 host=row["host"], 

252 timespan=TimespanReprClass.extract(row), 

253 ) 

254 elif type is CollectionType.CHAINED: 

255 record = ChainedCollectionRecord[K]( 

256 key=collection_id, 

257 name=name, 

258 children=[], 

259 ) 

260 else: 

261 record = CollectionRecord[K](key=collection_id, name=name, type=type) 

262 self._addCachedRecord(record) 

263 return record, registered 

264 

265 def remove(self, name: str) -> None: 

266 # Docstring inherited from CollectionManager. 

267 record = self._getByName(name) 

268 if record is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true

269 raise MissingCollectionError(f"No collection with name '{name}' found.") 

270 # This may raise 

271 self._db.delete( 

272 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

273 ) 

274 self._removeCachedRecord(record) 

275 

276 def find(self, name: str) -> CollectionRecord[K]: 

277 # Docstring inherited from CollectionManager. 

278 result = self._getByName(name) 

279 if result is None: 

280 raise MissingCollectionError(f"No collection with name '{name}' found.") 

281 return result 

282 

283 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

284 """Return multiple records given their names.""" 

285 names = list(names) 

286 # To protect against potential races in cache updates. 

287 records: dict[str, CollectionRecord | None] = {} 

288 if self._caching_context.collection_records is not None: 

289 for name in names: 

290 records[name] = self._caching_context.collection_records.get_by_name(name) 

291 fetch_names = [name for name, record in records.items() if record is None] 

292 else: 

293 fetch_names = list(names) 

294 records = {name: None for name in fetch_names} 

295 if fetch_names: 

296 for record in self._fetch_by_name(fetch_names): 

297 records[record.name] = record 

298 self._addCachedRecord(record) 

299 missing_names = [name for name, record in records.items() if record is None] 

300 if len(missing_names) == 1: 

301 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") 

302 elif len(missing_names) > 1: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true

303 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") 

304 return [cast(CollectionRecord[K], records[name]) for name in names] 

305 

306 def __getitem__(self, key: Any) -> CollectionRecord[K]: 

307 # Docstring inherited from CollectionManager. 

308 if self._caching_context.collection_records is not None: 

309 if (record := self._caching_context.collection_records.get_by_key(key)) is not None: 

310 return record 

311 if records := self._fetch_by_key([key]): 311 ↛ 317line 311 didn't jump to line 317, because the condition on line 311 was never false

312 record = records[0] 

313 if self._caching_context.collection_records is not None: 

314 self._caching_context.collection_records.add(record) 

315 return record 

316 else: 

317 raise MissingCollectionError(f"Collection with key '{key}' not found.") 

318 

319 def resolve_wildcard( 

320 self, 

321 wildcard: CollectionWildcard, 

322 *, 

323 collection_types: Set[CollectionType] = CollectionType.all(), 

324 done: set[str] | None = None, 

325 flatten_chains: bool = True, 

326 include_chains: bool | None = None, 

327 ) -> list[CollectionRecord[K]]: 

328 # Docstring inherited 

329 if done is None: 329 ↛ 331line 329 didn't jump to line 331, because the condition on line 329 was never false

330 done = set() 

331 include_chains = include_chains if include_chains is not None else not flatten_chains 

332 

333 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: 

334 if record.name in done: 

335 return 

336 if record.type in collection_types: 

337 done.add(record.name) 

338 if record.type is not CollectionType.CHAINED or include_chains: 

339 yield record 

340 if flatten_chains and record.type is CollectionType.CHAINED: 

341 done.add(record.name) 

342 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): 

343 # flake8 can't tell that we only delete this closure when 

344 # we're totally done with it. 

345 yield from resolve_nested(child, done) # noqa: F821 

346 

347 result: list[CollectionRecord[K]] = [] 

348 

349 if wildcard.patterns is ...: 

350 for record in self._fetch_all(): 

351 result.extend(resolve_nested(record, done)) 

352 del resolve_nested 

353 return result 

354 if wildcard.strings: 

355 for record in self._find_many(wildcard.strings): 

356 result.extend(resolve_nested(record, done)) 

357 if wildcard.patterns: 

358 for record in self._fetch_all(): 

359 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

360 result.extend(resolve_nested(record, done)) 

361 del resolve_nested 

362 return result 

363 

364 def getDocumentation(self, key: K) -> str | None: 

365 # Docstring inherited from CollectionManager. 

366 sql = ( 

367 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

368 .select_from(self._tables.collection) 

369 .where(self._tables.collection.columns[self._collectionIdName] == key) 

370 ) 

371 with self._db.query(sql) as sql_result: 

372 return sql_result.scalar() 

373 

374 def setDocumentation(self, key: K, doc: str | None) -> None: 

375 # Docstring inherited from CollectionManager. 

376 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

377 

378 def _addCachedRecord(self, record: CollectionRecord[K]) -> None: 

379 """Add single record to cache.""" 

380 if self._caching_context.collection_records is not None: 

381 self._caching_context.collection_records.add(record) 

382 

383 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: 

384 """Remove single record from cache.""" 

385 if self._caching_context.collection_records is not None: 385 ↛ 386line 385 didn't jump to line 386, because the condition on line 385 was never true

386 self._caching_context.collection_records.discard(record) 

387 

388 def _getByName(self, name: str) -> CollectionRecord[K] | None: 

389 """Find collection record given collection name.""" 

390 if self._caching_context.collection_records is not None: 

391 if (record := self._caching_context.collection_records.get_by_name(name)) is not None: 

392 return record 

393 records = self._fetch_by_name([name]) 

394 for record in records: 

395 self._addCachedRecord(record) 

396 return records[0] if records else None 

397 

398 @abstractmethod 

399 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

400 """Fetch collection record from database given its name.""" 

401 raise NotImplementedError() 

402 

403 @abstractmethod 

404 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: 

405 """Fetch collection record from database given its key, or fetch all 

406 collctions if argument is None. 

407 """ 

408 raise NotImplementedError() 

409 

410 def update_chain( 

411 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False 

412 ) -> ChainedCollectionRecord[K]: 

413 # Docstring inherited from CollectionManager. 

414 children_as_wildcard = CollectionWildcard.from_names(children) 

415 for record in self.resolve_wildcard( 

416 children_as_wildcard, 

417 flatten_chains=True, 

418 include_chains=True, 

419 collection_types={CollectionType.CHAINED}, 

420 ): 

421 if record == chain: 

422 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") 

423 if flatten: 

424 children = tuple( 

425 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) 

426 ) 

427 

428 rows = [] 

429 position = itertools.count() 

430 names = [] 

431 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

432 rows.append( 

433 { 

434 "parent": chain.key, 

435 "child": child.key, 

436 "position": next(position), 

437 } 

438 ) 

439 names.append(child.name) 

440 with self._db.transaction(): 

441 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) 

442 self._db.insert(self._tables.collection_chain, *rows) 

443 

444 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) 

445 self._addCachedRecord(record) 

446 return record