Coverage for python/lsst/daf/butler/registry/collections/_base.py: 88%

152 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-30 02:18 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = () 

24 

25import itertools 

26from abc import abstractmethod 

27from collections import namedtuple 

28from collections.abc import Iterable, Iterator 

29from typing import TYPE_CHECKING, Any, Generic, TypeVar 

30 

31import sqlalchemy 

32 

33from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl 

34from .._collectionType import CollectionType 

35from .._exceptions import MissingCollectionError 

36from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord 

37from ..wildcards import CollectionSearch 

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 from ..interfaces import Database, DimensionRecordStorageManager 

41 

42 

43def _makeCollectionForeignKey( 

44 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

45) -> ddl.ForeignKeySpec: 

46 """Define foreign key specification that refers to collections table. 

47 

48 Parameters 

49 ---------- 

50 sourceColumnName : `str` 

51 Name of the column in the referring table. 

52 collectionIdName : `str` 

53 Name of the column in collections table that identifies it (PK). 

54 **kwargs 

55 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

56 

57 Returns 

58 ------- 

59 spec : `ddl.ForeignKeySpec` 

60 Foreign key specification. 

61 

62 Notes 

63 ----- 

64 This method assumes fixed name ("collection") of a collections table. 

65 There is also a general assumption that collection primary key consists 

66 of a single column. 

67 """ 

68 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

69 

70 

71CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

72 

73 

74def makeRunTableSpec( 

75 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

76) -> ddl.TableSpec: 

77 """Define specification for "run" table. 

78 

79 Parameters 

80 ---------- 

81 collectionIdName : `str` 

82 Name of the column in collections table that identifies it (PK). 

83 collectionIdType 

84 Type of the PK column in the collections table, one of the 

85 `sqlalchemy` types. 

86 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

87 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

88 timespans are stored in this database. 

89 

90 

91 Returns 

92 ------- 

93 spec : `ddl.TableSpec` 

94 Specification for run table. 

95 

96 Notes 

97 ----- 

98 Assumption here and in the code below is that the name of the identifying 

99 column is the same in both collections and run tables. The names of 

100 non-identifying columns containing run metadata are fixed. 

101 """ 

102 result = ddl.TableSpec( 

103 fields=[ 

104 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

105 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

106 ], 

107 foreignKeys=[ 

108 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

109 ], 

110 ) 

111 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

112 result.fields.add(fieldSpec) 

113 return result 

114 

115 

116def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

117 """Define specification for "collection_chain" table. 

118 

119 Parameters 

120 ---------- 

121 collectionIdName : `str` 

122 Name of the column in collections table that identifies it (PK). 

123 collectionIdType 

124 Type of the PK column in the collections table, one of the 

125 `sqlalchemy` types. 

126 

127 Returns 

128 ------- 

129 spec : `ddl.TableSpec` 

130 Specification for collection chain table. 

131 

132 Notes 

133 ----- 

134 Collection chain is simply an ordered one-to-many relation between 

135 collections. The names of the columns in the table are fixed and 

136 also hardcoded in the code below. 

137 """ 

138 return ddl.TableSpec( 

139 fields=[ 

140 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

141 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

142 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

143 ], 

144 foreignKeys=[ 

145 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

146 _makeCollectionForeignKey("child", collectionIdName), 

147 ], 

148 ) 

149 

150 

151class DefaultRunRecord(RunRecord): 

152 """Default `RunRecord` implementation. 

153 

154 This method assumes the same run table definition as produced by 

155 `makeRunTableSpec` method. The only non-fixed name in the schema 

156 is the PK column name, this needs to be passed in a constructor. 

157 

158 Parameters 

159 ---------- 

160 db : `Database` 

161 Registry database. 

162 key 

163 Unique collection ID, can be the same as ``name`` if ``name`` is used 

164 for identification. Usually this is an integer or string, but can be 

165 other database-specific type. 

166 name : `str` 

167 Run collection name. 

168 table : `sqlalchemy.schema.Table` 

169 Table for run records. 

170 idColumnName : `str` 

171 Name of the identifying column in run table. 

172 host : `str`, optional 

173 Name of the host where run was produced. 

174 timespan : `Timespan`, optional 

175 Timespan for this run. 

176 """ 

177 

178 def __init__( 

179 self, 

180 db: Database, 

181 key: Any, 

182 name: str, 

183 *, 

184 table: sqlalchemy.schema.Table, 

185 idColumnName: str, 

186 host: str | None = None, 

187 timespan: Timespan | None = None, 

188 ): 

189 super().__init__(key=key, name=name, type=CollectionType.RUN) 

190 self._db = db 

191 self._table = table 

192 self._host = host 

193 if timespan is None: 193 ↛ 195line 193 didn't jump to line 195, because the condition on line 193 was never false

194 timespan = Timespan(begin=None, end=None) 

195 self._timespan = timespan 

196 self._idName = idColumnName 

197 

198 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: 

199 # Docstring inherited from RunRecord. 

200 if timespan is None: 

201 timespan = Timespan(begin=None, end=None) 

202 row = { 

203 self._idName: self.key, 

204 "host": host, 

205 } 

206 self._db.getTimespanRepresentation().update(timespan, result=row) 

207 count = self._db.update(self._table, {self._idName: self.key}, row) 

208 if count != 1: 

209 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

210 self._host = host 

211 self._timespan = timespan 

212 

213 @property 

214 def host(self) -> str | None: 

215 # Docstring inherited from RunRecord. 

216 return self._host 

217 

218 @property 

219 def timespan(self) -> Timespan: 

220 # Docstring inherited from RunRecord. 

221 return self._timespan 

222 

223 

224class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

225 """Default `ChainedCollectionRecord` implementation. 

226 

227 This method assumes the same chain table definition as produced by 

228 `makeCollectionChainTableSpec` method. All column names in the table are 

229 fixed and hard-coded in the methods. 

230 

231 Parameters 

232 ---------- 

233 db : `Database` 

234 Registry database. 

235 key 

236 Unique collection ID, can be the same as ``name`` if ``name`` is used 

237 for identification. Usually this is an integer or string, but can be 

238 other database-specific type. 

239 name : `str` 

240 Collection name. 

241 table : `sqlalchemy.schema.Table` 

242 Table for chain relationship records. 

243 universe : `DimensionUniverse` 

244 Object managing all known dimensions. 

245 """ 

246 

247 def __init__( 

248 self, 

249 db: Database, 

250 key: Any, 

251 name: str, 

252 *, 

253 table: sqlalchemy.schema.Table, 

254 universe: DimensionUniverse, 

255 ): 

256 super().__init__(key=key, name=name, universe=universe) 

257 self._db = db 

258 self._table = table 

259 self._universe = universe 

260 

261 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None: 

262 # Docstring inherited from ChainedCollectionRecord. 

263 rows = [] 

264 position = itertools.count() 

265 for child in children.iter(manager, flattenChains=False): 

266 rows.append( 

267 { 

268 "parent": self.key, 

269 "child": child.key, 

270 "position": next(position), 

271 } 

272 ) 

273 with self._db.transaction(): 

274 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

275 self._db.insert(self._table, *rows) 

276 

277 def _load(self, manager: CollectionManager) -> CollectionSearch: 

278 # Docstring inherited from ChainedCollectionRecord. 

279 sql = ( 

280 sqlalchemy.sql.select( 

281 self._table.columns.child, 

282 ) 

283 .select_from(self._table) 

284 .where(self._table.columns.parent == self.key) 

285 .order_by(self._table.columns.position) 

286 ) 

287 return CollectionSearch.fromExpression( 

288 [manager[row._mapping[self._table.columns.child]].name for row in self._db.query(sql)] 

289 ) 

290 

291 

292K = TypeVar("K") 

293 

294 

295class DefaultCollectionManager(Generic[K], CollectionManager): 

296 """Default `CollectionManager` implementation. 

297 

298 This implementation uses record classes defined in this module and is 

299 based on the same assumptions about schema outlined in the record classes. 

300 

301 Parameters 

302 ---------- 

303 db : `Database` 

304 Interface to the underlying database engine and namespace. 

305 tables : `CollectionTablesTuple` 

306 Named tuple of SQLAlchemy table objects. 

307 collectionIdName : `str` 

308 Name of the column in collections table that identifies it (PK). 

309 dimensions : `DimensionRecordStorageManager` 

310 Manager object for the dimensions in this `Registry`. 

311 

312 Notes 

313 ----- 

314 Implementation uses "aggressive" pre-fetching and caching of the records 

315 in memory. Memory cache is synchronized from database when `refresh` 

316 method is called. 

317 """ 

318 

319 def __init__( 

320 self, 

321 db: Database, 

322 tables: CollectionTablesTuple, 

323 collectionIdName: str, 

324 *, 

325 dimensions: DimensionRecordStorageManager, 

326 ): 

327 super().__init__() 

328 self._db = db 

329 self._tables = tables 

330 self._collectionIdName = collectionIdName 

331 self._records: dict[K, CollectionRecord] = {} # indexed by record ID 

332 self._dimensions = dimensions 

333 

334 def refresh(self) -> None: 

335 # Docstring inherited from CollectionManager. 

336 sql = sqlalchemy.sql.select( 

337 *(list(self._tables.collection.columns) + list(self._tables.run.columns)) 

338 ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) 

339 # Put found records into a temporary instead of updating self._records 

340 # in place, for exception safety. 

341 records = [] 

342 chains = [] 

343 TimespanReprClass = self._db.getTimespanRepresentation() 

344 for row in self._db.query(sql).mappings(): 

345 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

346 name = row[self._tables.collection.columns.name] 

347 type = CollectionType(row["type"]) 

348 record: CollectionRecord 

349 if type is CollectionType.RUN: 

350 record = DefaultRunRecord( 

351 key=collection_id, 

352 name=name, 

353 db=self._db, 

354 table=self._tables.run, 

355 idColumnName=self._collectionIdName, 

356 host=row[self._tables.run.columns.host], 

357 timespan=TimespanReprClass.extract(row), 

358 ) 

359 elif type is CollectionType.CHAINED: 

360 record = DefaultChainedCollectionRecord( 

361 db=self._db, 

362 key=collection_id, 

363 table=self._tables.collection_chain, 

364 name=name, 

365 universe=self._dimensions.universe, 

366 ) 

367 chains.append(record) 

368 else: 

369 record = CollectionRecord(key=collection_id, name=name, type=type) 

370 records.append(record) 

371 self._setRecordCache(records) 

372 for chain in chains: 

373 try: 

374 chain.refresh(self) 

375 except MissingCollectionError: 

376 # This indicates a race condition in which some other client 

377 # created a new collection and added it as a child of this 

378 # (pre-existing) chain between the time we fetched all 

379 # collections and the time we queried for parent-child 

380 # relationships. 

381 # Because that's some other unrelated client, we shouldn't care 

382 # about that parent collection anyway, so we just drop it on 

383 # the floor (a manual refresh can be used to get it back). 

384 self._removeCachedRecord(chain) 

385 

386 def register( 

387 self, name: str, type: CollectionType, doc: str | None = None 

388 ) -> tuple[CollectionRecord, bool]: 

389 # Docstring inherited from CollectionManager. 

390 registered = False 

391 record = self._getByName(name) 

392 if record is None: 

393 row, inserted_or_updated = self._db.sync( 

394 self._tables.collection, 

395 keys={"name": name}, 

396 compared={"type": int(type)}, 

397 extra={"doc": doc}, 

398 returning=[self._collectionIdName], 

399 ) 

400 assert isinstance(inserted_or_updated, bool) 

401 registered = inserted_or_updated 

402 assert row is not None 

403 collection_id = row[self._collectionIdName] 

404 if type is CollectionType.RUN: 

405 TimespanReprClass = self._db.getTimespanRepresentation() 

406 row, _ = self._db.sync( 

407 self._tables.run, 

408 keys={self._collectionIdName: collection_id}, 

409 returning=("host",) + TimespanReprClass.getFieldNames(), 

410 ) 

411 assert row is not None 

412 record = DefaultRunRecord( 

413 db=self._db, 

414 key=collection_id, 

415 name=name, 

416 table=self._tables.run, 

417 idColumnName=self._collectionIdName, 

418 host=row["host"], 

419 timespan=TimespanReprClass.extract(row), 

420 ) 

421 elif type is CollectionType.CHAINED: 

422 record = DefaultChainedCollectionRecord( 

423 db=self._db, 

424 key=collection_id, 

425 name=name, 

426 table=self._tables.collection_chain, 

427 universe=self._dimensions.universe, 

428 ) 

429 else: 

430 record = CollectionRecord(key=collection_id, name=name, type=type) 

431 self._addCachedRecord(record) 

432 return record, registered 

433 

434 def remove(self, name: str) -> None: 

435 # Docstring inherited from CollectionManager. 

436 record = self._getByName(name) 

437 if record is None: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true

438 raise MissingCollectionError(f"No collection with name '{name}' found.") 

439 # This may raise 

440 self._db.delete( 

441 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

442 ) 

443 self._removeCachedRecord(record) 

444 

445 def find(self, name: str) -> CollectionRecord: 

446 # Docstring inherited from CollectionManager. 

447 result = self._getByName(name) 

448 if result is None: 

449 raise MissingCollectionError(f"No collection with name '{name}' found.") 

450 return result 

451 

452 def __getitem__(self, key: Any) -> CollectionRecord: 

453 # Docstring inherited from CollectionManager. 

454 try: 

455 return self._records[key] 

456 except KeyError as err: 

457 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

458 

459 def __iter__(self) -> Iterator[CollectionRecord]: 

460 yield from self._records.values() 

461 

462 def getDocumentation(self, key: Any) -> str | None: 

463 # Docstring inherited from CollectionManager. 

464 sql = ( 

465 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

466 .select_from(self._tables.collection) 

467 .where(self._tables.collection.columns[self._collectionIdName] == key) 

468 ) 

469 return self._db.query(sql).scalar() 

470 

471 def setDocumentation(self, key: Any, doc: str | None) -> None: 

472 # Docstring inherited from CollectionManager. 

473 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

474 

475 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

476 """Set internal record cache to contain given records, 

477 old cached records will be removed. 

478 """ 

479 self._records = {} 

480 for record in records: 

481 self._records[record.key] = record 

482 

483 def _addCachedRecord(self, record: CollectionRecord) -> None: 

484 """Add single record to cache.""" 

485 self._records[record.key] = record 

486 

487 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

488 """Remove single record from cache.""" 

489 del self._records[record.key] 

490 

491 @abstractmethod 

492 def _getByName(self, name: str) -> CollectionRecord | None: 

493 """Find collection record given collection name.""" 

494 raise NotImplementedError()