Coverage for python/lsst/daf/butler/registry/collections/_base.py: 92%

180 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 07:59 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = () 

30 

31import itertools 

32from abc import abstractmethod 

33from collections import namedtuple 

34from collections.abc import Iterable, Iterator, Set 

35from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast 

36 

37import sqlalchemy 

38 

39from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl 

40from .._collectionType import CollectionType 

41from .._exceptions import MissingCollectionError 

42from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

43from ..wildcards import CollectionWildcard 

44 

45if TYPE_CHECKING: 

46 from ..interfaces import Database, DimensionRecordStorageManager 

47 

48 

49def _makeCollectionForeignKey( 

50 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

51) -> ddl.ForeignKeySpec: 

52 """Define foreign key specification that refers to collections table. 

53 

54 Parameters 

55 ---------- 

56 sourceColumnName : `str` 

57 Name of the column in the referring table. 

58 collectionIdName : `str` 

59 Name of the column in collections table that identifies it (PK). 

60 **kwargs 

61 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

62 

63 Returns 

64 ------- 

65 spec : `ddl.ForeignKeySpec` 

66 Foreign key specification. 

67 

68 Notes 

69 ----- 

70 This method assumes fixed name ("collection") of a collections table. 

71 There is also a general assumption that collection primary key consists 

72 of a single column. 

73 """ 

74 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

75 

76 

77CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

78 

79 

80def makeRunTableSpec( 

81 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

82) -> ddl.TableSpec: 

83 """Define specification for "run" table. 

84 

85 Parameters 

86 ---------- 

87 collectionIdName : `str` 

88 Name of the column in collections table that identifies it (PK). 

89 collectionIdType 

90 Type of the PK column in the collections table, one of the 

91 `sqlalchemy` types. 

92 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

93 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

94 timespans are stored in this database. 

95 

96 

97 Returns 

98 ------- 

99 spec : `ddl.TableSpec` 

100 Specification for run table. 

101 

102 Notes 

103 ----- 

104 Assumption here and in the code below is that the name of the identifying 

105 column is the same in both collections and run tables. The names of 

106 non-identifying columns containing run metadata are fixed. 

107 """ 

108 result = ddl.TableSpec( 

109 fields=[ 

110 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

111 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

112 ], 

113 foreignKeys=[ 

114 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

115 ], 

116 ) 

117 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

118 result.fields.add(fieldSpec) 

119 return result 

120 

121 

122def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

123 """Define specification for "collection_chain" table. 

124 

125 Parameters 

126 ---------- 

127 collectionIdName : `str` 

128 Name of the column in collections table that identifies it (PK). 

129 collectionIdType 

130 Type of the PK column in the collections table, one of the 

131 `sqlalchemy` types. 

132 

133 Returns 

134 ------- 

135 spec : `ddl.TableSpec` 

136 Specification for collection chain table. 

137 

138 Notes 

139 ----- 

140 Collection chain is simply an ordered one-to-many relation between 

141 collections. The names of the columns in the table are fixed and 

142 also hardcoded in the code below. 

143 """ 

144 return ddl.TableSpec( 

145 fields=[ 

146 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

147 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

148 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

149 ], 

150 foreignKeys=[ 

151 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

152 _makeCollectionForeignKey("child", collectionIdName), 

153 ], 

154 ) 

155 

156 

157class DefaultRunRecord(RunRecord): 

158 """Default `RunRecord` implementation. 

159 

160 This method assumes the same run table definition as produced by 

161 `makeRunTableSpec` method. The only non-fixed name in the schema 

162 is the PK column name, this needs to be passed in a constructor. 

163 

164 Parameters 

165 ---------- 

166 db : `Database` 

167 Registry database. 

168 key 

169 Unique collection ID, can be the same as ``name`` if ``name`` is used 

170 for identification. Usually this is an integer or string, but can be 

171 other database-specific type. 

172 name : `str` 

173 Run collection name. 

174 table : `sqlalchemy.schema.Table` 

175 Table for run records. 

176 idColumnName : `str` 

177 Name of the identifying column in run table. 

178 host : `str`, optional 

179 Name of the host where run was produced. 

180 timespan : `Timespan`, optional 

181 Timespan for this run. 

182 """ 

183 

184 def __init__( 

185 self, 

186 db: Database, 

187 key: Any, 

188 name: str, 

189 *, 

190 table: sqlalchemy.schema.Table, 

191 idColumnName: str, 

192 host: str | None = None, 

193 timespan: Timespan | None = None, 

194 ): 

195 super().__init__(key=key, name=name, type=CollectionType.RUN) 

196 self._db = db 

197 self._table = table 

198 self._host = host 

199 if timespan is None: 199 ↛ 201line 199 didn't jump to line 201, because the condition on line 199 was never false

200 timespan = Timespan(begin=None, end=None) 

201 self._timespan = timespan 

202 self._idName = idColumnName 

203 

204 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: 

205 # Docstring inherited from RunRecord. 

206 if timespan is None: 

207 timespan = Timespan(begin=None, end=None) 

208 row = { 

209 self._idName: self.key, 

210 "host": host, 

211 } 

212 self._db.getTimespanRepresentation().update(timespan, result=row) 

213 count = self._db.update(self._table, {self._idName: self.key}, row) 

214 if count != 1: 

215 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

216 self._host = host 

217 self._timespan = timespan 

218 

219 @property 

220 def host(self) -> str | None: 

221 # Docstring inherited from RunRecord. 

222 return self._host 

223 

224 @property 

225 def timespan(self) -> Timespan: 

226 # Docstring inherited from RunRecord. 

227 return self._timespan 

228 

229 

230class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

231 """Default `ChainedCollectionRecord` implementation. 

232 

233 This method assumes the same chain table definition as produced by 

234 `makeCollectionChainTableSpec` method. All column names in the table are 

235 fixed and hard-coded in the methods. 

236 

237 Parameters 

238 ---------- 

239 db : `Database` 

240 Registry database. 

241 key 

242 Unique collection ID, can be the same as ``name`` if ``name`` is used 

243 for identification. Usually this is an integer or string, but can be 

244 other database-specific type. 

245 name : `str` 

246 Collection name. 

247 table : `sqlalchemy.schema.Table` 

248 Table for chain relationship records. 

249 universe : `DimensionUniverse` 

250 Object managing all known dimensions. 

251 """ 

252 

253 def __init__( 

254 self, 

255 db: Database, 

256 key: Any, 

257 name: str, 

258 *, 

259 table: sqlalchemy.schema.Table, 

260 universe: DimensionUniverse, 

261 ): 

262 super().__init__(key=key, name=name, universe=universe) 

263 self._db = db 

264 self._table = table 

265 self._universe = universe 

266 

267 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: 

268 # Docstring inherited from ChainedCollectionRecord. 

269 rows = [] 

270 position = itertools.count() 

271 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

272 rows.append( 

273 { 

274 "parent": self.key, 

275 "child": child.key, 

276 "position": next(position), 

277 } 

278 ) 

279 with self._db.transaction(): 

280 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

281 self._db.insert(self._table, *rows) 

282 

283 def _load(self, manager: CollectionManager) -> tuple[str, ...]: 

284 # Docstring inherited from ChainedCollectionRecord. 

285 sql = ( 

286 sqlalchemy.sql.select( 

287 self._table.columns.child, 

288 ) 

289 .select_from(self._table) 

290 .where(self._table.columns.parent == self.key) 

291 .order_by(self._table.columns.position) 

292 ) 

293 with self._db.query(sql) as sql_result: 

294 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings()) 

295 

296 

297K = TypeVar("K") 

298 

299 

300class DefaultCollectionManager(Generic[K], CollectionManager): 

301 """Default `CollectionManager` implementation. 

302 

303 This implementation uses record classes defined in this module and is 

304 based on the same assumptions about schema outlined in the record classes. 

305 

306 Parameters 

307 ---------- 

308 db : `Database` 

309 Interface to the underlying database engine and namespace. 

310 tables : `CollectionTablesTuple` 

311 Named tuple of SQLAlchemy table objects. 

312 collectionIdName : `str` 

313 Name of the column in collections table that identifies it (PK). 

314 dimensions : `DimensionRecordStorageManager` 

315 Manager object for the dimensions in this `Registry`. 

316 

317 Notes 

318 ----- 

319 Implementation uses "aggressive" pre-fetching and caching of the records 

320 in memory. Memory cache is synchronized from database when `refresh` 

321 method is called. 

322 """ 

323 

324 def __init__( 

325 self, 

326 db: Database, 

327 tables: CollectionTablesTuple, 

328 collectionIdName: str, 

329 *, 

330 dimensions: DimensionRecordStorageManager, 

331 registry_schema_version: VersionTuple | None = None, 

332 ): 

333 super().__init__(registry_schema_version=registry_schema_version) 

334 self._db = db 

335 self._tables = tables 

336 self._collectionIdName = collectionIdName 

337 self._records: dict[K, CollectionRecord] = {} # indexed by record ID 

338 self._dimensions = dimensions 

339 

340 def refresh(self) -> None: 

341 # Docstring inherited from CollectionManager. 

342 sql = sqlalchemy.sql.select( 

343 *(list(self._tables.collection.columns) + list(self._tables.run.columns)) 

344 ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) 

345 # Put found records into a temporary instead of updating self._records 

346 # in place, for exception safety. 

347 records = [] 

348 chains = [] 

349 TimespanReprClass = self._db.getTimespanRepresentation() 

350 with self._db.query(sql) as sql_result: 

351 sql_rows = sql_result.mappings().fetchall() 

352 for row in sql_rows: 

353 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

354 name = row[self._tables.collection.columns.name] 

355 type = CollectionType(row["type"]) 

356 record: CollectionRecord 

357 if type is CollectionType.RUN: 

358 record = DefaultRunRecord( 

359 key=collection_id, 

360 name=name, 

361 db=self._db, 

362 table=self._tables.run, 

363 idColumnName=self._collectionIdName, 

364 host=row[self._tables.run.columns.host], 

365 timespan=TimespanReprClass.extract(row), 

366 ) 

367 elif type is CollectionType.CHAINED: 

368 record = DefaultChainedCollectionRecord( 

369 db=self._db, 

370 key=collection_id, 

371 table=self._tables.collection_chain, 

372 name=name, 

373 universe=self._dimensions.universe, 

374 ) 

375 chains.append(record) 

376 else: 

377 record = CollectionRecord(key=collection_id, name=name, type=type) 

378 records.append(record) 

379 self._setRecordCache(records) 

380 for chain in chains: 

381 try: 

382 chain.refresh(self) 

383 except MissingCollectionError: 

384 # This indicates a race condition in which some other client 

385 # created a new collection and added it as a child of this 

386 # (pre-existing) chain between the time we fetched all 

387 # collections and the time we queried for parent-child 

388 # relationships. 

389 # Because that's some other unrelated client, we shouldn't care 

390 # about that parent collection anyway, so we just drop it on 

391 # the floor (a manual refresh can be used to get it back). 

392 self._removeCachedRecord(chain) 

393 

394 def register( 

395 self, name: str, type: CollectionType, doc: str | None = None 

396 ) -> tuple[CollectionRecord, bool]: 

397 # Docstring inherited from CollectionManager. 

398 registered = False 

399 record = self._getByName(name) 

400 if record is None: 

401 row, inserted_or_updated = self._db.sync( 

402 self._tables.collection, 

403 keys={"name": name}, 

404 compared={"type": int(type)}, 

405 extra={"doc": doc}, 

406 returning=[self._collectionIdName], 

407 ) 

408 assert isinstance(inserted_or_updated, bool) 

409 registered = inserted_or_updated 

410 assert row is not None 

411 collection_id = row[self._collectionIdName] 

412 if type is CollectionType.RUN: 

413 TimespanReprClass = self._db.getTimespanRepresentation() 

414 row, _ = self._db.sync( 

415 self._tables.run, 

416 keys={self._collectionIdName: collection_id}, 

417 returning=("host",) + TimespanReprClass.getFieldNames(), 

418 ) 

419 assert row is not None 

420 record = DefaultRunRecord( 

421 db=self._db, 

422 key=collection_id, 

423 name=name, 

424 table=self._tables.run, 

425 idColumnName=self._collectionIdName, 

426 host=row["host"], 

427 timespan=TimespanReprClass.extract(row), 

428 ) 

429 elif type is CollectionType.CHAINED: 

430 record = DefaultChainedCollectionRecord( 

431 db=self._db, 

432 key=collection_id, 

433 name=name, 

434 table=self._tables.collection_chain, 

435 universe=self._dimensions.universe, 

436 ) 

437 else: 

438 record = CollectionRecord(key=collection_id, name=name, type=type) 

439 self._addCachedRecord(record) 

440 return record, registered 

441 

442 def remove(self, name: str) -> None: 

443 # Docstring inherited from CollectionManager. 

444 record = self._getByName(name) 

445 if record is None: 445 ↛ 446line 445 didn't jump to line 446, because the condition on line 445 was never true

446 raise MissingCollectionError(f"No collection with name '{name}' found.") 

447 # This may raise 

448 self._db.delete( 

449 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

450 ) 

451 self._removeCachedRecord(record) 

452 

453 def find(self, name: str) -> CollectionRecord: 

454 # Docstring inherited from CollectionManager. 

455 result = self._getByName(name) 

456 if result is None: 

457 raise MissingCollectionError(f"No collection with name '{name}' found.") 

458 return result 

459 

460 def __getitem__(self, key: Any) -> CollectionRecord: 

461 # Docstring inherited from CollectionManager. 

462 try: 

463 return self._records[key] 

464 except KeyError as err: 

465 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

466 

467 def resolve_wildcard( 

468 self, 

469 wildcard: CollectionWildcard, 

470 *, 

471 collection_types: Set[CollectionType] = CollectionType.all(), 

472 done: set[str] | None = None, 

473 flatten_chains: bool = True, 

474 include_chains: bool | None = None, 

475 ) -> list[CollectionRecord]: 

476 # Docstring inherited 

477 if done is None: 477 ↛ 479line 477 didn't jump to line 479, because the condition on line 477 was never false

478 done = set() 

479 include_chains = include_chains if include_chains is not None else not flatten_chains 

480 

481 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]: 

482 if record.name in done: 

483 return 

484 if record.type in collection_types: 

485 done.add(record.name) 

486 if record.type is not CollectionType.CHAINED or include_chains: 

487 yield record 

488 if flatten_chains and record.type is CollectionType.CHAINED: 

489 done.add(record.name) 

490 for name in cast(ChainedCollectionRecord, record).children: 

491 # flake8 can't tell that we only delete this closure when 

492 # we're totally done with it. 

493 yield from resolve_nested(self.find(name), done) # noqa: F821 

494 

495 result: list[CollectionRecord] = [] 

496 

497 if wildcard.patterns is ...: 

498 for record in self._records.values(): 

499 result.extend(resolve_nested(record, done)) 

500 del resolve_nested 

501 return result 

502 for name in wildcard.strings: 

503 result.extend(resolve_nested(self.find(name), done)) 

504 if wildcard.patterns: 

505 for record in self._records.values(): 

506 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

507 result.extend(resolve_nested(record, done)) 

508 del resolve_nested 

509 return result 

510 

511 def getDocumentation(self, key: Any) -> str | None: 

512 # Docstring inherited from CollectionManager. 

513 sql = ( 

514 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

515 .select_from(self._tables.collection) 

516 .where(self._tables.collection.columns[self._collectionIdName] == key) 

517 ) 

518 with self._db.query(sql) as sql_result: 

519 return sql_result.scalar() 

520 

521 def setDocumentation(self, key: Any, doc: str | None) -> None: 

522 # Docstring inherited from CollectionManager. 

523 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

524 

525 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

526 """Set internal record cache to contain given records, 

527 old cached records will be removed. 

528 """ 

529 self._records = {} 

530 for record in records: 

531 self._records[record.key] = record 

532 

533 def _addCachedRecord(self, record: CollectionRecord) -> None: 

534 """Add single record to cache.""" 

535 self._records[record.key] = record 

536 

537 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

538 """Remove single record from cache.""" 

539 del self._records[record.key] 

540 

541 @abstractmethod 

542 def _getByName(self, name: str) -> CollectionRecord | None: 

543 """Find collection record given collection name.""" 

544 raise NotImplementedError()