Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = () 

24 

25from abc import abstractmethod 

26import astropy.time 

27from collections import namedtuple 

28import itertools 

29from typing import ( 

30 Any, 

31 Dict, 

32 Generic, 

33 Iterable, 

34 Iterator, 

35 Optional, 

36 TYPE_CHECKING, 

37 TypeVar, 

38) 

39 

40import sqlalchemy 

41 

42from ...core import ddl 

43from ...core.timespan import Timespan, TIMESPAN_FIELD_SPECS 

44from .._collectionType import CollectionType 

45from ..interfaces import ( 

46 ChainedCollectionRecord, 

47 CollectionManager, 

48 CollectionRecord, 

49 MissingCollectionError, 

50 RunRecord, 

51) 

52from ..wildcards import CollectionSearch, Ellipsis 

53 

54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 from ..interfaces import Database 

56 

57 

58def _makeCollectionForeignKey(sourceColumnName: str, collectionIdName: str, 

59 **kwargs: Any) -> ddl.ForeignKeySpec: 

60 """Define foreign key specification that refers to collections table. 

61 

62 Parameters 

63 ---------- 

64 sourceColumnName : `str` 

65 Name of the column in the referring table. 

66 collectionIdName : `str` 

67 Name of the column in collections table that identifies it (PK). 

68 **kwargs 

69 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`. 

70 

71 Returns 

72 ------- 

73 spec : `ddl.ForeignKeySpec` 

74 Foreign key specification. 

75 

76 Notes 

77 ----- 

78 This method assumes fixed name ("collection") of a collections table. 

79 There is also a general assumption that collection primary key consists 

80 of a single column. 

81 """ 

82 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), 

83 **kwargs) 

84 

85 

86CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

87 

88 

89def makeRunTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

90 """Define specification for "run" table. 

91 

92 Parameters 

93 ---------- 

94 collectionIdName : `str` 

95 Name of the column in collections table that identifies it (PK). 

96 collectionIdType 

97 Type of the PK column in the collections table, one of the 

98 `sqlalchemy` types. 

99 

100 Returns 

101 ------- 

102 spec : `ddl.TableSpec` 

103 Specification for run table. 

104 

105 Notes 

106 ----- 

107 Assumption here and in the code below is that the name of the identifying 

108 column is the same in both collections and run tables. The names of 

109 non-identifying columns containing run metadata are fixed. 

110 """ 

111 return ddl.TableSpec( 

112 fields=[ 

113 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True), 

114 TIMESPAN_FIELD_SPECS.begin, 

115 TIMESPAN_FIELD_SPECS.end, 

116 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

117 ], 

118 foreignKeys=[ 

119 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"), 

120 ], 

121 ) 

122 

123 

124def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec: 

125 """Define specification for "collection_chain" table. 

126 

127 Parameters 

128 ---------- 

129 collectionIdName : `str` 

130 Name of the column in collections table that identifies it (PK). 

131 collectionIdType 

132 Type of the PK column in the collections table, one of the 

133 `sqlalchemy` types. 

134 

135 Returns 

136 ------- 

137 spec : `ddl.TableSpec` 

138 Specification for collection chain table. 

139 

140 Notes 

141 ----- 

142 Collection chain is simply an ordered one-to-many relation between 

143 collections. The names of the columns in the table are fixed and 

144 also hardcoded in the code below. 

145 """ 

146 return ddl.TableSpec( 

147 fields=[ 

148 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True), 

149 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

150 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False), 

151 ddl.FieldSpec("dataset_type_name", dtype=sqlalchemy.String, length=128, nullable=True), 

152 ], 

153 foreignKeys=[ 

154 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"), 

155 _makeCollectionForeignKey("child", collectionIdName), 

156 ], 

157 ) 

158 

159 

160class DefaultRunRecord(RunRecord): 

161 """Default `RunRecord` implementation. 

162 

163 This method assumes the same run table definition as produced by 

164 `makeRunTableSpec` method. The only non-fixed name in the schema 

165 is the PK column name, this needs to be passed in a constructor. 

166 

167 Parameters 

168 ---------- 

169 db : `Database` 

170 Registry database. 

171 key 

172 Unique collection ID, can be the same as ``name`` if ``name`` is used 

173 for identification. Usually this is an integer or string, but can be 

174 other database-specific type. 

175 name : `str` 

176 Run collection name. 

177 table : `sqlalchemy.schema.Table` 

178 Table for run records. 

179 idColumnName : `str` 

180 Name of the identifying column in run table. 

181 host : `str`, optional 

182 Name of the host where run was produced. 

183 timespan : `Timespan`, optional 

184 Timespan for this run. 

185 """ 

186 def __init__(self, db: Database, key: Any, name: str, *, table: sqlalchemy.schema.Table, 

187 idColumnName: str, host: Optional[str] = None, 

188 timespan: Optional[Timespan[astropy.time.Time]] = None): 

189 super().__init__(key=key, name=name, type=CollectionType.RUN) 

190 self._db = db 

191 self._table = table 

192 self._host = host 

193 if timespan is None: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 timespan = Timespan(begin=None, end=None) 

195 self._timespan = timespan 

196 self._idName = idColumnName 

197 

198 def update(self, host: Optional[str] = None, 

199 timespan: Optional[Timespan[astropy.time.Time]] = None) -> None: 

200 # Docstring inherited from RunRecord. 

201 if timespan is None: 

202 timespan = Timespan(begin=None, end=None) 

203 row = { 

204 self._idName: self.key, 

205 TIMESPAN_FIELD_SPECS.begin.name: timespan.begin, 

206 TIMESPAN_FIELD_SPECS.end.name: timespan.end, 

207 "host": host 

208 } 

209 count = self._db.update(self._table, {self._idName: self.key}, row) 

210 if count != 1: 

211 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

212 self._host = host 

213 self._timespan = timespan 

214 

215 @property 

216 def host(self) -> Optional[str]: 

217 # Docstring inherited from RunRecord. 

218 return self._host 

219 

220 @property 

221 def timespan(self) -> Timespan[astropy.time.Time]: 

222 # Docstring inherited from RunRecord. 

223 return self._timespan 

224 

225 

226class DefaultChainedCollectionRecord(ChainedCollectionRecord): 

227 """Default `ChainedCollectionRecord` implementation. 

228 

229 This method assumes the same chain table definition as produced by 

230 `makeCollectionChainTableSpec` method. All column names in the table are 

231 fixed and hard-coded in the methods. 

232 

233 Parameters 

234 ---------- 

235 db : `Database` 

236 Registry database. 

237 key 

238 Unique collection ID, can be the same as ``name`` if ``name`` is used 

239 for identification. Usually this is an integer or string, but can be 

240 other database-specific type. 

241 name : `str` 

242 Collection name. 

243 table : `sqlalchemy.schema.Table` 

244 Table for chain relationship records. 

245 """ 

246 def __init__(self, db: Database, key: Any, name: str, *, table: sqlalchemy.schema.Table): 

247 super().__init__(key=key, name=name) 

248 self._db = db 

249 self._table = table 

250 

251 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None: 

252 # Docstring inherited from ChainedCollectionRecord. 

253 rows = [] 

254 position = itertools.count() 

255 for child, restriction in children.iterPairs(manager, flattenChains=False): 

256 if restriction.names is Ellipsis: 

257 rows.append({"parent": self.key, "child": child.key, 

258 "position": next(position), "dataset_type_name": None}) 

259 else: 

260 for name in restriction.names: 

261 rows.append({"parent": self.key, "child": child.key, 

262 "position": next(position), "dataset_type_name": name}) 

263 with self._db.transaction(): 

264 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

265 self._db.insert(self._table, *rows) 

266 

267 def _load(self, manager: CollectionManager) -> CollectionSearch: 

268 # Docstring inherited from ChainedCollectionRecord. 

269 sql = sqlalchemy.sql.select( 

270 [self._table.columns.child, self._table.columns.dataset_type_name] 

271 ).select_from( 

272 self._table 

273 ).where( 

274 self._table.columns.parent == self.key 

275 ).order_by( 

276 self._table.columns.position 

277 ) 

278 # It's fine to have consecutive rows with the same collection name 

279 # and different dataset type names - CollectionSearch will group those 

280 # up for us. 

281 children = [] 

282 for row in self._db.query(sql): 

283 key = row[self._table.columns.child] 

284 restriction = row[self._table.columns.dataset_type_name] 

285 if not restriction: 

286 restriction = ... # we store ... as "" in the database 

287 record = manager[key] 

288 children.append((record.name, restriction)) 

289 return CollectionSearch.fromExpression(children) 

290 

291 

292K = TypeVar("K") 

293 

294 

295class DefaultCollectionManager(Generic[K], CollectionManager): 

296 """Default `CollectionManager` implementation. 

297 

298 This implementation uses record classes defined in this module and is 

299 based on the same assumptions about schema outlined in the record classes. 

300 

301 Parameters 

302 ---------- 

303 db : `Database` 

304 Interface to the underlying database engine and namespace. 

305 tables : `CollectionTablesTuple` 

306 Named tuple of SQLAlchemy table objects. 

307 collectionIdName : `str` 

308 Name of the column in collections table that identifies it (PK). 

309 

310 Notes 

311 ----- 

312 Implementation uses "aggressive" pre-fetching and caching of the records 

313 in memory. Memory cache is synchronized from database when `refresh` 

314 method is called. 

315 """ 

316 def __init__(self, db: Database, tables: CollectionTablesTuple, collectionIdName: str): 

317 self._db = db 

318 self._tables = tables 

319 self._collectionIdName = collectionIdName 

320 self._records: Dict[K, CollectionRecord] = {} # indexed by record ID 

321 

322 def refresh(self) -> None: 

323 # Docstring inherited from CollectionManager. 

324 sql = sqlalchemy.sql.select( 

325 self._tables.collection.columns + self._tables.run.columns 

326 ).select_from( 

327 self._tables.collection.join(self._tables.run, isouter=True) 

328 ) 

329 # Put found records into a temporary instead of updating self._records 

330 # in place, for exception safety. 

331 records = [] 

332 chains = [] 

333 for row in self._db.query(sql).fetchall(): 

334 collection_id = row[self._tables.collection.columns[self._collectionIdName]] 

335 name = row[self._tables.collection.columns.name] 

336 type = CollectionType(row["type"]) 

337 record: CollectionRecord 

338 if type is CollectionType.RUN: 338 ↛ 351line 338 didn't jump to line 351, because the condition on line 338 was never false

339 record = DefaultRunRecord( 

340 key=collection_id, 

341 name=name, 

342 db=self._db, 

343 table=self._tables.run, 

344 idColumnName=self._collectionIdName, 

345 host=row[self._tables.run.columns.host], 

346 timespan=Timespan( 

347 begin=row[self._tables.run.columns[TIMESPAN_FIELD_SPECS.begin.name]], 

348 end=row[self._tables.run.columns[TIMESPAN_FIELD_SPECS.end.name]], 

349 ) 

350 ) 

351 elif type is CollectionType.CHAINED: 

352 record = DefaultChainedCollectionRecord(db=self._db, 

353 key=collection_id, 

354 table=self._tables.collection_chain, 

355 name=name) 

356 chains.append(record) 

357 else: 

358 record = CollectionRecord(key=collection_id, name=name, type=type) 

359 records.append(record) 

360 self._setRecordCache(records) 

361 for chain in chains: 361 ↛ 362line 361 didn't jump to line 362, because the loop on line 361 never started

362 chain.refresh(self) 

363 

364 def register(self, name: str, type: CollectionType) -> CollectionRecord: 

365 # Docstring inherited from CollectionManager. 

366 record = self._getByName(name) 

367 if record is None: 

368 row, _ = self._db.sync( 

369 self._tables.collection, 

370 keys={"name": name}, 

371 compared={"type": int(type)}, 

372 returning=[self._collectionIdName], 

373 ) 

374 assert row is not None 

375 collection_id = row[self._collectionIdName] 

376 if type is CollectionType.RUN: 

377 row, _ = self._db.sync( 

378 self._tables.run, 

379 keys={self._collectionIdName: collection_id}, 

380 returning=["host", TIMESPAN_FIELD_SPECS.begin.name, TIMESPAN_FIELD_SPECS.end.name], 

381 ) 

382 assert row is not None 

383 record = DefaultRunRecord( 

384 db=self._db, 

385 key=collection_id, 

386 name=name, 

387 table=self._tables.run, 

388 idColumnName=self._collectionIdName, 

389 host=row["host"], 

390 timespan=Timespan( 

391 row[TIMESPAN_FIELD_SPECS.begin.name], 

392 row[TIMESPAN_FIELD_SPECS.end.name] 

393 ), 

394 ) 

395 elif type is CollectionType.CHAINED: 

396 record = DefaultChainedCollectionRecord(db=self._db, key=collection_id, name=name, 

397 table=self._tables.collection_chain) 

398 else: 

399 record = CollectionRecord(key=collection_id, name=name, type=type) 

400 self._addCachedRecord(record) 

401 return record 

402 

403 def remove(self, name: str) -> None: 

404 # Docstring inherited from CollectionManager. 

405 record = self._getByName(name) 

406 if record is None: 406 ↛ 407line 406 didn't jump to line 407, because the condition on line 406 was never true

407 raise MissingCollectionError(f"No collection with name '{name}' found.") 

408 # This may raise 

409 self._db.delete(self._tables.collection, [self._collectionIdName], 

410 {self._collectionIdName: record.key}) 

411 self._removeCachedRecord(record) 

412 

413 def find(self, name: str) -> CollectionRecord: 

414 # Docstring inherited from CollectionManager. 

415 result = self._getByName(name) 

416 if result is None: 

417 raise MissingCollectionError(f"No collection with name '{name}' found.") 

418 return result 

419 

420 def __getitem__(self, key: Any) -> CollectionRecord: 

421 # Docstring inherited from CollectionManager. 

422 try: 

423 return self._records[key] 

424 except KeyError as err: 

425 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err 

426 

427 def __iter__(self) -> Iterator[CollectionRecord]: 

428 yield from self._records.values() 

429 

430 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: 

431 """Set internal record cache to contain given records, 

432 old cached records will be removed. 

433 """ 

434 self._records = {} 

435 for record in records: 

436 self._records[record.key] = record 

437 

438 def _addCachedRecord(self, record: CollectionRecord) -> None: 

439 """Add single record to cache. 

440 """ 

441 self._records[record.key] = record 

442 

443 def _removeCachedRecord(self, record: CollectionRecord) -> None: 

444 """Remove single record from cache. 

445 """ 

446 del self._records[record.key] 

447 

448 @abstractmethod 

449 def _getByName(self, name: str) -> Optional[CollectionRecord]: 

450 """Find collection record given collection name. 

451 """ 

452 raise NotImplementedError()