Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = () 

24 

25from abc import abstractmethod 

26from collections import namedtuple 

27import itertools 

28from typing import ( 

29 Any, 

30 Dict, 

31 Generic, 

32 Iterable, 

33 Iterator, 

34 Optional, 

35 Type, 

36 TYPE_CHECKING, 

37 TypeVar, 

38) 

39 

40import sqlalchemy 

41 

42from ...core import DimensionUniverse, TimespanDatabaseRepresentation, ddl, Timespan 

43from .._collectionType import CollectionType 

44from ..interfaces import ( 

45 ChainedCollectionRecord, 

46 CollectionManager, 

47 CollectionRecord, 

48 MissingCollectionError, 

49 RunRecord, 

50) 

51from ..wildcards import CollectionSearch 

52 

53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 from ..interfaces import Database, DimensionRecordStorageManager 

55 

56 

57def _makeCollectionForeignKey(sourceColumnName: str, collectionIdName: str, 

58 **kwargs: Any) -> ddl.ForeignKeySpec: 

59 """Define foreign key specification that refers to collections table. 

60 

61 Parameters 

62 ---------- 

63 sourceColumnName : `str` 

64 Name of the column in the referring table. 

65 collectionIdName : `str` 

66 Name of the column in collections table that identifies it (PK). 

67 **kwargs 

68 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

69 

70 Returns 

71 ------- 

72 spec : `ddl.ForeignKeySpec` 

73 Foreign key specification. 

74 

75 Notes 

76 ----- 

77 This method assumes fixed name ("collection") of a collections table. 

78 There is also a general assumption that collection primary key consists 

79 of a single column. 

80 """ 

81 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), 

82 **kwargs) 

83 

84 

85CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

86 

87 

88def makeRunTableSpec(collectionIdName: str, collectionIdType: type, 

89 tsRepr: Type[TimespanDatabaseRepresentation]) -> ddl.TableSpec: 

90 """Define specification for "run" table. 

91 

92 Parameters 

93 ---------- 

94 collectionIdName : `str` 

95 Name of the column in collections table that identifies it (PK). 

96 collectionIdType 

97 Type of the PK column in the collections table, one of the 

98 `sqlalchemy` types. 

99 tsRepr : `type` [ `TimespanDatabaseRepresentation` ] 

100 Subclass of `TimespanDatabaseRepresentation` that encapsulates how 

101 timespans are stored in this database. 

102 

103 

104 Returns 

105 ------- 

106 spec : `ddl.TableSpec` 

107 Specification for run table. 

108 

109 Notes 

110 ----- 

111 Assumption here and in the code below is that the name of the identifying 

112 column is the same in both collections and run tables. The names of 

113 non-identifying columns containing run metadata are fixed. 

114 """ 

115 result = ddl.TableSpec( 

116 fields=[ 

117 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

118 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

119 ], 

120 foreignKeys=[ 

121 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

122 ], 

123 ) 

124 for fieldSpec in tsRepr.makeFieldSpecs(nullable=True): 

125 result.fields.add(fieldSpec) 

126 return result 

127 

128 

129def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

130 """Define specification for "collection_chain" table. 

131 

132 Parameters 

133 ---------- 

134 collectionIdName : `str` 

135 Name of the column in collections table that identifies it (PK). 

136 collectionIdType 

137 Type of the PK column in the collections table, one of the 

138 `sqlalchemy` types. 

139 

140 Returns 

141 ------- 

142 spec : `ddl.TableSpec` 

143 Specification for collection chain table. 

144 

145 Notes 

146 ----- 

147 Collection chain is simply an ordered one-to-many relation between 

148 collections. The names of the columns in the table are fixed and 

149 also hardcoded in the code below. 

150 """ 

151 return ddl.TableSpec( 

152 fields=[ 

153 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

154 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

155 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

156 ], 

157 foreignKeys=[ 

158 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

159 _makeCollectionForeignKey("child", collectionIdName), 

160 ], 

161 ) 

162 

163 

164class DefaultRunRecord(RunRecord): 

165 """Default `RunRecord` implementation. 

166 

167 This method assumes the same run table definition as produced by 

168 `makeRunTableSpec` method. The only non-fixed name in the schema 

169 is the PK column name, this needs to be passed in a constructor. 

170 

171 Parameters 

172 ---------- 

173 db : `Database` 

174 Registry database. 

175 key 

176 Unique collection ID, can be the same as ``name`` if ``name`` is used 

177 for identification. Usually this is an integer or string, but can be 

178 other database-specific type. 

179 name : `str` 

180 Run collection name. 

181 table : `sqlalchemy.schema.Table` 

182 Table for run records. 

183 idColumnName : `str` 

184 Name of the identifying column in run table. 

185 host : `str`, optional 

186 Name of the host where run was produced. 

187 timespan : `Timespan`, optional 

188 Timespan for this run. 

189 """ 

190 def __init__(self, db: Database, key: Any, name: str, *, table: sqlalchemy.schema.Table, 

191 idColumnName: str, host: Optional[str] = None, 

192 timespan: Optional[Timespan] = None): 

193 super().__init__(key=key, name=name, type=CollectionType.RUN) 

194 self._db = db 

195 self._table = table 

196 self._host = host 

197 if timespan is None: 197 ↛ 199line 197 didn't jump to line 199, because the condition on line 197 was never false

198 timespan = Timespan(begin=None, end=None) 

199 self._timespan = timespan 

200 self._idName = idColumnName 

201 

202 def update(self, host: Optional[str] = None, 

203 timespan: Optional[Timespan] = None) -> None: 

204 # Docstring inherited from RunRecord. 

205 if timespan is None: 

206 timespan = Timespan(begin=None, end=None) 

207 row = { 

208 self._idName: self.key, 

209 "host": host, 

210 } 

211 self._db.getTimespanRepresentation().update(timespan, result=row) 

212 count = self._db.update(self._table, {self._idName: self.key}, row) 

213 if count != 1: 

214 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

215 self._host = host 

216 self._timespan = timespan 

217 

218 @property 

219 def host(self) -> Optional[str]: 

220 # Docstring inherited from RunRecord. 

221 return self._host 

222 

223 @property 

224 def timespan(self) -> Timespan: 

225 # Docstring inherited from RunRecord. 

226 return self._timespan 

227 

228 

229class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

230 """Default `ChainedCollectionRecord` implementation. 

231 

232 This method assumes the same chain table definition as produced by 

233 `makeCollectionChainTableSpec` method. All column names in the table are 

234 fixed and hard-coded in the methods. 

235 

236 Parameters 

237 ---------- 

238 db : `Database` 

239 Registry database. 

240 key 

241 Unique collection ID, can be the same as ``name`` if ``name`` is used 

242 for identification. Usually this is an integer or string, but can be 

243 other database-specific type. 

244 name : `str` 

245 Collection name. 

246 table : `sqlalchemy.schema.Table` 

247 Table for chain relationship records. 

248 universe : `DimensionUniverse` 

249 Object managing all known dimensions. 

250 """ 

251 def __init__(self, db: Database, key: Any, name: str, *, table: sqlalchemy.schema.Table, 

252 universe: DimensionUniverse): 

253 super().__init__(key=key, name=name, universe=universe) 

254 self._db = db 

255 self._table = table 

256 self._universe = universe 

257 

258 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None: 

259 # Docstring inherited from ChainedCollectionRecord. 

260 rows = [] 

261 position = itertools.count() 

262 for child in children.iter(manager, flattenChains=False): 

263 rows.append({ 

264 "parent": self.key, 

265 "child": child.key, 

266 "position": next(position), 

267 }) 

268 with self._db.transaction(): 

269 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

270 self._db.insert(self._table, *rows) 

271 

272 def _load(self, manager: CollectionManager) -> CollectionSearch: 

273 # Docstring inherited from ChainedCollectionRecord. 

274 sql = sqlalchemy.sql.select([ 

275 self._table.columns.child, 

276 ]).select_from( 

277 self._table 

278 ).where( 

279 self._table.columns.parent == self.key 

280 ).order_by( 

281 self._table.columns.position 

282 ) 

283 return CollectionSearch.fromExpression( 

284 [manager[row[self._table.columns.child]].name for row in self._db.query(sql)] 

285 ) 

286 

287 

288K = TypeVar("K") 

289 

290 

291class DefaultCollectionManager(Generic[K], CollectionManager): 

292 """Default `CollectionManager` implementation. 

293 

294 This implementation uses record classes defined in this module and is 

295 based on the same assumptions about schema outlined in the record classes. 

296 

297 Parameters 

298 ---------- 

299 db : `Database` 

300 Interface to the underlying database engine and namespace. 

301 tables : `CollectionTablesTuple` 

302 Named tuple of SQLAlchemy table objects. 

303 collectionIdName : `str` 

304 Name of the column in collections table that identifies it (PK). 

305 dimensions : `DimensionRecordStorageManager` 

306 Manager object for the dimensions in this `Registry`. 

307 

308 Notes 

309 ----- 

310 Implementation uses "aggressive" pre-fetching and caching of the records 

311 in memory. Memory cache is synchronized from database when `refresh` 

312 method is called. 

313 """ 

314 def __init__(self, db: Database, tables: CollectionTablesTuple, collectionIdName: str, *, 

315 dimensions: DimensionRecordStorageManager): 

316 self._db = db 

317 self._tables = tables 

318 self._collectionIdName = collectionIdName 

319 self._records: Dict[K, CollectionRecord] = {} # indexed by record ID 

320 self._dimensions = dimensions 

321 

322 def refresh(self) -> None: 

323 # Docstring inherited from CollectionManager. 

324 sql = sqlalchemy.sql.select( 

325 self._tables.collection.columns + self._tables.run.columns 

326 ).select_from( 

327 self._tables.collection.join(self._tables.run, isouter=True) 

328 ) 

329 # Put found records into a temporary instead of updating self._records 

330 # in place, for exception safety. 

331 records = [] 

332 chains = [] 

333 tsRepr = self._db.getTimespanRepresentation() 

334 for row in self._db.query(sql).fetchall(): 

335 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

336 name = row[self._tables.collection.columns.name] 

337 type = CollectionType(row["type"]) 

338 record: CollectionRecord 

339 if type is CollectionType.RUN: 

340 record = DefaultRunRecord( 

341 key=collection_id, 

342 name=name, 

343 db=self._db, 

344 table=self._tables.run, 

345 idColumnName=self._collectionIdName, 

346 host=row[self._tables.run.columns.host], 

347 timespan=tsRepr.extract(row), 

348 ) 

349 elif type is CollectionType.CHAINED: 349 ↛ 350line 349 didn't jump to line 350, because the condition on line 349 was never true

350 record = DefaultChainedCollectionRecord(db=self._db, 

351 key=collection_id, 

352 table=self._tables.collection_chain, 

353 name=name, 

354 universe=self._dimensions.universe) 

355 chains.append(record) 

356 else: 

357 record = CollectionRecord(key=collection_id, name=name, type=type) 

358 records.append(record) 

359 self._setRecordCache(records) 

360 for chain in chains: 360 ↛ 361line 360 didn't jump to line 361, because the loop on line 360 never started

361 chain.refresh(self) 

362 

363 def register(self, name: str, type: CollectionType) -> CollectionRecord: 

364 # Docstring inherited from CollectionManager. 

365 record = self._getByName(name) 

366 if record is None: 

367 row, _ = self._db.sync( 

368 self._tables.collection, 

369 keys={"name": name}, 

370 compared={"type": int(type)}, 

371 returning=[self._collectionIdName], 

372 ) 

373 assert row is not None 

374 collection_id = row[self._collectionIdName] 

375 if type is CollectionType.RUN: 

376 tsRepr = self._db.getTimespanRepresentation() 

377 row, _ = self._db.sync( 

378 self._tables.run, 

379 keys={self._collectionIdName: collection_id}, 

380 returning=("host",) + tsRepr.getFieldNames(), 

381 ) 

382 assert row is not None 

383 record = DefaultRunRecord( 

384 db=self._db, 

385 key=collection_id, 

386 name=name, 

387 table=self._tables.run, 

388 idColumnName=self._collectionIdName, 

389 host=row["host"], 

390 timespan=tsRepr.extract(row), 

391 ) 

392 elif type is CollectionType.CHAINED: 

393 record = DefaultChainedCollectionRecord(db=self._db, key=collection_id, name=name, 

394 table=self._tables.collection_chain, 

395 universe=self._dimensions.universe) 

396 else: 

397 record = CollectionRecord(key=collection_id, name=name, type=type) 

398 self._addCachedRecord(record) 

399 return record 

400 

401 def remove(self, name: str) -> None: 

402 # Docstring inherited from CollectionManager. 

403 record = self._getByName(name) 

404 if record is None: 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true

405 raise MissingCollectionError(f"No collection with name '{name}' found.") 

406 # This may raise 

407 self._db.delete(self._tables.collection, [self._collectionIdName], 

408 {self._collectionIdName: record.key}) 

409 self._removeCachedRecord(record) 

410 

411 def find(self, name: str) -> CollectionRecord: 

412 # Docstring inherited from CollectionManager. 

413 result = self._getByName(name) 

414 if result is None: 

415 raise MissingCollectionError(f"No collection with name '{name}' found.") 

416 return result 

417 

418 def __getitem__(self, key: Any) -> CollectionRecord: 

419 # Docstring inherited from CollectionManager. 

420 try: 

421 return self._records[key] 

422 except KeyError as err: 

423 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

424 

425 def __iter__(self) -> Iterator[CollectionRecord]: 

426 yield from self._records.values() 

427 

428 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

429 """Set internal record cache to contain given records, 

430 old cached records will be removed. 

431 """ 

432 self._records = {} 

433 for record in records: 

434 self._records[record.key] = record 

435 

436 def _addCachedRecord(self, record: CollectionRecord) -> None: 

437 """Add single record to cache. 

438 """ 

439 self._records[record.key] = record 

440 

441 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

442 """Remove single record from cache. 

443 """ 

444 del self._records[record.key] 

445 

446 @abstractmethod 

447 def _getByName(self, name: str) -> Optional[CollectionRecord]: 

448 """Find collection record given collection name. 

449 """ 

450 raise NotImplementedError()