Coverage for python/lsst/daf/butler/registry/collections/_base.py: 91%

181 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-17 09:32 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = () 

24 

25import itertools 

26from abc import abstractmethod 

27from collections import namedtuple 

28from collections.abc import Iterable, Iterator, Set 

29from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast 

30 

31import sqlalchemy 

32from lsst.utils.ellipsis import Ellipsis 

33 

34from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl 

35from .._collectionType import CollectionType 

36from .._exceptions import MissingCollectionError 

37from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

38from ..wildcards import CollectionWildcard 

39 

40if TYPE_CHECKING: 

41 from ..interfaces import Database, DimensionRecordStorageManager 

42 

43 

44def _makeCollectionForeignKey( 

45 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

46) -> ddl.ForeignKeySpec: 

47 """Define foreign key specification that refers to collections table. 

48 

49 Parameters 

50 ---------- 

51 sourceColumnName : `str` 

52 Name of the column in the referring table. 

53 collectionIdName : `str` 

54 Name of the column in collections table that identifies it (PK). 

55 **kwargs 

56 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

57 

58 Returns 

59 ------- 

60 spec : `ddl.ForeignKeySpec` 

61 Foreign key specification. 

62 

63 Notes 

64 ----- 

65 This method assumes fixed name ("collection") of a collections table. 

66 There is also a general assumption that collection primary key consists 

67 of a single column. 

68 """ 

69 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

70 

71 

72CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

73 

74 

75def makeRunTableSpec( 

76 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

77) -> ddl.TableSpec: 

78 """Define specification for "run" table. 

79 

80 Parameters 

81 ---------- 

82 collectionIdName : `str` 

83 Name of the column in collections table that identifies it (PK). 

84 collectionIdType 

85 Type of the PK column in the collections table, one of the 

86 `sqlalchemy` types. 

87 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

88 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

89 timespans are stored in this database. 

90 

91 

92 Returns 

93 ------- 

94 spec : `ddl.TableSpec` 

95 Specification for run table. 

96 

97 Notes 

98 ----- 

99 Assumption here and in the code below is that the name of the identifying 

100 column is the same in both collections and run tables. The names of 

101 non-identifying columns containing run metadata are fixed. 

102 """ 

103 result = ddl.TableSpec( 

104 fields=[ 

105 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

106 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

107 ], 

108 foreignKeys=[ 

109 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

110 ], 

111 ) 

112 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

113 result.fields.add(fieldSpec) 

114 return result 

115 

116 

117def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

118 """Define specification for "collection_chain" table. 

119 

120 Parameters 

121 ---------- 

122 collectionIdName : `str` 

123 Name of the column in collections table that identifies it (PK). 

124 collectionIdType 

125 Type of the PK column in the collections table, one of the 

126 `sqlalchemy` types. 

127 

128 Returns 

129 ------- 

130 spec : `ddl.TableSpec` 

131 Specification for collection chain table. 

132 

133 Notes 

134 ----- 

135 Collection chain is simply an ordered one-to-many relation between 

136 collections. The names of the columns in the table are fixed and 

137 also hardcoded in the code below. 

138 """ 

139 return ddl.TableSpec( 

140 fields=[ 

141 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

142 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

143 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

144 ], 

145 foreignKeys=[ 

146 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

147 _makeCollectionForeignKey("child", collectionIdName), 

148 ], 

149 ) 

150 

151 

152class DefaultRunRecord(RunRecord): 

153 """Default `RunRecord` implementation. 

154 

155 This method assumes the same run table definition as produced by 

156 `makeRunTableSpec` method. The only non-fixed name in the schema 

157 is the PK column name, this needs to be passed in a constructor. 

158 

159 Parameters 

160 ---------- 

161 db : `Database` 

162 Registry database. 

163 key 

164 Unique collection ID, can be the same as ``name`` if ``name`` is used 

165 for identification. Usually this is an integer or string, but can be 

166 other database-specific type. 

167 name : `str` 

168 Run collection name. 

169 table : `sqlalchemy.schema.Table` 

170 Table for run records. 

171 idColumnName : `str` 

172 Name of the identifying column in run table. 

173 host : `str`, optional 

174 Name of the host where run was produced. 

175 timespan : `Timespan`, optional 

176 Timespan for this run. 

177 """ 

178 

179 def __init__( 

180 self, 

181 db: Database, 

182 key: Any, 

183 name: str, 

184 *, 

185 table: sqlalchemy.schema.Table, 

186 idColumnName: str, 

187 host: str | None = None, 

188 timespan: Timespan | None = None, 

189 ): 

190 super().__init__(key=key, name=name, type=CollectionType.RUN) 

191 self._db = db 

192 self._table = table 

193 self._host = host 

194 if timespan is None: 194 ↛ 196line 194 didn't jump to line 196, because the condition on line 194 was never false

195 timespan = Timespan(begin=None, end=None) 

196 self._timespan = timespan 

197 self._idName = idColumnName 

198 

199 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: 

200 # Docstring inherited from RunRecord. 

201 if timespan is None: 

202 timespan = Timespan(begin=None, end=None) 

203 row = { 

204 self._idName: self.key, 

205 "host": host, 

206 } 

207 self._db.getTimespanRepresentation().update(timespan, result=row) 

208 count = self._db.update(self._table, {self._idName: self.key}, row) 

209 if count != 1: 

210 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

211 self._host = host 

212 self._timespan = timespan 

213 

214 @property 

215 def host(self) -> str | None: 

216 # Docstring inherited from RunRecord. 

217 return self._host 

218 

219 @property 

220 def timespan(self) -> Timespan: 

221 # Docstring inherited from RunRecord. 

222 return self._timespan 

223 

224 

225class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

226 """Default `ChainedCollectionRecord` implementation. 

227 

228 This method assumes the same chain table definition as produced by 

229 `makeCollectionChainTableSpec` method. All column names in the table are 

230 fixed and hard-coded in the methods. 

231 

232 Parameters 

233 ---------- 

234 db : `Database` 

235 Registry database. 

236 key 

237 Unique collection ID, can be the same as ``name`` if ``name`` is used 

238 for identification. Usually this is an integer or string, but can be 

239 other database-specific type. 

240 name : `str` 

241 Collection name. 

242 table : `sqlalchemy.schema.Table` 

243 Table for chain relationship records. 

244 universe : `DimensionUniverse` 

245 Object managing all known dimensions. 

246 """ 

247 

248 def __init__( 

249 self, 

250 db: Database, 

251 key: Any, 

252 name: str, 

253 *, 

254 table: sqlalchemy.schema.Table, 

255 universe: DimensionUniverse, 

256 ): 

257 super().__init__(key=key, name=name, universe=universe) 

258 self._db = db 

259 self._table = table 

260 self._universe = universe 

261 

262 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: 

263 # Docstring inherited from ChainedCollectionRecord. 

264 rows = [] 

265 position = itertools.count() 

266 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

267 rows.append( 

268 { 

269 "parent": self.key, 

270 "child": child.key, 

271 "position": next(position), 

272 } 

273 ) 

274 with self._db.transaction(): 

275 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

276 self._db.insert(self._table, *rows) 

277 

278 def _load(self, manager: CollectionManager) -> tuple[str, ...]: 

279 # Docstring inherited from ChainedCollectionRecord. 

280 sql = ( 

281 sqlalchemy.sql.select( 

282 self._table.columns.child, 

283 ) 

284 .select_from(self._table) 

285 .where(self._table.columns.parent == self.key) 

286 .order_by(self._table.columns.position) 

287 ) 

288 with self._db.query(sql) as sql_result: 

289 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings()) 

290 

291 

292K = TypeVar("K") 

293 

294 

295class DefaultCollectionManager(Generic[K], CollectionManager): 

296 """Default `CollectionManager` implementation. 

297 

298 This implementation uses record classes defined in this module and is 

299 based on the same assumptions about schema outlined in the record classes. 

300 

301 Parameters 

302 ---------- 

303 db : `Database` 

304 Interface to the underlying database engine and namespace. 

305 tables : `CollectionTablesTuple` 

306 Named tuple of SQLAlchemy table objects. 

307 collectionIdName : `str` 

308 Name of the column in collections table that identifies it (PK). 

309 dimensions : `DimensionRecordStorageManager` 

310 Manager object for the dimensions in this `Registry`. 

311 

312 Notes 

313 ----- 

314 Implementation uses "aggressive" pre-fetching and caching of the records 

315 in memory. Memory cache is synchronized from database when `refresh` 

316 method is called. 

317 """ 

318 

319 def __init__( 

320 self, 

321 db: Database, 

322 tables: CollectionTablesTuple, 

323 collectionIdName: str, 

324 *, 

325 dimensions: DimensionRecordStorageManager, 

326 registry_schema_version: VersionTuple | None = None, 

327 ): 

328 super().__init__(registry_schema_version=registry_schema_version) 

329 self._db = db 

330 self._tables = tables 

331 self._collectionIdName = collectionIdName 

332 self._records: dict[K, CollectionRecord] = {} # indexed by record ID 

333 self._dimensions = dimensions 

334 

335 def refresh(self) -> None: 

336 # Docstring inherited from CollectionManager. 

337 sql = sqlalchemy.sql.select( 

338 *(list(self._tables.collection.columns) + list(self._tables.run.columns)) 

339 ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) 

340 # Put found records into a temporary instead of updating self._records 

341 # in place, for exception safety. 

342 records = [] 

343 chains = [] 

344 TimespanReprClass = self._db.getTimespanRepresentation() 

345 with self._db.query(sql) as sql_result: 

346 sql_rows = sql_result.mappings().fetchall() 

347 for row in sql_rows: 

348 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

349 name = row[self._tables.collection.columns.name] 

350 type = CollectionType(row["type"]) 

351 record: CollectionRecord 

352 if type is CollectionType.RUN: 

353 record = DefaultRunRecord( 

354 key=collection_id, 

355 name=name, 

356 db=self._db, 

357 table=self._tables.run, 

358 idColumnName=self._collectionIdName, 

359 host=row[self._tables.run.columns.host], 

360 timespan=TimespanReprClass.extract(row), 

361 ) 

362 elif type is CollectionType.CHAINED: 

363 record = DefaultChainedCollectionRecord( 

364 db=self._db, 

365 key=collection_id, 

366 table=self._tables.collection_chain, 

367 name=name, 

368 universe=self._dimensions.universe, 

369 ) 

370 chains.append(record) 

371 else: 

372 record = CollectionRecord(key=collection_id, name=name, type=type) 

373 records.append(record) 

374 self._setRecordCache(records) 

375 for chain in chains: 

376 try: 

377 chain.refresh(self) 

378 except MissingCollectionError: 

379 # This indicates a race condition in which some other client 

380 # created a new collection and added it as a child of this 

381 # (pre-existing) chain between the time we fetched all 

382 # collections and the time we queried for parent-child 

383 # relationships. 

384 # Because that's some other unrelated client, we shouldn't care 

385 # about that parent collection anyway, so we just drop it on 

386 # the floor (a manual refresh can be used to get it back). 

387 self._removeCachedRecord(chain) 

388 

389 def register( 

390 self, name: str, type: CollectionType, doc: str | None = None 

391 ) -> tuple[CollectionRecord, bool]: 

392 # Docstring inherited from CollectionManager. 

393 registered = False 

394 record = self._getByName(name) 

395 if record is None: 

396 row, inserted_or_updated = self._db.sync( 

397 self._tables.collection, 

398 keys={"name": name}, 

399 compared={"type": int(type)}, 

400 extra={"doc": doc}, 

401 returning=[self._collectionIdName], 

402 ) 

403 assert isinstance(inserted_or_updated, bool) 

404 registered = inserted_or_updated 

405 assert row is not None 

406 collection_id = row[self._collectionIdName] 

407 if type is CollectionType.RUN: 

408 TimespanReprClass = self._db.getTimespanRepresentation() 

409 row, _ = self._db.sync( 

410 self._tables.run, 

411 keys={self._collectionIdName: collection_id}, 

412 returning=("host",) + TimespanReprClass.getFieldNames(), 

413 ) 

414 assert row is not None 

415 record = DefaultRunRecord( 

416 db=self._db, 

417 key=collection_id, 

418 name=name, 

419 table=self._tables.run, 

420 idColumnName=self._collectionIdName, 

421 host=row["host"], 

422 timespan=TimespanReprClass.extract(row), 

423 ) 

424 elif type is CollectionType.CHAINED: 

425 record = DefaultChainedCollectionRecord( 

426 db=self._db, 

427 key=collection_id, 

428 name=name, 

429 table=self._tables.collection_chain, 

430 universe=self._dimensions.universe, 

431 ) 

432 else: 

433 record = CollectionRecord(key=collection_id, name=name, type=type) 

434 self._addCachedRecord(record) 

435 return record, registered 

436 

437 def remove(self, name: str) -> None: 

438 # Docstring inherited from CollectionManager. 

439 record = self._getByName(name) 

440 if record is None: 440 ↛ 441line 440 didn't jump to line 441, because the condition on line 440 was never true

441 raise MissingCollectionError(f"No collection with name '{name}' found.") 

442 # This may raise 

443 self._db.delete( 

444 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

445 ) 

446 self._removeCachedRecord(record) 

447 

448 def find(self, name: str) -> CollectionRecord: 

449 # Docstring inherited from CollectionManager. 

450 result = self._getByName(name) 

451 if result is None: 

452 raise MissingCollectionError(f"No collection with name '{name}' found.") 

453 return result 

454 

455 def __getitem__(self, key: Any) -> CollectionRecord: 

456 # Docstring inherited from CollectionManager. 

457 try: 

458 return self._records[key] 

459 except KeyError as err: 

460 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

461 

462 def resolve_wildcard( 

463 self, 

464 wildcard: CollectionWildcard, 

465 *, 

466 collection_types: Set[CollectionType] = CollectionType.all(), 

467 done: set[str] | None = None, 

468 flatten_chains: bool = True, 

469 include_chains: bool | None = None, 

470 ) -> list[CollectionRecord]: 

471 # Docstring inherited 

472 if done is None: 472 ↛ 474line 472 didn't jump to line 474, because the condition on line 472 was never false

473 done = set() 

474 include_chains = include_chains if include_chains is not None else not flatten_chains 

475 

476 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]: 

477 if record.name in done: 

478 return 

479 if record.type in collection_types: 

480 done.add(record.name) 

481 if record.type is not CollectionType.CHAINED or include_chains: 

482 yield record 

483 if flatten_chains and record.type is CollectionType.CHAINED: 

484 done.add(record.name) 

485 for name in cast(ChainedCollectionRecord, record).children: 

486 # flake8 can't tell that we only delete this closure when 

487 # we're totally done with it. 

488 yield from resolve_nested(self.find(name), done) # noqa: F821 

489 

490 result: list[CollectionRecord] = [] 

491 

492 if wildcard.patterns is Ellipsis: 

493 for record in self._records.values(): 

494 result.extend(resolve_nested(record, done)) 

495 del resolve_nested 

496 return result 

497 for name in wildcard.strings: 

498 result.extend(resolve_nested(self.find(name), done)) 

499 if wildcard.patterns: 

500 for record in self._records.values(): 

501 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

502 result.extend(resolve_nested(record, done)) 

503 del resolve_nested 

504 return result 

505 

506 def getDocumentation(self, key: Any) -> str | None: 

507 # Docstring inherited from CollectionManager. 

508 sql = ( 

509 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

510 .select_from(self._tables.collection) 

511 .where(self._tables.collection.columns[self._collectionIdName] == key) 

512 ) 

513 with self._db.query(sql) as sql_result: 

514 return sql_result.scalar() 

515 

516 def setDocumentation(self, key: Any, doc: str | None) -> None: 

517 # Docstring inherited from CollectionManager. 

518 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

519 

520 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

521 """Set internal record cache to contain given records, 

522 old cached records will be removed. 

523 """ 

524 self._records = {} 

525 for record in records: 

526 self._records[record.key] = record 

527 

528 def _addCachedRecord(self, record: CollectionRecord) -> None: 

529 """Add single record to cache.""" 

530 self._records[record.key] = record 

531 

532 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

533 """Remove single record from cache.""" 

534 del self._records[record.key] 

535 

536 @abstractmethod 

537 def _getByName(self, name: str) -> CollectionRecord | None: 

538 """Find collection record given collection name.""" 

539 raise NotImplementedError()