Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%

99 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-01 10:59 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["NameKeyCollectionManager"] 

30 

31from collections.abc import Iterable, Mapping 

32from typing import TYPE_CHECKING, Any 

33 

34import sqlalchemy 

35 

36from ... import ddl 

37from ..._timespan import TimespanDatabaseRepresentation 

38from .._collection_type import CollectionType 

39from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

40from ._base import ( 

41 CollectionTablesTuple, 

42 DefaultCollectionManager, 

43 makeCollectionChainTableSpec, 

44 makeRunTableSpec, 

45) 

46 

47if TYPE_CHECKING: 

48 from .._caching_context import CachingContext 

49 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext 

50 

51 

52_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) 

53 

54 

55# This has to be updated on every schema change 

56_VERSION = VersionTuple(2, 0, 0) 

57 

58 

59def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

60 return CollectionTablesTuple( 

61 collection=ddl.TableSpec( 

62 fields=[ 

63 _KEY_FIELD_SPEC, 

64 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

65 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

66 ], 

67 ), 

68 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass), 

69 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String), 

70 ) 

71 

72 

73class NameKeyCollectionManager(DefaultCollectionManager[str]): 

74 """A `CollectionManager` implementation that uses collection names for 

75 primary/foreign keys and aggressively loads all collection/run records in 

76 the database into memory. 

77 

78 Most of the logic, including caching policy, is implemented in the base 

79 class, this class only adds customizations specific to this particular 

80 table schema. 

81 """ 

82 

83 @classmethod 

84 def initialize( 

85 cls, 

86 db: Database, 

87 context: StaticTablesContext, 

88 *, 

89 dimensions: DimensionRecordStorageManager, 

90 caching_context: CachingContext, 

91 registry_schema_version: VersionTuple | None = None, 

92 ) -> NameKeyCollectionManager: 

93 # Docstring inherited from CollectionManager. 

94 return cls( 

95 db, 

96 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

97 collectionIdName="name", 

98 dimensions=dimensions, 

99 caching_context=caching_context, 

100 registry_schema_version=registry_schema_version, 

101 ) 

102 

103 @classmethod 

104 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

105 # Docstring inherited from CollectionManager. 

106 return f"{prefix}_name" 

107 

108 @classmethod 

109 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

110 # Docstring inherited from CollectionManager. 

111 return f"{prefix}_name" 

112 

113 @classmethod 

114 def addCollectionForeignKey( 

115 cls, 

116 tableSpec: ddl.TableSpec, 

117 *, 

118 prefix: str = "collection", 

119 onDelete: str | None = None, 

120 constraint: bool = True, 

121 **kwargs: Any, 

122 ) -> ddl.FieldSpec: 

123 # Docstring inherited from CollectionManager. 

124 original = _KEY_FIELD_SPEC 

125 copy = ddl.FieldSpec( 

126 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

127 ) 

128 tableSpec.fields.add(copy) 

129 if constraint: 

130 tableSpec.foreignKeys.append( 

131 ddl.ForeignKeySpec( 

132 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

133 ) 

134 ) 

135 return copy 

136 

137 @classmethod 

138 def addRunForeignKey( 

139 cls, 

140 tableSpec: ddl.TableSpec, 

141 *, 

142 prefix: str = "run", 

143 onDelete: str | None = None, 

144 constraint: bool = True, 

145 **kwargs: Any, 

146 ) -> ddl.FieldSpec: 

147 # Docstring inherited from CollectionManager. 

148 original = _KEY_FIELD_SPEC 

149 copy = ddl.FieldSpec( 

150 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

151 ) 

152 tableSpec.fields.add(copy) 

153 if constraint: 153 ↛ 157line 153 didn't jump to line 157, because the condition on line 153 was never false

154 tableSpec.foreignKeys.append( 

155 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

156 ) 

157 return copy 

158 

159 def getParentChains(self, key: str) -> set[str]: 

160 # Docstring inherited from CollectionManager. 

161 table = self._tables.collection_chain 

162 sql = ( 

163 sqlalchemy.sql.select(table.columns["parent"]) 

164 .select_from(table) 

165 .where(table.columns["child"] == key) 

166 ) 

167 with self._db.query(sql) as sql_result: 

168 parent_names = set(sql_result.scalars().all()) 

169 return parent_names 

170 

171 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: 

172 # Docstring inherited from base class. 

173 return self._fetch_by_key(names) 

174 

175 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: 

176 # Docstring inherited from base class. 

177 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( 

178 self._tables.collection.join(self._tables.run, isouter=True) 

179 ) 

180 

181 chain_sql = sqlalchemy.sql.select( 

182 self._tables.collection_chain.columns["parent"], 

183 self._tables.collection_chain.columns["position"], 

184 self._tables.collection_chain.columns["child"], 

185 ) 

186 

187 records: list[CollectionRecord[str]] = [] 

188 # We want to keep transactions as short as possible. When we fetch 

189 # everything we want to quickly fetch things into memory and finish 

190 # transaction. When we fetch just few records we need to process result 

191 # of the first query before we can run the second one. 

192 if collection_ids is not None: 

193 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) 

194 with self._db.transaction(): 

195 with self._db.query(sql) as sql_result: 

196 sql_rows = sql_result.mappings().fetchall() 

197 

198 records, chained_ids = self._rows_to_records(sql_rows) 

199 

200 if chained_ids: 

201 # Retrieve chained collection compositions 

202 chain_sql = chain_sql.where( 

203 self._tables.collection_chain.columns["parent"].in_(chained_ids) 

204 ) 

205 with self._db.query(chain_sql) as sql_result: 

206 chain_rows = sql_result.mappings().fetchall() 

207 

208 records += self._rows_to_chains(chain_rows, chained_ids) 

209 

210 else: 

211 with self._db.transaction(): 

212 with self._db.query(sql) as sql_result: 

213 sql_rows = sql_result.mappings().fetchall() 

214 with self._db.query(chain_sql) as sql_result: 

215 chain_rows = sql_result.mappings().fetchall() 

216 

217 records, chained_ids = self._rows_to_records(sql_rows) 

218 records += self._rows_to_chains(chain_rows, chained_ids) 

219 

220 return records 

221 

222 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: 

223 """Convert rows returned from collection query to a list of records 

224 and a list chained collection names. 

225 """ 

226 records: list[CollectionRecord[str]] = [] 

227 TimespanReprClass = self._db.getTimespanRepresentation() 

228 chained_ids: list[str] = [] 

229 for row in rows: 

230 name = row[self._tables.collection.columns.name] 

231 type = CollectionType(row["type"]) 

232 record: CollectionRecord[str] 

233 if type is CollectionType.RUN: 

234 record = RunRecord[str]( 

235 key=name, 

236 name=name, 

237 host=row[self._tables.run.columns.host], 

238 timespan=TimespanReprClass.extract(row), 

239 ) 

240 records.append(record) 

241 elif type is CollectionType.CHAINED: 

242 # Need to delay chained collection construction until to 

243 # fetch their children names. 

244 chained_ids.append(name) 

245 else: 

246 record = CollectionRecord[str](key=name, name=name, type=type) 

247 records.append(record) 

248 

249 return records, chained_ids 

250 

251 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: 

252 """Convert rows returned from collection chain query to a list of 

253 records. 

254 """ 

255 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

256 for row in rows: 

257 chains_defs[row["parent"]].append((row["position"], row["child"])) 

258 

259 records: list[CollectionRecord[str]] = [] 

260 for name, children in chains_defs.items(): 

261 children_names = [child for _, child in sorted(children)] 

262 record = ChainedCollectionRecord[str]( 

263 key=name, 

264 name=name, 

265 children=children_names, 

266 ) 

267 records.append(record) 

268 

269 return records 

270 

271 @classmethod 

272 def currentVersions(cls) -> list[VersionTuple]: 

273 # Docstring inherited from VersionedExtension. 

274 return [_VERSION]