Coverage for python/lsst/daf/butler/registry/collections/_base.py: 97%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-05 02:52 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33from abc import abstractmethod 

34from collections.abc import Iterable, Iterator, Set 

35from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast 

36 

37import sqlalchemy 

38 

39from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError 

40from ...timespan_database_representation import TimespanDatabaseRepresentation 

41from .._collection_type import CollectionType 

42from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

43from ..wildcards import CollectionWildcard 

44 

45if TYPE_CHECKING: 

46 from .._caching_context import CachingContext 

47 from ..interfaces import Database 

48 

49 

50def _makeCollectionForeignKey( 

51 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

52) -> ddl.ForeignKeySpec: 

53 """Define foreign key specification that refers to collections table. 

54 

55 Parameters 

56 ---------- 

57 sourceColumnName : `str` 

58 Name of the column in the referring table. 

59 collectionIdName : `str` 

60 Name of the column in collections table that identifies it (PK). 

61 **kwargs 

62 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

63 

64 Returns 

65 ------- 

66 spec : `ddl.ForeignKeySpec` 

67 Foreign key specification. 

68 

69 Notes 

70 ----- 

71 This method assumes fixed name ("collection") of a collections table. 

72 There is also a general assumption that collection primary key consists 

73 of a single column. 

74 """ 

75 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

76 

77 

78_T = TypeVar("_T") 

79 

80 

81class CollectionTablesTuple(NamedTuple, Generic[_T]): 

82 collection: _T 

83 run: _T 

84 collection_chain: _T 

85 

86 

87def makeRunTableSpec( 

88 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

89) -> ddl.TableSpec: 

90 """Define specification for "run" table. 

91 

92 Parameters 

93 ---------- 

94 collectionIdName : `str` 

95 Name of the column in collections table that identifies it (PK). 

96 collectionIdType : `type` 

97 Type of the PK column in the collections table, one of the 

98 `sqlalchemy` types. 

99 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

100 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

101 timespans are stored in this database. 

102 

103 Returns 

104 ------- 

105 spec : `ddl.TableSpec` 

106 Specification for run table. 

107 

108 Notes 

109 ----- 

110 Assumption here and in the code below is that the name of the identifying 

111 column is the same in both collections and run tables. The names of 

112 non-identifying columns containing run metadata are fixed. 

113 """ 

114 result = ddl.TableSpec( 

115 fields=[ 

116 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

117 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

118 ], 

119 foreignKeys=[ 

120 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

121 ], 

122 ) 

123 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

124 result.fields.add(fieldSpec) 

125 return result 

126 

127 

128def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

129 """Define specification for "collection_chain" table. 

130 

131 Parameters 

132 ---------- 

133 collectionIdName : `str` 

134 Name of the column in collections table that identifies it (PK). 

135 collectionIdType : `type` 

136 Type of the PK column in the collections table, one of the 

137 `sqlalchemy` types. 

138 

139 Returns 

140 ------- 

141 spec : `ddl.TableSpec` 

142 Specification for collection chain table. 

143 

144 Notes 

145 ----- 

146 Collection chain is simply an ordered one-to-many relation between 

147 collections. The names of the columns in the table are fixed and 

148 also hardcoded in the code below. 

149 """ 

150 return ddl.TableSpec( 

151 fields=[ 

152 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

153 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

154 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

155 ], 

156 foreignKeys=[ 

157 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

158 _makeCollectionForeignKey("child", collectionIdName), 

159 ], 

160 ) 

161 

162 

163K = TypeVar("K") 

164 

165 

166class DefaultCollectionManager(CollectionManager[K]): 

167 """Default `CollectionManager` implementation. 

168 

169 This implementation uses record classes defined in this module and is 

170 based on the same assumptions about schema outlined in the record classes. 

171 

172 Parameters 

173 ---------- 

174 db : `Database` 

175 Interface to the underlying database engine and namespace. 

176 tables : `CollectionTablesTuple` 

177 Named tuple of SQLAlchemy table objects. 

178 collectionIdName : `str` 

179 Name of the column in collections table that identifies it (PK). 

180 caching_context : `CachingContext` 

181 Caching context to use. 

182 registry_schema_version : `VersionTuple` or `None`, optional 

183 The version of the registry schema. 

184 

185 Notes 

186 ----- 

187 Implementation uses "aggressive" pre-fetching and caching of the records 

188 in memory. Memory cache is synchronized from database when `refresh` 

189 method is called. 

190 """ 

191 

192 def __init__( 

193 self, 

194 db: Database, 

195 tables: CollectionTablesTuple[sqlalchemy.Table], 

196 collectionIdName: str, 

197 *, 

198 caching_context: CachingContext, 

199 registry_schema_version: VersionTuple | None = None, 

200 ): 

201 super().__init__(registry_schema_version=registry_schema_version) 

202 self._db = db 

203 self._tables = tables 

204 self._collectionIdName = collectionIdName 

205 self._caching_context = caching_context 

206 

207 def refresh(self) -> None: 

208 # Docstring inherited from CollectionManager. 

209 if self._caching_context.collection_records is not None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true

210 self._caching_context.collection_records.clear() 

211 

212 def _fetch_all(self) -> list[CollectionRecord[K]]: 

213 """Retrieve all records into cache if not done so yet.""" 

214 if self._caching_context.collection_records is not None: 

215 if self._caching_context.collection_records.full: 

216 return list(self._caching_context.collection_records.records()) 

217 records = self._fetch_by_key(None) 

218 if self._caching_context.collection_records is not None: 

219 self._caching_context.collection_records.set(records, full=True) 

220 return records 

221 

222 def register( 

223 self, name: str, type: CollectionType, doc: str | None = None 

224 ) -> tuple[CollectionRecord[K], bool]: 

225 # Docstring inherited from CollectionManager. 

226 registered = False 

227 record = self._getByName(name) 

228 if record is None: 

229 row, inserted_or_updated = self._db.sync( 

230 self._tables.collection, 

231 keys={"name": name}, 

232 compared={"type": int(type)}, 

233 extra={"doc": doc}, 

234 returning=[self._collectionIdName], 

235 ) 

236 assert isinstance(inserted_or_updated, bool) 

237 registered = inserted_or_updated 

238 assert row is not None 

239 collection_id = cast(K, row[self._collectionIdName]) 

240 if type is CollectionType.RUN: 

241 TimespanReprClass = self._db.getTimespanRepresentation() 

242 row, _ = self._db.sync( 

243 self._tables.run, 

244 keys={self._collectionIdName: collection_id}, 

245 returning=("host",) + TimespanReprClass.getFieldNames(), 

246 ) 

247 assert row is not None 

248 record = RunRecord[K]( 

249 key=collection_id, 

250 name=name, 

251 host=row["host"], 

252 timespan=TimespanReprClass.extract(row), 

253 ) 

254 elif type is CollectionType.CHAINED: 

255 record = ChainedCollectionRecord[K]( 

256 key=collection_id, 

257 name=name, 

258 children=[], 

259 ) 

260 else: 

261 record = CollectionRecord[K](key=collection_id, name=name, type=type) 

262 self._addCachedRecord(record) 

263 return record, registered 

264 

265 def remove(self, name: str) -> None: 

266 # Docstring inherited from CollectionManager. 

267 record = self._getByName(name) 

268 if record is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true

269 raise MissingCollectionError(f"No collection with name '{name}' found.") 

270 # This may raise 

271 self._db.delete( 

272 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

273 ) 

274 self._removeCachedRecord(record) 

275 

276 def find(self, name: str) -> CollectionRecord[K]: 

277 # Docstring inherited from CollectionManager. 

278 result = self._getByName(name) 

279 if result is None: 

280 raise MissingCollectionError(f"No collection with name '{name}' found.") 

281 return result 

282 

283 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

284 """Return multiple records given their names.""" 

285 names = list(names) 

286 # To protect against potential races in cache updates. 

287 records: dict[str, CollectionRecord | None] = {} 

288 if self._caching_context.collection_records is not None: 

289 for name in names: 

290 records[name] = self._caching_context.collection_records.get_by_name(name) 

291 fetch_names = [name for name, record in records.items() if record is None] 

292 else: 

293 fetch_names = list(names) 

294 records = {name: None for name in fetch_names} 

295 if fetch_names: 

296 for record in self._fetch_by_name(fetch_names): 

297 records[record.name] = record 

298 self._addCachedRecord(record) 

299 missing_names = [name for name, record in records.items() if record is None] 

300 if len(missing_names) == 1: 

301 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") 

302 elif len(missing_names) > 1: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true

303 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") 

304 return [cast(CollectionRecord[K], records[name]) for name in names] 

305 

306 def __getitem__(self, key: Any) -> CollectionRecord[K]: 

307 # Docstring inherited from CollectionManager. 

308 if self._caching_context.collection_records is not None: 

309 if (record := self._caching_context.collection_records.get_by_key(key)) is not None: 

310 return record 

311 if records := self._fetch_by_key([key]): 311 ↛ 317line 311 didn't jump to line 317, because the condition on line 311 was never false

312 record = records[0] 

313 if self._caching_context.collection_records is not None: 

314 self._caching_context.collection_records.add(record) 

315 return record 

316 else: 

317 raise MissingCollectionError(f"Collection with key '{key}' not found.") 

318 

319 def resolve_wildcard( 

320 self, 

321 wildcard: CollectionWildcard, 

322 *, 

323 collection_types: Set[CollectionType] = CollectionType.all(), 

324 done: set[str] | None = None, 

325 flatten_chains: bool = True, 

326 include_chains: bool | None = None, 

327 ) -> list[CollectionRecord[K]]: 

328 # Docstring inherited 

329 if done is None: 329 ↛ 331line 329 didn't jump to line 331, because the condition on line 329 was never false

330 done = set() 

331 include_chains = include_chains if include_chains is not None else not flatten_chains 

332 

333 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: 

334 if record.name in done: 

335 return 

336 if record.type in collection_types: 

337 done.add(record.name) 

338 if record.type is not CollectionType.CHAINED or include_chains: 

339 yield record 

340 if flatten_chains and record.type is CollectionType.CHAINED: 

341 done.add(record.name) 

342 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): 

343 # flake8 can't tell that we only delete this closure when 

344 # we're totally done with it. 

345 yield from resolve_nested(child, done) # noqa: F821 

346 

347 result: list[CollectionRecord[K]] = [] 

348 

349 if wildcard.patterns is ...: 

350 for record in self._fetch_all(): 

351 result.extend(resolve_nested(record, done)) 

352 del resolve_nested 

353 return result 

354 if wildcard.strings: 

355 for record in self._find_many(wildcard.strings): 

356 result.extend(resolve_nested(record, done)) 

357 if wildcard.patterns: 

358 for record in self._fetch_all(): 

359 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

360 result.extend(resolve_nested(record, done)) 

361 del resolve_nested 

362 return result 

363 

364 def getDocumentation(self, key: K) -> str | None: 

365 # Docstring inherited from CollectionManager. 

366 sql = ( 

367 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

368 .select_from(self._tables.collection) 

369 .where(self._tables.collection.columns[self._collectionIdName] == key) 

370 ) 

371 with self._db.query(sql) as sql_result: 

372 return sql_result.scalar() 

373 

374 def setDocumentation(self, key: K, doc: str | None) -> None: 

375 # Docstring inherited from CollectionManager. 

376 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

377 

378 def _addCachedRecord(self, record: CollectionRecord[K]) -> None: 

379 """Add single record to cache.""" 

380 if self._caching_context.collection_records is not None: 

381 self._caching_context.collection_records.add(record) 

382 

383 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: 

384 """Remove single record from cache.""" 

385 if self._caching_context.collection_records is not None: 385 ↛ 386line 385 didn't jump to line 386, because the condition on line 385 was never true

386 self._caching_context.collection_records.discard(record) 

387 

388 def _getByName(self, name: str) -> CollectionRecord[K] | None: 

389 """Find collection record given collection name.""" 

390 if self._caching_context.collection_records is not None: 

391 if (record := self._caching_context.collection_records.get_by_name(name)) is not None: 

392 return record 

393 records = self._fetch_by_name([name]) 

394 for record in records: 

395 self._addCachedRecord(record) 

396 return records[0] if records else None 

397 

398 @abstractmethod 

399 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

400 """Fetch collection record from database given its name.""" 

401 raise NotImplementedError() 

402 

403 @abstractmethod 

404 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: 

405 """Fetch collection record from database given its key, or fetch all 

406 collctions if argument is None. 

407 """ 

408 raise NotImplementedError() 

409 

410 def update_chain( 

411 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False 

412 ) -> ChainedCollectionRecord[K]: 

413 # Docstring inherited from CollectionManager. 

414 children = list(children) 

415 self._sanity_check_collection_cycles(chain.name, children) 

416 

417 if flatten: 

418 children = tuple( 

419 record.name 

420 for record in self.resolve_wildcard( 

421 CollectionWildcard.from_names(children), flatten_chains=True 

422 ) 

423 ) 

424 

425 child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False) 

426 names = [child.name for child in child_records] 

427 with self._db.transaction(): 

428 self._find_and_lock_collection_chain(chain.name) 

429 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) 

430 self._block_for_concurrency_test() 

431 self._insert_collection_chain_rows(chain.key, 0, [child.key for child in child_records]) 

432 

433 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) 

434 self._addCachedRecord(record) 

435 return record 

436 

437 def _sanity_check_collection_cycles( 

438 self, parent_collection_name: str, child_collection_names: list[str] 

439 ) -> None: 

440 """Raise an exception if any of the collections in the ``child_names`` 

441 list have ``parent_name`` as a child, creating a collection cycle. 

442 

443 This is only a sanity check, and does not guarantee that no collection 

444 cycles are possible. Concurrent updates might allow collection cycles 

445 to be inserted. 

446 """ 

447 for record in self.resolve_wildcard( 

448 CollectionWildcard.from_names(child_collection_names), 

449 flatten_chains=True, 

450 include_chains=True, 

451 collection_types={CollectionType.CHAINED}, 

452 ): 

453 if record.name == parent_collection_name: 

454 raise CollectionCycleError( 

455 f"Cycle in collection chaining when defining '{parent_collection_name}'." 

456 ) 

457 

458 def _insert_collection_chain_rows( 

459 self, 

460 parent_key: K, 

461 starting_position: int, 

462 child_keys: list[K], 

463 ) -> None: 

464 rows = [ 

465 { 

466 "parent": parent_key, 

467 "child": child, 

468 "position": position, 

469 } 

470 for position, child in enumerate(child_keys, starting_position) 

471 ] 

472 self._db.insert(self._tables.collection_chain, *rows) 

473 

474 def _remove_collection_chain_rows( 

475 self, 

476 parent_key: K, 

477 child_keys: list[K], 

478 ) -> None: 

479 table = self._tables.collection_chain 

480 where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys)) 

481 self._db.deleteWhere(table, where) 

482 

483 def prepend_collection_chain( 

484 self, parent_collection_name: str, child_collection_names: list[str] 

485 ) -> None: 

486 if self._caching_context.is_enabled: 

487 # Avoid having cache-maintenance code around that is unlikely to 

488 # ever be used. 

489 raise RuntimeError("Chained collection modification not permitted with active caching context.") 

490 

491 self._sanity_check_collection_cycles(parent_collection_name, child_collection_names) 

492 

493 child_records = self.resolve_wildcard( 

494 CollectionWildcard.from_names(child_collection_names), flatten_chains=False 

495 ) 

496 child_keys = [child.key for child in child_records] 

497 

498 with self._db.transaction(): 

499 parent_key = self._find_and_lock_collection_chain(parent_collection_name) 

500 self._remove_collection_chain_rows(parent_key, child_keys) 

501 starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys) 

502 self._block_for_concurrency_test() 

503 self._insert_collection_chain_rows(parent_key, starting_position, child_keys) 

504 

505 def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int: 

506 """Return the lowest-numbered position in a collection chain, or 0 if 

507 the chain is empty. 

508 """ 

509 table = self._tables.collection_chain 

510 query = sqlalchemy.select(sqlalchemy.func.min(table.c.position)).where(table.c.parent == chain_key) 

511 with self._db.query(query) as cursor: 

512 lowest_existing_position = cursor.scalar() 

513 

514 if lowest_existing_position is None: 

515 return 0 

516 

517 return lowest_existing_position 

518 

519 def _find_and_lock_collection_chain(self, collection_name: str) -> K: 

520 """ 

521 Take a row lock on the specified collection's row in the collections 

522 table, and return the collection's primary key. 

523 

524 This lock is used to synchronize updates to collection chains. 

525 

526 The locking strategy requires cooperation from everything modifying the 

527 collection chain table -- all operations that modify collection chains 

528 must obtain this lock first. The database will NOT automatically 

529 prevent modification of tables based on this lock. The only guarantee 

530 is that only one caller will be allowed to hold this lock for a given 

531 collection at a time. Concurrent calls will block until the caller 

532 holding the lock has completed its transaction. 

533 

534 Parameters 

535 ---------- 

536 collection_name : `str` 

537 Name of the collection whose chain is being modified. 

538 

539 Returns 

540 ------- 

541 id : ``K`` 

542 The primary key for the given collection. 

543 

544 Raises 

545 ------ 

546 MissingCollectionError 

547 If the specified collection is not in the database table. 

548 CollectionTypeError 

549 If the specified collection is not a chained collection. 

550 """ 

551 assert self._db.isInTransaction(), ( 

552 "Row locks are only held until the end of the current transaction," 

553 " so it makes no sense to take a lock outside a transaction." 

554 ) 

555 assert self._db.isWriteable(), "Collection row locks are only useful for write operations." 

556 

557 query = self._select_pkey_by_name(collection_name).with_for_update() 

558 with self._db.query(query) as cursor: 

559 rows = cursor.all() 

560 

561 if len(rows) == 0: 

562 raise MissingCollectionError( 

563 f"Parent collection {collection_name} not found when updating collection chain." 

564 ) 

565 assert len(rows) == 1, "There should only be one entry for each collection in collection table." 

566 r = rows[0]._mapping 

567 if r["type"] != CollectionType.CHAINED: 

568 raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.") 

569 return r["key"] 

570 

571 @abstractmethod 

572 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: 

573 """Return a SQLAlchemy select statement that will return columns from 

574 the one row in the ``collection` table matching the given name. The 

575 select statement includes two columns: 

576 

577 - ``key`` : the primary key for the collection 

578 - ``type`` : the collection type 

579 """ 

580 raise NotImplementedError()