Coverage for python/lsst/daf/butler/registry/collections/_base.py: 92%

182 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-27 09:43 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33import itertools 

34from abc import abstractmethod 

35from collections import namedtuple 

36from collections.abc import Iterable, Iterator, Set 

37from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast 

38 

39import sqlalchemy 

40 

41from ..._timespan import Timespan, TimespanDatabaseRepresentation 

42from ...dimensions import DimensionUniverse 

43from .._collection_type import CollectionType 

44from .._exceptions import MissingCollectionError 

45from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

46from ..wildcards import CollectionWildcard 

47 

48if TYPE_CHECKING: 

49 from ..interfaces import Database, DimensionRecordStorageManager 

50 

51 

52def _makeCollectionForeignKey( 

53 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

54) -> ddl.ForeignKeySpec: 

55 """Define foreign key specification that refers to collections table. 

56 

57 Parameters 

58 ---------- 

59 sourceColumnName : `str` 

60 Name of the column in the referring table. 

61 collectionIdName : `str` 

62 Name of the column in collections table that identifies it (PK). 

63 **kwargs 

64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

65 

66 Returns 

67 ------- 

68 spec : `ddl.ForeignKeySpec` 

69 Foreign key specification. 

70 

71 Notes 

72 ----- 

73 This method assumes fixed name ("collection") of a collections table. 

74 There is also a general assumption that collection primary key consists 

75 of a single column. 

76 """ 

77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

78 

79 

80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

81 

82 

83def makeRunTableSpec( 

84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

85) -> ddl.TableSpec: 

86 """Define specification for "run" table. 

87 

88 Parameters 

89 ---------- 

90 collectionIdName : `str` 

91 Name of the column in collections table that identifies it (PK). 

92 collectionIdType 

93 Type of the PK column in the collections table, one of the 

94 `sqlalchemy` types. 

95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

97 timespans are stored in this database. 

98 

99 

100 Returns 

101 ------- 

102 spec : `ddl.TableSpec` 

103 Specification for run table. 

104 

105 Notes 

106 ----- 

107 Assumption here and in the code below is that the name of the identifying 

108 column is the same in both collections and run tables. The names of 

109 non-identifying columns containing run metadata are fixed. 

110 """ 

111 result = ddl.TableSpec( 

112 fields=[ 

113 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

114 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

115 ], 

116 foreignKeys=[ 

117 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

118 ], 

119 ) 

120 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

121 result.fields.add(fieldSpec) 

122 return result 

123 

124 

125def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

126 """Define specification for "collection_chain" table. 

127 

128 Parameters 

129 ---------- 

130 collectionIdName : `str` 

131 Name of the column in collections table that identifies it (PK). 

132 collectionIdType 

133 Type of the PK column in the collections table, one of the 

134 `sqlalchemy` types. 

135 

136 Returns 

137 ------- 

138 spec : `ddl.TableSpec` 

139 Specification for collection chain table. 

140 

141 Notes 

142 ----- 

143 Collection chain is simply an ordered one-to-many relation between 

144 collections. The names of the columns in the table are fixed and 

145 also hardcoded in the code below. 

146 """ 

147 return ddl.TableSpec( 

148 fields=[ 

149 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

150 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

151 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

152 ], 

153 foreignKeys=[ 

154 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

155 _makeCollectionForeignKey("child", collectionIdName), 

156 ], 

157 ) 

158 

159 

160class DefaultRunRecord(RunRecord): 

161 """Default `RunRecord` implementation. 

162 

163 This method assumes the same run table definition as produced by 

164 `makeRunTableSpec` method. The only non-fixed name in the schema 

165 is the PK column name, this needs to be passed in a constructor. 

166 

167 Parameters 

168 ---------- 

169 db : `Database` 

170 Registry database. 

171 key 

172 Unique collection ID, can be the same as ``name`` if ``name`` is used 

173 for identification. Usually this is an integer or string, but can be 

174 other database-specific type. 

175 name : `str` 

176 Run collection name. 

177 table : `sqlalchemy.schema.Table` 

178 Table for run records. 

179 idColumnName : `str` 

180 Name of the identifying column in run table. 

181 host : `str`, optional 

182 Name of the host where run was produced. 

183 timespan : `Timespan`, optional 

184 Timespan for this run. 

185 """ 

186 

187 def __init__( 

188 self, 

189 db: Database, 

190 key: Any, 

191 name: str, 

192 *, 

193 table: sqlalchemy.schema.Table, 

194 idColumnName: str, 

195 host: str | None = None, 

196 timespan: Timespan | None = None, 

197 ): 

198 super().__init__(key=key, name=name, type=CollectionType.RUN) 

199 self._db = db 

200 self._table = table 

201 self._host = host 

202 if timespan is None: 202 ↛ 204line 202 didn't jump to line 204, because the condition on line 202 was never false

203 timespan = Timespan(begin=None, end=None) 

204 self._timespan = timespan 

205 self._idName = idColumnName 

206 

207 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: 

208 # Docstring inherited from RunRecord. 

209 if timespan is None: 

210 timespan = Timespan(begin=None, end=None) 

211 row = { 

212 self._idName: self.key, 

213 "host": host, 

214 } 

215 self._db.getTimespanRepresentation().update(timespan, result=row) 

216 count = self._db.update(self._table, {self._idName: self.key}, row) 

217 if count != 1: 

218 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

219 self._host = host 

220 self._timespan = timespan 

221 

222 @property 

223 def host(self) -> str | None: 

224 # Docstring inherited from RunRecord. 

225 return self._host 

226 

227 @property 

228 def timespan(self) -> Timespan: 

229 # Docstring inherited from RunRecord. 

230 return self._timespan 

231 

232 

233class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

234 """Default `ChainedCollectionRecord` implementation. 

235 

236 This method assumes the same chain table definition as produced by 

237 `makeCollectionChainTableSpec` method. All column names in the table are 

238 fixed and hard-coded in the methods. 

239 

240 Parameters 

241 ---------- 

242 db : `Database` 

243 Registry database. 

244 key 

245 Unique collection ID, can be the same as ``name`` if ``name`` is used 

246 for identification. Usually this is an integer or string, but can be 

247 other database-specific type. 

248 name : `str` 

249 Collection name. 

250 table : `sqlalchemy.schema.Table` 

251 Table for chain relationship records. 

252 universe : `DimensionUniverse` 

253 Object managing all known dimensions. 

254 """ 

255 

256 def __init__( 

257 self, 

258 db: Database, 

259 key: Any, 

260 name: str, 

261 *, 

262 table: sqlalchemy.schema.Table, 

263 universe: DimensionUniverse, 

264 ): 

265 super().__init__(key=key, name=name, universe=universe) 

266 self._db = db 

267 self._table = table 

268 self._universe = universe 

269 

270 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: 

271 # Docstring inherited from ChainedCollectionRecord. 

272 rows = [] 

273 position = itertools.count() 

274 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

275 rows.append( 

276 { 

277 "parent": self.key, 

278 "child": child.key, 

279 "position": next(position), 

280 } 

281 ) 

282 with self._db.transaction(): 

283 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

284 self._db.insert(self._table, *rows) 

285 

286 def _load(self, manager: CollectionManager) -> tuple[str, ...]: 

287 # Docstring inherited from ChainedCollectionRecord. 

288 sql = ( 

289 sqlalchemy.sql.select( 

290 self._table.columns.child, 

291 ) 

292 .select_from(self._table) 

293 .where(self._table.columns.parent == self.key) 

294 .order_by(self._table.columns.position) 

295 ) 

296 with self._db.query(sql) as sql_result: 

297 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings()) 

298 

299 

300K = TypeVar("K") 

301 

302 

303class DefaultCollectionManager(Generic[K], CollectionManager): 

304 """Default `CollectionManager` implementation. 

305 

306 This implementation uses record classes defined in this module and is 

307 based on the same assumptions about schema outlined in the record classes. 

308 

309 Parameters 

310 ---------- 

311 db : `Database` 

312 Interface to the underlying database engine and namespace. 

313 tables : `CollectionTablesTuple` 

314 Named tuple of SQLAlchemy table objects. 

315 collectionIdName : `str` 

316 Name of the column in collections table that identifies it (PK). 

317 dimensions : `DimensionRecordStorageManager` 

318 Manager object for the dimensions in this `Registry`. 

319 

320 Notes 

321 ----- 

322 Implementation uses "aggressive" pre-fetching and caching of the records 

323 in memory. Memory cache is synchronized from database when `refresh` 

324 method is called. 

325 """ 

326 

327 def __init__( 

328 self, 

329 db: Database, 

330 tables: CollectionTablesTuple, 

331 collectionIdName: str, 

332 *, 

333 dimensions: DimensionRecordStorageManager, 

334 registry_schema_version: VersionTuple | None = None, 

335 ): 

336 super().__init__(registry_schema_version=registry_schema_version) 

337 self._db = db 

338 self._tables = tables 

339 self._collectionIdName = collectionIdName 

340 self._records: dict[K, CollectionRecord] = {} # indexed by record ID 

341 self._dimensions = dimensions 

342 

343 def refresh(self) -> None: 

344 # Docstring inherited from CollectionManager. 

345 sql = sqlalchemy.sql.select( 

346 *(list(self._tables.collection.columns) + list(self._tables.run.columns)) 

347 ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) 

348 # Put found records into a temporary instead of updating self._records 

349 # in place, for exception safety. 

350 records = [] 

351 chains = [] 

352 TimespanReprClass = self._db.getTimespanRepresentation() 

353 with self._db.query(sql) as sql_result: 

354 sql_rows = sql_result.mappings().fetchall() 

355 for row in sql_rows: 

356 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

357 name = row[self._tables.collection.columns.name] 

358 type = CollectionType(row["type"]) 

359 record: CollectionRecord 

360 if type is CollectionType.RUN: 

361 record = DefaultRunRecord( 

362 key=collection_id, 

363 name=name, 

364 db=self._db, 

365 table=self._tables.run, 

366 idColumnName=self._collectionIdName, 

367 host=row[self._tables.run.columns.host], 

368 timespan=TimespanReprClass.extract(row), 

369 ) 

370 elif type is CollectionType.CHAINED: 

371 record = DefaultChainedCollectionRecord( 

372 db=self._db, 

373 key=collection_id, 

374 table=self._tables.collection_chain, 

375 name=name, 

376 universe=self._dimensions.universe, 

377 ) 

378 chains.append(record) 

379 else: 

380 record = CollectionRecord(key=collection_id, name=name, type=type) 

381 records.append(record) 

382 self._setRecordCache(records) 

383 for chain in chains: 

384 try: 

385 chain.refresh(self) 

386 except MissingCollectionError: 

387 # This indicates a race condition in which some other client 

388 # created a new collection and added it as a child of this 

389 # (pre-existing) chain between the time we fetched all 

390 # collections and the time we queried for parent-child 

391 # relationships. 

392 # Because that's some other unrelated client, we shouldn't care 

393 # about that parent collection anyway, so we just drop it on 

394 # the floor (a manual refresh can be used to get it back). 

395 self._removeCachedRecord(chain) 

396 

397 def register( 

398 self, name: str, type: CollectionType, doc: str | None = None 

399 ) -> tuple[CollectionRecord, bool]: 

400 # Docstring inherited from CollectionManager. 

401 registered = False 

402 record = self._getByName(name) 

403 if record is None: 

404 row, inserted_or_updated = self._db.sync( 

405 self._tables.collection, 

406 keys={"name": name}, 

407 compared={"type": int(type)}, 

408 extra={"doc": doc}, 

409 returning=[self._collectionIdName], 

410 ) 

411 assert isinstance(inserted_or_updated, bool) 

412 registered = inserted_or_updated 

413 assert row is not None 

414 collection_id = row[self._collectionIdName] 

415 if type is CollectionType.RUN: 

416 TimespanReprClass = self._db.getTimespanRepresentation() 

417 row, _ = self._db.sync( 

418 self._tables.run, 

419 keys={self._collectionIdName: collection_id}, 

420 returning=("host",) + TimespanReprClass.getFieldNames(), 

421 ) 

422 assert row is not None 

423 record = DefaultRunRecord( 

424 db=self._db, 

425 key=collection_id, 

426 name=name, 

427 table=self._tables.run, 

428 idColumnName=self._collectionIdName, 

429 host=row["host"], 

430 timespan=TimespanReprClass.extract(row), 

431 ) 

432 elif type is CollectionType.CHAINED: 

433 record = DefaultChainedCollectionRecord( 

434 db=self._db, 

435 key=collection_id, 

436 name=name, 

437 table=self._tables.collection_chain, 

438 universe=self._dimensions.universe, 

439 ) 

440 else: 

441 record = CollectionRecord(key=collection_id, name=name, type=type) 

442 self._addCachedRecord(record) 

443 return record, registered 

444 

445 def remove(self, name: str) -> None: 

446 # Docstring inherited from CollectionManager. 

447 record = self._getByName(name) 

448 if record is None: 448 ↛ 449line 448 didn't jump to line 449, because the condition on line 448 was never true

449 raise MissingCollectionError(f"No collection with name '{name}' found.") 

450 # This may raise 

451 self._db.delete( 

452 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

453 ) 

454 self._removeCachedRecord(record) 

455 

456 def find(self, name: str) -> CollectionRecord: 

457 # Docstring inherited from CollectionManager. 

458 result = self._getByName(name) 

459 if result is None: 

460 raise MissingCollectionError(f"No collection with name '{name}' found.") 

461 return result 

462 

463 def __getitem__(self, key: Any) -> CollectionRecord: 

464 # Docstring inherited from CollectionManager. 

465 try: 

466 return self._records[key] 

467 except KeyError as err: 

468 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

469 

470 def resolve_wildcard( 

471 self, 

472 wildcard: CollectionWildcard, 

473 *, 

474 collection_types: Set[CollectionType] = CollectionType.all(), 

475 done: set[str] | None = None, 

476 flatten_chains: bool = True, 

477 include_chains: bool | None = None, 

478 ) -> list[CollectionRecord]: 

479 # Docstring inherited 

480 if done is None: 480 ↛ 482line 480 didn't jump to line 482, because the condition on line 480 was never false

481 done = set() 

482 include_chains = include_chains if include_chains is not None else not flatten_chains 

483 

484 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]: 

485 if record.name in done: 

486 return 

487 if record.type in collection_types: 

488 done.add(record.name) 

489 if record.type is not CollectionType.CHAINED or include_chains: 

490 yield record 

491 if flatten_chains and record.type is CollectionType.CHAINED: 

492 done.add(record.name) 

493 for name in cast(ChainedCollectionRecord, record).children: 

494 # flake8 can't tell that we only delete this closure when 

495 # we're totally done with it. 

496 yield from resolve_nested(self.find(name), done) # noqa: F821 

497 

498 result: list[CollectionRecord] = [] 

499 

500 if wildcard.patterns is ...: 

501 for record in self._records.values(): 

502 result.extend(resolve_nested(record, done)) 

503 del resolve_nested 

504 return result 

505 for name in wildcard.strings: 

506 result.extend(resolve_nested(self.find(name), done)) 

507 if wildcard.patterns: 

508 for record in self._records.values(): 

509 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

510 result.extend(resolve_nested(record, done)) 

511 del resolve_nested 

512 return result 

513 

514 def getDocumentation(self, key: Any) -> str | None: 

515 # Docstring inherited from CollectionManager. 

516 sql = ( 

517 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

518 .select_from(self._tables.collection) 

519 .where(self._tables.collection.columns[self._collectionIdName] == key) 

520 ) 

521 with self._db.query(sql) as sql_result: 

522 return sql_result.scalar() 

523 

524 def setDocumentation(self, key: Any, doc: str | None) -> None: 

525 # Docstring inherited from CollectionManager. 

526 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

527 

528 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

529 """Set internal record cache to contain given records, 

530 old cached records will be removed. 

531 """ 

532 self._records = {} 

533 for record in records: 

534 self._records[record.key] = record 

535 

536 def _addCachedRecord(self, record: CollectionRecord) -> None: 

537 """Add single record to cache.""" 

538 self._records[record.key] = record 

539 

540 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

541 """Remove single record from cache.""" 

542 del self._records[record.key] 

543 

544 @abstractmethod 

545 def _getByName(self, name: str) -> CollectionRecord | None: 

546 """Find collection record given collection name.""" 

547 raise NotImplementedError()