Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%
106 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["SynthIntKeyCollectionManager"]
33from collections.abc import Iterable, Mapping
34from typing import TYPE_CHECKING, Any
36import sqlalchemy
38from ..._timespan import TimespanDatabaseRepresentation
39from .._collection_type import CollectionType
40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
41from ._base import (
42 CollectionTablesTuple,
43 DefaultCollectionManager,
44 makeCollectionChainTableSpec,
45 makeRunTableSpec,
46)
48if TYPE_CHECKING:
49 from .._caching_context import CachingContext
50 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext
53_KEY_FIELD_SPEC = ddl.FieldSpec(
54 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True
55)
58# This has to be updated on every schema change
59_VERSION = VersionTuple(2, 0, 0)
62def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
63 return CollectionTablesTuple(
64 collection=ddl.TableSpec(
65 fields=[
66 _KEY_FIELD_SPEC,
67 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, nullable=False),
68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
70 ],
71 unique=[("name",)],
72 ),
73 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass),
74 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger),
75 )
78class SynthIntKeyCollectionManager(DefaultCollectionManager[int]):
79 """A `CollectionManager` implementation that uses synthetic primary key
80 (auto-incremented integer) for collections table.
81 """
83 @classmethod
84 def initialize(
85 cls,
86 db: Database,
87 context: StaticTablesContext,
88 *,
89 dimensions: DimensionRecordStorageManager,
90 caching_context: CachingContext,
91 registry_schema_version: VersionTuple | None = None,
92 ) -> SynthIntKeyCollectionManager:
93 # Docstring inherited from CollectionManager.
94 return cls(
95 db,
96 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
97 collectionIdName="collection_id",
98 dimensions=dimensions,
99 caching_context=caching_context,
100 registry_schema_version=registry_schema_version,
101 )
103 @classmethod
104 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
105 # Docstring inherited from CollectionManager.
106 return f"{prefix}_id"
108 @classmethod
109 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
110 # Docstring inherited from CollectionManager.
111 return f"{prefix}_id"
113 @classmethod
114 def addCollectionForeignKey(
115 cls,
116 tableSpec: ddl.TableSpec,
117 *,
118 prefix: str = "collection",
119 onDelete: str | None = None,
120 constraint: bool = True,
121 **kwargs: Any,
122 ) -> ddl.FieldSpec:
123 # Docstring inherited from CollectionManager.
124 original = _KEY_FIELD_SPEC
125 copy = ddl.FieldSpec(
126 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
127 )
128 tableSpec.fields.add(copy)
129 if constraint:
130 tableSpec.foreignKeys.append(
131 ddl.ForeignKeySpec(
132 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
133 )
134 )
135 return copy
137 @classmethod
138 def addRunForeignKey(
139 cls,
140 tableSpec: ddl.TableSpec,
141 *,
142 prefix: str = "run",
143 onDelete: str | None = None,
144 constraint: bool = True,
145 **kwargs: Any,
146 ) -> ddl.FieldSpec:
147 # Docstring inherited from CollectionManager.
148 original = _KEY_FIELD_SPEC
149 copy = ddl.FieldSpec(
150 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
151 )
152 tableSpec.fields.add(copy)
153 if constraint: 153 ↛ 157line 153 didn't jump to line 157, because the condition on line 153 was never false
154 tableSpec.foreignKeys.append(
155 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
156 )
157 return copy
159 def getParentChains(self, key: int) -> set[str]:
160 # Docstring inherited from CollectionManager.
161 chain = self._tables.collection_chain
162 collection = self._tables.collection
163 sql = (
164 sqlalchemy.sql.select(collection.columns["name"])
165 .select_from(collection)
166 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
167 .where(chain.columns["child"] == key)
168 )
169 with self._db.query(sql) as sql_result:
170 parent_names = set(sql_result.scalars().all())
171 return parent_names
173 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]:
174 # Docstring inherited from base class.
175 return self._fetch("name", names)
177 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]:
178 # Docstring inherited from base class.
179 return self._fetch(self._collectionIdName, collection_ids)
181 def _fetch(
182 self, column_name: str, collections: Iterable[int | str] | None
183 ) -> list[CollectionRecord[int]]:
184 collection_chain = self._tables.collection_chain
185 collection = self._tables.collection
186 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from(
187 collection.join(self._tables.run, isouter=True)
188 )
190 chain_sql = (
191 sqlalchemy.sql.select(
192 collection_chain.columns["parent"],
193 collection_chain.columns["position"],
194 collection.columns["name"].label("child_name"),
195 )
196 .select_from(collection_chain)
197 .join(
198 collection,
199 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName],
200 )
201 )
203 records: list[CollectionRecord[int]] = []
204 # We want to keep transactions as short as possible. When we fetch
205 # everything we want to quickly fetch things into memory and finish
206 # transaction. When we fetch just few records we need to process first
207 # query before wi can run second one,
208 if collections is not None:
209 sql = sql.where(collection.columns[column_name].in_(collections))
210 with self._db.transaction():
211 with self._db.query(sql) as sql_result:
212 sql_rows = sql_result.mappings().fetchall()
214 records, chained_ids = self._rows_to_records(sql_rows)
216 if chained_ids:
217 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids)))
219 with self._db.query(chain_sql) as sql_result:
220 chain_rows = sql_result.mappings().fetchall()
222 records += self._rows_to_chains(chain_rows, chained_ids)
224 else:
225 with self._db.transaction():
226 with self._db.query(sql) as sql_result:
227 sql_rows = sql_result.mappings().fetchall()
228 with self._db.query(chain_sql) as sql_result:
229 chain_rows = sql_result.mappings().fetchall()
231 records, chained_ids = self._rows_to_records(sql_rows)
232 records += self._rows_to_chains(chain_rows, chained_ids)
234 return records
236 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]:
237 """Convert rows returned from collection query to a list of records
238 and a dict chained collection names.
239 """
240 records: list[CollectionRecord[int]] = []
241 chained_ids: dict[int, str] = {}
242 TimespanReprClass = self._db.getTimespanRepresentation()
243 for row in rows:
244 key: int = row[self._collectionIdName]
245 name: str = row[self._tables.collection.columns.name]
246 type = CollectionType(row["type"])
247 record: CollectionRecord[int]
248 if type is CollectionType.RUN:
249 record = RunRecord[int](
250 key=key,
251 name=name,
252 host=row[self._tables.run.columns.host],
253 timespan=TimespanReprClass.extract(row),
254 )
255 records.append(record)
256 elif type is CollectionType.CHAINED:
257 # Need to delay chained collection construction until to
258 # fetch their children names.
259 chained_ids[key] = name
260 else:
261 record = CollectionRecord[int](key=key, name=name, type=type)
262 records.append(record)
263 return records, chained_ids
265 def _rows_to_chains(
266 self, rows: Iterable[Mapping], chained_ids: dict[int, str]
267 ) -> list[CollectionRecord[int]]:
268 """Convert rows returned from collection chain query to a list of
269 records.
270 """
271 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
272 for row in rows:
273 chains_defs[row["parent"]].append((row["position"], row["child_name"]))
275 records: list[CollectionRecord[int]] = []
276 for key, children in chains_defs.items():
277 name = chained_ids[key]
278 children_names = [child for _, child in sorted(children)]
279 record = ChainedCollectionRecord[int](
280 key=key,
281 name=name,
282 children=children_names,
283 )
284 records.append(record)
286 return records
288 @classmethod
289 def currentVersions(cls) -> list[VersionTuple]:
290 # Docstring inherited from VersionedExtension.
291 return [_VERSION]