Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%

104 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:48 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["NameKeyCollectionManager"] 

30 

31import logging 

32from collections.abc import Iterable, Mapping 

33from typing import TYPE_CHECKING, Any 

34 

35import sqlalchemy 

36 

37from ... import ddl 

38from ...timespan_database_representation import TimespanDatabaseRepresentation 

39from .._collection_type import CollectionType 

40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

41from ._base import ( 

42 CollectionTablesTuple, 

43 DefaultCollectionManager, 

44 makeCollectionChainTableSpec, 

45 makeRunTableSpec, 

46) 

47 

48if TYPE_CHECKING: 

49 from .._caching_context import CachingContext 

50 from ..interfaces import Database, StaticTablesContext 

51 

52 

53_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) 

54 

55 

56# This has to be updated on every schema change 

57_VERSION = VersionTuple(2, 0, 0) 

58 

59 

60_LOG = logging.getLogger(__name__) 

61 

62 

63def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

64 return CollectionTablesTuple( 

65 collection=ddl.TableSpec( 

66 fields=[ 

67 _KEY_FIELD_SPEC, 

68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

70 ], 

71 ), 

72 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass), 

73 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String), 

74 ) 

75 

76 

77class NameKeyCollectionManager(DefaultCollectionManager[str]): 

78 """A `CollectionManager` implementation that uses collection names for 

79 primary/foreign keys and aggressively loads all collection/run records in 

80 the database into memory. 

81 

82 Most of the logic, including caching policy, is implemented in the base 

83 class, this class only adds customizations specific to this particular 

84 table schema. 

85 """ 

86 

87 @classmethod 

88 def initialize( 

89 cls, 

90 db: Database, 

91 context: StaticTablesContext, 

92 *, 

93 caching_context: CachingContext, 

94 registry_schema_version: VersionTuple | None = None, 

95 ) -> NameKeyCollectionManager: 

96 # Docstring inherited from CollectionManager. 

97 return cls( 

98 db, 

99 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

100 collectionIdName="name", 

101 caching_context=caching_context, 

102 registry_schema_version=registry_schema_version, 

103 ) 

104 

105 def clone(self, db: Database, caching_context: CachingContext) -> NameKeyCollectionManager: 

106 return NameKeyCollectionManager( 

107 db, 

108 tables=self._tables, 

109 collectionIdName=self._collectionIdName, 

110 caching_context=caching_context, 

111 registry_schema_version=self._registry_schema_version, 

112 ) 

113 

114 @classmethod 

115 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

116 # Docstring inherited from CollectionManager. 

117 return f"{prefix}_name" 

118 

119 @classmethod 

120 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

121 # Docstring inherited from CollectionManager. 

122 return f"{prefix}_name" 

123 

124 @classmethod 

125 def addCollectionForeignKey( 

126 cls, 

127 tableSpec: ddl.TableSpec, 

128 *, 

129 prefix: str = "collection", 

130 onDelete: str | None = None, 

131 constraint: bool = True, 

132 **kwargs: Any, 

133 ) -> ddl.FieldSpec: 

134 # Docstring inherited from CollectionManager. 

135 original = _KEY_FIELD_SPEC 

136 copy = ddl.FieldSpec( 

137 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

138 ) 

139 tableSpec.fields.add(copy) 

140 if constraint: 

141 tableSpec.foreignKeys.append( 

142 ddl.ForeignKeySpec( 

143 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

144 ) 

145 ) 

146 return copy 

147 

148 @classmethod 

149 def addRunForeignKey( 

150 cls, 

151 tableSpec: ddl.TableSpec, 

152 *, 

153 prefix: str = "run", 

154 onDelete: str | None = None, 

155 constraint: bool = True, 

156 **kwargs: Any, 

157 ) -> ddl.FieldSpec: 

158 # Docstring inherited from CollectionManager. 

159 original = _KEY_FIELD_SPEC 

160 copy = ddl.FieldSpec( 

161 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

162 ) 

163 tableSpec.fields.add(copy) 

164 if constraint: 164 ↛ 168line 164 didn't jump to line 168, because the condition on line 164 was never false

165 tableSpec.foreignKeys.append( 

166 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

167 ) 

168 return copy 

169 

170 def getParentChains(self, key: str) -> set[str]: 

171 # Docstring inherited from CollectionManager. 

172 table = self._tables.collection_chain 

173 sql = ( 

174 sqlalchemy.sql.select(table.columns["parent"]) 

175 .select_from(table) 

176 .where(table.columns["child"] == key) 

177 ) 

178 with self._db.query(sql) as sql_result: 

179 parent_names = set(sql_result.scalars().all()) 

180 return parent_names 

181 

182 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: 

183 # Docstring inherited from base class. 

184 return self._fetch_by_key(names) 

185 

186 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: 

187 # Docstring inherited from base class. 

188 _LOG.debug("Fetching collection records using names %s.", collection_ids) 

189 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( 

190 self._tables.collection.join(self._tables.run, isouter=True) 

191 ) 

192 

193 chain_sql = sqlalchemy.sql.select( 

194 self._tables.collection_chain.columns["parent"], 

195 self._tables.collection_chain.columns["position"], 

196 self._tables.collection_chain.columns["child"], 

197 ) 

198 

199 records: list[CollectionRecord[str]] = [] 

200 # We want to keep transactions as short as possible. When we fetch 

201 # everything we want to quickly fetch things into memory and finish 

202 # transaction. When we fetch just few records we need to process result 

203 # of the first query before we can run the second one. 

204 if collection_ids is not None: 

205 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) 

206 with self._db.transaction(): 

207 with self._db.query(sql) as sql_result: 

208 sql_rows = sql_result.mappings().fetchall() 

209 

210 records, chained_ids = self._rows_to_records(sql_rows) 

211 

212 if chained_ids: 

213 # Retrieve chained collection compositions 

214 chain_sql = chain_sql.where( 

215 self._tables.collection_chain.columns["parent"].in_(chained_ids) 

216 ) 

217 with self._db.query(chain_sql) as sql_result: 

218 chain_rows = sql_result.mappings().fetchall() 

219 

220 records += self._rows_to_chains(chain_rows, chained_ids) 

221 

222 else: 

223 with self._db.transaction(): 

224 with self._db.query(sql) as sql_result: 

225 sql_rows = sql_result.mappings().fetchall() 

226 with self._db.query(chain_sql) as sql_result: 

227 chain_rows = sql_result.mappings().fetchall() 

228 

229 records, chained_ids = self._rows_to_records(sql_rows) 

230 records += self._rows_to_chains(chain_rows, chained_ids) 

231 

232 return records 

233 

234 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: 

235 """Convert rows returned from collection query to a list of records 

236 and a list chained collection names. 

237 """ 

238 records: list[CollectionRecord[str]] = [] 

239 TimespanReprClass = self._db.getTimespanRepresentation() 

240 chained_ids: list[str] = [] 

241 for row in rows: 

242 name = row[self._tables.collection.columns.name] 

243 type = CollectionType(row["type"]) 

244 record: CollectionRecord[str] 

245 if type is CollectionType.RUN: 

246 record = RunRecord[str]( 

247 key=name, 

248 name=name, 

249 host=row[self._tables.run.columns.host], 

250 timespan=TimespanReprClass.extract(row), 

251 ) 

252 records.append(record) 

253 elif type is CollectionType.CHAINED: 

254 # Need to delay chained collection construction until to 

255 # fetch their children names. 

256 chained_ids.append(name) 

257 else: 

258 record = CollectionRecord[str](key=name, name=name, type=type) 

259 records.append(record) 

260 

261 return records, chained_ids 

262 

263 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: 

264 """Convert rows returned from collection chain query to a list of 

265 records. 

266 """ 

267 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

268 for row in rows: 

269 chains_defs[row["parent"]].append((row["position"], row["child"])) 

270 

271 records: list[CollectionRecord[str]] = [] 

272 for name, children in chains_defs.items(): 

273 children_names = [child for _, child in sorted(children)] 

274 record = ChainedCollectionRecord[str]( 

275 key=name, 

276 name=name, 

277 children=children_names, 

278 ) 

279 records.append(record) 

280 

281 return records 

282 

283 @classmethod 

284 def currentVersions(cls) -> list[VersionTuple]: 

285 # Docstring inherited from VersionedExtension. 

286 return [_VERSION]