Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
408 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-19 10:10 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-19 10:10 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from typing import Any, Dict, List, Optional, Tuple, cast
33import lsst.daf.base as dafBase
34import numpy as np
35import pandas
36import sqlalchemy
37from felis.simple import Table
38from lsst.pex.config import ChoiceField, Field, ListField
39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
40from lsst.utils.iteration import chunk_iterable
41from sqlalchemy import func, sql
42from sqlalchemy.pool import NullPool
44from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
45from .apdbSchema import ApdbTables
46from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
47from .timer import Timer
49_LOG = logging.getLogger(__name__)
52def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
53 """Change type of the uint64 columns to int64, return copy of data frame."""
54 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
55 return df.astype({name: np.int64 for name in names})
58def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
59 """Calculate starting point for time-based source search.
61 Parameters
62 ----------
63 visit_time : `lsst.daf.base.DateTime`
64 Time of current visit.
65 months : `int`
66 Number of months in the sources history.
68 Returns
69 -------
70 time : `float`
71 A ``midPointTai`` starting point, MJD time.
72 """
73 # TODO: `system` must be consistent with the code in ap_association
74 # (see DM-31996)
75 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
78class ApdbSqlConfig(ApdbConfig):
79 """APDB configuration class for SQL implementation (ApdbSql)."""
81 db_url = Field[str](doc="SQLAlchemy database connection URI")
82 isolation_level = ChoiceField[str](
83 doc=(
84 "Transaction isolation level, if unset then backend-default value "
85 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
86 "Some backends may not support every allowed value."
87 ),
88 allowed={
89 "READ_COMMITTED": "Read committed",
90 "READ_UNCOMMITTED": "Read uncommitted",
91 "REPEATABLE_READ": "Repeatable read",
92 "SERIALIZABLE": "Serializable",
93 },
94 default=None,
95 optional=True,
96 )
97 connection_pool = Field[bool](
98 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
99 default=True,
100 )
101 connection_timeout = Field[float](
102 doc=(
103 "Maximum time to wait time for database lock to be released before exiting. "
104 "Defaults to sqlalchemy defaults if not set."
105 ),
106 default=None,
107 optional=True,
108 )
109 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
110 dia_object_index = ChoiceField[str](
111 doc="Indexing mode for DiaObject table",
112 allowed={
113 "baseline": "Index defined in baseline schema",
114 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
115 "last_object_table": "Separate DiaObjectLast table",
116 },
117 default="baseline",
118 )
119 htm_level = Field[int](doc="HTM indexing level", default=20)
120 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
121 htm_index_column = Field[str](
122 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
123 )
124 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names ra/dec columns in DiaObject table")
125 dia_object_columns = ListField[str](
126 doc="List of columns to read from DiaObject, by default read all columns", default=[]
127 )
128 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
129 namespace = Field[str](
130 doc=(
131 "Namespace or schema name for all tables in APDB database. "
132 "Presently only makes sense for PostgresQL backend. "
133 "If schema with this name does not exist it will be created when "
134 "APDB tables are created."
135 ),
136 default=None,
137 optional=True,
138 )
139 timer = Field[bool](doc="If True then print/log timing information", default=False)
141 def validate(self) -> None:
142 super().validate()
143 if len(self.ra_dec_columns) != 2:
144 raise ValueError("ra_dec_columns must have exactly two column names")
147class ApdbSqlTableData(ApdbTableData):
148 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
150 def __init__(self, result: sqlalchemy.engine.Result):
151 self.result = result
153 def column_names(self) -> list[str]:
154 return self.result.keys()
156 def rows(self) -> Iterable[tuple]:
157 for row in self.result:
158 yield tuple(row)
161class ApdbSql(Apdb):
162 """Implementation of APDB interface based on SQL database.
164 The implementation is configured via standard ``pex_config`` mechanism
165 using `ApdbSqlConfig` configuration class. For an example of different
166 configurations check ``config/`` folder.
168 Parameters
169 ----------
170 config : `ApdbSqlConfig`
171 Configuration object.
172 """
174 ConfigClass = ApdbSqlConfig
176 def __init__(self, config: ApdbSqlConfig):
178 config.validate()
179 self.config = config
181 _LOG.debug("APDB Configuration:")
182 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
183 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
184 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
185 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
186 _LOG.debug(" schema_file: %s", self.config.schema_file)
187 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
188 _LOG.debug(" schema prefix: %s", self.config.prefix)
190 # engine is reused between multiple processes, make sure that we don't
191 # share connections by disabling pool (by using NullPool class)
192 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
193 conn_args: Dict[str, Any] = dict()
194 if not self.config.connection_pool:
195 kw.update(poolclass=NullPool)
196 if self.config.isolation_level is not None:
197 kw.update(isolation_level=self.config.isolation_level)
198 elif self.config.db_url.startswith("sqlite"): # type: ignore
199 # Use READ_UNCOMMITTED as default value for sqlite.
200 kw.update(isolation_level="READ_UNCOMMITTED")
201 if self.config.connection_timeout is not None:
202 if self.config.db_url.startswith("sqlite"):
203 conn_args.update(timeout=self.config.connection_timeout)
204 elif self.config.db_url.startswith(("postgresql", "mysql")):
205 conn_args.update(connect_timeout=self.config.connection_timeout)
206 kw.update(connect_args=conn_args)
207 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
209 self._schema = ApdbSqlSchema(
210 engine=self._engine,
211 dia_object_index=self.config.dia_object_index,
212 schema_file=self.config.schema_file,
213 schema_name=self.config.schema_name,
214 prefix=self.config.prefix,
215 namespace=self.config.namespace,
216 htm_index_column=self.config.htm_index_column,
217 use_insert_id=config.use_insert_id,
218 )
220 self.pixelator = HtmPixelization(self.config.htm_level)
221 self.use_insert_id = self._schema.has_insert_id
223 def tableRowCount(self) -> Dict[str, int]:
224 """Returns dictionary with the table names and row counts.
226 Used by ``ap_proto`` to keep track of the size of the database tables.
227 Depending on database technology this could be expensive operation.
229 Returns
230 -------
231 row_counts : `dict`
232 Dict where key is a table name and value is a row count.
233 """
234 res = {}
235 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
236 if self.config.dia_object_index == "last_object_table":
237 tables.append(ApdbTables.DiaObjectLast)
238 for table in tables:
239 sa_table = self._schema.get_table(table)
240 stmt = sql.select([func.count()]).select_from(sa_table)
241 count = self._engine.scalar(stmt)
242 res[table.name] = count
244 return res
246 def tableDef(self, table: ApdbTables) -> Optional[Table]:
247 # docstring is inherited from a base class
248 return self._schema.tableSchemas.get(table)
250 def makeSchema(self, drop: bool = False) -> None:
251 # docstring is inherited from a base class
252 self._schema.makeSchema(drop=drop)
254 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
255 # docstring is inherited from a base class
257 # decide what columns we need
258 if self.config.dia_object_index == "last_object_table":
259 table_enum = ApdbTables.DiaObjectLast
260 else:
261 table_enum = ApdbTables.DiaObject
262 table = self._schema.get_table(table_enum)
263 if not self.config.dia_object_columns:
264 columns = self._schema.get_apdb_columns(table_enum)
265 else:
266 columns = [table.c[col] for col in self.config.dia_object_columns]
267 query = sql.select(*columns)
269 # build selection
270 query = query.where(self._filterRegion(table, region))
272 # select latest version of objects
273 if self.config.dia_object_index != "last_object_table":
274 query = query.where(table.c.validityEnd == None) # noqa: E711
276 # _LOG.debug("query: %s", query)
278 # execute select
279 with Timer("DiaObject select", self.config.timer):
280 with self._engine.begin() as conn:
281 objects = pandas.read_sql_query(query, conn)
282 _LOG.debug("found %s DiaObjects", len(objects))
283 return objects
285 def getDiaSources(
286 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
287 ) -> Optional[pandas.DataFrame]:
288 # docstring is inherited from a base class
289 if self.config.read_sources_months == 0:
290 _LOG.debug("Skip DiaSources fetching")
291 return None
293 if object_ids is None:
294 # region-based select
295 return self._getDiaSourcesInRegion(region, visit_time)
296 else:
297 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
299 def getDiaForcedSources(
300 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
301 ) -> Optional[pandas.DataFrame]:
302 """Return catalog of DiaForcedSource instances from a given region.
304 Parameters
305 ----------
306 region : `lsst.sphgeom.Region`
307 Region to search for DIASources.
308 object_ids : iterable [ `int` ], optional
309 List of DiaObject IDs to further constrain the set of returned
310 sources. If list is empty then empty catalog is returned with a
311 correct schema.
312 visit_time : `lsst.daf.base.DateTime`
313 Time of the current visit.
315 Returns
316 -------
317 catalog : `pandas.DataFrame`, or `None`
318 Catalog containing DiaSource records. `None` is returned if
319 ``read_sources_months`` configuration parameter is set to 0.
321 Raises
322 ------
323 NotImplementedError
324 Raised if ``object_ids`` is `None`.
326 Notes
327 -----
328 Even though base class allows `None` to be passed for ``object_ids``,
329 this class requires ``object_ids`` to be not-`None`.
330 `NotImplementedError` is raised if `None` is passed.
332 This method returns DiaForcedSource catalog for a region with additional
333 filtering based on DiaObject IDs. Only a subset of DiaSource history
334 is returned limited by ``read_forced_sources_months`` config parameter,
335 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
336 is always returned with a correct schema (columns/types).
337 """
339 if self.config.read_forced_sources_months == 0:
340 _LOG.debug("Skip DiaForceSources fetching")
341 return None
343 if object_ids is None:
344 # This implementation does not support region-based selection.
345 raise NotImplementedError("Region-based selection is not supported")
347 # TODO: DateTime.MJD must be consistent with code in ap_association,
348 # alternatively we can fill midPointTai ourselves in store()
349 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
350 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
352 with Timer("DiaForcedSource select", self.config.timer):
353 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start)
355 _LOG.debug("found %s DiaForcedSources", len(sources))
356 return sources
358 def getInsertIds(self) -> list[ApdbInsertId] | None:
359 # docstring is inherited from a base class
360 if not self._schema.has_insert_id:
361 return None
363 table = self._schema.get_table(ExtraTables.DiaInsertId)
364 assert table is not None, "has_insert_id=True means it must be defined"
365 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"])
366 with Timer("DiaObject insert id select", self.config.timer):
367 with self._engine.connect() as conn:
368 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
369 return [ApdbInsertId(row) for row in result.scalars()]
371 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
372 # docstring is inherited from a base class
373 if not self._schema.has_insert_id:
374 raise ValueError("APDB is not configured for history storage")
376 table = self._schema.get_table(ExtraTables.DiaInsertId)
378 insert_ids = [id.id for id in ids]
379 where_clause = table.columns["insert_id"].in_(insert_ids)
380 stmt = table.delete().where(where_clause)
381 with self._engine.begin() as conn:
382 conn.execute(stmt)
384 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
385 # docstring is inherited from a base class
386 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
388 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
389 # docstring is inherited from a base class
390 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
392 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
393 # docstring is inherited from a base class
394 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
396 def _get_history(
397 self,
398 ids: Iterable[ApdbInsertId],
399 table_enum: ApdbTables,
400 history_table_enum: ExtraTables,
401 ) -> ApdbTableData:
402 """Common implementation of the history methods."""
403 if not self._schema.has_insert_id:
404 raise ValueError("APDB is not configured for history retrieval")
406 table = self._schema.get_table(table_enum)
407 history_table = self._schema.get_table(history_table_enum)
409 join = table.join(history_table)
410 insert_ids = [id.id for id in ids]
411 history_id_column = history_table.columns["insert_id"]
412 apdb_columns = self._schema.get_apdb_columns(table_enum)
413 where_clause = history_id_column.in_(insert_ids)
414 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
416 # execute select
417 with Timer(f"{table.name} history select", self.config.timer):
418 connection = self._engine.connect(close_with_result=True)
419 result = connection.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
420 return ApdbSqlTableData(result)
422 def getSSObjects(self) -> pandas.DataFrame:
423 # docstring is inherited from a base class
425 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
426 query = sql.select(*columns)
428 # execute select
429 with Timer("DiaObject select", self.config.timer):
430 with self._engine.begin() as conn:
431 objects = pandas.read_sql_query(query, conn)
432 _LOG.debug("found %s SSObjects", len(objects))
433 return objects
435 def store(
436 self,
437 visit_time: dafBase.DateTime,
438 objects: pandas.DataFrame,
439 sources: Optional[pandas.DataFrame] = None,
440 forced_sources: Optional[pandas.DataFrame] = None,
441 ) -> None:
442 # docstring is inherited from a base class
444 # We want to run all inserts in one transaction.
445 with self._engine.begin() as connection:
447 insert_id: ApdbInsertId | None = None
448 if self._schema.has_insert_id:
449 insert_id = ApdbInsertId.new_insert_id()
450 self._storeInsertId(insert_id, visit_time, connection)
452 # fill pixelId column for DiaObjects
453 objects = self._add_obj_htm_index(objects)
454 self._storeDiaObjects(objects, visit_time, insert_id, connection)
456 if sources is not None:
457 # copy pixelId column from DiaObjects to DiaSources
458 sources = self._add_src_htm_index(sources, objects)
459 self._storeDiaSources(sources, insert_id, connection)
461 if forced_sources is not None:
462 self._storeDiaForcedSources(forced_sources, insert_id, connection)
464 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
465 # docstring is inherited from a base class
467 idColumn = "ssObjectId"
468 table = self._schema.get_table(ApdbTables.SSObject)
470 # everything to be done in single transaction
471 with self._engine.begin() as conn:
473 # Find record IDs that already exist. Some types like np.int64 can
474 # cause issues with sqlalchemy, convert them to int.
475 ids = sorted(int(oid) for oid in objects[idColumn])
477 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
478 result = conn.execute(query)
479 knownIds = set(row[idColumn] for row in result)
481 filter = objects[idColumn].isin(knownIds)
482 toUpdate = cast(pandas.DataFrame, objects[filter])
483 toInsert = cast(pandas.DataFrame, objects[~filter])
485 # insert new records
486 if len(toInsert) > 0:
487 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema)
489 # update existing records
490 if len(toUpdate) > 0:
491 whereKey = f"{idColumn}_param"
492 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
493 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
494 values = toUpdate.to_dict("records")
495 result = conn.execute(query, values)
497 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
498 # docstring is inherited from a base class
500 table = self._schema.get_table(ApdbTables.DiaSource)
501 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
503 with self._engine.begin() as conn:
504 # Need to make sure that every ID exists in the database, but
505 # executemany may not support rowcount, so iterate and check what is
506 # missing.
507 missing_ids: List[int] = []
508 for key, value in idMap.items():
509 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
510 result = conn.execute(query, params)
511 if result.rowcount == 0:
512 missing_ids.append(key)
513 if missing_ids:
514 missing = ",".join(str(item) for item in missing_ids)
515 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
517 def dailyJob(self) -> None:
518 # docstring is inherited from a base class
520 if self._engine.name == "postgresql":
522 # do VACUUM on all tables
523 _LOG.info("Running VACUUM on all tables")
524 connection = self._engine.raw_connection()
525 ISOLATION_LEVEL_AUTOCOMMIT = 0
526 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
527 cursor = connection.cursor()
528 cursor.execute("VACUUM ANALYSE")
530 def countUnassociatedObjects(self) -> int:
531 # docstring is inherited from a base class
533 # Retrieve the DiaObject table.
534 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
536 # Construct the sql statement.
537 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
538 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
540 # Return the count.
541 with self._engine.begin() as conn:
542 count = conn.scalar(stmt)
544 return count
546 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
547 """Returns catalog of DiaSource instances from given region.
549 Parameters
550 ----------
551 region : `lsst.sphgeom.Region`
552 Region to search for DIASources.
553 visit_time : `lsst.daf.base.DateTime`
554 Time of the current visit.
556 Returns
557 -------
558 catalog : `pandas.DataFrame`
559 Catalog containing DiaSource records.
560 """
561 # TODO: DateTime.MJD must be consistent with code in ap_association,
562 # alternatively we can fill midPointTai ourselves in store()
563 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
564 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
566 table = self._schema.get_table(ApdbTables.DiaSource)
567 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
568 query = sql.select(*columns)
570 # build selection
571 time_filter = table.columns["midPointTai"] > midPointTai_start
572 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
573 query = query.where(where)
575 # execute select
576 with Timer("DiaSource select", self.config.timer):
577 with self._engine.begin() as conn:
578 sources = pandas.read_sql_query(query, conn)
579 _LOG.debug("found %s DiaSources", len(sources))
580 return sources
582 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
583 """Returns catalog of DiaSource instances given set of DiaObject IDs.
585 Parameters
586 ----------
587 object_ids :
588 Collection of DiaObject IDs
589 visit_time : `lsst.daf.base.DateTime`
590 Time of the current visit.
592 Returns
593 -------
594 catalog : `pandas.DataFrame`
595 Catalog contaning DiaSource records.
596 """
597 # TODO: DateTime.MJD must be consistent with code in ap_association,
598 # alternatively we can fill midPointTai ourselves in store()
599 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
600 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
602 with Timer("DiaSource select", self.config.timer):
603 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start)
605 _LOG.debug("found %s DiaSources", len(sources))
606 return sources
608 def _getSourcesByIDs(
609 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float
610 ) -> pandas.DataFrame:
611 """Returns catalog of DiaSource or DiaForcedSource instances given set
612 of DiaObject IDs.
614 Parameters
615 ----------
616 table : `sqlalchemy.schema.Table`
617 Database table.
618 object_ids :
619 Collection of DiaObject IDs
620 midPointTai_start : `float`
621 Earliest midPointTai to retrieve.
623 Returns
624 -------
625 catalog : `pandas.DataFrame`
626 Catalog contaning DiaSource records. `None` is returned if
627 ``read_sources_months`` configuration parameter is set to 0 or
628 when ``object_ids`` is empty.
629 """
630 table = self._schema.get_table(table_enum)
631 columns = self._schema.get_apdb_columns(table_enum)
633 sources: Optional[pandas.DataFrame] = None
634 if len(object_ids) <= 0:
635 _LOG.debug("ID list is empty, just fetch empty result")
636 query = sql.select(*columns).where(False)
637 with self._engine.begin() as conn:
638 sources = pandas.read_sql_query(query, conn)
639 else:
640 data_frames: list[pandas.DataFrame] = []
641 for ids in chunk_iterable(sorted(object_ids), 1000):
642 query = sql.select(*columns)
644 # Some types like np.int64 can cause issues with
645 # sqlalchemy, convert them to int.
646 int_ids = [int(oid) for oid in ids]
648 # select by object id
649 query = query.where(
650 sql.expression.and_(
651 table.columns["diaObjectId"].in_(int_ids),
652 table.columns["midPointTai"] > midPointTai_start,
653 )
654 )
656 # execute select
657 with self._engine.begin() as conn:
658 data_frames.append(pandas.read_sql_query(query, conn))
660 if len(data_frames) == 1:
661 sources = data_frames[0]
662 else:
663 sources = pandas.concat(data_frames)
664 assert sources is not None, "Catalog cannot be None"
665 return sources
667 def _storeInsertId(
668 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
669 ) -> None:
671 dt = visit_time.toPython()
673 table = self._schema.get_table(ExtraTables.DiaInsertId)
675 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
676 connection.execute(stmt)
678 def _storeDiaObjects(
679 self,
680 objs: pandas.DataFrame,
681 visit_time: dafBase.DateTime,
682 insert_id: ApdbInsertId | None,
683 connection: sqlalchemy.engine.Connection,
684 ) -> None:
685 """Store catalog of DiaObjects from current visit.
687 Parameters
688 ----------
689 objs : `pandas.DataFrame`
690 Catalog with DiaObject records.
691 visit_time : `lsst.daf.base.DateTime`
692 Time of the visit.
693 insert_id : `ApdbInsertId`
694 Insert identifier.
695 """
697 # Some types like np.int64 can cause issues with sqlalchemy, convert
698 # them to int.
699 ids = sorted(int(oid) for oid in objs["diaObjectId"])
700 _LOG.debug("first object ID: %d", ids[0])
702 # TODO: Need to verify that we are using correct scale here for
703 # DATETIME representation (see DM-31996).
704 dt = visit_time.toPython()
706 # everything to be done in single transaction
707 if self.config.dia_object_index == "last_object_table":
709 # insert and replace all records in LAST table, mysql and postgres have
710 # non-standard features
711 table = self._schema.get_table(ApdbTables.DiaObjectLast)
713 # Drop the previous objects (pandas cannot upsert).
714 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
716 with Timer(table.name + " delete", self.config.timer):
717 res = connection.execute(query)
718 _LOG.debug("deleted %s objects", res.rowcount)
720 # DiaObjectLast is a subset of DiaObject, strip missing columns
721 last_column_names = [column.name for column in table.columns]
722 last_objs = objs[last_column_names]
723 last_objs = _coerce_uint64(last_objs)
725 if "lastNonForcedSource" in last_objs.columns:
726 # lastNonForcedSource is defined NOT NULL, fill it with visit time
727 # just in case.
728 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
729 else:
730 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
731 last_objs.set_index(extra_column.index, inplace=True)
732 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
734 with Timer("DiaObjectLast insert", self.config.timer):
735 last_objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
736 else:
738 # truncate existing validity intervals
739 table = self._schema.get_table(ApdbTables.DiaObject)
741 query = (
742 table.update()
743 .values(validityEnd=dt)
744 .where(
745 sql.expression.and_(
746 table.columns["diaObjectId"].in_(ids),
747 table.columns["validityEnd"].is_(None),
748 )
749 )
750 )
752 # _LOG.debug("query: %s", query)
754 with Timer(table.name + " truncate", self.config.timer):
755 res = connection.execute(query)
756 _LOG.debug("truncated %s intervals", res.rowcount)
758 objs = _coerce_uint64(objs)
760 # Fill additional columns
761 extra_columns: List[pandas.Series] = []
762 if "validityStart" in objs.columns:
763 objs["validityStart"] = dt
764 else:
765 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
766 if "validityEnd" in objs.columns:
767 objs["validityEnd"] = None
768 else:
769 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
770 if "lastNonForcedSource" in objs.columns:
771 # lastNonForcedSource is defined NOT NULL, fill it with visit time
772 # just in case.
773 objs["lastNonForcedSource"].fillna(dt, inplace=True)
774 else:
775 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
776 if extra_columns:
777 objs.set_index(extra_columns[0].index, inplace=True)
778 objs = pandas.concat([objs] + extra_columns, axis="columns")
780 # Insert history data
781 table = self._schema.get_table(ApdbTables.DiaObject)
782 history_data: list[dict] = []
783 history_stmt: Any = None
784 if insert_id is not None:
785 pk_names = [column.name for column in table.primary_key]
786 history_data = objs[pk_names].to_dict("records")
787 for row in history_data:
788 row["insert_id"] = insert_id.id
789 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
790 history_stmt = history_table.insert()
792 # insert new versions
793 with Timer("DiaObject insert", self.config.timer):
794 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
795 if history_stmt is not None:
796 connection.execute(history_stmt, *history_data)
798 def _storeDiaSources(
799 self,
800 sources: pandas.DataFrame,
801 insert_id: ApdbInsertId | None,
802 connection: sqlalchemy.engine.Connection,
803 ) -> None:
804 """Store catalog of DiaSources from current visit.
806 Parameters
807 ----------
808 sources : `pandas.DataFrame`
809 Catalog containing DiaSource records
810 """
811 table = self._schema.get_table(ApdbTables.DiaSource)
813 # Insert history data
814 history: list[dict] = []
815 history_stmt: Any = None
816 if insert_id is not None:
817 pk_names = [column.name for column in table.primary_key]
818 history = sources[pk_names].to_dict("records")
819 for row in history:
820 row["insert_id"] = insert_id.id
821 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
822 history_stmt = history_table.insert()
824 # everything to be done in single transaction
825 with Timer("DiaSource insert", self.config.timer):
826 sources = _coerce_uint64(sources)
827 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
828 if history_stmt is not None:
829 connection.execute(history_stmt, *history)
831 def _storeDiaForcedSources(
832 self,
833 sources: pandas.DataFrame,
834 insert_id: ApdbInsertId | None,
835 connection: sqlalchemy.engine.Connection,
836 ) -> None:
837 """Store a set of DiaForcedSources from current visit.
839 Parameters
840 ----------
841 sources : `pandas.DataFrame`
842 Catalog containing DiaForcedSource records
843 """
844 table = self._schema.get_table(ApdbTables.DiaForcedSource)
846 # Insert history data
847 history: list[dict] = []
848 history_stmt: Any = None
849 if insert_id is not None:
850 pk_names = [column.name for column in table.primary_key]
851 history = sources[pk_names].to_dict("records")
852 for row in history:
853 row["insert_id"] = insert_id.id
854 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
855 history_stmt = history_table.insert()
857 # everything to be done in single transaction
858 with Timer("DiaForcedSource insert", self.config.timer):
859 sources = _coerce_uint64(sources)
860 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema)
861 if history_stmt is not None:
862 connection.execute(history_stmt, *history)
864 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
865 """Generate a set of HTM indices covering specified region.
867 Parameters
868 ----------
869 region: `sphgeom.Region`
870 Region that needs to be indexed.
872 Returns
873 -------
874 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
875 """
876 _LOG.debug("region: %s", region)
877 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
879 return indices.ranges()
881 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement:
882 """Make SQLAlchemy expression for selecting records in a region."""
883 htm_index_column = table.columns[self.config.htm_index_column]
884 exprlist = []
885 pixel_ranges = self._htm_indices(region)
886 for low, upper in pixel_ranges:
887 upper -= 1
888 if low == upper:
889 exprlist.append(htm_index_column == low)
890 else:
891 exprlist.append(sql.expression.between(htm_index_column, low, upper))
893 return sql.expression.or_(*exprlist)
895 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
896 """Calculate HTM index for each record and add it to a DataFrame.
898 Notes
899 -----
900 This overrides any existing column in a DataFrame with the same name
901 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
902 returned.
903 """
904 # calculate HTM index for every DiaObject
905 htm_index = np.zeros(df.shape[0], dtype=np.int64)
906 ra_col, dec_col = self.config.ra_dec_columns
907 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
908 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
909 idx = self.pixelator.index(uv3d)
910 htm_index[i] = idx
911 df = df.copy()
912 df[self.config.htm_index_column] = htm_index
913 return df
915 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
916 """Add pixelId column to DiaSource catalog.
918 Notes
919 -----
920 This method copies pixelId value from a matching DiaObject record.
921 DiaObject catalog needs to have a pixelId column filled by
922 ``_add_obj_htm_index`` method and DiaSource records need to be
923 associated to DiaObjects via ``diaObjectId`` column.
925 This overrides any existing column in a DataFrame with the same name
926 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
927 returned.
928 """
929 pixel_id_map: Dict[int, int] = {
930 diaObjectId: pixelId
931 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
932 }
933 # DiaSources associated with SolarSystemObjects do not have an
934 # associated DiaObject hence we skip them and set their htmIndex
935 # value to 0.
936 pixel_id_map[0] = 0
937 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
938 for i, diaObjId in enumerate(sources["diaObjectId"]):
939 htm_index[i] = pixel_id_map[diaObjId]
940 sources = sources.copy()
941 sources[self.config.htm_index_column] = htm_index
942 return sources