Coverage for python/lsst/daf/butler/registry/collections/nameKey.py: 99%
99 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["NameKeyCollectionManager"]
31from collections.abc import Iterable, Mapping
32from typing import TYPE_CHECKING, Any
34import sqlalchemy
36from ... import ddl
37from ..._timespan import TimespanDatabaseRepresentation
38from .._collection_type import CollectionType
39from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
40from ._base import (
41 CollectionTablesTuple,
42 DefaultCollectionManager,
43 makeCollectionChainTableSpec,
44 makeRunTableSpec,
45)
47if TYPE_CHECKING:
48 from .._caching_context import CachingContext
49 from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext
52_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True)
55# This has to be updated on every schema change
56_VERSION = VersionTuple(2, 0, 0)
59def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
60 return CollectionTablesTuple(
61 collection=ddl.TableSpec(
62 fields=[
63 _KEY_FIELD_SPEC,
64 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
65 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
66 ],
67 ),
68 run=makeRunTableSpec("name", sqlalchemy.String, TimespanReprClass),
69 collection_chain=makeCollectionChainTableSpec("name", sqlalchemy.String),
70 )
73class NameKeyCollectionManager(DefaultCollectionManager[str]):
74 """A `CollectionManager` implementation that uses collection names for
75 primary/foreign keys and aggressively loads all collection/run records in
76 the database into memory.
78 Most of the logic, including caching policy, is implemented in the base
79 class, this class only adds customizations specific to this particular
80 table schema.
81 """
83 @classmethod
84 def initialize(
85 cls,
86 db: Database,
87 context: StaticTablesContext,
88 *,
89 dimensions: DimensionRecordStorageManager,
90 caching_context: CachingContext,
91 registry_schema_version: VersionTuple | None = None,
92 ) -> NameKeyCollectionManager:
93 # Docstring inherited from CollectionManager.
94 return cls(
95 db,
96 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
97 collectionIdName="name",
98 dimensions=dimensions,
99 caching_context=caching_context,
100 registry_schema_version=registry_schema_version,
101 )
103 @classmethod
104 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
105 # Docstring inherited from CollectionManager.
106 return f"{prefix}_name"
108 @classmethod
109 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
110 # Docstring inherited from CollectionManager.
111 return f"{prefix}_name"
113 @classmethod
114 def addCollectionForeignKey(
115 cls,
116 tableSpec: ddl.TableSpec,
117 *,
118 prefix: str = "collection",
119 onDelete: str | None = None,
120 constraint: bool = True,
121 **kwargs: Any,
122 ) -> ddl.FieldSpec:
123 # Docstring inherited from CollectionManager.
124 original = _KEY_FIELD_SPEC
125 copy = ddl.FieldSpec(
126 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
127 )
128 tableSpec.fields.add(copy)
129 if constraint:
130 tableSpec.foreignKeys.append(
131 ddl.ForeignKeySpec(
132 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
133 )
134 )
135 return copy
137 @classmethod
138 def addRunForeignKey(
139 cls,
140 tableSpec: ddl.TableSpec,
141 *,
142 prefix: str = "run",
143 onDelete: str | None = None,
144 constraint: bool = True,
145 **kwargs: Any,
146 ) -> ddl.FieldSpec:
147 # Docstring inherited from CollectionManager.
148 original = _KEY_FIELD_SPEC
149 copy = ddl.FieldSpec(
150 cls.getRunForeignKeyName(prefix), dtype=original.dtype, length=original.length, **kwargs
151 )
152 tableSpec.fields.add(copy)
153 if constraint: 153 ↛ 157line 153 didn't jump to line 157, because the condition on line 153 was never false
154 tableSpec.foreignKeys.append(
155 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
156 )
157 return copy
159 def getParentChains(self, key: str) -> set[str]:
160 # Docstring inherited from CollectionManager.
161 table = self._tables.collection_chain
162 sql = (
163 sqlalchemy.sql.select(table.columns["parent"])
164 .select_from(table)
165 .where(table.columns["child"] == key)
166 )
167 with self._db.query(sql) as sql_result:
168 parent_names = set(sql_result.scalars().all())
169 return parent_names
171 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
172 # Docstring inherited from base class.
173 return self._fetch_by_key(names)
175 def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
176 # Docstring inherited from base class.
177 sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
178 self._tables.collection.join(self._tables.run, isouter=True)
179 )
181 chain_sql = sqlalchemy.sql.select(
182 self._tables.collection_chain.columns["parent"],
183 self._tables.collection_chain.columns["position"],
184 self._tables.collection_chain.columns["child"],
185 )
187 records: list[CollectionRecord[str]] = []
188 # We want to keep transactions as short as possible. When we fetch
189 # everything we want to quickly fetch things into memory and finish
190 # transaction. When we fetch just few records we need to process result
191 # of the first query before we can run the second one.
192 if collection_ids is not None:
193 sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
194 with self._db.transaction():
195 with self._db.query(sql) as sql_result:
196 sql_rows = sql_result.mappings().fetchall()
198 records, chained_ids = self._rows_to_records(sql_rows)
200 if chained_ids:
201 # Retrieve chained collection compositions
202 chain_sql = chain_sql.where(
203 self._tables.collection_chain.columns["parent"].in_(chained_ids)
204 )
205 with self._db.query(chain_sql) as sql_result:
206 chain_rows = sql_result.mappings().fetchall()
208 records += self._rows_to_chains(chain_rows, chained_ids)
210 else:
211 with self._db.transaction():
212 with self._db.query(sql) as sql_result:
213 sql_rows = sql_result.mappings().fetchall()
214 with self._db.query(chain_sql) as sql_result:
215 chain_rows = sql_result.mappings().fetchall()
217 records, chained_ids = self._rows_to_records(sql_rows)
218 records += self._rows_to_chains(chain_rows, chained_ids)
220 return records
222 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
223 """Convert rows returned from collection query to a list of records
224 and a list chained collection names.
225 """
226 records: list[CollectionRecord[str]] = []
227 TimespanReprClass = self._db.getTimespanRepresentation()
228 chained_ids: list[str] = []
229 for row in rows:
230 name = row[self._tables.collection.columns.name]
231 type = CollectionType(row["type"])
232 record: CollectionRecord[str]
233 if type is CollectionType.RUN:
234 record = RunRecord[str](
235 key=name,
236 name=name,
237 host=row[self._tables.run.columns.host],
238 timespan=TimespanReprClass.extract(row),
239 )
240 records.append(record)
241 elif type is CollectionType.CHAINED:
242 # Need to delay chained collection construction until to
243 # fetch their children names.
244 chained_ids.append(name)
245 else:
246 record = CollectionRecord[str](key=name, name=name, type=type)
247 records.append(record)
249 return records, chained_ids
251 def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
252 """Convert rows returned from collection chain query to a list of
253 records.
254 """
255 chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
256 for row in rows:
257 chains_defs[row["parent"]].append((row["position"], row["child"]))
259 records: list[CollectionRecord[str]] = []
260 for name, children in chains_defs.items():
261 children_names = [child for _, child in sorted(children)]
262 record = ChainedCollectionRecord[str](
263 key=name,
264 name=name,
265 children=children_names,
266 )
267 records.append(record)
269 return records
271 @classmethod
272 def currentVersions(cls) -> list[VersionTuple]:
273 # Docstring inherited from VersionedExtension.
274 return [_VERSION]