Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%
118 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 09:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 09:54 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["SynthIntKeyCollectionManager"]
33import logging
34from collections.abc import Iterable, Mapping
35from typing import TYPE_CHECKING, Any
37import sqlalchemy
39from ...column_spec import COLLECTION_NAME_MAX_LENGTH
40from ...timespan_database_representation import TimespanDatabaseRepresentation
41from .._collection_type import CollectionType
42from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
43from ._base import (
44 CollectionTablesTuple,
45 DefaultCollectionManager,
46 makeCollectionChainTableSpec,
47 makeRunTableSpec,
48)
50if TYPE_CHECKING:
51 from .._caching_context import CachingContext
52 from ..interfaces import Database, StaticTablesContext
55_KEY_FIELD_SPEC = ddl.FieldSpec(
56 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True
57)
60# This has to be updated on every schema change
61_VERSION = VersionTuple(2, 0, 0)
63_LOG = logging.getLogger(__name__)
66def _makeTableSpecs(
67 TimespanReprClass: type[TimespanDatabaseRepresentation],
68) -> CollectionTablesTuple[ddl.TableSpec]:
69 return CollectionTablesTuple(
70 collection=ddl.TableSpec(
71 fields=[
72 _KEY_FIELD_SPEC,
73 ddl.FieldSpec(
74 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, nullable=False
75 ),
76 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
77 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
78 ],
79 unique=[("name",)],
80 ),
81 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass),
82 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger),
83 )
86class SynthIntKeyCollectionManager(DefaultCollectionManager[int]):
87 """A `CollectionManager` implementation that uses synthetic primary key
88 (auto-incremented integer) for collections table.
89 """
91 @classmethod
92 def initialize(
93 cls,
94 db: Database,
95 context: StaticTablesContext,
96 *,
97 caching_context: CachingContext,
98 registry_schema_version: VersionTuple | None = None,
99 ) -> SynthIntKeyCollectionManager:
100 # Docstring inherited from CollectionManager.
101 return cls(
102 db,
103 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
104 collectionIdName="collection_id",
105 caching_context=caching_context,
106 registry_schema_version=registry_schema_version,
107 )
109 def clone(self, db: Database, caching_context: CachingContext) -> SynthIntKeyCollectionManager:
110 return SynthIntKeyCollectionManager(
111 db,
112 tables=self._tables,
113 collectionIdName=self._collectionIdName,
114 caching_context=caching_context,
115 registry_schema_version=self._registry_schema_version,
116 )
118 @classmethod
119 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
120 # Docstring inherited from CollectionManager.
121 return f"{prefix}_id"
123 @classmethod
124 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
125 # Docstring inherited from CollectionManager.
126 return f"{prefix}_id"
128 @classmethod
129 def addCollectionForeignKey(
130 cls,
131 tableSpec: ddl.TableSpec,
132 *,
133 prefix: str = "collection",
134 onDelete: str | None = None,
135 constraint: bool = True,
136 **kwargs: Any,
137 ) -> ddl.FieldSpec:
138 # Docstring inherited from CollectionManager.
139 original = _KEY_FIELD_SPEC
140 copy = ddl.FieldSpec(
141 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
142 )
143 tableSpec.fields.add(copy)
144 if constraint:
145 tableSpec.foreignKeys.append(
146 ddl.ForeignKeySpec(
147 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
148 )
149 )
150 return copy
152 @classmethod
153 def addRunForeignKey(
154 cls,
155 tableSpec: ddl.TableSpec,
156 *,
157 prefix: str = "run",
158 onDelete: str | None = None,
159 constraint: bool = True,
160 **kwargs: Any,
161 ) -> ddl.FieldSpec:
162 # Docstring inherited from CollectionManager.
163 original = _KEY_FIELD_SPEC
164 copy = ddl.FieldSpec(
165 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
166 )
167 tableSpec.fields.add(copy)
168 if constraint: 168 ↛ 172line 168 didn't jump to line 172, because the condition on line 168 was never false
169 tableSpec.foreignKeys.append(
170 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
171 )
172 return copy
174 def getParentChains(self, key: int) -> set[str]:
175 # Docstring inherited from CollectionManager.
176 chain = self._tables.collection_chain
177 collection = self._tables.collection
178 sql = (
179 sqlalchemy.sql.select(collection.columns["name"])
180 .select_from(collection)
181 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
182 .where(chain.columns["child"] == key)
183 )
184 with self._db.query(sql) as sql_result:
185 parent_names = set(sql_result.scalars().all())
186 return parent_names
188 def lookup_name_sql(
189 self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: sqlalchemy.FromClause
190 ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]:
191 # Docstring inherited.
192 return (
193 self._tables.collection.c.name,
194 sql_from_clause.join(
195 self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key
196 ),
197 )
199 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]:
200 # Docstring inherited from base class.
201 _LOG.debug("Fetching collection records using names %s.", names)
202 return self._fetch("name", names)
204 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]:
205 # Docstring inherited from base class.
206 _LOG.debug("Fetching collection records using IDs %s.", collection_ids)
207 return self._fetch(self._collectionIdName, collection_ids)
209 def _fetch(
210 self, column_name: str, collections: Iterable[int | str] | None
211 ) -> list[CollectionRecord[int]]:
212 collection_chain = self._tables.collection_chain
213 collection = self._tables.collection
214 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from(
215 collection.join(self._tables.run, isouter=True)
216 )
218 chain_sql = (
219 sqlalchemy.sql.select(
220 collection_chain.columns["parent"],
221 collection_chain.columns["position"],
222 collection.columns["name"].label("child_name"),
223 )
224 .select_from(collection_chain)
225 .join(
226 collection,
227 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName],
228 )
229 )
231 records: list[CollectionRecord[int]] = []
232 # We want to keep transactions as short as possible. When we fetch
233 # everything we want to quickly fetch things into memory and finish
234 # transaction. When we fetch just few records we need to process first
235 # query before wi can run second one,
236 if collections is not None:
237 sql = sql.where(collection.columns[column_name].in_(collections))
238 with self._db.transaction():
239 with self._db.query(sql) as sql_result:
240 sql_rows = sql_result.mappings().fetchall()
242 records, chained_ids = self._rows_to_records(sql_rows)
244 if chained_ids:
245 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids)))
247 with self._db.query(chain_sql) as sql_result:
248 chain_rows = sql_result.mappings().fetchall()
250 records += self._rows_to_chains(chain_rows, chained_ids)
252 else:
253 with self._db.transaction():
254 with self._db.query(sql) as sql_result:
255 sql_rows = sql_result.mappings().fetchall()
256 with self._db.query(chain_sql) as sql_result:
257 chain_rows = sql_result.mappings().fetchall()
259 records, chained_ids = self._rows_to_records(sql_rows)
260 records += self._rows_to_chains(chain_rows, chained_ids)
262 return records
264 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]:
265 """Convert rows returned from collection query to a list of records
266 and a dict chained collection names.
267 """
268 records: list[CollectionRecord[int]] = []
269 chained_ids: dict[int, str] = {}
270 TimespanReprClass = self._db.getTimespanRepresentation()
271 for row in rows:
272 key: int = row[self._collectionIdName]
273 name: str = row[self._tables.collection.columns.name]
274 type = CollectionType(row["type"])
275 record: CollectionRecord[int]
276 if type is CollectionType.RUN:
277 record = RunRecord[int](
278 key=key,
279 name=name,
280 host=row[self._tables.run.columns.host],
281 timespan=TimespanReprClass.extract(row),
282 )
283 records.append(record)
284 elif type is CollectionType.CHAINED:
285 # Need to delay chained collection construction until to
286 # fetch their children names.
287 chained_ids[key] = name
288 else:
289 record = CollectionRecord[int](key=key, name=name, type=type)
290 records.append(record)
291 return records, chained_ids
293 def _rows_to_chains(
294 self, rows: Iterable[Mapping], chained_ids: dict[int, str]
295 ) -> list[CollectionRecord[int]]:
296 """Convert rows returned from collection chain query to a list of
297 records.
298 """
299 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
300 for row in rows:
301 chains_defs[row["parent"]].append((row["position"], row["child_name"]))
303 records: list[CollectionRecord[int]] = []
304 for key, children in chains_defs.items():
305 name = chained_ids[key]
306 children_names = [child for _, child in sorted(children)]
307 record = ChainedCollectionRecord[int](
308 key=key,
309 name=name,
310 children=children_names,
311 )
312 records.append(record)
314 return records
316 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
317 table = self._tables.collection
318 return sqlalchemy.select(table.c.collection_id.label("key"), table.c.type).where(
319 table.c.name == collection_name
320 )
322 @classmethod
323 def currentVersions(cls) -> list[VersionTuple]:
324 # Docstring inherited from VersionedExtension.
325 return [_VERSION]