Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
428 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 10:00 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 10:00 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from typing import Any, Dict, List, Optional, Tuple, cast
33import lsst.daf.base as dafBase
34import numpy as np
35import pandas
36import sqlalchemy
37from felis.simple import Table
38from lsst.pex.config import ChoiceField, Field, ListField
39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
40from lsst.utils.iteration import chunk_iterable
41from sqlalchemy import func, inspection, sql
42from sqlalchemy.engine import Inspector
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbSchema import ApdbTables
47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
48from .timer import Timer
50_LOG = logging.getLogger(__name__)
53if pandas.__version__.partition(".")[0] == "1": 53 ↛ 55line 53 didn't jump to line 55, because the condition on line 53 was never true
55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
56 """Terrible hack to workaround Pandas 1 incomplete support for
57 sqlalchemy 2.
59 We need to pass a Connection instance to pandas method, but in SA 2 the
60 Connection class lost ``connect`` method which is used by Pandas.
61 """
63 def __init__(self, connection: sqlalchemy.engine.Connection):
64 self._connection = connection
66 def connect(self, **kwargs: Any) -> Any:
67 return self
69 @property
70 def execute(self) -> Callable:
71 return self._connection.execute
73 @property
74 def execution_options(self) -> Callable:
75 return self._connection.execution_options
77 @property
78 def connection(self) -> Any:
79 return self._connection.connection
81 def __enter__(self) -> sqlalchemy.engine.Connection:
82 return self._connection
84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
85 # Do not close connection here
86 pass
88 @inspection._inspects(_ConnectionHackSA2)
89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
90 return Inspector._construct(Inspector._init_connection, conn._connection)
92else:
93 # Pandas 2.0 supports SQLAlchemy 2 correctly.
94 def _ConnectionHackSA2( # type: ignore[no-redef]
95 conn: sqlalchemy.engine.Connectable,
96 ) -> sqlalchemy.engine.Connectable:
97 return conn
100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
101 """Change type of the uint64 columns to int64, return copy of data frame."""
102 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
103 return df.astype({name: np.int64 for name in names})
106def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
107 """Calculate starting point for time-based source search.
109 Parameters
110 ----------
111 visit_time : `lsst.daf.base.DateTime`
112 Time of current visit.
113 months : `int`
114 Number of months in the sources history.
116 Returns
117 -------
118 time : `float`
119 A ``midPointTai`` starting point, MJD time.
120 """
121 # TODO: `system` must be consistent with the code in ap_association
122 # (see DM-31996)
123 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
126class ApdbSqlConfig(ApdbConfig):
127 """APDB configuration class for SQL implementation (ApdbSql)."""
129 db_url = Field[str](doc="SQLAlchemy database connection URI")
130 isolation_level = ChoiceField[str](
131 doc=(
132 "Transaction isolation level, if unset then backend-default value "
133 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
134 "Some backends may not support every allowed value."
135 ),
136 allowed={
137 "READ_COMMITTED": "Read committed",
138 "READ_UNCOMMITTED": "Read uncommitted",
139 "REPEATABLE_READ": "Repeatable read",
140 "SERIALIZABLE": "Serializable",
141 },
142 default=None,
143 optional=True,
144 )
145 connection_pool = Field[bool](
146 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
147 default=True,
148 )
149 connection_timeout = Field[float](
150 doc=(
151 "Maximum time to wait time for database lock to be released before exiting. "
152 "Defaults to sqlalchemy defaults if not set."
153 ),
154 default=None,
155 optional=True,
156 )
157 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
158 dia_object_index = ChoiceField[str](
159 doc="Indexing mode for DiaObject table",
160 allowed={
161 "baseline": "Index defined in baseline schema",
162 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
163 "last_object_table": "Separate DiaObjectLast table",
164 },
165 default="baseline",
166 )
167 htm_level = Field[int](doc="HTM indexing level", default=20)
168 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
169 htm_index_column = Field[str](
170 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
171 )
172 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names of ra/dec columns in DiaObject table")
173 dia_object_columns = ListField[str](
174 doc="List of columns to read from DiaObject, by default read all columns", default=[]
175 )
176 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
177 namespace = Field[str](
178 doc=(
179 "Namespace or schema name for all tables in APDB database. "
180 "Presently only works for PostgreSQL backend. "
181 "If schema with this name does not exist it will be created when "
182 "APDB tables are created."
183 ),
184 default=None,
185 optional=True,
186 )
187 timer = Field[bool](doc="If True then print/log timing information", default=False)
189 def validate(self) -> None:
190 super().validate()
191 if len(self.ra_dec_columns) != 2:
192 raise ValueError("ra_dec_columns must have exactly two column names")
195class ApdbSqlTableData(ApdbTableData):
196 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
198 def __init__(self, result: sqlalchemy.engine.Result):
199 self._keys = list(result.keys())
200 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
202 def column_names(self) -> list[str]:
203 return self._keys
205 def rows(self) -> Iterable[tuple]:
206 return self._rows
209class ApdbSql(Apdb):
210 """Implementation of APDB interface based on SQL database.
212 The implementation is configured via standard ``pex_config`` mechanism
213 using `ApdbSqlConfig` configuration class. For an example of different
214 configurations check ``config/`` folder.
216 Parameters
217 ----------
218 config : `ApdbSqlConfig`
219 Configuration object.
220 """
222 ConfigClass = ApdbSqlConfig
224 def __init__(self, config: ApdbSqlConfig):
225 config.validate()
226 self.config = config
228 _LOG.debug("APDB Configuration:")
229 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
230 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
231 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
232 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
233 _LOG.debug(" schema_file: %s", self.config.schema_file)
234 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
235 _LOG.debug(" schema prefix: %s", self.config.prefix)
237 # engine is reused between multiple processes, make sure that we don't
238 # share connections by disabling pool (by using NullPool class)
239 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
240 conn_args: Dict[str, Any] = dict()
241 if not self.config.connection_pool:
242 kw.update(poolclass=NullPool)
243 if self.config.isolation_level is not None:
244 kw.update(isolation_level=self.config.isolation_level)
245 elif self.config.db_url.startswith("sqlite"): # type: ignore
246 # Use READ_UNCOMMITTED as default value for sqlite.
247 kw.update(isolation_level="READ_UNCOMMITTED")
248 if self.config.connection_timeout is not None:
249 if self.config.db_url.startswith("sqlite"):
250 conn_args.update(timeout=self.config.connection_timeout)
251 elif self.config.db_url.startswith(("postgresql", "mysql")):
252 conn_args.update(connect_timeout=self.config.connection_timeout)
253 kw.update(connect_args=conn_args)
254 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
256 self._schema = ApdbSqlSchema(
257 engine=self._engine,
258 dia_object_index=self.config.dia_object_index,
259 schema_file=self.config.schema_file,
260 schema_name=self.config.schema_name,
261 prefix=self.config.prefix,
262 namespace=self.config.namespace,
263 htm_index_column=self.config.htm_index_column,
264 use_insert_id=config.use_insert_id,
265 )
267 self.pixelator = HtmPixelization(self.config.htm_level)
268 self.use_insert_id = self._schema.has_insert_id
270 def tableRowCount(self) -> Dict[str, int]:
271 """Returns dictionary with the table names and row counts.
273 Used by ``ap_proto`` to keep track of the size of the database tables.
274 Depending on database technology this could be expensive operation.
276 Returns
277 -------
278 row_counts : `dict`
279 Dict where key is a table name and value is a row count.
280 """
281 res = {}
282 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
283 if self.config.dia_object_index == "last_object_table":
284 tables.append(ApdbTables.DiaObjectLast)
285 with self._engine.begin() as conn:
286 for table in tables:
287 sa_table = self._schema.get_table(table)
288 stmt = sql.select(func.count()).select_from(sa_table)
289 count: int = conn.execute(stmt).scalar_one()
290 res[table.name] = count
292 return res
294 def tableDef(self, table: ApdbTables) -> Optional[Table]:
295 # docstring is inherited from a base class
296 return self._schema.tableSchemas.get(table)
298 def makeSchema(self, drop: bool = False) -> None:
299 # docstring is inherited from a base class
300 self._schema.makeSchema(drop=drop)
302 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
303 # docstring is inherited from a base class
305 # decide what columns we need
306 if self.config.dia_object_index == "last_object_table":
307 table_enum = ApdbTables.DiaObjectLast
308 else:
309 table_enum = ApdbTables.DiaObject
310 table = self._schema.get_table(table_enum)
311 if not self.config.dia_object_columns:
312 columns = self._schema.get_apdb_columns(table_enum)
313 else:
314 columns = [table.c[col] for col in self.config.dia_object_columns]
315 query = sql.select(*columns)
317 # build selection
318 query = query.where(self._filterRegion(table, region))
320 # select latest version of objects
321 if self.config.dia_object_index != "last_object_table":
322 query = query.where(table.c.validityEnd == None) # noqa: E711
324 # _LOG.debug("query: %s", query)
326 # execute select
327 with Timer("DiaObject select", self.config.timer):
328 with self._engine.begin() as conn:
329 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
330 _LOG.debug("found %s DiaObjects", len(objects))
331 return objects
333 def getDiaSources(
334 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
335 ) -> Optional[pandas.DataFrame]:
336 # docstring is inherited from a base class
337 if self.config.read_sources_months == 0:
338 _LOG.debug("Skip DiaSources fetching")
339 return None
341 if object_ids is None:
342 # region-based select
343 return self._getDiaSourcesInRegion(region, visit_time)
344 else:
345 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
347 def getDiaForcedSources(
348 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
349 ) -> Optional[pandas.DataFrame]:
350 """Return catalog of DiaForcedSource instances from a given region.
352 Parameters
353 ----------
354 region : `lsst.sphgeom.Region`
355 Region to search for DIASources.
356 object_ids : iterable [ `int` ], optional
357 List of DiaObject IDs to further constrain the set of returned
358 sources. If list is empty then empty catalog is returned with a
359 correct schema.
360 visit_time : `lsst.daf.base.DateTime`
361 Time of the current visit.
363 Returns
364 -------
365 catalog : `pandas.DataFrame`, or `None`
366 Catalog containing DiaSource records. `None` is returned if
367 ``read_sources_months`` configuration parameter is set to 0.
369 Raises
370 ------
371 NotImplementedError
372 Raised if ``object_ids`` is `None`.
374 Notes
375 -----
376 Even though base class allows `None` to be passed for ``object_ids``,
377 this class requires ``object_ids`` to be not-`None`.
378 `NotImplementedError` is raised if `None` is passed.
380 This method returns DiaForcedSource catalog for a region with additional
381 filtering based on DiaObject IDs. Only a subset of DiaSource history
382 is returned limited by ``read_forced_sources_months`` config parameter,
383 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
384 is always returned with a correct schema (columns/types).
385 """
387 if self.config.read_forced_sources_months == 0:
388 _LOG.debug("Skip DiaForceSources fetching")
389 return None
391 if object_ids is None:
392 # This implementation does not support region-based selection.
393 raise NotImplementedError("Region-based selection is not supported")
395 # TODO: DateTime.MJD must be consistent with code in ap_association,
396 # alternatively we can fill midPointTai ourselves in store()
397 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
398 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
400 with Timer("DiaForcedSource select", self.config.timer):
401 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start)
403 _LOG.debug("found %s DiaForcedSources", len(sources))
404 return sources
406 def getInsertIds(self) -> list[ApdbInsertId] | None:
407 # docstring is inherited from a base class
408 if not self._schema.has_insert_id:
409 return None
411 table = self._schema.get_table(ExtraTables.DiaInsertId)
412 assert table is not None, "has_insert_id=True means it must be defined"
413 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"])
414 with Timer("DiaObject insert id select", self.config.timer):
415 with self._engine.connect() as conn:
416 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
417 return [ApdbInsertId(row) for row in result.scalars()]
419 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
420 # docstring is inherited from a base class
421 if not self._schema.has_insert_id:
422 raise ValueError("APDB is not configured for history storage")
424 table = self._schema.get_table(ExtraTables.DiaInsertId)
426 insert_ids = [id.id for id in ids]
427 where_clause = table.columns["insert_id"].in_(insert_ids)
428 stmt = table.delete().where(where_clause)
429 with self._engine.begin() as conn:
430 conn.execute(stmt)
432 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
433 # docstring is inherited from a base class
434 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
436 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
437 # docstring is inherited from a base class
438 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
440 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
441 # docstring is inherited from a base class
442 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
444 def _get_history(
445 self,
446 ids: Iterable[ApdbInsertId],
447 table_enum: ApdbTables,
448 history_table_enum: ExtraTables,
449 ) -> ApdbTableData:
450 """Common implementation of the history methods."""
451 if not self._schema.has_insert_id:
452 raise ValueError("APDB is not configured for history retrieval")
454 table = self._schema.get_table(table_enum)
455 history_table = self._schema.get_table(history_table_enum)
457 join = table.join(history_table)
458 insert_ids = [id.id for id in ids]
459 history_id_column = history_table.columns["insert_id"]
460 apdb_columns = self._schema.get_apdb_columns(table_enum)
461 where_clause = history_id_column.in_(insert_ids)
462 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
464 # execute select
465 with Timer(f"{table.name} history select", self.config.timer):
466 with self._engine.begin() as conn:
467 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
468 return ApdbSqlTableData(result)
470 def getSSObjects(self) -> pandas.DataFrame:
471 # docstring is inherited from a base class
473 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
474 query = sql.select(*columns)
476 # execute select
477 with Timer("DiaObject select", self.config.timer):
478 with self._engine.begin() as conn:
479 objects = pandas.read_sql_query(query, conn)
480 _LOG.debug("found %s SSObjects", len(objects))
481 return objects
483 def store(
484 self,
485 visit_time: dafBase.DateTime,
486 objects: pandas.DataFrame,
487 sources: Optional[pandas.DataFrame] = None,
488 forced_sources: Optional[pandas.DataFrame] = None,
489 ) -> None:
490 # docstring is inherited from a base class
492 # We want to run all inserts in one transaction.
493 with self._engine.begin() as connection:
494 insert_id: ApdbInsertId | None = None
495 if self._schema.has_insert_id:
496 insert_id = ApdbInsertId.new_insert_id()
497 self._storeInsertId(insert_id, visit_time, connection)
499 # fill pixelId column for DiaObjects
500 objects = self._add_obj_htm_index(objects)
501 self._storeDiaObjects(objects, visit_time, insert_id, connection)
503 if sources is not None:
504 # copy pixelId column from DiaObjects to DiaSources
505 sources = self._add_src_htm_index(sources, objects)
506 self._storeDiaSources(sources, insert_id, connection)
508 if forced_sources is not None:
509 self._storeDiaForcedSources(forced_sources, insert_id, connection)
511 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
512 # docstring is inherited from a base class
514 idColumn = "ssObjectId"
515 table = self._schema.get_table(ApdbTables.SSObject)
517 # everything to be done in single transaction
518 with self._engine.begin() as conn:
519 # Find record IDs that already exist. Some types like np.int64 can
520 # cause issues with sqlalchemy, convert them to int.
521 ids = sorted(int(oid) for oid in objects[idColumn])
523 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
524 result = conn.execute(query)
525 knownIds = set(row.ssObjectId for row in result)
527 filter = objects[idColumn].isin(knownIds)
528 toUpdate = cast(pandas.DataFrame, objects[filter])
529 toInsert = cast(pandas.DataFrame, objects[~filter])
531 # insert new records
532 if len(toInsert) > 0:
533 toInsert.to_sql(
534 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
535 )
537 # update existing records
538 if len(toUpdate) > 0:
539 whereKey = f"{idColumn}_param"
540 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
541 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
542 values = toUpdate.to_dict("records")
543 result = conn.execute(update, values)
545 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
546 # docstring is inherited from a base class
548 table = self._schema.get_table(ApdbTables.DiaSource)
549 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
551 with self._engine.begin() as conn:
552 # Need to make sure that every ID exists in the database, but
553 # executemany may not support rowcount, so iterate and check what is
554 # missing.
555 missing_ids: List[int] = []
556 for key, value in idMap.items():
557 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
558 result = conn.execute(query, params)
559 if result.rowcount == 0:
560 missing_ids.append(key)
561 if missing_ids:
562 missing = ",".join(str(item) for item in missing_ids)
563 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
565 def dailyJob(self) -> None:
566 # docstring is inherited from a base class
567 pass
569 def countUnassociatedObjects(self) -> int:
570 # docstring is inherited from a base class
572 # Retrieve the DiaObject table.
573 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
575 # Construct the sql statement.
576 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
577 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
579 # Return the count.
580 with self._engine.begin() as conn:
581 count = conn.execute(stmt).scalar_one()
583 return count
585 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
586 """Returns catalog of DiaSource instances from given region.
588 Parameters
589 ----------
590 region : `lsst.sphgeom.Region`
591 Region to search for DIASources.
592 visit_time : `lsst.daf.base.DateTime`
593 Time of the current visit.
595 Returns
596 -------
597 catalog : `pandas.DataFrame`
598 Catalog containing DiaSource records.
599 """
600 # TODO: DateTime.MJD must be consistent with code in ap_association,
601 # alternatively we can fill midPointTai ourselves in store()
602 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
603 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
605 table = self._schema.get_table(ApdbTables.DiaSource)
606 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
607 query = sql.select(*columns)
609 # build selection
610 time_filter = table.columns["midPointTai"] > midPointTai_start
611 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
612 query = query.where(where)
614 # execute select
615 with Timer("DiaSource select", self.config.timer):
616 with self._engine.begin() as conn:
617 sources = pandas.read_sql_query(query, conn)
618 _LOG.debug("found %s DiaSources", len(sources))
619 return sources
621 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
622 """Returns catalog of DiaSource instances given set of DiaObject IDs.
624 Parameters
625 ----------
626 object_ids :
627 Collection of DiaObject IDs
628 visit_time : `lsst.daf.base.DateTime`
629 Time of the current visit.
631 Returns
632 -------
633 catalog : `pandas.DataFrame`
634 Catalog contaning DiaSource records.
635 """
636 # TODO: DateTime.MJD must be consistent with code in ap_association,
637 # alternatively we can fill midPointTai ourselves in store()
638 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
639 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
641 with Timer("DiaSource select", self.config.timer):
642 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start)
644 _LOG.debug("found %s DiaSources", len(sources))
645 return sources
647 def _getSourcesByIDs(
648 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float
649 ) -> pandas.DataFrame:
650 """Returns catalog of DiaSource or DiaForcedSource instances given set
651 of DiaObject IDs.
653 Parameters
654 ----------
655 table : `sqlalchemy.schema.Table`
656 Database table.
657 object_ids :
658 Collection of DiaObject IDs
659 midPointTai_start : `float`
660 Earliest midPointTai to retrieve.
662 Returns
663 -------
664 catalog : `pandas.DataFrame`
665 Catalog contaning DiaSource records. `None` is returned if
666 ``read_sources_months`` configuration parameter is set to 0 or
667 when ``object_ids`` is empty.
668 """
669 table = self._schema.get_table(table_enum)
670 columns = self._schema.get_apdb_columns(table_enum)
672 sources: Optional[pandas.DataFrame] = None
673 if len(object_ids) <= 0:
674 _LOG.debug("ID list is empty, just fetch empty result")
675 query = sql.select(*columns).where(sql.literal(False))
676 with self._engine.begin() as conn:
677 sources = pandas.read_sql_query(query, conn)
678 else:
679 data_frames: list[pandas.DataFrame] = []
680 for ids in chunk_iterable(sorted(object_ids), 1000):
681 query = sql.select(*columns)
683 # Some types like np.int64 can cause issues with
684 # sqlalchemy, convert them to int.
685 int_ids = [int(oid) for oid in ids]
687 # select by object id
688 query = query.where(
689 sql.expression.and_(
690 table.columns["diaObjectId"].in_(int_ids),
691 table.columns["midPointTai"] > midPointTai_start,
692 )
693 )
695 # execute select
696 with self._engine.begin() as conn:
697 data_frames.append(pandas.read_sql_query(query, conn))
699 if len(data_frames) == 1:
700 sources = data_frames[0]
701 else:
702 sources = pandas.concat(data_frames)
703 assert sources is not None, "Catalog cannot be None"
704 return sources
706 def _storeInsertId(
707 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
708 ) -> None:
709 dt = visit_time.toPython()
711 table = self._schema.get_table(ExtraTables.DiaInsertId)
713 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
714 connection.execute(stmt)
716 def _storeDiaObjects(
717 self,
718 objs: pandas.DataFrame,
719 visit_time: dafBase.DateTime,
720 insert_id: ApdbInsertId | None,
721 connection: sqlalchemy.engine.Connection,
722 ) -> None:
723 """Store catalog of DiaObjects from current visit.
725 Parameters
726 ----------
727 objs : `pandas.DataFrame`
728 Catalog with DiaObject records.
729 visit_time : `lsst.daf.base.DateTime`
730 Time of the visit.
731 insert_id : `ApdbInsertId`
732 Insert identifier.
733 """
735 # Some types like np.int64 can cause issues with sqlalchemy, convert
736 # them to int.
737 ids = sorted(int(oid) for oid in objs["diaObjectId"])
738 _LOG.debug("first object ID: %d", ids[0])
740 # TODO: Need to verify that we are using correct scale here for
741 # DATETIME representation (see DM-31996).
742 dt = visit_time.toPython()
744 # everything to be done in single transaction
745 if self.config.dia_object_index == "last_object_table":
746 # insert and replace all records in LAST table, mysql and postgres have
747 # non-standard features
748 table = self._schema.get_table(ApdbTables.DiaObjectLast)
750 # Drop the previous objects (pandas cannot upsert).
751 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
753 with Timer(table.name + " delete", self.config.timer):
754 res = connection.execute(query)
755 _LOG.debug("deleted %s objects", res.rowcount)
757 # DiaObjectLast is a subset of DiaObject, strip missing columns
758 last_column_names = [column.name for column in table.columns]
759 last_objs = objs[last_column_names]
760 last_objs = _coerce_uint64(last_objs)
762 if "lastNonForcedSource" in last_objs.columns:
763 # lastNonForcedSource is defined NOT NULL, fill it with visit time
764 # just in case.
765 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
766 else:
767 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
768 last_objs.set_index(extra_column.index, inplace=True)
769 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
771 with Timer("DiaObjectLast insert", self.config.timer):
772 last_objs.to_sql(
773 table.name,
774 _ConnectionHackSA2(connection),
775 if_exists="append",
776 index=False,
777 schema=table.schema,
778 )
779 else:
780 # truncate existing validity intervals
781 table = self._schema.get_table(ApdbTables.DiaObject)
783 update = (
784 table.update()
785 .values(validityEnd=dt)
786 .where(
787 sql.expression.and_(
788 table.columns["diaObjectId"].in_(ids),
789 table.columns["validityEnd"].is_(None),
790 )
791 )
792 )
794 # _LOG.debug("query: %s", query)
796 with Timer(table.name + " truncate", self.config.timer):
797 res = connection.execute(update)
798 _LOG.debug("truncated %s intervals", res.rowcount)
800 objs = _coerce_uint64(objs)
802 # Fill additional columns
803 extra_columns: List[pandas.Series] = []
804 if "validityStart" in objs.columns:
805 objs["validityStart"] = dt
806 else:
807 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
808 if "validityEnd" in objs.columns:
809 objs["validityEnd"] = None
810 else:
811 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
812 if "lastNonForcedSource" in objs.columns:
813 # lastNonForcedSource is defined NOT NULL, fill it with visit time
814 # just in case.
815 objs["lastNonForcedSource"].fillna(dt, inplace=True)
816 else:
817 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
818 if extra_columns:
819 objs.set_index(extra_columns[0].index, inplace=True)
820 objs = pandas.concat([objs] + extra_columns, axis="columns")
822 # Insert history data
823 table = self._schema.get_table(ApdbTables.DiaObject)
824 history_data: list[dict] = []
825 history_stmt: Any = None
826 if insert_id is not None:
827 pk_names = [column.name for column in table.primary_key]
828 history_data = objs[pk_names].to_dict("records")
829 for row in history_data:
830 row["insert_id"] = insert_id.id
831 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
832 history_stmt = history_table.insert()
834 # insert new versions
835 with Timer("DiaObject insert", self.config.timer):
836 objs.to_sql(
837 table.name,
838 _ConnectionHackSA2(connection),
839 if_exists="append",
840 index=False,
841 schema=table.schema,
842 )
843 if history_stmt is not None:
844 connection.execute(history_stmt, history_data)
846 def _storeDiaSources(
847 self,
848 sources: pandas.DataFrame,
849 insert_id: ApdbInsertId | None,
850 connection: sqlalchemy.engine.Connection,
851 ) -> None:
852 """Store catalog of DiaSources from current visit.
854 Parameters
855 ----------
856 sources : `pandas.DataFrame`
857 Catalog containing DiaSource records
858 """
859 table = self._schema.get_table(ApdbTables.DiaSource)
861 # Insert history data
862 history: list[dict] = []
863 history_stmt: Any = None
864 if insert_id is not None:
865 pk_names = [column.name for column in table.primary_key]
866 history = sources[pk_names].to_dict("records")
867 for row in history:
868 row["insert_id"] = insert_id.id
869 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
870 history_stmt = history_table.insert()
872 # everything to be done in single transaction
873 with Timer("DiaSource insert", self.config.timer):
874 sources = _coerce_uint64(sources)
875 sources.to_sql(
876 table.name,
877 _ConnectionHackSA2(connection),
878 if_exists="append",
879 index=False,
880 schema=table.schema,
881 )
882 if history_stmt is not None:
883 connection.execute(history_stmt, history)
885 def _storeDiaForcedSources(
886 self,
887 sources: pandas.DataFrame,
888 insert_id: ApdbInsertId | None,
889 connection: sqlalchemy.engine.Connection,
890 ) -> None:
891 """Store a set of DiaForcedSources from current visit.
893 Parameters
894 ----------
895 sources : `pandas.DataFrame`
896 Catalog containing DiaForcedSource records
897 """
898 table = self._schema.get_table(ApdbTables.DiaForcedSource)
900 # Insert history data
901 history: list[dict] = []
902 history_stmt: Any = None
903 if insert_id is not None:
904 pk_names = [column.name for column in table.primary_key]
905 history = sources[pk_names].to_dict("records")
906 for row in history:
907 row["insert_id"] = insert_id.id
908 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
909 history_stmt = history_table.insert()
911 # everything to be done in single transaction
912 with Timer("DiaForcedSource insert", self.config.timer):
913 sources = _coerce_uint64(sources)
914 sources.to_sql(
915 table.name,
916 _ConnectionHackSA2(connection),
917 if_exists="append",
918 index=False,
919 schema=table.schema,
920 )
921 if history_stmt is not None:
922 connection.execute(history_stmt, history)
924 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
925 """Generate a set of HTM indices covering specified region.
927 Parameters
928 ----------
929 region: `sphgeom.Region`
930 Region that needs to be indexed.
932 Returns
933 -------
934 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
935 """
936 _LOG.debug("region: %s", region)
937 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
939 return indices.ranges()
941 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
942 """Make SQLAlchemy expression for selecting records in a region."""
943 htm_index_column = table.columns[self.config.htm_index_column]
944 exprlist = []
945 pixel_ranges = self._htm_indices(region)
946 for low, upper in pixel_ranges:
947 upper -= 1
948 if low == upper:
949 exprlist.append(htm_index_column == low)
950 else:
951 exprlist.append(sql.expression.between(htm_index_column, low, upper))
953 return sql.expression.or_(*exprlist)
955 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
956 """Calculate HTM index for each record and add it to a DataFrame.
958 Notes
959 -----
960 This overrides any existing column in a DataFrame with the same name
961 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
962 returned.
963 """
964 # calculate HTM index for every DiaObject
965 htm_index = np.zeros(df.shape[0], dtype=np.int64)
966 ra_col, dec_col = self.config.ra_dec_columns
967 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
968 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
969 idx = self.pixelator.index(uv3d)
970 htm_index[i] = idx
971 df = df.copy()
972 df[self.config.htm_index_column] = htm_index
973 return df
975 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
976 """Add pixelId column to DiaSource catalog.
978 Notes
979 -----
980 This method copies pixelId value from a matching DiaObject record.
981 DiaObject catalog needs to have a pixelId column filled by
982 ``_add_obj_htm_index`` method and DiaSource records need to be
983 associated to DiaObjects via ``diaObjectId`` column.
985 This overrides any existing column in a DataFrame with the same name
986 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
987 returned.
988 """
989 pixel_id_map: Dict[int, int] = {
990 diaObjectId: pixelId
991 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
992 }
993 # DiaSources associated with SolarSystemObjects do not have an
994 # associated DiaObject hence we skip them and set their htmIndex
995 # value to 0.
996 pixel_id_map[0] = 0
997 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
998 for i, diaObjId in enumerate(sources["diaObjectId"]):
999 htm_index[i] = pixel_id_map[diaObjId]
1000 sources = sources.copy()
1001 sources[self.config.htm_index_column] = htm_index
1002 return sources