Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
428 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 09:46 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 09:46 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from typing import Any, Dict, List, Optional, Tuple, cast
33import lsst.daf.base as dafBase
34import numpy as np
35import pandas
36import sqlalchemy
37from felis.simple import Table
38from lsst.pex.config import ChoiceField, Field, ListField
39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
40from lsst.utils.iteration import chunk_iterable
41from sqlalchemy import func, inspection, sql
42from sqlalchemy.engine import Inspector
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbSchema import ApdbTables
47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
48from .timer import Timer
50_LOG = logging.getLogger(__name__)
53if pandas.__version__.partition(".")[0] == "1": 53 ↛ 55line 53 didn't jump to line 55, because the condition on line 53 was never true
55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
56 """Terrible hack to workaround Pandas 1 incomplete support for
57 sqlalchemy 2.
59 We need to pass a Connection instance to pandas method, but in SA 2 the
60 Connection class lost ``connect`` method which is used by Pandas.
61 """
63 def __init__(self, connection: sqlalchemy.engine.Connection):
64 self._connection = connection
66 def connect(self, **kwargs: Any) -> Any:
67 return self
69 @property
70 def execute(self) -> Callable:
71 return self._connection.execute
73 @property
74 def execution_options(self) -> Callable:
75 return self._connection.execution_options
77 @property
78 def connection(self) -> Any:
79 return self._connection.connection
81 def __enter__(self) -> sqlalchemy.engine.Connection:
82 return self._connection
84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
85 # Do not close connection here
86 pass
88 @inspection._inspects(_ConnectionHackSA2)
89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
90 return Inspector._construct(Inspector._init_connection, conn._connection)
92else:
93 # Pandas 2.0 supports SQLAlchemy 2 correctly.
94 def _ConnectionHackSA2( # type: ignore[no-redef]
95 conn: sqlalchemy.engine.Connectable,
96 ) -> sqlalchemy.engine.Connectable:
97 return conn
100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
101 """Change the type of uint64 columns to int64, and return copy of data
102 frame.
103 """
104 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
105 return df.astype({name: np.int64 for name in names})
108def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float:
109 """Calculate starting point for time-based source search.
111 Parameters
112 ----------
113 visit_time : `lsst.daf.base.DateTime`
114 Time of current visit.
115 months : `int`
116 Number of months in the sources history.
118 Returns
119 -------
120 time : `float`
121 A ``midpointMjdTai`` starting point, MJD time.
122 """
123 # TODO: `system` must be consistent with the code in ap_association
124 # (see DM-31996)
125 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
128class ApdbSqlConfig(ApdbConfig):
129 """APDB configuration class for SQL implementation (ApdbSql)."""
131 db_url = Field[str](doc="SQLAlchemy database connection URI")
132 isolation_level = ChoiceField[str](
133 doc=(
134 "Transaction isolation level, if unset then backend-default value "
135 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
136 "Some backends may not support every allowed value."
137 ),
138 allowed={
139 "READ_COMMITTED": "Read committed",
140 "READ_UNCOMMITTED": "Read uncommitted",
141 "REPEATABLE_READ": "Repeatable read",
142 "SERIALIZABLE": "Serializable",
143 },
144 default=None,
145 optional=True,
146 )
147 connection_pool = Field[bool](
148 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
149 default=True,
150 )
151 connection_timeout = Field[float](
152 doc=(
153 "Maximum time to wait time for database lock to be released before exiting. "
154 "Defaults to sqlalchemy defaults if not set."
155 ),
156 default=None,
157 optional=True,
158 )
159 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
160 dia_object_index = ChoiceField[str](
161 doc="Indexing mode for DiaObject table",
162 allowed={
163 "baseline": "Index defined in baseline schema",
164 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
165 "last_object_table": "Separate DiaObjectLast table",
166 },
167 default="baseline",
168 )
169 htm_level = Field[int](doc="HTM indexing level", default=20)
170 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
171 htm_index_column = Field[str](
172 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
173 )
174 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
175 dia_object_columns = ListField[str](
176 doc="List of columns to read from DiaObject, by default read all columns", default=[]
177 )
178 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
179 namespace = Field[str](
180 doc=(
181 "Namespace or schema name for all tables in APDB database. "
182 "Presently only works for PostgreSQL backend. "
183 "If schema with this name does not exist it will be created when "
184 "APDB tables are created."
185 ),
186 default=None,
187 optional=True,
188 )
189 timer = Field[bool](doc="If True then print/log timing information", default=False)
191 def validate(self) -> None:
192 super().validate()
193 if len(self.ra_dec_columns) != 2:
194 raise ValueError("ra_dec_columns must have exactly two column names")
197class ApdbSqlTableData(ApdbTableData):
198 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
200 def __init__(self, result: sqlalchemy.engine.Result):
201 self._keys = list(result.keys())
202 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
204 def column_names(self) -> list[str]:
205 return self._keys
207 def rows(self) -> Iterable[tuple]:
208 return self._rows
211class ApdbSql(Apdb):
212 """Implementation of APDB interface based on SQL database.
214 The implementation is configured via standard ``pex_config`` mechanism
215 using `ApdbSqlConfig` configuration class. For an example of different
216 configurations check ``config/`` folder.
218 Parameters
219 ----------
220 config : `ApdbSqlConfig`
221 Configuration object.
222 """
224 ConfigClass = ApdbSqlConfig
226 def __init__(self, config: ApdbSqlConfig):
227 config.validate()
228 self.config = config
230 _LOG.debug("APDB Configuration:")
231 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
232 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
233 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
234 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
235 _LOG.debug(" schema_file: %s", self.config.schema_file)
236 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
237 _LOG.debug(" schema prefix: %s", self.config.prefix)
239 # engine is reused between multiple processes, make sure that we don't
240 # share connections by disabling pool (by using NullPool class)
241 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
242 conn_args: Dict[str, Any] = dict()
243 if not self.config.connection_pool:
244 kw.update(poolclass=NullPool)
245 if self.config.isolation_level is not None:
246 kw.update(isolation_level=self.config.isolation_level)
247 elif self.config.db_url.startswith("sqlite"): # type: ignore
248 # Use READ_UNCOMMITTED as default value for sqlite.
249 kw.update(isolation_level="READ_UNCOMMITTED")
250 if self.config.connection_timeout is not None:
251 if self.config.db_url.startswith("sqlite"):
252 conn_args.update(timeout=self.config.connection_timeout)
253 elif self.config.db_url.startswith(("postgresql", "mysql")):
254 conn_args.update(connect_timeout=self.config.connection_timeout)
255 kw.update(connect_args=conn_args)
256 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
258 self._schema = ApdbSqlSchema(
259 engine=self._engine,
260 dia_object_index=self.config.dia_object_index,
261 schema_file=self.config.schema_file,
262 schema_name=self.config.schema_name,
263 prefix=self.config.prefix,
264 namespace=self.config.namespace,
265 htm_index_column=self.config.htm_index_column,
266 use_insert_id=config.use_insert_id,
267 )
269 self.pixelator = HtmPixelization(self.config.htm_level)
270 self.use_insert_id = self._schema.has_insert_id
272 def tableRowCount(self) -> Dict[str, int]:
273 """Return dictionary with the table names and row counts.
275 Used by ``ap_proto`` to keep track of the size of the database tables.
276 Depending on database technology this could be expensive operation.
278 Returns
279 -------
280 row_counts : `dict`
281 Dict where key is a table name and value is a row count.
282 """
283 res = {}
284 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
285 if self.config.dia_object_index == "last_object_table":
286 tables.append(ApdbTables.DiaObjectLast)
287 with self._engine.begin() as conn:
288 for table in tables:
289 sa_table = self._schema.get_table(table)
290 stmt = sql.select(func.count()).select_from(sa_table)
291 count: int = conn.execute(stmt).scalar_one()
292 res[table.name] = count
294 return res
296 def tableDef(self, table: ApdbTables) -> Optional[Table]:
297 # docstring is inherited from a base class
298 return self._schema.tableSchemas.get(table)
300 def makeSchema(self, drop: bool = False) -> None:
301 # docstring is inherited from a base class
302 self._schema.makeSchema(drop=drop)
304 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
305 # docstring is inherited from a base class
307 # decide what columns we need
308 if self.config.dia_object_index == "last_object_table":
309 table_enum = ApdbTables.DiaObjectLast
310 else:
311 table_enum = ApdbTables.DiaObject
312 table = self._schema.get_table(table_enum)
313 if not self.config.dia_object_columns:
314 columns = self._schema.get_apdb_columns(table_enum)
315 else:
316 columns = [table.c[col] for col in self.config.dia_object_columns]
317 query = sql.select(*columns)
319 # build selection
320 query = query.where(self._filterRegion(table, region))
322 # select latest version of objects
323 if self.config.dia_object_index != "last_object_table":
324 query = query.where(table.c.validityEnd == None) # noqa: E711
326 # _LOG.debug("query: %s", query)
328 # execute select
329 with Timer("DiaObject select", self.config.timer):
330 with self._engine.begin() as conn:
331 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
332 _LOG.debug("found %s DiaObjects", len(objects))
333 return objects
335 def getDiaSources(
336 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
337 ) -> Optional[pandas.DataFrame]:
338 # docstring is inherited from a base class
339 if self.config.read_sources_months == 0:
340 _LOG.debug("Skip DiaSources fetching")
341 return None
343 if object_ids is None:
344 # region-based select
345 return self._getDiaSourcesInRegion(region, visit_time)
346 else:
347 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
349 def getDiaForcedSources(
350 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
351 ) -> Optional[pandas.DataFrame]:
352 # docstring is inherited from a base class
353 if self.config.read_forced_sources_months == 0:
354 _LOG.debug("Skip DiaForceSources fetching")
355 return None
357 if object_ids is None:
358 # This implementation does not support region-based selection.
359 raise NotImplementedError("Region-based selection is not supported")
361 # TODO: DateTime.MJD must be consistent with code in ap_association,
362 # alternatively we can fill midpointMjdTai ourselves in store()
363 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
364 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
366 with Timer("DiaForcedSource select", self.config.timer):
367 sources = self._getSourcesByIDs(
368 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
369 )
371 _LOG.debug("found %s DiaForcedSources", len(sources))
372 return sources
374 def getInsertIds(self) -> list[ApdbInsertId] | None:
375 # docstring is inherited from a base class
376 if not self._schema.has_insert_id:
377 return None
379 table = self._schema.get_table(ExtraTables.DiaInsertId)
380 assert table is not None, "has_insert_id=True means it must be defined"
381 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"])
382 with Timer("DiaObject insert id select", self.config.timer):
383 with self._engine.connect() as conn:
384 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
385 return [ApdbInsertId(row) for row in result.scalars()]
387 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
388 # docstring is inherited from a base class
389 if not self._schema.has_insert_id:
390 raise ValueError("APDB is not configured for history storage")
392 table = self._schema.get_table(ExtraTables.DiaInsertId)
394 insert_ids = [id.id for id in ids]
395 where_clause = table.columns["insert_id"].in_(insert_ids)
396 stmt = table.delete().where(where_clause)
397 with self._engine.begin() as conn:
398 conn.execute(stmt)
400 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
401 # docstring is inherited from a base class
402 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
404 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
405 # docstring is inherited from a base class
406 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
408 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
409 # docstring is inherited from a base class
410 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
412 def _get_history(
413 self,
414 ids: Iterable[ApdbInsertId],
415 table_enum: ApdbTables,
416 history_table_enum: ExtraTables,
417 ) -> ApdbTableData:
418 """Return catalog of records for given insert identifiers, common
419 implementation for all DIA tables.
420 """
421 if not self._schema.has_insert_id:
422 raise ValueError("APDB is not configured for history retrieval")
424 table = self._schema.get_table(table_enum)
425 history_table = self._schema.get_table(history_table_enum)
427 join = table.join(history_table)
428 insert_ids = [id.id for id in ids]
429 history_id_column = history_table.columns["insert_id"]
430 apdb_columns = self._schema.get_apdb_columns(table_enum)
431 where_clause = history_id_column.in_(insert_ids)
432 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
434 # execute select
435 with Timer(f"{table.name} history select", self.config.timer):
436 with self._engine.begin() as conn:
437 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
438 return ApdbSqlTableData(result)
440 def getSSObjects(self) -> pandas.DataFrame:
441 # docstring is inherited from a base class
443 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
444 query = sql.select(*columns)
446 # execute select
447 with Timer("DiaObject select", self.config.timer):
448 with self._engine.begin() as conn:
449 objects = pandas.read_sql_query(query, conn)
450 _LOG.debug("found %s SSObjects", len(objects))
451 return objects
453 def store(
454 self,
455 visit_time: dafBase.DateTime,
456 objects: pandas.DataFrame,
457 sources: Optional[pandas.DataFrame] = None,
458 forced_sources: Optional[pandas.DataFrame] = None,
459 ) -> None:
460 # docstring is inherited from a base class
462 # We want to run all inserts in one transaction.
463 with self._engine.begin() as connection:
464 insert_id: ApdbInsertId | None = None
465 if self._schema.has_insert_id:
466 insert_id = ApdbInsertId.new_insert_id()
467 self._storeInsertId(insert_id, visit_time, connection)
469 # fill pixelId column for DiaObjects
470 objects = self._add_obj_htm_index(objects)
471 self._storeDiaObjects(objects, visit_time, insert_id, connection)
473 if sources is not None:
474 # copy pixelId column from DiaObjects to DiaSources
475 sources = self._add_src_htm_index(sources, objects)
476 self._storeDiaSources(sources, insert_id, connection)
478 if forced_sources is not None:
479 self._storeDiaForcedSources(forced_sources, insert_id, connection)
481 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
482 # docstring is inherited from a base class
484 idColumn = "ssObjectId"
485 table = self._schema.get_table(ApdbTables.SSObject)
487 # everything to be done in single transaction
488 with self._engine.begin() as conn:
489 # Find record IDs that already exist. Some types like np.int64 can
490 # cause issues with sqlalchemy, convert them to int.
491 ids = sorted(int(oid) for oid in objects[idColumn])
493 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
494 result = conn.execute(query)
495 knownIds = set(row.ssObjectId for row in result)
497 filter = objects[idColumn].isin(knownIds)
498 toUpdate = cast(pandas.DataFrame, objects[filter])
499 toInsert = cast(pandas.DataFrame, objects[~filter])
501 # insert new records
502 if len(toInsert) > 0:
503 toInsert.to_sql(
504 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
505 )
507 # update existing records
508 if len(toUpdate) > 0:
509 whereKey = f"{idColumn}_param"
510 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
511 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
512 values = toUpdate.to_dict("records")
513 result = conn.execute(update, values)
515 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
516 # docstring is inherited from a base class
518 table = self._schema.get_table(ApdbTables.DiaSource)
519 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
521 with self._engine.begin() as conn:
522 # Need to make sure that every ID exists in the database, but
523 # executemany may not support rowcount, so iterate and check what
524 # is missing.
525 missing_ids: List[int] = []
526 for key, value in idMap.items():
527 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
528 result = conn.execute(query, params)
529 if result.rowcount == 0:
530 missing_ids.append(key)
531 if missing_ids:
532 missing = ",".join(str(item) for item in missing_ids)
533 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
535 def dailyJob(self) -> None:
536 # docstring is inherited from a base class
537 pass
539 def countUnassociatedObjects(self) -> int:
540 # docstring is inherited from a base class
542 # Retrieve the DiaObject table.
543 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
545 # Construct the sql statement.
546 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
547 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
549 # Return the count.
550 with self._engine.begin() as conn:
551 count = conn.execute(stmt).scalar_one()
553 return count
555 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
556 """Return catalog of DiaSource instances from given region.
558 Parameters
559 ----------
560 region : `lsst.sphgeom.Region`
561 Region to search for DIASources.
562 visit_time : `lsst.daf.base.DateTime`
563 Time of the current visit.
565 Returns
566 -------
567 catalog : `pandas.DataFrame`
568 Catalog containing DiaSource records.
569 """
570 # TODO: DateTime.MJD must be consistent with code in ap_association,
571 # alternatively we can fill midpointMjdTai ourselves in store()
572 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
573 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
575 table = self._schema.get_table(ApdbTables.DiaSource)
576 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
577 query = sql.select(*columns)
579 # build selection
580 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
581 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
582 query = query.where(where)
584 # execute select
585 with Timer("DiaSource select", self.config.timer):
586 with self._engine.begin() as conn:
587 sources = pandas.read_sql_query(query, conn)
588 _LOG.debug("found %s DiaSources", len(sources))
589 return sources
591 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
592 """Return catalog of DiaSource instances given set of DiaObject IDs.
594 Parameters
595 ----------
596 object_ids :
597 Collection of DiaObject IDs
598 visit_time : `lsst.daf.base.DateTime`
599 Time of the current visit.
601 Returns
602 -------
603 catalog : `pandas.DataFrame`
604 Catalog contaning DiaSource records.
605 """
606 # TODO: DateTime.MJD must be consistent with code in ap_association,
607 # alternatively we can fill midpointMjdTai ourselves in store()
608 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
609 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
611 with Timer("DiaSource select", self.config.timer):
612 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
614 _LOG.debug("found %s DiaSources", len(sources))
615 return sources
617 def _getSourcesByIDs(
618 self, table_enum: ApdbTables, object_ids: List[int], midpointMjdTai_start: float
619 ) -> pandas.DataFrame:
620 """Return catalog of DiaSource or DiaForcedSource instances given set
621 of DiaObject IDs.
623 Parameters
624 ----------
625 table : `sqlalchemy.schema.Table`
626 Database table.
627 object_ids :
628 Collection of DiaObject IDs
629 midpointMjdTai_start : `float`
630 Earliest midpointMjdTai to retrieve.
632 Returns
633 -------
634 catalog : `pandas.DataFrame`
635 Catalog contaning DiaSource records. `None` is returned if
636 ``read_sources_months`` configuration parameter is set to 0 or
637 when ``object_ids`` is empty.
638 """
639 table = self._schema.get_table(table_enum)
640 columns = self._schema.get_apdb_columns(table_enum)
642 sources: Optional[pandas.DataFrame] = None
643 if len(object_ids) <= 0:
644 _LOG.debug("ID list is empty, just fetch empty result")
645 query = sql.select(*columns).where(sql.literal(False))
646 with self._engine.begin() as conn:
647 sources = pandas.read_sql_query(query, conn)
648 else:
649 data_frames: list[pandas.DataFrame] = []
650 for ids in chunk_iterable(sorted(object_ids), 1000):
651 query = sql.select(*columns)
653 # Some types like np.int64 can cause issues with
654 # sqlalchemy, convert them to int.
655 int_ids = [int(oid) for oid in ids]
657 # select by object id
658 query = query.where(
659 sql.expression.and_(
660 table.columns["diaObjectId"].in_(int_ids),
661 table.columns["midpointMjdTai"] > midpointMjdTai_start,
662 )
663 )
665 # execute select
666 with self._engine.begin() as conn:
667 data_frames.append(pandas.read_sql_query(query, conn))
669 if len(data_frames) == 1:
670 sources = data_frames[0]
671 else:
672 sources = pandas.concat(data_frames)
673 assert sources is not None, "Catalog cannot be None"
674 return sources
676 def _storeInsertId(
677 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
678 ) -> None:
679 dt = visit_time.toPython()
681 table = self._schema.get_table(ExtraTables.DiaInsertId)
683 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
684 connection.execute(stmt)
686 def _storeDiaObjects(
687 self,
688 objs: pandas.DataFrame,
689 visit_time: dafBase.DateTime,
690 insert_id: ApdbInsertId | None,
691 connection: sqlalchemy.engine.Connection,
692 ) -> None:
693 """Store catalog of DiaObjects from current visit.
695 Parameters
696 ----------
697 objs : `pandas.DataFrame`
698 Catalog with DiaObject records.
699 visit_time : `lsst.daf.base.DateTime`
700 Time of the visit.
701 insert_id : `ApdbInsertId`
702 Insert identifier.
703 """
704 # Some types like np.int64 can cause issues with sqlalchemy, convert
705 # them to int.
706 ids = sorted(int(oid) for oid in objs["diaObjectId"])
707 _LOG.debug("first object ID: %d", ids[0])
709 # TODO: Need to verify that we are using correct scale here for
710 # DATETIME representation (see DM-31996).
711 dt = visit_time.toPython()
713 # everything to be done in single transaction
714 if self.config.dia_object_index == "last_object_table":
715 # Insert and replace all records in LAST table.
716 table = self._schema.get_table(ApdbTables.DiaObjectLast)
718 # Drop the previous objects (pandas cannot upsert).
719 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
721 with Timer(table.name + " delete", self.config.timer):
722 res = connection.execute(query)
723 _LOG.debug("deleted %s objects", res.rowcount)
725 # DiaObjectLast is a subset of DiaObject, strip missing columns
726 last_column_names = [column.name for column in table.columns]
727 last_objs = objs[last_column_names]
728 last_objs = _coerce_uint64(last_objs)
730 if "lastNonForcedSource" in last_objs.columns:
731 # lastNonForcedSource is defined NOT NULL, fill it with visit
732 # time just in case.
733 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
734 else:
735 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
736 last_objs.set_index(extra_column.index, inplace=True)
737 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
739 with Timer("DiaObjectLast insert", self.config.timer):
740 last_objs.to_sql(
741 table.name,
742 _ConnectionHackSA2(connection),
743 if_exists="append",
744 index=False,
745 schema=table.schema,
746 )
747 else:
748 # truncate existing validity intervals
749 table = self._schema.get_table(ApdbTables.DiaObject)
751 update = (
752 table.update()
753 .values(validityEnd=dt)
754 .where(
755 sql.expression.and_(
756 table.columns["diaObjectId"].in_(ids),
757 table.columns["validityEnd"].is_(None),
758 )
759 )
760 )
762 # _LOG.debug("query: %s", query)
764 with Timer(table.name + " truncate", self.config.timer):
765 res = connection.execute(update)
766 _LOG.debug("truncated %s intervals", res.rowcount)
768 objs = _coerce_uint64(objs)
770 # Fill additional columns
771 extra_columns: List[pandas.Series] = []
772 if "validityStart" in objs.columns:
773 objs["validityStart"] = dt
774 else:
775 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
776 if "validityEnd" in objs.columns:
777 objs["validityEnd"] = None
778 else:
779 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
780 if "lastNonForcedSource" in objs.columns:
781 # lastNonForcedSource is defined NOT NULL, fill it with visit time
782 # just in case.
783 objs["lastNonForcedSource"].fillna(dt, inplace=True)
784 else:
785 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
786 if extra_columns:
787 objs.set_index(extra_columns[0].index, inplace=True)
788 objs = pandas.concat([objs] + extra_columns, axis="columns")
790 # Insert history data
791 table = self._schema.get_table(ApdbTables.DiaObject)
792 history_data: list[dict] = []
793 history_stmt: Any = None
794 if insert_id is not None:
795 pk_names = [column.name for column in table.primary_key]
796 history_data = objs[pk_names].to_dict("records")
797 for row in history_data:
798 row["insert_id"] = insert_id.id
799 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
800 history_stmt = history_table.insert()
802 # insert new versions
803 with Timer("DiaObject insert", self.config.timer):
804 objs.to_sql(
805 table.name,
806 _ConnectionHackSA2(connection),
807 if_exists="append",
808 index=False,
809 schema=table.schema,
810 )
811 if history_stmt is not None:
812 connection.execute(history_stmt, history_data)
814 def _storeDiaSources(
815 self,
816 sources: pandas.DataFrame,
817 insert_id: ApdbInsertId | None,
818 connection: sqlalchemy.engine.Connection,
819 ) -> None:
820 """Store catalog of DiaSources from current visit.
822 Parameters
823 ----------
824 sources : `pandas.DataFrame`
825 Catalog containing DiaSource records
826 """
827 table = self._schema.get_table(ApdbTables.DiaSource)
829 # Insert history data
830 history: list[dict] = []
831 history_stmt: Any = None
832 if insert_id is not None:
833 pk_names = [column.name for column in table.primary_key]
834 history = sources[pk_names].to_dict("records")
835 for row in history:
836 row["insert_id"] = insert_id.id
837 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
838 history_stmt = history_table.insert()
840 # everything to be done in single transaction
841 with Timer("DiaSource insert", self.config.timer):
842 sources = _coerce_uint64(sources)
843 sources.to_sql(
844 table.name,
845 _ConnectionHackSA2(connection),
846 if_exists="append",
847 index=False,
848 schema=table.schema,
849 )
850 if history_stmt is not None:
851 connection.execute(history_stmt, history)
853 def _storeDiaForcedSources(
854 self,
855 sources: pandas.DataFrame,
856 insert_id: ApdbInsertId | None,
857 connection: sqlalchemy.engine.Connection,
858 ) -> None:
859 """Store a set of DiaForcedSources from current visit.
861 Parameters
862 ----------
863 sources : `pandas.DataFrame`
864 Catalog containing DiaForcedSource records
865 """
866 table = self._schema.get_table(ApdbTables.DiaForcedSource)
868 # Insert history data
869 history: list[dict] = []
870 history_stmt: Any = None
871 if insert_id is not None:
872 pk_names = [column.name for column in table.primary_key]
873 history = sources[pk_names].to_dict("records")
874 for row in history:
875 row["insert_id"] = insert_id.id
876 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
877 history_stmt = history_table.insert()
879 # everything to be done in single transaction
880 with Timer("DiaForcedSource insert", self.config.timer):
881 sources = _coerce_uint64(sources)
882 sources.to_sql(
883 table.name,
884 _ConnectionHackSA2(connection),
885 if_exists="append",
886 index=False,
887 schema=table.schema,
888 )
889 if history_stmt is not None:
890 connection.execute(history_stmt, history)
892 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
893 """Generate a set of HTM indices covering specified region.
895 Parameters
896 ----------
897 region: `sphgeom.Region`
898 Region that needs to be indexed.
900 Returns
901 -------
902 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
903 """
904 _LOG.debug("region: %s", region)
905 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
907 return indices.ranges()
909 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
910 """Make SQLAlchemy expression for selecting records in a region."""
911 htm_index_column = table.columns[self.config.htm_index_column]
912 exprlist = []
913 pixel_ranges = self._htm_indices(region)
914 for low, upper in pixel_ranges:
915 upper -= 1
916 if low == upper:
917 exprlist.append(htm_index_column == low)
918 else:
919 exprlist.append(sql.expression.between(htm_index_column, low, upper))
921 return sql.expression.or_(*exprlist)
923 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
924 """Calculate HTM index for each record and add it to a DataFrame.
926 Notes
927 -----
928 This overrides any existing column in a DataFrame with the same name
929 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
930 returned.
931 """
932 # calculate HTM index for every DiaObject
933 htm_index = np.zeros(df.shape[0], dtype=np.int64)
934 ra_col, dec_col = self.config.ra_dec_columns
935 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
936 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
937 idx = self.pixelator.index(uv3d)
938 htm_index[i] = idx
939 df = df.copy()
940 df[self.config.htm_index_column] = htm_index
941 return df
943 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
944 """Add pixelId column to DiaSource catalog.
946 Notes
947 -----
948 This method copies pixelId value from a matching DiaObject record.
949 DiaObject catalog needs to have a pixelId column filled by
950 ``_add_obj_htm_index`` method and DiaSource records need to be
951 associated to DiaObjects via ``diaObjectId`` column.
953 This overrides any existing column in a DataFrame with the same name
954 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
955 returned.
956 """
957 pixel_id_map: Dict[int, int] = {
958 diaObjectId: pixelId
959 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
960 }
961 # DiaSources associated with SolarSystemObjects do not have an
962 # associated DiaObject hence we skip them and set their htmIndex
963 # value to 0.
964 pixel_id_map[0] = 0
965 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
966 for i, diaObjId in enumerate(sources["diaObjectId"]):
967 htm_index[i] = pixel_id_map[diaObjId]
968 sources = sources.copy()
969 sources[self.config.htm_index_column] = htm_index
970 return sources