Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
367 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 02:38 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 02:38 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29from contextlib import contextmanager
30import logging
31import numpy as np
32import pandas
33from typing import cast, Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple
35import lsst.daf.base as dafBase
36from lsst.pex.config import Field, ChoiceField, ListField
37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
38from lsst.utils.iteration import chunk_iterable
39import sqlalchemy
40from sqlalchemy import (func, sql)
41from sqlalchemy.pool import NullPool
42from .apdb import Apdb, ApdbConfig
43from .apdbSchema import ApdbTables, TableDef
44from .apdbSqlSchema import ApdbSqlSchema
45from .timer import Timer
48_LOG = logging.getLogger(__name__)
51def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
52 """Change type of the uint64 columns to int64, return copy of data frame.
53 """
54 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
55 return df.astype({name: np.int64 for name in names})
58def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
59 """Calculate starting point for time-based source search.
61 Parameters
62 ----------
63 visit_time : `lsst.daf.base.DateTime`
64 Time of current visit.
65 months : `int`
66 Number of months in the sources history.
68 Returns
69 -------
70 time : `float`
71 A ``midPointTai`` starting point, MJD time.
72 """
73 # TODO: `system` must be consistent with the code in ap_association
74 # (see DM-31996)
75 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
78@contextmanager
79def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]:
80 """Returns a connection, makes sure that ANSI mode is set for MySQL
81 """
82 with engine.begin() as conn:
83 if engine.name == 'mysql':
84 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'"))
85 yield conn
86 return
89class ApdbSqlConfig(ApdbConfig):
90 """APDB configuration class for SQL implementation (ApdbSql).
91 """
92 db_url = Field(
93 dtype=str,
94 doc="SQLAlchemy database connection URI"
95 )
96 isolation_level = ChoiceField(
97 dtype=str,
98 doc="Transaction isolation level, if unset then backend-default value "
99 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
100 "Some backends may not support every allowed value.",
101 allowed={
102 "READ_COMMITTED": "Read committed",
103 "READ_UNCOMMITTED": "Read uncommitted",
104 "REPEATABLE_READ": "Repeatable read",
105 "SERIALIZABLE": "Serializable"
106 },
107 default=None,
108 optional=True
109 )
110 connection_pool = Field(
111 dtype=bool,
112 doc="If False then disable SQLAlchemy connection pool. "
113 "Do not use connection pool when forking.",
114 default=True
115 )
116 connection_timeout = Field(
117 dtype=float,
118 doc="Maximum time to wait time for database lock to be released before "
119 "exiting. Defaults to sqlachemy defaults if not set.",
120 default=None,
121 optional=True
122 )
123 sql_echo = Field(
124 dtype=bool,
125 doc="If True then pass SQLAlchemy echo option.",
126 default=False
127 )
128 dia_object_index = ChoiceField(
129 dtype=str,
130 doc="Indexing mode for DiaObject table",
131 allowed={
132 'baseline': "Index defined in baseline schema",
133 'pix_id_iov': "(pixelId, objectId, iovStart) PK",
134 'last_object_table': "Separate DiaObjectLast table"
135 },
136 default='baseline'
137 )
138 htm_level = Field(
139 dtype=int,
140 doc="HTM indexing level",
141 default=20
142 )
143 htm_max_ranges = Field(
144 dtype=int,
145 doc="Max number of ranges in HTM envelope",
146 default=64
147 )
148 htm_index_column = Field(
149 dtype=str,
150 default="pixelId",
151 doc="Name of a HTM index column for DiaObject and DiaSource tables"
152 )
153 ra_dec_columns = ListField(
154 dtype=str,
155 default=["ra", "decl"],
156 doc="Names ra/dec columns in DiaObject table"
157 )
158 dia_object_columns = ListField(
159 dtype=str,
160 doc="List of columns to read from DiaObject, by default read all columns",
161 default=[]
162 )
163 object_last_replace = Field(
164 dtype=bool,
165 doc="If True (default) then use \"upsert\" for DiaObjectsLast table",
166 default=True
167 )
168 prefix = Field(
169 dtype=str,
170 doc="Prefix to add to table names and index names",
171 default=""
172 )
173 namespace = Field(
174 dtype=str,
175 doc=(
176 "Namespace or schema name for all tables in APDB database. "
177 "Presently only makes sense for PostgresQL backend. "
178 "If schema with this name does not exist it will be created when "
179 "APDB tables are created."
180 ),
181 default=None,
182 optional=True
183 )
184 explain = Field(
185 dtype=bool,
186 doc="If True then run EXPLAIN SQL command on each executed query",
187 default=False
188 )
189 timer = Field(
190 dtype=bool,
191 doc="If True then print/log timing information",
192 default=False
193 )
195 def validate(self) -> None:
196 super().validate()
197 if len(self.ra_dec_columns) != 2:
198 raise ValueError("ra_dec_columns must have exactly two column names")
201class ApdbSql(Apdb):
202 """Implementation of APDB interface based on SQL database.
204 The implementation is configured via standard ``pex_config`` mechanism
205 using `ApdbSqlConfig` configuration class. For an example of different
206 configurations check ``config/`` folder.
208 Parameters
209 ----------
210 config : `ApdbSqlConfig`
211 Configuration object.
212 """
214 ConfigClass = ApdbSqlConfig
216 def __init__(self, config: ApdbSqlConfig):
218 self.config = config
220 _LOG.debug("APDB Configuration:")
221 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
222 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
223 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
224 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
225 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace)
226 _LOG.debug(" schema_file: %s", self.config.schema_file)
227 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
228 _LOG.debug(" schema prefix: %s", self.config.prefix)
230 # engine is reused between multiple processes, make sure that we don't
231 # share connections by disabling pool (by using NullPool class)
232 kw = dict(echo=self.config.sql_echo)
233 conn_args: Dict[str, Any] = dict()
234 if not self.config.connection_pool:
235 kw.update(poolclass=NullPool)
236 if self.config.isolation_level is not None:
237 kw.update(isolation_level=self.config.isolation_level)
238 elif self.config.db_url.startswith("sqlite"):
239 # Use READ_UNCOMMITTED as default value for sqlite.
240 kw.update(isolation_level="READ_UNCOMMITTED")
241 if self.config.connection_timeout is not None:
242 if self.config.db_url.startswith("sqlite"):
243 conn_args.update(timeout=self.config.connection_timeout)
244 elif self.config.db_url.startswith(("postgresql", "mysql")):
245 conn_args.update(connect_timeout=self.config.connection_timeout)
246 kw.update(connect_args=conn_args)
247 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
249 self._schema = ApdbSqlSchema(engine=self._engine,
250 dia_object_index=self.config.dia_object_index,
251 schema_file=self.config.schema_file,
252 schema_name=self.config.schema_name,
253 prefix=self.config.prefix,
254 namespace=self.config.namespace,
255 htm_index_column=self.config.htm_index_column)
257 self.pixelator = HtmPixelization(self.config.htm_level)
259 def tableRowCount(self) -> Dict[str, int]:
260 """Returns dictionary with the table names and row counts.
262 Used by ``ap_proto`` to keep track of the size of the database tables.
263 Depending on database technology this could be expensive operation.
265 Returns
266 -------
267 row_counts : `dict`
268 Dict where key is a table name and value is a row count.
269 """
270 res = {}
271 tables: List[sqlalchemy.schema.Table] = [
272 self._schema.objects, self._schema.sources, self._schema.forcedSources]
273 if self.config.dia_object_index == 'last_object_table':
274 tables.append(self._schema.objects_last)
275 for table in tables:
276 stmt = sql.select([func.count()]).select_from(table)
277 count = self._engine.scalar(stmt)
278 res[table.name] = count
280 return res
282 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
283 # docstring is inherited from a base class
284 return self._schema.tableSchemas.get(table)
286 def makeSchema(self, drop: bool = False) -> None:
287 # docstring is inherited from a base class
288 self._schema.makeSchema(drop=drop)
290 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
291 # docstring is inherited from a base class
293 # decide what columns we need
294 table: sqlalchemy.schema.Table
295 if self.config.dia_object_index == 'last_object_table':
296 table = self._schema.objects_last
297 else:
298 table = self._schema.objects
299 if not self.config.dia_object_columns:
300 query = table.select()
301 else:
302 columns = [table.c[col] for col in self.config.dia_object_columns]
303 query = sql.select(columns)
305 # build selection
306 query = query.where(self._filterRegion(table, region))
308 # select latest version of objects
309 if self.config.dia_object_index != 'last_object_table':
310 query = query.where(table.c.validityEnd == None) # noqa: E711
312 _LOG.debug("query: %s", query)
314 if self.config.explain:
315 # run the same query with explain
316 self._explain(query, self._engine)
318 # execute select
319 with Timer('DiaObject select', self.config.timer):
320 with self._engine.begin() as conn:
321 objects = pandas.read_sql_query(query, conn)
322 _LOG.debug("found %s DiaObjects", len(objects))
323 return objects
325 def getDiaSources(self, region: Region,
326 object_ids: Optional[Iterable[int]],
327 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
328 # docstring is inherited from a base class
329 if self.config.read_sources_months == 0:
330 _LOG.debug("Skip DiaSources fetching")
331 return None
333 if object_ids is None:
334 # region-based select
335 return self._getDiaSourcesInRegion(region, visit_time)
336 else:
337 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
339 def getDiaForcedSources(self, region: Region,
340 object_ids: Optional[Iterable[int]],
341 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
342 """Return catalog of DiaForcedSource instances from a given region.
344 Parameters
345 ----------
346 region : `lsst.sphgeom.Region`
347 Region to search for DIASources.
348 object_ids : iterable [ `int` ], optional
349 List of DiaObject IDs to further constrain the set of returned
350 sources. If list is empty then empty catalog is returned with a
351 correct schema.
352 visit_time : `lsst.daf.base.DateTime`
353 Time of the current visit.
355 Returns
356 -------
357 catalog : `pandas.DataFrame`, or `None`
358 Catalog containing DiaSource records. `None` is returned if
359 ``read_sources_months`` configuration parameter is set to 0.
361 Raises
362 ------
363 NotImplementedError
364 Raised if ``object_ids`` is `None`.
366 Notes
367 -----
368 Even though base class allows `None` to be passed for ``object_ids``,
369 this class requires ``object_ids`` to be not-`None`.
370 `NotImplementedError` is raised if `None` is passed.
372 This method returns DiaForcedSource catalog for a region with additional
373 filtering based on DiaObject IDs. Only a subset of DiaSource history
374 is returned limited by ``read_forced_sources_months`` config parameter,
375 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
376 is always returned with a correct schema (columns/types).
377 """
379 if self.config.read_forced_sources_months == 0:
380 _LOG.debug("Skip DiaForceSources fetching")
381 return None
383 if object_ids is None:
384 # This implementation does not support region-based selection.
385 raise NotImplementedError("Region-based selection is not supported")
387 # TODO: DateTime.MJD must be consistent with code in ap_association,
388 # alternatively we can fill midPointTai ourselves in store()
389 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
390 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
392 table: sqlalchemy.schema.Table = self._schema.forcedSources
393 with Timer('DiaForcedSource select', self.config.timer):
394 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start)
396 _LOG.debug("found %s DiaForcedSources", len(sources))
397 return sources
399 def getDiaObjectsHistory(self,
400 start_time: dafBase.DateTime,
401 end_time: dafBase.DateTime,
402 region: Optional[Region] = None) -> pandas.DataFrame:
403 # docstring is inherited from a base class
405 table = self._schema.objects
406 query = table.select()
408 # build selection
409 time_filter = sql.expression.and_(
410 table.columns["validityStart"] >= start_time.toPython(),
411 table.columns["validityStart"] < end_time.toPython()
412 )
414 if region:
415 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
416 query = query.where(where)
417 else:
418 query = query.where(time_filter)
420 # execute select
421 with Timer('DiaObject history select', self.config.timer):
422 with self._engine.begin() as conn:
423 catalog = pandas.read_sql_query(query, conn)
424 _LOG.debug("found %s DiaObjects history records", len(catalog))
425 return catalog
427 def getDiaSourcesHistory(self,
428 start_time: dafBase.DateTime,
429 end_time: dafBase.DateTime,
430 region: Optional[Region] = None) -> pandas.DataFrame:
431 # docstring is inherited from a base class
433 table = self._schema.sources
434 query = table.select()
436 # build selection
437 time_filter = sql.expression.and_(
438 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
439 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
440 )
442 if region:
443 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
444 query = query.where(where)
445 else:
446 query = query.where(time_filter)
448 # execute select
449 with Timer('DiaSource history select', self.config.timer):
450 with self._engine.begin() as conn:
451 catalog = pandas.read_sql_query(query, conn)
452 _LOG.debug("found %s DiaSource history records", len(catalog))
453 return catalog
455 def getDiaForcedSourcesHistory(self,
456 start_time: dafBase.DateTime,
457 end_time: dafBase.DateTime,
458 region: Optional[Region] = None) -> pandas.DataFrame:
459 # docstring is inherited from a base class
461 table = self._schema.forcedSources
462 query = table.select()
464 # build selection
465 time_filter = sql.expression.and_(
466 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
467 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
468 )
469 # Forced sources have no pixel index, so no region filtering
470 query = query.where(time_filter)
472 # execute select
473 with Timer('DiaForcedSource history select', self.config.timer):
474 with self._engine.begin() as conn:
475 catalog = pandas.read_sql_query(query, conn)
476 _LOG.debug("found %s DiaForcedSource history records", len(catalog))
477 return catalog
479 def getSSObjects(self) -> pandas.DataFrame:
480 # docstring is inherited from a base class
482 table = self._schema.ssObjects
483 query = table.select()
485 if self.config.explain:
486 # run the same query with explain
487 self._explain(query, self._engine)
489 # execute select
490 with Timer('DiaObject select', self.config.timer):
491 with self._engine.begin() as conn:
492 objects = pandas.read_sql_query(query, conn)
493 _LOG.debug("found %s SSObjects", len(objects))
494 return objects
496 def store(self,
497 visit_time: dafBase.DateTime,
498 objects: pandas.DataFrame,
499 sources: Optional[pandas.DataFrame] = None,
500 forced_sources: Optional[pandas.DataFrame] = None) -> None:
501 # docstring is inherited from a base class
503 # fill pixelId column for DiaObjects
504 objects = self._add_obj_htm_index(objects)
505 self._storeDiaObjects(objects, visit_time)
507 if sources is not None:
508 # copy pixelId column from DiaObjects to DiaSources
509 sources = self._add_src_htm_index(sources, objects)
510 self._storeDiaSources(sources)
512 if forced_sources is not None:
513 self._storeDiaForcedSources(forced_sources)
515 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
516 # docstring is inherited from a base class
518 idColumn = "ssObjectId"
519 table = self._schema.ssObjects
521 # everything to be done in single transaction
522 with self._engine.begin() as conn:
524 # Find record IDs that already exist. Some types like np.int64 can
525 # cause issues with sqlalchemy, convert them to int.
526 ids = sorted(int(oid) for oid in objects[idColumn])
528 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
529 result = conn.execute(query)
530 knownIds = set(row[idColumn] for row in result)
532 filter = objects[idColumn].isin(knownIds)
533 toUpdate = cast(pandas.DataFrame, objects[filter])
534 toInsert = cast(pandas.DataFrame, objects[~filter])
536 # insert new records
537 if len(toInsert) > 0:
538 toInsert.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
540 # update existing records
541 if len(toUpdate) > 0:
542 whereKey = f"{idColumn}_param"
543 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
544 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
545 values = toUpdate.to_dict("records")
546 result = conn.execute(query, values)
548 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
549 # docstring is inherited from a base class
551 table = self._schema.sources
552 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
554 with self._engine.begin() as conn:
555 # Need to make sure that every ID exists in the database, but
556 # executemany may not support rowcount, so iterate and check what is
557 # missing.
558 missing_ids: List[int] = []
559 for key, value in idMap.items():
560 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
561 result = conn.execute(query, params)
562 if result.rowcount == 0:
563 missing_ids.append(key)
564 if missing_ids:
565 missing = ",".join(str(item)for item in missing_ids)
566 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
568 def dailyJob(self) -> None:
569 # docstring is inherited from a base class
571 if self._engine.name == 'postgresql':
573 # do VACUUM on all tables
574 _LOG.info("Running VACUUM on all tables")
575 connection = self._engine.raw_connection()
576 ISOLATION_LEVEL_AUTOCOMMIT = 0
577 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
578 cursor = connection.cursor()
579 cursor.execute("VACUUM ANALYSE")
581 def countUnassociatedObjects(self) -> int:
582 # docstring is inherited from a base class
584 # Retrieve the DiaObject table.
585 table: sqlalchemy.schema.Table = self._schema.objects
587 # Construct the sql statement.
588 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
589 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
591 # Return the count.
592 with self._engine.begin() as conn:
593 count = conn.scalar(stmt)
595 return count
597 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime
598 ) -> pandas.DataFrame:
599 """Returns catalog of DiaSource instances from given region.
601 Parameters
602 ----------
603 region : `lsst.sphgeom.Region`
604 Region to search for DIASources.
605 visit_time : `lsst.daf.base.DateTime`
606 Time of the current visit.
608 Returns
609 -------
610 catalog : `pandas.DataFrame`
611 Catalog containing DiaSource records.
612 """
613 # TODO: DateTime.MJD must be consistent with code in ap_association,
614 # alternatively we can fill midPointTai ourselves in store()
615 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
616 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
618 table: sqlalchemy.schema.Table = self._schema.sources
619 query = table.select()
621 # build selection
622 time_filter = table.columns["midPointTai"] > midPointTai_start
623 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
624 query = query.where(where)
626 # execute select
627 with Timer('DiaSource select', self.config.timer):
628 with _ansi_session(self._engine) as conn:
629 sources = pandas.read_sql_query(query, conn)
630 _LOG.debug("found %s DiaSources", len(sources))
631 return sources
633 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime
634 ) -> pandas.DataFrame:
635 """Returns catalog of DiaSource instances given set of DiaObject IDs.
637 Parameters
638 ----------
639 object_ids :
640 Collection of DiaObject IDs
641 visit_time : `lsst.daf.base.DateTime`
642 Time of the current visit.
644 Returns
645 -------
646 catalog : `pandas.DataFrame`
647 Catalog contaning DiaSource records.
648 """
649 # TODO: DateTime.MJD must be consistent with code in ap_association,
650 # alternatively we can fill midPointTai ourselves in store()
651 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
652 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
654 table: sqlalchemy.schema.Table = self._schema.sources
655 with Timer('DiaSource select', self.config.timer):
656 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start)
658 _LOG.debug("found %s DiaSources", len(sources))
659 return sources
661 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table,
662 object_ids: List[int],
663 midPointTai_start: float
664 ) -> pandas.DataFrame:
665 """Returns catalog of DiaSource or DiaForcedSource instances given set
666 of DiaObject IDs.
668 Parameters
669 ----------
670 table : `sqlalchemy.schema.Table`
671 Database table.
672 object_ids :
673 Collection of DiaObject IDs
674 midPointTai_start : `float`
675 Earliest midPointTai to retrieve.
677 Returns
678 -------
679 catalog : `pandas.DataFrame`
680 Catalog contaning DiaSource records. `None` is returned if
681 ``read_sources_months`` configuration parameter is set to 0 or
682 when ``object_ids`` is empty.
683 """
684 sources: Optional[pandas.DataFrame] = None
685 with _ansi_session(self._engine) as conn:
686 if len(object_ids) <= 0:
687 _LOG.debug("ID list is empty, just fetch empty result")
688 query = table.select().where(False)
689 sources = pandas.read_sql_query(query, conn)
690 else:
691 for ids in chunk_iterable(sorted(object_ids), 1000):
692 query = table.select()
694 # Some types like np.int64 can cause issues with
695 # sqlalchemy, convert them to int.
696 int_ids = [int(oid) for oid in ids]
698 # select by object id
699 query = query.where(
700 sql.expression.and_(
701 table.columns["diaObjectId"].in_(int_ids),
702 table.columns["midPointTai"] > midPointTai_start,
703 )
704 )
706 # execute select
707 df = pandas.read_sql_query(query, conn)
708 if sources is None:
709 sources = df
710 else:
711 sources = sources.append(df)
712 assert sources is not None, "Catalog cannot be None"
713 return sources
715 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
716 """Store catalog of DiaObjects from current visit.
718 Parameters
719 ----------
720 objs : `pandas.DataFrame`
721 Catalog with DiaObject records.
722 visit_time : `lsst.daf.base.DateTime`
723 Time of the visit.
724 """
726 # Some types like np.int64 can cause issues with sqlalchemy, convert
727 # them to int.
728 ids = sorted(int(oid) for oid in objs['diaObjectId'])
729 _LOG.debug("first object ID: %d", ids[0])
731 # NOTE: workaround for sqlite, need this here to avoid
732 # "database is locked" error.
733 table: sqlalchemy.schema.Table = self._schema.objects
735 # TODO: Need to verify that we are using correct scale here for
736 # DATETIME representation (see DM-31996).
737 dt = visit_time.toPython()
739 # everything to be done in single transaction
740 with _ansi_session(self._engine) as conn:
742 if self.config.dia_object_index == 'last_object_table':
744 # insert and replace all records in LAST table, mysql and postgres have
745 # non-standard features
746 table = self._schema.objects_last
747 do_replace = self.config.object_last_replace
748 # If the input data is of type Pandas, we drop the previous
749 # objects regardless of the do_replace setting due to how
750 # Pandas inserts objects.
751 if not do_replace or isinstance(objs, pandas.DataFrame):
752 query = table.delete().where(
753 table.columns["diaObjectId"].in_(ids)
754 )
756 if self.config.explain:
757 # run the same query with explain
758 self._explain(query, conn)
760 with Timer(table.name + ' delete', self.config.timer):
761 res = conn.execute(query)
762 _LOG.debug("deleted %s objects", res.rowcount)
764 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt)
765 with Timer("DiaObjectLast insert", self.config.timer):
766 objs = _coerce_uint64(objs)
767 for col, data in extra_columns.items():
768 objs[col] = data
769 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
770 else:
772 # truncate existing validity intervals
773 table = self._schema.objects
775 query = table.update().values(validityEnd=dt).where(
776 sql.expression.and_(
777 table.columns["diaObjectId"].in_(ids),
778 table.columns["validityEnd"].is_(None),
779 )
780 )
782 # _LOG.debug("query: %s", query)
784 if self.config.explain:
785 # run the same query with explain
786 self._explain(query, conn)
788 with Timer(table.name + ' truncate', self.config.timer):
789 res = conn.execute(query)
790 _LOG.debug("truncated %s intervals", res.rowcount)
792 # insert new versions
793 table = self._schema.objects
794 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt,
795 validityEnd=None)
796 with Timer("DiaObject insert", self.config.timer):
797 objs = _coerce_uint64(objs)
798 for col, data in extra_columns.items():
799 objs[col] = data
800 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
802 def _storeDiaSources(self, sources: pandas.DataFrame) -> None:
803 """Store catalog of DiaSources from current visit.
805 Parameters
806 ----------
807 sources : `pandas.DataFrame`
808 Catalog containing DiaSource records
809 """
810 # everything to be done in single transaction
811 with _ansi_session(self._engine) as conn:
813 with Timer("DiaSource insert", self.config.timer):
814 sources = _coerce_uint64(sources)
815 table = self._schema.sources
816 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
818 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None:
819 """Store a set of DiaForcedSources from current visit.
821 Parameters
822 ----------
823 sources : `pandas.DataFrame`
824 Catalog containing DiaForcedSource records
825 """
827 # everything to be done in single transaction
828 with _ansi_session(self._engine) as conn:
830 with Timer("DiaForcedSource insert", self.config.timer):
831 sources = _coerce_uint64(sources)
832 table = self._schema.forcedSources
833 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
835 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None:
836 """Run the query with explain
837 """
839 _LOG.info("explain for query: %s...", query[:64])
841 if conn.engine.name == 'mysql':
842 query = "EXPLAIN EXTENDED " + query
843 else:
844 query = "EXPLAIN " + query
846 res = conn.execute(sql.text(query))
847 if res.returns_rows:
848 _LOG.info("explain: %s", res.keys())
849 for row in res:
850 _LOG.info("explain: %s", row)
851 else:
852 _LOG.info("EXPLAIN returned nothing")
854 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
855 """Generate a set of HTM indices covering specified region.
857 Parameters
858 ----------
859 region: `sphgeom.Region`
860 Region that needs to be indexed.
862 Returns
863 -------
864 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
865 """
866 _LOG.debug('region: %s', region)
867 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
869 return indices.ranges()
871 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement:
872 """Make SQLAlchemy expression for selecting records in a region.
873 """
874 htm_index_column = table.columns[self.config.htm_index_column]
875 exprlist = []
876 pixel_ranges = self._htm_indices(region)
877 for low, upper in pixel_ranges:
878 upper -= 1
879 if low == upper:
880 exprlist.append(htm_index_column == low)
881 else:
882 exprlist.append(sql.expression.between(htm_index_column, low, upper))
884 return sql.expression.or_(*exprlist)
886 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
887 """Calculate HTM index for each record and add it to a DataFrame.
889 Notes
890 -----
891 This overrides any existing column in a DataFrame with the same name
892 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
893 returned.
894 """
895 # calculate HTM index for every DiaObject
896 htm_index = np.zeros(df.shape[0], dtype=np.int64)
897 ra_col, dec_col = self.config.ra_dec_columns
898 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
899 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
900 idx = self.pixelator.index(uv3d)
901 htm_index[i] = idx
902 df = df.copy()
903 df[self.config.htm_index_column] = htm_index
904 return df
906 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
907 """Add pixelId column to DiaSource catalog.
909 Notes
910 -----
911 This method copies pixelId value from a matching DiaObject record.
912 DiaObject catalog needs to have a pixelId column filled by
913 ``_add_obj_htm_index`` method and DiaSource records need to be
914 associated to DiaObjects via ``diaObjectId`` column.
916 This overrides any existing column in a DataFrame with the same name
917 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
918 returned.
919 """
920 pixel_id_map: Dict[int, int] = {
921 diaObjectId: pixelId for diaObjectId, pixelId
922 in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
923 }
924 # DiaSources associated with SolarSystemObjects do not have an
925 # associated DiaObject hence we skip them and set their htmIndex
926 # value to 0.
927 pixel_id_map[0] = 0
928 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
929 for i, diaObjId in enumerate(sources["diaObjectId"]):
930 htm_index[i] = pixel_id_map[diaObjId]
931 sources = sources.copy()
932 sources[self.config.htm_index_column] = htm_index
933 return sources