Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%
104 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 11:19 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 11:19 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["NameKeyCollectionManager"]
31import logging
32from collections.abc import Iterable, Mapping
33from typing import TYPE_CHECKING, Any
35import sqlalchemy
37from ... import ddl
38from ...timespan_database_representation import TimespanDatabaseRepresentation
39from .._collection_type import CollectionType
40from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
41from ._base import (
42 CollectionTablesTuple,
43 DefaultCollectionManager,
44 makeCollectionChainTableSpec,
45 makeRunTableSpec,
46)
48if TYPE_CHECKING:
49 from .._caching_context import CachingContext
50 from ..interfaces import Database, StaticTablesContext
53_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True)
56# This has to be updated on every schema change
57_VERSION = VersionTuple(2, 0, 0)
60_LOG = logging.getLogger(__name__)
63def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
64 return CollectionTablesTuple(
65 collection=ddl.TableSpec(
66 fields=[
67 _KEY_FIELD_SPEC,
68 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
69 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
70 ],
71 ),
72 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass),
73 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String),
74 )
77class NameKeyCollectionManager(DefaultCollectionManager[str]):
78 """A `CollectionManager` implementation that uses collection names for
79 primary/foreign keys and aggressively loads all collection/run records in
80 the database into memory.
82 Most of the logic, including caching policy, is implemented in the base
83 class, this class only adds customizations specific to this particular
84 table schema.
85 """
87 @classmethod
88 def initialize(
89 cls,
90 db: Database,
91 context: StaticTablesContext,
92 *,
93 caching_context: CachingContext,
94 registry_schema_version: VersionTuple | None = None,
95 ) -> NameKeyCollectionManager:
96 # Docstring inherited from CollectionManager.
97 return cls(
98 db,
99 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
100 collectionIdName="name",
101 caching_context=caching_context,
102 registry_schema_version=registry_schema_version,
103 )
105 def clone(self, db: Database, caching_context: CachingContext) -> NameKeyCollectionManager:
106 return NameKeyCollectionManager(
107 db,
108 tables=self._tables,
109 collectionIdName=self._collectionIdName,
110 caching_context=caching_context,
111 registry_schema_version=self._registry_schema_version,
112 )
114 @classmethod
115 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
116 # Docstring inherited from CollectionManager.
117 return f"{prefix}_name"
119 @classmethod
120 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
121 # Docstring inherited from CollectionManager.
122 return f"{prefix}_name"
124 @classmethod
125 def addCollectionForeignKey(
126 cls,
127 tableSpec: ddl.TableSpec,
128 *,
129 prefix: str = "collection",
130 onDelete: str | None = None,
131 constraint: bool = True,
132 **kwargs: Any,
133 ) -> ddl.FieldSpec:
134 # Docstring inherited from CollectionManager.
135 original = _KEY_FIELD_SPEC
136 copy = ddl.FieldSpec(
137 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
138 )
139 tableSpec.fields.add(copy)
140 if constraint:
141 tableSpec.foreignKeys.append(
142 ddl.ForeignKeySpec(
143 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
144 )
145 )
146 return copy
148 @classmethod
149 def addRunForeignKey(
150 cls,
151 tableSpec: ddl.TableSpec,
152 *,
153 prefix: str = "run",
154 onDelete: str | None = None,
155 constraint: bool = True,
156 **kwargs: Any,
157 ) -> ddl.FieldSpec:
158 # Docstring inherited from CollectionManager.
159 original = _KEY_FIELD_SPEC
160 copy = ddl.FieldSpec(
161 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
162 )
163 tableSpec.fields.add(copy)
164 if constraint: 164 ↛ 168line 164 didn't jump to line 168, because the condition on line 164 was never false
165 tableSpec.foreignKeys.append(
166 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
167 )
168 return copy
170 def getParentChains(self, key: str) -> set[str]:
171 # Docstring inherited from CollectionManager.
172 table = self._tables.collection_chain
173 sql = (
174 sqlalchemy.sql.select(table.columns["parent"])
175 .select_from(table)
176 .where(table.columns["child"] == key)
177 )
178 with self._db.query(sql) as sql_result:
179 parent_names = set(sql_result.scalars().all())
180 return parent_names
182 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
183 # Docstring inherited from base class.
184 return self._fetch_by_key(names)
186 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
187 # Docstring inherited from base class.
188 _LOG.debug("Fetching collection records using names %s.", collection_ids)
189 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
190 self._tables.collection.join(self._tables.run, isouter=True)
191 )
193 chain_sql = sqlalchemy.sql.select(
194 self._tables.collection_chain.columns["parent"],
195 self._tables.collection_chain.columns["position"],
196 self._tables.collection_chain.columns["child"],
197 )
199 records: list[CollectionRecord[str]] = []
200 # We want to keep transactions as short as possible. When we fetch
201 # everything we want to quickly fetch things into memory and finish
202 # transaction. When we fetch just few records we need to process result
203 # of the first query before we can run the second one.
204 if collection_ids is not None:
205 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
206 with self._db.transaction():
207 with self._db.query(sql) as sql_result:
208 sql_rows = sql_result.mappings().fetchall()
210 records, chained_ids = self._rows_to_records(sql_rows)
212 if chained_ids:
213 # Retrieve chained collection compositions
214 chain_sql = chain_sql.where(
215 self._tables.collection_chain.columns["parent"].in_(chained_ids)
216 )
217 with self._db.query(chain_sql) as sql_result:
218 chain_rows = sql_result.mappings().fetchall()
220 records += self._rows_to_chains(chain_rows, chained_ids)
222 else:
223 with self._db.transaction():
224 with self._db.query(sql) as sql_result:
225 sql_rows = sql_result.mappings().fetchall()
226 with self._db.query(chain_sql) as sql_result:
227 chain_rows = sql_result.mappings().fetchall()
229 records, chained_ids = self._rows_to_records(sql_rows)
230 records += self._rows_to_chains(chain_rows, chained_ids)
232 return records
234 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
235 """Convert rows returned from collection query to a list of records
236 and a list chained collection names.
237 """
238 records: list[CollectionRecord[str]] = []
239 TimespanReprClass = self._db.getTimespanRepresentation()
240 chained_ids: list[str] = []
241 for row in rows:
242 name = row[self._tables.collection.columns.name]
243 type = CollectionType(row["type"])
244 record: CollectionRecord[str]
245 if type is CollectionType.RUN:
246 record = RunRecord[str](
247 key=name,
248 name=name,
249 host=row[self._tables.run.columns.host],
250 timespan=TimespanReprClass.extract(row),
251 )
252 records.append(record)
253 elif type is CollectionType.CHAINED:
254 # Need to delay chained collection construction until to
255 # fetch their children names.
256 chained_ids.append(name)
257 else:
258 record = CollectionRecord[str](key=name, name=name, type=type)
259 records.append(record)
261 return records, chained_ids
263 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
264 """Convert rows returned from collection chain query to a list of
265 records.
266 """
267 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
268 for row in rows:
269 chains_defs[row["parent"]].append((row["position"], row["child"]))
271 records: list[CollectionRecord[str]] = []
272 for name, children in chains_defs.items():
273 children_names = [child for _, child in sorted(children)]
274 record = ChainedCollectionRecord[str](
275 key=name,
276 name=name,
277 children=children_names,
278 )
279 records.append(record)
281 return records
283 @classmethod
284 def currentVersions(cls) -> list[VersionTuple]:
285 # Docstring inherited from VersionedExtension.
286 return [_VERSION]