Coverage for python/lsst/daf/butler/registry/collections/synthIntKey.py: 99%
112 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:48 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:48 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["SynthIntKeyCollectionManager"]
33import logging
34from collections.abc import Iterable, Mapping
35from typing import TYPE_CHECKING, Any
37import sqlalchemy
39from ...timespan_database_representation import TimespanDatabaseRepresentation
40from .._collection_type import CollectionType
41from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
42from ._base import (
43 CollectionTablesTuple,
44 DefaultCollectionManager,
45 makeCollectionChainTableSpec,
46 makeRunTableSpec,
47)
49if TYPE_CHECKING:
50 from .._caching_context import CachingContext
51 from ..interfaces import Database, StaticTablesContext
54_KEY_FIELD_SPEC = ddl.FieldSpec(
55 "collection_id", dtype=sqlalchemy.BigInteger, primaryKey=True, autoincrement=True
56)
59# This has to be updated on every schema change
60_VERSION = VersionTuple(2, 0, 0)
62_LOG = logging.getLogger(__name__)
65def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
66 return CollectionTablesTuple(
67 collection=ddl.TableSpec(
68 fields=[
69 _KEY_FIELD_SPEC,
70 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, nullable=False),
71 ddl.FieldSpec("type", dtype=sqlalchemy.SmallInteger, nullable=False),
72 ddl.FieldSpec("doc", dtype=sqlalchemy.Text, nullable=True),
73 ],
74 unique=[("name",)],
75 ),
76 run=makeRunTableSpec("collection_id", sqlalchemy.BigInteger, TimespanReprClass),
77 collection_chain=makeCollectionChainTableSpec("collection_id", sqlalchemy.BigInteger),
78 )
81class SynthIntKeyCollectionManager(DefaultCollectionManager[int]):
82 """A `CollectionManager` implementation that uses synthetic primary key
83 (auto-incremented integer) for collections table.
84 """
86 @classmethod
87 def initialize(
88 cls,
89 db: Database,
90 context: StaticTablesContext,
91 *,
92 caching_context: CachingContext,
93 registry_schema_version: VersionTuple | None = None,
94 ) -> SynthIntKeyCollectionManager:
95 # Docstring inherited from CollectionManager.
96 return cls(
97 db,
98 tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
99 collectionIdName="collection_id",
100 caching_context=caching_context,
101 registry_schema_version=registry_schema_version,
102 )
104 def clone(self, db: Database, caching_context: CachingContext) -> SynthIntKeyCollectionManager:
105 return SynthIntKeyCollectionManager(
106 db,
107 tables=self._tables,
108 collectionIdName=self._collectionIdName,
109 caching_context=caching_context,
110 registry_schema_version=self._registry_schema_version,
111 )
113 @classmethod
114 def getCollectionForeignKeyName(cls, prefix: str = "collection") -> str:
115 # Docstring inherited from CollectionManager.
116 return f"{prefix}_id"
118 @classmethod
119 def getRunForeignKeyName(cls, prefix: str = "run") -> str:
120 # Docstring inherited from CollectionManager.
121 return f"{prefix}_id"
123 @classmethod
124 def addCollectionForeignKey(
125 cls,
126 tableSpec: ddl.TableSpec,
127 *,
128 prefix: str = "collection",
129 onDelete: str | None = None,
130 constraint: bool = True,
131 **kwargs: Any,
132 ) -> ddl.FieldSpec:
133 # Docstring inherited from CollectionManager.
134 original = _KEY_FIELD_SPEC
135 copy = ddl.FieldSpec(
136 cls.getCollectionForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
137 )
138 tableSpec.fields.add(copy)
139 if constraint:
140 tableSpec.foreignKeys.append(
141 ddl.ForeignKeySpec(
142 "collection", source=(copy.name,), target=(original.name,), onDelete=onDelete
143 )
144 )
145 return copy
147 @classmethod
148 def addRunForeignKey(
149 cls,
150 tableSpec: ddl.TableSpec,
151 *,
152 prefix: str = "run",
153 onDelete: str | None = None,
154 constraint: bool = True,
155 **kwargs: Any,
156 ) -> ddl.FieldSpec:
157 # Docstring inherited from CollectionManager.
158 original = _KEY_FIELD_SPEC
159 copy = ddl.FieldSpec(
160 cls.getRunForeignKeyName(prefix), dtype=original.dtype, autoincrement=False, **kwargs
161 )
162 tableSpec.fields.add(copy)
163 if constraint: 163 ↛ 167line 163 didn't jump to line 167, because the condition on line 163 was never false
164 tableSpec.foreignKeys.append(
165 ddl.ForeignKeySpec("run", source=(copy.name,), target=(original.name,), onDelete=onDelete)
166 )
167 return copy
169 def getParentChains(self, key: int) -> set[str]:
170 # Docstring inherited from CollectionManager.
171 chain = self._tables.collection_chain
172 collection = self._tables.collection
173 sql = (
174 sqlalchemy.sql.select(collection.columns["name"])
175 .select_from(collection)
176 .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
177 .where(chain.columns["child"] == key)
178 )
179 with self._db.query(sql) as sql_result:
180 parent_names = set(sql_result.scalars().all())
181 return parent_names
183 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]:
184 # Docstring inherited from base class.
185 _LOG.debug("Fetching collection records using names %s.", names)
186 return self._fetch("name", names)
188 def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]:
189 # Docstring inherited from base class.
190 _LOG.debug("Fetching collection records using IDs %s.", collection_ids)
191 return self._fetch(self._collectionIdName, collection_ids)
193 def _fetch(
194 self, column_name: str, collections: Iterable[int | str] | None
195 ) -> list[CollectionRecord[int]]:
196 collection_chain = self._tables.collection_chain
197 collection = self._tables.collection
198 sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from(
199 collection.join(self._tables.run, isouter=True)
200 )
202 chain_sql = (
203 sqlalchemy.sql.select(
204 collection_chain.columns["parent"],
205 collection_chain.columns["position"],
206 collection.columns["name"].label("child_name"),
207 )
208 .select_from(collection_chain)
209 .join(
210 collection,
211 onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName],
212 )
213 )
215 records: list[CollectionRecord[int]] = []
216 # We want to keep transactions as short as possible. When we fetch
217 # everything we want to quickly fetch things into memory and finish
218 # transaction. When we fetch just few records we need to process first
219 # query before wi can run second one,
220 if collections is not None:
221 sql = sql.where(collection.columns[column_name].in_(collections))
222 with self._db.transaction():
223 with self._db.query(sql) as sql_result:
224 sql_rows = sql_result.mappings().fetchall()
226 records, chained_ids = self._rows_to_records(sql_rows)
228 if chained_ids:
229 chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids)))
231 with self._db.query(chain_sql) as sql_result:
232 chain_rows = sql_result.mappings().fetchall()
234 records += self._rows_to_chains(chain_rows, chained_ids)
236 else:
237 with self._db.transaction():
238 with self._db.query(sql) as sql_result:
239 sql_rows = sql_result.mappings().fetchall()
240 with self._db.query(chain_sql) as sql_result:
241 chain_rows = sql_result.mappings().fetchall()
243 records, chained_ids = self._rows_to_records(sql_rows)
244 records += self._rows_to_chains(chain_rows, chained_ids)
246 return records
248 def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]:
249 """Convert rows returned from collection query to a list of records
250 and a dict chained collection names.
251 """
252 records: list[CollectionRecord[int]] = []
253 chained_ids: dict[int, str] = {}
254 TimespanReprClass = self._db.getTimespanRepresentation()
255 for row in rows:
256 key: int = row[self._collectionIdName]
257 name: str = row[self._tables.collection.columns.name]
258 type = CollectionType(row["type"])
259 record: CollectionRecord[int]
260 if type is CollectionType.RUN:
261 record = RunRecord[int](
262 key=key,
263 name=name,
264 host=row[self._tables.run.columns.host],
265 timespan=TimespanReprClass.extract(row),
266 )
267 records.append(record)
268 elif type is CollectionType.CHAINED:
269 # Need to delay chained collection construction until to
270 # fetch their children names.
271 chained_ids[key] = name
272 else:
273 record = CollectionRecord[int](key=key, name=name, type=type)
274 records.append(record)
275 return records, chained_ids
277 def _rows_to_chains(
278 self, rows: Iterable[Mapping], chained_ids: dict[int, str]
279 ) -> list[CollectionRecord[int]]:
280 """Convert rows returned from collection chain query to a list of
281 records.
282 """
283 chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
284 for row in rows:
285 chains_defs[row["parent"]].append((row["position"], row["child_name"]))
287 records: list[CollectionRecord[int]] = []
288 for key, children in chains_defs.items():
289 name = chained_ids[key]
290 children_names = [child for _, child in sorted(children)]
291 record = ChainedCollectionRecord[int](
292 key=key,
293 name=name,
294 children=children_names,
295 )
296 records.append(record)
298 return records
300 @classmethod
301 def currentVersions(cls) -> list[VersionTuple]:
302 # Docstring inherited from VersionedExtension.
303 return [_VERSION]