Coverage for python/lsst/analysis/ap/apdb.py: 25%
139 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 03:02 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 03:02 -0700
1# This file is part of analysis_ap.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""APDB connection management and data access tools.
23"""
25__all__ = ["DbQuery", "ApdbSqliteQuery", "ApdbPostgresQuery"]
27import abc
28import contextlib
30import pandas as pd
31import sqlalchemy
34class DbQuery(abc.ABC):
35 """Base class for APDB connection and query management.
37 Subclasses must specify a ``connection`` property to use as a context-
38 manager for queries.
40 Parameters
41 ----------
42 instrument : `str`
43 Short name (e.g. "DECam") of instrument to make a dataId unpacker
44 and to add to the table columns; supports any gen3 instrument.
45 To be deprecated once this information is in the database.
46 """
48 def __init__(self, instrument=None):
49 if not instrument:
50 raise RuntimeError("Instrument is required until DM-39502, "
51 "when it will be part of the APDB metadata.")
52 self._instrument = instrument
53 self.set_excluded_diaSource_flags(['pixelFlags_bad',
54 'pixelFlags_suspect',
55 'pixelFlags_saturatedCenter',
56 'pixelFlags_interpolated',
57 'pixelFlags_interpolatedCenter',
58 'pixelFlags_edge',
59 ])
61 @property
62 @contextlib.contextmanager
63 @abc.abstractmethod
64 def connection(self):
65 """Context manager for database connections.
67 Yields
68 ------
69 connection : `sqlalchemy.engine.Connection`
70 Connection to the database that will be queried. Whether the
71 connection is closed after the context manager is closed is
72 implementation dependent.
73 """
74 pass
76 def set_excluded_diaSource_flags(self, flag_list):
77 """Set flags of diaSources to exclude when loading diaSources.
79 Any diaSources with configured flags are not returned
80 when calling `load_sources_for_object` or `load_sources`
81 with `exclude_flagged = True`.
83 Parameters
84 ----------
85 flag_list : `list` [`str`]
86 Flag names to exclude.
87 """
88 for flag in flag_list:
89 if flag not in self._tables["DiaSource"].columns:
90 raise ValueError(f"flag {flag} not included in DiaSource flags")
92 self.diaSource_flags_exclude = flag_list
94 def _make_flag_exclusion_query(self, query, table, flag_list):
95 """Return an SQL where query that excludes sources with chosen flags.
97 Parameters
98 ----------
99 flag_list : `list` [`str`]
100 Flag names to exclude.
101 query : `sqlalchemy.sql.Query`
102 Query to include the where statement in.
103 table : `sqlalchemy.schema.Table`
104 Table containing the column to be queried.
106 Returns
107 -------
108 query : `sqlalchemy.sql.Query`
109 Query that selects rows to exclude based on flags.
110 """
111 # Build a query that selects any source with one or more chosen flags,
112 # and return the opposite (`not_`) of that query.
113 query = query.where(sqlalchemy.not_(sqlalchemy.or_(table.columns[flag_col] == 1
114 for flag_col in flag_list)))
115 return query
117 def load_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000):
118 """Load diaSources for a single diaObject.
120 Parameters
121 ----------
122 dia_object_id : `int`
123 Id of object to load sources for.
124 exclude_flagged : `bool`, optional
125 Exclude sources that have selected flags set.
126 Use `set_excluded_diaSource_flags` to configure which flags
127 are excluded.
128 limit : `int`
129 Maximum number of rows to return.
131 Returns
132 -------
133 data : `pandas.DataFrame`
134 A data frame of diaSources for the specified diaObject.
135 """
136 table = self._tables["DiaSource"]
137 query = table.select().where(table.columns["diaObjectId"] == dia_object_id)
138 if exclude_flagged:
139 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
140 query = query.order_by(table.columns["visit"],
141 table.columns["detector"],
142 table.columns["diaSourceId"])
143 with self.connection as connection:
144 result = pd.read_sql_query(query, connection)
146 self._fill_from_instrument(result)
147 return result
149 def load_forced_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000):
150 """Load diaForcedSources for a single diaObject.
152 Parameters
153 ----------
154 dia_object_id : `int`
155 Id of object to load sources for.
156 exclude_flagged : `bool`, optional
157 Exclude sources that have selected flags set.
158 Use `set_excluded_diaSource_flags` to configure which flags
159 are excluded.
160 limit : `int`
161 Maximum number of rows to return.
163 Returns
164 -------
165 data : `pandas.DataFrame`
166 A data frame of diaSources for the specified diaObject.
167 """
168 table = self._tables["DiaForcedSource"]
169 query = table.select().where(table.columns["diaObjectId"] == dia_object_id)
170 if exclude_flagged:
171 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
172 query = query.order_by(table.columns["visit"],
173 table.columns["detector"],
174 table.columns["diaForcedSourceId"])
175 with self.connection as connection:
176 result = pd.read_sql_query(query, connection)
178 self._fill_from_instrument(result)
179 return result
181 def load_source(self, id):
182 """Load one diaSource.
184 Parameters
185 ----------
186 id : `int`
187 The diaSourceId to load data for.
189 Returns
190 -------
191 data : `pandas.Series`
192 The requested diaSource.
193 """
194 table = self._tables["DiaSource"]
195 query = table.select().where(table.columns["diaSourceId"] == id)
196 with self.connection as connection:
197 result = pd.read_sql_query(query, connection)
198 if len(result) == 0:
199 raise RuntimeError(f"diaSourceId={id} not found in DiaSource table")
201 self._fill_from_instrument(result)
202 return result.iloc[0]
204 def load_sources(self, exclude_flagged=False, limit=100000):
205 """Load diaSources.
207 Parameters
208 ----------
209 exclude_flagged : `bool`, optional
210 Exclude sources that have selected flags set.
211 Use `set_excluded_diaSource_flags` to configure which flags
212 are excluded.
213 limit : `int`
214 Maximum number of rows to return.
216 Returns
217 -------
218 data : `pandas.DataFrame`
219 All available diaSources.
220 """
221 table = self._tables["DiaSource"]
222 query = table.select()
223 if exclude_flagged:
224 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude)
225 query = query.order_by(table.columns["visit"],
226 table.columns["detector"],
227 table.columns["diaSourceId"])
228 if limit is not None:
229 query = query.limit(limit)
231 with self.connection as connection:
232 result = pd.read_sql_query(query, connection)
234 self._fill_from_instrument(result)
235 return result
237 def load_object(self, id):
238 """Load the most-recently updated version of one diaObject.
240 Parameters
241 ----------
242 id : `int`
243 The diaObjectId to load data for.
245 Returns
246 -------
247 data : `pandas.Series`
248 The requested object.
249 """
250 table = self._tables["DiaObject"]
251 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711
252 query = query.where(table.columns["diaObjectId"] == id)
253 with self.connection as connection:
254 result = pd.read_sql_query(query, connection)
255 if len(result) == 0:
256 raise RuntimeError(f"diaObjectId={id} not found in DiaObject table")
258 return result.iloc[0]
260 def load_objects(self, limit=100000, latest=True):
261 """Load all diaObjects.
263 Parameters
264 ----------
265 limit : `int`
266 Maximum number of rows to return.
267 latest : `bool`
268 Only load diaObjects where validityEnd is None.
269 These are the most-recently updated diaObjects.
271 Returns
272 -------
273 data : `pandas.DataFrame`
274 All available diaObjects.
275 """
276 table = self._tables["DiaObject"]
277 if latest:
278 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711
279 query = query.order_by(table.columns["diaObjectId"])
280 if limit is not None:
281 query = query.limit(limit)
283 with self.connection as connection:
284 result = pd.read_sql_query(query, connection)
286 return result
288 def load_forced_source(self, id):
289 """Load one diaForcedSource.
291 Parameters
292 ----------
293 id : `int`
294 The diaForcedSourceId to load data for.
296 Returns
297 -------
298 data : `pandas.Series`
299 The requested forced source.
300 """
301 table = self._tables["DiaForcedSource"]
302 query = table.select().where(table.columns["diaForcedSourceId"] == id)
303 with self.connection as connection:
304 result = pd.read_sql_query(query, connection)
305 if len(result) == 0:
306 raise RuntimeError(f"diaForcedSourceId={id} not found in DiaForcedSource table")
308 self._fill_from_instrument(result)
309 return result.iloc[0]
311 def load_forced_sources(self, limit=100000):
312 """Load all diaForcedSources.
314 Parameters
315 ----------
316 limit : `int`
317 Maximum number of rows to return.
319 Returns
320 -------
321 data : `pandas.DataFrame`
322 All available diaForcedSources.
323 """
324 table = self._tables["DiaForcedSource"]
325 query = table.select()
326 query = query.order_by(table.columns["visit"],
327 table.columns["detector"],
328 table.columns["diaForcedSourceId"])
329 if limit is not None:
330 query = query.limit(limit)
332 with self.connection as connection:
333 result = pd.read_sql_query(query, connection)
334 self._fill_from_instrument(result)
335 return result
337 def _fill_from_instrument(self, diaSources):
338 """Add instrument to the database.
339 This method is temporary, until APDB has instrument in its metadata.
341 Parameters
342 ----------
343 diaSources : `pandas.core.frame.DataFrame`
344 Pandas dataframe with diaSources from an APDB; modified in-place.
345 """
346 # do nothing for an empty series
347 if len(diaSources) == 0:
348 return
350 diaSources['instrument'] = self._instrument
353class ApdbSqliteQuery(DbQuery):
354 """Open an sqlite3 APDB file to load data from it.
356 This class keeps the sqlite connection open after initialization because
357 our sqlite usage is to load a local file. Closing and re-opening would
358 re-scan the whole file every time, and we don't need to worry about
359 multiple users when working with local sqlite files.
361 Parameters
362 ----------
363 filename : `str`
364 Path to the sqlite3 file containing the APDB to load.
365 instrument : `str`
366 Short name (e.g. "DECam") of instrument to make a dataId unpacker
367 and to add to the table columns; supports any gen3 instrument.
368 To be deprecated once this information is in the database.
369 """
371 def __init__(self, filename, instrument=None, **kwargs):
372 # For sqlite, use a larger pool and a faster timeout, to allow many
373 # repeat transactions with the same connection, as transactions on
374 # our sqlite DBs should be small and fast.
375 self._engine = sqlalchemy.create_engine(f"sqlite:///{filename}",
376 pool_timeout=5, pool_size=200)
378 with self.connection as connection:
379 metadata = sqlalchemy.MetaData()
380 metadata.reflect(bind=connection)
381 self._tables = metadata.tables
382 super().__init__(instrument=instrument, **kwargs)
384 @property
385 @contextlib.contextmanager
386 def connection(self):
387 yield self._engine.connect()
390class ApdbPostgresQuery(DbQuery):
391 """Connect to a running postgres APDB instance and load data from it.
393 This class connects to the database only when the ``connection`` context
394 manager is entered, and closes the connection after it exits.
396 Parameters
397 ----------
398 namespace : `str`
399 Database namespace to load from. Called "schema" in postgres docs.
400 url : `str`
401 Complete url to connect to postgres database, without prepended
402 ``postgresql://``.
403 instrument : `str`
404 Short name (e.g. "DECam") of instrument to make a dataId unpacker
405 and to add to the table columns; supports any gen3 instrument.
406 To be deprecated once this information is in the database.
407 """
409 def __init__(self, namespace, url="rubin@usdf-prompt-processing-dev.slac.stanford.edu/lsst-devl",
410 instrument=None, **kwargs):
411 self._connection_string = f"postgresql://{url}"
412 self._namespace = namespace
413 self._engine = sqlalchemy.create_engine(self._connection_string, poolclass=sqlalchemy.pool.NullPool)
415 with self.connection as connection:
416 metadata = sqlalchemy.MetaData(schema=namespace)
417 metadata.reflect(bind=connection)
418 # ensure tables don't have schema prepended
419 self._tables = {}
420 for table in metadata.tables.values():
421 self._tables[table.name] = table
422 super().__init__(instrument=instrument, **kwargs)
424 @property
425 @contextlib.contextmanager
426 def connection(self):
427 _connection = self._engine.connect()
428 try:
429 yield _connection
430 finally:
431 _connection.close()