Coverage for python/lsst/dax/apdb/apdbSql.py: 15%
425 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-11 10:49 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-11 10:49 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from typing import Any, Dict, List, Optional, Tuple, cast
33import lsst.daf.base as dafBase
34import numpy as np
35import pandas
36import sqlalchemy
37from felis.simple import Table
38from lsst.pex.config import ChoiceField, Field, ListField
39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
40from lsst.utils.iteration import chunk_iterable
41from sqlalchemy import func, inspection, sql
42from sqlalchemy.engine import Inspector
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbSchema import ApdbTables
47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
48from .timer import Timer
50_LOG = logging.getLogger(__name__)
53class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
54 """Terrible hack to workaround Pandas incomplete support for sqlalchemy 2.
56 We need to pass a Connection instance to pandas method, but in SA 2 the
57 Connection class lost ``connect`` method which is used by Pandas.
58 """
60 def __init__(self, connection: sqlalchemy.engine.Connection):
61 self._connection = connection
63 def connect(self) -> Any:
64 return self
66 @property
67 def execute(self) -> Callable:
68 return self._connection.execute
70 @property
71 def execution_options(self) -> Callable:
72 return self._connection.execution_options
74 @property
75 def connection(self) -> Any:
76 return self._connection.connection
78 def __enter__(self) -> sqlalchemy.engine.Connection:
79 return self._connection
81 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
82 # Do not close connection here
83 pass
86@inspection._inspects(_ConnectionHackSA2)
87def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
88 return Inspector._construct(Inspector._init_connection, conn._connection)
91def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
92 """Change type of the uint64 columns to int64, return copy of data frame."""
93 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
94 return df.astype({name: np.int64 for name in names})
97def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
98 """Calculate starting point for time-based source search.
100 Parameters
101 ----------
102 visit_time : `lsst.daf.base.DateTime`
103 Time of current visit.
104 months : `int`
105 Number of months in the sources history.
107 Returns
108 -------
109 time : `float`
110 A ``midPointTai`` starting point, MJD time.
111 """
112 # TODO: `system` must be consistent with the code in ap_association
113 # (see DM-31996)
114 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
117class ApdbSqlConfig(ApdbConfig):
118 """APDB configuration class for SQL implementation (ApdbSql)."""
120 db_url = Field[str](doc="SQLAlchemy database connection URI")
121 isolation_level = ChoiceField[str](
122 doc=(
123 "Transaction isolation level, if unset then backend-default value "
124 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
125 "Some backends may not support every allowed value."
126 ),
127 allowed={
128 "READ_COMMITTED": "Read committed",
129 "READ_UNCOMMITTED": "Read uncommitted",
130 "REPEATABLE_READ": "Repeatable read",
131 "SERIALIZABLE": "Serializable",
132 },
133 default=None,
134 optional=True,
135 )
136 connection_pool = Field[bool](
137 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
138 default=True,
139 )
140 connection_timeout = Field[float](
141 doc=(
142 "Maximum time to wait time for database lock to be released before exiting. "
143 "Defaults to sqlalchemy defaults if not set."
144 ),
145 default=None,
146 optional=True,
147 )
148 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
149 dia_object_index = ChoiceField[str](
150 doc="Indexing mode for DiaObject table",
151 allowed={
152 "baseline": "Index defined in baseline schema",
153 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
154 "last_object_table": "Separate DiaObjectLast table",
155 },
156 default="baseline",
157 )
158 htm_level = Field[int](doc="HTM indexing level", default=20)
159 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
160 htm_index_column = Field[str](
161 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
162 )
163 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names ra/dec columns in DiaObject table")
164 dia_object_columns = ListField[str](
165 doc="List of columns to read from DiaObject, by default read all columns", default=[]
166 )
167 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
168 namespace = Field[str](
169 doc=(
170 "Namespace or schema name for all tables in APDB database. "
171 "Presently only makes sense for PostgresQL backend. "
172 "If schema with this name does not exist it will be created when "
173 "APDB tables are created."
174 ),
175 default=None,
176 optional=True,
177 )
178 timer = Field[bool](doc="If True then print/log timing information", default=False)
180 def validate(self) -> None:
181 super().validate()
182 if len(self.ra_dec_columns) != 2:
183 raise ValueError("ra_dec_columns must have exactly two column names")
186class ApdbSqlTableData(ApdbTableData):
187 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
189 def __init__(self, result: sqlalchemy.engine.Result):
190 self._keys = list(result.keys())
191 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
193 def column_names(self) -> list[str]:
194 return self._keys
196 def rows(self) -> Iterable[tuple]:
197 return self._rows
200class ApdbSql(Apdb):
201 """Implementation of APDB interface based on SQL database.
203 The implementation is configured via standard ``pex_config`` mechanism
204 using `ApdbSqlConfig` configuration class. For an example of different
205 configurations check ``config/`` folder.
207 Parameters
208 ----------
209 config : `ApdbSqlConfig`
210 Configuration object.
211 """
213 ConfigClass = ApdbSqlConfig
215 def __init__(self, config: ApdbSqlConfig):
216 config.validate()
217 self.config = config
219 _LOG.debug("APDB Configuration:")
220 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
221 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
222 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
223 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
224 _LOG.debug(" schema_file: %s", self.config.schema_file)
225 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
226 _LOG.debug(" schema prefix: %s", self.config.prefix)
228 # engine is reused between multiple processes, make sure that we don't
229 # share connections by disabling pool (by using NullPool class)
230 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
231 conn_args: Dict[str, Any] = dict()
232 if not self.config.connection_pool:
233 kw.update(poolclass=NullPool)
234 if self.config.isolation_level is not None:
235 kw.update(isolation_level=self.config.isolation_level)
236 elif self.config.db_url.startswith("sqlite"): # type: ignore
237 # Use READ_UNCOMMITTED as default value for sqlite.
238 kw.update(isolation_level="READ_UNCOMMITTED")
239 if self.config.connection_timeout is not None:
240 if self.config.db_url.startswith("sqlite"):
241 conn_args.update(timeout=self.config.connection_timeout)
242 elif self.config.db_url.startswith(("postgresql", "mysql")):
243 conn_args.update(connect_timeout=self.config.connection_timeout)
244 kw.update(connect_args=conn_args)
245 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
247 self._schema = ApdbSqlSchema(
248 engine=self._engine,
249 dia_object_index=self.config.dia_object_index,
250 schema_file=self.config.schema_file,
251 schema_name=self.config.schema_name,
252 prefix=self.config.prefix,
253 namespace=self.config.namespace,
254 htm_index_column=self.config.htm_index_column,
255 use_insert_id=config.use_insert_id,
256 )
258 self.pixelator = HtmPixelization(self.config.htm_level)
259 self.use_insert_id = self._schema.has_insert_id
261 def tableRowCount(self) -> Dict[str, int]:
262 """Returns dictionary with the table names and row counts.
264 Used by ``ap_proto`` to keep track of the size of the database tables.
265 Depending on database technology this could be expensive operation.
267 Returns
268 -------
269 row_counts : `dict`
270 Dict where key is a table name and value is a row count.
271 """
272 res = {}
273 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
274 if self.config.dia_object_index == "last_object_table":
275 tables.append(ApdbTables.DiaObjectLast)
276 with self._engine.begin() as conn:
277 for table in tables:
278 sa_table = self._schema.get_table(table)
279 stmt = sql.select(func.count()).select_from(sa_table)
280 count: int = conn.execute(stmt).scalar_one()
281 res[table.name] = count
283 return res
285 def tableDef(self, table: ApdbTables) -> Optional[Table]:
286 # docstring is inherited from a base class
287 return self._schema.tableSchemas.get(table)
289 def makeSchema(self, drop: bool = False) -> None:
290 # docstring is inherited from a base class
291 self._schema.makeSchema(drop=drop)
293 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
294 # docstring is inherited from a base class
296 # decide what columns we need
297 if self.config.dia_object_index == "last_object_table":
298 table_enum = ApdbTables.DiaObjectLast
299 else:
300 table_enum = ApdbTables.DiaObject
301 table = self._schema.get_table(table_enum)
302 if not self.config.dia_object_columns:
303 columns = self._schema.get_apdb_columns(table_enum)
304 else:
305 columns = [table.c[col] for col in self.config.dia_object_columns]
306 query = sql.select(*columns)
308 # build selection
309 query = query.where(self._filterRegion(table, region))
311 # select latest version of objects
312 if self.config.dia_object_index != "last_object_table":
313 query = query.where(table.c.validityEnd == None) # noqa: E711
315 # _LOG.debug("query: %s", query)
317 # execute select
318 with Timer("DiaObject select", self.config.timer):
319 with self._engine.begin() as conn:
320 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
321 _LOG.debug("found %s DiaObjects", len(objects))
322 return objects
324 def getDiaSources(
325 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
326 ) -> Optional[pandas.DataFrame]:
327 # docstring is inherited from a base class
328 if self.config.read_sources_months == 0:
329 _LOG.debug("Skip DiaSources fetching")
330 return None
332 if object_ids is None:
333 # region-based select
334 return self._getDiaSourcesInRegion(region, visit_time)
335 else:
336 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
338 def getDiaForcedSources(
339 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
340 ) -> Optional[pandas.DataFrame]:
341 """Return catalog of DiaForcedSource instances from a given region.
343 Parameters
344 ----------
345 region : `lsst.sphgeom.Region`
346 Region to search for DIASources.
347 object_ids : iterable [ `int` ], optional
348 List of DiaObject IDs to further constrain the set of returned
349 sources. If list is empty then empty catalog is returned with a
350 correct schema.
351 visit_time : `lsst.daf.base.DateTime`
352 Time of the current visit.
354 Returns
355 -------
356 catalog : `pandas.DataFrame`, or `None`
357 Catalog containing DiaSource records. `None` is returned if
358 ``read_sources_months`` configuration parameter is set to 0.
360 Raises
361 ------
362 NotImplementedError
363 Raised if ``object_ids`` is `None`.
365 Notes
366 -----
367 Even though base class allows `None` to be passed for ``object_ids``,
368 this class requires ``object_ids`` to be not-`None`.
369 `NotImplementedError` is raised if `None` is passed.
371 This method returns DiaForcedSource catalog for a region with additional
372 filtering based on DiaObject IDs. Only a subset of DiaSource history
373 is returned limited by ``read_forced_sources_months`` config parameter,
374 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
375 is always returned with a correct schema (columns/types).
376 """
378 if self.config.read_forced_sources_months == 0:
379 _LOG.debug("Skip DiaForceSources fetching")
380 return None
382 if object_ids is None:
383 # This implementation does not support region-based selection.
384 raise NotImplementedError("Region-based selection is not supported")
386 # TODO: DateTime.MJD must be consistent with code in ap_association,
387 # alternatively we can fill midPointTai ourselves in store()
388 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
389 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
391 with Timer("DiaForcedSource select", self.config.timer):
392 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start)
394 _LOG.debug("found %s DiaForcedSources", len(sources))
395 return sources
397 def getInsertIds(self) -> list[ApdbInsertId] | None:
398 # docstring is inherited from a base class
399 if not self._schema.has_insert_id:
400 return None
402 table = self._schema.get_table(ExtraTables.DiaInsertId)
403 assert table is not None, "has_insert_id=True means it must be defined"
404 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"])
405 with Timer("DiaObject insert id select", self.config.timer):
406 with self._engine.connect() as conn:
407 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
408 return [ApdbInsertId(row) for row in result.scalars()]
410 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
411 # docstring is inherited from a base class
412 if not self._schema.has_insert_id:
413 raise ValueError("APDB is not configured for history storage")
415 table = self._schema.get_table(ExtraTables.DiaInsertId)
417 insert_ids = [id.id for id in ids]
418 where_clause = table.columns["insert_id"].in_(insert_ids)
419 stmt = table.delete().where(where_clause)
420 with self._engine.begin() as conn:
421 conn.execute(stmt)
423 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
424 # docstring is inherited from a base class
425 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
427 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
428 # docstring is inherited from a base class
429 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
431 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
432 # docstring is inherited from a base class
433 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
435 def _get_history(
436 self,
437 ids: Iterable[ApdbInsertId],
438 table_enum: ApdbTables,
439 history_table_enum: ExtraTables,
440 ) -> ApdbTableData:
441 """Common implementation of the history methods."""
442 if not self._schema.has_insert_id:
443 raise ValueError("APDB is not configured for history retrieval")
445 table = self._schema.get_table(table_enum)
446 history_table = self._schema.get_table(history_table_enum)
448 join = table.join(history_table)
449 insert_ids = [id.id for id in ids]
450 history_id_column = history_table.columns["insert_id"]
451 apdb_columns = self._schema.get_apdb_columns(table_enum)
452 where_clause = history_id_column.in_(insert_ids)
453 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
455 # execute select
456 with Timer(f"{table.name} history select", self.config.timer):
457 with self._engine.begin() as conn:
458 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
459 return ApdbSqlTableData(result)
461 def getSSObjects(self) -> pandas.DataFrame:
462 # docstring is inherited from a base class
464 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
465 query = sql.select(*columns)
467 # execute select
468 with Timer("DiaObject select", self.config.timer):
469 with self._engine.begin() as conn:
470 objects = pandas.read_sql_query(query, conn)
471 _LOG.debug("found %s SSObjects", len(objects))
472 return objects
474 def store(
475 self,
476 visit_time: dafBase.DateTime,
477 objects: pandas.DataFrame,
478 sources: Optional[pandas.DataFrame] = None,
479 forced_sources: Optional[pandas.DataFrame] = None,
480 ) -> None:
481 # docstring is inherited from a base class
483 # We want to run all inserts in one transaction.
484 with self._engine.begin() as connection:
485 insert_id: ApdbInsertId | None = None
486 if self._schema.has_insert_id:
487 insert_id = ApdbInsertId.new_insert_id()
488 self._storeInsertId(insert_id, visit_time, connection)
490 # fill pixelId column for DiaObjects
491 objects = self._add_obj_htm_index(objects)
492 self._storeDiaObjects(objects, visit_time, insert_id, connection)
494 if sources is not None:
495 # copy pixelId column from DiaObjects to DiaSources
496 sources = self._add_src_htm_index(sources, objects)
497 self._storeDiaSources(sources, insert_id, connection)
499 if forced_sources is not None:
500 self._storeDiaForcedSources(forced_sources, insert_id, connection)
502 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
503 # docstring is inherited from a base class
505 idColumn = "ssObjectId"
506 table = self._schema.get_table(ApdbTables.SSObject)
508 # everything to be done in single transaction
509 with self._engine.begin() as conn:
510 # Find record IDs that already exist. Some types like np.int64 can
511 # cause issues with sqlalchemy, convert them to int.
512 ids = sorted(int(oid) for oid in objects[idColumn])
514 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
515 result = conn.execute(query)
516 knownIds = set(row.ssObjectId for row in result)
518 filter = objects[idColumn].isin(knownIds)
519 toUpdate = cast(pandas.DataFrame, objects[filter])
520 toInsert = cast(pandas.DataFrame, objects[~filter])
522 # insert new records
523 if len(toInsert) > 0:
524 toInsert.to_sql(
525 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
526 )
528 # update existing records
529 if len(toUpdate) > 0:
530 whereKey = f"{idColumn}_param"
531 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
532 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
533 values = toUpdate.to_dict("records")
534 result = conn.execute(update, values)
536 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
537 # docstring is inherited from a base class
539 table = self._schema.get_table(ApdbTables.DiaSource)
540 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
542 with self._engine.begin() as conn:
543 # Need to make sure that every ID exists in the database, but
544 # executemany may not support rowcount, so iterate and check what is
545 # missing.
546 missing_ids: List[int] = []
547 for key, value in idMap.items():
548 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
549 result = conn.execute(query, params)
550 if result.rowcount == 0:
551 missing_ids.append(key)
552 if missing_ids:
553 missing = ",".join(str(item) for item in missing_ids)
554 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
556 def dailyJob(self) -> None:
557 # docstring is inherited from a base class
558 pass
560 def countUnassociatedObjects(self) -> int:
561 # docstring is inherited from a base class
563 # Retrieve the DiaObject table.
564 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
566 # Construct the sql statement.
567 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
568 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
570 # Return the count.
571 with self._engine.begin() as conn:
572 count = conn.execute(stmt).scalar_one()
574 return count
576 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
577 """Returns catalog of DiaSource instances from given region.
579 Parameters
580 ----------
581 region : `lsst.sphgeom.Region`
582 Region to search for DIASources.
583 visit_time : `lsst.daf.base.DateTime`
584 Time of the current visit.
586 Returns
587 -------
588 catalog : `pandas.DataFrame`
589 Catalog containing DiaSource records.
590 """
591 # TODO: DateTime.MJD must be consistent with code in ap_association,
592 # alternatively we can fill midPointTai ourselves in store()
593 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
594 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
596 table = self._schema.get_table(ApdbTables.DiaSource)
597 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
598 query = sql.select(*columns)
600 # build selection
601 time_filter = table.columns["midPointTai"] > midPointTai_start
602 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
603 query = query.where(where)
605 # execute select
606 with Timer("DiaSource select", self.config.timer):
607 with self._engine.begin() as conn:
608 sources = pandas.read_sql_query(query, conn)
609 _LOG.debug("found %s DiaSources", len(sources))
610 return sources
612 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
613 """Returns catalog of DiaSource instances given set of DiaObject IDs.
615 Parameters
616 ----------
617 object_ids :
618 Collection of DiaObject IDs
619 visit_time : `lsst.daf.base.DateTime`
620 Time of the current visit.
622 Returns
623 -------
624 catalog : `pandas.DataFrame`
625 Catalog contaning DiaSource records.
626 """
627 # TODO: DateTime.MJD must be consistent with code in ap_association,
628 # alternatively we can fill midPointTai ourselves in store()
629 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
630 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
632 with Timer("DiaSource select", self.config.timer):
633 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start)
635 _LOG.debug("found %s DiaSources", len(sources))
636 return sources
638 def _getSourcesByIDs(
639 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float
640 ) -> pandas.DataFrame:
641 """Returns catalog of DiaSource or DiaForcedSource instances given set
642 of DiaObject IDs.
644 Parameters
645 ----------
646 table : `sqlalchemy.schema.Table`
647 Database table.
648 object_ids :
649 Collection of DiaObject IDs
650 midPointTai_start : `float`
651 Earliest midPointTai to retrieve.
653 Returns
654 -------
655 catalog : `pandas.DataFrame`
656 Catalog contaning DiaSource records. `None` is returned if
657 ``read_sources_months`` configuration parameter is set to 0 or
658 when ``object_ids`` is empty.
659 """
660 table = self._schema.get_table(table_enum)
661 columns = self._schema.get_apdb_columns(table_enum)
663 sources: Optional[pandas.DataFrame] = None
664 if len(object_ids) <= 0:
665 _LOG.debug("ID list is empty, just fetch empty result")
666 query = sql.select(*columns).where(sql.literal(False))
667 with self._engine.begin() as conn:
668 sources = pandas.read_sql_query(query, conn)
669 else:
670 data_frames: list[pandas.DataFrame] = []
671 for ids in chunk_iterable(sorted(object_ids), 1000):
672 query = sql.select(*columns)
674 # Some types like np.int64 can cause issues with
675 # sqlalchemy, convert them to int.
676 int_ids = [int(oid) for oid in ids]
678 # select by object id
679 query = query.where(
680 sql.expression.and_(
681 table.columns["diaObjectId"].in_(int_ids),
682 table.columns["midPointTai"] > midPointTai_start,
683 )
684 )
686 # execute select
687 with self._engine.begin() as conn:
688 data_frames.append(pandas.read_sql_query(query, conn))
690 if len(data_frames) == 1:
691 sources = data_frames[0]
692 else:
693 sources = pandas.concat(data_frames)
694 assert sources is not None, "Catalog cannot be None"
695 return sources
697 def _storeInsertId(
698 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
699 ) -> None:
700 dt = visit_time.toPython()
702 table = self._schema.get_table(ExtraTables.DiaInsertId)
704 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
705 connection.execute(stmt)
707 def _storeDiaObjects(
708 self,
709 objs: pandas.DataFrame,
710 visit_time: dafBase.DateTime,
711 insert_id: ApdbInsertId | None,
712 connection: sqlalchemy.engine.Connection,
713 ) -> None:
714 """Store catalog of DiaObjects from current visit.
716 Parameters
717 ----------
718 objs : `pandas.DataFrame`
719 Catalog with DiaObject records.
720 visit_time : `lsst.daf.base.DateTime`
721 Time of the visit.
722 insert_id : `ApdbInsertId`
723 Insert identifier.
724 """
726 # Some types like np.int64 can cause issues with sqlalchemy, convert
727 # them to int.
728 ids = sorted(int(oid) for oid in objs["diaObjectId"])
729 _LOG.debug("first object ID: %d", ids[0])
731 # TODO: Need to verify that we are using correct scale here for
732 # DATETIME representation (see DM-31996).
733 dt = visit_time.toPython()
735 # everything to be done in single transaction
736 if self.config.dia_object_index == "last_object_table":
737 # insert and replace all records in LAST table, mysql and postgres have
738 # non-standard features
739 table = self._schema.get_table(ApdbTables.DiaObjectLast)
741 # Drop the previous objects (pandas cannot upsert).
742 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
744 with Timer(table.name + " delete", self.config.timer):
745 res = connection.execute(query)
746 _LOG.debug("deleted %s objects", res.rowcount)
748 # DiaObjectLast is a subset of DiaObject, strip missing columns
749 last_column_names = [column.name for column in table.columns]
750 last_objs = objs[last_column_names]
751 last_objs = _coerce_uint64(last_objs)
753 if "lastNonForcedSource" in last_objs.columns:
754 # lastNonForcedSource is defined NOT NULL, fill it with visit time
755 # just in case.
756 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
757 else:
758 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
759 last_objs.set_index(extra_column.index, inplace=True)
760 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
762 with Timer("DiaObjectLast insert", self.config.timer):
763 last_objs.to_sql(
764 table.name,
765 _ConnectionHackSA2(connection),
766 if_exists="append",
767 index=False,
768 schema=table.schema,
769 )
770 else:
771 # truncate existing validity intervals
772 table = self._schema.get_table(ApdbTables.DiaObject)
774 update = (
775 table.update()
776 .values(validityEnd=dt)
777 .where(
778 sql.expression.and_(
779 table.columns["diaObjectId"].in_(ids),
780 table.columns["validityEnd"].is_(None),
781 )
782 )
783 )
785 # _LOG.debug("query: %s", query)
787 with Timer(table.name + " truncate", self.config.timer):
788 res = connection.execute(update)
789 _LOG.debug("truncated %s intervals", res.rowcount)
791 objs = _coerce_uint64(objs)
793 # Fill additional columns
794 extra_columns: List[pandas.Series] = []
795 if "validityStart" in objs.columns:
796 objs["validityStart"] = dt
797 else:
798 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
799 if "validityEnd" in objs.columns:
800 objs["validityEnd"] = None
801 else:
802 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
803 if "lastNonForcedSource" in objs.columns:
804 # lastNonForcedSource is defined NOT NULL, fill it with visit time
805 # just in case.
806 objs["lastNonForcedSource"].fillna(dt, inplace=True)
807 else:
808 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
809 if extra_columns:
810 objs.set_index(extra_columns[0].index, inplace=True)
811 objs = pandas.concat([objs] + extra_columns, axis="columns")
813 # Insert history data
814 table = self._schema.get_table(ApdbTables.DiaObject)
815 history_data: list[dict] = []
816 history_stmt: Any = None
817 if insert_id is not None:
818 pk_names = [column.name for column in table.primary_key]
819 history_data = objs[pk_names].to_dict("records")
820 for row in history_data:
821 row["insert_id"] = insert_id.id
822 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
823 history_stmt = history_table.insert()
825 # insert new versions
826 with Timer("DiaObject insert", self.config.timer):
827 objs.to_sql(
828 table.name,
829 _ConnectionHackSA2(connection),
830 if_exists="append",
831 index=False,
832 schema=table.schema,
833 )
834 if history_stmt is not None:
835 connection.execute(history_stmt, history_data)
837 def _storeDiaSources(
838 self,
839 sources: pandas.DataFrame,
840 insert_id: ApdbInsertId | None,
841 connection: sqlalchemy.engine.Connection,
842 ) -> None:
843 """Store catalog of DiaSources from current visit.
845 Parameters
846 ----------
847 sources : `pandas.DataFrame`
848 Catalog containing DiaSource records
849 """
850 table = self._schema.get_table(ApdbTables.DiaSource)
852 # Insert history data
853 history: list[dict] = []
854 history_stmt: Any = None
855 if insert_id is not None:
856 pk_names = [column.name for column in table.primary_key]
857 history = sources[pk_names].to_dict("records")
858 for row in history:
859 row["insert_id"] = insert_id.id
860 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
861 history_stmt = history_table.insert()
863 # everything to be done in single transaction
864 with Timer("DiaSource insert", self.config.timer):
865 sources = _coerce_uint64(sources)
866 sources.to_sql(
867 table.name,
868 _ConnectionHackSA2(connection),
869 if_exists="append",
870 index=False,
871 schema=table.schema,
872 )
873 if history_stmt is not None:
874 connection.execute(history_stmt, history)
876 def _storeDiaForcedSources(
877 self,
878 sources: pandas.DataFrame,
879 insert_id: ApdbInsertId | None,
880 connection: sqlalchemy.engine.Connection,
881 ) -> None:
882 """Store a set of DiaForcedSources from current visit.
884 Parameters
885 ----------
886 sources : `pandas.DataFrame`
887 Catalog containing DiaForcedSource records
888 """
889 table = self._schema.get_table(ApdbTables.DiaForcedSource)
891 # Insert history data
892 history: list[dict] = []
893 history_stmt: Any = None
894 if insert_id is not None:
895 pk_names = [column.name for column in table.primary_key]
896 history = sources[pk_names].to_dict("records")
897 for row in history:
898 row["insert_id"] = insert_id.id
899 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
900 history_stmt = history_table.insert()
902 # everything to be done in single transaction
903 with Timer("DiaForcedSource insert", self.config.timer):
904 sources = _coerce_uint64(sources)
905 sources.to_sql(
906 table.name,
907 _ConnectionHackSA2(connection),
908 if_exists="append",
909 index=False,
910 schema=table.schema,
911 )
912 if history_stmt is not None:
913 connection.execute(history_stmt, history)
915 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
916 """Generate a set of HTM indices covering specified region.
918 Parameters
919 ----------
920 region: `sphgeom.Region`
921 Region that needs to be indexed.
923 Returns
924 -------
925 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
926 """
927 _LOG.debug("region: %s", region)
928 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
930 return indices.ranges()
932 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
933 """Make SQLAlchemy expression for selecting records in a region."""
934 htm_index_column = table.columns[self.config.htm_index_column]
935 exprlist = []
936 pixel_ranges = self._htm_indices(region)
937 for low, upper in pixel_ranges:
938 upper -= 1
939 if low == upper:
940 exprlist.append(htm_index_column == low)
941 else:
942 exprlist.append(sql.expression.between(htm_index_column, low, upper))
944 return sql.expression.or_(*exprlist)
946 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
947 """Calculate HTM index for each record and add it to a DataFrame.
949 Notes
950 -----
951 This overrides any existing column in a DataFrame with the same name
952 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
953 returned.
954 """
955 # calculate HTM index for every DiaObject
956 htm_index = np.zeros(df.shape[0], dtype=np.int64)
957 ra_col, dec_col = self.config.ra_dec_columns
958 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
959 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
960 idx = self.pixelator.index(uv3d)
961 htm_index[i] = idx
962 df = df.copy()
963 df[self.config.htm_index_column] = htm_index
964 return df
966 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
967 """Add pixelId column to DiaSource catalog.
969 Notes
970 -----
971 This method copies pixelId value from a matching DiaObject record.
972 DiaObject catalog needs to have a pixelId column filled by
973 ``_add_obj_htm_index`` method and DiaSource records need to be
974 associated to DiaObjects via ``diaObjectId`` column.
976 This overrides any existing column in a DataFrame with the same name
977 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
978 returned.
979 """
980 pixel_id_map: Dict[int, int] = {
981 diaObjectId: pixelId
982 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
983 }
984 # DiaSources associated with SolarSystemObjects do not have an
985 # associated DiaObject hence we skip them and set their htmIndex
986 # value to 0.
987 pixel_id_map[0] = 0
988 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
989 for i, diaObjId in enumerate(sources["diaObjectId"]):
990 htm_index[i] = pixel_id_map[diaObjId]
991 sources = sources.copy()
992 sources[self.config.htm_index_column] = htm_index
993 return sources