Coverage for python/lsst/dax/apdb/apdbSql.py: 15%
522 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 10:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 10:38 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from contextlib import closing, suppress
32from typing import TYPE_CHECKING, Any, cast
34import astropy.time
35import numpy as np
36import pandas
37import sqlalchemy
38from felis.simple import Table
39from lsst.pex.config import ChoiceField, Field, ListField
40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
41from lsst.utils.iteration import chunk_iterable
42from sqlalchemy import func, sql
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbConfigFreezer import ApdbConfigFreezer
47from .apdbMetadataSql import ApdbMetadataSql
48from .apdbSchema import ApdbTables
49from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
50from .timer import Timer
51from .versionTuple import IncompatibleVersionError, VersionTuple
53if TYPE_CHECKING:
54 import sqlite3
56 from .apdbMetadata import ApdbMetadata
58_LOG = logging.getLogger(__name__)
60VERSION = VersionTuple(0, 1, 0)
61"""Version for the code defined in this module. This needs to be updated
62(following compatibility rules) when schema produced by this code changes.
63"""
66def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
67 """Change the type of uint64 columns to int64, and return copy of data
68 frame.
69 """
70 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
71 return df.astype({name: np.int64 for name in names})
74def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float:
75 """Calculate starting point for time-based source search.
77 Parameters
78 ----------
79 visit_time : `astropy.time.Time`
80 Time of current visit.
81 months : `int`
82 Number of months in the sources history.
84 Returns
85 -------
86 time : `float`
87 A ``midpointMjdTai`` starting point, MJD time.
88 """
89 # TODO: Use of MJD must be consistent with the code in ap_association
90 # (see DM-31996)
91 return visit_time.mjd - months * 30
94def _onSqlite3Connect(
95 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord
96) -> None:
97 # Enable foreign keys
98 with closing(dbapiConnection.cursor()) as cursor:
99 cursor.execute("PRAGMA foreign_keys=ON;")
102class ApdbSqlConfig(ApdbConfig):
103 """APDB configuration class for SQL implementation (ApdbSql)."""
105 db_url = Field[str](doc="SQLAlchemy database connection URI")
106 isolation_level = ChoiceField[str](
107 doc=(
108 "Transaction isolation level, if unset then backend-default value "
109 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
110 "Some backends may not support every allowed value."
111 ),
112 allowed={
113 "READ_COMMITTED": "Read committed",
114 "READ_UNCOMMITTED": "Read uncommitted",
115 "REPEATABLE_READ": "Repeatable read",
116 "SERIALIZABLE": "Serializable",
117 },
118 default=None,
119 optional=True,
120 )
121 connection_pool = Field[bool](
122 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
123 default=True,
124 )
125 connection_timeout = Field[float](
126 doc=(
127 "Maximum time to wait time for database lock to be released before exiting. "
128 "Defaults to sqlalchemy defaults if not set."
129 ),
130 default=None,
131 optional=True,
132 )
133 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
134 dia_object_index = ChoiceField[str](
135 doc="Indexing mode for DiaObject table",
136 allowed={
137 "baseline": "Index defined in baseline schema",
138 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
139 "last_object_table": "Separate DiaObjectLast table",
140 },
141 default="baseline",
142 )
143 htm_level = Field[int](doc="HTM indexing level", default=20)
144 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
145 htm_index_column = Field[str](
146 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
147 )
148 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
149 dia_object_columns = ListField[str](
150 doc="List of columns to read from DiaObject, by default read all columns", default=[]
151 )
152 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
153 namespace = Field[str](
154 doc=(
155 "Namespace or schema name for all tables in APDB database. "
156 "Presently only works for PostgreSQL backend. "
157 "If schema with this name does not exist it will be created when "
158 "APDB tables are created."
159 ),
160 default=None,
161 optional=True,
162 )
163 timer = Field[bool](doc="If True then print/log timing information", default=False)
165 def validate(self) -> None:
166 super().validate()
167 if len(self.ra_dec_columns) != 2:
168 raise ValueError("ra_dec_columns must have exactly two column names")
171class ApdbSqlTableData(ApdbTableData):
172 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
174 def __init__(self, result: sqlalchemy.engine.Result):
175 self._keys = list(result.keys())
176 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
178 def column_names(self) -> list[str]:
179 return self._keys
181 def rows(self) -> Iterable[tuple]:
182 return self._rows
185class ApdbSql(Apdb):
186 """Implementation of APDB interface based on SQL database.
188 The implementation is configured via standard ``pex_config`` mechanism
189 using `ApdbSqlConfig` configuration class. For an example of different
190 configurations check ``config/`` folder.
192 Parameters
193 ----------
194 config : `ApdbSqlConfig`
195 Configuration object.
196 """
198 ConfigClass = ApdbSqlConfig
200 metadataSchemaVersionKey = "version:schema"
201 """Name of the metadata key to store schema version number."""
203 metadataCodeVersionKey = "version:ApdbSql"
204 """Name of the metadata key to store code version number."""
206 metadataConfigKey = "config:apdb-sql.json"
207 """Name of the metadata key to store code version number."""
209 _frozen_parameters = (
210 "use_insert_id",
211 "dia_object_index",
212 "htm_level",
213 "htm_index_column",
214 "ra_dec_columns",
215 )
216 """Names of the config parameters to be frozen in metadata table."""
218 def __init__(self, config: ApdbSqlConfig):
219 self._engine = self._makeEngine(config)
221 sa_metadata = sqlalchemy.MetaData(schema=config.namespace)
222 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix)
223 meta_table: sqlalchemy.schema.Table | None = None
224 with suppress(sqlalchemy.exc.NoSuchTableError):
225 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine)
227 self._metadata = ApdbMetadataSql(self._engine, meta_table)
229 # Read frozen config from metadata.
230 config_json = self._metadata.get(self.metadataConfigKey)
231 if config_json is not None:
232 # Update config from metadata.
233 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters)
234 self.config = freezer.update(config, config_json)
235 else:
236 self.config = config
237 self.config.validate()
239 self._schema = ApdbSqlSchema(
240 engine=self._engine,
241 dia_object_index=self.config.dia_object_index,
242 schema_file=self.config.schema_file,
243 schema_name=self.config.schema_name,
244 prefix=self.config.prefix,
245 namespace=self.config.namespace,
246 htm_index_column=self.config.htm_index_column,
247 use_insert_id=self.config.use_insert_id,
248 )
250 if self._metadata.table_exists():
251 self._versionCheck(self._metadata)
253 self.pixelator = HtmPixelization(self.config.htm_level)
254 self.use_insert_id = self._schema.has_insert_id
256 _LOG.debug("APDB Configuration:")
257 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
258 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
259 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
260 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
261 _LOG.debug(" schema_file: %s", self.config.schema_file)
262 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
263 _LOG.debug(" schema prefix: %s", self.config.prefix)
265 @classmethod
266 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine:
267 """Make SQLALchemy engine based on configured parameters.
269 Parameters
270 ----------
271 config : `ApdbSqlConfig`
272 Configuration object.
273 """
274 # engine is reused between multiple processes, make sure that we don't
275 # share connections by disabling pool (by using NullPool class)
276 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo)
277 conn_args: dict[str, Any] = dict()
278 if not config.connection_pool:
279 kw.update(poolclass=NullPool)
280 if config.isolation_level is not None:
281 kw.update(isolation_level=config.isolation_level)
282 elif config.db_url.startswith("sqlite"): # type: ignore
283 # Use READ_UNCOMMITTED as default value for sqlite.
284 kw.update(isolation_level="READ_UNCOMMITTED")
285 if config.connection_timeout is not None:
286 if config.db_url.startswith("sqlite"):
287 conn_args.update(timeout=config.connection_timeout)
288 elif config.db_url.startswith(("postgresql", "mysql")):
289 conn_args.update(connect_timeout=config.connection_timeout)
290 kw.update(connect_args=conn_args)
291 engine = sqlalchemy.create_engine(config.db_url, **kw)
293 if engine.dialect.name == "sqlite":
294 # Need to enable foreign keys on every new connection.
295 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)
297 return engine
299 def _versionCheck(self, metadata: ApdbMetadataSql) -> None:
300 """Check schema version compatibility."""
302 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
303 """Retrieve version number from given metadata key."""
304 if metadata.table_exists():
305 version_str = metadata.get(key)
306 if version_str is None:
307 # Should not happen with existing metadata table.
308 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
309 return VersionTuple.fromString(version_str)
310 return default
312 # For old databases where metadata table does not exist we assume that
313 # version of both code and schema is 0.1.0.
314 initial_version = VersionTuple(0, 1, 0)
315 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
316 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
318 # For now there is no way to make read-only APDB instances, assume that
319 # any access can do updates.
320 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
321 raise IncompatibleVersionError(
322 f"Configured schema version {self._schema.schemaVersion()} "
323 f"is not compatible with database version {db_schema_version}"
324 )
325 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
326 raise IncompatibleVersionError(
327 f"Current code version {self.apdbImplementationVersion()} "
328 f"is not compatible with database version {db_code_version}"
329 )
331 @classmethod
332 def apdbImplementationVersion(cls) -> VersionTuple:
333 # Docstring inherited from base class.
334 return VERSION
336 @classmethod
337 def init_database(
338 cls,
339 db_url: str,
340 *,
341 schema_file: str | None = None,
342 schema_name: str | None = None,
343 read_sources_months: int | None = None,
344 read_forced_sources_months: int | None = None,
345 use_insert_id: bool = False,
346 connection_timeout: int | None = None,
347 dia_object_index: str | None = None,
348 htm_level: int | None = None,
349 htm_index_column: str | None = None,
350 ra_dec_columns: list[str] | None = None,
351 prefix: str | None = None,
352 namespace: str | None = None,
353 drop: bool = False,
354 ) -> ApdbSqlConfig:
355 """Initialize new APDB instance and make configuration object for it.
357 Parameters
358 ----------
359 db_url : `str`
360 SQLAlchemy database URL.
361 schema_file : `str`, optional
362 Location of (YAML) configuration file with APDB schema. If not
363 specified then default location will be used.
364 schema_name : str | None
365 Name of the schema in YAML configuration file. If not specified
366 then default name will be used.
367 read_sources_months : `int`, optional
368 Number of months of history to read from DiaSource.
369 read_forced_sources_months : `int`, optional
370 Number of months of history to read from DiaForcedSource.
371 use_insert_id : `bool`
372 If True, make additional tables used for replication to PPDB.
373 connection_timeout : `int`, optional
374 Database connection timeout in seconds.
375 dia_object_index : `str`, optional
376 Indexing mode for DiaObject table.
377 htm_level : `int`, optional
378 HTM indexing level.
379 htm_index_column : `str`, optional
380 Name of a HTM index column for DiaObject and DiaSource tables.
381 ra_dec_columns : `list` [`str`], optional
382 Names of ra/dec columns in DiaObject table.
383 prefix : `str`, optional
384 Optional prefix for all table names.
385 namespace : `str`, optional
386 Name of the database schema for all APDB tables. If not specified
387 then default schema is used.
388 drop : `bool`, optional
389 If `True` then drop existing tables before re-creating the schema.
391 Returns
392 -------
393 config : `ApdbSqlConfig`
394 Resulting configuration object for a created APDB instance.
395 """
396 config = ApdbSqlConfig(db_url=db_url, use_insert_id=use_insert_id)
397 if schema_file is not None:
398 config.schema_file = schema_file
399 if schema_name is not None:
400 config.schema_name = schema_name
401 if read_sources_months is not None:
402 config.read_sources_months = read_sources_months
403 if read_forced_sources_months is not None:
404 config.read_forced_sources_months = read_forced_sources_months
405 if connection_timeout is not None:
406 config.connection_timeout = connection_timeout
407 if dia_object_index is not None:
408 config.dia_object_index = dia_object_index
409 if htm_level is not None:
410 config.htm_level = htm_level
411 if htm_index_column is not None:
412 config.htm_index_column = htm_index_column
413 if ra_dec_columns is not None:
414 config.ra_dec_columns = ra_dec_columns
415 if prefix is not None:
416 config.prefix = prefix
417 if namespace is not None:
418 config.namespace = namespace
420 cls._makeSchema(config, drop=drop)
422 return config
424 def apdbSchemaVersion(self) -> VersionTuple:
425 # Docstring inherited from base class.
426 return self._schema.schemaVersion()
428 def tableRowCount(self) -> dict[str, int]:
429 """Return dictionary with the table names and row counts.
431 Used by ``ap_proto`` to keep track of the size of the database tables.
432 Depending on database technology this could be expensive operation.
434 Returns
435 -------
436 row_counts : `dict`
437 Dict where key is a table name and value is a row count.
438 """
439 res = {}
440 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
441 if self.config.dia_object_index == "last_object_table":
442 tables.append(ApdbTables.DiaObjectLast)
443 with self._engine.begin() as conn:
444 for table in tables:
445 sa_table = self._schema.get_table(table)
446 stmt = sql.select(func.count()).select_from(sa_table)
447 count: int = conn.execute(stmt).scalar_one()
448 res[table.name] = count
450 return res
452 def tableDef(self, table: ApdbTables) -> Table | None:
453 # docstring is inherited from a base class
454 return self._schema.tableSchemas.get(table)
456 @classmethod
457 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None:
458 # docstring is inherited from a base class
460 if not isinstance(config, ApdbSqlConfig):
461 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
463 engine = cls._makeEngine(config)
465 # Ask schema class to create all tables.
466 schema = ApdbSqlSchema(
467 engine=engine,
468 dia_object_index=config.dia_object_index,
469 schema_file=config.schema_file,
470 schema_name=config.schema_name,
471 prefix=config.prefix,
472 namespace=config.namespace,
473 htm_index_column=config.htm_index_column,
474 use_insert_id=config.use_insert_id,
475 )
476 schema.makeSchema(drop=drop)
478 # Need metadata table to store few items in it, if table exists.
479 meta_table: sqlalchemy.schema.Table | None = None
480 with suppress(ValueError):
481 meta_table = schema.get_table(ApdbTables.metadata)
483 apdb_meta = ApdbMetadataSql(engine, meta_table)
484 if apdb_meta.table_exists():
485 # Fill version numbers, overwrite if they are already there.
486 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True)
487 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
489 # Store frozen part of a configuration in metadata.
490 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters)
491 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
493 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
494 # docstring is inherited from a base class
496 # decide what columns we need
497 if self.config.dia_object_index == "last_object_table":
498 table_enum = ApdbTables.DiaObjectLast
499 else:
500 table_enum = ApdbTables.DiaObject
501 table = self._schema.get_table(table_enum)
502 if not self.config.dia_object_columns:
503 columns = self._schema.get_apdb_columns(table_enum)
504 else:
505 columns = [table.c[col] for col in self.config.dia_object_columns]
506 query = sql.select(*columns)
508 # build selection
509 query = query.where(self._filterRegion(table, region))
511 # select latest version of objects
512 if self.config.dia_object_index != "last_object_table":
513 query = query.where(table.c.validityEnd == None) # noqa: E711
515 # _LOG.debug("query: %s", query)
517 # execute select
518 with Timer("DiaObject select", self.config.timer):
519 with self._engine.begin() as conn:
520 objects = pandas.read_sql_query(query, conn)
521 _LOG.debug("found %s DiaObjects", len(objects))
522 return objects
524 def getDiaSources(
525 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
526 ) -> pandas.DataFrame | None:
527 # docstring is inherited from a base class
528 if self.config.read_sources_months == 0:
529 _LOG.debug("Skip DiaSources fetching")
530 return None
532 if object_ids is None:
533 # region-based select
534 return self._getDiaSourcesInRegion(region, visit_time)
535 else:
536 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
538 def getDiaForcedSources(
539 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
540 ) -> pandas.DataFrame | None:
541 # docstring is inherited from a base class
542 if self.config.read_forced_sources_months == 0:
543 _LOG.debug("Skip DiaForceSources fetching")
544 return None
546 if object_ids is None:
547 # This implementation does not support region-based selection.
548 raise NotImplementedError("Region-based selection is not supported")
550 # TODO: DateTime.MJD must be consistent with code in ap_association,
551 # alternatively we can fill midpointMjdTai ourselves in store()
552 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
553 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
555 with Timer("DiaForcedSource select", self.config.timer):
556 sources = self._getSourcesByIDs(
557 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
558 )
560 _LOG.debug("found %s DiaForcedSources", len(sources))
561 return sources
563 def containsVisitDetector(self, visit: int, detector: int) -> bool:
564 # docstring is inherited from a base class
565 raise NotImplementedError()
567 def containsCcdVisit(self, ccdVisitId: int) -> bool:
568 """Test whether data for a given visit-detector is present in the APDB.
570 This method is a placeholder until `Apdb.containsVisitDetector` can
571 be implemented.
573 Parameters
574 ----------
575 ccdVisitId : `int`
576 The packed ID of the visit-detector to search for.
578 Returns
579 -------
580 present : `bool`
581 `True` if some DiaSource records exist for the specified
582 observation, `False` otherwise.
583 """
584 # TODO: remove this method in favor of containsVisitDetector on either
585 # DM-41671 or a ticket that removes ccdVisitId from these tables
586 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource)
587 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource)
588 # Query should load only one leaf page of the index
589 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
590 # Backup query in case an image was processed but had no diaSources
591 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1)
593 with self._engine.begin() as conn:
594 result = conn.execute(query1).scalar_one_or_none()
595 if result is not None:
596 return True
597 else:
598 result = conn.execute(query2).scalar_one_or_none()
599 return result is not None
601 def getInsertIds(self) -> list[ApdbInsertId] | None:
602 # docstring is inherited from a base class
603 if not self._schema.has_insert_id:
604 return None
606 table = self._schema.get_table(ExtraTables.DiaInsertId)
607 assert table is not None, "has_insert_id=True means it must be defined"
608 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by(
609 table.columns["insert_time"]
610 )
611 with Timer("DiaObject insert id select", self.config.timer):
612 with self._engine.connect() as conn:
613 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
614 ids = []
615 for row in result:
616 insert_time = astropy.time.Time(row[1].timestamp(), format="unix_tai")
617 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time))
618 return ids
620 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
621 # docstring is inherited from a base class
622 if not self._schema.has_insert_id:
623 raise ValueError("APDB is not configured for history storage")
625 table = self._schema.get_table(ExtraTables.DiaInsertId)
627 insert_ids = [id.id for id in ids]
628 where_clause = table.columns["insert_id"].in_(insert_ids)
629 stmt = table.delete().where(where_clause)
630 with self._engine.begin() as conn:
631 conn.execute(stmt)
633 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
634 # docstring is inherited from a base class
635 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
637 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
638 # docstring is inherited from a base class
639 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
641 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
642 # docstring is inherited from a base class
643 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
645 def _get_history(
646 self,
647 ids: Iterable[ApdbInsertId],
648 table_enum: ApdbTables,
649 history_table_enum: ExtraTables,
650 ) -> ApdbTableData:
651 """Return catalog of records for given insert identifiers, common
652 implementation for all DIA tables.
653 """
654 if not self._schema.has_insert_id:
655 raise ValueError("APDB is not configured for history retrieval")
657 table = self._schema.get_table(table_enum)
658 history_table = self._schema.get_table(history_table_enum)
660 join = table.join(history_table)
661 insert_ids = [id.id for id in ids]
662 history_id_column = history_table.columns["insert_id"]
663 apdb_columns = self._schema.get_apdb_columns(table_enum)
664 where_clause = history_id_column.in_(insert_ids)
665 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
667 # execute select
668 with Timer(f"{table.name} history select", self.config.timer):
669 with self._engine.begin() as conn:
670 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
671 return ApdbSqlTableData(result)
673 def getSSObjects(self) -> pandas.DataFrame:
674 # docstring is inherited from a base class
676 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
677 query = sql.select(*columns)
679 # execute select
680 with Timer("DiaObject select", self.config.timer):
681 with self._engine.begin() as conn:
682 objects = pandas.read_sql_query(query, conn)
683 _LOG.debug("found %s SSObjects", len(objects))
684 return objects
686 def store(
687 self,
688 visit_time: astropy.time.Time,
689 objects: pandas.DataFrame,
690 sources: pandas.DataFrame | None = None,
691 forced_sources: pandas.DataFrame | None = None,
692 ) -> None:
693 # docstring is inherited from a base class
695 # We want to run all inserts in one transaction.
696 with self._engine.begin() as connection:
697 insert_id: ApdbInsertId | None = None
698 if self._schema.has_insert_id:
699 insert_id = ApdbInsertId.new_insert_id(visit_time)
700 self._storeInsertId(insert_id, visit_time, connection)
702 # fill pixelId column for DiaObjects
703 objects = self._add_obj_htm_index(objects)
704 self._storeDiaObjects(objects, visit_time, insert_id, connection)
706 if sources is not None:
707 # copy pixelId column from DiaObjects to DiaSources
708 sources = self._add_src_htm_index(sources, objects)
709 self._storeDiaSources(sources, insert_id, connection)
711 if forced_sources is not None:
712 self._storeDiaForcedSources(forced_sources, insert_id, connection)
714 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
715 # docstring is inherited from a base class
717 idColumn = "ssObjectId"
718 table = self._schema.get_table(ApdbTables.SSObject)
720 # everything to be done in single transaction
721 with self._engine.begin() as conn:
722 # Find record IDs that already exist. Some types like np.int64 can
723 # cause issues with sqlalchemy, convert them to int.
724 ids = sorted(int(oid) for oid in objects[idColumn])
726 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
727 result = conn.execute(query)
728 knownIds = set(row.ssObjectId for row in result)
730 filter = objects[idColumn].isin(knownIds)
731 toUpdate = cast(pandas.DataFrame, objects[filter])
732 toInsert = cast(pandas.DataFrame, objects[~filter])
734 # insert new records
735 if len(toInsert) > 0:
736 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema)
738 # update existing records
739 if len(toUpdate) > 0:
740 whereKey = f"{idColumn}_param"
741 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
742 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
743 values = toUpdate.to_dict("records")
744 result = conn.execute(update, values)
746 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
747 # docstring is inherited from a base class
749 table = self._schema.get_table(ApdbTables.DiaSource)
750 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
752 with self._engine.begin() as conn:
753 # Need to make sure that every ID exists in the database, but
754 # executemany may not support rowcount, so iterate and check what
755 # is missing.
756 missing_ids: list[int] = []
757 for key, value in idMap.items():
758 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
759 result = conn.execute(query, params)
760 if result.rowcount == 0:
761 missing_ids.append(key)
762 if missing_ids:
763 missing = ",".join(str(item) for item in missing_ids)
764 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
766 def dailyJob(self) -> None:
767 # docstring is inherited from a base class
768 pass
770 def countUnassociatedObjects(self) -> int:
771 # docstring is inherited from a base class
773 # Retrieve the DiaObject table.
774 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
776 # Construct the sql statement.
777 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
778 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
780 # Return the count.
781 with self._engine.begin() as conn:
782 count = conn.execute(stmt).scalar_one()
784 return count
786 @property
787 def metadata(self) -> ApdbMetadata:
788 # docstring is inherited from a base class
789 if self._metadata is None:
790 raise RuntimeError("Database schema was not initialized.")
791 return self._metadata
793 def _getDiaSourcesInRegion(self, region: Region, visit_time: astropy.time.Time) -> pandas.DataFrame:
794 """Return catalog of DiaSource instances from given region.
796 Parameters
797 ----------
798 region : `lsst.sphgeom.Region`
799 Region to search for DIASources.
800 visit_time : `astropy.time.Time`
801 Time of the current visit.
803 Returns
804 -------
805 catalog : `pandas.DataFrame`
806 Catalog containing DiaSource records.
807 """
808 # TODO: DateTime.MJD must be consistent with code in ap_association,
809 # alternatively we can fill midpointMjdTai ourselves in store()
810 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
811 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
813 table = self._schema.get_table(ApdbTables.DiaSource)
814 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
815 query = sql.select(*columns)
817 # build selection
818 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
819 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
820 query = query.where(where)
822 # execute select
823 with Timer("DiaSource select", self.config.timer):
824 with self._engine.begin() as conn:
825 sources = pandas.read_sql_query(query, conn)
826 _LOG.debug("found %s DiaSources", len(sources))
827 return sources
829 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: astropy.time.Time) -> pandas.DataFrame:
830 """Return catalog of DiaSource instances given set of DiaObject IDs.
832 Parameters
833 ----------
834 object_ids :
835 Collection of DiaObject IDs
836 visit_time : `astropy.time.Time`
837 Time of the current visit.
839 Returns
840 -------
841 catalog : `pandas.DataFrame`
842 Catalog contaning DiaSource records.
843 """
844 # TODO: DateTime.MJD must be consistent with code in ap_association,
845 # alternatively we can fill midpointMjdTai ourselves in store()
846 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
847 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
849 with Timer("DiaSource select", self.config.timer):
850 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
852 _LOG.debug("found %s DiaSources", len(sources))
853 return sources
855 def _getSourcesByIDs(
856 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float
857 ) -> pandas.DataFrame:
858 """Return catalog of DiaSource or DiaForcedSource instances given set
859 of DiaObject IDs.
861 Parameters
862 ----------
863 table : `sqlalchemy.schema.Table`
864 Database table.
865 object_ids :
866 Collection of DiaObject IDs
867 midpointMjdTai_start : `float`
868 Earliest midpointMjdTai to retrieve.
870 Returns
871 -------
872 catalog : `pandas.DataFrame`
873 Catalog contaning DiaSource records. `None` is returned if
874 ``read_sources_months`` configuration parameter is set to 0 or
875 when ``object_ids`` is empty.
876 """
877 table = self._schema.get_table(table_enum)
878 columns = self._schema.get_apdb_columns(table_enum)
880 sources: pandas.DataFrame | None = None
881 if len(object_ids) <= 0:
882 _LOG.debug("ID list is empty, just fetch empty result")
883 query = sql.select(*columns).where(sql.literal(False))
884 with self._engine.begin() as conn:
885 sources = pandas.read_sql_query(query, conn)
886 else:
887 data_frames: list[pandas.DataFrame] = []
888 for ids in chunk_iterable(sorted(object_ids), 1000):
889 query = sql.select(*columns)
891 # Some types like np.int64 can cause issues with
892 # sqlalchemy, convert them to int.
893 int_ids = [int(oid) for oid in ids]
895 # select by object id
896 query = query.where(
897 sql.expression.and_(
898 table.columns["diaObjectId"].in_(int_ids),
899 table.columns["midpointMjdTai"] > midpointMjdTai_start,
900 )
901 )
903 # execute select
904 with self._engine.begin() as conn:
905 data_frames.append(pandas.read_sql_query(query, conn))
907 if len(data_frames) == 1:
908 sources = data_frames[0]
909 else:
910 sources = pandas.concat(data_frames)
911 assert sources is not None, "Catalog cannot be None"
912 return sources
914 def _storeInsertId(
915 self, insert_id: ApdbInsertId, visit_time: astropy.time.Time, connection: sqlalchemy.engine.Connection
916 ) -> None:
917 dt = visit_time.datetime
919 table = self._schema.get_table(ExtraTables.DiaInsertId)
921 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
922 connection.execute(stmt)
924 def _storeDiaObjects(
925 self,
926 objs: pandas.DataFrame,
927 visit_time: astropy.time.Time,
928 insert_id: ApdbInsertId | None,
929 connection: sqlalchemy.engine.Connection,
930 ) -> None:
931 """Store catalog of DiaObjects from current visit.
933 Parameters
934 ----------
935 objs : `pandas.DataFrame`
936 Catalog with DiaObject records.
937 visit_time : `astropy.time.Time`
938 Time of the visit.
939 insert_id : `ApdbInsertId`
940 Insert identifier.
941 """
942 if len(objs) == 0:
943 _LOG.debug("No objects to write to database.")
944 return
946 # Some types like np.int64 can cause issues with sqlalchemy, convert
947 # them to int.
948 ids = sorted(int(oid) for oid in objs["diaObjectId"])
949 _LOG.debug("first object ID: %d", ids[0])
951 # TODO: Need to verify that we are using correct scale here for
952 # DATETIME representation (see DM-31996).
953 dt = visit_time.datetime
955 # everything to be done in single transaction
956 if self.config.dia_object_index == "last_object_table":
957 # Insert and replace all records in LAST table.
958 table = self._schema.get_table(ApdbTables.DiaObjectLast)
960 # Drop the previous objects (pandas cannot upsert).
961 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
963 with Timer(table.name + " delete", self.config.timer):
964 res = connection.execute(query)
965 _LOG.debug("deleted %s objects", res.rowcount)
967 # DiaObjectLast is a subset of DiaObject, strip missing columns
968 last_column_names = [column.name for column in table.columns]
969 last_objs = objs[last_column_names]
970 last_objs = _coerce_uint64(last_objs)
972 if "lastNonForcedSource" in last_objs.columns:
973 # lastNonForcedSource is defined NOT NULL, fill it with visit
974 # time just in case.
975 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
976 else:
977 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
978 last_objs.set_index(extra_column.index, inplace=True)
979 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
981 with Timer("DiaObjectLast insert", self.config.timer):
982 last_objs.to_sql(
983 table.name,
984 connection,
985 if_exists="append",
986 index=False,
987 schema=table.schema,
988 )
989 else:
990 # truncate existing validity intervals
991 table = self._schema.get_table(ApdbTables.DiaObject)
993 update = (
994 table.update()
995 .values(validityEnd=dt)
996 .where(
997 sql.expression.and_(
998 table.columns["diaObjectId"].in_(ids),
999 table.columns["validityEnd"].is_(None),
1000 )
1001 )
1002 )
1004 # _LOG.debug("query: %s", query)
1006 with Timer(table.name + " truncate", self.config.timer):
1007 res = connection.execute(update)
1008 _LOG.debug("truncated %s intervals", res.rowcount)
1010 objs = _coerce_uint64(objs)
1012 # Fill additional columns
1013 extra_columns: list[pandas.Series] = []
1014 if "validityStart" in objs.columns:
1015 objs["validityStart"] = dt
1016 else:
1017 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
1018 if "validityEnd" in objs.columns:
1019 objs["validityEnd"] = None
1020 else:
1021 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
1022 if "lastNonForcedSource" in objs.columns:
1023 # lastNonForcedSource is defined NOT NULL, fill it with visit time
1024 # just in case.
1025 objs["lastNonForcedSource"].fillna(dt, inplace=True)
1026 else:
1027 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
1028 if extra_columns:
1029 objs.set_index(extra_columns[0].index, inplace=True)
1030 objs = pandas.concat([objs] + extra_columns, axis="columns")
1032 # Insert history data
1033 table = self._schema.get_table(ApdbTables.DiaObject)
1034 history_data: list[dict] = []
1035 history_stmt: Any = None
1036 if insert_id is not None:
1037 pk_names = [column.name for column in table.primary_key]
1038 history_data = objs[pk_names].to_dict("records")
1039 for row in history_data:
1040 row["insert_id"] = insert_id.id
1041 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
1042 history_stmt = history_table.insert()
1044 # insert new versions
1045 with Timer("DiaObject insert", self.config.timer):
1046 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1047 if history_stmt is not None:
1048 connection.execute(history_stmt, history_data)
1050 def _storeDiaSources(
1051 self,
1052 sources: pandas.DataFrame,
1053 insert_id: ApdbInsertId | None,
1054 connection: sqlalchemy.engine.Connection,
1055 ) -> None:
1056 """Store catalog of DiaSources from current visit.
1058 Parameters
1059 ----------
1060 sources : `pandas.DataFrame`
1061 Catalog containing DiaSource records
1062 """
1063 table = self._schema.get_table(ApdbTables.DiaSource)
1065 # Insert history data
1066 history: list[dict] = []
1067 history_stmt: Any = None
1068 if insert_id is not None:
1069 pk_names = [column.name for column in table.primary_key]
1070 history = sources[pk_names].to_dict("records")
1071 for row in history:
1072 row["insert_id"] = insert_id.id
1073 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
1074 history_stmt = history_table.insert()
1076 # everything to be done in single transaction
1077 with Timer("DiaSource insert", self.config.timer):
1078 sources = _coerce_uint64(sources)
1079 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1080 if history_stmt is not None:
1081 connection.execute(history_stmt, history)
1083 def _storeDiaForcedSources(
1084 self,
1085 sources: pandas.DataFrame,
1086 insert_id: ApdbInsertId | None,
1087 connection: sqlalchemy.engine.Connection,
1088 ) -> None:
1089 """Store a set of DiaForcedSources from current visit.
1091 Parameters
1092 ----------
1093 sources : `pandas.DataFrame`
1094 Catalog containing DiaForcedSource records
1095 """
1096 table = self._schema.get_table(ApdbTables.DiaForcedSource)
1098 # Insert history data
1099 history: list[dict] = []
1100 history_stmt: Any = None
1101 if insert_id is not None:
1102 pk_names = [column.name for column in table.primary_key]
1103 history = sources[pk_names].to_dict("records")
1104 for row in history:
1105 row["insert_id"] = insert_id.id
1106 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
1107 history_stmt = history_table.insert()
1109 # everything to be done in single transaction
1110 with Timer("DiaForcedSource insert", self.config.timer):
1111 sources = _coerce_uint64(sources)
1112 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
1113 if history_stmt is not None:
1114 connection.execute(history_stmt, history)
1116 def _htm_indices(self, region: Region) -> list[tuple[int, int]]:
1117 """Generate a set of HTM indices covering specified region.
1119 Parameters
1120 ----------
1121 region: `sphgeom.Region`
1122 Region that needs to be indexed.
1124 Returns
1125 -------
1126 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
1127 """
1128 _LOG.debug("region: %s", region)
1129 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
1131 return indices.ranges()
1133 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
1134 """Make SQLAlchemy expression for selecting records in a region."""
1135 htm_index_column = table.columns[self.config.htm_index_column]
1136 exprlist = []
1137 pixel_ranges = self._htm_indices(region)
1138 for low, upper in pixel_ranges:
1139 upper -= 1
1140 if low == upper:
1141 exprlist.append(htm_index_column == low)
1142 else:
1143 exprlist.append(sql.expression.between(htm_index_column, low, upper))
1145 return sql.expression.or_(*exprlist)
1147 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
1148 """Calculate HTM index for each record and add it to a DataFrame.
1150 Notes
1151 -----
1152 This overrides any existing column in a DataFrame with the same name
1153 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1154 returned.
1155 """
1156 # calculate HTM index for every DiaObject
1157 htm_index = np.zeros(df.shape[0], dtype=np.int64)
1158 ra_col, dec_col = self.config.ra_dec_columns
1159 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1160 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
1161 idx = self.pixelator.index(uv3d)
1162 htm_index[i] = idx
1163 df = df.copy()
1164 df[self.config.htm_index_column] = htm_index
1165 return df
1167 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1168 """Add pixelId column to DiaSource catalog.
1170 Notes
1171 -----
1172 This method copies pixelId value from a matching DiaObject record.
1173 DiaObject catalog needs to have a pixelId column filled by
1174 ``_add_obj_htm_index`` method and DiaSource records need to be
1175 associated to DiaObjects via ``diaObjectId`` column.
1177 This overrides any existing column in a DataFrame with the same name
1178 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
1179 returned.
1180 """
1181 pixel_id_map: dict[int, int] = {
1182 diaObjectId: pixelId
1183 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
1184 }
1185 # DiaSources associated with SolarSystemObjects do not have an
1186 # associated DiaObject hence we skip them and set their htmIndex
1187 # value to 0.
1188 pixel_id_map[0] = 0
1189 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1190 for i, diaObjId in enumerate(sources["diaObjectId"]):
1191 htm_index[i] = pixel_id_map[diaObjId]
1192 sources = sources.copy()
1193 sources[self.config.htm_index_column] = htm_index
1194 return sources