Coverage for python / lsst / dax / apdb / sql / apdbSql.py: 12%
702 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:43 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:43 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods."""
24from __future__ import annotations
26__all__ = ["ApdbSql"]
28import datetime
29import json
30import logging
31import urllib.parse
32import uuid
33import warnings
34from collections import Counter
35from collections.abc import Iterable, Mapping, MutableMapping
36from contextlib import closing
37from typing import TYPE_CHECKING, Any
39import astropy.time
40import numpy as np
41import pandas
42import sqlalchemy
43import sqlalchemy.dialects.postgresql
44import sqlalchemy.dialects.sqlite
45from sqlalchemy import func, sql
46from sqlalchemy.pool import NullPool
48from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
49from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError
50from lsst.utils.iteration import chunk_iterable
52from ..apdb import Apdb
53from ..apdbConfigFreezer import ApdbConfigFreezer
54from ..apdbReplica import ReplicaChunk
55from ..apdbSchema import ApdbSchema, ApdbTables
56from ..apdbUpdateRecord import (
57 ApdbCloseDiaObjectValidityRecord,
58 ApdbReassignDiaSourceToDiaObjectRecord,
59 ApdbUpdateNDiaSourcesRecord,
60 ApdbUpdateRecord,
61)
62from ..config import ApdbConfig
63from ..monitor import MonAgent
64from ..recordIds import DiaObjectId, DiaSourceId
65from ..schema_model import Table
66from ..timer import Timer
67from ..versionTuple import IncompatibleVersionError, VersionTuple
68from .apdbMetadataSql import ApdbMetadataSql
69from .apdbSqlAdmin import ApdbSqlAdmin
70from .apdbSqlReplica import ApdbSqlReplica, ApdbSqlTableData
71from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
72from .config import ApdbSqlConfig
74if TYPE_CHECKING:
75 import sqlite3
77 from ..apdbMetadata import ApdbMetadata
78 from ..apdbUpdateRecord import ApdbUpdateRecord
80_LOG = logging.getLogger(__name__)
82_MON = MonAgent(__name__)
84VERSION = VersionTuple(1, 2, 1)
85"""Version for the code controlling non-replication tables. This needs to be
86updated following compatibility rules when schema produced by this code
87changes.
88"""
91def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
92 """Change the type of uint64 columns to int64, and return copy of data
93 frame.
94 """
95 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
96 return df.astype(dict.fromkeys(names, np.int64))
99def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float:
100 """Calculate starting point for time-based source search.
102 Parameters
103 ----------
104 visit_time : `astropy.time.Time`
105 Time of current visit.
106 months : `int`
107 Number of months in the sources history.
109 Returns
110 -------
111 time : `float`
112 A ``midpointMjdTai`` starting point, MJD time.
113 """
114 # TODO: Use of MJD must be consistent with the code in ap_association
115 # (see DM-31996)
116 return float(visit_time.tai.mjd) - months * 30
119def _onSqlite3Connect(
120 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
121) -> None:
122 # Enable foreign keys
123 with closing(dbapiConnection.cursor()) as cursor:
124 cursor.execute("PRAGMA foreign_keys=ON;")
127class ApdbSql(Apdb):
128 """Implementation of APDB interface based on SQL database.
130 The implementation is configured via standard ``pex_config`` mechanism
131 using `ApdbSqlConfig` configuration class. For an example of different
132 configurations check ``config/`` folder.
134 Parameters
135 ----------
136 config : `ApdbSqlConfig`
137 Configuration object.
138 """
140 metadataSchemaVersionKey = "version:schema"
141 """Name of the metadata key to store schema version number."""
143 metadataCodeVersionKey = "version:ApdbSql"
144 """Name of the metadata key to store code version number."""
146 metadataReplicaVersionKey = "version:ApdbSqlReplica"
147 """Name of the metadata key to store replica code version number."""
149 metadataConfigKey = "config:apdb-sql.json"
150 """Name of the metadata key to store code version number."""
152 metadataDedupKey = "status:deduplication.json"
153 """Name of the metadata key to store code version number."""
155 _frozen_parameters = (
156 "enable_replica",
157 "dia_object_index",
158 "pixelization.htm_level",
159 "pixelization.htm_index_column",
160 "ra_dec_columns",
161 )
162 """Names of the config parameters to be frozen in metadata table."""
164 def __init__(self, config: ApdbSqlConfig):
165 self._engine = self._makeEngine(config, create=False)
167 sa_metadata = sqlalchemy.MetaData(schema=config.namespace)
168 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix)
169 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine)
170 self._metadata = ApdbMetadataSql(self._engine, meta_table)
172 # Get tables schemas.
173 self._table_schema = ApdbSchema(config.schema_file, config.ss_schema_file)
175 # Check that versions are compatible, must be the first thing before
176 # reading frozen config.
177 self._db_schema_version = self._versionCheck(self._metadata, self._table_schema.schemaVersion())
179 # Read frozen config from metadata.
180 config_json = self._metadata.get(self.metadataConfigKey)
181 if config_json is not None:
182 # Update config from metadata.
183 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters)
184 self.config = freezer.update(config, config_json)
185 else:
186 self.config = config
188 self._schema = ApdbSqlSchema(
189 table_schema=self._table_schema,
190 engine=self._engine,
191 dia_object_index=self.config.dia_object_index,
192 prefix=self.config.prefix,
193 namespace=self.config.namespace,
194 htm_index_column=self.config.pixelization.htm_index_column,
195 enable_replica=self.config.enable_replica,
196 )
198 self.pixelator = HtmPixelization(self.config.pixelization.htm_level)
200 if _LOG.isEnabledFor(logging.DEBUG):
201 _LOG.debug("ApdbSql Configuration: %s", self.config.model_dump())
203 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer:
204 """Create `Timer` instance given its name."""
205 return Timer(name, _MON, tags=tags)
207 @classmethod
208 def _makeEngine(cls, config: ApdbSqlConfig, *, create: bool) -> sqlalchemy.engine.Engine:
209 """Make SQLALchemy engine based on configured parameters.
211 Parameters
212 ----------
213 config : `ApdbSqlConfig`
214 Configuration object.
215 create : `bool`
216 Whether to try to create new database file, only relevant for
217 SQLite backend which always creates new files by default.
218 """
219 # engine is reused between multiple processes, make sure that we don't
220 # share connections by disabling pool (by using NullPool class)
221 kw: MutableMapping[str, Any] = dict(config.connection_config.extra_parameters)
222 conn_args: dict[str, Any] = {}
223 if not config.connection_config.connection_pool:
224 kw.update(poolclass=NullPool)
225 if config.connection_config.isolation_level is not None:
226 kw.update(isolation_level=config.connection_config.isolation_level)
227 elif config.db_url.startswith("sqlite"):
228 # Use READ_UNCOMMITTED as default value for sqlite.
229 kw.update(isolation_level="READ_UNCOMMITTED")
230 if config.connection_config.connection_timeout is not None:
231 if config.db_url.startswith("sqlite"):
232 conn_args.update(timeout=config.connection_config.connection_timeout)
233 elif config.db_url.startswith(("postgresql", "mysql")):
234 conn_args.update(connect_timeout=int(config.connection_config.connection_timeout))
235 kw.update(connect_args=conn_args)
236 engine = sqlalchemy.create_engine(cls._connection_url(config.db_url, create=create), **kw)
238 if engine.dialect.name == "sqlite":
239 # Need to enable foreign keys on every new connection.
240 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
242 return engine
244 @classmethod
245 def _connection_url(cls, config_url: str, *, create: bool) -> sqlalchemy.engine.URL | str:
246 """Generate a complete URL for database with proper credentials.
248 Parameters
249 ----------
250 config_url : `str`
251 Database URL as specified in configuration.
252 create : `bool`
253 Whether to try to create new database file, only relevant for
254 SQLite backend which always creates new files by default.
256 Returns
257 -------
258 connection_url : `sqlalchemy.engine.URL` or `str`
259 Connection URL including credentials.
260 """
261 # Allow 3rd party authentication mechanisms by assuming connection
262 # string is correct when we can not recognize (dialect, host, database)
263 # matching keys.
264 components = urllib.parse.urlparse(config_url)
265 if all((components.scheme is not None, components.hostname is not None, components.path is not None)):
266 try:
267 db_auth = DbAuth()
268 config_url = db_auth.getUrl(config_url)
269 except DbAuthNotFoundError:
270 # Credentials file doesn't exist or no matching credentials,
271 # use default auth.
272 pass
274 # SQLite has a nasty habit creating empty databases when they do not
275 # exist, tell it not to do that unless we do need to create it.
276 if not create:
277 config_url = cls._update_sqlite_url(config_url)
279 return config_url
281 @classmethod
282 def _update_sqlite_url(cls, url_string: str) -> str:
283 """If URL refers to sqlite dialect, update it so that the backend does
284 not try to create database file if it does not exist already.
286 Parameters
287 ----------
288 url_string : `str`
289 Connection string.
291 Returns
292 -------
293 url_string : `str`
294 Possibly updated connection string.
295 """
296 try:
297 url = sqlalchemy.make_url(url_string)
298 except sqlalchemy.exc.SQLAlchemyError:
299 # If parsing fails it means some special format, likely not
300 # sqlite so we just return it unchanged.
301 return url_string
303 if url.get_backend_name() == "sqlite":
304 # Massage url so that database name starts with "file:" and
305 # option string has "mode=rw&uri=true". Database name
306 # should look like a path (:memory: is not supported by
307 # Apdb, but someone could still try to use it).
308 database = url.database
309 if database and not database.startswith((":", "file:")):
310 query = dict(url.query, mode="rw", uri="true")
311 # If ``database`` is an absolute path then original URL should
312 # include four slashes after "sqlite:". Humans are bad at
313 # counting things beyond four and sometimes an extra slash gets
314 # added unintentionally, which causes sqlite to treat initial
315 # element as "authority" and to complain. Strip extra slashes
316 # at the start of the path to avoid that (DM-46077).
317 if database.startswith("//"):
318 warnings.warn(
319 f"Database URL contains extra leading slashes which will be removed: {url}",
320 stacklevel=3,
321 )
322 database = "/" + database.lstrip("/")
323 url = url.set(database=f"file:{database}", query=query)
324 url_string = url.render_as_string()
326 return url_string
328 @classmethod
329 def _versionCheck(cls, metadata: ApdbMetadataSql, schema_version: VersionTuple) -> VersionTuple:
330 """Check schema version compatibility and return the database schema
331 version.
332 """
334 def _get_version(key: str) -> VersionTuple:
335 """Retrieve version number from given metadata key."""
336 version_str = metadata.get(key)
337 if version_str is None:
338 # Should not happen with existing metadata table.
339 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
340 return VersionTuple.fromString(version_str)
342 db_schema_version = _get_version(cls.metadataSchemaVersionKey)
343 db_code_version = _get_version(cls.metadataCodeVersionKey)
345 # For now there is no way to make read-only APDB instances, assume that
346 # any access can do updates.
347 if not schema_version.checkCompatibility(db_schema_version):
348 raise IncompatibleVersionError(
349 f"Configured schema version {schema_version} "
350 f"is not compatible with database version {db_schema_version}"
351 )
352 if not cls.apdbImplementationVersion().checkCompatibility(db_code_version):
353 raise IncompatibleVersionError(
354 f"Current code version {cls.apdbImplementationVersion()} "
355 f"is not compatible with database version {db_code_version}"
356 )
358 # Check replica code version only if replica is enabled. Sort of
359 # chicken and egg problem - `enable_replica` is a part of frozen
360 # configuration, but we cannot read frozen configuration until we
361 # validate versions. Assume that if the replica version is present
362 # then replication is enabled.
363 if metadata.get(cls.metadataReplicaVersionKey) is not None:
364 db_replica_version = _get_version(cls.metadataReplicaVersionKey)
365 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion()
366 if not code_replica_version.checkCompatibility(db_replica_version):
367 raise IncompatibleVersionError(
368 f"Current replication code version {code_replica_version} "
369 f"is not compatible with database version {db_replica_version}"
370 )
372 return db_schema_version
374 @classmethod
375 def apdbImplementationVersion(cls) -> VersionTuple:
376 """Return version number for current APDB implementation.
378 Returns
379 -------
380 version : `VersionTuple`
381 Version of the code defined in implementation class.
382 """
383 return VERSION
385 @classmethod
386 def init_database(
387 cls,
388 db_url: str,
389 *,
390 schema_file: str | None = None,
391 ss_schema_file: str | None = None,
392 read_sources_months: int | None = None,
393 read_forced_sources_months: int | None = None,
394 enable_replica: bool = False,
395 connection_timeout: int | None = None,
396 dia_object_index: str | None = None,
397 htm_level: int | None = None,
398 htm_index_column: str | None = None,
399 ra_dec_columns: tuple[str, str] | None = None,
400 prefix: str | None = None,
401 namespace: str | None = None,
402 drop: bool = False,
403 ) -> ApdbSqlConfig:
404 """Initialize new APDB instance and make configuration object for it.
406 Parameters
407 ----------
408 db_url : `str`
409 SQLAlchemy database URL.
410 schema_file : `str`, optional
411 Location of (YAML) configuration file with APDB schema. If not
412 specified then default location will be used.
413 ss_schema_file : `str`, optional
414 Location of (YAML) configuration file with SSO schema. If not
415 specified then default location will be used.
416 read_sources_months : `int`, optional
417 Number of months of history to read from DiaSource.
418 read_forced_sources_months : `int`, optional
419 Number of months of history to read from DiaForcedSource.
420 enable_replica : `bool`, optional
421 If True, make additional tables used for replication to PPDB.
422 connection_timeout : `int`, optional
423 Database connection timeout in seconds.
424 dia_object_index : `str`, optional
425 Indexing mode for DiaObject table.
426 htm_level : `int`, optional
427 HTM indexing level.
428 htm_index_column : `str`, optional
429 Name of a HTM index column for DiaObject and DiaSource tables.
430 ra_dec_columns : `tuple` [`str`, `str`], optional
431 Names of ra/dec columns in DiaObject table.
432 prefix : `str`, optional
433 Optional prefix for all table names.
434 namespace : `str`, optional
435 Name of the database schema for all APDB tables. If not specified
436 then default schema is used.
437 drop : `bool`, optional
438 If `True` then drop existing tables before re-creating the schema.
440 Returns
441 -------
442 config : `ApdbSqlConfig`
443 Resulting configuration object for a created APDB instance.
444 """
445 config = ApdbSqlConfig(db_url=db_url, enable_replica=enable_replica)
446 if schema_file is not None:
447 config.schema_file = schema_file
448 if ss_schema_file is not None:
449 config.ss_schema_file = ss_schema_file
450 if read_sources_months is not None:
451 config.read_sources_months = read_sources_months
452 if read_forced_sources_months is not None:
453 config.read_forced_sources_months = read_forced_sources_months
454 if connection_timeout is not None:
455 config.connection_config.connection_timeout = connection_timeout
456 if dia_object_index is not None:
457 config.dia_object_index = dia_object_index
458 if htm_level is not None:
459 config.pixelization.htm_level = htm_level
460 if htm_index_column is not None:
461 config.pixelization.htm_index_column = htm_index_column
462 if ra_dec_columns is not None:
463 config.ra_dec_columns = ra_dec_columns
464 if prefix is not None:
465 config.prefix = prefix
466 if namespace is not None:
467 config.namespace = namespace
469 cls._makeSchema(config, drop=drop)
471 # SQLite has a nasty habit of creating empty database by default,
472 # update URL in config file to disable that behavior.
473 config.db_url = cls._update_sqlite_url(config.db_url)
475 return config
477 def get_replica(self) -> ApdbSqlReplica:
478 """Return `ApdbReplica` instance for this database."""
479 return ApdbSqlReplica(self._schema, self._engine, self._db_schema_version)
481 def tableRowCount(self) -> dict[str, int]:
482 """Return dictionary with the table names and row counts.
484 Used by ``ap_proto`` to keep track of the size of the database tables.
485 Depending on database technology this could be expensive operation.
487 Returns
488 -------
489 row_counts : `dict`
490 Dict where key is a table name and value is a row count.
491 """
492 res = {}
493 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
494 if self.config.dia_object_index == "last_object_table":
495 tables.append(ApdbTables.DiaObjectLast)
496 with self._engine.begin() as conn:
497 for table in tables:
498 sa_table = self._schema.get_table(table)
499 stmt = sql.select(func.count()).select_from(sa_table)
500 count: int = conn.execute(stmt).scalar_one()
501 res[table.name] = count
503 return res
505 def getConfig(self) -> ApdbSqlConfig:
506 # docstring is inherited from a base class
507 return self.config
509 def tableDef(self, table: ApdbTables) -> Table | None:
510 # docstring is inherited from a base class
511 return self._schema.tableSchemas.get(table)
513 @classmethod
514 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None:
515 # docstring is inherited from a base class
517 if not isinstance(config, ApdbSqlConfig):
518 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
520 engine = cls._makeEngine(config, create=True)
522 table_schema = ApdbSchema(config.schema_file, config.ss_schema_file)
524 # Ask schema class to create all tables.
525 schema = ApdbSqlSchema(
526 table_schema=table_schema,
527 engine=engine,
528 dia_object_index=config.dia_object_index,
529 prefix=config.prefix,
530 namespace=config.namespace,
531 htm_index_column=config.pixelization.htm_index_column,
532 enable_replica=config.enable_replica,
533 )
534 schema.makeSchema(drop=drop)
536 # Need metadata table to store few items in it.
537 meta_table = schema.get_table(ApdbTables.metadata)
538 apdb_meta = ApdbMetadataSql(engine, meta_table)
540 # Fill version numbers, overwrite if they are already there.
541 apdb_meta.set(cls.metadataSchemaVersionKey, str(table_schema.schemaVersion()), force=True)
542 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
543 if config.enable_replica:
544 # Only store replica code version if replica is enabled.
545 apdb_meta.set(
546 cls.metadataReplicaVersionKey,
547 str(ApdbSqlReplica.apdbReplicaImplementationVersion()),
548 force=True,
549 )
551 # Store frozen part of a configuration in metadata.
552 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters)
553 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
555 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
556 # docstring is inherited from a base class
558 # decide what columns we need
559 if self.config.dia_object_index == "last_object_table":
560 table_enum = ApdbTables.DiaObjectLast
561 else:
562 table_enum = ApdbTables.DiaObject
563 table = self._schema.get_table(table_enum)
564 if not self.config.dia_object_columns:
565 columns = self._schema.get_apdb_columns(table_enum)
566 else:
567 columns = [table.c[col] for col in self.config.dia_object_columns]
568 query = sql.select(*columns)
570 # build selection
571 query = query.where(self._filterRegion(table, region))
573 validity_end_column = self._timestamp_column_name("validityEnd")
575 # select latest version of objects
576 if self.config.dia_object_index != "last_object_table":
577 query = query.where(table.columns[validity_end_column] == None) # noqa: E711
579 # _LOG.debug("query: %s", query)
581 # execute select
582 with self._timer("select_time", tags={"table": "DiaObject"}) as timer:
583 with self._engine.begin() as conn:
584 result = conn.execute(query)
585 column_defs = self._table_schema.tableSchemas[table_enum].columns
586 table_data = ApdbSqlTableData(result, column_defs)
587 objects = table_data.to_pandas()
588 timer.add_values(row_count=len(objects))
589 _LOG.debug("found %s DiaObjects", len(objects))
590 return self._fix_result_timestamps(objects)
592 def getDiaSources(
593 self,
594 region: Region,
595 object_ids: Iterable[int] | None,
596 visit_time: astropy.time.Time,
597 start_time: astropy.time.Time | None = None,
598 ) -> pandas.DataFrame | None:
599 # docstring is inherited from a base class
600 if start_time is None and self.config.read_sources_months == 0:
601 _LOG.debug("Skip DiaSources fetching")
602 return None
604 if start_time is None:
605 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
606 else:
607 start_time_mjdTai = float(start_time.tai.mjd)
608 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai)
610 if object_ids is None:
611 # region-based select
612 return self._getDiaSourcesInRegion(region, start_time_mjdTai)
613 else:
614 return self._getDiaSourcesByIDs(list(object_ids), start_time_mjdTai)
616 def getDiaForcedSources(
617 self,
618 region: Region,
619 object_ids: Iterable[int] | None,
620 visit_time: astropy.time.Time,
621 start_time: astropy.time.Time | None = None,
622 ) -> pandas.DataFrame | None:
623 # docstring is inherited from a base class
624 if start_time is None and self.config.read_forced_sources_months == 0:
625 _LOG.debug("Skip DiaForceSources fetching")
626 return None
628 if object_ids is None:
629 # This implementation does not support region-based selection. In
630 # the past DiaForcedSource schema did not have ra/dec columns (it
631 # had x/y columns). ra/dec were added at some point, so we could
632 # add pixelOd column to this table if/when needed.
633 raise NotImplementedError("Region-based selection is not supported")
635 # TODO: DateTime.MJD must be consistent with code in ap_association,
636 # alternatively we can fill midpointMjdTai ourselves in store()
637 if start_time is None:
638 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
639 else:
640 start_time_mjdTai = float(start_time.tai.mjd)
641 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai)
643 with self._timer("select_time", tags={"table": "DiaForcedSource"}) as timer:
644 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), start_time_mjdTai)
645 timer.add_values(row_count=len(sources))
647 _LOG.debug("found %s DiaForcedSources", len(sources))
648 return sources
650 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame:
651 # docstring is inherited from a base class
653 if since is None:
654 # Read last deduplication time from metadata.
655 dedup_str = self._metadata.get(self.metadataDedupKey)
656 if dedup_str is not None:
657 dedup_state = json.loads(dedup_str)
658 dedup_time_str = dedup_state["dedup_time_iso_tai"]
659 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai")
661 validity_start_column = self._timestamp_column_name("validityStart")
663 # decide what columns we need
664 if self.config.dia_object_index == "last_object_table":
665 table_enum = ApdbTables.DiaObjectLast
666 else:
667 table_enum = ApdbTables.DiaObject
668 table = self._schema.get_table(table_enum)
670 if not self.config.dia_object_columns_for_dedup:
671 columns = self._schema.get_apdb_columns(table_enum)
672 else:
673 column_names = list(self.config.dia_object_columns_for_dedup)
674 if validity_start_column not in column_names:
675 column_names.insert(0, validity_start_column)
676 if "diaObjectId" not in column_names:
677 column_names.insert(0, "diaObjectId")
678 columns = [table.columns[col] for col in column_names]
680 query = sql.select(*columns)
682 # build selection
683 if since is not None:
684 timestamp = self._timestamp_column_value(since)
685 query = query.where(table.columns[validity_start_column] >= timestamp)
687 # execute select
688 with self._timer("select_time", tags={"table": "DiaObject"}) as timer:
689 with self._engine.begin() as conn:
690 result = conn.execute(query)
691 column_defs = self._table_schema.tableSchemas[table_enum].columns
692 table_data = ApdbSqlTableData(result, column_defs)
693 objects = table_data.to_pandas()
694 timer.add_values(row_count=len(objects))
695 _LOG.debug("found %s DiaObjects", len(objects))
696 return self._fix_result_timestamps(objects)
698 def getDiaSourcesForDiaObjects(
699 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0
700 ) -> pandas.DataFrame:
701 # docstring is inherited from a base class
702 object_ids = {object_id.diaObjectId for object_id in objects}
703 return self._getDiaSourcesByIDs(list(object_ids), float(start_time.tai.mjd))
705 def containsVisitDetector(
706 self,
707 visit: int,
708 detector: int,
709 region: Region,
710 visit_time: astropy.time.Time,
711 ) -> bool:
712 # docstring is inherited from a base class
713 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
714 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
715 # Query should load only one leaf page of the index
716 query1 = sql.select(src_table.c.visit).filter_by(visit=visit, detector=detector).limit(1)
718 with self._engine.begin() as conn:
719 result = conn.execute(query1).scalar_one_or_none()
720 if result is not None:
721 return True
722 else:
723 # Backup query if an image was processed but had no diaSources
724 query2 = sql.select(frcsrc_table.c.visit).filter_by(visit=visit, detector=detector).limit(1)
725 result = conn.execute(query2).scalar_one_or_none()
726 return result is not None
728 def store(
729 self,
730 visit_time: astropy.time.Time,
731 objects: pandas.DataFrame,
732 sources: pandas.DataFrame | None = None,
733 forced_sources: pandas.DataFrame | None = None,
734 ) -> None:
735 # docstring is inherited from a base class
736 objects = self._fix_input_timestamps(objects)
737 if sources is not None:
738 sources = self._fix_input_timestamps(sources)
739 if forced_sources is not None:
740 forced_sources = self._fix_input_timestamps(forced_sources)
742 # We want to run all inserts in one transaction.
743 with self._engine.begin() as connection:
744 replica_chunk: ReplicaChunk | None = None
745 if self._schema.replication_enabled:
746 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds)
747 self._storeReplicaChunk(replica_chunk, connection)
749 # fill pixelId column for DiaObjects
750 objects = self._add_spatial_index(objects)
751 self._storeDiaObjects(objects, visit_time, replica_chunk, connection)
753 if sources is not None:
754 # fill pixelId column for DiaSources
755 sources = self._add_spatial_index(sources)
756 self._storeDiaSources(sources, replica_chunk, connection)
758 if forced_sources is not None:
759 self._storeDiaForcedSources(forced_sources, replica_chunk, connection)
761 def reassignDiaSourcesToDiaObjects(
762 self,
763 idMap: Mapping[DiaSourceId, int],
764 *,
765 increment_nDiaSources: bool = True,
766 decrement_nDiaSources: bool = True,
767 ) -> None:
768 # docstring is inherited from a base class
770 new_object_ids = set(idMap.values())
771 source_ids = {source.diaSourceId for source in idMap}
773 current_time = self._current_time()
774 current_time_ns = int(current_time.unix_tai * 1e9)
776 with self._engine.begin() as conn:
777 # Make sure that all DiaSources exist.
778 found_sources = self._get_diasource_data(conn, source_ids, "diaObjectId")
779 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}):
780 raise LookupError(f"Some source IDs are missing from DiaSource table: {missing_ids}")
781 original_object_ids = {row.diaSourceId: row.diaObjectId for row in found_sources}
783 # Make sure that all DiaObjects exist, we also want to know
784 # nDiaSources count for current and new records because we want to
785 # send updated values to replica.
786 all_object_ids = new_object_ids | set(original_object_ids.values())
787 found_objects = self._get_diaobject_data(conn, all_object_ids, "ra", "dec", "nDiaSources")
788 if missing_ids := (new_object_ids - {row.diaObjectId for row in found_objects}):
789 raise LookupError(f"Some object IDs are missing from DiaObject table: {missing_ids}")
791 found_objects_by_id = {row.diaObjectId: row for row in found_objects}
793 update_records: list[ApdbUpdateRecord] = []
794 update_order = 0
796 # Update DiaSources.
797 source_table = self._schema.get_table(ApdbTables.DiaSource)
798 for source, diaObjectId in idMap.items():
799 update = (
800 source_table.update()
801 .where(source_table.columns["diaSourceId"] == source.diaSourceId)
802 .values(diaObjectId=diaObjectId)
803 )
804 conn.execute(update)
806 if self._schema.replication_enabled:
807 update_records.append(
808 ApdbReassignDiaSourceToDiaObjectRecord(
809 diaSourceId=source.diaSourceId,
810 ra=source.ra,
811 dec=source.dec,
812 midpointMjdTai=source.midpointMjdTai,
813 diaObjectId=diaObjectId,
814 update_time_ns=current_time_ns,
815 update_order=update_order,
816 )
817 )
818 update_order += 1
820 # DiaObject tables to update.
821 object_tables = [self._schema.get_table(ApdbTables.DiaObject)]
822 if self.config.dia_object_index == "last_object_table":
823 object_tables.append(self._schema.get_table(ApdbTables.DiaObjectLast))
825 # Things to increment/decrement.
826 increments: Counter = Counter()
827 if increment_nDiaSources:
828 increments.update(idMap.values())
829 if decrement_nDiaSources:
830 increments.subtract(original_object_ids[source_id.diaSourceId] for source_id in idMap)
832 if increments:
833 for table in object_tables:
834 for diaObjectId, increment in increments.items():
835 update = (
836 table.update()
837 .where(table.columns["diaObjectId"] == diaObjectId)
838 .values(nDiaSources=table.columns["nDiaSources"] + increment)
839 )
840 conn.execute(update)
842 # Also send updated values to replica.
843 if self._schema.replication_enabled:
844 for diaObjectId, increment in increments.items():
845 dia_object = found_objects_by_id[diaObjectId]
846 update_records.append(
847 ApdbUpdateNDiaSourcesRecord(
848 diaObjectId=diaObjectId,
849 ra=dia_object.ra,
850 dec=dia_object.dec,
851 nDiaSources=dia_object.nDiaSources + increment,
852 update_time_ns=current_time_ns,
853 update_order=update_order,
854 )
855 )
856 update_order += 1
858 if update_records:
859 replica_chunk = ReplicaChunk.make_replica_chunk(
860 current_time, self.config.replica_chunk_seconds
861 )
862 self._storeUpdateRecords(update_records, replica_chunk, connection=conn, store_chunk=True)
864 def setValidityEnd(
865 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False
866 ) -> int:
867 # docstring is inherited from a base class
868 if not objects:
869 return 0
871 requested_ids = {obj.diaObjectId for obj in objects}
873 validity_end_column = self._timestamp_column_name("validityEnd")
874 validityEnd_value = self._timestamp_column_value(validityEnd)
876 # Find all matching DiaObjects with validityEnd = NULL.
877 table = self._schema.get_table(ApdbTables.DiaObject)
878 query = sql.select(table.columns["diaObjectId"]).where(
879 sqlalchemy.and_(
880 table.columns["diaObjectId"].in_(sorted(requested_ids)),
881 table.columns[validity_end_column].is_(None),
882 )
883 )
885 with self._engine.begin() as conn:
886 result = conn.execute(query)
887 found_ids = set(result.scalars())
889 # Check that we found all that is requested.
890 if raise_on_missing_id:
891 if missing_ids := (requested_ids - found_ids):
892 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}")
894 # Filter existing records.
895 if len(objects) != len(found_ids):
896 objects = [obj for obj in objects if obj.diaObjectId in found_ids]
898 if not objects:
899 return 0
901 values = {validity_end_column: validityEnd_value}
902 update = (
903 table.update()
904 .where(
905 sqlalchemy.and_(
906 table.columns["diaObjectId"].in_(sorted(found_ids)),
907 table.columns[validity_end_column].is_(None),
908 )
909 )
910 .values(**values)
911 )
912 result = conn.execute(update)
913 if result.rowcount != len(found_ids):
914 raise RuntimeError(
915 f"Unexpected mismatch in the number of records updated. Object IDs = {found_ids}"
916 )
918 # Also drop them from DiaObjectLast.
919 if self.config.dia_object_index == "last_object_table":
920 last_table = self._schema.get_table(ApdbTables.DiaObjectLast)
921 delete = last_table.delete().where(last_table.columns["diaObjectId"].in_(sorted(found_ids)))
922 result = conn.execute(delete)
923 if result.rowcount != len(found_ids):
924 raise RuntimeError(
925 f"Unexpected mismatch in the number of records deleted. Object IDs = {found_ids}"
926 )
928 # If replication is enabled then send all updates.
929 if self._schema.replication_enabled:
930 current_time = self._current_time()
931 current_time_ns = int(current_time.unix_tai * 1e9)
932 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, self.config.replica_chunk_seconds)
934 update_records = [
935 ApdbCloseDiaObjectValidityRecord(
936 diaObjectId=obj.diaObjectId,
937 ra=obj.ra,
938 dec=obj.dec,
939 update_time_ns=current_time_ns,
940 update_order=index,
941 validityEndMjdTai=float(validityEnd.tai.mjd),
942 nDiaSources=None,
943 )
944 for index, obj in enumerate(objects)
945 ]
947 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True)
949 return len(objects)
951 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None:
952 # docstring is inherited from a base class
954 # SQL backend does not have separate dedup tables, nothing to delete,
955 # only save last dedup time in metadata.
956 if dedup_time is None:
957 dedup_time = self._current_time()
958 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")}
959 data_json = json.dumps(data)
960 self._metadata.set(self.metadataDedupKey, data_json, force=True)
962 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
963 # docstring is inherited from a base class
965 timestamp: float | datetime.datetime
966 now = self._current_time()
967 timestamp_column = self._timestamp_column_name("ssObjectReassocTime")
968 timestamp = self._timestamp_column_value(now)
970 table = self._schema.get_table(ApdbTables.DiaSource)
971 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
973 with self._engine.begin() as conn:
974 # Need to make sure that every ID exists in the database, but
975 # executemany may not support rowcount, so iterate and check what
976 # is missing.
977 missing_ids: list[int] = []
978 for key, value in idMap.items():
979 params = {
980 "srcId": key,
981 "diaObjectId": 0,
982 "ssObjectId": value,
983 timestamp_column: timestamp,
984 }
985 result = conn.execute(query, params)
986 if result.rowcount == 0:
987 missing_ids.append(key)
988 if missing_ids:
989 missing = ",".join(str(item) for item in missing_ids)
990 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
992 def countUnassociatedObjects(self) -> int:
993 # docstring is inherited from a base class
995 # Retrieve the DiaObject table.
996 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
998 # Construct the sql statement.
999 validity_end_column = self._timestamp_column_name("validityEnd")
1000 stmt = (
1001 sql.select(func.count())
1002 .select_from(table)
1003 .where(
1004 sqlalchemy.and_(
1005 table.columns["nDiaSources"] == 1,
1006 table.columns[validity_end_column].is_(None),
1007 )
1008 )
1009 )
1011 # Return the count.
1012 with self._engine.begin() as conn:
1013 count = conn.execute(stmt).scalar_one()
1015 return count
1017 @property
1018 def schema(self) -> ApdbSchema:
1019 # docstring is inherited from a base class
1020 return self._table_schema
1022 @property
1023 def metadata(self) -> ApdbMetadata:
1024 # docstring is inherited from a base class
1025 return self._metadata
1027 @property
1028 def admin(self) -> ApdbSqlAdmin:
1029 # docstring is inherited from a base class
1030 return ApdbSqlAdmin(self.pixelator)
1032 def _getDiaSourcesInRegion(self, region: Region, start_time_mjdTai: float) -> pandas.DataFrame:
1033 """Return catalog of DiaSource instances from given region.
1035 Parameters
1036 ----------
1037 region : `lsst.sphgeom.Region`
1038 Region to search for DIASources.
1039 start_time_mjdTai : `float`
1040 Lower bound of time window for the query.
1042 Returns
1043 -------
1044 catalog : `pandas.DataFrame`
1045 Catalog containing DiaSource records.
1046 """
1047 table = self._schema.get_table(ApdbTables.DiaSource)
1048 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
1049 query = sql.select(*columns)
1051 # build selection
1052 time_filter = table.columns["midpointMjdTai"] > start_time_mjdTai
1053 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
1054 query = query.where(where)
1056 # execute select
1057 with self._timer("DiaSource_select_time", tags={"table": "DiaSource"}) as timer:
1058 with self._engine.begin() as conn:
1059 result = conn.execute(query)
1060 column_defs = self._table_schema.tableSchemas[ApdbTables.DiaSource].columns
1061 table_data = ApdbSqlTableData(result, column_defs)
1062 sources = table_data.to_pandas()
1063 timer.add_values(row_counts=len(sources))
1064 _LOG.debug("found %s DiaSources", len(sources))
1065 return self._fix_result_timestamps(sources)
1067 def _getDiaSourcesByIDs(self, object_ids: list[int], start_time_mjdTai: float) -> pandas.DataFrame:
1068 """Return catalog of DiaSource instances given set of DiaObject IDs.
1070 Parameters
1071 ----------
1072 object_ids :
1073 Collection of DiaObject IDs
1074 start_time_mjdTai : `float`
1075 Lower bound of time window for the query.
1077 Returns
1078 -------
1079 catalog : `pandas.DataFrame`
1080 Catalog containing DiaSource records.
1081 """
1082 with self._timer("select_time", tags={"table": "DiaSource"}) as timer:
1083 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, start_time_mjdTai)
1084 timer.add_values(row_count=len(sources))
1086 _LOG.debug("found %s DiaSources", len(sources))
1087 return sources
1089 def _getSourcesByIDs(
1090 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
1091 ) -> pandas.DataFrame:
1092 """Return catalog of DiaSource or DiaForcedSource instances given set
1093 of DiaObject IDs.
1095 Parameters
1096 ----------
1097 table : `sqlalchemy.schema.Table`
1098 Database table.
1099 object_ids :
1100 Collection of DiaObject IDs
1101 midpointMjdTai_start : `float`
1102 Earliest midpointMjdTai to retrieve.
1104 Returns
1105 -------
1106 catalog : `pandas.DataFrame`
1107 Catalog contaning DiaSource records. `None` is returned if
1108 ``read_sources_months`` configuration parameter is set to 0 or
1109 when ``object_ids`` is empty.
1110 """
1111 table = self._schema.get_table(table_enum)
1112 columns = self._schema.get_apdb_columns(table_enum)
1113 column_defs = self._table_schema.tableSchemas[table_enum].columns
1115 sources: pandas.DataFrame | None = None
1116 if len(object_ids) <= 0:
1117 _LOG.debug("ID list is empty, just fetch empty result")
1118 query = sql.select(*columns).where(sql.literal(False))
1119 with self._engine.begin() as conn:
1120 result = conn.execute(query)
1121 table_data = ApdbSqlTableData(result, column_defs)
1122 sources = table_data.to_pandas()
1123 else:
1124 data_frames: list[pandas.DataFrame] = []
1125 for ids in chunk_iterable(sorted(object_ids), 1000):
1126 query = sql.select(*columns)
1128 # Some types like np.int64 can cause issues with
1129 # sqlalchemy, convert them to int.
1130 int_ids = [int(oid) for oid in ids]
1132 # select by object id
1133 query = query.where(
1134 sql.expression.and_(
1135 table.columns["diaObjectId"].in_(int_ids),
1136 table.columns["midpointMjdTai"] >= midpointMjdTai_start,
1137 )
1138 )
1140 # execute select
1141 with self._engine.begin() as conn:
1142 result = conn.execute(query)
1143 table_data = ApdbSqlTableData(result, column_defs)
1144 data_frames.append(table_data.to_pandas())
1146 if len(data_frames) == 1:
1147 sources = data_frames[0]
1148 else:
1149 sources = pandas.concat(data_frames)
1150 assert sources is not None, "Catalog cannot be None"
1151 return self._fix_result_timestamps(sources)
1153 def _storeReplicaChunk(
1154 self,
1155 replica_chunk: ReplicaChunk,
1156 connection: sqlalchemy.engine.Connection,
1157 ) -> None:
1158 # `visit_time.datetime` returns naive datetime, even though all astropy
1159 # times are in UTC. Add UTC timezone to timestamp so that database
1160 # can store a correct value.
1161 dt = datetime.datetime.fromtimestamp(replica_chunk.last_update_time.unix_tai, tz=datetime.UTC)
1163 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks)
1165 # We need UPSERT which is dialect-specific construct
1166 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id}
1167 row = {"apdb_replica_chunk": replica_chunk.id} | values
1168 if connection.dialect.name == "sqlite":
1169 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table)
1170 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values)
1171 connection.execute(insert_sqlite, row)
1172 elif connection.dialect.name == "postgresql":
1173 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table)
1174 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values)
1175 connection.execute(insert_pg, row)
1176 else:
1177 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.")
1179 def _storeDiaObjects(
1180 self,
1181 objs: pandas.DataFrame,
1182 visit_time: astropy.time.Time,
1183 replica_chunk: ReplicaChunk | None,
1184 connection: sqlalchemy.engine.Connection,
1185 ) -> None:
1186 """Store catalog of DiaObjects from current visit.
1188 Parameters
1189 ----------
1190 objs : `pandas.DataFrame`
1191 Catalog with DiaObject records.
1192 visit_time : `astropy.time.Time`
1193 Time of the visit.
1194 replica_chunk : `ReplicaChunk`
1195 Insert identifier.
1196 """
1197 if len(objs) == 0:
1198 _LOG.debug("No objects to write to database.")
1199 return
1201 # Some types like np.int64 can cause issues with sqlalchemy, convert
1202 # them to int.
1203 ids = sorted(int(oid) for oid in objs["diaObjectId"])
1204 _LOG.debug("first object ID: %d", ids[0])
1206 validity_start_column = self._timestamp_column_name("validityStart")
1207 validity_end_column = self._timestamp_column_name("validityEnd")
1208 timestamp = self._timestamp_column_value(visit_time)
1210 # everything to be done in single transaction
1211 if self.config.dia_object_index == "last_object_table":
1212 # Insert and replace all records in LAST table.
1213 table = self._schema.get_table(ApdbTables.DiaObjectLast)
1215 # DiaObjectLast did not have this column in the past.
1216 use_validity_start = self._schema.check_column(ApdbTables.DiaObjectLast, validity_start_column)
1218 # Drop the previous objects (pandas cannot upsert).
1219 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
1221 with self._timer("delete_time", tags={"table": table.name}) as timer:
1222 res = connection.execute(query)
1223 timer.add_values(row_count=res.rowcount)
1224 _LOG.debug("deleted %s objects", res.rowcount)
1226 # DiaObjectLast is a subset of DiaObject, strip missing columns
1227 last_column_names = [column.name for column in table.columns]
1228 if validity_start_column in last_column_names and validity_start_column not in objs.columns:
1229 last_column_names.remove(validity_start_column)
1230 last_objs = objs[last_column_names]
1231 last_objs = _coerce_uint64(last_objs)
1233 # Fill validityStart, only when it is in the schema.
1234 if use_validity_start:
1235 if validity_start_column in last_objs:
1236 last_objs[validity_start_column] = timestamp
1237 else:
1238 extra_column = pandas.Series([timestamp] * len(last_objs), name=validity_start_column)
1239 last_objs.set_index(extra_column.index, inplace=True)
1240 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
1242 with self._timer("insert_time", tags={"table": "DiaObjectLast"}) as timer:
1243 last_objs.to_sql(
1244 table.name,
1245 connection,
1246 if_exists="append",
1247 index=False,
1248 schema=table.schema,
1249 )
1250 timer.add_values(row_count=len(last_objs))
1251 else:
1252 # truncate existing validity intervals
1253 table = self._schema.get_table(ApdbTables.DiaObject)
1255 update = (
1256 table.update()
1257 .values(**{validity_end_column: timestamp})
1258 .where(
1259 sql.expression.and_(
1260 table.columns["diaObjectId"].in_(ids),
1261 table.columns[validity_end_column].is_(None),
1262 )
1263 )
1264 )
1266 with self._timer("truncate_time", tags={"table": table.name}) as timer:
1267 res = connection.execute(update)
1268 timer.add_values(row_count=res.rowcount)
1269 _LOG.debug("truncated %s intervals", res.rowcount)
1271 objs = _coerce_uint64(objs)
1273 # Fill additional columns
1274 extra_columns: list[pandas.Series] = []
1275 if validity_start_column in objs.columns:
1276 objs[validity_start_column] = timestamp
1277 else:
1278 extra_columns.append(pandas.Series([timestamp] * len(objs), name=validity_start_column))
1279 if validity_end_column in objs.columns:
1280 objs[validity_end_column] = None
1281 else:
1282 extra_columns.append(pandas.Series([None] * len(objs), name=validity_end_column))
1283 if extra_columns:
1284 objs.set_index(extra_columns[0].index, inplace=True)
1285 objs = pandas.concat([objs] + extra_columns, axis="columns")
1287 # Insert replica data
1288 table = self._schema.get_table(ApdbTables.DiaObject)
1289 replica_data: list[dict] = []
1290 replica_stmt: Any = None
1291 replica_table_name = ""
1292 if replica_chunk is not None:
1293 pk_names = [column.name for column in table.primary_key]
1294 replica_data = objs[pk_names].to_dict("records")
1295 if replica_data:
1296 for row in replica_data:
1297 row["apdb_replica_chunk"] = replica_chunk.id
1298 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks)
1299 replica_table_name = replica_table.name
1300 replica_stmt = replica_table.insert()
1302 # insert new versions
1303 with self._timer("insert_time", tags={"table": table.name}) as timer:
1304 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1305 timer.add_values(row_count=len(objs))
1306 if replica_stmt is not None:
1307 with self._timer("insert_time", tags={"table": replica_table_name}) as timer:
1308 connection.execute(replica_stmt, replica_data)
1309 timer.add_values(row_count=len(replica_data))
1311 def _storeDiaSources(
1312 self,
1313 sources: pandas.DataFrame,
1314 replica_chunk: ReplicaChunk | None,
1315 connection: sqlalchemy.engine.Connection,
1316 ) -> None:
1317 """Store catalog of DiaSources from current visit.
1319 Parameters
1320 ----------
1321 sources : `pandas.DataFrame`
1322 Catalog containing DiaSource records
1323 """
1324 table = self._schema.get_table(ApdbTables.DiaSource)
1326 # Insert replica data
1327 replica_data: list[dict] = []
1328 replica_stmt: Any = None
1329 replica_table_name = ""
1330 if replica_chunk is not None:
1331 pk_names = [column.name for column in table.primary_key]
1332 replica_data = sources[pk_names].to_dict("records")
1333 if replica_data:
1334 for row in replica_data:
1335 row["apdb_replica_chunk"] = replica_chunk.id
1336 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks)
1337 replica_table_name = replica_table.name
1338 replica_stmt = replica_table.insert()
1340 # everything to be done in single transaction
1341 with self._timer("insert_time", tags={"table": table.name}) as timer:
1342 sources = _coerce_uint64(sources)
1343 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1344 timer.add_values(row_count=len(sources))
1345 if replica_stmt is not None:
1346 with self._timer("replica_insert_time", tags={"table": replica_table_name}) as timer:
1347 connection.execute(replica_stmt, replica_data)
1348 timer.add_values(row_count=len(replica_data))
1350 def _storeDiaForcedSources(
1351 self,
1352 sources: pandas.DataFrame,
1353 replica_chunk: ReplicaChunk | None,
1354 connection: sqlalchemy.engine.Connection,
1355 ) -> None:
1356 """Store a set of DiaForcedSources from current visit.
1358 Parameters
1359 ----------
1360 sources : `pandas.DataFrame`
1361 Catalog containing DiaForcedSource records
1362 """
1363 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1365 # Insert replica data
1366 replica_data: list[dict] = []
1367 replica_stmt: Any = None
1368 replica_table_name = ""
1369 if replica_chunk is not None:
1370 pk_names = [column.name for column in table.primary_key]
1371 replica_data = sources[pk_names].to_dict("records")
1372 if replica_data:
1373 for row in replica_data:
1374 row["apdb_replica_chunk"] = replica_chunk.id
1375 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks)
1376 replica_table_name = replica_table.name
1377 replica_stmt = replica_table.insert()
1379 # everything to be done in single transaction
1380 with self._timer("insert_time", tags={"table": table.name}) as timer:
1381 sources = _coerce_uint64(sources)
1382 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1383 timer.add_values(row_count=len(sources))
1384 if replica_stmt is not None:
1385 with self._timer("insert_time", tags={"table": replica_table_name}) as timer:
1386 connection.execute(replica_stmt, replica_data)
1387 timer.add_values(row_count=len(replica_data))
1389 def _storeUpdateRecords(
1390 self,
1391 records: Iterable[ApdbUpdateRecord],
1392 chunk: ReplicaChunk,
1393 *,
1394 store_chunk: bool = False,
1395 connection: sqlalchemy.engine.Connection | None = None,
1396 ) -> None:
1397 """Store ApdbUpdateRecords in the replica table for those records.
1399 Parameters
1400 ----------
1401 records : `list` [`ApdbUpdateRecord`]
1402 Records to store.
1403 chunk : `ReplicaChunk`
1404 Replica chunk for these records.
1405 store_chunk : `bool`
1406 If True then also store replica chunk.
1407 connection : `sqlalchemy.engine.Connection`
1408 SQLALchemy connection to use, if `None` the new connection will be
1409 made. `None` is useful for tests only, regular use will call this
1410 method in the same transaction that saves other types of records.
1412 Raises
1413 ------
1414 TypeError
1415 Raised if replication is not enabled for this instance.
1416 """
1417 if not self._schema.replication_enabled:
1418 raise TypeError("Replication is not enabled for this APDB instance.")
1420 apdb_replica_chunk = chunk.id
1421 # Do not use unique_if from ReplicaChunk as it could be reused in
1422 # multiple calls to this method.
1423 update_unique_id = uuid.uuid4()
1425 record_dicts = []
1426 for record in records:
1427 record_dicts.append(
1428 {
1429 "apdb_replica_chunk": apdb_replica_chunk,
1430 "update_time_ns": record.update_time_ns,
1431 "update_order": record.update_order,
1432 "update_unique_id": update_unique_id,
1433 "update_payload": record.to_json(),
1434 }
1435 )
1437 if not record_dicts:
1438 return
1440 # TODO: Need to check that table exists.
1441 table = self._schema.get_table(ExtraTables.ApdbUpdateRecordChunks)
1443 def _do_store(connection: sqlalchemy.engine.Connection) -> None:
1444 if store_chunk:
1445 self._storeReplicaChunk(chunk, connection)
1446 with self._timer("insert_time", tags={"table": table.name}) as timer:
1447 connection.execute(table.insert(), record_dicts)
1448 timer.add_values(row_count=len(record_dicts))
1450 if connection is None:
1451 with self._engine.begin() as connection:
1452 _do_store(connection)
1453 else:
1454 _do_store(connection)
1456 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1457 """Generate a set of HTM indices covering specified region.
1459 Parameters
1460 ----------
1461 region: `sphgeom.Region`
1462 Region that needs to be indexed.
1464 Returns
1465 -------
1466 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1467 """
1468 _LOG.debug("region: %s", region)
1469 indices = self.pixelator.envelope(region, self.config.pixelization.htm_max_ranges)
1471 return indices.ranges()
1473 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1474 """Make SQLAlchemy expression for selecting records in a region."""
1475 htm_index_column = table.columns[self.config.pixelization.htm_index_column]
1476 exprlist = []
1477 pixel_ranges = self._htm_indices(region)
1478 for low, upper in pixel_ranges:
1479 upper -= 1
1480 if low == upper:
1481 exprlist.append(htm_index_column == low)
1482 else:
1483 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1485 return sql.expression.or_(*exprlist)
1487 def _add_spatial_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1488 """Calculate spatial index for each record and add it to a DataFrame.
1490 Parameters
1491 ----------
1492 df : `pandas.DataFrame`
1493 DataFrame which has to contain ra/dec columns, names of these
1494 columns are defined by configuration ``ra_dec_columns`` field.
1496 Returns
1497 -------
1498 df : `pandas.DataFrame`
1499 DataFrame with ``pixelId`` column which contains pixel index
1500 for ra/dec coordinates.
1502 Notes
1503 -----
1504 This overrides any existing column in a DataFrame with the same name
1505 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1506 returned.
1507 """
1508 # calculate HTM index for every DiaObject
1509 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1510 ra_col, dec_col = self.config.ra_dec_columns
1511 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1512 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1513 idx = self.pixelator.index(uv3d)
1514 htm_index[i] = idx
1515 df = df.copy()
1516 df[self.config.pixelization.htm_index_column] = htm_index
1517 return df
1519 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame:
1520 """Update timestamp columns in input DataFrame to be aware datetime
1521 type in in UTC.
1523 AP pipeline generates naive datetime instances, we want them to be
1524 aware before they go to database. All naive timestamps are assumed to
1525 be in UTC timezone (they should be TAI).
1526 """
1527 # Find all columns with aware non-UTC timestamps and convert to UTC.
1528 columns = [
1529 column
1530 for column, dtype in df.dtypes.items()
1531 if isinstance(dtype, pandas.DatetimeTZDtype) and dtype.tz is not datetime.UTC
1532 ]
1533 for column in columns:
1534 df[column] = df[column].dt.tz_convert(datetime.UTC)
1535 # Find all columns with naive timestamps and add UTC timezone.
1536 columns = [
1537 column for column, dtype in df.dtypes.items() if pandas.api.types.is_datetime64_dtype(dtype)
1538 ]
1539 for column in columns:
1540 df[column] = df[column].dt.tz_localize(datetime.UTC)
1541 return df
1543 def _fix_result_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame:
1544 """Update timestamp columns to be naive datetime type in returned
1545 DataFrame.
1547 AP pipeline code expects DataFrames to contain naive datetime columns,
1548 while Postgres queries return timezone-aware type. This method converts
1549 those columns to naive datetime in UTC timezone.
1550 """
1551 # Find all columns with aware timestamps.
1552 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)]
1553 for column in columns:
1554 # tz_convert(None) will convert to UTC and drop timezone.
1555 df[column] = df[column].dt.tz_convert(None)
1556 return df
1558 def _timestamp_column_name(self, column: str) -> str:
1559 """Return column name before/after schema migration to MJD TAI."""
1560 return self.schema.timestamp_column_name(column)
1562 def _timestamp_column_value(self, time: astropy.time.Time) -> float | datetime.datetime:
1563 """Return column value before/after schema migration to MJD TAI."""
1564 if self.schema.has_mjd_timestamps:
1565 return float(time.tai.mjd)
1566 else:
1567 return time.datetime.astimezone(tz=datetime.UTC)
1569 def _get_diaobject_data(
1570 self, conn: sqlalchemy.engine.Connection, object_ids: Iterable[int], *columns: str
1571 ) -> list:
1572 """Select records from either DiaObject or DiaObjectLast and return
1573 selected rows as names tuples.
1574 """
1575 where: sqlalchemy.ColumnElement[bool]
1576 if self.config.dia_object_index == "last_object_table":
1577 table = self._schema.get_table(ApdbTables.DiaObjectLast)
1578 where = table.columns["diaObjectId"].in_(sorted(object_ids))
1579 else:
1580 table = self._schema.get_table(ApdbTables.DiaObject)
1581 validity_end_column = self._timestamp_column_name("validityEnd")
1582 where = sqlalchemy.and_(
1583 table.columns["diaObjectId"].in_(sorted(object_ids)),
1584 table.columns[validity_end_column].is_(None),
1585 )
1586 column_list = [table.columns["diaObjectId"]] + [table.columns[column] for column in columns]
1587 query = sql.select(*column_list).where(where)
1588 result = conn.execute(query)
1590 return list(result)
1592 def _get_diasource_data(
1593 self, conn: sqlalchemy.engine.Connection, source_ids: Iterable[int], *columns: str
1594 ) -> list:
1595 """Select records from DiaSource table by diaSourceId and return
1596 selected rows as named tuples.
1597 """
1598 table = self._schema.get_table(ApdbTables.DiaSource)
1599 where = table.columns["diaSourceId"].in_(sorted(source_ids))
1600 column_list = [table.columns["diaSourceId"]] + [table.columns[column] for column in columns]
1601 query = sql.select(*column_list).where(where)
1602 result = conn.execute(query)
1604 return list(result)