Coverage for python / lsst / daf / butler / registry / collections / _base.py: 0%

293 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:30 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33from abc import abstractmethod 

34from collections.abc import Callable, Iterable, Iterator, Mapping, Set 

35from contextlib import contextmanager 

36from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeVar, cast 

37 

38import sqlalchemy 

39 

40from lsst.utils.iteration import chunk_iterable 

41 

42from ..._collection_type import CollectionType 

43from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError 

44from ...timespan_database_representation import TimespanDatabaseRepresentation 

45from .._collection_record_cache import CollectionRecordCache 

46from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

47from ..wildcards import CollectionWildcard 

48 

49if TYPE_CHECKING: 

50 from .._caching_context import CachingContext 

51 from ..interfaces import Database 

52 

53 

54def _makeCollectionForeignKey( 

55 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

56) -> ddl.ForeignKeySpec: 

57 """Define foreign key specification that refers to collections table. 

58 

59 Parameters 

60 ---------- 

61 sourceColumnName : `str` 

62 Name of the column in the referring table. 

63 collectionIdName : `str` 

64 Name of the column in collections table that identifies it (PK). 

65 **kwargs 

66 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

67 

68 Returns 

69 ------- 

70 spec : `ddl.ForeignKeySpec` 

71 Foreign key specification. 

72 

73 Notes 

74 ----- 

75 This method assumes fixed name ("collection") of a collections table. 

76 There is also a general assumption that collection primary key consists 

77 of a single column. 

78 """ 

79 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

80 

81 

82_T = TypeVar("_T") 

83 

84 

85class CollectionTablesTuple(NamedTuple, Generic[_T]): 

86 collection: _T 

87 run: _T 

88 collection_chain: _T 

89 

90 

91def makeRunTableSpec( 

92 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

93) -> ddl.TableSpec: 

94 """Define specification for "run" table. 

95 

96 Parameters 

97 ---------- 

98 collectionIdName : `str` 

99 Name of the column in collections table that identifies it (PK). 

100 collectionIdType : `type` 

101 Type of the PK column in the collections table, one of the 

102 `sqlalchemy` types. 

103 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

104 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

105 timespans are stored in this database. 

106 

107 Returns 

108 ------- 

109 spec : `ddl.TableSpec` 

110 Specification for run table. 

111 

112 Notes 

113 ----- 

114 Assumption here and in the code below is that the name of the identifying 

115 column is the same in both collections and run tables. The names of 

116 non-identifying columns containing run metadata are fixed. 

117 """ 

118 result = ddl.TableSpec( 

119 fields=[ 

120 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

121 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

122 ], 

123 foreignKeys=[ 

124 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

125 ], 

126 ) 

127 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

128 result.fields.add(fieldSpec) 

129 return result 

130 

131 

132def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

133 """Define specification for "collection_chain" table. 

134 

135 Parameters 

136 ---------- 

137 collectionIdName : `str` 

138 Name of the column in collections table that identifies it (PK). 

139 collectionIdType : `type` 

140 Type of the PK column in the collections table, one of the 

141 `sqlalchemy` types. 

142 

143 Returns 

144 ------- 

145 spec : `ddl.TableSpec` 

146 Specification for collection chain table. 

147 

148 Notes 

149 ----- 

150 Collection chain is simply an ordered one-to-many relation between 

151 collections. The names of the columns in the table are fixed and 

152 also hardcoded in the code below. 

153 """ 

154 return ddl.TableSpec( 

155 fields=[ 

156 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

157 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

158 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

159 ], 

160 foreignKeys=[ 

161 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

162 _makeCollectionForeignKey("child", collectionIdName), 

163 ], 

164 ) 

165 

166 

167K = TypeVar("K") 

168 

169 

170class DefaultCollectionManager(CollectionManager[K]): 

171 """Default `CollectionManager` implementation. 

172 

173 This implementation uses record classes defined in this module and is 

174 based on the same assumptions about schema outlined in the record classes. 

175 

176 Parameters 

177 ---------- 

178 db : `Database` 

179 Interface to the underlying database engine and namespace. 

180 tables : `CollectionTablesTuple` 

181 Named tuple of SQLAlchemy table objects. 

182 collectionIdName : `str` 

183 Name of the column in collections table that identifies it (PK). 

184 caching_context : `CachingContext` 

185 Caching context to use. 

186 registry_schema_version : `VersionTuple` or `None`, optional 

187 The version of the registry schema. 

188 

189 Notes 

190 ----- 

191 Implementation uses "aggressive" pre-fetching and caching of the records 

192 in memory. Memory cache is synchronized from database when `refresh` 

193 method is called. 

194 """ 

195 

196 def __init__( 

197 self, 

198 db: Database, 

199 tables: CollectionTablesTuple[sqlalchemy.Table], 

200 collectionIdName: str, 

201 *, 

202 caching_context: CachingContext, 

203 registry_schema_version: VersionTuple | None = None, 

204 ): 

205 super().__init__(registry_schema_version=registry_schema_version) 

206 self._db = db 

207 self._tables = tables 

208 self._collectionIdName = collectionIdName 

209 self._caching_context = caching_context 

210 

211 def refresh(self) -> None: 

212 # Docstring inherited from CollectionManager. 

213 if self._caching_context.collection_records is not None: 

214 self._caching_context.collection_records.clear() 

215 

216 def _fetch_all(self, collection_cache: CollectionRecordCache | None = None) -> list[CollectionRecord[K]]: 

217 """Retrieve all records into cache if not done so yet.""" 

218 if collection_cache is None: 

219 collection_cache = self._caching_context.collection_records 

220 if collection_cache is not None: 

221 if collection_cache.full: 

222 return list(collection_cache.records()) 

223 records = self._fetch_by_key(None) 

224 if collection_cache is not None: 

225 collection_cache.set(records, full=True) 

226 return records 

227 

228 def register( 

229 self, name: str, type: CollectionType, doc: str | None = None 

230 ) -> tuple[CollectionRecord[K], bool]: 

231 # Docstring inherited from CollectionManager. 

232 registered = False 

233 record = self._getByName(name) 

234 if record is None: 

235 row, inserted_or_updated = self._db.sync( 

236 self._tables.collection, 

237 keys={"name": name}, 

238 compared={"type": int(type)}, 

239 extra={"doc": doc}, 

240 returning=[self._collectionIdName], 

241 ) 

242 assert isinstance(inserted_or_updated, bool) 

243 registered = inserted_or_updated 

244 assert row is not None 

245 collection_id = cast(K, row[self._collectionIdName]) 

246 if type is CollectionType.RUN: 

247 TimespanReprClass = self._db.getTimespanRepresentation() 

248 row, _ = self._db.sync( 

249 self._tables.run, 

250 keys={self._collectionIdName: collection_id}, 

251 returning=("host",) + TimespanReprClass.getFieldNames(), 

252 ) 

253 assert row is not None 

254 record = RunRecord[K]( 

255 key=collection_id, 

256 name=name, 

257 host=row["host"], 

258 timespan=TimespanReprClass.extract(row), 

259 ) 

260 elif type is CollectionType.CHAINED: 

261 record = ChainedCollectionRecord[K]( 

262 key=collection_id, 

263 name=name, 

264 children=[], 

265 ) 

266 else: 

267 record = CollectionRecord[K](key=collection_id, name=name, type=type) 

268 self._addCachedRecord(record) 

269 return record, registered 

270 

271 def remove(self, name: str) -> None: 

272 # Docstring inherited from CollectionManager. 

273 record = self._getByName(name) 

274 if record is None: 

275 raise MissingCollectionError(f"No collection with name '{name}' found.") 

276 # This may raise 

277 self._db.delete( 

278 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

279 ) 

280 self._removeCachedRecord(record) 

281 

282 def find(self, name: str) -> CollectionRecord[K]: 

283 # Docstring inherited from CollectionManager. 

284 result = self._getByName(name) 

285 if result is None: 

286 raise MissingCollectionError(f"No collection with name '{name}' found.") 

287 return result 

288 

289 def _find_many( 

290 self, names: Iterable[str], flatten_chains: bool, collection_cache: CollectionRecordCache | None 

291 ) -> list[CollectionRecord[K]]: 

292 """Return multiple records given their names. 

293 

294 Parameters 

295 ---------- 

296 names : `~collections.abc.Iterable` [`str`] 

297 Collection names to search for. 

298 flatten_chains : `bool` 

299 If `True` then also retrieve recursively collection records for all 

300 chained collections in the input list. Child collections are not 

301 returned but are stored in collection cache. 

302 collection_cache : `CollectionRecordCache` 

303 If `None` then the cache from the caching context will be used if 

304 that is not `None`. Collections are searched in the cache first, 

305 collections that are missing from the cache are fetched from 

306 database. All fetched collections are added to the cache. 

307 

308 Returns 

309 ------- 

310 records : `list` [`CollectionRecord`] 

311 Collection records. Records are ordered according to the input list 

312 and expanded depth-first if ``flatten_chains`` is True. 

313 """ 

314 

315 def check_cache( 

316 name: str, cache: CollectionRecordCache, flatten_chains: bool 

317 ) -> list[CollectionRecord[K]]: 

318 """Check that cache contains a record for a given name and all its 

319 child records if ``flatten_chain`` is True. 

320 

321 Parameters 

322 ---------- 

323 name : `str` 

324 Collection name. 

325 cache : `CollectionRecordCache` 

326 Record cache. 

327 flatten_chains : `bool` 

328 If `True` then return all children records recursively. 

329 

330 Returns 

331 ------- 

332 records : `list` [`CollectionRecord`] 

333 Records from cache, including all child records if 

334 ``flatten_chains`` is True. 

335 

336 Raises 

337 ------ 

338 LookupError 

339 Raised if any record is missing from cache. If LookupError is 

340 raised then no records are generated. 

341 """ 

342 record = cache.get_by_name(name) 

343 if record is not None: 

344 records = [record] 

345 if flatten_chains and record.type is CollectionType.CHAINED: 

346 # Check all children recursively. 

347 for child_name in cast(ChainedCollectionRecord, record).children: 

348 records += check_cache(child_name, cache, flatten_chains) 

349 return records 

350 else: 

351 raise LookupError(name) 

352 

353 names = list(names) 

354 if collection_cache is None: 

355 collection_cache = self._caching_context.collection_records 

356 

357 # To protect against potential races in cache updates. 

358 records: dict[str, CollectionRecord[K]] = {} 

359 fetch_names = [] 

360 if collection_cache is not None: 

361 for name in names: 

362 try: 

363 for record in check_cache(name, collection_cache, flatten_chains): 

364 records[record.name] = record 

365 except LookupError: 

366 fetch_names.append(name) 

367 else: 

368 fetch_names = names 

369 

370 if fetch_names: 

371 # Fetch all missing collections and optionally their children. 

372 for record in self._fetch_by_name(fetch_names, flatten_chains): 

373 records[record.name] = record 

374 self._addCachedRecord(record, collection_cache) 

375 

376 missing_names = [name for name in names if name not in records] 

377 if len(missing_names) == 1: 

378 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") 

379 elif len(missing_names) > 1: 

380 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") 

381 

382 def order(names: Iterable[str]) -> Iterator[CollectionRecord[K]]: 

383 for name in names: 

384 record = records[name] 

385 yield record 

386 if flatten_chains and record.type is CollectionType.CHAINED: 

387 # Also return all children recursively. 

388 yield from order(cast(ChainedCollectionRecord, record).children) 

389 

390 return list(order(names)) 

391 

392 def __getitem__(self, key: Any) -> CollectionRecord[K]: 

393 # Docstring inherited from CollectionManager. 

394 if self._caching_context.collection_records is not None: 

395 if (record := self._caching_context.collection_records.get_by_key(key)) is not None: 

396 return record 

397 if records := self._fetch_by_key([key]): 

398 record = records[0] 

399 if self._caching_context.collection_records is not None: 

400 self._caching_context.collection_records.add(record) 

401 return record 

402 else: 

403 raise MissingCollectionError(f"Collection with key '{key}' not found.") 

404 

405 def resolve_wildcard( 

406 self, 

407 wildcard: CollectionWildcard, 

408 *, 

409 collection_types: Set[CollectionType] = CollectionType.all(), 

410 flatten_chains: bool = True, 

411 include_chains: bool | None = None, 

412 ) -> list[CollectionRecord[K]]: 

413 # Docstring inherited 

414 include_chains = include_chains if include_chains is not None else not flatten_chains 

415 

416 def filter_types(records: Iterable[CollectionRecord[K]]) -> Iterator[CollectionRecord[K]]: 

417 for record in records: 

418 if record.type in collection_types: 

419 if record.type is not CollectionType.CHAINED or include_chains: 

420 yield record 

421 

422 if wildcard.patterns is ...: 

423 # As _fetch_all() returns all records without duplicates, we just 

424 # have to filter types. 

425 return list(filter_types(self._fetch_all())) 

426 

427 cache: CollectionRecordCache | None = None 

428 result: list[CollectionRecord[K]] = [] 

429 done_keys: set[K] = set() 

430 explicit_names = list(wildcard.strings) 

431 if wildcard.patterns: 

432 if explicit_names or flatten_chains: 

433 # To be efficient in case both patterns and strings are 

434 # specified we want to have caching enabled for at least the 

435 # duration of this call. Chains flattening can also produce 

436 # additional names to look for. 

437 cache = self._caching_context.collection_records or CollectionRecordCache() 

438 all_records = self._fetch_all(cache) 

439 for record in filter_types(all_records): 

440 if record.key not in done_keys: 

441 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

442 result.append(record) 

443 done_keys.add(record.key) 

444 if flatten_chains: 

445 # If flattening include children names of all matching chains. 

446 for record in all_records: 

447 if isinstance(record, ChainedCollectionRecord): 

448 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

449 explicit_names.extend(record.children) 

450 

451 if explicit_names: 

452 # _find_many() returns correctly ordered records, but there may be 

453 # duplicates. 

454 for record in filter_types(self._find_many(explicit_names, flatten_chains, cache)): 

455 if record.key not in done_keys: 

456 result.append(record) 

457 done_keys.add(record.key) 

458 

459 return result 

460 

461 def getDocumentation(self, key: K) -> str | None: 

462 # Docstring inherited from CollectionManager. 

463 docs = self.get_docs([key]) 

464 return docs.get(key) 

465 

466 def get_docs(self, keys: Iterable[K]) -> Mapping[K, str]: 

467 # Docstring inherited from CollectionManager. 

468 docs: dict[K, str] = {} 

469 id_column = self._tables.collection.columns[self._collectionIdName] 

470 doc_column = self._tables.collection.columns.doc 

471 for chunk in chunk_iterable(keys): 

472 sql = ( 

473 sqlalchemy.sql.select(id_column, doc_column) 

474 .select_from(self._tables.collection) 

475 .where(sqlalchemy.sql.and_(id_column.in_(chunk), doc_column != sqlalchemy.literal(""))) 

476 ) 

477 with self._db.query(sql) as sql_result: 

478 for row in sql_result: 

479 docs[row[0]] = row[1] 

480 return docs 

481 

482 def setDocumentation(self, key: K, doc: str | None) -> None: 

483 # Docstring inherited from CollectionManager. 

484 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

485 

486 def _addCachedRecord( 

487 self, record: CollectionRecord[K], collection_cache: CollectionRecordCache | None = None 

488 ) -> None: 

489 """Add single record to cache.""" 

490 if collection_cache is None: 

491 collection_cache = self._caching_context.collection_records 

492 if collection_cache is not None: 

493 collection_cache.add(record) 

494 

495 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: 

496 """Remove single record from cache.""" 

497 if self._caching_context.collection_records is not None: 

498 self._caching_context.collection_records.discard(record) 

499 

500 def _getByName(self, name: str) -> CollectionRecord[K] | None: 

501 """Find collection record given collection name.""" 

502 if self._caching_context.collection_records is not None: 

503 if (record := self._caching_context.collection_records.get_by_name(name)) is not None: 

504 return record 

505 records = self._fetch_by_name([name], False) 

506 for record in records: 

507 self._addCachedRecord(record) 

508 return records[0] if records else None 

509 

510 @abstractmethod 

511 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[K]]: 

512 """Fetch collection record from database given its name.""" 

513 raise NotImplementedError() 

514 

515 @abstractmethod 

516 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: 

517 """Fetch collection record from database given its key, or fetch all 

518 collctions if argument is None. 

519 """ 

520 raise NotImplementedError() 

521 

522 def update_chain( 

523 self, 

524 parent_collection_name: str, 

525 child_collection_names: list[str], 

526 allow_use_in_caching_context: bool = False, 

527 ) -> None: 

528 with self._modify_collection_chain( 

529 parent_collection_name, 

530 child_collection_names, 

531 # update_chain is currently used in setCollectionChain, which is 

532 # called within caching contexts. (At least in Butler.import_ and 

533 # possibly other places.) So, unlike the other collection chain 

534 # modification methods, it has to update the collection cache. 

535 skip_caching_check=allow_use_in_caching_context, 

536 ) as c: 

537 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": c.parent_key}) 

538 self._block_for_concurrency_test() 

539 self._insert_collection_chain_rows(c.parent_key, 0, c.child_keys) 

540 

541 names = [child.name for child in c.child_records] 

542 record = ChainedCollectionRecord[K](c.parent_key, parent_collection_name, children=tuple(names)) 

543 self._addCachedRecord(record) 

544 

545 def prepend_collection_chain( 

546 self, parent_collection_name: str, child_collection_names: list[str] 

547 ) -> None: 

548 self._add_to_collection_chain( 

549 parent_collection_name, child_collection_names, self._find_prepend_position 

550 ) 

551 

552 def extend_collection_chain(self, parent_collection_name: str, child_collection_names: list[str]) -> None: 

553 self._add_to_collection_chain( 

554 parent_collection_name, child_collection_names, self._find_extend_position 

555 ) 

556 

557 def _add_to_collection_chain( 

558 self, 

559 parent_collection_name: str, 

560 child_collection_names: list[str], 

561 position_func: Callable[[_CollectionChainModificationContext], int], 

562 ) -> None: 

563 with self._modify_collection_chain(parent_collection_name, child_collection_names) as c: 

564 # Remove any of the new children that are already in the 

565 # collection, so they move to a new position instead of being 

566 # duplicated. 

567 self._remove_collection_chain_rows(c.parent_key, c.child_keys) 

568 # Figure out where to insert the new children. 

569 starting_position = position_func(c) 

570 self._block_for_concurrency_test() 

571 self._insert_collection_chain_rows(c.parent_key, starting_position, c.child_keys) 

572 

573 def remove_from_collection_chain( 

574 self, parent_collection_name: str, child_collection_names: list[str] 

575 ) -> None: 

576 with self._modify_collection_chain( 

577 parent_collection_name, 

578 child_collection_names, 

579 # Removing members from a chain can't create collection cycles 

580 skip_cycle_check=True, 

581 # It is OK for multiple instances of `remove_from_collection_chain` 

582 # to run concurrently on the same collection, because it doesn't 

583 # read/modify the position numbers of the children -- it only 

584 # deletes existing rows. 

585 # 

586 # However, other chain modification operations must still be 

587 # blocked to avoid consistency issues. 

588 exclusive_lock=False, 

589 ) as c: 

590 self._block_for_concurrency_test() 

591 self._remove_collection_chain_rows(c.parent_key, c.child_keys) 

592 

593 @contextmanager 

594 def _modify_collection_chain( 

595 self, 

596 parent_collection_name: str, 

597 child_collection_names: list[str], 

598 *, 

599 skip_caching_check: bool = False, 

600 skip_cycle_check: bool = False, 

601 exclusive_lock: bool = True, 

602 ) -> Iterator[_CollectionChainModificationContext[K]]: 

603 if (not skip_caching_check) and self._caching_context.collection_records is not None: 

604 # Avoid having cache-maintenance code around that is unlikely to 

605 # ever be used. 

606 raise RuntimeError("Chained collection modification not permitted with active caching context.") 

607 

608 if not skip_cycle_check: 

609 self._sanity_check_collection_cycles(parent_collection_name, child_collection_names) 

610 

611 # Look up the collection primary keys corresponding to the 

612 # user-provided list of child collection names. Because there is no 

613 # locking for the child collections, it's possible for a concurrent 

614 # deletion of one of the children to cause a foreign key constraint 

615 # violation when we attempt to insert them in the collection chain 

616 # table later. 

617 child_records = self.resolve_wildcard( 

618 CollectionWildcard.from_names(child_collection_names), flatten_chains=False 

619 ) 

620 child_keys = [child.key for child in child_records] 

621 

622 with self._db.transaction(): 

623 # Lock the parent collection to prevent concurrent updates to the 

624 # same collection chain. 

625 parent_key = self._find_and_lock_collection_chain( 

626 parent_collection_name, exclusive_lock=exclusive_lock 

627 ) 

628 yield _CollectionChainModificationContext[K]( 

629 parent_key=parent_key, child_keys=child_keys, child_records=child_records 

630 ) 

631 

632 def _sanity_check_collection_cycles( 

633 self, parent_collection_name: str, child_collection_names: list[str] 

634 ) -> None: 

635 """Raise an exception if any of the collections in the ``child_names`` 

636 list have ``parent_name`` as a child, creating a collection cycle. 

637 

638 This is only a sanity check, and does not guarantee that no collection 

639 cycles are possible. Concurrent updates might allow collection cycles 

640 to be inserted. 

641 """ 

642 for record in self.resolve_wildcard( 

643 CollectionWildcard.from_names(child_collection_names), 

644 flatten_chains=True, 

645 include_chains=True, 

646 collection_types={CollectionType.CHAINED}, 

647 ): 

648 if record.name == parent_collection_name: 

649 raise CollectionCycleError( 

650 f"Cycle in collection chaining when defining '{parent_collection_name}'." 

651 ) 

652 

653 def _insert_collection_chain_rows( 

654 self, 

655 parent_key: K, 

656 starting_position: int, 

657 child_keys: list[K], 

658 ) -> None: 

659 rows = [ 

660 { 

661 "parent": parent_key, 

662 "child": child, 

663 "position": position, 

664 } 

665 for position, child in enumerate(child_keys, starting_position) 

666 ] 

667 

668 # It's possible for the DB to raise an exception for the integers being 

669 # out of range here. The position column is only a 16-bit number. 

670 # Even if there aren't an unreasonably large number of children in the 

671 # collection, a series of many deletes and insertions could cause the 

672 # space to become fragmented. 

673 # 

674 # If this ever actually happens, we should consider doing a migration 

675 # to increase the position column to a 32-bit number. 

676 # To fix it in the short term, you can re-write the collection chain to 

677 # defragment it by doing something like: 

678 # registry.setCollectionChain( 

679 # parent, 

680 # registry.getCollectionChain(parent) 

681 # ) 

682 self._db.insert(self._tables.collection_chain, *rows) 

683 

684 def _remove_collection_chain_rows( 

685 self, 

686 parent_key: K, 

687 child_keys: list[K], 

688 ) -> None: 

689 table = self._tables.collection_chain 

690 where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys)) 

691 self._db.deleteWhere(table, where) 

692 

693 def _find_prepend_position(self, c: _CollectionChainModificationContext) -> int: 

694 """Return the position where children can be inserted to 

695 prepend them to a collection chain. 

696 """ 

697 return self._find_position_in_collection_chain(c.parent_key, "begin") - len(c.child_keys) 

698 

699 def _find_extend_position(self, c: _CollectionChainModificationContext) -> int: 

700 """Return the position where children can be inserted to append them to 

701 a collection chain. 

702 """ 

703 return self._find_position_in_collection_chain(c.parent_key, "end") + 1 

704 

705 def _find_position_in_collection_chain(self, chain_key: K, begin_or_end: Literal["begin", "end"]) -> int: 

706 """Return the lowest or highest numbered position in a collection 

707 chain, or 0 if the chain is empty. 

708 """ 

709 table = self._tables.collection_chain 

710 

711 func: sqlalchemy.Function 

712 match begin_or_end: 

713 case "begin": 

714 func = sqlalchemy.func.min(table.c.position) 

715 case "end": 

716 func = sqlalchemy.func.max(table.c.position) 

717 

718 query = sqlalchemy.select(func).where(table.c.parent == chain_key) 

719 with self._db.query(query) as cursor: 

720 position = cursor.scalar() 

721 

722 if position is None: 

723 return 0 

724 

725 return position 

726 

727 def _find_and_lock_collection_chain(self, collection_name: str, *, exclusive_lock: bool) -> K: 

728 """ 

729 Take a row lock on the specified collection's row in the collections 

730 table, and return the collection's primary key. 

731 

732 This lock is used to synchronize updates to collection chains. 

733 

734 The locking strategy requires cooperation from everything modifying the 

735 collection chain table -- all operations that modify collection chains 

736 must obtain this lock first. The database will NOT automatically 

737 prevent modification of tables based on this lock. The only guarantee 

738 is that only one caller will be allowed to hold the exclusive lock for 

739 a given collection at a time. Concurrent calls will block until the 

740 caller holding the lock has completed its transaction. 

741 

742 Parameters 

743 ---------- 

744 collection_name : `str` 

745 Name of the collection whose chain is being modified. 

746 exclusive_lock : `bool` 

747 If `True`, an exclusive lock will be taken to block all concurrent 

748 modifications to the same collection. If `False`, a shared lock 

749 will be taken which will only block operations that request an 

750 exclusive lock. 

751 

752 Returns 

753 ------- 

754 id : ``K`` 

755 The primary key for the given collection. 

756 

757 Raises 

758 ------ 

759 MissingCollectionError 

760 If the specified collection is not in the database table. 

761 CollectionTypeError 

762 If the specified collection is not a chained collection. 

763 """ 

764 assert self._db.isInTransaction(), ( 

765 "Row locks are only held until the end of the current transaction," 

766 " so it makes no sense to take a lock outside a transaction." 

767 ) 

768 assert self._db.isWriteable(), "Collection row locks are only useful for write operations." 

769 

770 query = self._select_pkey_by_name(collection_name).with_for_update(read=not exclusive_lock) 

771 with self._db.query(query) as cursor: 

772 rows = cursor.all() 

773 

774 if len(rows) == 0: 

775 raise MissingCollectionError( 

776 f"Parent collection {collection_name} not found when updating collection chain." 

777 ) 

778 assert len(rows) == 1, "There should only be one entry for each collection in collection table." 

779 r = rows[0]._mapping 

780 if r["type"] != CollectionType.CHAINED: 

781 raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.") 

782 return r["key"] 

783 

784 @abstractmethod 

785 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: 

786 """Return a SQLAlchemy select statement that will return columns from 

787 the one row in the ``collection` table matching the given name. The 

788 select statement includes two columns: 

789 

790 - ``key`` : the primary key for the collection 

791 - ``type`` : the collection type 

792 """ 

793 raise NotImplementedError() 

794 

795 def _query_recursive( 

796 self, 

797 collections: Iterable[str], 

798 key_type: type, 

799 ) -> list[Mapping]: 

800 """Run the query that recursively finds collections and all their 

801 child collections. 

802 

803 Parameters 

804 ---------- 

805 collections : `~collections.abc.Iterable` [`str`] 

806 List of collection names to retrieve. 

807 key_type : `type` 

808 Type of the key column, e.g. `sqlalchemy.BigInteger`. 

809 

810 Returns 

811 ------- 

812 rows : `list` [`~collections.abc.Mapping`] 

813 Database rows resulting from the query. Each row contains a 

814 combination of columns from ``collections`` table and 

815 ``collection_chain`` table joined on ``child`` column, 

816 ``child`` column is not included into returned mappings. 

817 For top-level collections both ``parent`` and ``position`` will 

818 be `None`. Same collection can appear multiple times if it is a 

819 child of multiple collections. 

820 """ 

821 # Make recursive CTE to fetch everything in one query. There may be 

822 # duplicate collection names in the result, but it should not affect 

823 # performance too much for the limited number of input collections. 

824 # 

825 # The query will look like 

826 # 

827 # WITH RECURSIVE chains AS ( 

828 # SELECT 

829 # coll_1.*, 

830 # cast(NULL as KEY_TYPE) parent, 

831 # cast(NULL as SMALLINT) position 

832 # FROM 

833 # collection coll_1 

834 # WHERE 

835 # coll_1.name IN (:collections) 

836 # UNION ALL 

837 # SELECT 

838 # coll_2.*, 

839 # chain_2.parent, 

840 # chain_2.position 

841 # FROM 

842 # collection coll_2 

843 # JOIN collection_chain chain_2 

844 # ON coll_2.key_column = chain_2.child 

845 # JOIN chains ON chain_2.parent = chains.key_column 

846 # ) 

847 # SELECT 

848 # ch.*, 

849 # run.host, 

850 # run.timespan 

851 # FROM 

852 # chains ch 

853 # LEFT OUTER JOIN run 

854 # ON ch.key_column = run.key_column; 

855 # 

856 chain_table = self._tables.collection_chain 

857 collection_table = self._tables.collection 

858 run_table = self._tables.run 

859 key_column = self._collectionIdName 

860 

861 # First CTE select. 

862 coll_1 = collection_table.alias("coll_1") 

863 chains_cte = ( 

864 sqlalchemy.select( 

865 *coll_1.columns, 

866 sqlalchemy.cast(None, type_=key_type).label("parent"), 

867 sqlalchemy.cast(None, type_=sqlalchemy.SmallInteger).label("position"), 

868 ) 

869 .where(coll_1.columns["name"].in_(collections)) 

870 .cte("chains", recursive=True) 

871 ) 

872 

873 # Second CTE select. 

874 cte_alias = chains_cte.alias() 

875 coll_2 = collection_table.alias("coll_2") 

876 chain_2 = chain_table.alias("chain_2") 

877 chains_cte = chains_cte.union_all( 

878 sqlalchemy.select( 

879 *coll_2.columns, chain_2.columns["parent"], chain_2.columns["position"] 

880 ).select_from( 

881 coll_2.join(chain_2, onclause=(coll_2.columns[key_column] == chain_2.columns["child"])).join( 

882 cte_alias, onclause=(chain_2.columns["parent"] == cte_alias.columns[key_column]) 

883 ) 

884 ) 

885 ) 

886 

887 # Outer select joining chains CTE with run table using LEFT OUTER JOIN. 

888 TimespanReprClass = self._db.getTimespanRepresentation() 

889 query = sqlalchemy.select( 

890 *chains_cte.columns, 

891 run_table.columns["host"], 

892 *[run_table.columns[column] for column in TimespanReprClass.getFieldNames()], 

893 ).select_from( 

894 chains_cte.join( 

895 run_table, 

896 isouter=True, 

897 onclause=(chains_cte.columns[key_column] == run_table.columns[key_column]), 

898 ) 

899 ) 

900 

901 with self._db.transaction(): 

902 with self._db.query(query) as sql_result: 

903 return list(sql_result.mappings().fetchall()) 

904 

905 

906class _CollectionChainModificationContext(NamedTuple, Generic[K]): 

907 parent_key: K 

908 child_keys: list[K] 

909 child_records: list[CollectionRecord[K]]