Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-07 11:02 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["NameKeyCollectionManager"] 

30 

31import logging 

32from collections.abc import Iterable, Mapping 

33from typing import TYPE_CHECKING, Any 

34 

35import sqlalchemy 

36 

37from ... import ddl 

38from ...column_spec import COLLECTION_NAME_MAX_LENGTH 

39from ...timespan_database_representation import TimespanDatabaseRepresentation 

40from .._collection_type import CollectionType 

41from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple 

42from ._base import ( 

43 CollectionTablesTuple, 

44 DefaultCollectionManager, 

45 makeCollectionChainTableSpec, 

46 makeRunTableSpec, 

47) 

48 

49if TYPE_CHECKING: 

50 from .._caching_context import CachingContext 

51 from ..interfaces import Database, StaticTablesContext 

52 

53 

54_KEY_FIELD_SPEC = ddl.FieldSpec( 

55 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, primaryKey=True 

56) 

57 

58 

59# This has to be updated on every schema change 

60_VERSION = VersionTuple(2, 0, 0) 

61 

62 

63_LOG = logging.getLogger(__name__) 

64 

65 

66def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: 

67 return CollectionTablesTuple( 

68 collection=ddl.TableSpec( 

69 fields=[ 

70 _KEY_FIELD_SPEC, 

71 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False), 

72 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True), 

73 ], 

74 ), 

75 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass), 

76 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String), 

77 ) 

78 

79 

80class NameKeyCollectionManager(DefaultCollectionManager[str]): 

81 """A `CollectionManager` implementation that uses collection names for 

82 primary/foreign keys and aggressively loads all collection/run records in 

83 the database into memory. 

84 

85 Most of the logic, including caching policy, is implemented in the base 

86 class, this class only adds customizations specific to this particular 

87 table schema. 

88 """ 

89 

90 @classmethod 

91 def initialize( 

92 cls, 

93 db: Database, 

94 context: StaticTablesContext, 

95 *, 

96 caching_context: CachingContext, 

97 registry_schema_version: VersionTuple | None = None, 

98 ) -> NameKeyCollectionManager: 

99 # Docstring inherited from CollectionManager. 

100 return cls( 

101 db, 

102 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore 

103 collectionIdName="name", 

104 caching_context=caching_context, 

105 registry_schema_version=registry_schema_version, 

106 ) 

107 

108 def clone(self, db: Database, caching_context: CachingContext) -> NameKeyCollectionManager: 

109 return NameKeyCollectionManager( 

110 db, 

111 tables=self._tables, 

112 collectionIdName=self._collectionIdName, 

113 caching_context=caching_context, 

114 registry_schema_version=self._registry_schema_version, 

115 ) 

116 

117 @classmethod 

118 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str: 

119 # Docstring inherited from CollectionManager. 

120 return f"{prefix}_name" 

121 

122 @classmethod 

123 def getRunForeignKeyName(cls, prefix: str = "run") -> str: 

124 # Docstring inherited from CollectionManager. 

125 return f"{prefix}_name" 

126 

127 @classmethod 

128 def addCollectionForeignKey( 

129 cls, 

130 tableSpec: ddl.TableSpec, 

131 *, 

132 prefix: str = "collection", 

133 onDelete: str | None = None, 

134 constraint: bool = True, 

135 **kwargs: Any, 

136 ) -> ddl.FieldSpec: 

137 # Docstring inherited from CollectionManager. 

138 original = _KEY_FIELD_SPEC 

139 copy = ddl.FieldSpec( 

140 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

141 ) 

142 tableSpec.fields.add(copy) 

143 if constraint: 

144 tableSpec.foreignKeys.append( 

145 ddl.ForeignKeySpec( 

146 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete 

147 ) 

148 ) 

149 return copy 

150 

151 @classmethod 

152 def addRunForeignKey( 

153 cls, 

154 tableSpec: ddl.TableSpec, 

155 *, 

156 prefix: str = "run", 

157 onDelete: str | None = None, 

158 constraint: bool = True, 

159 **kwargs: Any, 

160 ) -> ddl.FieldSpec: 

161 # Docstring inherited from CollectionManager. 

162 original = _KEY_FIELD_SPEC 

163 copy = ddl.FieldSpec( 

164 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs 

165 ) 

166 tableSpec.fields.add(copy) 

167 if constraint: 167 ↛ 171line 167 didn't jump to line 171, because the condition on line 167 was never false

168 tableSpec.foreignKeys.append( 

169 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete) 

170 ) 

171 return copy 

172 

173 def getParentChains(self, key: str) -> set[str]: 

174 # Docstring inherited from CollectionManager. 

175 table = self._tables.collection_chain 

176 sql = ( 

177 sqlalchemy.sql.select(table.columns["parent"]) 

178 .select_from(table) 

179 .where(table.columns["child"] == key) 

180 ) 

181 with self._db.query(sql) as sql_result: 

182 parent_names = set(sql_result.scalars().all()) 

183 return parent_names 

184 

185 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: 

186 # Docstring inherited from base class. 

187 return self._fetch_by_key(names) 

188 

189 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: 

190 # Docstring inherited from base class. 

191 _LOG.debug("Fetching collection records using names %s.", collection_ids) 

192 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( 

193 self._tables.collection.join(self._tables.run, isouter=True) 

194 ) 

195 

196 chain_sql = sqlalchemy.sql.select( 

197 self._tables.collection_chain.columns["parent"], 

198 self._tables.collection_chain.columns["position"], 

199 self._tables.collection_chain.columns["child"], 

200 ) 

201 

202 records: list[CollectionRecord[str]] = [] 

203 # We want to keep transactions as short as possible. When we fetch 

204 # everything we want to quickly fetch things into memory and finish 

205 # transaction. When we fetch just few records we need to process result 

206 # of the first query before we can run the second one. 

207 if collection_ids is not None: 

208 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) 

209 with self._db.transaction(): 

210 with self._db.query(sql) as sql_result: 

211 sql_rows = sql_result.mappings().fetchall() 

212 

213 records, chained_ids = self._rows_to_records(sql_rows) 

214 

215 if chained_ids: 

216 # Retrieve chained collection compositions 

217 chain_sql = chain_sql.where( 

218 self._tables.collection_chain.columns["parent"].in_(chained_ids) 

219 ) 

220 with self._db.query(chain_sql) as sql_result: 

221 chain_rows = sql_result.mappings().fetchall() 

222 

223 records += self._rows_to_chains(chain_rows, chained_ids) 

224 

225 else: 

226 with self._db.transaction(): 

227 with self._db.query(sql) as sql_result: 

228 sql_rows = sql_result.mappings().fetchall() 

229 with self._db.query(chain_sql) as sql_result: 

230 chain_rows = sql_result.mappings().fetchall() 

231 

232 records, chained_ids = self._rows_to_records(sql_rows) 

233 records += self._rows_to_chains(chain_rows, chained_ids) 

234 

235 return records 

236 

237 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: 

238 """Convert rows returned from collection query to a list of records 

239 and a list chained collection names. 

240 """ 

241 records: list[CollectionRecord[str]] = [] 

242 TimespanReprClass = self._db.getTimespanRepresentation() 

243 chained_ids: list[str] = [] 

244 for row in rows: 

245 name = row[self._tables.collection.columns.name] 

246 type = CollectionType(row["type"]) 

247 record: CollectionRecord[str] 

248 if type is CollectionType.RUN: 

249 record = RunRecord[str]( 

250 key=name, 

251 name=name, 

252 host=row[self._tables.run.columns.host], 

253 timespan=TimespanReprClass.extract(row), 

254 ) 

255 records.append(record) 

256 elif type is CollectionType.CHAINED: 

257 # Need to delay chained collection construction until to 

258 # fetch their children names. 

259 chained_ids.append(name) 

260 else: 

261 record = CollectionRecord[str](key=name, name=name, type=type) 

262 records.append(record) 

263 

264 return records, chained_ids 

265 

266 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: 

267 """Convert rows returned from collection chain query to a list of 

268 records. 

269 """ 

270 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} 

271 for row in rows: 

272 chains_defs[row["parent"]].append((row["position"], row["child"])) 

273 

274 records: list[CollectionRecord[str]] = [] 

275 for name, children in chains_defs.items(): 

276 children_names = [child for _, child in sorted(children)] 

277 record = ChainedCollectionRecord[str]( 

278 key=name, 

279 name=name, 

280 children=children_names, 

281 ) 

282 records.append(record) 

283 

284 return records 

285 

286 @classmethod 

287 def currentVersions(cls) -> list[VersionTuple]: 

288 # Docstring inherited from VersionedExtension. 

289 return [_VERSION]