Coverage for python / lsst / daf / butler / registry / collections / synthIntKey.py: 0%
128 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["SynthIntKeyCollectionManager"]
33import logging
34from collections.abc import Iterable, Mapping
35from typing import TYPE_CHECKING, Any
37import sqlalchemy
39from ..._collection_type import CollectionType
40from ...column_spec import COLLECTION_NAME_MAX_LENGTH
41from ...timespan_database_representation import TimespanDatabaseRepresentation
42from ..interfaces import (
43 ChainedCollectionRecord,
44 CollectionRecord,
45 Joinable,
46 JoinedCollectionsTable,
47 RunRecord,
48 VersionTuple,
49)
50from ._base import (
51 CollectionTablesTuple,
52 DefaultCollectionManager,
53 makeCollectionChainTableSpec,
54 makeRunTableSpec,
55)
57if TYPE_CHECKING:
58 from .._caching_context import CachingContext
59 from ..interfaces import Database, StaticTablesContext
62_KEY_FIELD_SPEC = ddl.FieldSpec(
63 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True
64)
67# This has to be updated on every schema change
68_VERSION = VersionTuple(2, 0, 0)
70_LOG = logging.getLogger(__name__)
73def _makeTableSpecs(
74 TimespanReprClass: type[TimespanDatabaseRepresentation],
75) -> CollectionTablesTuple[ddl.TableSpec]:
76 return CollectionTablesTuple(
77 collection=ddl.TableSpec(
78 fields=[
79 _KEY_FIELD_SPEC,
80 ddl.FieldSpec(
81 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, nullable=False
82 ),
83 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
84 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
85 ],
86 unique=[("name",)],
87 ),
88 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass),
89 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger),
90 )
93class SynthIntKeyCollectionManager(DefaultCollectionManager[int]):
94 """A `CollectionManager` implementation that uses synthetic primary key
95 (auto-incremented integer) for collections table.
96 """
98 @classmethod
99 def initialize(
100 cls,
101 db: Database,
102 context: StaticTablesContext,
103 *,
104 caching_context: CachingContext,
105 registry_schema_version: VersionTuple | None = None,
106 ) -> SynthIntKeyCollectionManager:
107 # Docstring inherited from CollectionManager.
108 return cls(
109 db,
110 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
111 collectionIdName="collection_id",
112 caching_context=caching_context,
113 registry_schema_version=registry_schema_version,
114 )
116 def clone(self, db: Database, caching_context: CachingContext) -> SynthIntKeyCollectionManager:
117 return SynthIntKeyCollectionManager(
118 db,
119 tables=self._tables,
120 collectionIdName=self._collectionIdName,
121 caching_context=caching_context,
122 registry_schema_version=self._registry_schema_version,
123 )
125 @classmethod
126 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
127 # Docstring inherited from CollectionManager.
128 return f"{prefix}_id"
130 @classmethod
131 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
132 # Docstring inherited from CollectionManager.
133 return f"{prefix}_id"
135 @classmethod
136 def addCollectionForeignKey(
137 cls,
138 tableSpec: ddl.TableSpec,
139 *,
140 prefix: str = "collection",
141 onDelete: str | None = None,
142 constraint: bool = True,
143 **kwargs: Any,
144 ) -> ddl.FieldSpec:
145 # Docstring inherited from CollectionManager.
146 original = _KEY_FIELD_SPEC
147 copy = ddl.FieldSpec(
148 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
149 )
150 tableSpec.fields.add(copy)
151 if constraint:
152 tableSpec.foreignKeys.append(
153 ddl.ForeignKeySpec(
154 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
155 )
156 )
157 return copy
159 @classmethod
160 def addRunForeignKey(
161 cls,
162 tableSpec: ddl.TableSpec,
163 *,
164 prefix: str = "run",
165 onDelete: str | None = None,
166 constraint: bool = True,
167 **kwargs: Any,
168 ) -> ddl.FieldSpec:
169 # Docstring inherited from CollectionManager.
170 original = _KEY_FIELD_SPEC
171 copy = ddl.FieldSpec(
172 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
173 )
174 tableSpec.fields.add(copy)
175 if constraint:
176 tableSpec.foreignKeys.append(
177 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
178 )
179 return copy
181 def getParentChains(self, key: int) -> set[str]:
182 # Docstring inherited from CollectionManager.
183 chain = self._tables.collection_chain
184 collection = self._tables.collection
185 sql = (
186 sqlalchemy.sql.select(collection.columns["name"])
187 .select_from(collection)
188 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
189 .where(chain.columns["child"] == key)
190 )
191 with self._db.query(sql) as sql_result:
192 parent_names = set(sql_result.scalars().all())
193 return parent_names
195 def lookup_name_sql(
196 self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: Joinable
197 ) -> tuple[sqlalchemy.ColumnElement[str], Joinable]:
198 # Docstring inherited.
199 joined = self.join_collections_sql(sql_key, sql_from_clause)
200 return (joined.name_column, joined.joined_sql)
202 def join_collections_sql(
203 self, sql_key: sqlalchemy.ColumnElement[int], joinable: Joinable
204 ) -> JoinedCollectionsTable:
205 return JoinedCollectionsTable(
206 joined_sql=joinable.join(
207 self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key
208 ),
209 name_column=self._tables.collection.columns["name"],
210 type_column=self._tables.collection.columns["type"],
211 )
213 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[int]]:
214 # Docstring inherited from base class.
215 _LOG.debug("Fetching collection records using names %s.", names)
216 if flatten_chains:
217 sql_rows = self._query_recursive(names, _KEY_FIELD_SPEC.dtype)
219 # There may be duplicates in the result, select unique IDs.
220 unique_rows = {row[self._collectionIdName]: row for row in sql_rows}
222 records, chained_ids = self._rows_to_records(unique_rows.values())
223 records += self._rows_to_chains(sql_rows, chained_ids)
225 return records
226 else:
227 return self._fetch("name", names)
229 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]:
230 # Docstring inherited from base class.
231 _LOG.debug("Fetching collection records using IDs %s.", collection_ids)
232 return self._fetch(self._collectionIdName, collection_ids)
234 def _fetch(
235 self, column_name: str, collections: Iterable[int | str] | None
236 ) -> list[CollectionRecord[int]]:
237 collection_chain = self._tables.collection_chain
238 collection = self._tables.collection
239 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from(
240 collection.join(self._tables.run, isouter=True)
241 )
243 chain_sql = (
244 sqlalchemy.sql.select(
245 collection_chain.columns["parent"],
246 collection_chain.columns["position"],
247 collection.columns["name"],
248 )
249 .select_from(collection_chain)
250 .join(
251 collection,
252 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName],
253 )
254 )
256 records: list[CollectionRecord[int]] = []
257 # We want to keep transactions as short as possible. When we fetch
258 # everything we want to quickly fetch things into memory and finish
259 # transaction. When we fetch just few records we need to process first
260 # query before wi can run second one,
261 if collections is not None:
262 sql = sql.where(collection.columns[column_name].in_(collections))
263 with self._db.transaction():
264 with self._db.query(sql) as sql_result:
265 sql_rows = sql_result.mappings().fetchall()
267 records, chained_ids = self._rows_to_records(sql_rows)
269 if chained_ids:
270 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids)))
272 with self._db.query(chain_sql) as sql_result:
273 chain_rows = sql_result.mappings().fetchall()
275 records += self._rows_to_chains(chain_rows, chained_ids)
277 else:
278 with self._db.transaction():
279 with self._db.query(sql) as sql_result:
280 sql_rows = sql_result.mappings().fetchall()
281 with self._db.query(chain_sql) as sql_result:
282 chain_rows = sql_result.mappings().fetchall()
284 records, chained_ids = self._rows_to_records(sql_rows)
285 records += self._rows_to_chains(chain_rows, chained_ids)
287 return records
289 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]:
290 """Convert rows returned from collection query to a list of records
291 and a dict chained collection names.
292 """
293 records: list[CollectionRecord[int]] = []
294 chained_ids: dict[int, str] = {}
295 TimespanReprClass = self._db.getTimespanRepresentation()
296 for row in rows:
297 key: int = row[self._collectionIdName]
298 name: str = row["name"]
299 type = CollectionType(row["type"])
300 record: CollectionRecord[int]
301 if type is CollectionType.RUN:
302 record = RunRecord[int](
303 key=key,
304 name=name,
305 host=row[self._tables.run.columns.host],
306 timespan=TimespanReprClass.extract(row),
307 )
308 records.append(record)
309 elif type is CollectionType.CHAINED:
310 # Need to delay chained collection construction until to
311 # fetch their children names.
312 chained_ids[key] = name
313 else:
314 record = CollectionRecord[int](key=key, name=name, type=type)
315 records.append(record)
316 return records, chained_ids
318 def _rows_to_chains(
319 self, rows: Iterable[Mapping], chained_ids: dict[int, str]
320 ) -> list[CollectionRecord[int]]:
321 """Convert rows returned from collection chain query to a list of
322 records.
323 """
324 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
325 for row in rows:
326 if row["parent"] is not None:
327 chains_defs[row["parent"]].append((row["position"], row["name"]))
329 records: list[CollectionRecord[int]] = []
330 for key, children in chains_defs.items():
331 name = chained_ids[key]
332 children_names = [child for _, child in sorted(children)]
333 record = ChainedCollectionRecord[int](
334 key=key,
335 name=name,
336 children=children_names,
337 )
338 records.append(record)
340 return records
342 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
343 table = self._tables.collection
344 return sqlalchemy.select(table.c.collection_id.label("key"), table.c.type).where(
345 table.c.name == collection_name
346 )
348 @classmethod
349 def currentVersions(cls) -> list[VersionTuple]:
350 # Docstring inherited from VersionedExtension.
351 return [_VERSION]