Coverage for python/lsst/daf/butler/registry/collections/_base.py: 96%

175 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:48 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = () 

32 

33import itertools 

34from abc import abstractmethod 

35from collections import namedtuple 

36from collections.abc import Iterable, Iterator, Set 

37from typing import TYPE_CHECKING, Any, TypeVar, cast 

38 

39import sqlalchemy 

40 

41from ...timespan_database_representation import TimespanDatabaseRepresentation 

42from .._collection_type import CollectionType 

43from .._exceptions import MissingCollectionError 

44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple 

45from ..wildcards import CollectionWildcard 

46 

47if TYPE_CHECKING: 

48 from .._caching_context import CachingContext 

49 from ..interfaces import Database 

50 

51 

52def _makeCollectionForeignKey( 

53 sourceColumnName: str, collectionIdName: str, **kwargs: Any 

54) -> ddl.ForeignKeySpec: 

55 """Define foreign key specification that refers to collections table. 

56 

57 Parameters 

58 ---------- 

59 sourceColumnName : `str` 

60 Name of the column in the referring table. 

61 collectionIdName : `str` 

62 Name of the column in collections table that identifies it (PK). 

63 **kwargs 

64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

65 

66 Returns 

67 ------- 

68 spec : `ddl.ForeignKeySpec` 

69 Foreign key specification. 

70 

71 Notes 

72 ----- 

73 This method assumes fixed name ("collection") of a collections table. 

74 There is also a general assumption that collection primary key consists 

75 of a single column. 

76 """ 

77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) 

78 

79 

80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

81 

82 

83def makeRunTableSpec( 

84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation] 

85) -> ddl.TableSpec: 

86 """Define specification for "run" table. 

87 

88 Parameters 

89 ---------- 

90 collectionIdName : `str` 

91 Name of the column in collections table that identifies it (PK). 

92 collectionIdType : `type` 

93 Type of the PK column in the collections table, one of the 

94 `sqlalchemy` types. 

95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ] 

96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

97 timespans are stored in this database. 

98 

99 Returns 

100 ------- 

101 spec : `ddl.TableSpec` 

102 Specification for run table. 

103 

104 Notes 

105 ----- 

106 Assumption here and in the code below is that the name of the identifying 

107 column is the same in both collections and run tables. The names of 

108 non-identifying columns containing run metadata are fixed. 

109 """ 

110 result = ddl.TableSpec( 

111 fields=[ 

112 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

113 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

114 ], 

115 foreignKeys=[ 

116 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

117 ], 

118 ) 

119 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

120 result.fields.add(fieldSpec) 

121 return result 

122 

123 

124def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

125 """Define specification for "collection_chain" table. 

126 

127 Parameters 

128 ---------- 

129 collectionIdName : `str` 

130 Name of the column in collections table that identifies it (PK). 

131 collectionIdType : `type` 

132 Type of the PK column in the collections table, one of the 

133 `sqlalchemy` types. 

134 

135 Returns 

136 ------- 

137 spec : `ddl.TableSpec` 

138 Specification for collection chain table. 

139 

140 Notes 

141 ----- 

142 Collection chain is simply an ordered one-to-many relation between 

143 collections. The names of the columns in the table are fixed and 

144 also hardcoded in the code below. 

145 """ 

146 return ddl.TableSpec( 

147 fields=[ 

148 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

149 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

150 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

151 ], 

152 foreignKeys=[ 

153 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

154 _makeCollectionForeignKey("child", collectionIdName), 

155 ], 

156 ) 

157 

158 

159K = TypeVar("K") 

160 

161 

162class DefaultCollectionManager(CollectionManager[K]): 

163 """Default `CollectionManager` implementation. 

164 

165 This implementation uses record classes defined in this module and is 

166 based on the same assumptions about schema outlined in the record classes. 

167 

168 Parameters 

169 ---------- 

170 db : `Database` 

171 Interface to the underlying database engine and namespace. 

172 tables : `CollectionTablesTuple` 

173 Named tuple of SQLAlchemy table objects. 

174 collectionIdName : `str` 

175 Name of the column in collections table that identifies it (PK). 

176 caching_context : `CachingContext` 

177 Caching context to use. 

178 registry_schema_version : `VersionTuple` or `None`, optional 

179 The version of the registry schema. 

180 

181 Notes 

182 ----- 

183 Implementation uses "aggressive" pre-fetching and caching of the records 

184 in memory. Memory cache is synchronized from database when `refresh` 

185 method is called. 

186 """ 

187 

188 def __init__( 

189 self, 

190 db: Database, 

191 tables: CollectionTablesTuple, 

192 collectionIdName: str, 

193 *, 

194 caching_context: CachingContext, 

195 registry_schema_version: VersionTuple | None = None, 

196 ): 

197 super().__init__(registry_schema_version=registry_schema_version) 

198 self._db = db 

199 self._tables = tables 

200 self._collectionIdName = collectionIdName 

201 self._caching_context = caching_context 

202 

203 def refresh(self) -> None: 

204 # Docstring inherited from CollectionManager. 

205 if self._caching_context.collection_records is not None: 205 ↛ 206line 205 didn't jump to line 206, because the condition on line 205 was never true

206 self._caching_context.collection_records.clear() 

207 

208 def _fetch_all(self) -> list[CollectionRecord[K]]: 

209 """Retrieve all records into cache if not done so yet.""" 

210 if self._caching_context.collection_records is not None: 

211 if self._caching_context.collection_records.full: 

212 return list(self._caching_context.collection_records.records()) 

213 records = self._fetch_by_key(None) 

214 if self._caching_context.collection_records is not None: 

215 self._caching_context.collection_records.set(records, full=True) 

216 return records 

217 

218 def register( 

219 self, name: str, type: CollectionType, doc: str | None = None 

220 ) -> tuple[CollectionRecord[K], bool]: 

221 # Docstring inherited from CollectionManager. 

222 registered = False 

223 record = self._getByName(name) 

224 if record is None: 

225 row, inserted_or_updated = self._db.sync( 

226 self._tables.collection, 

227 keys={"name": name}, 

228 compared={"type": int(type)}, 

229 extra={"doc": doc}, 

230 returning=[self._collectionIdName], 

231 ) 

232 assert isinstance(inserted_or_updated, bool) 

233 registered = inserted_or_updated 

234 assert row is not None 

235 collection_id = cast(K, row[self._collectionIdName]) 

236 if type is CollectionType.RUN: 

237 TimespanReprClass = self._db.getTimespanRepresentation() 

238 row, _ = self._db.sync( 

239 self._tables.run, 

240 keys={self._collectionIdName: collection_id}, 

241 returning=("host",) + TimespanReprClass.getFieldNames(), 

242 ) 

243 assert row is not None 

244 record = RunRecord[K]( 

245 key=collection_id, 

246 name=name, 

247 host=row["host"], 

248 timespan=TimespanReprClass.extract(row), 

249 ) 

250 elif type is CollectionType.CHAINED: 

251 record = ChainedCollectionRecord[K]( 

252 key=collection_id, 

253 name=name, 

254 children=[], 

255 ) 

256 else: 

257 record = CollectionRecord[K](key=collection_id, name=name, type=type) 

258 self._addCachedRecord(record) 

259 return record, registered 

260 

261 def remove(self, name: str) -> None: 

262 # Docstring inherited from CollectionManager. 

263 record = self._getByName(name) 

264 if record is None: 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true

265 raise MissingCollectionError(f"No collection with name '{name}' found.") 

266 # This may raise 

267 self._db.delete( 

268 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key} 

269 ) 

270 self._removeCachedRecord(record) 

271 

272 def find(self, name: str) -> CollectionRecord[K]: 

273 # Docstring inherited from CollectionManager. 

274 result = self._getByName(name) 

275 if result is None: 

276 raise MissingCollectionError(f"No collection with name '{name}' found.") 

277 return result 

278 

279 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

280 """Return multiple records given their names.""" 

281 names = list(names) 

282 # To protect against potential races in cache updates. 

283 records: dict[str, CollectionRecord | None] = {} 

284 if self._caching_context.collection_records is not None: 

285 for name in names: 

286 records[name] = self._caching_context.collection_records.get_by_name(name) 

287 fetch_names = [name for name, record in records.items() if record is None] 

288 else: 

289 fetch_names = list(names) 

290 records = {name: None for name in fetch_names} 

291 if fetch_names: 

292 for record in self._fetch_by_name(fetch_names): 

293 records[record.name] = record 

294 self._addCachedRecord(record) 

295 missing_names = [name for name, record in records.items() if record is None] 

296 if len(missing_names) == 1: 

297 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") 

298 elif len(missing_names) > 1: 298 ↛ 299line 298 didn't jump to line 299, because the condition on line 298 was never true

299 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") 

300 return [cast(CollectionRecord[K], records[name]) for name in names] 

301 

302 def __getitem__(self, key: Any) -> CollectionRecord[K]: 

303 # Docstring inherited from CollectionManager. 

304 if self._caching_context.collection_records is not None: 

305 if (record := self._caching_context.collection_records.get_by_key(key)) is not None: 

306 return record 

307 if records := self._fetch_by_key([key]): 307 ↛ 313line 307 didn't jump to line 313, because the condition on line 307 was never false

308 record = records[0] 

309 if self._caching_context.collection_records is not None: 

310 self._caching_context.collection_records.add(record) 

311 return record 

312 else: 

313 raise MissingCollectionError(f"Collection with key '{key}' not found.") 

314 

315 def resolve_wildcard( 

316 self, 

317 wildcard: CollectionWildcard, 

318 *, 

319 collection_types: Set[CollectionType] = CollectionType.all(), 

320 done: set[str] | None = None, 

321 flatten_chains: bool = True, 

322 include_chains: bool | None = None, 

323 ) -> list[CollectionRecord[K]]: 

324 # Docstring inherited 

325 if done is None: 325 ↛ 327line 325 didn't jump to line 327, because the condition on line 325 was never false

326 done = set() 

327 include_chains = include_chains if include_chains is not None else not flatten_chains 

328 

329 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: 

330 if record.name in done: 

331 return 

332 if record.type in collection_types: 

333 done.add(record.name) 

334 if record.type is not CollectionType.CHAINED or include_chains: 

335 yield record 

336 if flatten_chains and record.type is CollectionType.CHAINED: 

337 done.add(record.name) 

338 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): 

339 # flake8 can't tell that we only delete this closure when 

340 # we're totally done with it. 

341 yield from resolve_nested(child, done) # noqa: F821 

342 

343 result: list[CollectionRecord[K]] = [] 

344 

345 if wildcard.patterns is ...: 

346 for record in self._fetch_all(): 

347 result.extend(resolve_nested(record, done)) 

348 del resolve_nested 

349 return result 

350 if wildcard.strings: 

351 for record in self._find_many(wildcard.strings): 

352 result.extend(resolve_nested(record, done)) 

353 if wildcard.patterns: 

354 for record in self._fetch_all(): 

355 if any(p.fullmatch(record.name) for p in wildcard.patterns): 

356 result.extend(resolve_nested(record, done)) 

357 del resolve_nested 

358 return result 

359 

360 def getDocumentation(self, key: K) -> str | None: 

361 # Docstring inherited from CollectionManager. 

362 sql = ( 

363 sqlalchemy.sql.select(self._tables.collection.columns.doc) 

364 .select_from(self._tables.collection) 

365 .where(self._tables.collection.columns[self._collectionIdName] == key) 

366 ) 

367 with self._db.query(sql) as sql_result: 

368 return sql_result.scalar() 

369 

370 def setDocumentation(self, key: K, doc: str | None) -> None: 

371 # Docstring inherited from CollectionManager. 

372 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) 

373 

374 def _addCachedRecord(self, record: CollectionRecord[K]) -> None: 

375 """Add single record to cache.""" 

376 if self._caching_context.collection_records is not None: 

377 self._caching_context.collection_records.add(record) 

378 

379 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: 

380 """Remove single record from cache.""" 

381 if self._caching_context.collection_records is not None: 381 ↛ 382line 381 didn't jump to line 382, because the condition on line 381 was never true

382 self._caching_context.collection_records.discard(record) 

383 

384 def _getByName(self, name: str) -> CollectionRecord[K] | None: 

385 """Find collection record given collection name.""" 

386 if self._caching_context.collection_records is not None: 

387 if (record := self._caching_context.collection_records.get_by_name(name)) is not None: 

388 return record 

389 records = self._fetch_by_name([name]) 

390 for record in records: 

391 self._addCachedRecord(record) 

392 return records[0] if records else None 

393 

394 @abstractmethod 

395 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: 

396 """Fetch collection record from database given its name.""" 

397 raise NotImplementedError() 

398 

399 @abstractmethod 

400 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: 

401 """Fetch collection record from database given its key, or fetch all 

402 collctions if argument is None. 

403 """ 

404 raise NotImplementedError() 

405 

406 def update_chain( 

407 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False 

408 ) -> ChainedCollectionRecord[K]: 

409 # Docstring inherited from CollectionManager. 

410 children_as_wildcard = CollectionWildcard.from_names(children) 

411 for record in self.resolve_wildcard( 

412 children_as_wildcard, 

413 flatten_chains=True, 

414 include_chains=True, 

415 collection_types={CollectionType.CHAINED}, 

416 ): 

417 if record == chain: 

418 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") 

419 if flatten: 

420 children = tuple( 

421 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) 

422 ) 

423 

424 rows = [] 

425 position = itertools.count() 

426 names = [] 

427 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): 

428 rows.append( 

429 { 

430 "parent": chain.key, 

431 "child": child.key, 

432 "position": next(position), 

433 } 

434 ) 

435 names.append(child.name) 

436 with self._db.transaction(): 

437 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) 

438 self._db.insert(self._tables.collection_chain, *rows) 

439 

440 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) 

441 self._addCachedRecord(record) 

442 return record