Coverage for python / lsst / analysis / ap / apdbCassandra.py: 19%
120 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:24 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:24 +0000
1# This file is part of analysis_ap.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["DbCassandraQuery"]
26import os
27import warnings
28from typing import TYPE_CHECKING, cast
30import pandas as pd
32import lsst.utils
33from lsst.ap.association import UnpackApdbFlags
34from lsst.dax.apdb import Apdb, ApdbCassandra, ApdbTables
35from lsst.dax.apdb.cassandra.cassandra_utils import quote_id
36from lsst.pipe.base import Instrument
37from lsst.resources import ResourcePathExpression
38from .apdb import DbQuery
40if TYPE_CHECKING:
41 import lsst.daf.butler
44class DbCassandraQuery(DbQuery):
45 """Implementation of `DbQuery` interface for Cassandra backend.
47 Parameters
48 ----------
49 config_uri : `~lsst.resources.ResourcePathExpression`
50 URI or local file path pointing to a file with serialized
51 configuration, or a string with a "label:" prefix to locate
52 configuration in APDB index.
53 instrument : `str`
54 Short name (e.g. "DECam") of instrument to make a dataId unpacker
55 and to add to the table columns; supports any gen3 instrument.
56 To be deprecated once this information is in the database.
57 """
59 timeout = 600
60 """Timeout for queries in seconds. Regular timeout specified in APDB
61 configuration could be too short for for full-scan queries that this
62 class executes.
63 """
65 def __init__(
66 self,
67 config_uri: ResourcePathExpression,
68 *,
69 instrument: Instrument | None = None,
70 ):
71 self._instrument = instrument
73 flag_map = os.path.join(
74 lsst.utils.getPackageDir("ap_association"), "data/association-flag-map.yaml"
75 )
76 self._unpacker = UnpackApdbFlags(flag_map, "DiaSource")
78 self.set_excluded_diaSource_flags(
79 [
80 "base_PixelFlags_flag_bad",
81 "base_PixelFlags_flag_suspect",
82 "base_PixelFlags_flag_saturatedCenter",
83 "base_PixelFlags_flag_interpolated",
84 "base_PixelFlags_flag_interpolatedCenter",
85 "base_PixelFlags_flag_edge",
86 ]
87 )
89 # We depend on ApdbCassandra for many things which we do not want to
90 # reimplement here.
91 apdb = Apdb.from_uri(config_uri)
92 if not isinstance(apdb, ApdbCassandra):
93 raise TypeError(
94 f"Configuration file {config_uri} was produced for non-Cassandra backend."
95 )
96 self._apdb = apdb
98 # NOTE: not getting instrument here, as I don't know the interface for
99 # it and we don't want to rely on Cassandra for any analysis tooling.
101 def set_excluded_diaSource_flags(self, flag_list: list[str]) -> None:
102 # Docstring is inherited from base class.
103 for flag in flag_list:
104 if not self._unpacker.flagExists(flag, columnName="flags"):
105 raise ValueError(f"flag {flag} not included in DiaSource flags")
107 self.diaSource_flags_exclude = flag_list
109 def _filter_flags(self, catalog: pd.DataFrame, column_name: str = "flags") -> None:
110 """Filter catalog contents to exclude .
112 Parameters
113 ----------
114 catalog : `pandas.DataFrame`
115 Catalog to filter, update happens in-place.
116 column_name : `str`, optional
117 Name of flag column to query.
118 """
119 bitmask = int(
120 self._unpacker.makeFlagBitMask(
121 self.diaSource_flags_exclude, columnName=column_name
122 )
123 )
124 if bitmask == 0:
125 warnings.warn(
126 f"Flag bitmask is zero. Supplied flags: {self.diaSource_flags_exclude}",
127 RuntimeWarning,
128 )
129 mask = (catalog[column_name] & bitmask) != 0
130 catalog.drop(catalog[mask].index, inplace=True)
132 def _build_query(
133 self,
134 table: ApdbTables,
135 *,
136 columns: list[str] = [],
137 where: str = "",
138 limit: int = -1,
139 ) -> str:
140 """Build query for a specific table and selection."""
141 if columns:
142 what = ",".join(quote_id(column) for column in columns)
143 else:
144 what = "*"
146 query = f'SELECT {what} from "{self._apdb._keyspace}"."{table.name}"'
147 if where:
148 query += f" WHERE {where}"
149 if limit > 0:
150 query += f" LIMIT {limit}"
151 query += " ALLOW FILTERING"
152 return query
154 def load_sources_for_object(
155 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000
156 ) -> pd.DataFrame:
157 # Docstring is inherited from base class.
158 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource)
159 query = self._build_query(
160 ApdbTables.DiaSource,
161 columns=column_names,
162 where='"diaObjectId" = ?',
163 limit=int(limit),
164 )
165 statement = self._apdb._preparer.prepare(query)
166 result = self._apdb._session.execute(
167 statement,
168 (dia_object_id,),
169 timeout=self.timeout,
170 execution_profile="read_pandas",
171 )
172 catalog = cast(pd.DataFrame, result._current_rows)
173 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True)
175 if exclude_flagged:
176 self._filter_flags(catalog)
178 return catalog
180 def load_forced_sources_for_object(
181 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000
182 ) -> pd.DataFrame:
183 # Docstring is inherited from base class.
184 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource)
185 query = self._build_query(
186 ApdbTables.DiaForcedSource,
187 columns=column_names,
188 where='"diaObjectId" = ?',
189 limit=int(limit),
190 )
191 statement = self._apdb._preparer.prepare(query)
192 result = self._apdb._session.execute(
193 statement,
194 (dia_object_id,),
195 timeout=self.timeout,
196 execution_profile="read_pandas",
197 )
198 catalog = cast(pd.DataFrame, result._current_rows)
199 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True)
201 if exclude_flagged:
202 self._filter_flags(catalog)
204 return catalog
206 def load_source(self, id: int) -> pd.Series:
207 # Docstring is inherited from base class.
208 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource)
209 query = self._build_query(
210 ApdbTables.DiaSource, columns=column_names, where='"diaSourceId" = ?'
211 )
212 statement = self._apdb._preparer.prepare(query)
213 result = self._apdb._session.execute(
214 statement, (id,), timeout=self.timeout, execution_profile="read_pandas"
215 )
216 catalog = cast(pd.DataFrame, result._current_rows)
217 return catalog.iloc[0]
219 def load_sources(
220 self, exclude_flagged: bool = False, limit: int = 100000
221 ) -> pd.DataFrame:
222 # Docstring is inherited from base class.
223 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource)
224 query = self._build_query(
225 ApdbTables.DiaSource,
226 columns=column_names,
227 limit=int(limit),
228 )
229 statement = self._apdb._preparer.prepare(query)
230 result = self._apdb._session.execute(
231 statement, timeout=self.timeout, execution_profile="read_pandas"
232 )
233 catalog = cast(pd.DataFrame, result._current_rows)
234 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True)
236 if exclude_flagged:
237 self._filter_flags(catalog)
239 return catalog
241 def load_object(self, id: int) -> pd.Series:
242 # Docstring is inherited from base class.
243 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
244 query = self._build_query(
245 ApdbTables.DiaObjectLast, columns=column_names, where='"diaObjectId" = ?'
246 )
247 statement = self._apdb._preparer.prepare(query)
248 result = self._apdb._session.execute(
249 statement, (id,), timeout=self.timeout, execution_profile="read_pandas"
250 )
251 catalog = cast(pd.DataFrame, result._current_rows)
252 return catalog.iloc[0]
254 def load_objects(self, limit: int = 100000, latest: bool = True) -> pd.DataFrame:
255 # Docstring is inherited from base class.
256 if latest:
257 table = ApdbTables.DiaObjectLast
258 else:
259 # when we do replication then we don't always generate DiaObject
260 # contents.
261 config = self._apdb.config
262 if config.use_insert_id and config.use_insert_id_skips_diaobjects:
263 raise ValueError("DiaObject history is not available for this database")
264 table = ApdbTables.DiaObject
266 column_names = self._apdb._schema.apdbColumnNames(table)
267 query = self._build_query(
268 table,
269 columns=column_names,
270 limit=int(limit),
271 )
272 statement = self._apdb._preparer.prepare(query)
273 result = self._apdb._session.execute(
274 statement, timeout=self.timeout, execution_profile="read_pandas"
275 )
276 catalog = cast(pd.DataFrame, result._current_rows)
277 catalog.sort_values(by=["diaObjectId"], inplace=True)
278 return catalog
280 def load_forced_source(self, id: int) -> pd.Series:
281 # Docstring is inherited from base class.
282 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource)
283 query = self._build_query(
284 ApdbTables.DiaForcedSource,
285 columns=column_names,
286 where='"diaForcedSourceId" = ?',
287 )
288 statement = self._apdb._preparer.prepare(query)
289 result = self._apdb._session.execute(
290 statement, (id,), timeout=self.timeout, execution_profile="read_pandas"
291 )
292 catalog = cast(pd.DataFrame, result._current_rows)
293 return catalog.iloc[0]
295 def load_forced_sources(self, limit: int = 100000) -> pd.DataFrame:
296 # Docstring is inherited from base class.
297 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource)
298 query = self._build_query(
299 ApdbTables.DiaForcedSource,
300 columns=column_names,
301 limit=int(limit),
302 )
303 statement = self._apdb._preparer.prepare(query)
304 result = self._apdb._session.execute(
305 statement, timeout=self.timeout, execution_profile="read_pandas"
306 )
307 catalog = cast(pd.DataFrame, result._current_rows)
308 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True)
310 return catalog