Coverage for python/lsst/dax/apdb/sql/apdbSql.py: 15%
508 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 02:52 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 02:52 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from contextlib import closing, suppress
32from typing import TYPE_CHECKING, Any, cast
34import astropy.time
35import numpy as np
36import pandas
37import sqlalchemy
38import sqlalchemy.dialects.postgresql
39import sqlalchemy.dialects.sqlite
40from lsst.pex.config import ChoiceField, Field, ListField
41from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
42from lsst.utils.iteration import chunk_iterable
43from sqlalchemy import func, sql
44from sqlalchemy.pool import NullPool
46from ..apdb import Apdb, ApdbConfig
47from ..apdbConfigFreezer import ApdbConfigFreezer
48from ..apdbReplica import ReplicaChunk
49from ..apdbSchema import ApdbTables
50from ..monitor import MonAgent
51from ..schema_model import Table
52from ..timer import Timer
53from ..versionTuple import IncompatibleVersionError, VersionTuple
54from .apdbMetadataSql import ApdbMetadataSql
55from .apdbSqlReplica import ApdbSqlReplica
56from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
58if TYPE_CHECKING:
59 import sqlite3
61 from ..apdbMetadata import ApdbMetadata
63_LOG = logging.getLogger(__name__)
65_MON = MonAgent(__name__)
67VERSION = VersionTuple(0, 1, 0)
68"""Version for the code controlling non-replication tables. This needs to be
69updated following compatibility rules when schema produced by this code
70changes.
71"""
74def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
75 """Change the type of uint64 columns to int64, and return copy of data
76 frame.
77 """
78 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
79 return df.astype({name: np.int64 for name in names})
82def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float:
83 """Calculate starting point for time-based source search.
85 Parameters
86 ----------
87 visit_time : `astropy.time.Time`
88 Time of current visit.
89 months : `int`
90 Number of months in the sources history.
92 Returns
93 -------
94 time : `float`
95 A ``midpointMjdTai`` starting point, MJD time.
96 """
97 # TODO: Use of MJD must be consistent with the code in ap_association
98 # (see DM-31996)
99 return visit_time.mjd - months * 30
102def _onSqlite3Connect(
103 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
104) -> None:
105 # Enable foreign keys
106 with closing(dbapiConnection.cursor()) as cursor:
107 cursor.execute("PRAGMA foreign_keys=ON;")
110class ApdbSqlConfig(ApdbConfig):
111 """APDB configuration class for SQL implementation (ApdbSql)."""
113 db_url = Field[str](doc="SQLAlchemy database connection URI")
114 isolation_level = ChoiceField[str](
115 doc=(
116 "Transaction isolation level, if unset then backend-default value "
117 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
118 "Some backends may not support every allowed value."
119 ),
120 allowed={
121 "READ_COMMITTED": "Read committed",
122 "READ_UNCOMMITTED": "Read uncommitted",
123 "REPEATABLE_READ": "Repeatable read",
124 "SERIALIZABLE": "Serializable",
125 },
126 default=None,
127 optional=True,
128 )
129 connection_pool = Field[bool](
130 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
131 default=True,
132 )
133 connection_timeout = Field[float](
134 doc=(
135 "Maximum time to wait time for database lock to be released before exiting. "
136 "Defaults to sqlalchemy defaults if not set."
137 ),
138 default=None,
139 optional=True,
140 )
141 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
142 dia_object_index = ChoiceField[str](
143 doc="Indexing mode for DiaObject table",
144 allowed={
145 "baseline": "Index defined in baseline schema",
146 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
147 "last_object_table": "Separate DiaObjectLast table",
148 },
149 default="baseline",
150 )
151 htm_level = Field[int](doc="HTM indexing level", default=20)
152 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
153 htm_index_column = Field[str](
154 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
155 )
156 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
157 dia_object_columns = ListField[str](
158 doc="List of columns to read from DiaObject, by default read all columns", default=[]
159 )
160 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
161 namespace = Field[str](
162 doc=(
163 "Namespace or schema name for all tables in APDB database. "
164 "Presently only works for PostgreSQL backend. "
165 "If schema with this name does not exist it will be created when "
166 "APDB tables are created."
167 ),
168 default=None,
169 optional=True,
170 )
171 timer = Field[bool](doc="If True then print/log timing information", default=False)
173 def validate(self) -> None:
174 super().validate()
175 if len(self.ra_dec_columns) != 2:
176 raise ValueError("ra_dec_columns must have exactly two column names")
179class ApdbSql(Apdb):
180 """Implementation of APDB interface based on SQL database.
182 The implementation is configured via standard ``pex_config`` mechanism
183 using `ApdbSqlConfig` configuration class. For an example of different
184 configurations check ``config/`` folder.
186 Parameters
187 ----------
188 config : `ApdbSqlConfig`
189 Configuration object.
190 """
192 ConfigClass = ApdbSqlConfig
194 metadataSchemaVersionKey = "version:schema"
195 """Name of the metadata key to store schema version number."""
197 metadataCodeVersionKey = "version:ApdbSql"
198 """Name of the metadata key to store code version number."""
200 metadataReplicaVersionKey = "version:ApdbSqlReplica"
201 """Name of the metadata key to store replica code version number."""
203 metadataConfigKey = "config:apdb-sql.json"
204 """Name of the metadata key to store code version number."""
206 _frozen_parameters = (
207 "use_insert_id",
208 "dia_object_index",
209 "htm_level",
210 "htm_index_column",
211 "ra_dec_columns",
212 )
213 """Names of the config parameters to be frozen in metadata table."""
215 def __init__(self, config: ApdbSqlConfig):
216 self._engine = self._makeEngine(config)
218 sa_metadata = sqlalchemy.MetaData(schema=config.namespace)
219 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix)
220 meta_table: sqlalchemy.schema.Table | None = None
221 with suppress(sqlalchemy.exc.NoSuchTableError):
222 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine)
224 self._metadata = ApdbMetadataSql(self._engine, meta_table)
226 # Read frozen config from metadata.
227 config_json = self._metadata.get(self.metadataConfigKey)
228 if config_json is not None:
229 # Update config from metadata.
230 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters)
231 self.config = freezer.update(config, config_json)
232 else:
233 self.config = config
234 self.config.validate()
236 self._schema = ApdbSqlSchema(
237 engine=self._engine,
238 dia_object_index=self.config.dia_object_index,
239 schema_file=self.config.schema_file,
240 schema_name=self.config.schema_name,
241 prefix=self.config.prefix,
242 namespace=self.config.namespace,
243 htm_index_column=self.config.htm_index_column,
244 enable_replica=self.config.use_insert_id,
245 )
247 if self._metadata.table_exists():
248 self._versionCheck(self._metadata)
250 self.pixelator = HtmPixelization(self.config.htm_level)
252 _LOG.debug("APDB Configuration:")
253 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
254 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
255 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
256 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
257 _LOG.debug(" schema_file: %s", self.config.schema_file)
258 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
259 _LOG.debug(" schema prefix: %s", self.config.prefix)
261 self._timer_args: list[MonAgent | logging.Logger] = [_MON]
262 if self.config.timer:
263 self._timer_args.append(_LOG)
265 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer:
266 """Create `Timer` instance given its name."""
267 return Timer(name, *self._timer_args, tags=tags)
269 @classmethod
270 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine:
271 """Make SQLALchemy engine based on configured parameters.
273 Parameters
274 ----------
275 config : `ApdbSqlConfig`
276 Configuration object.
277 """
278 # engine is reused between multiple processes, make sure that we don't
279 # share connections by disabling pool (by using NullPool class)
280 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo)
281 conn_args: dict[str, Any] = dict()
282 if not config.connection_pool:
283 kw.update(poolclass=NullPool)
284 if config.isolation_level is not None:
285 kw.update(isolation_level=config.isolation_level)
286 elif config.db_url.startswith("sqlite"): # type: ignore
287 # Use READ_UNCOMMITTED as default value for sqlite.
288 kw.update(isolation_level="READ_UNCOMMITTED")
289 if config.connection_timeout is not None:
290 if config.db_url.startswith("sqlite"):
291 conn_args.update(timeout=config.connection_timeout)
292 elif config.db_url.startswith(("postgresql", "mysql")):
293 conn_args.update(connect_timeout=config.connection_timeout)
294 kw.update(connect_args=conn_args)
295 engine = sqlalchemy.create_engine(config.db_url, **kw)
297 if engine.dialect.name == "sqlite":
298 # Need to enable foreign keys on every new connection.
299 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
301 return engine
303 def _versionCheck(self, metadata: ApdbMetadataSql) -> None:
304 """Check schema version compatibility."""
306 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
307 """Retrieve version number from given metadata key."""
308 if metadata.table_exists():
309 version_str = metadata.get(key)
310 if version_str is None:
311 # Should not happen with existing metadata table.
312 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
313 return VersionTuple.fromString(version_str)
314 return default
316 # For old databases where metadata table does not exist we assume that
317 # version of both code and schema is 0.1.0.
318 initial_version = VersionTuple(0, 1, 0)
319 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
320 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
322 # For now there is no way to make read-only APDB instances, assume that
323 # any access can do updates.
324 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
325 raise IncompatibleVersionError(
326 f"Configured schema version {self._schema.schemaVersion()} "
327 f"is not compatible with database version {db_schema_version}"
328 )
329 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
330 raise IncompatibleVersionError(
331 f"Current code version {self.apdbImplementationVersion()} "
332 f"is not compatible with database version {db_code_version}"
333 )
335 # Check replica code version only if replica is enabled.
336 if self._schema.has_replica_chunks:
337 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version)
338 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion()
339 if not code_replica_version.checkCompatibility(db_replica_version, True):
340 raise IncompatibleVersionError(
341 f"Current replication code version {code_replica_version} "
342 f"is not compatible with database version {db_replica_version}"
343 )
345 @classmethod
346 def apdbImplementationVersion(cls) -> VersionTuple:
347 # Docstring inherited from base class.
348 return VERSION
350 @classmethod
351 def init_database(
352 cls,
353 db_url: str,
354 *,
355 schema_file: str | None = None,
356 schema_name: str | None = None,
357 read_sources_months: int | None = None,
358 read_forced_sources_months: int | None = None,
359 use_insert_id: bool = False,
360 connection_timeout: int | None = None,
361 dia_object_index: str | None = None,
362 htm_level: int | None = None,
363 htm_index_column: str | None = None,
364 ra_dec_columns: list[str] | None = None,
365 prefix: str | None = None,
366 namespace: str | None = None,
367 drop: bool = False,
368 ) -> ApdbSqlConfig:
369 """Initialize new APDB instance and make configuration object for it.
371 Parameters
372 ----------
373 db_url : `str`
374 SQLAlchemy database URL.
375 schema_file : `str`, optional
376 Location of (YAML) configuration file with APDB schema. If not
377 specified then default location will be used.
378 schema_name : str | None
379 Name of the schema in YAML configuration file. If not specified
380 then default name will be used.
381 read_sources_months : `int`, optional
382 Number of months of history to read from DiaSource.
383 read_forced_sources_months : `int`, optional
384 Number of months of history to read from DiaForcedSource.
385 use_insert_id : `bool`
386 If True, make additional tables used for replication to PPDB.
387 connection_timeout : `int`, optional
388 Database connection timeout in seconds.
389 dia_object_index : `str`, optional
390 Indexing mode for DiaObject table.
391 htm_level : `int`, optional
392 HTM indexing level.
393 htm_index_column : `str`, optional
394 Name of a HTM index column for DiaObject and DiaSource tables.
395 ra_dec_columns : `list` [`str`], optional
396 Names of ra/dec columns in DiaObject table.
397 prefix : `str`, optional
398 Optional prefix for all table names.
399 namespace : `str`, optional
400 Name of the database schema for all APDB tables. If not specified
401 then default schema is used.
402 drop : `bool`, optional
403 If `True` then drop existing tables before re-creating the schema.
405 Returns
406 -------
407 config : `ApdbSqlConfig`
408 Resulting configuration object for a created APDB instance.
409 """
410 config = ApdbSqlConfig(db_url=db_url, use_insert_id=use_insert_id)
411 if schema_file is not None:
412 config.schema_file = schema_file
413 if schema_name is not None:
414 config.schema_name = schema_name
415 if read_sources_months is not None:
416 config.read_sources_months = read_sources_months
417 if read_forced_sources_months is not None:
418 config.read_forced_sources_months = read_forced_sources_months
419 if connection_timeout is not None:
420 config.connection_timeout = connection_timeout
421 if dia_object_index is not None:
422 config.dia_object_index = dia_object_index
423 if htm_level is not None:
424 config.htm_level = htm_level
425 if htm_index_column is not None:
426 config.htm_index_column = htm_index_column
427 if ra_dec_columns is not None:
428 config.ra_dec_columns = ra_dec_columns
429 if prefix is not None:
430 config.prefix = prefix
431 if namespace is not None:
432 config.namespace = namespace
434 cls._makeSchema(config, drop=drop)
436 return config
438 def apdbSchemaVersion(self) -> VersionTuple:
439 # Docstring inherited from base class.
440 return self._schema.schemaVersion()
442 def get_replica(self) -> ApdbSqlReplica:
443 """Return `ApdbReplica` instance for this database."""
444 return ApdbSqlReplica(self._schema, self._engine)
446 def tableRowCount(self) -> dict[str, int]:
447 """Return dictionary with the table names and row counts.
449 Used by ``ap_proto`` to keep track of the size of the database tables.
450 Depending on database technology this could be expensive operation.
452 Returns
453 -------
454 row_counts : `dict`
455 Dict where key is a table name and value is a row count.
456 """
457 res = {}
458 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
459 if self.config.dia_object_index == "last_object_table":
460 tables.append(ApdbTables.DiaObjectLast)
461 with self._engine.begin() as conn:
462 for table in tables:
463 sa_table = self._schema.get_table(table)
464 stmt = sql.select(func.count()).select_from(sa_table)
465 count: int = conn.execute(stmt).scalar_one()
466 res[table.name] = count
468 return res
470 def tableDef(self, table: ApdbTables) -> Table | None:
471 # docstring is inherited from a base class
472 return self._schema.tableSchemas.get(table)
474 @classmethod
475 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None:
476 # docstring is inherited from a base class
478 if not isinstance(config, ApdbSqlConfig):
479 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
481 engine = cls._makeEngine(config)
483 # Ask schema class to create all tables.
484 schema = ApdbSqlSchema(
485 engine=engine,
486 dia_object_index=config.dia_object_index,
487 schema_file=config.schema_file,
488 schema_name=config.schema_name,
489 prefix=config.prefix,
490 namespace=config.namespace,
491 htm_index_column=config.htm_index_column,
492 enable_replica=config.use_insert_id,
493 )
494 schema.makeSchema(drop=drop)
496 # Need metadata table to store few items in it, if table exists.
497 meta_table: sqlalchemy.schema.Table | None = None
498 with suppress(ValueError):
499 meta_table = schema.get_table(ApdbTables.metadata)
501 apdb_meta = ApdbMetadataSql(engine, meta_table)
502 if apdb_meta.table_exists():
503 # Fill version numbers, overwrite if they are already there.
504 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True)
505 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
506 if config.use_insert_id:
507 # Only store replica code version if replcia is enabled.
508 apdb_meta.set(
509 cls.metadataReplicaVersionKey,
510 str(ApdbSqlReplica.apdbReplicaImplementationVersion()),
511 force=True,
512 )
514 # Store frozen part of a configuration in metadata.
515 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters)
516 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
518 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
519 # docstring is inherited from a base class
521 # decide what columns we need
522 if self.config.dia_object_index == "last_object_table":
523 table_enum = ApdbTables.DiaObjectLast
524 else:
525 table_enum = ApdbTables.DiaObject
526 table = self._schema.get_table(table_enum)
527 if not self.config.dia_object_columns:
528 columns = self._schema.get_apdb_columns(table_enum)
529 else:
530 columns = [table.c[col] for col in self.config.dia_object_columns]
531 query = sql.select(*columns)
533 # build selection
534 query = query.where(self._filterRegion(table, region))
536 # select latest version of objects
537 if self.config.dia_object_index != "last_object_table":
538 query = query.where(table.c.validityEnd == None) # noqa: E711
540 # _LOG.debug("query: %s", query)
542 # execute select
543 with self._timer("select_time", tags={"table": "DiaObject"}):
544 with self._engine.begin() as conn:
545 objects = pandas.read_sql_query(query, conn)
546 _LOG.debug("found %s DiaObjects", len(objects))
547 return objects
549 def getDiaSources(
550 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
551 ) -> pandas.DataFrame | None:
552 # docstring is inherited from a base class
553 if self.config.read_sources_months == 0:
554 _LOG.debug("Skip DiaSources fetching")
555 return None
557 if object_ids is None:
558 # region-based select
559 return self._getDiaSourcesInRegion(region, visit_time)
560 else:
561 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
563 def getDiaForcedSources(
564 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
565 ) -> pandas.DataFrame | None:
566 # docstring is inherited from a base class
567 if self.config.read_forced_sources_months == 0:
568 _LOG.debug("Skip DiaForceSources fetching")
569 return None
571 if object_ids is None:
572 # This implementation does not support region-based selection.
573 raise NotImplementedError("Region-based selection is not supported")
575 # TODO: DateTime.MJD must be consistent with code in ap_association,
576 # alternatively we can fill midpointMjdTai ourselves in store()
577 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
578 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
580 with self._timer("select_time", tags={"table": "DiaForcedSource"}):
581 sources = self._getSourcesByIDs(
582 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
583 )
585 _LOG.debug("found %s DiaForcedSources", len(sources))
586 return sources
588 def containsVisitDetector(self, visit: int, detector: int) -> bool:
589 # docstring is inherited from a base class
590 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
591 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
592 # Query should load only one leaf page of the index
593 query1 = sql.select(src_table.c.visit).filter_by(visit=visit, detector=detector).limit(1)
595 with self._engine.begin() as conn:
596 result = conn.execute(query1).scalar_one_or_none()
597 if result is not None:
598 return True
599 else:
600 # Backup query if an image was processed but had no diaSources
601 query2 = sql.select(frcsrc_table.c.visit).filter_by(visit=visit, detector=detector).limit(1)
602 result = conn.execute(query2).scalar_one_or_none()
603 return result is not None
605 def getSSObjects(self) -> pandas.DataFrame:
606 # docstring is inherited from a base class
608 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
609 query = sql.select(*columns)
611 # execute select
612 with self._timer("SSObject_select_time", tags={"table": "SSObject"}):
613 with self._engine.begin() as conn:
614 objects = pandas.read_sql_query(query, conn)
615 _LOG.debug("found %s SSObjects", len(objects))
616 return objects
618 def store(
619 self,
620 visit_time: astropy.time.Time,
621 objects: pandas.DataFrame,
622 sources: pandas.DataFrame | None = None,
623 forced_sources: pandas.DataFrame | None = None,
624 ) -> None:
625 # docstring is inherited from a base class
627 # We want to run all inserts in one transaction.
628 with self._engine.begin() as connection:
629 replica_chunk: ReplicaChunk | None = None
630 if self._schema.has_replica_chunks:
631 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds)
632 self._storeReplicaChunk(replica_chunk, visit_time, connection)
634 # fill pixelId column for DiaObjects
635 objects = self._add_obj_htm_index(objects)
636 self._storeDiaObjects(objects, visit_time, replica_chunk, connection)
638 if sources is not None:
639 # copy pixelId column from DiaObjects to DiaSources
640 sources = self._add_src_htm_index(sources, objects)
641 self._storeDiaSources(sources, replica_chunk, connection)
643 if forced_sources is not None:
644 self._storeDiaForcedSources(forced_sources, replica_chunk, connection)
646 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
647 # docstring is inherited from a base class
649 idColumn = "ssObjectId"
650 table = self._schema.get_table(ApdbTables.SSObject)
652 # everything to be done in single transaction
653 with self._engine.begin() as conn:
654 # Find record IDs that already exist. Some types like np.int64 can
655 # cause issues with sqlalchemy, convert them to int.
656 ids = sorted(int(oid) for oid in objects[idColumn])
658 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
659 result = conn.execute(query)
660 knownIds = set(row.ssObjectId for row in result)
662 filter = objects[idColumn].isin(knownIds)
663 toUpdate = cast(pandas.DataFrame, objects[filter])
664 toInsert = cast(pandas.DataFrame, objects[~filter])
666 # insert new records
667 if len(toInsert) > 0:
668 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema)
670 # update existing records
671 if len(toUpdate) > 0:
672 whereKey = f"{idColumn}_param"
673 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
674 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
675 values = toUpdate.to_dict("records")
676 result = conn.execute(update, values)
678 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
679 # docstring is inherited from a base class
681 table = self._schema.get_table(ApdbTables.DiaSource)
682 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
684 with self._engine.begin() as conn:
685 # Need to make sure that every ID exists in the database, but
686 # executemany may not support rowcount, so iterate and check what
687 # is missing.
688 missing_ids: list[int] = []
689 for key, value in idMap.items():
690 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
691 result = conn.execute(query, params)
692 if result.rowcount == 0:
693 missing_ids.append(key)
694 if missing_ids:
695 missing = ",".join(str(item) for item in missing_ids)
696 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
698 def dailyJob(self) -> None:
699 # docstring is inherited from a base class
700 pass
702 def countUnassociatedObjects(self) -> int:
703 # docstring is inherited from a base class
705 # Retrieve the DiaObject table.
706 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
708 # Construct the sql statement.
709 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
710 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
712 # Return the count.
713 with self._engine.begin() as conn:
714 count = conn.execute(stmt).scalar_one()
716 return count
718 @property
719 def metadata(self) -> ApdbMetadata:
720 # docstring is inherited from a base class
721 if self._metadata is None:
722 raise RuntimeError("Database schema was not initialized.")
723 return self._metadata
725 def _getDiaSourcesInRegion(self, region: Region, visit_time: astropy.time.Time) -> pandas.DataFrame:
726 """Return catalog of DiaSource instances from given region.
728 Parameters
729 ----------
730 region : `lsst.sphgeom.Region`
731 Region to search for DIASources.
732 visit_time : `astropy.time.Time`
733 Time of the current visit.
735 Returns
736 -------
737 catalog : `pandas.DataFrame`
738 Catalog containing DiaSource records.
739 """
740 # TODO: DateTime.MJD must be consistent with code in ap_association,
741 # alternatively we can fill midpointMjdTai ourselves in store()
742 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
743 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
745 table = self._schema.get_table(ApdbTables.DiaSource)
746 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
747 query = sql.select(*columns)
749 # build selection
750 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
751 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
752 query = query.where(where)
754 # execute select
755 with self._timer("DiaSource_select_time", tags={"table": "DiaSource"}):
756 with self._engine.begin() as conn:
757 sources = pandas.read_sql_query(query, conn)
758 _LOG.debug("found %s DiaSources", len(sources))
759 return sources
761 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: astropy.time.Time) -> pandas.DataFrame:
762 """Return catalog of DiaSource instances given set of DiaObject IDs.
764 Parameters
765 ----------
766 object_ids :
767 Collection of DiaObject IDs
768 visit_time : `astropy.time.Time`
769 Time of the current visit.
771 Returns
772 -------
773 catalog : `pandas.DataFrame`
774 Catalog contaning DiaSource records.
775 """
776 # TODO: DateTime.MJD must be consistent with code in ap_association,
777 # alternatively we can fill midpointMjdTai ourselves in store()
778 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
779 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
781 with self._timer("select_time", tags={"table": "DiaSource"}):
782 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
784 _LOG.debug("found %s DiaSources", len(sources))
785 return sources
787 def _getSourcesByIDs(
788 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
789 ) -> pandas.DataFrame:
790 """Return catalog of DiaSource or DiaForcedSource instances given set
791 of DiaObject IDs.
793 Parameters
794 ----------
795 table : `sqlalchemy.schema.Table`
796 Database table.
797 object_ids :
798 Collection of DiaObject IDs
799 midpointMjdTai_start : `float`
800 Earliest midpointMjdTai to retrieve.
802 Returns
803 -------
804 catalog : `pandas.DataFrame`
805 Catalog contaning DiaSource records. `None` is returned if
806 ``read_sources_months`` configuration parameter is set to 0 or
807 when ``object_ids`` is empty.
808 """
809 table = self._schema.get_table(table_enum)
810 columns = self._schema.get_apdb_columns(table_enum)
812 sources: pandas.DataFrame | None = None
813 if len(object_ids) <= 0:
814 _LOG.debug("ID list is empty, just fetch empty result")
815 query = sql.select(*columns).where(sql.literal(False))
816 with self._engine.begin() as conn:
817 sources = pandas.read_sql_query(query, conn)
818 else:
819 data_frames: list[pandas.DataFrame] = []
820 for ids in chunk_iterable(sorted(object_ids), 1000):
821 query = sql.select(*columns)
823 # Some types like np.int64 can cause issues with
824 # sqlalchemy, convert them to int.
825 int_ids = [int(oid) for oid in ids]
827 # select by object id
828 query = query.where(
829 sql.expression.and_(
830 table.columns["diaObjectId"].in_(int_ids),
831 table.columns["midpointMjdTai"] > midpointMjdTai_start,
832 )
833 )
835 # execute select
836 with self._engine.begin() as conn:
837 data_frames.append(pandas.read_sql_query(query, conn))
839 if len(data_frames) == 1:
840 sources = data_frames[0]
841 else:
842 sources = pandas.concat(data_frames)
843 assert sources is not None, "Catalog cannot be None"
844 return sources
846 def _storeReplicaChunk(
847 self,
848 replica_chunk: ReplicaChunk,
849 visit_time: astropy.time.Time,
850 connection: sqlalchemy.engine.Connection,
851 ) -> None:
852 dt = visit_time.datetime
854 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks)
856 # We need UPSERT which is dialect-specific construct
857 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id}
858 row = {"apdb_replica_chunk": replica_chunk.id} | values
859 if connection.dialect.name == "sqlite":
860 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table)
861 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values)
862 connection.execute(insert_sqlite, row)
863 elif connection.dialect.name == "postgresql":
864 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table)
865 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values)
866 connection.execute(insert_pg, row)
867 else:
868 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.")
870 def _storeDiaObjects(
871 self,
872 objs: pandas.DataFrame,
873 visit_time: astropy.time.Time,
874 replica_chunk: ReplicaChunk | None,
875 connection: sqlalchemy.engine.Connection,
876 ) -> None:
877 """Store catalog of DiaObjects from current visit.
879 Parameters
880 ----------
881 objs : `pandas.DataFrame`
882 Catalog with DiaObject records.
883 visit_time : `astropy.time.Time`
884 Time of the visit.
885 replica_chunk : `ReplicaChunk`
886 Insert identifier.
887 """
888 if len(objs) == 0:
889 _LOG.debug("No objects to write to database.")
890 return
892 # Some types like np.int64 can cause issues with sqlalchemy, convert
893 # them to int.
894 ids = sorted(int(oid) for oid in objs["diaObjectId"])
895 _LOG.debug("first object ID: %d", ids[0])
897 # TODO: Need to verify that we are using correct scale here for
898 # DATETIME representation (see DM-31996).
899 dt = visit_time.datetime
901 # everything to be done in single transaction
902 if self.config.dia_object_index == "last_object_table":
903 # Insert and replace all records in LAST table.
904 table = self._schema.get_table(ApdbTables.DiaObjectLast)
906 # Drop the previous objects (pandas cannot upsert).
907 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
909 with self._timer("delete_time", tags={"table": table.name}):
910 res = connection.execute(query)
911 _LOG.debug("deleted %s objects", res.rowcount)
913 # DiaObjectLast is a subset of DiaObject, strip missing columns
914 last_column_names = [column.name for column in table.columns]
915 last_objs = objs[last_column_names]
916 last_objs = _coerce_uint64(last_objs)
918 if "lastNonForcedSource" in last_objs.columns:
919 # lastNonForcedSource is defined NOT NULL, fill it with visit
920 # time just in case.
921 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
922 else:
923 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
924 last_objs.set_index(extra_column.index, inplace=True)
925 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
927 with self._timer("insert_time", tags={"table": "DiaObjectLast"}):
928 last_objs.to_sql(
929 table.name,
930 connection,
931 if_exists="append",
932 index=False,
933 schema=table.schema,
934 )
935 else:
936 # truncate existing validity intervals
937 table = self._schema.get_table(ApdbTables.DiaObject)
939 update = (
940 table.update()
941 .values(validityEnd=dt)
942 .where(
943 sql.expression.and_(
944 table.columns["diaObjectId"].in_(ids),
945 table.columns["validityEnd"].is_(None),
946 )
947 )
948 )
950 with self._timer("truncate_time", tags={"table": table.name}):
951 res = connection.execute(update)
952 _LOG.debug("truncated %s intervals", res.rowcount)
954 objs = _coerce_uint64(objs)
956 # Fill additional columns
957 extra_columns: list[pandas.Series] = []
958 if "validityStart" in objs.columns:
959 objs["validityStart"] = dt
960 else:
961 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
962 if "validityEnd" in objs.columns:
963 objs["validityEnd"] = None
964 else:
965 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
966 if "lastNonForcedSource" in objs.columns:
967 # lastNonForcedSource is defined NOT NULL, fill it with visit time
968 # just in case.
969 objs["lastNonForcedSource"].fillna(dt, inplace=True)
970 else:
971 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
972 if extra_columns:
973 objs.set_index(extra_columns[0].index, inplace=True)
974 objs = pandas.concat([objs] + extra_columns, axis="columns")
976 # Insert replica data
977 table = self._schema.get_table(ApdbTables.DiaObject)
978 replica_data: list[dict] = []
979 replica_stmt: Any = None
980 replica_table_name = ""
981 if replica_chunk is not None:
982 pk_names = [column.name for column in table.primary_key]
983 replica_data = objs[pk_names].to_dict("records")
984 for row in replica_data:
985 row["apdb_replica_chunk"] = replica_chunk.id
986 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks)
987 replica_table_name = replica_table.name
988 replica_stmt = replica_table.insert()
990 # insert new versions
991 with self._timer("insert_time", tags={"table": table.name}):
992 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
993 if replica_stmt is not None:
994 with self._timer("insert_time", tags={"table": replica_table_name}):
995 connection.execute(replica_stmt, replica_data)
997 def _storeDiaSources(
998 self,
999 sources: pandas.DataFrame,
1000 replica_chunk: ReplicaChunk | None,
1001 connection: sqlalchemy.engine.Connection,
1002 ) -> None:
1003 """Store catalog of DiaSources from current visit.
1005 Parameters
1006 ----------
1007 sources : `pandas.DataFrame`
1008 Catalog containing DiaSource records
1009 """
1010 table = self._schema.get_table(ApdbTables.DiaSource)
1012 # Insert replica data
1013 replica_data: list[dict] = []
1014 replica_stmt: Any = None
1015 replica_table_name = ""
1016 if replica_chunk is not None:
1017 pk_names = [column.name for column in table.primary_key]
1018 replica_data = sources[pk_names].to_dict("records")
1019 for row in replica_data:
1020 row["apdb_replica_chunk"] = replica_chunk.id
1021 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks)
1022 replica_table_name = replica_table.name
1023 replica_stmt = replica_table.insert()
1025 # everything to be done in single transaction
1026 with self._timer("insert_time", tags={"table": table.name}):
1027 sources = _coerce_uint64(sources)
1028 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1029 if replica_stmt is not None:
1030 with self._timer("replica_insert_time", tags={"table": replica_table_name}):
1031 connection.execute(replica_stmt, replica_data)
1033 def _storeDiaForcedSources(
1034 self,
1035 sources: pandas.DataFrame,
1036 replica_chunk: ReplicaChunk | None,
1037 connection: sqlalchemy.engine.Connection,
1038 ) -> None:
1039 """Store a set of DiaForcedSources from current visit.
1041 Parameters
1042 ----------
1043 sources : `pandas.DataFrame`
1044 Catalog containing DiaForcedSource records
1045 """
1046 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1048 # Insert replica data
1049 replica_data: list[dict] = []
1050 replica_stmt: Any = None
1051 replica_table_name = ""
1052 if replica_chunk is not None:
1053 pk_names = [column.name for column in table.primary_key]
1054 replica_data = sources[pk_names].to_dict("records")
1055 for row in replica_data:
1056 row["apdb_replica_chunk"] = replica_chunk.id
1057 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks)
1058 replica_table_name = replica_table.name
1059 replica_stmt = replica_table.insert()
1061 # everything to be done in single transaction
1062 with self._timer("insert_time", tags={"table": table.name}):
1063 sources = _coerce_uint64(sources)
1064 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1065 if replica_stmt is not None:
1066 with self._timer("insert_time", tags={"table": replica_table_name}):
1067 connection.execute(replica_stmt, replica_data)
1069 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1070 """Generate a set of HTM indices covering specified region.
1072 Parameters
1073 ----------
1074 region: `sphgeom.Region`
1075 Region that needs to be indexed.
1077 Returns
1078 -------
1079 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1080 """
1081 _LOG.debug("region: %s", region)
1082 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
1084 return indices.ranges()
1086 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1087 """Make SQLAlchemy expression for selecting records in a region."""
1088 htm_index_column = table.columns[self.config.htm_index_column]
1089 exprlist = []
1090 pixel_ranges = self._htm_indices(region)
1091 for low, upper in pixel_ranges:
1092 upper -= 1
1093 if low == upper:
1094 exprlist.append(htm_index_column == low)
1095 else:
1096 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1098 return sql.expression.or_(*exprlist)
1100 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1101 """Calculate HTM index for each record and add it to a DataFrame.
1103 Notes
1104 -----
1105 This overrides any existing column in a DataFrame with the same name
1106 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1107 returned.
1108 """
1109 # calculate HTM index for every DiaObject
1110 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1111 ra_col, dec_col = self.config.ra_dec_columns
1112 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1113 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1114 idx = self.pixelator.index(uv3d)
1115 htm_index[i] = idx
1116 df = df.copy()
1117 df[self.config.htm_index_column] = htm_index
1118 return df
1120 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1121 """Add pixelId column to DiaSource catalog.
1123 Notes
1124 -----
1125 This method copies pixelId value from a matching DiaObject record.
1126 DiaObject catalog needs to have a pixelId column filled by
1127 ``_add_obj_htm_index`` method and DiaSource records need to be
1128 associated to DiaObjects via ``diaObjectId`` column.
1130 This overrides any existing column in a DataFrame with the same name
1131 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1132 returned.
1133 """
1134 pixel_id_map: dict[int, int] = {
1135 diaObjectId: pixelId
1136 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1137 }
1138 # DiaSources associated with SolarSystemObjects do not have an
1139 # associated DiaObject hence we skip them and set their htmIndex
1140 # value to 0.
1141 pixel_id_map[0] = 0
1142 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1143 for i, diaObjId in enumerate(sources["diaObjectId"]):
1144 htm_index[i] = pixel_id_map[diaObjId]
1145 sources = sources.copy()
1146 sources[self.config.htm_index_column] = htm_index
1147 return sources