Coverage for python / lsst / analysis / ap / apdb.py: 22%
164 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:46 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:46 +0000
1# This file is part of analysis_ap.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""APDB connection management and data access tools.
23"""
25__all__ = ["DbQuery", "ApdbSqliteQuery", "ApdbPostgresQuery"]
27import abc
28import contextlib
29import warnings
31import pandas as pd
32import sqlalchemy
35class DbQuery(abc.ABC):
36 """Abstract interface for APDB queries.
38 Notes
39 -----
40 APDB interface used by AP pipeline is defined by `lsst.dax.apdb.Apdb`
41 class. Methods in this class are for non-pipeline tools that can analyse
42 data produced by pipeline. APDB schema is not designed for analysis queries
43 and performance of these methods can be non-optimal, especially for
44 Cassandra backend. It is expected that these analysis queries should not be
45 executed on production Cassandra service.
46 """
48 def set_excluded_diaSource_flags(self, flag_list: list[str]) -> None:
49 """Set flags of diaSources to exclude when loading diaSources.
51 Any diaSources with configured flags are not returned
52 when calling `load_sources_for_object` or `load_sources`
53 with `exclude_flagged = True`.
55 Parameters
56 ----------
57 flag_list : `list` [`str`]
58 Flag names to exclude.
59 """
60 raise NotImplementedError()
62 def load_sources_for_object(
63 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000
64 ) -> pd.DataFrame:
65 """Load diaSources for a single diaObject.
67 Parameters
68 ----------
69 dia_object_id : `int`
70 Id of object to load sources for.
71 exclude_flagged : `bool`, optional
72 Exclude sources that have selected flags set.
73 Use `set_excluded_diaSource_flags` to configure which flags
74 are excluded.
75 limit : `int`
76 Maximum number of rows to return.
78 Returns
79 -------
80 data : `pandas.DataFrame`
81 A data frame of diaSources for the specified diaObject.
82 """
83 raise NotImplementedError()
85 def load_forced_sources_for_object(
86 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000
87 ) -> pd.DataFrame:
88 """Load diaForcedSources for a single diaObject.
90 Parameters
91 ----------
92 dia_object_id : `int`
93 Id of object to load sources for.
94 exclude_flagged : `bool`, optional
95 Exclude sources that have selected flags set.
96 Use `set_excluded_diaSource_flags` to configure which flags
97 are excluded.
98 limit : `int`
99 Maximum number of rows to return.
101 Returns
102 -------
103 data : `pandas.DataFrame`
104 A data frame of diaSources for the specified diaObject.
105 """
106 raise NotImplementedError()
108 def load_source(self, id: int) -> pd.Series:
109 """Load one diaSource.
111 Parameters
112 ----------
113 id : `int`
114 The diaSourceId to load data for.
116 Returns
117 -------
118 data : `pandas.Series`
119 The requested diaSource.
120 """
121 raise NotImplementedError()
123 def load_sources(self, exclude_flagged: bool = False, limit: int = 100000) -> pd.DataFrame:
124 """Load diaSources.
126 Parameters
127 ----------
128 exclude_flagged : `bool`, optional
129 Exclude sources that have selected flags set.
130 Use `set_excluded_diaSource_flags` to configure which flags
131 are excluded.
132 limit : `int`
133 Maximum number of rows to return.
135 Returns
136 -------
137 data : `pandas.DataFrame`
138 All available diaSources.
139 """
140 raise NotImplementedError()
142 def load_object(self, id: int) -> pd.Series:
143 """Load the most-recently updated version of one diaObject.
145 Parameters
146 ----------
147 id : `int`
148 The diaObjectId to load data for.
150 Returns
151 -------
152 data : `pandas.Series`
153 The requested object.
154 """
155 raise NotImplementedError()
157 def load_objects(self, limit: int = 100000, latest: bool = True) -> pd.DataFrame:
158 """Load all diaObjects.
160 Parameters
161 ----------
162 limit : `int`
163 Maximum number of rows to return.
164 latest : `bool`
165 Only load diaObjects where validityEnd is None.
166 These are the most-recently updated diaObjects.
168 Returns
169 -------
170 data : `pandas.DataFrame`
171 All available diaObjects.
172 """
173 raise NotImplementedError()
175 def load_forced_source(self, id: int) -> pd.Series:
176 """Load one diaForcedSource.
178 Parameters
179 ----------
180 id : `int`
181 The diaForcedSourceId to load data for.
183 Returns
184 -------
185 data : `pandas.Series`
186 The requested forced source.
187 """
188 raise NotImplementedError()
190 def load_forced_sources(self, limit: int = 100000) -> pd.DataFrame:
191 """Load all diaForcedSources.
193 Parameters
194 ----------
195 limit : `int`
196 Maximum number of rows to return.
198 Returns
199 -------
200 data : `pandas.DataFrame`
201 All available diaForcedSources.
202 """
203 raise NotImplementedError()
206class DbSqlQuery(DbQuery):
207 """Base class for APDB connection and query management for SQL backends.
209 Subclasses must specify a ``connection`` property to use as a context-
210 manager for queries.
212 Parameters
213 ----------
214 instrument : `str`
215 Short name (e.g. "DECam") of instrument to make a dataId unpacker
216 and to add to the table columns; supports any gen3 instrument.
217 To be deprecated once this information is in the database.
218 """
220 def __init__(self, instrument=None):
221 if instrument is not None:
222 warnings.warn("The instrument name is now pulled from the APDB; "
223 "this kwarg is ignored and will be removed after v29",
224 FutureWarning,
225 stacklevel=2)
227 self.set_excluded_diaSource_flags(['pixelFlags_bad',
228 'pixelFlags_suspect',
229 'pixelFlags_saturatedCenter',
230 'pixelFlags_interpolated',
231 'pixelFlags_interpolatedCenter',
232 'pixelFlags_edge',
233 ])
235 key = "instrument"
236 table = self._tables["metadata"]
237 sql = sqlalchemy.sql.select(table.columns.value).where(table.columns.name == key)
238 with self.connection as conn:
239 result = conn.execute(sql)
240 self._instrument = result.scalar()
242 @property
243 @contextlib.contextmanager
244 @abc.abstractmethod
245 def connection(self):
246 """Context manager for database connections.
248 Yields
249 ------
250 connection : `sqlalchemy.engine.Connection`
251 Connection to the database that will be queried. Whether the
252 connection is closed after the context manager is closed is
253 implementation dependent.
254 """
255 pass
257 def set_excluded_diaSource_flags(self, flag_list):
258 """Set flags of diaSources to exclude when loading diaSources.
260 Any diaSources with configured flags are not returned
261 when calling `load_sources_for_object` or `load_sources`
262 with `exclude_flagged = True`.
264 Parameters
265 ----------
266 flag_list : `list` [`str`]
267 Flag names to exclude.
268 """
269 for flag in flag_list:
270 if flag not in self._tables["DiaSource"].columns:
271 raise ValueError(f"flag {flag} not included in DiaSource flags")
273 self.diaSource_flags_exclude = flag_list
275 def _make_flag_exclusion_query(self, query, table, flag_list):
276 """Return an SQL where query that excludes sources with chosen flags.
278 Parameters
279 ----------
280 flag_list : `list` [`str`]
281 Flag names to exclude.
282 query : `sqlalchemy.sql.Query`
283 Query to include the where statement in.
284 table : `sqlalchemy.schema.Table`
285 Table containing the column to be queried.
287 Returns
288 -------
289 query : `sqlalchemy.sql.Query`
290 Query that selects rows to exclude based on flags.
291 """
292 # Build a query that selects any source with one or more chosen flags,
293 # and return the opposite (`not_`) of that query.
294 query = query.where(sqlalchemy.and_(table.columns[flag_col] == False # noqa: E712
295 for flag_col in flag_list))
296 return query
298 def load_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000):
299 """Load diaSources for a single diaObject.
301 Parameters
302 ----------
303 dia_object_id : `int`
304 Id of object to load sources for.
305 exclude_flagged : `bool`, optional
306 Exclude sources that have selected flags set.
307 Use `set_excluded_diaSource_flags` to configure which flags
308 are excluded.
309 limit : `int`
310 Maximum number of rows to return.
312 Returns
313 -------
314 data : `pandas.DataFrame`
315 A data frame of diaSources for the specified diaObject.
316 """
317 table = self._tables["DiaSource"]
318 query = table.select().where(table.columns["diaObjectId"] == dia_object_id)
319 if exclude_flagged:
320 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
321 query = query.order_by(table.columns["visit"],
322 table.columns["detector"],
323 table.columns["diaSourceId"])
324 with self.connection as connection:
325 result = pd.read_sql_query(query, connection)
327 self._fill_from_instrument(result)
328 return result
330 def load_forced_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000):
331 """Load diaForcedSources for a single diaObject.
333 Parameters
334 ----------
335 dia_object_id : `int`
336 Id of object to load sources for.
337 exclude_flagged : `bool`, optional
338 Exclude sources that have selected flags set.
339 Use `set_excluded_diaSource_flags` to configure which flags
340 are excluded.
341 limit : `int`
342 Maximum number of rows to return.
344 Returns
345 -------
346 data : `pandas.DataFrame`
347 A data frame of diaSources for the specified diaObject.
348 """
349 table = self._tables["DiaForcedSource"]
350 query = table.select().where(table.columns["diaObjectId"] == dia_object_id)
351 if exclude_flagged:
352 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
353 query = query.order_by(table.columns["visit"],
354 table.columns["detector"],
355 table.columns["diaForcedSourceId"])
356 with self.connection as connection:
357 result = pd.read_sql_query(query, connection)
359 self._fill_from_instrument(result)
360 return result
362 def load_source(self, id):
363 """Load one diaSource.
365 Parameters
366 ----------
367 id : `int`
368 The diaSourceId to load data for.
370 Returns
371 -------
372 data : `pandas.Series`
373 The requested diaSource.
374 """
375 table = self._tables["DiaSource"]
376 query = table.select().where(table.columns["diaSourceId"] == id)
377 with self.connection as connection:
378 result = pd.read_sql_query(query, connection)
379 if len(result) == 0:
380 raise RuntimeError(f"diaSourceId={id} not found in DiaSource table")
382 self._fill_from_instrument(result)
383 return result.iloc[0]
385 def load_sources(self, exclude_flagged=False, limit=100000):
386 """Load diaSources.
388 Parameters
389 ----------
390 exclude_flagged : `bool`, optional
391 Exclude sources that have selected flags set.
392 Use `set_excluded_diaSource_flags` to configure which flags
393 are excluded.
394 limit : `int`
395 Maximum number of rows to return.
397 Returns
398 -------
399 data : `pandas.DataFrame`
400 All available diaSources.
401 """
402 table = self._tables["DiaSource"]
403 query = table.select()
404 if exclude_flagged:
405 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
406 query = query.order_by(table.columns["visit"],
407 table.columns["detector"],
408 table.columns["diaSourceId"])
409 if limit is not None:
410 query = query.limit(limit)
412 with self.connection as connection:
413 result = pd.read_sql_query(query, connection)
415 self._fill_from_instrument(result)
416 return result
418 def load_object(self, id):
419 """Load the most-recently updated version of one diaObject.
421 Parameters
422 ----------
423 id : `int`
424 The diaObjectId to load data for.
426 Returns
427 -------
428 data : `pandas.Series`
429 The requested object.
430 """
431 table = self._tables["DiaObject"]
432 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711
433 query = query.where(table.columns["diaObjectId"] == id)
434 with self.connection as connection:
435 result = pd.read_sql_query(query, connection)
436 if len(result) == 0:
437 raise RuntimeError(f"diaObjectId={id} not found in DiaObject table")
439 return result.iloc[0]
441 def load_objects(self, limit=100000, latest=True):
442 """Load all diaObjects.
444 Parameters
445 ----------
446 limit : `int`
447 Maximum number of rows to return.
448 latest : `bool`
449 Only load diaObjects where validityEnd is None.
450 These are the most-recently updated diaObjects.
452 Returns
453 -------
454 data : `pandas.DataFrame`
455 All available diaObjects.
456 """
457 table = self._tables["DiaObject"]
458 if latest:
459 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711
460 else:
461 query = table.select()
462 query = query.order_by(table.columns["diaObjectId"])
463 if limit is not None:
464 query = query.limit(limit)
466 with self.connection as connection:
467 result = pd.read_sql_query(query, connection)
469 return result
471 def load_forced_source(self, id):
472 """Load one diaForcedSource.
474 Parameters
475 ----------
476 id : `int`
477 The diaForcedSourceId to load data for.
479 Returns
480 -------
481 data : `pandas.Series`
482 The requested forced source.
483 """
484 table = self._tables["DiaForcedSource"]
485 query = table.select().where(table.columns["diaForcedSourceId"] == id)
486 with self.connection as connection:
487 result = pd.read_sql_query(query, connection)
488 if len(result) == 0:
489 raise RuntimeError(f"diaForcedSourceId={id} not found in DiaForcedSource table")
491 self._fill_from_instrument(result)
492 return result.iloc[0]
494 def load_forced_sources(self, limit=100000):
495 """Load all diaForcedSources.
497 Parameters
498 ----------
499 limit : `int`
500 Maximum number of rows to return.
502 Returns
503 -------
504 data : `pandas.DataFrame`
505 All available diaForcedSources.
506 """
507 table = self._tables["DiaForcedSource"]
508 query = table.select()
509 query = query.order_by(table.columns["visit"],
510 table.columns["detector"],
511 table.columns["diaForcedSourceId"])
512 if limit is not None:
513 query = query.limit(limit)
515 with self.connection as connection:
516 result = pd.read_sql_query(query, connection)
517 self._fill_from_instrument(result)
518 return result
520 def _fill_from_instrument(self, diaSources):
521 """Add an instrument column to a list of sources.
522 This method is temporary, until APDB has instrument in its metadata.
524 Parameters
525 ----------
526 diaSources : `pandas.core.frame.DataFrame`
527 Pandas dataframe with diaSources from an APDB; modified in-place.
528 """
529 # do nothing for an empty series
530 if len(diaSources) == 0:
531 return
533 diaSources['instrument'] = self._instrument
536class ApdbSqliteQuery(DbSqlQuery):
537 """Open an sqlite3 APDB file to load data from it.
539 This class keeps the sqlite connection open after initialization because
540 our sqlite usage is to load a local file. Closing and re-opening would
541 re-scan the whole file every time, and we don't need to worry about
542 multiple users when working with local sqlite files.
544 Parameters
545 ----------
546 filename : `str`
547 Path to the sqlite3 file containing the APDB to load.
548 instrument : `str`
549 Short name (e.g. "DECam") of instrument to make a dataId unpacker
550 and to add to the table columns; supports any gen3 instrument.
551 To be deprecated once this information is in the database.
552 """
554 def __init__(self, filename, instrument=None, **kwargs):
555 # For sqlite, use a larger pool and a faster timeout, to allow many
556 # repeat transactions with the same connection, as transactions on
557 # our sqlite DBs should be small and fast.
558 self._engine = sqlalchemy.create_engine(f"sqlite:///{filename}",
559 pool_timeout=5, pool_size=200)
561 with self.connection as connection:
562 metadata = sqlalchemy.MetaData()
563 metadata.reflect(bind=connection)
564 self._tables = metadata.tables
565 super().__init__(**kwargs)
567 @property
568 @contextlib.contextmanager
569 def connection(self):
570 yield self._engine.connect()
573class ApdbPostgresQuery(DbSqlQuery):
574 """Connect to a running postgres APDB instance and load data from it.
576 This class connects to the database only when the ``connection`` context
577 manager is entered, and closes the connection after it exits.
579 Parameters
580 ----------
581 namespace : `str`
582 Database namespace to load from. Called "schema" in postgres docs.
583 url : `str`
584 Complete url to connect to postgres database, without prepended
585 ``postgresql://``.
586 instrument : `str`
587 Short name (e.g. "DECam") of instrument to make a dataId unpacker
588 and to add to the table columns; supports any gen3 instrument.
589 To be deprecated once this information is in the database.
590 """
592 def __init__(self, namespace, url="rubin@usdf-prompt-processing-dev.slac.stanford.edu/lsst-devl",
593 instrument=None, **kwargs):
594 self._connection_string = f"postgresql://{url}"
595 self._namespace = namespace
596 self._engine = sqlalchemy.create_engine(self._connection_string, poolclass=sqlalchemy.pool.NullPool)
598 with self.connection as connection:
599 metadata = sqlalchemy.MetaData(schema=namespace)
600 metadata.reflect(bind=connection)
601 # ensure tables don't have schema prepended
602 self._tables = {}
603 for table in metadata.tables.values():
604 self._tables[table.name] = table
605 super().__init__(instrument=instrument, **kwargs)
607 @property
608 @contextlib.contextmanager
609 def connection(self):
610 _connection = self._engine.connect()
611 try:
612 yield _connection
613 finally:
614 _connection.close()