Coverage for python/lsst/daf/butler/registry/collections/_base.py: 86%

176 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-05 11:05 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33import itertools 

34from abc import abstractmethod 

35from collections import namedtuple 

36from collections.abc import Iterable, Iterator, Set 

37from typing import TYPE_CHECKING, Any, TypeVar, cast 

38 

39import sqlalchemy 

40 

41from ..._timespan import TimespanDatabaseRepresentation 

42from .._collection_type import CollectionType 

43from .._exceptions import MissingCollectionError 

44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

45from ..wildcards import CollectionWildcard 

46 

47if TYPE_CHECKING: 

48 from .._caching_context import CachingContext 

49 from ..interfaces import Database, DimensionRecordStorageManager 

50 

51 

52def _makeCollectionForeignKey( 

53 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

54) -> ddl.ForeignKeySpec: 

55 """Define foreign key specification that refers to collections table. 

56 

57 Parameters 

58 ---------- 

59 sourceColumnName : `str` 

60 Name of the column in the referring table. 

61 collectionIdName : `str` 

62 Name of the column in collections table that identifies it (PK). 

63 **kwargs 

64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

65 

66 Returns 

67 ------- 

68 spec : `ddl.ForeignKeySpec` 

69 Foreign key specification. 

70 

71 Notes 

72 ----- 

73 This method assumes fixed name ("collection") of a collections table. 

74 There is also a general assumption that collection primary key consists 

75 of a single column. 

76 """ 

77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

78 

79 

80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

81 

82 

83def makeRunTableSpec( 

84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

85) -> ddl.TableSpec: 

86 """Define specification for "run" table. 

87 

88 Parameters 

89 ---------- 

90 collectionIdName : `str` 

91 Name of the column in collections table that identifies it (PK). 

92 collectionIdType 

93 Type of the PK column in the collections table, one of the 

94 `sqlalchemy` types. 

95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

97 timespans are stored in this database. 

98 

99 

100 Returns 

101 ------- 

102 spec : `ddl.TableSpec` 

103 Specification for run table. 

104 

105 Notes 

106 ----- 

107 Assumption here and in the code below is that the name of the identifying 

108 column is the same in both collections and run tables. The names of 

109 non-identifying columns containing run metadata are fixed. 

110 """ 

111 result = ddl.TableSpec( 

112 fields=[ 

113 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

114 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

115 ], 

116 foreignKeys=[ 

117 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

118 ], 

119 ) 

120 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

121 result.fields.add(fieldSpec) 

122 return result 

123 

124 

125def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

126 """Define specification for "collection_chain" table. 

127 

128 Parameters 

129 ---------- 

130 collectionIdName : `str` 

131 Name of the column in collections table that identifies it (PK). 

132 collectionIdType 

133 Type of the PK column in the collections table, one of the 

134 `sqlalchemy` types. 

135 

136 Returns 

137 ------- 

138 spec : `ddl.TableSpec` 

139 Specification for collection chain table. 

140 

141 Notes 

142 ----- 

143 Collection chain is simply an ordered one-to-many relation between 

144 collections. The names of the columns in the table are fixed and 

145 also hardcoded in the code below. 

146 """ 

147 return ddl.TableSpec( 

148 fields=[ 

149 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

150 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

151 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

152 ], 

153 foreignKeys=[ 

154 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

155 _makeCollectionForeignKey("child", collectionIdName), 

156 ], 

157 ) 

158 

159 

160K = TypeVar("K") 

161 

162 

163class DefaultCollectionManager(CollectionManager[K]): 

164 """Default `CollectionManager` implementation. 

165 

166 This implementation uses record classes defined in this module and is 

167 based on the same assumptions about schema outlined in the record classes. 

168 

169 Parameters 

170 ---------- 

171 db : `Database` 

172 Interface to the underlying database engine and namespace. 

173 tables : `CollectionTablesTuple` 

174 Named tuple of SQLAlchemy table objects. 

175 collectionIdName : `str` 

176 Name of the column in collections table that identifies it (PK). 

177 dimensions : `DimensionRecordStorageManager` 

178 Manager object for the dimensions in this `Registry`. 

179 

180 Notes 

181 ----- 

182 Implementation uses "aggressive" pre-fetching and caching of the records 

183 in memory. Memory cache is synchronized from database when `refresh` 

184 method is called. 

185 """ 

186 

187 def __init__( 

188 self, 

189 db: Database, 

190 tables: CollectionTablesTuple, 

191 collectionIdName: str, 

192 *, 

193 dimensions: DimensionRecordStorageManager, 

194 caching_context: CachingContext, 

195 registry_schema_version: VersionTuple | None = None, 

196 ): 

197 super().__init__(registry_schema_version=registry_schema_version) 

198 self._db = db 

199 self._tables = tables 

200 self._collectionIdName = collectionIdName 

201 self._dimensions = dimensions 

202 self._caching_context = caching_context 

203 

204 def refresh(self) -> None: 

205 # Docstring inherited from CollectionManager. 

206 if self._caching_context.collection_records is not None: 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true

207 self._caching_context.collection_records.clear() 

208 

209 def _fetch_all(self) -> list[CollectionRecord[K]]: 

210 """Retrieve all records into cache if not done so yet.""" 

211 if self._caching_context.collection_records is not None: 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true

212 if self._caching_context.collection_records.full: 

213 return list(self._caching_context.collection_records.records()) 

214 records = self._fetch_by_key(None) 

215 if self._caching_context.collection_records is not None: 215 ↛ 216line 215 didn't jump to line 216, because the condition on line 215 was never true

216 self._caching_context.collection_records.set(records, full=True) 

217 return records 

218 

219 def register( 

220 self, name: str, type: CollectionType, doc: str | None = None 

221 ) -> tuple[CollectionRecord[K], bool]: 

222 # Docstring inherited from CollectionManager. 

223 registered = False 

224 record = self._getByName(name) 

225 if record is None: 

226 row, inserted_or_updated = self._db.sync( 

227 self._tables.collection, 

228 keys={"name": name}, 

229 compared={"type": int(type)}, 

230 extra={"doc": doc}, 

231 returning=[self._collectionIdName], 

232 ) 

233 assert isinstance(inserted_or_updated, bool) 

234 registered = inserted_or_updated 

235 assert row is not None 

236 collection_id = cast(K, row[self._collectionIdName]) 

237 if type is CollectionType.RUN: 

238 TimespanReprClass = self._db.getTimespanRepresentation() 

239 row, _ = self._db.sync( 

240 self._tables.run, 

241 keys={self._collectionIdName: collection_id}, 

242 returning=("host",) + TimespanReprClass.getFieldNames(), 

243 ) 

244 assert row is not None 

245 record = RunRecord[K]( 

246 key=collection_id, 

247 name=name, 

248 host=row["host"], 

249 timespan=TimespanReprClass.extract(row), 

250 ) 

251 elif type is CollectionType.CHAINED: 

252 record = ChainedCollectionRecord[K]( 

253 key=collection_id, 

254 name=name, 

255 children=[], 

256 ) 

257 else: 

258 record = CollectionRecord[K](key=collection_id, name=name, type=type) 

259 self._addCachedRecord(record) 

260 return record, registered 

261 

262 def remove(self, name: str) -> None: 

263 # Docstring inherited from CollectionManager. 

264 record = self._getByName(name) 

265 if record is None: 265 ↛ 266line 265 didn't jump to line 266, because the condition on line 265 was never true

266 raise MissingCollectionError(f"No collection with name '{name}' found.") 

267 # This may raise 

268 self._db.delete( 

269 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

270 ) 

271 self._removeCachedRecord(record) 

272 

273 def find(self, name: str) -> CollectionRecord[K]: 

274 # Docstring inherited from CollectionManager. 

275 result = self._getByName(name) 

276 if result is None: 

277 raise MissingCollectionError(f"No collection with name '{name}' found.") 

278 return result 

279 

280 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

281 """Return multiple records given their names.""" 

282 names = list(names) 

283 # To protect against potential races in cache updates. 

284 records: dict[str, CollectionRecord | None] = {} 

285 if self._caching_context.collection_records is not None: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true

286 for name in names: 

287 records[name] = self._caching_context.collection_records.get_by_name(name) 

288 fetch_names = [name for name, record in records.items() if record is None] 

289 else: 

290 fetch_names = list(names) 

291 records = {name: None for name in fetch_names} 

292 if fetch_names: 

293 for record in self._fetch_by_name(fetch_names): 

294 records[record.name] = record 

295 self._addCachedRecord(record) 

296 missing_names = [name for name, record in records.items() if record is None] 

297 if len(missing_names) == 1: 

298 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") 

299 elif len(missing_names) > 1: 299 ↛ 300line 299 didn't jump to line 300, because the condition on line 299 was never true

300 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") 

301 return [cast(CollectionRecord[K], records[name]) for name in names] 

302 

303 def __getitem__(self, key: Any) -> CollectionRecord[K]: 

304 # Docstring inherited from CollectionManager. 

305 if self._caching_context.collection_records is not None: 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true

306 if (record := self._caching_context.collection_records.get_by_key(key)) is not None: 

307 return record 

308 if records := self._fetch_by_key([key]): 308 ↛ 314line 308 didn't jump to line 314, because the condition on line 308 was never false

309 record = records[0] 

310 if self._caching_context.collection_records is not None: 310 ↛ 311line 310 didn't jump to line 311, because the condition on line 310 was never true

311 self._caching_context.collection_records.add(record) 

312 return record 

313 else: 

314 raise MissingCollectionError(f"Collection with key '{key}' not found.") 

315 

316 def resolve_wildcard( 

317 self, 

318 wildcard: CollectionWildcard, 

319 *, 

320 collection_types: Set[CollectionType] = CollectionType.all(), 

321 done: set[str] | None = None, 

322 flatten_chains: bool = True, 

323 include_chains: bool | None = None, 

324 ) -> list[CollectionRecord[K]]: 

325 # Docstring inherited 

326 if done is None: 326 ↛ 328line 326 didn't jump to line 328, because the condition on line 326 was never false

327 done = set() 

328 include_chains = include_chains if include_chains is not None else not flatten_chains 

329 

330 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: 

331 if record.name in done: 

332 return 

333 if record.type in collection_types: 

334 done.add(record.name) 

335 if record.type is not CollectionType.CHAINED or include_chains: 

336 yield record 

337 if flatten_chains and record.type is CollectionType.CHAINED: 

338 done.add(record.name) 

339 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): 

340 # flake8 can't tell that we only delete this closure when 

341 # we're totally done with it. 

342 yield from resolve_nested(child, done) # noqa: F821 

343 

344 result: list[CollectionRecord[K]] = [] 

345 

346 if wildcard.patterns is ...: 

347 for record in self._fetch_all(): 

348 result.extend(resolve_nested(record, done)) 

349 del resolve_nested 

350 return result 

351 if wildcard.strings: 

352 for record in self._find_many(wildcard.strings): 

353 result.extend(resolve_nested(record, done)) 

354 if wildcard.patterns: 

355 for record in self._fetch_all(): 

356 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

357 result.extend(resolve_nested(record, done)) 

358 del resolve_nested 

359 return result 

360 

361 def getDocumentation(self, key: K) -> str | None: 

362 # Docstring inherited from CollectionManager. 

363 sql = ( 

364 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

365 .select_from(self._tables.collection) 

366 .where(self._tables.collection.columns[self._collectionIdName] == key) 

367 ) 

368 with self._db.query(sql) as sql_result: 

369 return sql_result.scalar() 

370 

371 def setDocumentation(self, key: K, doc: str | None) -> None: 

372 # Docstring inherited from CollectionManager. 

373 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

374 

375 def _addCachedRecord(self, record: CollectionRecord[K]) -> None: 

376 """Add single record to cache.""" 

377 if self._caching_context.collection_records is not None: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true

378 self._caching_context.collection_records.add(record) 

379 

380 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: 

381 """Remove single record from cache.""" 

382 if self._caching_context.collection_records is not None: 382 ↛ 383line 382 didn't jump to line 383, because the condition on line 382 was never true

383 self._caching_context.collection_records.discard(record) 

384 

385 def _getByName(self, name: str) -> CollectionRecord[K] | None: 

386 """Find collection record given collection name.""" 

387 if self._caching_context.collection_records is not None: 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true

388 if (record := self._caching_context.collection_records.get_by_name(name)) is not None: 

389 return record 

390 records = self._fetch_by_name([name]) 

391 for record in records: 

392 self._addCachedRecord(record) 

393 return records[0] if records else None 

394 

395 @abstractmethod 

396 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

397 """Fetch collection record from database given its name.""" 

398 raise NotImplementedError() 

399 

400 @abstractmethod 

401 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: 

402 """Fetch collection record from database given its key, or fetch all 

403 collctions if argument is None. 

404 """ 

405 raise NotImplementedError() 

406 

407 def update_chain( 

408 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False 

409 ) -> ChainedCollectionRecord[K]: 

410 # Docstring inherited from CollectionManager. 

411 children_as_wildcard = CollectionWildcard.from_names(children) 

412 for record in self.resolve_wildcard( 

413 children_as_wildcard, 

414 flatten_chains=True, 

415 include_chains=True, 

416 collection_types={CollectionType.CHAINED}, 

417 ): 

418 if record == chain: 

419 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") 

420 if flatten: 

421 children = tuple( 

422 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) 

423 ) 

424 

425 rows = [] 

426 position = itertools.count() 

427 names = [] 

428 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

429 rows.append( 

430 { 

431 "parent": chain.key, 

432 "child": child.key, 

433 "position": next(position), 

434 } 

435 ) 

436 names.append(child.name) 

437 with self._db.transaction(): 

438 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) 

439 self._db.insert(self._tables.collection_chain, *rows) 

440 

441 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) 

442 self._addCachedRecord(record) 

443 return record