Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
453 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-12 10:17 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-12 10:17 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from contextlib import closing
32from typing import TYPE_CHECKING, Any, cast
34import lsst.daf.base as dafBase
35import numpy as np
36import pandas
37import sqlalchemy
38from felis.simple import Table
39from lsst.pex.config import ChoiceField, Field, ListField
40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
41from lsst.utils.iteration import chunk_iterable
42from sqlalchemy import func, inspection, sql
43from sqlalchemy.engine import Inspector
44from sqlalchemy.pool import NullPool
46from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
47from .apdbSchema import ApdbTables
48from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
49from .timer import Timer
51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true
52 import sqlite3
54_LOG = logging.getLogger(__name__)
57if pandas.__version__.partition(".")[0] == "1": 57 ↛ 59line 57 didn't jump to line 59, because the condition on line 57 was never true
59 class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
60 """Terrible hack to workaround Pandas 1 incomplete support for
61 sqlalchemy 2.
63 We need to pass a Connection instance to pandas method, but in SA 2 the
64 Connection class lost ``connect`` method which is used by Pandas.
65 """
67 def __init__(self, connection: sqlalchemy.engine.Connection):
68 self._connection = connection
70 def connect(self, **kwargs: Any) -> Any:
71 return self
73 @property
74 def execute(self) -> Callable:
75 return self._connection.execute
77 @property
78 def execution_options(self) -> Callable:
79 return self._connection.execution_options
81 @property
82 def connection(self) -> Any:
83 return self._connection.connection
85 def __enter__(self) -> sqlalchemy.engine.Connection:
86 return self._connection
88 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
89 # Do not close connection here
90 pass
92 @inspection._inspects(_ConnectionHackSA2)
93 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
94 return Inspector._construct(Inspector._init_connection, conn._connection)
96else:
97 # Pandas 2.0 supports SQLAlchemy 2 correctly.
98 def _ConnectionHackSA2( # type: ignore[no-redef]
99 conn: sqlalchemy.engine.Connectable,
100 ) -> sqlalchemy.engine.Connectable:
101 return conn
104def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
105 """Change the type of uint64 columns to int64, and return copy of data
106 frame.
107 """
108 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
109 return df.astype({name: np.int64 for name in names})
112def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float:
113 """Calculate starting point for time-based source search.
115 Parameters
116 ----------
117 visit_time : `lsst.daf.base.DateTime`
118 Time of current visit.
119 months : `int`
120 Number of months in the sources history.
122 Returns
123 -------
124 time : `float`
125 A ``midpointMjdTai`` starting point, MJD time.
126 """
127 # TODO: `system` must be consistent with the code in ap_association
128 # (see DM-31996)
129 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
132def _onSqlite3Connect(
133 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
134) -> None:
135 # Enable foreign keys
136 with closing(dbapiConnection.cursor()) as cursor:
137 cursor.execute("PRAGMA foreign_keys=ON;")
140class ApdbSqlConfig(ApdbConfig):
141 """APDB configuration class for SQL implementation (ApdbSql)."""
143 db_url = Field[str](doc="SQLAlchemy database connection URI")
144 isolation_level = ChoiceField[str](
145 doc=(
146 "Transaction isolation level, if unset then backend-default value "
147 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
148 "Some backends may not support every allowed value."
149 ),
150 allowed={
151 "READ_COMMITTED": "Read committed",
152 "READ_UNCOMMITTED": "Read uncommitted",
153 "REPEATABLE_READ": "Repeatable read",
154 "SERIALIZABLE": "Serializable",
155 },
156 default=None,
157 optional=True,
158 )
159 connection_pool = Field[bool](
160 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
161 default=True,
162 )
163 connection_timeout = Field[float](
164 doc=(
165 "Maximum time to wait time for database lock to be released before exiting. "
166 "Defaults to sqlalchemy defaults if not set."
167 ),
168 default=None,
169 optional=True,
170 )
171 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
172 dia_object_index = ChoiceField[str](
173 doc="Indexing mode for DiaObject table",
174 allowed={
175 "baseline": "Index defined in baseline schema",
176 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
177 "last_object_table": "Separate DiaObjectLast table",
178 },
179 default="baseline",
180 )
181 htm_level = Field[int](doc="HTM indexing level", default=20)
182 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
183 htm_index_column = Field[str](
184 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
185 )
186 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
187 dia_object_columns = ListField[str](
188 doc="List of columns to read from DiaObject, by default read all columns", default=[]
189 )
190 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
191 namespace = Field[str](
192 doc=(
193 "Namespace or schema name for all tables in APDB database. "
194 "Presently only works for PostgreSQL backend. "
195 "If schema with this name does not exist it will be created when "
196 "APDB tables are created."
197 ),
198 default=None,
199 optional=True,
200 )
201 timer = Field[bool](doc="If True then print/log timing information", default=False)
203 def validate(self) -> None:
204 super().validate()
205 if len(self.ra_dec_columns) != 2:
206 raise ValueError("ra_dec_columns must have exactly two column names")
209class ApdbSqlTableData(ApdbTableData):
210 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
212 def __init__(self, result: sqlalchemy.engine.Result):
213 self._keys = list(result.keys())
214 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
216 def column_names(self) -> list[str]:
217 return self._keys
219 def rows(self) -> Iterable[tuple]:
220 return self._rows
223class ApdbSql(Apdb):
224 """Implementation of APDB interface based on SQL database.
226 The implementation is configured via standard ``pex_config`` mechanism
227 using `ApdbSqlConfig` configuration class. For an example of different
228 configurations check ``config/`` folder.
230 Parameters
231 ----------
232 config : `ApdbSqlConfig`
233 Configuration object.
234 """
236 ConfigClass = ApdbSqlConfig
238 def __init__(self, config: ApdbSqlConfig):
239 config.validate()
240 self.config = config
242 _LOG.debug("APDB Configuration:")
243 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
244 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
245 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
246 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
247 _LOG.debug(" schema_file: %s", self.config.schema_file)
248 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
249 _LOG.debug(" schema prefix: %s", self.config.prefix)
251 # engine is reused between multiple processes, make sure that we don't
252 # share connections by disabling pool (by using NullPool class)
253 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
254 conn_args: dict[str, Any] = dict()
255 if not self.config.connection_pool:
256 kw.update(poolclass=NullPool)
257 if self.config.isolation_level is not None:
258 kw.update(isolation_level=self.config.isolation_level)
259 elif self.config.db_url.startswith("sqlite"): # type: ignore
260 # Use READ_UNCOMMITTED as default value for sqlite.
261 kw.update(isolation_level="READ_UNCOMMITTED")
262 if self.config.connection_timeout is not None:
263 if self.config.db_url.startswith("sqlite"):
264 conn_args.update(timeout=self.config.connection_timeout)
265 elif self.config.db_url.startswith(("postgresql", "mysql")):
266 conn_args.update(connect_timeout=self.config.connection_timeout)
267 kw.update(connect_args=conn_args)
268 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
270 if self._engine.dialect.name == "sqlite":
271 # Need to enable foreign keys on every new connection.
272 sqlalchemy.event.listen(self._engine, "connect", _onSqlite3Connect)
274 self._schema = ApdbSqlSchema(
275 engine=self._engine,
276 dia_object_index=self.config.dia_object_index,
277 schema_file=self.config.schema_file,
278 schema_name=self.config.schema_name,
279 prefix=self.config.prefix,
280 namespace=self.config.namespace,
281 htm_index_column=self.config.htm_index_column,
282 use_insert_id=config.use_insert_id,
283 )
285 self.pixelator = HtmPixelization(self.config.htm_level)
286 self.use_insert_id = self._schema.has_insert_id
288 def tableRowCount(self) -> dict[str, int]:
289 """Return dictionary with the table names and row counts.
291 Used by ``ap_proto`` to keep track of the size of the database tables.
292 Depending on database technology this could be expensive operation.
294 Returns
295 -------
296 row_counts : `dict`
297 Dict where key is a table name and value is a row count.
298 """
299 res = {}
300 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
301 if self.config.dia_object_index == "last_object_table":
302 tables.append(ApdbTables.DiaObjectLast)
303 with self._engine.begin() as conn:
304 for table in tables:
305 sa_table = self._schema.get_table(table)
306 stmt = sql.select(func.count()).select_from(sa_table)
307 count: int = conn.execute(stmt).scalar_one()
308 res[table.name] = count
310 return res
312 def tableDef(self, table: ApdbTables) -> Table | None:
313 # docstring is inherited from a base class
314 return self._schema.tableSchemas.get(table)
316 def makeSchema(self, drop: bool = False) -> None:
317 # docstring is inherited from a base class
318 self._schema.makeSchema(drop=drop)
320 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
321 # docstring is inherited from a base class
323 # decide what columns we need
324 if self.config.dia_object_index == "last_object_table":
325 table_enum = ApdbTables.DiaObjectLast
326 else:
327 table_enum = ApdbTables.DiaObject
328 table = self._schema.get_table(table_enum)
329 if not self.config.dia_object_columns:
330 columns = self._schema.get_apdb_columns(table_enum)
331 else:
332 columns = [table.c[col] for col in self.config.dia_object_columns]
333 query = sql.select(*columns)
335 # build selection
336 query = query.where(self._filterRegion(table, region))
338 # select latest version of objects
339 if self.config.dia_object_index != "last_object_table":
340 query = query.where(table.c.validityEnd == None) # noqa: E711
342 # _LOG.debug("query: %s", query)
344 # execute select
345 with Timer("DiaObject select", self.config.timer):
346 with self._engine.begin() as conn:
347 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
348 _LOG.debug("found %s DiaObjects", len(objects))
349 return objects
351 def getDiaSources(
352 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
353 ) -> pandas.DataFrame | None:
354 # docstring is inherited from a base class
355 if self.config.read_sources_months == 0:
356 _LOG.debug("Skip DiaSources fetching")
357 return None
359 if object_ids is None:
360 # region-based select
361 return self._getDiaSourcesInRegion(region, visit_time)
362 else:
363 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
365 def getDiaForcedSources(
366 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
367 ) -> pandas.DataFrame | None:
368 # docstring is inherited from a base class
369 if self.config.read_forced_sources_months == 0:
370 _LOG.debug("Skip DiaForceSources fetching")
371 return None
373 if object_ids is None:
374 # This implementation does not support region-based selection.
375 raise NotImplementedError("Region-based selection is not supported")
377 # TODO: DateTime.MJD must be consistent with code in ap_association,
378 # alternatively we can fill midpointMjdTai ourselves in store()
379 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
380 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
382 with Timer("DiaForcedSource select", self.config.timer):
383 sources = self._getSourcesByIDs(
384 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
385 )
387 _LOG.debug("found %s DiaForcedSources", len(sources))
388 return sources
390 def containsVisitDetector(self, visit: int, detector: int) -> bool:
391 # docstring is inherited from a base class
392 raise NotImplementedError()
394 def containsCcdVisit(self, ccdVisitId: int) -> bool:
395 """Test whether data for a given visit-detector is present in the APDB.
397 This method is a placeholder until `Apdb.containsVisitDetector` can
398 be implemented.
400 Parameters
401 ----------
402 ccdVisitId : `int`
403 The packed ID of the visit-detector to search for.
405 Returns
406 -------
407 present : `bool`
408 `True` if some DiaSource records exist for the specified
409 observation, `False` otherwise.
410 """
411 # TODO: remove this method in favor of containsVisitDetector on either
412 # DM-41671 or a ticket that removes ccdVisitId from these tables
413 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
414 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
415 # Query should load only one leaf page of the index
416 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
417 # Backup query in case an image was processed but had no diaSources
418 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
420 with self._engine.begin() as conn:
421 result = conn.execute(query1).scalar_one_or_none()
422 if result is not None:
423 return True
424 else:
425 result = conn.execute(query2).scalar_one_or_none()
426 return result is not None
428 def getInsertIds(self) -> list[ApdbInsertId] | None:
429 # docstring is inherited from a base class
430 if not self._schema.has_insert_id:
431 return None
433 table = self._schema.get_table(ExtraTables.DiaInsertId)
434 assert table is not None, "has_insert_id=True means it must be defined"
435 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by(
436 table.columns["insert_time"]
437 )
438 with Timer("DiaObject insert id select", self.config.timer):
439 with self._engine.connect() as conn:
440 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
441 ids = []
442 for row in result:
443 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9))
444 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time))
445 return ids
447 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
448 # docstring is inherited from a base class
449 if not self._schema.has_insert_id:
450 raise ValueError("APDB is not configured for history storage")
452 table = self._schema.get_table(ExtraTables.DiaInsertId)
454 insert_ids = [id.id for id in ids]
455 where_clause = table.columns["insert_id"].in_(insert_ids)
456 stmt = table.delete().where(where_clause)
457 with self._engine.begin() as conn:
458 conn.execute(stmt)
460 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
461 # docstring is inherited from a base class
462 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
464 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
465 # docstring is inherited from a base class
466 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
468 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
469 # docstring is inherited from a base class
470 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
472 def _get_history(
473 self,
474 ids: Iterable[ApdbInsertId],
475 table_enum: ApdbTables,
476 history_table_enum: ExtraTables,
477 ) -> ApdbTableData:
478 """Return catalog of records for given insert identifiers, common
479 implementation for all DIA tables.
480 """
481 if not self._schema.has_insert_id:
482 raise ValueError("APDB is not configured for history retrieval")
484 table = self._schema.get_table(table_enum)
485 history_table = self._schema.get_table(history_table_enum)
487 join = table.join(history_table)
488 insert_ids = [id.id for id in ids]
489 history_id_column = history_table.columns["insert_id"]
490 apdb_columns = self._schema.get_apdb_columns(table_enum)
491 where_clause = history_id_column.in_(insert_ids)
492 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
494 # execute select
495 with Timer(f"{table.name} history select", self.config.timer):
496 with self._engine.begin() as conn:
497 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
498 return ApdbSqlTableData(result)
500 def getSSObjects(self) -> pandas.DataFrame:
501 # docstring is inherited from a base class
503 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
504 query = sql.select(*columns)
506 # execute select
507 with Timer("DiaObject select", self.config.timer):
508 with self._engine.begin() as conn:
509 objects = pandas.read_sql_query(query, conn)
510 _LOG.debug("found %s SSObjects", len(objects))
511 return objects
513 def store(
514 self,
515 visit_time: dafBase.DateTime,
516 objects: pandas.DataFrame,
517 sources: pandas.DataFrame | None = None,
518 forced_sources: pandas.DataFrame | None = None,
519 ) -> None:
520 # docstring is inherited from a base class
522 # We want to run all inserts in one transaction.
523 with self._engine.begin() as connection:
524 insert_id: ApdbInsertId | None = None
525 if self._schema.has_insert_id:
526 insert_id = ApdbInsertId.new_insert_id(visit_time)
527 self._storeInsertId(insert_id, visit_time, connection)
529 # fill pixelId column for DiaObjects
530 objects = self._add_obj_htm_index(objects)
531 self._storeDiaObjects(objects, visit_time, insert_id, connection)
533 if sources is not None:
534 # copy pixelId column from DiaObjects to DiaSources
535 sources = self._add_src_htm_index(sources, objects)
536 self._storeDiaSources(sources, insert_id, connection)
538 if forced_sources is not None:
539 self._storeDiaForcedSources(forced_sources, insert_id, connection)
541 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
542 # docstring is inherited from a base class
544 idColumn = "ssObjectId"
545 table = self._schema.get_table(ApdbTables.SSObject)
547 # everything to be done in single transaction
548 with self._engine.begin() as conn:
549 # Find record IDs that already exist. Some types like np.int64 can
550 # cause issues with sqlalchemy, convert them to int.
551 ids = sorted(int(oid) for oid in objects[idColumn])
553 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
554 result = conn.execute(query)
555 knownIds = set(row.ssObjectId for row in result)
557 filter = objects[idColumn].isin(knownIds)
558 toUpdate = cast(pandas.DataFrame, objects[filter])
559 toInsert = cast(pandas.DataFrame, objects[~filter])
561 # insert new records
562 if len(toInsert) > 0:
563 toInsert.to_sql(
564 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
565 )
567 # update existing records
568 if len(toUpdate) > 0:
569 whereKey = f"{idColumn}_param"
570 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
571 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
572 values = toUpdate.to_dict("records")
573 result = conn.execute(update, values)
575 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
576 # docstring is inherited from a base class
578 table = self._schema.get_table(ApdbTables.DiaSource)
579 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
581 with self._engine.begin() as conn:
582 # Need to make sure that every ID exists in the database, but
583 # executemany may not support rowcount, so iterate and check what
584 # is missing.
585 missing_ids: list[int] = []
586 for key, value in idMap.items():
587 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
588 result = conn.execute(query, params)
589 if result.rowcount == 0:
590 missing_ids.append(key)
591 if missing_ids:
592 missing = ",".join(str(item) for item in missing_ids)
593 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
595 def dailyJob(self) -> None:
596 # docstring is inherited from a base class
597 pass
599 def countUnassociatedObjects(self) -> int:
600 # docstring is inherited from a base class
602 # Retrieve the DiaObject table.
603 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
605 # Construct the sql statement.
606 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
607 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
609 # Return the count.
610 with self._engine.begin() as conn:
611 count = conn.execute(stmt).scalar_one()
613 return count
615 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
616 """Return catalog of DiaSource instances from given region.
618 Parameters
619 ----------
620 region : `lsst.sphgeom.Region`
621 Region to search for DIASources.
622 visit_time : `lsst.daf.base.DateTime`
623 Time of the current visit.
625 Returns
626 -------
627 catalog : `pandas.DataFrame`
628 Catalog containing DiaSource records.
629 """
630 # TODO: DateTime.MJD must be consistent with code in ap_association,
631 # alternatively we can fill midpointMjdTai ourselves in store()
632 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
633 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
635 table = self._schema.get_table(ApdbTables.DiaSource)
636 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
637 query = sql.select(*columns)
639 # build selection
640 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
641 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
642 query = query.where(where)
644 # execute select
645 with Timer("DiaSource select", self.config.timer):
646 with self._engine.begin() as conn:
647 sources = pandas.read_sql_query(query, conn)
648 _LOG.debug("found %s DiaSources", len(sources))
649 return sources
651 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
652 """Return catalog of DiaSource instances given set of DiaObject IDs.
654 Parameters
655 ----------
656 object_ids :
657 Collection of DiaObject IDs
658 visit_time : `lsst.daf.base.DateTime`
659 Time of the current visit.
661 Returns
662 -------
663 catalog : `pandas.DataFrame`
664 Catalog contaning DiaSource records.
665 """
666 # TODO: DateTime.MJD must be consistent with code in ap_association,
667 # alternatively we can fill midpointMjdTai ourselves in store()
668 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
669 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
671 with Timer("DiaSource select", self.config.timer):
672 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
674 _LOG.debug("found %s DiaSources", len(sources))
675 return sources
677 def _getSourcesByIDs(
678 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
679 ) -> pandas.DataFrame:
680 """Return catalog of DiaSource or DiaForcedSource instances given set
681 of DiaObject IDs.
683 Parameters
684 ----------
685 table : `sqlalchemy.schema.Table`
686 Database table.
687 object_ids :
688 Collection of DiaObject IDs
689 midpointMjdTai_start : `float`
690 Earliest midpointMjdTai to retrieve.
692 Returns
693 -------
694 catalog : `pandas.DataFrame`
695 Catalog contaning DiaSource records. `None` is returned if
696 ``read_sources_months`` configuration parameter is set to 0 or
697 when ``object_ids`` is empty.
698 """
699 table = self._schema.get_table(table_enum)
700 columns = self._schema.get_apdb_columns(table_enum)
702 sources: pandas.DataFrame | None = None
703 if len(object_ids) <= 0:
704 _LOG.debug("ID list is empty, just fetch empty result")
705 query = sql.select(*columns).where(sql.literal(False))
706 with self._engine.begin() as conn:
707 sources = pandas.read_sql_query(query, conn)
708 else:
709 data_frames: list[pandas.DataFrame] = []
710 for ids in chunk_iterable(sorted(object_ids), 1000):
711 query = sql.select(*columns)
713 # Some types like np.int64 can cause issues with
714 # sqlalchemy, convert them to int.
715 int_ids = [int(oid) for oid in ids]
717 # select by object id
718 query = query.where(
719 sql.expression.and_(
720 table.columns["diaObjectId"].in_(int_ids),
721 table.columns["midpointMjdTai"] > midpointMjdTai_start,
722 )
723 )
725 # execute select
726 with self._engine.begin() as conn:
727 data_frames.append(pandas.read_sql_query(query, conn))
729 if len(data_frames) == 1:
730 sources = data_frames[0]
731 else:
732 sources = pandas.concat(data_frames)
733 assert sources is not None, "Catalog cannot be None"
734 return sources
736 def _storeInsertId(
737 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
738 ) -> None:
739 dt = visit_time.toPython()
741 table = self._schema.get_table(ExtraTables.DiaInsertId)
743 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
744 connection.execute(stmt)
746 def _storeDiaObjects(
747 self,
748 objs: pandas.DataFrame,
749 visit_time: dafBase.DateTime,
750 insert_id: ApdbInsertId | None,
751 connection: sqlalchemy.engine.Connection,
752 ) -> None:
753 """Store catalog of DiaObjects from current visit.
755 Parameters
756 ----------
757 objs : `pandas.DataFrame`
758 Catalog with DiaObject records.
759 visit_time : `lsst.daf.base.DateTime`
760 Time of the visit.
761 insert_id : `ApdbInsertId`
762 Insert identifier.
763 """
764 # Some types like np.int64 can cause issues with sqlalchemy, convert
765 # them to int.
766 ids = sorted(int(oid) for oid in objs["diaObjectId"])
767 _LOG.debug("first object ID: %d", ids[0])
769 # TODO: Need to verify that we are using correct scale here for
770 # DATETIME representation (see DM-31996).
771 dt = visit_time.toPython()
773 # everything to be done in single transaction
774 if self.config.dia_object_index == "last_object_table":
775 # Insert and replace all records in LAST table.
776 table = self._schema.get_table(ApdbTables.DiaObjectLast)
778 # Drop the previous objects (pandas cannot upsert).
779 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
781 with Timer(table.name + " delete", self.config.timer):
782 res = connection.execute(query)
783 _LOG.debug("deleted %s objects", res.rowcount)
785 # DiaObjectLast is a subset of DiaObject, strip missing columns
786 last_column_names = [column.name for column in table.columns]
787 last_objs = objs[last_column_names]
788 last_objs = _coerce_uint64(last_objs)
790 if "lastNonForcedSource" in last_objs.columns:
791 # lastNonForcedSource is defined NOT NULL, fill it with visit
792 # time just in case.
793 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
794 else:
795 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
796 last_objs.set_index(extra_column.index, inplace=True)
797 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
799 with Timer("DiaObjectLast insert", self.config.timer):
800 last_objs.to_sql(
801 table.name,
802 _ConnectionHackSA2(connection),
803 if_exists="append",
804 index=False,
805 schema=table.schema,
806 )
807 else:
808 # truncate existing validity intervals
809 table = self._schema.get_table(ApdbTables.DiaObject)
811 update = (
812 table.update()
813 .values(validityEnd=dt)
814 .where(
815 sql.expression.and_(
816 table.columns["diaObjectId"].in_(ids),
817 table.columns["validityEnd"].is_(None),
818 )
819 )
820 )
822 # _LOG.debug("query: %s", query)
824 with Timer(table.name + " truncate", self.config.timer):
825 res = connection.execute(update)
826 _LOG.debug("truncated %s intervals", res.rowcount)
828 objs = _coerce_uint64(objs)
830 # Fill additional columns
831 extra_columns: list[pandas.Series] = []
832 if "validityStart" in objs.columns:
833 objs["validityStart"] = dt
834 else:
835 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
836 if "validityEnd" in objs.columns:
837 objs["validityEnd"] = None
838 else:
839 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
840 if "lastNonForcedSource" in objs.columns:
841 # lastNonForcedSource is defined NOT NULL, fill it with visit time
842 # just in case.
843 objs["lastNonForcedSource"].fillna(dt, inplace=True)
844 else:
845 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
846 if extra_columns:
847 objs.set_index(extra_columns[0].index, inplace=True)
848 objs = pandas.concat([objs] + extra_columns, axis="columns")
850 # Insert history data
851 table = self._schema.get_table(ApdbTables.DiaObject)
852 history_data: list[dict] = []
853 history_stmt: Any = None
854 if insert_id is not None:
855 pk_names = [column.name for column in table.primary_key]
856 history_data = objs[pk_names].to_dict("records")
857 for row in history_data:
858 row["insert_id"] = insert_id.id
859 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
860 history_stmt = history_table.insert()
862 # insert new versions
863 with Timer("DiaObject insert", self.config.timer):
864 objs.to_sql(
865 table.name,
866 _ConnectionHackSA2(connection),
867 if_exists="append",
868 index=False,
869 schema=table.schema,
870 )
871 if history_stmt is not None:
872 connection.execute(history_stmt, history_data)
874 def _storeDiaSources(
875 self,
876 sources: pandas.DataFrame,
877 insert_id: ApdbInsertId | None,
878 connection: sqlalchemy.engine.Connection,
879 ) -> None:
880 """Store catalog of DiaSources from current visit.
882 Parameters
883 ----------
884 sources : `pandas.DataFrame`
885 Catalog containing DiaSource records
886 """
887 table = self._schema.get_table(ApdbTables.DiaSource)
889 # Insert history data
890 history: list[dict] = []
891 history_stmt: Any = None
892 if insert_id is not None:
893 pk_names = [column.name for column in table.primary_key]
894 history = sources[pk_names].to_dict("records")
895 for row in history:
896 row["insert_id"] = insert_id.id
897 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
898 history_stmt = history_table.insert()
900 # everything to be done in single transaction
901 with Timer("DiaSource insert", self.config.timer):
902 sources = _coerce_uint64(sources)
903 sources.to_sql(
904 table.name,
905 _ConnectionHackSA2(connection),
906 if_exists="append",
907 index=False,
908 schema=table.schema,
909 )
910 if history_stmt is not None:
911 connection.execute(history_stmt, history)
913 def _storeDiaForcedSources(
914 self,
915 sources: pandas.DataFrame,
916 insert_id: ApdbInsertId | None,
917 connection: sqlalchemy.engine.Connection,
918 ) -> None:
919 """Store a set of DiaForcedSources from current visit.
921 Parameters
922 ----------
923 sources : `pandas.DataFrame`
924 Catalog containing DiaForcedSource records
925 """
926 table = self._schema.get_table(ApdbTables.DiaForcedSource)
928 # Insert history data
929 history: list[dict] = []
930 history_stmt: Any = None
931 if insert_id is not None:
932 pk_names = [column.name for column in table.primary_key]
933 history = sources[pk_names].to_dict("records")
934 for row in history:
935 row["insert_id"] = insert_id.id
936 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
937 history_stmt = history_table.insert()
939 # everything to be done in single transaction
940 with Timer("DiaForcedSource insert", self.config.timer):
941 sources = _coerce_uint64(sources)
942 sources.to_sql(
943 table.name,
944 _ConnectionHackSA2(connection),
945 if_exists="append",
946 index=False,
947 schema=table.schema,
948 )
949 if history_stmt is not None:
950 connection.execute(history_stmt, history)
952 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
953 """Generate a set of HTM indices covering specified region.
955 Parameters
956 ----------
957 region: `sphgeom.Region`
958 Region that needs to be indexed.
960 Returns
961 -------
962 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
963 """
964 _LOG.debug("region: %s", region)
965 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
967 return indices.ranges()
969 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
970 """Make SQLAlchemy expression for selecting records in a region."""
971 htm_index_column = table.columns[self.config.htm_index_column]
972 exprlist = []
973 pixel_ranges = self._htm_indices(region)
974 for low, upper in pixel_ranges:
975 upper -= 1
976 if low == upper:
977 exprlist.append(htm_index_column == low)
978 else:
979 exprlist.append(sql.expression.between(htm_index_column, low, upper))
981 return sql.expression.or_(*exprlist)
983 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
984 """Calculate HTM index for each record and add it to a DataFrame.
986 Notes
987 -----
988 This overrides any existing column in a DataFrame with the same name
989 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
990 returned.
991 """
992 # calculate HTM index for every DiaObject
993 htm_index = np.zeros(df.shape[0], dtype=np.int64)
994 ra_col, dec_col = self.config.ra_dec_columns
995 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
996 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
997 idx = self.pixelator.index(uv3d)
998 htm_index[i] = idx
999 df = df.copy()
1000 df[self.config.htm_index_column] = htm_index
1001 return df
1003 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1004 """Add pixelId column to DiaSource catalog.
1006 Notes
1007 -----
1008 This method copies pixelId value from a matching DiaObject record.
1009 DiaObject catalog needs to have a pixelId column filled by
1010 ``_add_obj_htm_index`` method and DiaSource records need to be
1011 associated to DiaObjects via ``diaObjectId`` column.
1013 This overrides any existing column in a DataFrame with the same name
1014 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1015 returned.
1016 """
1017 pixel_id_map: dict[int, int] = {
1018 diaObjectId: pixelId
1019 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1020 }
1021 # DiaSources associated with SolarSystemObjects do not have an
1022 # associated DiaObject hence we skip them and set their htmIndex
1023 # value to 0.
1024 pixel_id_map[0] = 0
1025 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1026 for i, diaObjId in enumerate(sources["diaObjectId"]):
1027 htm_index[i] = pixel_id_map[diaObjId]
1028 sources = sources.copy()
1029 sources[self.config.htm_index_column] = htm_index
1030 return sources