Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["AggressiveNameKeyCollectionManager"] 

24 

25from collections import namedtuple 

26import astropy.time 

27from typing import ( 

28 Any, 

29 Iterator, 

30 Optional, 

31 TYPE_CHECKING, 

32) 

33 

34import sqlalchemy 

35 

36from ...core import ddl 

37from ...core.timespan import Timespan, TIMESPAN_FIELD_SPECS 

38from .._collectionType import CollectionType 

39from ..interfaces import ( 

40 ChainedCollectionRecord, 

41 CollectionManager, 

42 CollectionRecord, 

43 MissingCollectionError, 

44 RunRecord, 

45) 

46from ..wildcards import CollectionSearch 

47 

48if TYPE_CHECKING: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 from .database import Database, StaticTablesContext 

50 

51 

52_TablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) 

53 

54_TABLES_SPEC = _TablesTuple( 

55 collection=ddl.TableSpec( 

56 fields=[ 

57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True), 

58 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

59 ], 

60 ), 

61 run=ddl.TableSpec( 

62 fields=[ 

63 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True), 

64 TIMESPAN_FIELD_SPECS.begin, 

65 TIMESPAN_FIELD_SPECS.end, 

66 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128), 

67 ], 

68 foreignKeys=[ 

69 ddl.ForeignKeySpec("collection", source=("name",), target=("name",), onDelete="CASCADE"), 

70 ], 

71 ), 

72 collection_chain=ddl.TableSpec( 

73 fields=[ 

74 ddl.FieldSpec("parent", dtype=sqlalchemy.String, length=64, primaryKey=True), 

75 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True), 

76 ddl.FieldSpec("child", dtype=sqlalchemy.String, length=64, nullable=False), 

77 ddl.FieldSpec("dataset_type_name", dtype=sqlalchemy.String, length=128, nullable=True), 

78 ], 

79 foreignKeys=[ 

80 ddl.ForeignKeySpec("collection", source=("parent",), target=("name",), onDelete="CASCADE"), 

81 ddl.ForeignKeySpec("collection", source=("child",), target=("name",)), 

82 ], 

83 ), 

84) 

85 

86 

87class NameKeyCollectionRecord(CollectionRecord): 

88 """A `CollectionRecord` implementation that just uses the string name as 

89 the primary/foreign key for collections. 

90 """ 

91 

92 @property 

93 def key(self) -> str: 

94 # Docstring inherited from CollectionRecord. 

95 return self.name 

96 

97 

98class NameKeyRunRecord(RunRecord): 

99 """A `RunRecord` implementation that just uses the string name as the 

100 primary/foreign key for collections. 

101 """ 

102 def __init__(self, db: Database, name: str, *, table: sqlalchemy.schema.Table, 

103 host: Optional[str] = None, timespan: Optional[Timespan[astropy.time.Time]] = None): 

104 super().__init__(name=name, type=CollectionType.RUN) 

105 self._db = db 

106 self._table = table 

107 self._host = host 

108 if timespan is None: 108 ↛ 109line 108 didn't jump to line 109, because the condition on line 108 was never true

109 timespan = Timespan(begin=None, end=None) 

110 self._timespan = timespan 

111 

112 def update(self, host: Optional[str] = None, timespan: Optional[Timespan[astropy.time.Time]] = None): 

113 # Docstring inherited from RunRecord. 

114 if timespan is None: 

115 timespan = Timespan(begin=None, end=None) 

116 row = { 

117 "name": self.name, 

118 TIMESPAN_FIELD_SPECS.begin.name: timespan.begin, 

119 TIMESPAN_FIELD_SPECS.end.name: timespan.end, 

120 "host": host, 

121 } 

122 count = self._db.update(self._table, {"name": self.name}, row) 

123 if count != 1: 

124 raise RuntimeError(f"Run update affected {count} records; expected exactly one.") 

125 self._host = host 

126 self._timespan = timespan 

127 

128 @property 

129 def key(self) -> str: 

130 # Docstring inherited from CollectionRecord. 

131 return self.name 

132 

133 @property 

134 def host(self) -> Optional[str]: 

135 # Docstring inherited from RunRecord. 

136 return self._host 

137 

138 @property 

139 def timespan(self) -> Timespan[astropy.time.Time]: 

140 # Docstring inherited from RunRecord. 

141 return self._timespan 

142 

143 

144class NameKeyChainedCollectionRecord(ChainedCollectionRecord): 

145 """A `ChainedCollectionRecord` implementation that just uses the string 

146 name as the primary/foreign key for collections. 

147 """ 

148 def __init__(self, db: Database, name: str, *, table: sqlalchemy.schema.Table): 

149 super().__init__(name=name) 

150 self._db = db 

151 self._table = table 

152 

153 @property 

154 def key(self) -> str: 

155 # Docstring inherited from CollectionRecord. 

156 return self.name 

157 

158 def _update(self, manager: CollectionManager, children: CollectionSearch): 

159 # Docstring inherited from ChainedCollectionRecord. 

160 rows = [] 

161 i = 0 

162 for child, restriction in children.iter(manager, withRestrictions=True, flattenChains=False): 

163 if restriction.names is ...: 

164 rows.append({"parent": self.key, "child": child.key, "position": i, 

165 "dataset_type_name": ""}) 

166 i += 1 

167 else: 

168 for name in restriction.names: 

169 rows.append({"parent": self.key, "child": child.key, "position": i, 

170 "dataset_type_name": name}) 

171 i += 1 

172 with self._db.transaction(): 

173 self._db.delete(self._table, ["parent"], {"parent": self.key}) 

174 self._db.insert(self._table, *rows) 

175 

176 def _load(self, manager: CollectionManager) -> CollectionSearch: 

177 # Docstring inherited from ChainedCollectionRecord. 

178 sql = sqlalchemy.sql.select( 

179 [self._table.columns.child, self._table.columns.dataset_type_name] 

180 ).select_from( 

181 self._table 

182 ).where( 

183 self._table.columns.parent == self.key 

184 ).order_by( 

185 self._table.columns.position 

186 ) 

187 # It's fine to have consecutive rows with the same collection name 

188 # and different dataset type names - CollectionSearch will group those 

189 # up for us. 

190 children = [] 

191 for row in self._db.query(sql): 

192 key = row[self._table.columns.child] 

193 restriction = row[self._table.columns.dataset_type_name] 

194 if not restriction: 

195 restriction = ... # we store ... as "" in the database 

196 record = manager[key] 

197 children.append((record.name, restriction)) 

198 return CollectionSearch.fromExpression(children) 

199 

200 

201class AggressiveNameKeyCollectionManager(CollectionManager): 

202 """A `CollectionManager` implementation that uses collection names for 

203 primary/foreign keys and aggressively loads all collection/run records in 

204 the database into memory. 

205 

206 Parameters 

207 ---------- 

208 db : `Database` 

209 Interface to the underlying database engine and namespace. 

210 tables : `_TablesTuple` 

211 Named tuple of SQLAlchemy table objects. 

212 """ 

213 def __init__(self, db: Database, tables: _TablesTuple): 

214 self._db = db 

215 self._tables = tables 

216 self._records = {} 

217 

218 @classmethod 

219 def initialize(cls, db: Database, context: StaticTablesContext) -> CollectionManager: 

220 # Docstring inherited from CollectionManager. 

221 return cls(db, tables=context.addTableTuple(_TABLES_SPEC)) 

222 

223 @classmethod 

224 def addCollectionForeignKey(cls, tableSpec: ddl.TableSpec, *, prefix: str = "collection", 

225 onDelete: Optional[str] = None, **kwds: Any) -> ddl.FieldSpec: 

226 # Docstring inherited from CollectionManager. 

227 if prefix is None: 227 ↛ 228line 227 didn't jump to line 228, because the condition on line 227 was never true

228 prefix = "collection" 

229 original = _TABLES_SPEC.collection.fields["name"] 

230 copy = ddl.FieldSpec(cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, 

231 length=original.length, **kwds) 

232 tableSpec.fields.add(copy) 

233 tableSpec.foreignKeys.append(ddl.ForeignKeySpec("collection", source=(copy.name,), 

234 target=(original.name,), onDelete=onDelete)) 

235 return copy 

236 

237 @classmethod 

238 def addRunForeignKey(cls, tableSpec: ddl.TableSpec, *, prefix: str = "run", 

239 onDelete: Optional[str] = None, **kwds: Any) -> ddl.FieldSpec: 

240 # Docstring inherited from CollectionManager. 

241 if prefix is None: 241 ↛ 242line 241 didn't jump to line 242, because the condition on line 241 was never true

242 prefix = "run" 

243 original = _TABLES_SPEC.run.fields["name"] 

244 copy = ddl.FieldSpec(cls.getRunForeignKeyName(prefix), dtype=original.dtype, 

245 length=original.length, **kwds) 

246 tableSpec.fields.add(copy) 

247 tableSpec.foreignKeys.append(ddl.ForeignKeySpec("run", source=(copy.name,), 

248 target=(original.name,), onDelete=onDelete)) 

249 return copy 

250 

251 @classmethod 

252 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

253 return f"{prefix}_name" 

254 

255 @classmethod 

256 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

257 return f"{prefix}_name" 

258 

259 def refresh(self): 

260 # Docstring inherited from CollectionManager. 

261 sql = sqlalchemy.sql.select( 

262 self._tables.collection.columns + self._tables.run.columns 

263 ).select_from( 

264 self._tables.collection.join(self._tables.run, isouter=True) 

265 ) 

266 # Put found records into a temporary instead of updating self._records 

267 # in place, for exception safety. 

268 records = {} 

269 chains = [] 

270 for row in self._db.query(sql).fetchall(): 

271 name = row[self._tables.collection.columns.name] 

272 type = CollectionType(row["type"]) 

273 if type is CollectionType.RUN: 273 ↛ 284line 273 didn't jump to line 284, because the condition on line 273 was never false

274 record = NameKeyRunRecord( 

275 name=name, 

276 db=self._db, 

277 table=self._tables.run, 

278 host=row[self._tables.run.columns.host], 

279 timespan=Timespan( 

280 begin=row[self._tables.run.columns[TIMESPAN_FIELD_SPECS.begin.name]], 

281 end=row[self._tables.run.columns[TIMESPAN_FIELD_SPECS.end.name]], 

282 ) 

283 ) 

284 elif type is CollectionType.CHAINED: 

285 record = NameKeyChainedCollectionRecord(db=self._db, table=self._tables.collection_chain, 

286 name=name) 

287 chains.append(record) 

288 else: 

289 record = NameKeyCollectionRecord(type=type, name=name) 

290 records[record.name] = record 

291 self._records = records 

292 for chain in chains: 292 ↛ 293line 292 didn't jump to line 293, because the loop on line 292 never started

293 chain.refresh(self) 

294 

295 def register(self, name: str, type: CollectionType) -> CollectionRecord: 

296 # Docstring inherited from CollectionManager. 

297 record = self._records.get(name) 

298 if record is None: 

299 kwds = {"name": name} 

300 self._db.sync( 

301 self._tables.collection, 

302 keys=kwds, 

303 compared={"type": int(type)}, 

304 ) 

305 if type is CollectionType.RUN: 

306 row, _ = self._db.sync( 

307 self._tables.run, 

308 keys=kwds, 

309 returning={"host", TIMESPAN_FIELD_SPECS.begin.name, TIMESPAN_FIELD_SPECS.end.name}, 

310 ) 

311 record = NameKeyRunRecord( 

312 db=self._db, 

313 table=self._tables.run, 

314 host=row["host"], 

315 timespan=Timespan( 

316 row[TIMESPAN_FIELD_SPECS.begin.name], 

317 row[TIMESPAN_FIELD_SPECS.end.name] 

318 ), 

319 **kwds 

320 ) 

321 elif type is CollectionType.CHAINED: 

322 record = NameKeyChainedCollectionRecord(db=self._db, table=self._tables.collection_chain, 

323 **kwds) 

324 else: 

325 record = NameKeyCollectionRecord(type=type, **kwds) 

326 self._records[record.name] = record 

327 return record 

328 

329 def remove(self, name: str): 

330 # Docstring inherited from CollectionManager. 

331 try: 

332 record = self._records.pop(name) 

333 except KeyError: 

334 raise MissingCollectionError(f"No collection with name '{name}' found.") 

335 try: 

336 self._db.delete(self._tables.collection, ["name"], {"name": name}) 

337 except Exception: 

338 self._records[name] = record 

339 raise 

340 

341 def find(self, name: str) -> CollectionRecord: 

342 # Docstring inherited from CollectionManager. 

343 result = self._records.get(name) 

344 if result is None: 

345 raise MissingCollectionError(f"No collection with name '{name}' found.") 

346 return result 

347 

348 def __getitem__(self, key: Any) -> Optional[CollectionRecord]: 

349 # Docstring inherited from CollectionManager. 

350 try: 

351 return self._records[key] 

352 except KeyError as err: 

353 raise MissingCollectionError(f"Collection with key '{err}' not found.") from err 

354 

355 def __iter__(self) -> Iterator[CollectionRecord]: 

356 yield from self._records.values()