Coverage for python / lsst / daf / butler / registry / collections / nameKey.py: 0%
120 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["NameKeyCollectionManager"]
31import logging
32from collections.abc import Iterable, Mapping
33from typing import TYPE_CHECKING, Any
35import sqlalchemy
37from ... import ddl
38from ..._collection_type import CollectionType
39from ...column_spec import COLLECTION_NAME_MAX_LENGTH
40from ...timespan_database_representation import TimespanDatabaseRepresentation
41from ..interfaces import (
42 ChainedCollectionRecord,
43 CollectionRecord,
44 Joinable,
45 JoinedCollectionsTable,
46 RunRecord,
47 VersionTuple,
48)
49from ._base import (
50 CollectionTablesTuple,
51 DefaultCollectionManager,
52 makeCollectionChainTableSpec,
53 makeRunTableSpec,
54)
56if TYPE_CHECKING:
57 from .._caching_context import CachingContext
58 from ..interfaces import Database, StaticTablesContext
61_KEY_FIELD_SPEC = ddl.FieldSpec(
62 "name", dtype=sqlalchemy.String, length=COLLECTION_NAME_MAX_LENGTH, primaryKey=True
63)
66# This has to be updated on every schema change
67_VERSION = VersionTuple(2, 0, 0)
70_LOG = logging.getLogger(__name__)
73def _makeTableSpecs(
74 TimespanReprClass: type[TimespanDatabaseRepresentation],
75) -> CollectionTablesTuple[ddl.TableSpec]:
76 return CollectionTablesTuple(
77 collection=ddl.TableSpec(
78 fields=[
79 _KEY_FIELD_SPEC,
80 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
81 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
82 ],
83 ),
84 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass),
85 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String),
86 )
89class NameKeyCollectionManager(DefaultCollectionManager[str]):
90 """A `CollectionManager` implementation that uses collection names for
91 primary/foreign keys and aggressively loads all collection/run records in
92 the database into memory.
94 Most of the logic, including caching policy, is implemented in the base
95 class, this class only adds customizations specific to this particular
96 table schema.
97 """
99 @classmethod
100 def initialize(
101 cls,
102 db: Database,
103 context: StaticTablesContext,
104 *,
105 caching_context: CachingContext,
106 registry_schema_version: VersionTuple | None = None,
107 ) -> NameKeyCollectionManager:
108 # Docstring inherited from CollectionManager.
109 return cls(
110 db,
111 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
112 collectionIdName="name",
113 caching_context=caching_context,
114 registry_schema_version=registry_schema_version,
115 )
117 def clone(self, db: Database, caching_context: CachingContext) -> NameKeyCollectionManager:
118 return NameKeyCollectionManager(
119 db,
120 tables=self._tables,
121 collectionIdName=self._collectionIdName,
122 caching_context=caching_context,
123 registry_schema_version=self._registry_schema_version,
124 )
126 @classmethod
127 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
128 # Docstring inherited from CollectionManager.
129 return f"{prefix}_name"
131 @classmethod
132 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
133 # Docstring inherited from CollectionManager.
134 return f"{prefix}_name"
136 @classmethod
137 def addCollectionForeignKey(
138 cls,
139 tableSpec: ddl.TableSpec,
140 *,
141 prefix: str = "collection",
142 onDelete: str | None = None,
143 constraint: bool = True,
144 **kwargs: Any,
145 ) -> ddl.FieldSpec:
146 # Docstring inherited from CollectionManager.
147 original = _KEY_FIELD_SPEC
148 copy = ddl.FieldSpec(
149 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
150 )
151 tableSpec.fields.add(copy)
152 if constraint:
153 tableSpec.foreignKeys.append(
154 ddl.ForeignKeySpec(
155 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
156 )
157 )
158 return copy
160 @classmethod
161 def addRunForeignKey(
162 cls,
163 tableSpec: ddl.TableSpec,
164 *,
165 prefix: str = "run",
166 onDelete: str | None = None,
167 constraint: bool = True,
168 **kwargs: Any,
169 ) -> ddl.FieldSpec:
170 # Docstring inherited from CollectionManager.
171 original = _KEY_FIELD_SPEC
172 copy = ddl.FieldSpec(
173 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
174 )
175 tableSpec.fields.add(copy)
176 if constraint:
177 tableSpec.foreignKeys.append(
178 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
179 )
180 return copy
182 def getParentChains(self, key: str) -> set[str]:
183 # Docstring inherited from CollectionManager.
184 table = self._tables.collection_chain
185 sql = (
186 sqlalchemy.sql.select(table.columns["parent"])
187 .select_from(table)
188 .where(table.columns["child"] == key)
189 )
190 with self._db.query(sql) as sql_result:
191 parent_names = set(sql_result.scalars().all())
192 return parent_names
194 def lookup_name_sql(
195 self, sql_key: sqlalchemy.ColumnElement[str], sql_from_clause: Joinable
196 ) -> tuple[sqlalchemy.ColumnElement[str], Joinable]:
197 # Docstring inherited.
198 return sql_key, sql_from_clause
200 def join_collections_sql(
201 self, sql_key: sqlalchemy.ColumnElement[str], joinable: Joinable
202 ) -> JoinedCollectionsTable:
203 name_column = self._tables.collection.columns["name"]
204 return JoinedCollectionsTable(
205 joined_sql=joinable.join(self._tables.collection, onclause=name_column == sql_key),
206 name_column=sql_key,
207 type_column=self._tables.collection.columns["type"],
208 )
210 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[str]]:
211 # Docstring inherited from base class.
212 if flatten_chains:
213 sql_rows = self._query_recursive(names, _KEY_FIELD_SPEC.dtype)
215 # There may be duplicates in the result, select unique names.
216 unique_rows = {row[self._collectionIdName]: row for row in sql_rows}
218 records, chained_ids = self._rows_to_records(unique_rows.values())
219 records += self._rows_to_chains(sql_rows, chained_ids)
221 return records
222 else:
223 return self._fetch_by_key(names)
225 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
226 # Docstring inherited from base class.
227 _LOG.debug("Fetching collection records using names %s.", collection_ids)
228 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
229 self._tables.collection.join(self._tables.run, isouter=True)
230 )
232 # "Rename" child column to "name" as expected by _rows_to_chains()
233 chain_sql = sqlalchemy.sql.select(
234 self._tables.collection_chain.columns["parent"],
235 self._tables.collection_chain.columns["position"],
236 self._tables.collection_chain.columns["child"].label("name"),
237 )
239 records: list[CollectionRecord[str]] = []
240 # We want to keep transactions as short as possible. When we fetch
241 # everything we want to quickly fetch things into memory and finish
242 # transaction. When we fetch just few records we need to process result
243 # of the first query before we can run the second one.
244 if collection_ids is not None:
245 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
246 with self._db.transaction():
247 with self._db.query(sql) as sql_result:
248 sql_rows = sql_result.mappings().fetchall()
250 records, chained_ids = self._rows_to_records(sql_rows)
252 if chained_ids:
253 # Retrieve chained collection compositions
254 chain_sql = chain_sql.where(
255 self._tables.collection_chain.columns["parent"].in_(chained_ids)
256 )
257 with self._db.query(chain_sql) as sql_result:
258 chain_rows = sql_result.mappings().fetchall()
260 records += self._rows_to_chains(chain_rows, chained_ids)
262 else:
263 with self._db.transaction():
264 with self._db.query(sql) as sql_result:
265 sql_rows = sql_result.mappings().fetchall()
266 with self._db.query(chain_sql) as sql_result:
267 chain_rows = sql_result.mappings().fetchall()
269 records, chained_ids = self._rows_to_records(sql_rows)
270 records += self._rows_to_chains(chain_rows, chained_ids)
272 return records
274 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
275 """Convert rows returned from collection query to a list of records
276 and a list chained collection names.
277 """
278 records: list[CollectionRecord[str]] = []
279 TimespanReprClass = self._db.getTimespanRepresentation()
280 chained_ids: list[str] = []
281 for row in rows:
282 name = row["name"]
283 type = CollectionType(row["type"])
284 record: CollectionRecord[str]
285 if type is CollectionType.RUN:
286 record = RunRecord[str](
287 key=name,
288 name=name,
289 host=row[self._tables.run.columns.host],
290 timespan=TimespanReprClass.extract(row),
291 )
292 records.append(record)
293 elif type is CollectionType.CHAINED:
294 # Need to delay chained collection construction until to
295 # fetch their children names.
296 chained_ids.append(name)
297 else:
298 record = CollectionRecord[str](key=name, name=name, type=type)
299 records.append(record)
301 return records, chained_ids
303 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
304 """Convert rows returned from collection chain query to a list of
305 records.
306 """
307 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
308 for row in rows:
309 if row["parent"] is not None:
310 chains_defs[row["parent"]].append((row["position"], row["name"]))
312 records: list[CollectionRecord[str]] = []
313 for name, children in chains_defs.items():
314 children_names = [child for _, child in sorted(children)]
315 record = ChainedCollectionRecord[str](
316 key=name,
317 name=name,
318 children=children_names,
319 )
320 records.append(record)
322 return records
324 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
325 table = self._tables.collection
326 return sqlalchemy.select(table.c.name.label("key"), table.c.type).where(
327 table.c.name == collection_name
328 )
330 @classmethod
331 def currentVersions(cls) -> list[VersionTuple]:
332 # Docstring inherited from VersionedExtension.
333 return [_VERSION]