Coverage for python/lsst/dax/apdb/apdbSql.py: 14%
503 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-03 10:52 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-03 10:52 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from contextlib import closing, suppress
32from typing import TYPE_CHECKING, Any, cast
34import lsst.daf.base as dafBase
35import numpy as np
36import pandas
37import sqlalchemy
38from felis.simple import Table
39from lsst.pex.config import ChoiceField, Field, ListField
40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
41from lsst.utils.iteration import chunk_iterable
42from sqlalchemy import func, inspection, sql
43from sqlalchemy.engine import Inspector
44from sqlalchemy.pool import NullPool
46from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
47from .apdbMetadataSql import ApdbMetadataSql
48from .apdbSchema import ApdbTables
49from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
50from .timer import Timer
51from .versionTuple import IncompatibleVersionError, VersionTuple
53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 import sqlite3
56 from .apdbMetadata import ApdbMetadata
58_LOG = logging.getLogger(__name__)
60VERSION = VersionTuple(0, 1, 0)
61"""Version for the code defined in this module. This needs to be updated
62(following compatibility rules) when schema produced by this code changes.
63"""
65if pandas.__version__.partition(".")[0] == "1": 65 ↛ 67line 65 didn't jump to line 67, because the condition on line 65 was never true
67 class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
68 """Terrible hack to workaround Pandas 1 incomplete support for
69 sqlalchemy 2.
71 We need to pass a Connection instance to pandas method, but in SA 2 the
72 Connection class lost ``connect`` method which is used by Pandas.
73 """
75 def __init__(self, connection: sqlalchemy.engine.Connection):
76 self._connection = connection
78 def connect(self, **kwargs: Any) -> Any:
79 return self
81 @property
82 def execute(self) -> Callable:
83 return self._connection.execute
85 @property
86 def execution_options(self) -> Callable:
87 return self._connection.execution_options
89 @property
90 def connection(self) -> Any:
91 return self._connection.connection
93 def __enter__(self) -> sqlalchemy.engine.Connection:
94 return self._connection
96 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
97 # Do not close connection here
98 pass
100 @inspection._inspects(_ConnectionHackSA2)
101 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
102 return Inspector._construct(Inspector._init_connection, conn._connection)
104else:
105 # Pandas 2.0 supports SQLAlchemy 2 correctly.
106 def _ConnectionHackSA2( # type: ignore[no-redef]
107 conn: sqlalchemy.engine.Connectable,
108 ) -> sqlalchemy.engine.Connectable:
109 return conn
112def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
113 """Change the type of uint64 columns to int64, and return copy of data
114 frame.
115 """
116 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
117 return df.astype({name: np.int64 for name in names})
120def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float:
121 """Calculate starting point for time-based source search.
123 Parameters
124 ----------
125 visit_time : `lsst.daf.base.DateTime`
126 Time of current visit.
127 months : `int`
128 Number of months in the sources history.
130 Returns
131 -------
132 time : `float`
133 A ``midpointMjdTai`` starting point, MJD time.
134 """
135 # TODO: `system` must be consistent with the code in ap_association
136 # (see DM-31996)
137 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
140def _onSqlite3Connect(
141 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
142) -> None:
143 # Enable foreign keys
144 with closing(dbapiConnection.cursor()) as cursor:
145 cursor.execute("PRAGMA foreign_keys=ON;")
148class ApdbSqlConfig(ApdbConfig):
149 """APDB configuration class for SQL implementation (ApdbSql)."""
151 db_url = Field[str](doc="SQLAlchemy database connection URI")
152 isolation_level = ChoiceField[str](
153 doc=(
154 "Transaction isolation level, if unset then backend-default value "
155 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
156 "Some backends may not support every allowed value."
157 ),
158 allowed={
159 "READ_COMMITTED": "Read committed",
160 "READ_UNCOMMITTED": "Read uncommitted",
161 "REPEATABLE_READ": "Repeatable read",
162 "SERIALIZABLE": "Serializable",
163 },
164 default=None,
165 optional=True,
166 )
167 connection_pool = Field[bool](
168 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
169 default=True,
170 )
171 connection_timeout = Field[float](
172 doc=(
173 "Maximum time to wait time for database lock to be released before exiting. "
174 "Defaults to sqlalchemy defaults if not set."
175 ),
176 default=None,
177 optional=True,
178 )
179 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
180 dia_object_index = ChoiceField[str](
181 doc="Indexing mode for DiaObject table",
182 allowed={
183 "baseline": "Index defined in baseline schema",
184 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
185 "last_object_table": "Separate DiaObjectLast table",
186 },
187 default="baseline",
188 )
189 htm_level = Field[int](doc="HTM indexing level", default=20)
190 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
191 htm_index_column = Field[str](
192 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
193 )
194 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
195 dia_object_columns = ListField[str](
196 doc="List of columns to read from DiaObject, by default read all columns", default=[]
197 )
198 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
199 namespace = Field[str](
200 doc=(
201 "Namespace or schema name for all tables in APDB database. "
202 "Presently only works for PostgreSQL backend. "
203 "If schema with this name does not exist it will be created when "
204 "APDB tables are created."
205 ),
206 default=None,
207 optional=True,
208 )
209 timer = Field[bool](doc="If True then print/log timing information", default=False)
211 def validate(self) -> None:
212 super().validate()
213 if len(self.ra_dec_columns) != 2:
214 raise ValueError("ra_dec_columns must have exactly two column names")
217class ApdbSqlTableData(ApdbTableData):
218 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
220 def __init__(self, result: sqlalchemy.engine.Result):
221 self._keys = list(result.keys())
222 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
224 def column_names(self) -> list[str]:
225 return self._keys
227 def rows(self) -> Iterable[tuple]:
228 return self._rows
231class ApdbSql(Apdb):
232 """Implementation of APDB interface based on SQL database.
234 The implementation is configured via standard ``pex_config`` mechanism
235 using `ApdbSqlConfig` configuration class. For an example of different
236 configurations check ``config/`` folder.
238 Parameters
239 ----------
240 config : `ApdbSqlConfig`
241 Configuration object.
242 """
244 ConfigClass = ApdbSqlConfig
246 metadataSchemaVersionKey = "version:schema"
247 """Name of the metadata key to store schema version number."""
249 metadataCodeVersionKey = "version:ApdbSql"
250 """Name of the metadata key to store code version number."""
252 def __init__(self, config: ApdbSqlConfig):
253 config.validate()
254 self.config = config
256 _LOG.debug("APDB Configuration:")
257 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
258 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
259 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
260 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
261 _LOG.debug(" schema_file: %s", self.config.schema_file)
262 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
263 _LOG.debug(" schema prefix: %s", self.config.prefix)
265 # engine is reused between multiple processes, make sure that we don't
266 # share connections by disabling pool (by using NullPool class)
267 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
268 conn_args: dict[str, Any] = dict()
269 if not self.config.connection_pool:
270 kw.update(poolclass=NullPool)
271 if self.config.isolation_level is not None:
272 kw.update(isolation_level=self.config.isolation_level)
273 elif self.config.db_url.startswith("sqlite"): # type: ignore
274 # Use READ_UNCOMMITTED as default value for sqlite.
275 kw.update(isolation_level="READ_UNCOMMITTED")
276 if self.config.connection_timeout is not None:
277 if self.config.db_url.startswith("sqlite"):
278 conn_args.update(timeout=self.config.connection_timeout)
279 elif self.config.db_url.startswith(("postgresql", "mysql")):
280 conn_args.update(connect_timeout=self.config.connection_timeout)
281 kw.update(connect_args=conn_args)
282 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
284 if self._engine.dialect.name == "sqlite":
285 # Need to enable foreign keys on every new connection.
286 sqlalchemy.event.listen(self._engine, "connect", _onSqlite3Connect)
288 self._schema = ApdbSqlSchema(
289 engine=self._engine,
290 dia_object_index=self.config.dia_object_index,
291 schema_file=self.config.schema_file,
292 schema_name=self.config.schema_name,
293 prefix=self.config.prefix,
294 namespace=self.config.namespace,
295 htm_index_column=self.config.htm_index_column,
296 use_insert_id=config.use_insert_id,
297 )
299 self._metadata: ApdbMetadataSql | None = None
300 if not self._schema.empty():
301 table: sqlalchemy.schema.Table | None = None
302 with suppress(ValueError):
303 table = self._schema.get_table(ApdbTables.metadata)
304 self._metadata = ApdbMetadataSql(self._engine, table)
305 self._versionCheck(self._metadata)
307 self.pixelator = HtmPixelization(self.config.htm_level)
308 self.use_insert_id = self._schema.has_insert_id
310 def _versionCheck(self, metadata: ApdbMetadataSql) -> None:
311 """Check schema version compatibility."""
313 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
314 """Retrieve version number from given metadata key."""
315 if metadata.table_exists():
316 version_str = metadata.get(key)
317 if version_str is None:
318 # Should not happen with existing metadata table.
319 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
320 return VersionTuple.fromString(version_str)
321 return default
323 # For old databases where metadata table does not exist we assume that
324 # version of both code and schema is 0.1.0.
325 initial_version = VersionTuple(0, 1, 0)
326 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
327 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
329 # For now there is no way to make read-only APDB instances, assume that
330 # any access can do updates.
331 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
332 raise IncompatibleVersionError(
333 f"Configured schema version {self._schema.schemaVersion()} "
334 f"is not compatible with database version {db_schema_version}"
335 )
336 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
337 raise IncompatibleVersionError(
338 f"Current code version {self.apdbImplementationVersion()} "
339 f"is not compatible with database version {db_code_version}"
340 )
342 @classmethod
343 def apdbImplementationVersion(cls) -> VersionTuple:
344 # Docstring inherited from base class.
345 return VERSION
347 def apdbSchemaVersion(self) -> VersionTuple:
348 # Docstring inherited from base class.
349 return self._schema.schemaVersion()
351 def tableRowCount(self) -> dict[str, int]:
352 """Return dictionary with the table names and row counts.
354 Used by ``ap_proto`` to keep track of the size of the database tables.
355 Depending on database technology this could be expensive operation.
357 Returns
358 -------
359 row_counts : `dict`
360 Dict where key is a table name and value is a row count.
361 """
362 res = {}
363 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
364 if self.config.dia_object_index == "last_object_table":
365 tables.append(ApdbTables.DiaObjectLast)
366 with self._engine.begin() as conn:
367 for table in tables:
368 sa_table = self._schema.get_table(table)
369 stmt = sql.select(func.count()).select_from(sa_table)
370 count: int = conn.execute(stmt).scalar_one()
371 res[table.name] = count
373 return res
375 def tableDef(self, table: ApdbTables) -> Table | None:
376 # docstring is inherited from a base class
377 return self._schema.tableSchemas.get(table)
379 def makeSchema(self, drop: bool = False) -> None:
380 # docstring is inherited from a base class
381 self._schema.makeSchema(drop=drop)
382 # Need to reset metadata after table was created.
383 table: sqlalchemy.schema.Table | None = None
384 with suppress(ValueError):
385 table = self._schema.get_table(ApdbTables.metadata)
386 self._metadata = ApdbMetadataSql(self._engine, table)
388 if self._metadata.table_exists():
389 # Fill version numbers, but only if they are not defined.
390 if self._metadata.get(self.metadataSchemaVersionKey) is None:
391 self._metadata.set(self.metadataSchemaVersionKey, str(self._schema.schemaVersion()))
392 if self._metadata.get(self.metadataCodeVersionKey) is None:
393 self._metadata.set(self.metadataCodeVersionKey, str(self.apdbImplementationVersion()))
395 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
396 # docstring is inherited from a base class
398 # decide what columns we need
399 if self.config.dia_object_index == "last_object_table":
400 table_enum = ApdbTables.DiaObjectLast
401 else:
402 table_enum = ApdbTables.DiaObject
403 table = self._schema.get_table(table_enum)
404 if not self.config.dia_object_columns:
405 columns = self._schema.get_apdb_columns(table_enum)
406 else:
407 columns = [table.c[col] for col in self.config.dia_object_columns]
408 query = sql.select(*columns)
410 # build selection
411 query = query.where(self._filterRegion(table, region))
413 # select latest version of objects
414 if self.config.dia_object_index != "last_object_table":
415 query = query.where(table.c.validityEnd == None) # noqa: E711
417 # _LOG.debug("query: %s", query)
419 # execute select
420 with Timer("DiaObject select", self.config.timer):
421 with self._engine.begin() as conn:
422 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
423 _LOG.debug("found %s DiaObjects", len(objects))
424 return objects
426 def getDiaSources(
427 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
428 ) -> pandas.DataFrame | None:
429 # docstring is inherited from a base class
430 if self.config.read_sources_months == 0:
431 _LOG.debug("Skip DiaSources fetching")
432 return None
434 if object_ids is None:
435 # region-based select
436 return self._getDiaSourcesInRegion(region, visit_time)
437 else:
438 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
440 def getDiaForcedSources(
441 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
442 ) -> pandas.DataFrame | None:
443 # docstring is inherited from a base class
444 if self.config.read_forced_sources_months == 0:
445 _LOG.debug("Skip DiaForceSources fetching")
446 return None
448 if object_ids is None:
449 # This implementation does not support region-based selection.
450 raise NotImplementedError("Region-based selection is not supported")
452 # TODO: DateTime.MJD must be consistent with code in ap_association,
453 # alternatively we can fill midpointMjdTai ourselves in store()
454 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
455 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
457 with Timer("DiaForcedSource select", self.config.timer):
458 sources = self._getSourcesByIDs(
459 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
460 )
462 _LOG.debug("found %s DiaForcedSources", len(sources))
463 return sources
465 def containsVisitDetector(self, visit: int, detector: int) -> bool:
466 # docstring is inherited from a base class
467 raise NotImplementedError()
469 def containsCcdVisit(self, ccdVisitId: int) -> bool:
470 """Test whether data for a given visit-detector is present in the APDB.
472 This method is a placeholder until `Apdb.containsVisitDetector` can
473 be implemented.
475 Parameters
476 ----------
477 ccdVisitId : `int`
478 The packed ID of the visit-detector to search for.
480 Returns
481 -------
482 present : `bool`
483 `True` if some DiaSource records exist for the specified
484 observation, `False` otherwise.
485 """
486 # TODO: remove this method in favor of containsVisitDetector on either
487 # DM-41671 or a ticket that removes ccdVisitId from these tables
488 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
489 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
490 # Query should load only one leaf page of the index
491 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
492 # Backup query in case an image was processed but had no diaSources
493 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
495 with self._engine.begin() as conn:
496 result = conn.execute(query1).scalar_one_or_none()
497 if result is not None:
498 return True
499 else:
500 result = conn.execute(query2).scalar_one_or_none()
501 return result is not None
503 def getInsertIds(self) -> list[ApdbInsertId] | None:
504 # docstring is inherited from a base class
505 if not self._schema.has_insert_id:
506 return None
508 table = self._schema.get_table(ExtraTables.DiaInsertId)
509 assert table is not None, "has_insert_id=True means it must be defined"
510 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by(
511 table.columns["insert_time"]
512 )
513 with Timer("DiaObject insert id select", self.config.timer):
514 with self._engine.connect() as conn:
515 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
516 ids = []
517 for row in result:
518 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9))
519 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time))
520 return ids
522 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
523 # docstring is inherited from a base class
524 if not self._schema.has_insert_id:
525 raise ValueError("APDB is not configured for history storage")
527 table = self._schema.get_table(ExtraTables.DiaInsertId)
529 insert_ids = [id.id for id in ids]
530 where_clause = table.columns["insert_id"].in_(insert_ids)
531 stmt = table.delete().where(where_clause)
532 with self._engine.begin() as conn:
533 conn.execute(stmt)
535 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
536 # docstring is inherited from a base class
537 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
539 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
540 # docstring is inherited from a base class
541 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
543 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
544 # docstring is inherited from a base class
545 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
547 def _get_history(
548 self,
549 ids: Iterable[ApdbInsertId],
550 table_enum: ApdbTables,
551 history_table_enum: ExtraTables,
552 ) -> ApdbTableData:
553 """Return catalog of records for given insert identifiers, common
554 implementation for all DIA tables.
555 """
556 if not self._schema.has_insert_id:
557 raise ValueError("APDB is not configured for history retrieval")
559 table = self._schema.get_table(table_enum)
560 history_table = self._schema.get_table(history_table_enum)
562 join = table.join(history_table)
563 insert_ids = [id.id for id in ids]
564 history_id_column = history_table.columns["insert_id"]
565 apdb_columns = self._schema.get_apdb_columns(table_enum)
566 where_clause = history_id_column.in_(insert_ids)
567 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
569 # execute select
570 with Timer(f"{table.name} history select", self.config.timer):
571 with self._engine.begin() as conn:
572 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
573 return ApdbSqlTableData(result)
575 def getSSObjects(self) -> pandas.DataFrame:
576 # docstring is inherited from a base class
578 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
579 query = sql.select(*columns)
581 # execute select
582 with Timer("DiaObject select", self.config.timer):
583 with self._engine.begin() as conn:
584 objects = pandas.read_sql_query(query, conn)
585 _LOG.debug("found %s SSObjects", len(objects))
586 return objects
588 def store(
589 self,
590 visit_time: dafBase.DateTime,
591 objects: pandas.DataFrame,
592 sources: pandas.DataFrame | None = None,
593 forced_sources: pandas.DataFrame | None = None,
594 ) -> None:
595 # docstring is inherited from a base class
597 # We want to run all inserts in one transaction.
598 with self._engine.begin() as connection:
599 insert_id: ApdbInsertId | None = None
600 if self._schema.has_insert_id:
601 insert_id = ApdbInsertId.new_insert_id(visit_time)
602 self._storeInsertId(insert_id, visit_time, connection)
604 # fill pixelId column for DiaObjects
605 objects = self._add_obj_htm_index(objects)
606 self._storeDiaObjects(objects, visit_time, insert_id, connection)
608 if sources is not None:
609 # copy pixelId column from DiaObjects to DiaSources
610 sources = self._add_src_htm_index(sources, objects)
611 self._storeDiaSources(sources, insert_id, connection)
613 if forced_sources is not None:
614 self._storeDiaForcedSources(forced_sources, insert_id, connection)
616 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
617 # docstring is inherited from a base class
619 idColumn = "ssObjectId"
620 table = self._schema.get_table(ApdbTables.SSObject)
622 # everything to be done in single transaction
623 with self._engine.begin() as conn:
624 # Find record IDs that already exist. Some types like np.int64 can
625 # cause issues with sqlalchemy, convert them to int.
626 ids = sorted(int(oid) for oid in objects[idColumn])
628 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
629 result = conn.execute(query)
630 knownIds = set(row.ssObjectId for row in result)
632 filter = objects[idColumn].isin(knownIds)
633 toUpdate = cast(pandas.DataFrame, objects[filter])
634 toInsert = cast(pandas.DataFrame, objects[~filter])
636 # insert new records
637 if len(toInsert) > 0:
638 toInsert.to_sql(
639 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
640 )
642 # update existing records
643 if len(toUpdate) > 0:
644 whereKey = f"{idColumn}_param"
645 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
646 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
647 values = toUpdate.to_dict("records")
648 result = conn.execute(update, values)
650 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
651 # docstring is inherited from a base class
653 table = self._schema.get_table(ApdbTables.DiaSource)
654 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
656 with self._engine.begin() as conn:
657 # Need to make sure that every ID exists in the database, but
658 # executemany may not support rowcount, so iterate and check what
659 # is missing.
660 missing_ids: list[int] = []
661 for key, value in idMap.items():
662 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
663 result = conn.execute(query, params)
664 if result.rowcount == 0:
665 missing_ids.append(key)
666 if missing_ids:
667 missing = ",".join(str(item) for item in missing_ids)
668 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
670 def dailyJob(self) -> None:
671 # docstring is inherited from a base class
672 pass
674 def countUnassociatedObjects(self) -> int:
675 # docstring is inherited from a base class
677 # Retrieve the DiaObject table.
678 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
680 # Construct the sql statement.
681 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
682 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
684 # Return the count.
685 with self._engine.begin() as conn:
686 count = conn.execute(stmt).scalar_one()
688 return count
690 @property
691 def metadata(self) -> ApdbMetadata:
692 # docstring is inherited from a base class
693 if self._metadata is None:
694 raise RuntimeError("Database schema was not initialized.")
695 return self._metadata
697 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
698 """Return catalog of DiaSource instances from given region.
700 Parameters
701 ----------
702 region : `lsst.sphgeom.Region`
703 Region to search for DIASources.
704 visit_time : `lsst.daf.base.DateTime`
705 Time of the current visit.
707 Returns
708 -------
709 catalog : `pandas.DataFrame`
710 Catalog containing DiaSource records.
711 """
712 # TODO: DateTime.MJD must be consistent with code in ap_association,
713 # alternatively we can fill midpointMjdTai ourselves in store()
714 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
715 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
717 table = self._schema.get_table(ApdbTables.DiaSource)
718 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
719 query = sql.select(*columns)
721 # build selection
722 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
723 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
724 query = query.where(where)
726 # execute select
727 with Timer("DiaSource select", self.config.timer):
728 with self._engine.begin() as conn:
729 sources = pandas.read_sql_query(query, conn)
730 _LOG.debug("found %s DiaSources", len(sources))
731 return sources
733 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
734 """Return catalog of DiaSource instances given set of DiaObject IDs.
736 Parameters
737 ----------
738 object_ids :
739 Collection of DiaObject IDs
740 visit_time : `lsst.daf.base.DateTime`
741 Time of the current visit.
743 Returns
744 -------
745 catalog : `pandas.DataFrame`
746 Catalog contaning DiaSource records.
747 """
748 # TODO: DateTime.MJD must be consistent with code in ap_association,
749 # alternatively we can fill midpointMjdTai ourselves in store()
750 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
751 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
753 with Timer("DiaSource select", self.config.timer):
754 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
756 _LOG.debug("found %s DiaSources", len(sources))
757 return sources
759 def _getSourcesByIDs(
760 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
761 ) -> pandas.DataFrame:
762 """Return catalog of DiaSource or DiaForcedSource instances given set
763 of DiaObject IDs.
765 Parameters
766 ----------
767 table : `sqlalchemy.schema.Table`
768 Database table.
769 object_ids :
770 Collection of DiaObject IDs
771 midpointMjdTai_start : `float`
772 Earliest midpointMjdTai to retrieve.
774 Returns
775 -------
776 catalog : `pandas.DataFrame`
777 Catalog contaning DiaSource records. `None` is returned if
778 ``read_sources_months`` configuration parameter is set to 0 or
779 when ``object_ids`` is empty.
780 """
781 table = self._schema.get_table(table_enum)
782 columns = self._schema.get_apdb_columns(table_enum)
784 sources: pandas.DataFrame | None = None
785 if len(object_ids) <= 0:
786 _LOG.debug("ID list is empty, just fetch empty result")
787 query = sql.select(*columns).where(sql.literal(False))
788 with self._engine.begin() as conn:
789 sources = pandas.read_sql_query(query, conn)
790 else:
791 data_frames: list[pandas.DataFrame] = []
792 for ids in chunk_iterable(sorted(object_ids), 1000):
793 query = sql.select(*columns)
795 # Some types like np.int64 can cause issues with
796 # sqlalchemy, convert them to int.
797 int_ids = [int(oid) for oid in ids]
799 # select by object id
800 query = query.where(
801 sql.expression.and_(
802 table.columns["diaObjectId"].in_(int_ids),
803 table.columns["midpointMjdTai"] > midpointMjdTai_start,
804 )
805 )
807 # execute select
808 with self._engine.begin() as conn:
809 data_frames.append(pandas.read_sql_query(query, conn))
811 if len(data_frames) == 1:
812 sources = data_frames[0]
813 else:
814 sources = pandas.concat(data_frames)
815 assert sources is not None, "Catalog cannot be None"
816 return sources
818 def _storeInsertId(
819 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
820 ) -> None:
821 dt = visit_time.toPython()
823 table = self._schema.get_table(ExtraTables.DiaInsertId)
825 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
826 connection.execute(stmt)
828 def _storeDiaObjects(
829 self,
830 objs: pandas.DataFrame,
831 visit_time: dafBase.DateTime,
832 insert_id: ApdbInsertId | None,
833 connection: sqlalchemy.engine.Connection,
834 ) -> None:
835 """Store catalog of DiaObjects from current visit.
837 Parameters
838 ----------
839 objs : `pandas.DataFrame`
840 Catalog with DiaObject records.
841 visit_time : `lsst.daf.base.DateTime`
842 Time of the visit.
843 insert_id : `ApdbInsertId`
844 Insert identifier.
845 """
846 # Some types like np.int64 can cause issues with sqlalchemy, convert
847 # them to int.
848 ids = sorted(int(oid) for oid in objs["diaObjectId"])
849 _LOG.debug("first object ID: %d", ids[0])
851 # TODO: Need to verify that we are using correct scale here for
852 # DATETIME representation (see DM-31996).
853 dt = visit_time.toPython()
855 # everything to be done in single transaction
856 if self.config.dia_object_index == "last_object_table":
857 # Insert and replace all records in LAST table.
858 table = self._schema.get_table(ApdbTables.DiaObjectLast)
860 # Drop the previous objects (pandas cannot upsert).
861 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
863 with Timer(table.name + " delete", self.config.timer):
864 res = connection.execute(query)
865 _LOG.debug("deleted %s objects", res.rowcount)
867 # DiaObjectLast is a subset of DiaObject, strip missing columns
868 last_column_names = [column.name for column in table.columns]
869 last_objs = objs[last_column_names]
870 last_objs = _coerce_uint64(last_objs)
872 if "lastNonForcedSource" in last_objs.columns:
873 # lastNonForcedSource is defined NOT NULL, fill it with visit
874 # time just in case.
875 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
876 else:
877 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
878 last_objs.set_index(extra_column.index, inplace=True)
879 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
881 with Timer("DiaObjectLast insert", self.config.timer):
882 last_objs.to_sql(
883 table.name,
884 _ConnectionHackSA2(connection),
885 if_exists="append",
886 index=False,
887 schema=table.schema,
888 )
889 else:
890 # truncate existing validity intervals
891 table = self._schema.get_table(ApdbTables.DiaObject)
893 update = (
894 table.update()
895 .values(validityEnd=dt)
896 .where(
897 sql.expression.and_(
898 table.columns["diaObjectId"].in_(ids),
899 table.columns["validityEnd"].is_(None),
900 )
901 )
902 )
904 # _LOG.debug("query: %s", query)
906 with Timer(table.name + " truncate", self.config.timer):
907 res = connection.execute(update)
908 _LOG.debug("truncated %s intervals", res.rowcount)
910 objs = _coerce_uint64(objs)
912 # Fill additional columns
913 extra_columns: list[pandas.Series] = []
914 if "validityStart" in objs.columns:
915 objs["validityStart"] = dt
916 else:
917 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
918 if "validityEnd" in objs.columns:
919 objs["validityEnd"] = None
920 else:
921 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
922 if "lastNonForcedSource" in objs.columns:
923 # lastNonForcedSource is defined NOT NULL, fill it with visit time
924 # just in case.
925 objs["lastNonForcedSource"].fillna(dt, inplace=True)
926 else:
927 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
928 if extra_columns:
929 objs.set_index(extra_columns[0].index, inplace=True)
930 objs = pandas.concat([objs] + extra_columns, axis="columns")
932 # Insert history data
933 table = self._schema.get_table(ApdbTables.DiaObject)
934 history_data: list[dict] = []
935 history_stmt: Any = None
936 if insert_id is not None:
937 pk_names = [column.name for column in table.primary_key]
938 history_data = objs[pk_names].to_dict("records")
939 for row in history_data:
940 row["insert_id"] = insert_id.id
941 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
942 history_stmt = history_table.insert()
944 # insert new versions
945 with Timer("DiaObject insert", self.config.timer):
946 objs.to_sql(
947 table.name,
948 _ConnectionHackSA2(connection),
949 if_exists="append",
950 index=False,
951 schema=table.schema,
952 )
953 if history_stmt is not None:
954 connection.execute(history_stmt, history_data)
956 def _storeDiaSources(
957 self,
958 sources: pandas.DataFrame,
959 insert_id: ApdbInsertId | None,
960 connection: sqlalchemy.engine.Connection,
961 ) -> None:
962 """Store catalog of DiaSources from current visit.
964 Parameters
965 ----------
966 sources : `pandas.DataFrame`
967 Catalog containing DiaSource records
968 """
969 table = self._schema.get_table(ApdbTables.DiaSource)
971 # Insert history data
972 history: list[dict] = []
973 history_stmt: Any = None
974 if insert_id is not None:
975 pk_names = [column.name for column in table.primary_key]
976 history = sources[pk_names].to_dict("records")
977 for row in history:
978 row["insert_id"] = insert_id.id
979 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
980 history_stmt = history_table.insert()
982 # everything to be done in single transaction
983 with Timer("DiaSource insert", self.config.timer):
984 sources = _coerce_uint64(sources)
985 sources.to_sql(
986 table.name,
987 _ConnectionHackSA2(connection),
988 if_exists="append",
989 index=False,
990 schema=table.schema,
991 )
992 if history_stmt is not None:
993 connection.execute(history_stmt, history)
995 def _storeDiaForcedSources(
996 self,
997 sources: pandas.DataFrame,
998 insert_id: ApdbInsertId | None,
999 connection: sqlalchemy.engine.Connection,
1000 ) -> None:
1001 """Store a set of DiaForcedSources from current visit.
1003 Parameters
1004 ----------
1005 sources : `pandas.DataFrame`
1006 Catalog containing DiaForcedSource records
1007 """
1008 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1010 # Insert history data
1011 history: list[dict] = []
1012 history_stmt: Any = None
1013 if insert_id is not None:
1014 pk_names = [column.name for column in table.primary_key]
1015 history = sources[pk_names].to_dict("records")
1016 for row in history:
1017 row["insert_id"] = insert_id.id
1018 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
1019 history_stmt = history_table.insert()
1021 # everything to be done in single transaction
1022 with Timer("DiaForcedSource insert", self.config.timer):
1023 sources = _coerce_uint64(sources)
1024 sources.to_sql(
1025 table.name,
1026 _ConnectionHackSA2(connection),
1027 if_exists="append",
1028 index=False,
1029 schema=table.schema,
1030 )
1031 if history_stmt is not None:
1032 connection.execute(history_stmt, history)
1034 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1035 """Generate a set of HTM indices covering specified region.
1037 Parameters
1038 ----------
1039 region: `sphgeom.Region`
1040 Region that needs to be indexed.
1042 Returns
1043 -------
1044 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1045 """
1046 _LOG.debug("region: %s", region)
1047 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
1049 return indices.ranges()
1051 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1052 """Make SQLAlchemy expression for selecting records in a region."""
1053 htm_index_column = table.columns[self.config.htm_index_column]
1054 exprlist = []
1055 pixel_ranges = self._htm_indices(region)
1056 for low, upper in pixel_ranges:
1057 upper -= 1
1058 if low == upper:
1059 exprlist.append(htm_index_column == low)
1060 else:
1061 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1063 return sql.expression.or_(*exprlist)
1065 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1066 """Calculate HTM index for each record and add it to a DataFrame.
1068 Notes
1069 -----
1070 This overrides any existing column in a DataFrame with the same name
1071 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1072 returned.
1073 """
1074 # calculate HTM index for every DiaObject
1075 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1076 ra_col, dec_col = self.config.ra_dec_columns
1077 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1078 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1079 idx = self.pixelator.index(uv3d)
1080 htm_index[i] = idx
1081 df = df.copy()
1082 df[self.config.htm_index_column] = htm_index
1083 return df
1085 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1086 """Add pixelId column to DiaSource catalog.
1088 Notes
1089 -----
1090 This method copies pixelId value from a matching DiaObject record.
1091 DiaObject catalog needs to have a pixelId column filled by
1092 ``_add_obj_htm_index`` method and DiaSource records need to be
1093 associated to DiaObjects via ``diaObjectId`` column.
1095 This overrides any existing column in a DataFrame with the same name
1096 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1097 returned.
1098 """
1099 pixel_id_map: dict[int, int] = {
1100 diaObjectId: pixelId
1101 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1102 }
1103 # DiaSources associated with SolarSystemObjects do not have an
1104 # associated DiaObject hence we skip them and set their htmIndex
1105 # value to 0.
1106 pixel_id_map[0] = 0
1107 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1108 for i, diaObjId in enumerate(sources["diaObjectId"]):
1109 htm_index[i] = pixel_id_map[diaObjId]
1110 sources = sources.copy()
1111 sources[self.config.htm_index_column] = htm_index
1112 return sources