Coverage for python/lsst/dax/apdb/sql/apdbSql.py: 15%
493 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 09:55 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 09:55 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from contextlib import closing, suppress
32from typing import TYPE_CHECKING, Any, cast
34import astropy.time
35import numpy as np
36import pandas
37import sqlalchemy
38import sqlalchemy.dialects.postgresql
39import sqlalchemy.dialects.sqlite
40from lsst.pex.config import ChoiceField, Field, ListField
41from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
42from lsst.utils.iteration import chunk_iterable
43from sqlalchemy import func, sql
44from sqlalchemy.pool import NullPool
46from ..apdb import Apdb, ApdbConfig
47from ..apdbConfigFreezer import ApdbConfigFreezer
48from ..apdbReplica import ReplicaChunk
49from ..apdbSchema import ApdbTables
50from ..schema_model import Table
51from ..timer import Timer
52from ..versionTuple import IncompatibleVersionError, VersionTuple
53from .apdbMetadataSql import ApdbMetadataSql
54from .apdbSqlReplica import ApdbSqlReplica
55from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
57if TYPE_CHECKING:
58 import sqlite3
60 from ..apdbMetadata import ApdbMetadata
62_LOG = logging.getLogger(__name__)
64VERSION = VersionTuple(0, 1, 0)
65"""Version for the code controlling non-replication tables. This needs to be
66updated following compatibility rules when schema produced by this code
67changes.
68"""
71def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
72 """Change the type of uint64 columns to int64, and return copy of data
73 frame.
74 """
75 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
76 return df.astype({name: np.int64 for name in names})
79def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float:
80 """Calculate starting point for time-based source search.
82 Parameters
83 ----------
84 visit_time : `astropy.time.Time`
85 Time of current visit.
86 months : `int`
87 Number of months in the sources history.
89 Returns
90 -------
91 time : `float`
92 A ``midpointMjdTai`` starting point, MJD time.
93 """
94 # TODO: Use of MJD must be consistent with the code in ap_association
95 # (see DM-31996)
96 return visit_time.mjd - months * 30
99def _onSqlite3Connect(
100 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
101) -> None:
102 # Enable foreign keys
103 with closing(dbapiConnection.cursor()) as cursor:
104 cursor.execute("PRAGMA foreign_keys=ON;")
107class ApdbSqlConfig(ApdbConfig):
108 """APDB configuration class for SQL implementation (ApdbSql)."""
110 db_url = Field[str](doc="SQLAlchemy database connection URI")
111 isolation_level = ChoiceField[str](
112 doc=(
113 "Transaction isolation level, if unset then backend-default value "
114 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
115 "Some backends may not support every allowed value."
116 ),
117 allowed={
118 "READ_COMMITTED": "Read committed",
119 "READ_UNCOMMITTED": "Read uncommitted",
120 "REPEATABLE_READ": "Repeatable read",
121 "SERIALIZABLE": "Serializable",
122 },
123 default=None,
124 optional=True,
125 )
126 connection_pool = Field[bool](
127 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
128 default=True,
129 )
130 connection_timeout = Field[float](
131 doc=(
132 "Maximum time to wait time for database lock to be released before exiting. "
133 "Defaults to sqlalchemy defaults if not set."
134 ),
135 default=None,
136 optional=True,
137 )
138 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
139 dia_object_index = ChoiceField[str](
140 doc="Indexing mode for DiaObject table",
141 allowed={
142 "baseline": "Index defined in baseline schema",
143 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
144 "last_object_table": "Separate DiaObjectLast table",
145 },
146 default="baseline",
147 )
148 htm_level = Field[int](doc="HTM indexing level", default=20)
149 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
150 htm_index_column = Field[str](
151 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
152 )
153 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
154 dia_object_columns = ListField[str](
155 doc="List of columns to read from DiaObject, by default read all columns", default=[]
156 )
157 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
158 namespace = Field[str](
159 doc=(
160 "Namespace or schema name for all tables in APDB database. "
161 "Presently only works for PostgreSQL backend. "
162 "If schema with this name does not exist it will be created when "
163 "APDB tables are created."
164 ),
165 default=None,
166 optional=True,
167 )
168 timer = Field[bool](doc="If True then print/log timing information", default=False)
170 def validate(self) -> None:
171 super().validate()
172 if len(self.ra_dec_columns) != 2:
173 raise ValueError("ra_dec_columns must have exactly two column names")
176class ApdbSql(Apdb):
177 """Implementation of APDB interface based on SQL database.
179 The implementation is configured via standard ``pex_config`` mechanism
180 using `ApdbSqlConfig` configuration class. For an example of different
181 configurations check ``config/`` folder.
183 Parameters
184 ----------
185 config : `ApdbSqlConfig`
186 Configuration object.
187 """
189 ConfigClass = ApdbSqlConfig
191 metadataSchemaVersionKey = "version:schema"
192 """Name of the metadata key to store schema version number."""
194 metadataCodeVersionKey = "version:ApdbSql"
195 """Name of the metadata key to store code version number."""
197 metadataReplicaVersionKey = "version:ApdbSqlReplica"
198 """Name of the metadata key to store replica code version number."""
200 metadataConfigKey = "config:apdb-sql.json"
201 """Name of the metadata key to store code version number."""
203 _frozen_parameters = (
204 "use_insert_id",
205 "dia_object_index",
206 "htm_level",
207 "htm_index_column",
208 "ra_dec_columns",
209 )
210 """Names of the config parameters to be frozen in metadata table."""
212 def __init__(self, config: ApdbSqlConfig):
213 self._engine = self._makeEngine(config)
215 sa_metadata = sqlalchemy.MetaData(schema=config.namespace)
216 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix)
217 meta_table: sqlalchemy.schema.Table | None = None
218 with suppress(sqlalchemy.exc.NoSuchTableError):
219 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine)
221 self._metadata = ApdbMetadataSql(self._engine, meta_table)
223 # Read frozen config from metadata.
224 config_json = self._metadata.get(self.metadataConfigKey)
225 if config_json is not None:
226 # Update config from metadata.
227 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters)
228 self.config = freezer.update(config, config_json)
229 else:
230 self.config = config
231 self.config.validate()
233 self._schema = ApdbSqlSchema(
234 engine=self._engine,
235 dia_object_index=self.config.dia_object_index,
236 schema_file=self.config.schema_file,
237 schema_name=self.config.schema_name,
238 prefix=self.config.prefix,
239 namespace=self.config.namespace,
240 htm_index_column=self.config.htm_index_column,
241 enable_replica=self.config.use_insert_id,
242 )
244 if self._metadata.table_exists():
245 self._versionCheck(self._metadata)
247 self.pixelator = HtmPixelization(self.config.htm_level)
249 _LOG.debug("APDB Configuration:")
250 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
251 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
252 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
253 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
254 _LOG.debug(" schema_file: %s", self.config.schema_file)
255 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
256 _LOG.debug(" schema prefix: %s", self.config.prefix)
258 @classmethod
259 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine:
260 """Make SQLALchemy engine based on configured parameters.
262 Parameters
263 ----------
264 config : `ApdbSqlConfig`
265 Configuration object.
266 """
267 # engine is reused between multiple processes, make sure that we don't
268 # share connections by disabling pool (by using NullPool class)
269 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo)
270 conn_args: dict[str, Any] = dict()
271 if not config.connection_pool:
272 kw.update(poolclass=NullPool)
273 if config.isolation_level is not None:
274 kw.update(isolation_level=config.isolation_level)
275 elif config.db_url.startswith("sqlite"): # type: ignore
276 # Use READ_UNCOMMITTED as default value for sqlite.
277 kw.update(isolation_level="READ_UNCOMMITTED")
278 if config.connection_timeout is not None:
279 if config.db_url.startswith("sqlite"):
280 conn_args.update(timeout=config.connection_timeout)
281 elif config.db_url.startswith(("postgresql", "mysql")):
282 conn_args.update(connect_timeout=config.connection_timeout)
283 kw.update(connect_args=conn_args)
284 engine = sqlalchemy.create_engine(config.db_url, **kw)
286 if engine.dialect.name == "sqlite":
287 # Need to enable foreign keys on every new connection.
288 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
290 return engine
292 def _versionCheck(self, metadata: ApdbMetadataSql) -> None:
293 """Check schema version compatibility."""
295 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
296 """Retrieve version number from given metadata key."""
297 if metadata.table_exists():
298 version_str = metadata.get(key)
299 if version_str is None:
300 # Should not happen with existing metadata table.
301 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
302 return VersionTuple.fromString(version_str)
303 return default
305 # For old databases where metadata table does not exist we assume that
306 # version of both code and schema is 0.1.0.
307 initial_version = VersionTuple(0, 1, 0)
308 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
309 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
311 # For now there is no way to make read-only APDB instances, assume that
312 # any access can do updates.
313 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
314 raise IncompatibleVersionError(
315 f"Configured schema version {self._schema.schemaVersion()} "
316 f"is not compatible with database version {db_schema_version}"
317 )
318 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
319 raise IncompatibleVersionError(
320 f"Current code version {self.apdbImplementationVersion()} "
321 f"is not compatible with database version {db_code_version}"
322 )
324 # Check replica code version only if replica is enabled.
325 if self._schema.has_replica_chunks:
326 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version)
327 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion()
328 if not code_replica_version.checkCompatibility(db_replica_version, True):
329 raise IncompatibleVersionError(
330 f"Current replication code version {code_replica_version} "
331 f"is not compatible with database version {db_replica_version}"
332 )
334 @classmethod
335 def apdbImplementationVersion(cls) -> VersionTuple:
336 # Docstring inherited from base class.
337 return VERSION
339 @classmethod
340 def init_database(
341 cls,
342 db_url: str,
343 *,
344 schema_file: str | None = None,
345 schema_name: str | None = None,
346 read_sources_months: int | None = None,
347 read_forced_sources_months: int | None = None,
348 use_insert_id: bool = False,
349 connection_timeout: int | None = None,
350 dia_object_index: str | None = None,
351 htm_level: int | None = None,
352 htm_index_column: str | None = None,
353 ra_dec_columns: list[str] | None = None,
354 prefix: str | None = None,
355 namespace: str | None = None,
356 drop: bool = False,
357 ) -> ApdbSqlConfig:
358 """Initialize new APDB instance and make configuration object for it.
360 Parameters
361 ----------
362 db_url : `str`
363 SQLAlchemy database URL.
364 schema_file : `str`, optional
365 Location of (YAML) configuration file with APDB schema. If not
366 specified then default location will be used.
367 schema_name : str | None
368 Name of the schema in YAML configuration file. If not specified
369 then default name will be used.
370 read_sources_months : `int`, optional
371 Number of months of history to read from DiaSource.
372 read_forced_sources_months : `int`, optional
373 Number of months of history to read from DiaForcedSource.
374 use_insert_id : `bool`
375 If True, make additional tables used for replication to PPDB.
376 connection_timeout : `int`, optional
377 Database connection timeout in seconds.
378 dia_object_index : `str`, optional
379 Indexing mode for DiaObject table.
380 htm_level : `int`, optional
381 HTM indexing level.
382 htm_index_column : `str`, optional
383 Name of a HTM index column for DiaObject and DiaSource tables.
384 ra_dec_columns : `list` [`str`], optional
385 Names of ra/dec columns in DiaObject table.
386 prefix : `str`, optional
387 Optional prefix for all table names.
388 namespace : `str`, optional
389 Name of the database schema for all APDB tables. If not specified
390 then default schema is used.
391 drop : `bool`, optional
392 If `True` then drop existing tables before re-creating the schema.
394 Returns
395 -------
396 config : `ApdbSqlConfig`
397 Resulting configuration object for a created APDB instance.
398 """
399 config = ApdbSqlConfig(db_url=db_url, use_insert_id=use_insert_id)
400 if schema_file is not None:
401 config.schema_file = schema_file
402 if schema_name is not None:
403 config.schema_name = schema_name
404 if read_sources_months is not None:
405 config.read_sources_months = read_sources_months
406 if read_forced_sources_months is not None:
407 config.read_forced_sources_months = read_forced_sources_months
408 if connection_timeout is not None:
409 config.connection_timeout = connection_timeout
410 if dia_object_index is not None:
411 config.dia_object_index = dia_object_index
412 if htm_level is not None:
413 config.htm_level = htm_level
414 if htm_index_column is not None:
415 config.htm_index_column = htm_index_column
416 if ra_dec_columns is not None:
417 config.ra_dec_columns = ra_dec_columns
418 if prefix is not None:
419 config.prefix = prefix
420 if namespace is not None:
421 config.namespace = namespace
423 cls._makeSchema(config, drop=drop)
425 return config
427 def apdbSchemaVersion(self) -> VersionTuple:
428 # Docstring inherited from base class.
429 return self._schema.schemaVersion()
431 def get_replica(self) -> ApdbSqlReplica:
432 """Return `ApdbReplica` instance for this database."""
433 return ApdbSqlReplica(self._schema, self._engine)
435 def tableRowCount(self) -> dict[str, int]:
436 """Return dictionary with the table names and row counts.
438 Used by ``ap_proto`` to keep track of the size of the database tables.
439 Depending on database technology this could be expensive operation.
441 Returns
442 -------
443 row_counts : `dict`
444 Dict where key is a table name and value is a row count.
445 """
446 res = {}
447 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
448 if self.config.dia_object_index == "last_object_table":
449 tables.append(ApdbTables.DiaObjectLast)
450 with self._engine.begin() as conn:
451 for table in tables:
452 sa_table = self._schema.get_table(table)
453 stmt = sql.select(func.count()).select_from(sa_table)
454 count: int = conn.execute(stmt).scalar_one()
455 res[table.name] = count
457 return res
459 def tableDef(self, table: ApdbTables) -> Table | None:
460 # docstring is inherited from a base class
461 return self._schema.tableSchemas.get(table)
463 @classmethod
464 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None:
465 # docstring is inherited from a base class
467 if not isinstance(config, ApdbSqlConfig):
468 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
470 engine = cls._makeEngine(config)
472 # Ask schema class to create all tables.
473 schema = ApdbSqlSchema(
474 engine=engine,
475 dia_object_index=config.dia_object_index,
476 schema_file=config.schema_file,
477 schema_name=config.schema_name,
478 prefix=config.prefix,
479 namespace=config.namespace,
480 htm_index_column=config.htm_index_column,
481 enable_replica=config.use_insert_id,
482 )
483 schema.makeSchema(drop=drop)
485 # Need metadata table to store few items in it, if table exists.
486 meta_table: sqlalchemy.schema.Table | None = None
487 with suppress(ValueError):
488 meta_table = schema.get_table(ApdbTables.metadata)
490 apdb_meta = ApdbMetadataSql(engine, meta_table)
491 if apdb_meta.table_exists():
492 # Fill version numbers, overwrite if they are already there.
493 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True)
494 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
495 if config.use_insert_id:
496 # Only store replica code version if replcia is enabled.
497 apdb_meta.set(
498 cls.metadataReplicaVersionKey,
499 str(ApdbSqlReplica.apdbReplicaImplementationVersion()),
500 force=True,
501 )
503 # Store frozen part of a configuration in metadata.
504 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters)
505 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
507 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
508 # docstring is inherited from a base class
510 # decide what columns we need
511 if self.config.dia_object_index == "last_object_table":
512 table_enum = ApdbTables.DiaObjectLast
513 else:
514 table_enum = ApdbTables.DiaObject
515 table = self._schema.get_table(table_enum)
516 if not self.config.dia_object_columns:
517 columns = self._schema.get_apdb_columns(table_enum)
518 else:
519 columns = [table.c[col] for col in self.config.dia_object_columns]
520 query = sql.select(*columns)
522 # build selection
523 query = query.where(self._filterRegion(table, region))
525 # select latest version of objects
526 if self.config.dia_object_index != "last_object_table":
527 query = query.where(table.c.validityEnd == None) # noqa: E711
529 # _LOG.debug("query: %s", query)
531 # execute select
532 with Timer("DiaObject select", self.config.timer):
533 with self._engine.begin() as conn:
534 objects = pandas.read_sql_query(query, conn)
535 _LOG.debug("found %s DiaObjects", len(objects))
536 return objects
538 def getDiaSources(
539 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
540 ) -> pandas.DataFrame | None:
541 # docstring is inherited from a base class
542 if self.config.read_sources_months == 0:
543 _LOG.debug("Skip DiaSources fetching")
544 return None
546 if object_ids is None:
547 # region-based select
548 return self._getDiaSourcesInRegion(region, visit_time)
549 else:
550 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
552 def getDiaForcedSources(
553 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
554 ) -> pandas.DataFrame | None:
555 # docstring is inherited from a base class
556 if self.config.read_forced_sources_months == 0:
557 _LOG.debug("Skip DiaForceSources fetching")
558 return None
560 if object_ids is None:
561 # This implementation does not support region-based selection.
562 raise NotImplementedError("Region-based selection is not supported")
564 # TODO: DateTime.MJD must be consistent with code in ap_association,
565 # alternatively we can fill midpointMjdTai ourselves in store()
566 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
567 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
569 with Timer("DiaForcedSource select", self.config.timer):
570 sources = self._getSourcesByIDs(
571 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
572 )
574 _LOG.debug("found %s DiaForcedSources", len(sources))
575 return sources
577 def containsVisitDetector(self, visit: int, detector: int) -> bool:
578 # docstring is inherited from a base class
579 raise NotImplementedError()
581 def containsCcdVisit(self, ccdVisitId: int) -> bool:
582 """Test whether data for a given visit-detector is present in the APDB.
584 This method is a placeholder until `Apdb.containsVisitDetector` can
585 be implemented.
587 Parameters
588 ----------
589 ccdVisitId : `int`
590 The packed ID of the visit-detector to search for.
592 Returns
593 -------
594 present : `bool`
595 `True` if some DiaSource records exist for the specified
596 observation, `False` otherwise.
597 """
598 # TODO: remove this method in favor of containsVisitDetector on either
599 # DM-41671 or a ticket that removes ccdVisitId from these tables
600 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
601 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
602 # Query should load only one leaf page of the index
603 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
604 # Backup query in case an image was processed but had no diaSources
605 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
607 with self._engine.begin() as conn:
608 result = conn.execute(query1).scalar_one_or_none()
609 if result is not None:
610 return True
611 else:
612 result = conn.execute(query2).scalar_one_or_none()
613 return result is not None
615 def getSSObjects(self) -> pandas.DataFrame:
616 # docstring is inherited from a base class
618 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
619 query = sql.select(*columns)
621 # execute select
622 with Timer("DiaObject select", self.config.timer):
623 with self._engine.begin() as conn:
624 objects = pandas.read_sql_query(query, conn)
625 _LOG.debug("found %s SSObjects", len(objects))
626 return objects
628 def store(
629 self,
630 visit_time: astropy.time.Time,
631 objects: pandas.DataFrame,
632 sources: pandas.DataFrame | None = None,
633 forced_sources: pandas.DataFrame | None = None,
634 ) -> None:
635 # docstring is inherited from a base class
637 # We want to run all inserts in one transaction.
638 with self._engine.begin() as connection:
639 replica_chunk: ReplicaChunk | None = None
640 if self._schema.has_replica_chunks:
641 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds)
642 self._storeReplicaChunk(replica_chunk, visit_time, connection)
644 # fill pixelId column for DiaObjects
645 objects = self._add_obj_htm_index(objects)
646 self._storeDiaObjects(objects, visit_time, replica_chunk, connection)
648 if sources is not None:
649 # copy pixelId column from DiaObjects to DiaSources
650 sources = self._add_src_htm_index(sources, objects)
651 self._storeDiaSources(sources, replica_chunk, connection)
653 if forced_sources is not None:
654 self._storeDiaForcedSources(forced_sources, replica_chunk, connection)
656 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
657 # docstring is inherited from a base class
659 idColumn = "ssObjectId"
660 table = self._schema.get_table(ApdbTables.SSObject)
662 # everything to be done in single transaction
663 with self._engine.begin() as conn:
664 # Find record IDs that already exist. Some types like np.int64 can
665 # cause issues with sqlalchemy, convert them to int.
666 ids = sorted(int(oid) for oid in objects[idColumn])
668 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
669 result = conn.execute(query)
670 knownIds = set(row.ssObjectId for row in result)
672 filter = objects[idColumn].isin(knownIds)
673 toUpdate = cast(pandas.DataFrame, objects[filter])
674 toInsert = cast(pandas.DataFrame, objects[~filter])
676 # insert new records
677 if len(toInsert) > 0:
678 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema)
680 # update existing records
681 if len(toUpdate) > 0:
682 whereKey = f"{idColumn}_param"
683 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
684 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
685 values = toUpdate.to_dict("records")
686 result = conn.execute(update, values)
688 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
689 # docstring is inherited from a base class
691 table = self._schema.get_table(ApdbTables.DiaSource)
692 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
694 with self._engine.begin() as conn:
695 # Need to make sure that every ID exists in the database, but
696 # executemany may not support rowcount, so iterate and check what
697 # is missing.
698 missing_ids: list[int] = []
699 for key, value in idMap.items():
700 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
701 result = conn.execute(query, params)
702 if result.rowcount == 0:
703 missing_ids.append(key)
704 if missing_ids:
705 missing = ",".join(str(item) for item in missing_ids)
706 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
708 def dailyJob(self) -> None:
709 # docstring is inherited from a base class
710 pass
712 def countUnassociatedObjects(self) -> int:
713 # docstring is inherited from a base class
715 # Retrieve the DiaObject table.
716 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
718 # Construct the sql statement.
719 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
720 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
722 # Return the count.
723 with self._engine.begin() as conn:
724 count = conn.execute(stmt).scalar_one()
726 return count
728 @property
729 def metadata(self) -> ApdbMetadata:
730 # docstring is inherited from a base class
731 if self._metadata is None:
732 raise RuntimeError("Database schema was not initialized.")
733 return self._metadata
735 def _getDiaSourcesInRegion(self, region: Region, visit_time: astropy.time.Time) -> pandas.DataFrame:
736 """Return catalog of DiaSource instances from given region.
738 Parameters
739 ----------
740 region : `lsst.sphgeom.Region`
741 Region to search for DIASources.
742 visit_time : `astropy.time.Time`
743 Time of the current visit.
745 Returns
746 -------
747 catalog : `pandas.DataFrame`
748 Catalog containing DiaSource records.
749 """
750 # TODO: DateTime.MJD must be consistent with code in ap_association,
751 # alternatively we can fill midpointMjdTai ourselves in store()
752 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
753 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
755 table = self._schema.get_table(ApdbTables.DiaSource)
756 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
757 query = sql.select(*columns)
759 # build selection
760 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
761 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
762 query = query.where(where)
764 # execute select
765 with Timer("DiaSource select", self.config.timer):
766 with self._engine.begin() as conn:
767 sources = pandas.read_sql_query(query, conn)
768 _LOG.debug("found %s DiaSources", len(sources))
769 return sources
771 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: astropy.time.Time) -> pandas.DataFrame:
772 """Return catalog of DiaSource instances given set of DiaObject IDs.
774 Parameters
775 ----------
776 object_ids :
777 Collection of DiaObject IDs
778 visit_time : `astropy.time.Time`
779 Time of the current visit.
781 Returns
782 -------
783 catalog : `pandas.DataFrame`
784 Catalog contaning DiaSource records.
785 """
786 # TODO: DateTime.MJD must be consistent with code in ap_association,
787 # alternatively we can fill midpointMjdTai ourselves in store()
788 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
789 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
791 with Timer("DiaSource select", self.config.timer):
792 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
794 _LOG.debug("found %s DiaSources", len(sources))
795 return sources
797 def _getSourcesByIDs(
798 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
799 ) -> pandas.DataFrame:
800 """Return catalog of DiaSource or DiaForcedSource instances given set
801 of DiaObject IDs.
803 Parameters
804 ----------
805 table : `sqlalchemy.schema.Table`
806 Database table.
807 object_ids :
808 Collection of DiaObject IDs
809 midpointMjdTai_start : `float`
810 Earliest midpointMjdTai to retrieve.
812 Returns
813 -------
814 catalog : `pandas.DataFrame`
815 Catalog contaning DiaSource records. `None` is returned if
816 ``read_sources_months`` configuration parameter is set to 0 or
817 when ``object_ids`` is empty.
818 """
819 table = self._schema.get_table(table_enum)
820 columns = self._schema.get_apdb_columns(table_enum)
822 sources: pandas.DataFrame | None = None
823 if len(object_ids) <= 0:
824 _LOG.debug("ID list is empty, just fetch empty result")
825 query = sql.select(*columns).where(sql.literal(False))
826 with self._engine.begin() as conn:
827 sources = pandas.read_sql_query(query, conn)
828 else:
829 data_frames: list[pandas.DataFrame] = []
830 for ids in chunk_iterable(sorted(object_ids), 1000):
831 query = sql.select(*columns)
833 # Some types like np.int64 can cause issues with
834 # sqlalchemy, convert them to int.
835 int_ids = [int(oid) for oid in ids]
837 # select by object id
838 query = query.where(
839 sql.expression.and_(
840 table.columns["diaObjectId"].in_(int_ids),
841 table.columns["midpointMjdTai"] > midpointMjdTai_start,
842 )
843 )
845 # execute select
846 with self._engine.begin() as conn:
847 data_frames.append(pandas.read_sql_query(query, conn))
849 if len(data_frames) == 1:
850 sources = data_frames[0]
851 else:
852 sources = pandas.concat(data_frames)
853 assert sources is not None, "Catalog cannot be None"
854 return sources
856 def _storeReplicaChunk(
857 self,
858 replica_chunk: ReplicaChunk,
859 visit_time: astropy.time.Time,
860 connection: sqlalchemy.engine.Connection,
861 ) -> None:
862 dt = visit_time.datetime
864 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks)
866 # We need UPSERT which is dialect-specific construct
867 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id}
868 row = {"apdb_replica_chunk": replica_chunk.id} | values
869 if connection.dialect.name == "sqlite":
870 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table)
871 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values)
872 connection.execute(insert_sqlite, row)
873 elif connection.dialect.name == "postgresql":
874 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table)
875 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values)
876 connection.execute(insert_pg, row)
877 else:
878 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.")
880 def _storeDiaObjects(
881 self,
882 objs: pandas.DataFrame,
883 visit_time: astropy.time.Time,
884 replica_chunk: ReplicaChunk | None,
885 connection: sqlalchemy.engine.Connection,
886 ) -> None:
887 """Store catalog of DiaObjects from current visit.
889 Parameters
890 ----------
891 objs : `pandas.DataFrame`
892 Catalog with DiaObject records.
893 visit_time : `astropy.time.Time`
894 Time of the visit.
895 replica_chunk : `ReplicaChunk`
896 Insert identifier.
897 """
898 if len(objs) == 0:
899 _LOG.debug("No objects to write to database.")
900 return
902 # Some types like np.int64 can cause issues with sqlalchemy, convert
903 # them to int.
904 ids = sorted(int(oid) for oid in objs["diaObjectId"])
905 _LOG.debug("first object ID: %d", ids[0])
907 # TODO: Need to verify that we are using correct scale here for
908 # DATETIME representation (see DM-31996).
909 dt = visit_time.datetime
911 # everything to be done in single transaction
912 if self.config.dia_object_index == "last_object_table":
913 # Insert and replace all records in LAST table.
914 table = self._schema.get_table(ApdbTables.DiaObjectLast)
916 # Drop the previous objects (pandas cannot upsert).
917 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
919 with Timer(table.name + " delete", self.config.timer):
920 res = connection.execute(query)
921 _LOG.debug("deleted %s objects", res.rowcount)
923 # DiaObjectLast is a subset of DiaObject, strip missing columns
924 last_column_names = [column.name for column in table.columns]
925 last_objs = objs[last_column_names]
926 last_objs = _coerce_uint64(last_objs)
928 if "lastNonForcedSource" in last_objs.columns:
929 # lastNonForcedSource is defined NOT NULL, fill it with visit
930 # time just in case.
931 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
932 else:
933 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
934 last_objs.set_index(extra_column.index, inplace=True)
935 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
937 with Timer("DiaObjectLast insert", self.config.timer):
938 last_objs.to_sql(
939 table.name,
940 connection,
941 if_exists="append",
942 index=False,
943 schema=table.schema,
944 )
945 else:
946 # truncate existing validity intervals
947 table = self._schema.get_table(ApdbTables.DiaObject)
949 update = (
950 table.update()
951 .values(validityEnd=dt)
952 .where(
953 sql.expression.and_(
954 table.columns["diaObjectId"].in_(ids),
955 table.columns["validityEnd"].is_(None),
956 )
957 )
958 )
960 # _LOG.debug("query: %s", query)
962 with Timer(table.name + " truncate", self.config.timer):
963 res = connection.execute(update)
964 _LOG.debug("truncated %s intervals", res.rowcount)
966 objs = _coerce_uint64(objs)
968 # Fill additional columns
969 extra_columns: list[pandas.Series] = []
970 if "validityStart" in objs.columns:
971 objs["validityStart"] = dt
972 else:
973 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
974 if "validityEnd" in objs.columns:
975 objs["validityEnd"] = None
976 else:
977 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
978 if "lastNonForcedSource" in objs.columns:
979 # lastNonForcedSource is defined NOT NULL, fill it with visit time
980 # just in case.
981 objs["lastNonForcedSource"].fillna(dt, inplace=True)
982 else:
983 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
984 if extra_columns:
985 objs.set_index(extra_columns[0].index, inplace=True)
986 objs = pandas.concat([objs] + extra_columns, axis="columns")
988 # Insert replica data
989 table = self._schema.get_table(ApdbTables.DiaObject)
990 replica_data: list[dict] = []
991 replica_stmt: Any = None
992 if replica_chunk is not None:
993 pk_names = [column.name for column in table.primary_key]
994 replica_data = objs[pk_names].to_dict("records")
995 for row in replica_data:
996 row["apdb_replica_chunk"] = replica_chunk.id
997 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks)
998 replica_stmt = replica_table.insert()
1000 # insert new versions
1001 with Timer("DiaObject insert", self.config.timer):
1002 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1003 if replica_stmt is not None:
1004 connection.execute(replica_stmt, replica_data)
1006 def _storeDiaSources(
1007 self,
1008 sources: pandas.DataFrame,
1009 replica_chunk: ReplicaChunk | None,
1010 connection: sqlalchemy.engine.Connection,
1011 ) -> None:
1012 """Store catalog of DiaSources from current visit.
1014 Parameters
1015 ----------
1016 sources : `pandas.DataFrame`
1017 Catalog containing DiaSource records
1018 """
1019 table = self._schema.get_table(ApdbTables.DiaSource)
1021 # Insert replica data
1022 replica_data: list[dict] = []
1023 replica_stmt: Any = None
1024 if replica_chunk is not None:
1025 pk_names = [column.name for column in table.primary_key]
1026 replica_data = sources[pk_names].to_dict("records")
1027 for row in replica_data:
1028 row["apdb_replica_chunk"] = replica_chunk.id
1029 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks)
1030 replica_stmt = replica_table.insert()
1032 # everything to be done in single transaction
1033 with Timer("DiaSource insert", self.config.timer):
1034 sources = _coerce_uint64(sources)
1035 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1036 if replica_stmt is not None:
1037 connection.execute(replica_stmt, replica_data)
1039 def _storeDiaForcedSources(
1040 self,
1041 sources: pandas.DataFrame,
1042 replica_chunk: ReplicaChunk | None,
1043 connection: sqlalchemy.engine.Connection,
1044 ) -> None:
1045 """Store a set of DiaForcedSources from current visit.
1047 Parameters
1048 ----------
1049 sources : `pandas.DataFrame`
1050 Catalog containing DiaForcedSource records
1051 """
1052 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1054 # Insert replica data
1055 replica_data: list[dict] = []
1056 replica_stmt: Any = None
1057 if replica_chunk is not None:
1058 pk_names = [column.name for column in table.primary_key]
1059 replica_data = sources[pk_names].to_dict("records")
1060 for row in replica_data:
1061 row["apdb_replica_chunk"] = replica_chunk.id
1062 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks)
1063 replica_stmt = replica_table.insert()
1065 # everything to be done in single transaction
1066 with Timer("DiaForcedSource insert", self.config.timer):
1067 sources = _coerce_uint64(sources)
1068 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1069 if replica_stmt is not None:
1070 connection.execute(replica_stmt, replica_data)
1072 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1073 """Generate a set of HTM indices covering specified region.
1075 Parameters
1076 ----------
1077 region: `sphgeom.Region`
1078 Region that needs to be indexed.
1080 Returns
1081 -------
1082 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1083 """
1084 _LOG.debug("region: %s", region)
1085 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
1087 return indices.ranges()
1089 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1090 """Make SQLAlchemy expression for selecting records in a region."""
1091 htm_index_column = table.columns[self.config.htm_index_column]
1092 exprlist = []
1093 pixel_ranges = self._htm_indices(region)
1094 for low, upper in pixel_ranges:
1095 upper -= 1
1096 if low == upper:
1097 exprlist.append(htm_index_column == low)
1098 else:
1099 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1101 return sql.expression.or_(*exprlist)
1103 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1104 """Calculate HTM index for each record and add it to a DataFrame.
1106 Notes
1107 -----
1108 This overrides any existing column in a DataFrame with the same name
1109 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1110 returned.
1111 """
1112 # calculate HTM index for every DiaObject
1113 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1114 ra_col, dec_col = self.config.ra_dec_columns
1115 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1116 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1117 idx = self.pixelator.index(uv3d)
1118 htm_index[i] = idx
1119 df = df.copy()
1120 df[self.config.htm_index_column] = htm_index
1121 return df
1123 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1124 """Add pixelId column to DiaSource catalog.
1126 Notes
1127 -----
1128 This method copies pixelId value from a matching DiaObject record.
1129 DiaObject catalog needs to have a pixelId column filled by
1130 ``_add_obj_htm_index`` method and DiaSource records need to be
1131 associated to DiaObjects via ``diaObjectId`` column.
1133 This overrides any existing column in a DataFrame with the same name
1134 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1135 returned.
1136 """
1137 pixel_id_map: dict[int, int] = {
1138 diaObjectId: pixelId
1139 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1140 }
1141 # DiaSources associated with SolarSystemObjects do not have an
1142 # associated DiaObject hence we skip them and set their htmIndex
1143 # value to 0.
1144 pixel_id_map[0] = 0
1145 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1146 for i, diaObjId in enumerate(sources["diaObjectId"]):
1147 htm_index[i] = pixel_id_map[diaObjId]
1148 sources = sources.copy()
1149 sources[self.config.htm_index_column] = htm_index
1150 return sources