Coverage for python/lsst/dax/apdb/apdbSql.py: 15%
500 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 11:36 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 11:36 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from contextlib import closing, suppress
32from typing import TYPE_CHECKING, Any, cast
34import lsst.daf.base as dafBase
35import numpy as np
36import pandas
37import sqlalchemy
38from felis.simple import Table
39from lsst.pex.config import ChoiceField, Field, ListField
40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
41from lsst.utils.iteration import chunk_iterable
42from sqlalchemy import func, sql
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbConfigFreezer import ApdbConfigFreezer
47from .apdbMetadataSql import ApdbMetadataSql
48from .apdbSchema import ApdbTables
49from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
50from .timer import Timer
51from .versionTuple import IncompatibleVersionError, VersionTuple
53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 import sqlite3
56 from .apdbMetadata import ApdbMetadata
58_LOG = logging.getLogger(__name__)
60VERSION = VersionTuple(0, 1, 0)
61"""Version for the code defined in this module. This needs to be updated
62(following compatibility rules) when schema produced by this code changes.
63"""
66def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
67 """Change the type of uint64 columns to int64, and return copy of data
68 frame.
69 """
70 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
71 return df.astype({name: np.int64 for name in names})
74def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float:
75 """Calculate starting point for time-based source search.
77 Parameters
78 ----------
79 visit_time : `lsst.daf.base.DateTime`
80 Time of current visit.
81 months : `int`
82 Number of months in the sources history.
84 Returns
85 -------
86 time : `float`
87 A ``midpointMjdTai`` starting point, MJD time.
88 """
89 # TODO: `system` must be consistent with the code in ap_association
90 # (see DM-31996)
91 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
94def _onSqlite3Connect(
95 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
96) -> None:
97 # Enable foreign keys
98 with closing(dbapiConnection.cursor()) as cursor:
99 cursor.execute("PRAGMA foreign_keys=ON;")
102class ApdbSqlConfig(ApdbConfig):
103 """APDB configuration class for SQL implementation (ApdbSql)."""
105 db_url = Field[str](doc="SQLAlchemy database connection URI")
106 isolation_level = ChoiceField[str](
107 doc=(
108 "Transaction isolation level, if unset then backend-default value "
109 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
110 "Some backends may not support every allowed value."
111 ),
112 allowed={
113 "READ_COMMITTED": "Read committed",
114 "READ_UNCOMMITTED": "Read uncommitted",
115 "REPEATABLE_READ": "Repeatable read",
116 "SERIALIZABLE": "Serializable",
117 },
118 default=None,
119 optional=True,
120 )
121 connection_pool = Field[bool](
122 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
123 default=True,
124 )
125 connection_timeout = Field[float](
126 doc=(
127 "Maximum time to wait time for database lock to be released before exiting. "
128 "Defaults to sqlalchemy defaults if not set."
129 ),
130 default=None,
131 optional=True,
132 )
133 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
134 dia_object_index = ChoiceField[str](
135 doc="Indexing mode for DiaObject table",
136 allowed={
137 "baseline": "Index defined in baseline schema",
138 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
139 "last_object_table": "Separate DiaObjectLast table",
140 },
141 default="baseline",
142 )
143 htm_level = Field[int](doc="HTM indexing level", default=20)
144 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
145 htm_index_column = Field[str](
146 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
147 )
148 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
149 dia_object_columns = ListField[str](
150 doc="List of columns to read from DiaObject, by default read all columns", default=[]
151 )
152 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
153 namespace = Field[str](
154 doc=(
155 "Namespace or schema name for all tables in APDB database. "
156 "Presently only works for PostgreSQL backend. "
157 "If schema with this name does not exist it will be created when "
158 "APDB tables are created."
159 ),
160 default=None,
161 optional=True,
162 )
163 timer = Field[bool](doc="If True then print/log timing information", default=False)
165 def validate(self) -> None:
166 super().validate()
167 if len(self.ra_dec_columns) != 2:
168 raise ValueError("ra_dec_columns must have exactly two column names")
171class ApdbSqlTableData(ApdbTableData):
172 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
174 def __init__(self, result: sqlalchemy.engine.Result):
175 self._keys = list(result.keys())
176 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
178 def column_names(self) -> list[str]:
179 return self._keys
181 def rows(self) -> Iterable[tuple]:
182 return self._rows
185class ApdbSql(Apdb):
186 """Implementation of APDB interface based on SQL database.
188 The implementation is configured via standard ``pex_config`` mechanism
189 using `ApdbSqlConfig` configuration class. For an example of different
190 configurations check ``config/`` folder.
192 Parameters
193 ----------
194 config : `ApdbSqlConfig`
195 Configuration object.
196 """
198 ConfigClass = ApdbSqlConfig
200 metadataSchemaVersionKey = "version:schema"
201 """Name of the metadata key to store schema version number."""
203 metadataCodeVersionKey = "version:ApdbSql"
204 """Name of the metadata key to store code version number."""
206 metadataConfigKey = "config:apdb-sql.json"
207 """Name of the metadata key to store code version number."""
209 _frozen_parameters = (
210 "use_insert_id",
211 "dia_object_index",
212 "htm_level",
213 "htm_index_column",
214 "ra_dec_columns",
215 )
216 """Names of the config parameters to be frozen in metadata table."""
218 def __init__(self, config: ApdbSqlConfig):
219 self._engine = self._makeEngine(config)
221 sa_metadata = sqlalchemy.MetaData(schema=config.namespace)
222 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix)
223 meta_table: sqlalchemy.schema.Table | None = None
224 with suppress(sqlalchemy.exc.NoSuchTableError):
225 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine)
227 self._metadata = ApdbMetadataSql(self._engine, meta_table)
229 # Read frozen config from metadata.
230 config_json = self._metadata.get(self.metadataConfigKey)
231 if config_json is not None:
232 # Update config from metadata.
233 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters)
234 self.config = freezer.update(config, config_json)
235 else:
236 self.config = config
237 self.config.validate()
239 self._schema = ApdbSqlSchema(
240 engine=self._engine,
241 dia_object_index=self.config.dia_object_index,
242 schema_file=self.config.schema_file,
243 schema_name=self.config.schema_name,
244 prefix=self.config.prefix,
245 namespace=self.config.namespace,
246 htm_index_column=self.config.htm_index_column,
247 use_insert_id=self.config.use_insert_id,
248 )
250 if self._metadata.table_exists():
251 self._versionCheck(self._metadata)
253 self.pixelator = HtmPixelization(self.config.htm_level)
254 self.use_insert_id = self._schema.has_insert_id
256 _LOG.debug("APDB Configuration:")
257 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
258 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
259 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
260 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
261 _LOG.debug(" schema_file: %s", self.config.schema_file)
262 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
263 _LOG.debug(" schema prefix: %s", self.config.prefix)
265 @classmethod
266 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine:
267 """Make SQLALchemy engine based on configured parameters.
269 Parameters
270 ----------
271 config : `ApdbSqlConfig`
272 Configuration object.
273 """
274 # engine is reused between multiple processes, make sure that we don't
275 # share connections by disabling pool (by using NullPool class)
276 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo)
277 conn_args: dict[str, Any] = dict()
278 if not config.connection_pool:
279 kw.update(poolclass=NullPool)
280 if config.isolation_level is not None:
281 kw.update(isolation_level=config.isolation_level)
282 elif config.db_url.startswith("sqlite"): # type: ignore
283 # Use READ_UNCOMMITTED as default value for sqlite.
284 kw.update(isolation_level="READ_UNCOMMITTED")
285 if config.connection_timeout is not None:
286 if config.db_url.startswith("sqlite"):
287 conn_args.update(timeout=config.connection_timeout)
288 elif config.db_url.startswith(("postgresql", "mysql")):
289 conn_args.update(connect_timeout=config.connection_timeout)
290 kw.update(connect_args=conn_args)
291 engine = sqlalchemy.create_engine(config.db_url, **kw)
293 if engine.dialect.name == "sqlite":
294 # Need to enable foreign keys on every new connection.
295 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
297 return engine
299 def _versionCheck(self, metadata: ApdbMetadataSql) -> None:
300 """Check schema version compatibility."""
302 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
303 """Retrieve version number from given metadata key."""
304 if metadata.table_exists():
305 version_str = metadata.get(key)
306 if version_str is None:
307 # Should not happen with existing metadata table.
308 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
309 return VersionTuple.fromString(version_str)
310 return default
312 # For old databases where metadata table does not exist we assume that
313 # version of both code and schema is 0.1.0.
314 initial_version = VersionTuple(0, 1, 0)
315 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
316 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
318 # For now there is no way to make read-only APDB instances, assume that
319 # any access can do updates.
320 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
321 raise IncompatibleVersionError(
322 f"Configured schema version {self._schema.schemaVersion()} "
323 f"is not compatible with database version {db_schema_version}"
324 )
325 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
326 raise IncompatibleVersionError(
327 f"Current code version {self.apdbImplementationVersion()} "
328 f"is not compatible with database version {db_code_version}"
329 )
331 @classmethod
332 def apdbImplementationVersion(cls) -> VersionTuple:
333 # Docstring inherited from base class.
334 return VERSION
336 def apdbSchemaVersion(self) -> VersionTuple:
337 # Docstring inherited from base class.
338 return self._schema.schemaVersion()
340 def tableRowCount(self) -> dict[str, int]:
341 """Return dictionary with the table names and row counts.
343 Used by ``ap_proto`` to keep track of the size of the database tables.
344 Depending on database technology this could be expensive operation.
346 Returns
347 -------
348 row_counts : `dict`
349 Dict where key is a table name and value is a row count.
350 """
351 res = {}
352 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
353 if self.config.dia_object_index == "last_object_table":
354 tables.append(ApdbTables.DiaObjectLast)
355 with self._engine.begin() as conn:
356 for table in tables:
357 sa_table = self._schema.get_table(table)
358 stmt = sql.select(func.count()).select_from(sa_table)
359 count: int = conn.execute(stmt).scalar_one()
360 res[table.name] = count
362 return res
364 def tableDef(self, table: ApdbTables) -> Table | None:
365 # docstring is inherited from a base class
366 return self._schema.tableSchemas.get(table)
368 @classmethod
369 def makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None:
370 # docstring is inherited from a base class
372 if not isinstance(config, ApdbSqlConfig):
373 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
375 engine = cls._makeEngine(config)
377 # Ask schema class to create all tables.
378 schema = ApdbSqlSchema(
379 engine=engine,
380 dia_object_index=config.dia_object_index,
381 schema_file=config.schema_file,
382 schema_name=config.schema_name,
383 prefix=config.prefix,
384 namespace=config.namespace,
385 htm_index_column=config.htm_index_column,
386 use_insert_id=config.use_insert_id,
387 )
388 schema.makeSchema(drop=drop)
390 # Need metadata table to store few items in it, if table exists.
391 meta_table: sqlalchemy.schema.Table | None = None
392 with suppress(ValueError):
393 meta_table = schema.get_table(ApdbTables.metadata)
395 apdb_meta = ApdbMetadataSql(engine, meta_table)
396 if apdb_meta.table_exists():
397 # Fill version numbers, overwrite if they are already there.
398 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True)
399 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
401 # Store frozen part of a configuration in metadata.
402 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters)
403 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
405 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
406 # docstring is inherited from a base class
408 # decide what columns we need
409 if self.config.dia_object_index == "last_object_table":
410 table_enum = ApdbTables.DiaObjectLast
411 else:
412 table_enum = ApdbTables.DiaObject
413 table = self._schema.get_table(table_enum)
414 if not self.config.dia_object_columns:
415 columns = self._schema.get_apdb_columns(table_enum)
416 else:
417 columns = [table.c[col] for col in self.config.dia_object_columns]
418 query = sql.select(*columns)
420 # build selection
421 query = query.where(self._filterRegion(table, region))
423 # select latest version of objects
424 if self.config.dia_object_index != "last_object_table":
425 query = query.where(table.c.validityEnd == None) # noqa: E711
427 # _LOG.debug("query: %s", query)
429 # execute select
430 with Timer("DiaObject select", self.config.timer):
431 with self._engine.begin() as conn:
432 objects = pandas.read_sql_query(query, conn)
433 _LOG.debug("found %s DiaObjects", len(objects))
434 return objects
436 def getDiaSources(
437 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
438 ) -> pandas.DataFrame | None:
439 # docstring is inherited from a base class
440 if self.config.read_sources_months == 0:
441 _LOG.debug("Skip DiaSources fetching")
442 return None
444 if object_ids is None:
445 # region-based select
446 return self._getDiaSourcesInRegion(region, visit_time)
447 else:
448 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
450 def getDiaForcedSources(
451 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
452 ) -> pandas.DataFrame | None:
453 # docstring is inherited from a base class
454 if self.config.read_forced_sources_months == 0:
455 _LOG.debug("Skip DiaForceSources fetching")
456 return None
458 if object_ids is None:
459 # This implementation does not support region-based selection.
460 raise NotImplementedError("Region-based selection is not supported")
462 # TODO: DateTime.MJD must be consistent with code in ap_association,
463 # alternatively we can fill midpointMjdTai ourselves in store()
464 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
465 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
467 with Timer("DiaForcedSource select", self.config.timer):
468 sources = self._getSourcesByIDs(
469 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
470 )
472 _LOG.debug("found %s DiaForcedSources", len(sources))
473 return sources
475 def containsVisitDetector(self, visit: int, detector: int) -> bool:
476 # docstring is inherited from a base class
477 raise NotImplementedError()
479 def containsCcdVisit(self, ccdVisitId: int) -> bool:
480 """Test whether data for a given visit-detector is present in the APDB.
482 This method is a placeholder until `Apdb.containsVisitDetector` can
483 be implemented.
485 Parameters
486 ----------
487 ccdVisitId : `int`
488 The packed ID of the visit-detector to search for.
490 Returns
491 -------
492 present : `bool`
493 `True` if some DiaSource records exist for the specified
494 observation, `False` otherwise.
495 """
496 # TODO: remove this method in favor of containsVisitDetector on either
497 # DM-41671 or a ticket that removes ccdVisitId from these tables
498 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
499 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
500 # Query should load only one leaf page of the index
501 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
502 # Backup query in case an image was processed but had no diaSources
503 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
505 with self._engine.begin() as conn:
506 result = conn.execute(query1).scalar_one_or_none()
507 if result is not None:
508 return True
509 else:
510 result = conn.execute(query2).scalar_one_or_none()
511 return result is not None
513 def getInsertIds(self) -> list[ApdbInsertId] | None:
514 # docstring is inherited from a base class
515 if not self._schema.has_insert_id:
516 return None
518 table = self._schema.get_table(ExtraTables.DiaInsertId)
519 assert table is not None, "has_insert_id=True means it must be defined"
520 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by(
521 table.columns["insert_time"]
522 )
523 with Timer("DiaObject insert id select", self.config.timer):
524 with self._engine.connect() as conn:
525 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
526 ids = []
527 for row in result:
528 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9))
529 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time))
530 return ids
532 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
533 # docstring is inherited from a base class
534 if not self._schema.has_insert_id:
535 raise ValueError("APDB is not configured for history storage")
537 table = self._schema.get_table(ExtraTables.DiaInsertId)
539 insert_ids = [id.id for id in ids]
540 where_clause = table.columns["insert_id"].in_(insert_ids)
541 stmt = table.delete().where(where_clause)
542 with self._engine.begin() as conn:
543 conn.execute(stmt)
545 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
546 # docstring is inherited from a base class
547 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
549 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
550 # docstring is inherited from a base class
551 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
553 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
554 # docstring is inherited from a base class
555 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
557 def _get_history(
558 self,
559 ids: Iterable[ApdbInsertId],
560 table_enum: ApdbTables,
561 history_table_enum: ExtraTables,
562 ) -> ApdbTableData:
563 """Return catalog of records for given insert identifiers, common
564 implementation for all DIA tables.
565 """
566 if not self._schema.has_insert_id:
567 raise ValueError("APDB is not configured for history retrieval")
569 table = self._schema.get_table(table_enum)
570 history_table = self._schema.get_table(history_table_enum)
572 join = table.join(history_table)
573 insert_ids = [id.id for id in ids]
574 history_id_column = history_table.columns["insert_id"]
575 apdb_columns = self._schema.get_apdb_columns(table_enum)
576 where_clause = history_id_column.in_(insert_ids)
577 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
579 # execute select
580 with Timer(f"{table.name} history select", self.config.timer):
581 with self._engine.begin() as conn:
582 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
583 return ApdbSqlTableData(result)
585 def getSSObjects(self) -> pandas.DataFrame:
586 # docstring is inherited from a base class
588 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
589 query = sql.select(*columns)
591 # execute select
592 with Timer("DiaObject select", self.config.timer):
593 with self._engine.begin() as conn:
594 objects = pandas.read_sql_query(query, conn)
595 _LOG.debug("found %s SSObjects", len(objects))
596 return objects
598 def store(
599 self,
600 visit_time: dafBase.DateTime,
601 objects: pandas.DataFrame,
602 sources: pandas.DataFrame | None = None,
603 forced_sources: pandas.DataFrame | None = None,
604 ) -> None:
605 # docstring is inherited from a base class
607 # We want to run all inserts in one transaction.
608 with self._engine.begin() as connection:
609 insert_id: ApdbInsertId | None = None
610 if self._schema.has_insert_id:
611 insert_id = ApdbInsertId.new_insert_id(visit_time)
612 self._storeInsertId(insert_id, visit_time, connection)
614 # fill pixelId column for DiaObjects
615 objects = self._add_obj_htm_index(objects)
616 self._storeDiaObjects(objects, visit_time, insert_id, connection)
618 if sources is not None:
619 # copy pixelId column from DiaObjects to DiaSources
620 sources = self._add_src_htm_index(sources, objects)
621 self._storeDiaSources(sources, insert_id, connection)
623 if forced_sources is not None:
624 self._storeDiaForcedSources(forced_sources, insert_id, connection)
626 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
627 # docstring is inherited from a base class
629 idColumn = "ssObjectId"
630 table = self._schema.get_table(ApdbTables.SSObject)
632 # everything to be done in single transaction
633 with self._engine.begin() as conn:
634 # Find record IDs that already exist. Some types like np.int64 can
635 # cause issues with sqlalchemy, convert them to int.
636 ids = sorted(int(oid) for oid in objects[idColumn])
638 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
639 result = conn.execute(query)
640 knownIds = set(row.ssObjectId for row in result)
642 filter = objects[idColumn].isin(knownIds)
643 toUpdate = cast(pandas.DataFrame, objects[filter])
644 toInsert = cast(pandas.DataFrame, objects[~filter])
646 # insert new records
647 if len(toInsert) > 0:
648 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema)
650 # update existing records
651 if len(toUpdate) > 0:
652 whereKey = f"{idColumn}_param"
653 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
654 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
655 values = toUpdate.to_dict("records")
656 result = conn.execute(update, values)
658 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
659 # docstring is inherited from a base class
661 table = self._schema.get_table(ApdbTables.DiaSource)
662 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
664 with self._engine.begin() as conn:
665 # Need to make sure that every ID exists in the database, but
666 # executemany may not support rowcount, so iterate and check what
667 # is missing.
668 missing_ids: list[int] = []
669 for key, value in idMap.items():
670 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
671 result = conn.execute(query, params)
672 if result.rowcount == 0:
673 missing_ids.append(key)
674 if missing_ids:
675 missing = ",".join(str(item) for item in missing_ids)
676 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
678 def dailyJob(self) -> None:
679 # docstring is inherited from a base class
680 pass
682 def countUnassociatedObjects(self) -> int:
683 # docstring is inherited from a base class
685 # Retrieve the DiaObject table.
686 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
688 # Construct the sql statement.
689 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
690 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
692 # Return the count.
693 with self._engine.begin() as conn:
694 count = conn.execute(stmt).scalar_one()
696 return count
698 @property
699 def metadata(self) -> ApdbMetadata:
700 # docstring is inherited from a base class
701 if self._metadata is None:
702 raise RuntimeError("Database schema was not initialized.")
703 return self._metadata
705 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
706 """Return catalog of DiaSource instances from given region.
708 Parameters
709 ----------
710 region : `lsst.sphgeom.Region`
711 Region to search for DIASources.
712 visit_time : `lsst.daf.base.DateTime`
713 Time of the current visit.
715 Returns
716 -------
717 catalog : `pandas.DataFrame`
718 Catalog containing DiaSource records.
719 """
720 # TODO: DateTime.MJD must be consistent with code in ap_association,
721 # alternatively we can fill midpointMjdTai ourselves in store()
722 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
723 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
725 table = self._schema.get_table(ApdbTables.DiaSource)
726 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
727 query = sql.select(*columns)
729 # build selection
730 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
731 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
732 query = query.where(where)
734 # execute select
735 with Timer("DiaSource select", self.config.timer):
736 with self._engine.begin() as conn:
737 sources = pandas.read_sql_query(query, conn)
738 _LOG.debug("found %s DiaSources", len(sources))
739 return sources
741 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
742 """Return catalog of DiaSource instances given set of DiaObject IDs.
744 Parameters
745 ----------
746 object_ids :
747 Collection of DiaObject IDs
748 visit_time : `lsst.daf.base.DateTime`
749 Time of the current visit.
751 Returns
752 -------
753 catalog : `pandas.DataFrame`
754 Catalog contaning DiaSource records.
755 """
756 # TODO: DateTime.MJD must be consistent with code in ap_association,
757 # alternatively we can fill midpointMjdTai ourselves in store()
758 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
759 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
761 with Timer("DiaSource select", self.config.timer):
762 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
764 _LOG.debug("found %s DiaSources", len(sources))
765 return sources
767 def _getSourcesByIDs(
768 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
769 ) -> pandas.DataFrame:
770 """Return catalog of DiaSource or DiaForcedSource instances given set
771 of DiaObject IDs.
773 Parameters
774 ----------
775 table : `sqlalchemy.schema.Table`
776 Database table.
777 object_ids :
778 Collection of DiaObject IDs
779 midpointMjdTai_start : `float`
780 Earliest midpointMjdTai to retrieve.
782 Returns
783 -------
784 catalog : `pandas.DataFrame`
785 Catalog contaning DiaSource records. `None` is returned if
786 ``read_sources_months`` configuration parameter is set to 0 or
787 when ``object_ids`` is empty.
788 """
789 table = self._schema.get_table(table_enum)
790 columns = self._schema.get_apdb_columns(table_enum)
792 sources: pandas.DataFrame | None = None
793 if len(object_ids) <= 0:
794 _LOG.debug("ID list is empty, just fetch empty result")
795 query = sql.select(*columns).where(sql.literal(False))
796 with self._engine.begin() as conn:
797 sources = pandas.read_sql_query(query, conn)
798 else:
799 data_frames: list[pandas.DataFrame] = []
800 for ids in chunk_iterable(sorted(object_ids), 1000):
801 query = sql.select(*columns)
803 # Some types like np.int64 can cause issues with
804 # sqlalchemy, convert them to int.
805 int_ids = [int(oid) for oid in ids]
807 # select by object id
808 query = query.where(
809 sql.expression.and_(
810 table.columns["diaObjectId"].in_(int_ids),
811 table.columns["midpointMjdTai"] > midpointMjdTai_start,
812 )
813 )
815 # execute select
816 with self._engine.begin() as conn:
817 data_frames.append(pandas.read_sql_query(query, conn))
819 if len(data_frames) == 1:
820 sources = data_frames[0]
821 else:
822 sources = pandas.concat(data_frames)
823 assert sources is not None, "Catalog cannot be None"
824 return sources
826 def _storeInsertId(
827 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
828 ) -> None:
829 dt = visit_time.toPython()
831 table = self._schema.get_table(ExtraTables.DiaInsertId)
833 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
834 connection.execute(stmt)
836 def _storeDiaObjects(
837 self,
838 objs: pandas.DataFrame,
839 visit_time: dafBase.DateTime,
840 insert_id: ApdbInsertId | None,
841 connection: sqlalchemy.engine.Connection,
842 ) -> None:
843 """Store catalog of DiaObjects from current visit.
845 Parameters
846 ----------
847 objs : `pandas.DataFrame`
848 Catalog with DiaObject records.
849 visit_time : `lsst.daf.base.DateTime`
850 Time of the visit.
851 insert_id : `ApdbInsertId`
852 Insert identifier.
853 """
854 if len(objs) == 0:
855 _LOG.debug("No objects to write to database.")
856 return
858 # Some types like np.int64 can cause issues with sqlalchemy, convert
859 # them to int.
860 ids = sorted(int(oid) for oid in objs["diaObjectId"])
861 _LOG.debug("first object ID: %d", ids[0])
863 # TODO: Need to verify that we are using correct scale here for
864 # DATETIME representation (see DM-31996).
865 dt = visit_time.toPython()
867 # everything to be done in single transaction
868 if self.config.dia_object_index == "last_object_table":
869 # Insert and replace all records in LAST table.
870 table = self._schema.get_table(ApdbTables.DiaObjectLast)
872 # Drop the previous objects (pandas cannot upsert).
873 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
875 with Timer(table.name + " delete", self.config.timer):
876 res = connection.execute(query)
877 _LOG.debug("deleted %s objects", res.rowcount)
879 # DiaObjectLast is a subset of DiaObject, strip missing columns
880 last_column_names = [column.name for column in table.columns]
881 last_objs = objs[last_column_names]
882 last_objs = _coerce_uint64(last_objs)
884 if "lastNonForcedSource" in last_objs.columns:
885 # lastNonForcedSource is defined NOT NULL, fill it with visit
886 # time just in case.
887 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
888 else:
889 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
890 last_objs.set_index(extra_column.index, inplace=True)
891 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
893 with Timer("DiaObjectLast insert", self.config.timer):
894 last_objs.to_sql(
895 table.name,
896 connection,
897 if_exists="append",
898 index=False,
899 schema=table.schema,
900 )
901 else:
902 # truncate existing validity intervals
903 table = self._schema.get_table(ApdbTables.DiaObject)
905 update = (
906 table.update()
907 .values(validityEnd=dt)
908 .where(
909 sql.expression.and_(
910 table.columns["diaObjectId"].in_(ids),
911 table.columns["validityEnd"].is_(None),
912 )
913 )
914 )
916 # _LOG.debug("query: %s", query)
918 with Timer(table.name + " truncate", self.config.timer):
919 res = connection.execute(update)
920 _LOG.debug("truncated %s intervals", res.rowcount)
922 objs = _coerce_uint64(objs)
924 # Fill additional columns
925 extra_columns: list[pandas.Series] = []
926 if "validityStart" in objs.columns:
927 objs["validityStart"] = dt
928 else:
929 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
930 if "validityEnd" in objs.columns:
931 objs["validityEnd"] = None
932 else:
933 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
934 if "lastNonForcedSource" in objs.columns:
935 # lastNonForcedSource is defined NOT NULL, fill it with visit time
936 # just in case.
937 objs["lastNonForcedSource"].fillna(dt, inplace=True)
938 else:
939 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
940 if extra_columns:
941 objs.set_index(extra_columns[0].index, inplace=True)
942 objs = pandas.concat([objs] + extra_columns, axis="columns")
944 # Insert history data
945 table = self._schema.get_table(ApdbTables.DiaObject)
946 history_data: list[dict] = []
947 history_stmt: Any = None
948 if insert_id is not None:
949 pk_names = [column.name for column in table.primary_key]
950 history_data = objs[pk_names].to_dict("records")
951 for row in history_data:
952 row["insert_id"] = insert_id.id
953 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
954 history_stmt = history_table.insert()
956 # insert new versions
957 with Timer("DiaObject insert", self.config.timer):
958 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
959 if history_stmt is not None:
960 connection.execute(history_stmt, history_data)
962 def _storeDiaSources(
963 self,
964 sources: pandas.DataFrame,
965 insert_id: ApdbInsertId | None,
966 connection: sqlalchemy.engine.Connection,
967 ) -> None:
968 """Store catalog of DiaSources from current visit.
970 Parameters
971 ----------
972 sources : `pandas.DataFrame`
973 Catalog containing DiaSource records
974 """
975 table = self._schema.get_table(ApdbTables.DiaSource)
977 # Insert history data
978 history: list[dict] = []
979 history_stmt: Any = None
980 if insert_id is not None:
981 pk_names = [column.name for column in table.primary_key]
982 history = sources[pk_names].to_dict("records")
983 for row in history:
984 row["insert_id"] = insert_id.id
985 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
986 history_stmt = history_table.insert()
988 # everything to be done in single transaction
989 with Timer("DiaSource insert", self.config.timer):
990 sources = _coerce_uint64(sources)
991 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
992 if history_stmt is not None:
993 connection.execute(history_stmt, history)
995 def _storeDiaForcedSources(
996 self,
997 sources: pandas.DataFrame,
998 insert_id: ApdbInsertId | None,
999 connection: sqlalchemy.engine.Connection,
1000 ) -> None:
1001 """Store a set of DiaForcedSources from current visit.
1003 Parameters
1004 ----------
1005 sources : `pandas.DataFrame`
1006 Catalog containing DiaForcedSource records
1007 """
1008 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1010 # Insert history data
1011 history: list[dict] = []
1012 history_stmt: Any = None
1013 if insert_id is not None:
1014 pk_names = [column.name for column in table.primary_key]
1015 history = sources[pk_names].to_dict("records")
1016 for row in history:
1017 row["insert_id"] = insert_id.id
1018 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
1019 history_stmt = history_table.insert()
1021 # everything to be done in single transaction
1022 with Timer("DiaForcedSource insert", self.config.timer):
1023 sources = _coerce_uint64(sources)
1024 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1025 if history_stmt is not None:
1026 connection.execute(history_stmt, history)
1028 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1029 """Generate a set of HTM indices covering specified region.
1031 Parameters
1032 ----------
1033 region: `sphgeom.Region`
1034 Region that needs to be indexed.
1036 Returns
1037 -------
1038 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1039 """
1040 _LOG.debug("region: %s", region)
1041 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
1043 return indices.ranges()
1045 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1046 """Make SQLAlchemy expression for selecting records in a region."""
1047 htm_index_column = table.columns[self.config.htm_index_column]
1048 exprlist = []
1049 pixel_ranges = self._htm_indices(region)
1050 for low, upper in pixel_ranges:
1051 upper -= 1
1052 if low == upper:
1053 exprlist.append(htm_index_column == low)
1054 else:
1055 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1057 return sql.expression.or_(*exprlist)
1059 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1060 """Calculate HTM index for each record and add it to a DataFrame.
1062 Notes
1063 -----
1064 This overrides any existing column in a DataFrame with the same name
1065 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1066 returned.
1067 """
1068 # calculate HTM index for every DiaObject
1069 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1070 ra_col, dec_col = self.config.ra_dec_columns
1071 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1072 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1073 idx = self.pixelator.index(uv3d)
1074 htm_index[i] = idx
1075 df = df.copy()
1076 df[self.config.htm_index_column] = htm_index
1077 return df
1079 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1080 """Add pixelId column to DiaSource catalog.
1082 Notes
1083 -----
1084 This method copies pixelId value from a matching DiaObject record.
1085 DiaObject catalog needs to have a pixelId column filled by
1086 ``_add_obj_htm_index`` method and DiaSource records need to be
1087 associated to DiaObjects via ``diaObjectId`` column.
1089 This overrides any existing column in a DataFrame with the same name
1090 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1091 returned.
1092 """
1093 pixel_id_map: dict[int, int] = {
1094 diaObjectId: pixelId
1095 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1096 }
1097 # DiaSources associated with SolarSystemObjects do not have an
1098 # associated DiaObject hence we skip them and set their htmIndex
1099 # value to 0.
1100 pixel_id_map[0] = 0
1101 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1102 for i, diaObjId in enumerate(sources["diaObjectId"]):
1103 htm_index[i] = pixel_id_map[diaObjId]
1104 sources = sources.copy()
1105 sources[self.config.htm_index_column] = htm_index
1106 return sources