Coverage for python/lsst/daf/butler/registry/collections/_base.py: 88%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

150 statements  

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = () 

24 

25import itertools 

26from abc import abstractmethod 

27from collections import namedtuple 

28from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Tuple, Type, TypeVar 

29 

30import sqlalchemy 

31 

32from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl 

33from .._collectionType import CollectionType 

34from .._exceptions import MissingCollectionError 

35from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord 

36from ..wildcards import CollectionSearch 

37 

38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true

39 from ..interfaces import Database, DimensionRecordStorageManager 

40 

41 

42def _makeCollectionForeignKey( 

43 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

44) -> ddl.ForeignKeySpec: 

45 """Define foreign key specification that refers to collections table. 

46 

47 Parameters 

48 ---------- 

49 sourceColumnName : `str` 

50 Name of the column in the referring table. 

51 collectionIdName : `str` 

52 Name of the column in collections table that identifies it (PK). 

53 **kwargs 

54 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

55 

56 Returns 

57 ------- 

58 spec : `ddl.ForeignKeySpec` 

59 Foreign key specification. 

60 

61 Notes 

62 ----- 

63 This method assumes fixed name ("collection") of a collections table. 

64 There is also a general assumption that collection primary key consists 

65 of a single column. 

66 """ 

67 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

68 

69 

70CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

71 

72 

73def makeRunTableSpec( 

74 collectionIdName: str, collectionIdType: type, TimespanReprClass: Type[TimespanDatabaseRepresentation] 

75) -> ddl.TableSpec: 

76 """Define specification for "run" table. 

77 

78 Parameters 

79 ---------- 

80 collectionIdName : `str` 

81 Name of the column in collections table that identifies it (PK). 

82 collectionIdType 

83 Type of the PK column in the collections table, one of the 

84 `sqlalchemy` types. 

85 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

86 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

87 timespans are stored in this database. 

88 

89 

90 Returns 

91 ------- 

92 spec : `ddl.TableSpec` 

93 Specification for run table. 

94 

95 Notes 

96 ----- 

97 Assumption here and in the code below is that the name of the identifying 

98 column is the same in both collections and run tables. The names of 

99 non-identifying columns containing run metadata are fixed. 

100 """ 

101 result = ddl.TableSpec( 

102 fields=[ 

103 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

104 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

105 ], 

106 foreignKeys=[ 

107 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

108 ], 

109 ) 

110 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

111 result.fields.add(fieldSpec) 

112 return result 

113 

114 

115def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

116 """Define specification for "collection_chain" table. 

117 

118 Parameters 

119 ---------- 

120 collectionIdName : `str` 

121 Name of the column in collections table that identifies it (PK). 

122 collectionIdType 

123 Type of the PK column in the collections table, one of the 

124 `sqlalchemy` types. 

125 

126 Returns 

127 ------- 

128 spec : `ddl.TableSpec` 

129 Specification for collection chain table. 

130 

131 Notes 

132 ----- 

133 Collection chain is simply an ordered one-to-many relation between 

134 collections. The names of the columns in the table are fixed and 

135 also hardcoded in the code below. 

136 """ 

137 return ddl.TableSpec( 

138 fields=[ 

139 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

140 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

141 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

142 ], 

143 foreignKeys=[ 

144 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

145 _makeCollectionForeignKey("child", collectionIdName), 

146 ], 

147 ) 

148 

149 

150class DefaultRunRecord(RunRecord): 

151 """Default `RunRecord` implementation. 

152 

153 This method assumes the same run table definition as produced by 

154 `makeRunTableSpec` method. The only non-fixed name in the schema 

155 is the PK column name, this needs to be passed in a constructor. 

156 

157 Parameters 

158 ---------- 

159 db : `Database` 

160 Registry database. 

161 key 

162 Unique collection ID, can be the same as ``name`` if ``name`` is used 

163 for identification. Usually this is an integer or string, but can be 

164 other database-specific type. 

165 name : `str` 

166 Run collection name. 

167 table : `sqlalchemy.schema.Table` 

168 Table for run records. 

169 idColumnName : `str` 

170 Name of the identifying column in run table. 

171 host : `str`, optional 

172 Name of the host where run was produced. 

173 timespan : `Timespan`, optional 

174 Timespan for this run. 

175 """ 

176 

177 def __init__( 

178 self, 

179 db: Database, 

180 key: Any, 

181 name: str, 

182 *, 

183 table: sqlalchemy.schema.Table, 

184 idColumnName: str, 

185 host: Optional[str] = None, 

186 timespan: Optional[Timespan] = None, 

187 ): 

188 super().__init__(key=key, name=name, type=CollectionType.RUN) 

189 self._db = db 

190 self._table = table 

191 self._host = host 

192 if timespan is None: 192 ↛ 194line 192 didn't jump to line 194, because the condition on line 192 was never false

193 timespan = Timespan(begin=None, end=None) 

194 self._timespan = timespan 

195 self._idName = idColumnName 

196 

197 def update(self, host: Optional[str] = None, timespan: Optional[Timespan] = None) -> None: 

198 # Docstring inherited from RunRecord. 

199 if timespan is None: 

200 timespan = Timespan(begin=None, end=None) 

201 row = { 

202 self._idName: self.key, 

203 "host": host, 

204 } 

205 self._db.getTimespanRepresentation().update(timespan, result=row) 

206 count = self._db.update(self._table, {self._idName: self.key}, row) 

207 if count != 1: 

208 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

209 self._host = host 

210 self._timespan = timespan 

211 

212 @property 

213 def host(self) -> Optional[str]: 

214 # Docstring inherited from RunRecord. 

215 return self._host 

216 

217 @property 

218 def timespan(self) -> Timespan: 

219 # Docstring inherited from RunRecord. 

220 return self._timespan 

221 

222 

223class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

224 """Default `ChainedCollectionRecord` implementation. 

225 

226 This method assumes the same chain table definition as produced by 

227 `makeCollectionChainTableSpec` method. All column names in the table are 

228 fixed and hard-coded in the methods. 

229 

230 Parameters 

231 ---------- 

232 db : `Database` 

233 Registry database. 

234 key 

235 Unique collection ID, can be the same as ``name`` if ``name`` is used 

236 for identification. Usually this is an integer or string, but can be 

237 other database-specific type. 

238 name : `str` 

239 Collection name. 

240 table : `sqlalchemy.schema.Table` 

241 Table for chain relationship records. 

242 universe : `DimensionUniverse` 

243 Object managing all known dimensions. 

244 """ 

245 

246 def __init__( 

247 self, 

248 db: Database, 

249 key: Any, 

250 name: str, 

251 *, 

252 table: sqlalchemy.schema.Table, 

253 universe: DimensionUniverse, 

254 ): 

255 super().__init__(key=key, name=name, universe=universe) 

256 self._db = db 

257 self._table = table 

258 self._universe = universe 

259 

260 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None: 

261 # Docstring inherited from ChainedCollectionRecord. 

262 rows = [] 

263 position = itertools.count() 

264 for child in children.iter(manager, flattenChains=False): 

265 rows.append( 

266 { 

267 "parent": self.key, 

268 "child": child.key, 

269 "position": next(position), 

270 } 

271 ) 

272 with self._db.transaction(): 

273 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

274 self._db.insert(self._table, *rows) 

275 

276 def _load(self, manager: CollectionManager) -> CollectionSearch: 

277 # Docstring inherited from ChainedCollectionRecord. 

278 sql = ( 

279 sqlalchemy.sql.select( 

280 self._table.columns.child, 

281 ) 

282 .select_from(self._table) 

283 .where(self._table.columns.parent == self.key) 

284 .order_by(self._table.columns.position) 

285 ) 

286 return CollectionSearch.fromExpression( 

287 [manager[row._mapping[self._table.columns.child]].name for row in self._db.query(sql)] 

288 ) 

289 

290 

291K = TypeVar("K") 

292 

293 

294class DefaultCollectionManager(Generic[K], CollectionManager): 

295 """Default `CollectionManager` implementation. 

296 

297 This implementation uses record classes defined in this module and is 

298 based on the same assumptions about schema outlined in the record classes. 

299 

300 Parameters 

301 ---------- 

302 db : `Database` 

303 Interface to the underlying database engine and namespace. 

304 tables : `CollectionTablesTuple` 

305 Named tuple of SQLAlchemy table objects. 

306 collectionIdName : `str` 

307 Name of the column in collections table that identifies it (PK). 

308 dimensions : `DimensionRecordStorageManager` 

309 Manager object for the dimensions in this `Registry`. 

310 

311 Notes 

312 ----- 

313 Implementation uses "aggressive" pre-fetching and caching of the records 

314 in memory. Memory cache is synchronized from database when `refresh` 

315 method is called. 

316 """ 

317 

318 def __init__( 

319 self, 

320 db: Database, 

321 tables: CollectionTablesTuple, 

322 collectionIdName: str, 

323 *, 

324 dimensions: DimensionRecordStorageManager, 

325 ): 

326 self._db = db 

327 self._tables = tables 

328 self._collectionIdName = collectionIdName 

329 self._records: Dict[K, CollectionRecord] = {} # indexed by record ID 

330 self._dimensions = dimensions 

331 

332 def refresh(self) -> None: 

333 # Docstring inherited from CollectionManager. 

334 sql = sqlalchemy.sql.select( 

335 *(list(self._tables.collection.columns) + list(self._tables.run.columns)) 

336 ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) 

337 # Put found records into a temporary instead of updating self._records 

338 # in place, for exception safety. 

339 records = [] 

340 chains = [] 

341 TimespanReprClass = self._db.getTimespanRepresentation() 

342 for row in self._db.query(sql).mappings(): 

343 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

344 name = row[self._tables.collection.columns.name] 

345 type = CollectionType(row["type"]) 

346 record: CollectionRecord 

347 if type is CollectionType.RUN: 

348 record = DefaultRunRecord( 

349 key=collection_id, 

350 name=name, 

351 db=self._db, 

352 table=self._tables.run, 

353 idColumnName=self._collectionIdName, 

354 host=row[self._tables.run.columns.host], 

355 timespan=TimespanReprClass.extract(row), 

356 ) 

357 elif type is CollectionType.CHAINED: 

358 record = DefaultChainedCollectionRecord( 

359 db=self._db, 

360 key=collection_id, 

361 table=self._tables.collection_chain, 

362 name=name, 

363 universe=self._dimensions.universe, 

364 ) 

365 chains.append(record) 

366 else: 

367 record = CollectionRecord(key=collection_id, name=name, type=type) 

368 records.append(record) 

369 self._setRecordCache(records) 

370 for chain in chains: 

371 try: 

372 chain.refresh(self) 

373 except MissingCollectionError: 

374 # This indicates a race condition in which some other client 

375 # created a new collection and added it as a child of this 

376 # (pre-existing) chain between the time we fetched all 

377 # collections and the time we queried for parent-child 

378 # relationships. 

379 # Because that's some other unrelated client, we shouldn't care 

380 # about that parent collection anyway, so we just drop it on 

381 # the floor (a manual refresh can be used to get it back). 

382 self._removeCachedRecord(chain) 

383 

384 def register( 

385 self, name: str, type: CollectionType, doc: Optional[str] = None 

386 ) -> Tuple[CollectionRecord, bool]: 

387 # Docstring inherited from CollectionManager. 

388 registered = False 

389 record = self._getByName(name) 

390 if record is None: 

391 row, inserted_or_updated = self._db.sync( 

392 self._tables.collection, 

393 keys={"name": name}, 

394 compared={"type": int(type)}, 

395 extra={"doc": doc}, 

396 returning=[self._collectionIdName], 

397 ) 

398 assert isinstance(inserted_or_updated, bool) 

399 registered = inserted_or_updated 

400 assert row is not None 

401 collection_id = row[self._collectionIdName] 

402 if type is CollectionType.RUN: 

403 TimespanReprClass = self._db.getTimespanRepresentation() 

404 row, _ = self._db.sync( 

405 self._tables.run, 

406 keys={self._collectionIdName: collection_id}, 

407 returning=("host",) + TimespanReprClass.getFieldNames(), 

408 ) 

409 assert row is not None 

410 record = DefaultRunRecord( 

411 db=self._db, 

412 key=collection_id, 

413 name=name, 

414 table=self._tables.run, 

415 idColumnName=self._collectionIdName, 

416 host=row["host"], 

417 timespan=TimespanReprClass.extract(row), 

418 ) 

419 elif type is CollectionType.CHAINED: 

420 record = DefaultChainedCollectionRecord( 

421 db=self._db, 

422 key=collection_id, 

423 name=name, 

424 table=self._tables.collection_chain, 

425 universe=self._dimensions.universe, 

426 ) 

427 else: 

428 record = CollectionRecord(key=collection_id, name=name, type=type) 

429 self._addCachedRecord(record) 

430 return record, registered 

431 

432 def remove(self, name: str) -> None: 

433 # Docstring inherited from CollectionManager. 

434 record = self._getByName(name) 

435 if record is None: 435 ↛ 436line 435 didn't jump to line 436, because the condition on line 435 was never true

436 raise MissingCollectionError(f"No collection with name '{name}' found.") 

437 # This may raise 

438 self._db.delete( 

439 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

440 ) 

441 self._removeCachedRecord(record) 

442 

443 def find(self, name: str) -> CollectionRecord: 

444 # Docstring inherited from CollectionManager. 

445 result = self._getByName(name) 

446 if result is None: 

447 raise MissingCollectionError(f"No collection with name '{name}' found.") 

448 return result 

449 

450 def __getitem__(self, key: Any) -> CollectionRecord: 

451 # Docstring inherited from CollectionManager. 

452 try: 

453 return self._records[key] 

454 except KeyError as err: 

455 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

456 

457 def __iter__(self) -> Iterator[CollectionRecord]: 

458 yield from self._records.values() 

459 

460 def getDocumentation(self, key: Any) -> Optional[str]: 

461 # Docstring inherited from CollectionManager. 

462 sql = ( 

463 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

464 .select_from(self._tables.collection) 

465 .where(self._tables.collection.columns[self._collectionIdName] == key) 

466 ) 

467 return self._db.query(sql).scalar() 

468 

469 def setDocumentation(self, key: Any, doc: Optional[str]) -> None: 

470 # Docstring inherited from CollectionManager. 

471 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

472 

473 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

474 """Set internal record cache to contain given records, 

475 old cached records will be removed. 

476 """ 

477 self._records = {} 

478 for record in records: 

479 self._records[record.key] = record 

480 

481 def _addCachedRecord(self, record: CollectionRecord) -> None: 

482 """Add single record to cache.""" 

483 self._records[record.key] = record 

484 

485 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

486 """Remove single record from cache.""" 

487 del self._records[record.key] 

488 

489 @abstractmethod 

490 def _getByName(self, name: str) -> Optional[CollectionRecord]: 

491 """Find collection record given collection name.""" 

492 raise NotImplementedError()