Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%

102 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-16 10:43 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["NameKeyCollectionManager"] 

30 

31import logging 

32from collections.abc import Iterable, Mapping 

33from typing import TYPE_CHECKING, Any 

34 

35import sqlalchemy 

36 

37from ... import ddl 

38from ...timespan_database_representation import TimespanDatabaseRepresentation 

39from .._collection_type import CollectionType 

40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

41from ._base import ( 

42 CollectionTablesTuple, 

43 DefaultCollectionManager, 

44 makeCollectionChainTableSpec, 

45 makeRunTableSpec, 

46) 

47 

48if TYPE_CHECKING: 

49 from .._caching_context import CachingContext 

50 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext 

51 

52 

53_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) 

54 

55 

56# This has to be updated on every schema change 

57_VERSION = VersionTuple(2, 0, 0) 

58 

59 

60_LOG = logging.getLogger(__name__) 

61 

62 

63def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

64 return CollectionTablesTuple( 

65 collection=ddl.TableSpec( 

66 fields=[ 

67 _KEY_FIELD_SPEC, 

68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

70 ], 

71 ), 

72 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass), 

73 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String), 

74 ) 

75 

76 

77class NameKeyCollectionManager(DefaultCollectionManager[str]): 

78 """A `CollectionManager` implementation that uses collection names for 

79 primary/foreign keys and aggressively loads all collection/run records in 

80 the database into memory. 

81 

82 Most of the logic, including caching policy, is implemented in the base 

83 class, this class only adds customizations specific to this particular 

84 table schema. 

85 """ 

86 

87 @classmethod 

88 def initialize( 

89 cls, 

90 db: Database, 

91 context: StaticTablesContext, 

92 *, 

93 dimensions: DimensionRecordStorageManager, 

94 caching_context: CachingContext, 

95 registry_schema_version: VersionTuple | None = None, 

96 ) -> NameKeyCollectionManager: 

97 # Docstring inherited from CollectionManager. 

98 return cls( 

99 db, 

100 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

101 collectionIdName="name", 

102 dimensions=dimensions, 

103 caching_context=caching_context, 

104 registry_schema_version=registry_schema_version, 

105 ) 

106 

107 @classmethod 

108 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

109 # Docstring inherited from CollectionManager. 

110 return f"{prefix}_name" 

111 

112 @classmethod 

113 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

114 # Docstring inherited from CollectionManager. 

115 return f"{prefix}_name" 

116 

117 @classmethod 

118 def addCollectionForeignKey( 

119 cls, 

120 tableSpec: ddl.TableSpec, 

121 *, 

122 prefix: str = "collection", 

123 onDelete: str | None = None, 

124 constraint: bool = True, 

125 **kwargs: Any, 

126 ) -> ddl.FieldSpec: 

127 # Docstring inherited from CollectionManager. 

128 original = _KEY_FIELD_SPEC 

129 copy = ddl.FieldSpec( 

130 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

131 ) 

132 tableSpec.fields.add(copy) 

133 if constraint: 

134 tableSpec.foreignKeys.append( 

135 ddl.ForeignKeySpec( 

136 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

137 ) 

138 ) 

139 return copy 

140 

141 @classmethod 

142 def addRunForeignKey( 

143 cls, 

144 tableSpec: ddl.TableSpec, 

145 *, 

146 prefix: str = "run", 

147 onDelete: str | None = None, 

148 constraint: bool = True, 

149 **kwargs: Any, 

150 ) -> ddl.FieldSpec: 

151 # Docstring inherited from CollectionManager. 

152 original = _KEY_FIELD_SPEC 

153 copy = ddl.FieldSpec( 

154 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

155 ) 

156 tableSpec.fields.add(copy) 

157 if constraint: 157 ↛ 161line 157 didn't jump to line 161, because the condition on line 157 was never false

158 tableSpec.foreignKeys.append( 

159 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

160 ) 

161 return copy 

162 

163 def getParentChains(self, key: str) -> set[str]: 

164 # Docstring inherited from CollectionManager. 

165 table = self._tables.collection_chain 

166 sql = ( 

167 sqlalchemy.sql.select(table.columns["parent"]) 

168 .select_from(table) 

169 .where(table.columns["child"] == key) 

170 ) 

171 with self._db.query(sql) as sql_result: 

172 parent_names = set(sql_result.scalars().all()) 

173 return parent_names 

174 

175 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: 

176 # Docstring inherited from base class. 

177 return self._fetch_by_key(names) 

178 

179 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: 

180 # Docstring inherited from base class. 

181 _LOG.debug("Fetching collection records using names %s.", collection_ids) 

182 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( 

183 self._tables.collection.join(self._tables.run, isouter=True) 

184 ) 

185 

186 chain_sql = sqlalchemy.sql.select( 

187 self._tables.collection_chain.columns["parent"], 

188 self._tables.collection_chain.columns["position"], 

189 self._tables.collection_chain.columns["child"], 

190 ) 

191 

192 records: list[CollectionRecord[str]] = [] 

193 # We want to keep transactions as short as possible. When we fetch 

194 # everything we want to quickly fetch things into memory and finish 

195 # transaction. When we fetch just few records we need to process result 

196 # of the first query before we can run the second one. 

197 if collection_ids is not None: 

198 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) 

199 with self._db.transaction(): 

200 with self._db.query(sql) as sql_result: 

201 sql_rows = sql_result.mappings().fetchall() 

202 

203 records, chained_ids = self._rows_to_records(sql_rows) 

204 

205 if chained_ids: 

206 # Retrieve chained collection compositions 

207 chain_sql = chain_sql.where( 

208 self._tables.collection_chain.columns["parent"].in_(chained_ids) 

209 ) 

210 with self._db.query(chain_sql) as sql_result: 

211 chain_rows = sql_result.mappings().fetchall() 

212 

213 records += self._rows_to_chains(chain_rows, chained_ids) 

214 

215 else: 

216 with self._db.transaction(): 

217 with self._db.query(sql) as sql_result: 

218 sql_rows = sql_result.mappings().fetchall() 

219 with self._db.query(chain_sql) as sql_result: 

220 chain_rows = sql_result.mappings().fetchall() 

221 

222 records, chained_ids = self._rows_to_records(sql_rows) 

223 records += self._rows_to_chains(chain_rows, chained_ids) 

224 

225 return records 

226 

227 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: 

228 """Convert rows returned from collection query to a list of records 

229 and a list chained collection names. 

230 """ 

231 records: list[CollectionRecord[str]] = [] 

232 TimespanReprClass = self._db.getTimespanRepresentation() 

233 chained_ids: list[str] = [] 

234 for row in rows: 

235 name = row[self._tables.collection.columns.name] 

236 type = CollectionType(row["type"]) 

237 record: CollectionRecord[str] 

238 if type is CollectionType.RUN: 

239 record = RunRecord[str]( 

240 key=name, 

241 name=name, 

242 host=row[self._tables.run.columns.host], 

243 timespan=TimespanReprClass.extract(row), 

244 ) 

245 records.append(record) 

246 elif type is CollectionType.CHAINED: 

247 # Need to delay chained collection construction until to 

248 # fetch their children names. 

249 chained_ids.append(name) 

250 else: 

251 record = CollectionRecord[str](key=name, name=name, type=type) 

252 records.append(record) 

253 

254 return records, chained_ids 

255 

256 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: 

257 """Convert rows returned from collection chain query to a list of 

258 records. 

259 """ 

260 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

261 for row in rows: 

262 chains_defs[row["parent"]].append((row["position"], row["child"])) 

263 

264 records: list[CollectionRecord[str]] = [] 

265 for name, children in chains_defs.items(): 

266 children_names = [child for _, child in sorted(children)] 

267 record = ChainedCollectionRecord[str]( 

268 key=name, 

269 name=name, 

270 children=children_names, 

271 ) 

272 records.append(record) 

273 

274 return records 

275 

276 @classmethod 

277 def currentVersions(cls) -> list[VersionTuple]: 

278 # Docstring inherited from VersionedExtension. 

279 return [_VERSION]