Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%
102 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["NameKeyCollectionManager"]
31import logging
32from collections.abc import Iterable, Mapping
33from typing import TYPE_CHECKING, Any
35import sqlalchemy
37from ... import ddl
38from ...timespan_database_representation import TimespanDatabaseRepresentation
39from .._collection_type import CollectionType
40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
41from ._base import (
42 CollectionTablesTuple,
43 DefaultCollectionManager,
44 makeCollectionChainTableSpec,
45 makeRunTableSpec,
46)
48if TYPE_CHECKING:
49 from .._caching_context import CachingContext
50 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext
53_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True)
56# This has to be updated on every schema change
57_VERSION = VersionTuple(2, 0, 0)
60_LOG = logging.getLogger(__name__)
63def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
64 return CollectionTablesTuple(
65 collection=ddl.TableSpec(
66 fields=[
67 _KEY_FIELD_SPEC,
68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
70 ],
71 ),
72 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass),
73 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String),
74 )
77class NameKeyCollectionManager(DefaultCollectionManager[str]):
78 """A `CollectionManager` implementation that uses collection names for
79 primary/foreign keys and aggressively loads all collection/run records in
80 the database into memory.
82 Most of the logic, including caching policy, is implemented in the base
83 class, this class only adds customizations specific to this particular
84 table schema.
85 """
87 @classmethod
88 def initialize(
89 cls,
90 db: Database,
91 context: StaticTablesContext,
92 *,
93 dimensions: DimensionRecordStorageManager,
94 caching_context: CachingContext,
95 registry_schema_version: VersionTuple | None = None,
96 ) -> NameKeyCollectionManager:
97 # Docstring inherited from CollectionManager.
98 return cls(
99 db,
100 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
101 collectionIdName="name",
102 dimensions=dimensions,
103 caching_context=caching_context,
104 registry_schema_version=registry_schema_version,
105 )
107 @classmethod
108 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
109 # Docstring inherited from CollectionManager.
110 return f"{prefix}_name"
112 @classmethod
113 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
114 # Docstring inherited from CollectionManager.
115 return f"{prefix}_name"
117 @classmethod
118 def addCollectionForeignKey(
119 cls,
120 tableSpec: ddl.TableSpec,
121 *,
122 prefix: str = "collection",
123 onDelete: str | None = None,
124 constraint: bool = True,
125 **kwargs: Any,
126 ) -> ddl.FieldSpec:
127 # Docstring inherited from CollectionManager.
128 original = _KEY_FIELD_SPEC
129 copy = ddl.FieldSpec(
130 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
131 )
132 tableSpec.fields.add(copy)
133 if constraint:
134 tableSpec.foreignKeys.append(
135 ddl.ForeignKeySpec(
136 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
137 )
138 )
139 return copy
141 @classmethod
142 def addRunForeignKey(
143 cls,
144 tableSpec: ddl.TableSpec,
145 *,
146 prefix: str = "run",
147 onDelete: str | None = None,
148 constraint: bool = True,
149 **kwargs: Any,
150 ) -> ddl.FieldSpec:
151 # Docstring inherited from CollectionManager.
152 original = _KEY_FIELD_SPEC
153 copy = ddl.FieldSpec(
154 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
155 )
156 tableSpec.fields.add(copy)
157 if constraint: 157 ↛ 161line 157 didn't jump to line 161, because the condition on line 157 was never false
158 tableSpec.foreignKeys.append(
159 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
160 )
161 return copy
163 def getParentChains(self, key: str) -> set[str]:
164 # Docstring inherited from CollectionManager.
165 table = self._tables.collection_chain
166 sql = (
167 sqlalchemy.sql.select(table.columns["parent"])
168 .select_from(table)
169 .where(table.columns["child"] == key)
170 )
171 with self._db.query(sql) as sql_result:
172 parent_names = set(sql_result.scalars().all())
173 return parent_names
175 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
176 # Docstring inherited from base class.
177 return self._fetch_by_key(names)
179 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
180 # Docstring inherited from base class.
181 _LOG.debug("Fetching collection records using names %s.", collection_ids)
182 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
183 self._tables.collection.join(self._tables.run, isouter=True)
184 )
186 chain_sql = sqlalchemy.sql.select(
187 self._tables.collection_chain.columns["parent"],
188 self._tables.collection_chain.columns["position"],
189 self._tables.collection_chain.columns["child"],
190 )
192 records: list[CollectionRecord[str]] = []
193 # We want to keep transactions as short as possible. When we fetch
194 # everything we want to quickly fetch things into memory and finish
195 # transaction. When we fetch just few records we need to process result
196 # of the first query before we can run the second one.
197 if collection_ids is not None:
198 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
199 with self._db.transaction():
200 with self._db.query(sql) as sql_result:
201 sql_rows = sql_result.mappings().fetchall()
203 records, chained_ids = self._rows_to_records(sql_rows)
205 if chained_ids:
206 # Retrieve chained collection compositions
207 chain_sql = chain_sql.where(
208 self._tables.collection_chain.columns["parent"].in_(chained_ids)
209 )
210 with self._db.query(chain_sql) as sql_result:
211 chain_rows = sql_result.mappings().fetchall()
213 records += self._rows_to_chains(chain_rows, chained_ids)
215 else:
216 with self._db.transaction():
217 with self._db.query(sql) as sql_result:
218 sql_rows = sql_result.mappings().fetchall()
219 with self._db.query(chain_sql) as sql_result:
220 chain_rows = sql_result.mappings().fetchall()
222 records, chained_ids = self._rows_to_records(sql_rows)
223 records += self._rows_to_chains(chain_rows, chained_ids)
225 return records
227 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
228 """Convert rows returned from collection query to a list of records
229 and a list chained collection names.
230 """
231 records: list[CollectionRecord[str]] = []
232 TimespanReprClass = self._db.getTimespanRepresentation()
233 chained_ids: list[str] = []
234 for row in rows:
235 name = row[self._tables.collection.columns.name]
236 type = CollectionType(row["type"])
237 record: CollectionRecord[str]
238 if type is CollectionType.RUN:
239 record = RunRecord[str](
240 key=name,
241 name=name,
242 host=row[self._tables.run.columns.host],
243 timespan=TimespanReprClass.extract(row),
244 )
245 records.append(record)
246 elif type is CollectionType.CHAINED:
247 # Need to delay chained collection construction until to
248 # fetch their children names.
249 chained_ids.append(name)
250 else:
251 record = CollectionRecord[str](key=name, name=name, type=type)
252 records.append(record)
254 return records, chained_ids
256 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
257 """Convert rows returned from collection chain query to a list of
258 records.
259 """
260 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
261 for row in rows:
262 chains_defs[row["parent"]].append((row["position"], row["child"]))
264 records: list[CollectionRecord[str]] = []
265 for name, children in chains_defs.items():
266 children_names = [child for _, child in sorted(children)]
267 record = ChainedCollectionRecord[str](
268 key=name,
269 name=name,
270 children=children_names,
271 )
272 records.append(record)
274 return records
276 @classmethod
277 def currentVersions(cls) -> list[VersionTuple]:
278 # Docstring inherited from VersionedExtension.
279 return [_VERSION]