Coverage for python/lsst/dax/apdb/apdbSql.py: 14%
372 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-20 00:41 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-20 00:41 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29import logging
30from collections.abc import Iterable, Iterator, Mapping, MutableMapping
31from contextlib import contextmanager
32from typing import Any, Dict, List, Optional, Tuple, cast
34import lsst.daf.base as dafBase
35import numpy as np
36import pandas
37import sqlalchemy
38from felis.simple import Table
39from lsst.pex.config import ChoiceField, Field, ListField
40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
41from lsst.utils.iteration import chunk_iterable
42from sqlalchemy import func, sql
43from sqlalchemy.pool import NullPool
45from .apdb import Apdb, ApdbConfig
46from .apdbSchema import ApdbTables
47from .apdbSqlSchema import ApdbSqlSchema
48from .timer import Timer
50_LOG = logging.getLogger(__name__)
53def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
54 """Change type of the uint64 columns to int64, return copy of data frame.
55 """
56 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
57 return df.astype({name: np.int64 for name in names})
60def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
61 """Calculate starting point for time-based source search.
63 Parameters
64 ----------
65 visit_time : `lsst.daf.base.DateTime`
66 Time of current visit.
67 months : `int`
68 Number of months in the sources history.
70 Returns
71 -------
72 time : `float`
73 A ``midPointTai`` starting point, MJD time.
74 """
75 # TODO: `system` must be consistent with the code in ap_association
76 # (see DM-31996)
77 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
80@contextmanager
81def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]:
82 """Returns a connection, makes sure that ANSI mode is set for MySQL
83 """
84 with engine.begin() as conn:
85 if engine.name == 'mysql':
86 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'"))
87 yield conn
88 return
91class ApdbSqlConfig(ApdbConfig):
92 """APDB configuration class for SQL implementation (ApdbSql).
93 """
94 db_url = Field[str](
95 doc="SQLAlchemy database connection URI"
96 )
97 isolation_level = ChoiceField[str](
98 doc="Transaction isolation level, if unset then backend-default value "
99 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
100 "Some backends may not support every allowed value.",
101 allowed={
102 "READ_COMMITTED": "Read committed",
103 "READ_UNCOMMITTED": "Read uncommitted",
104 "REPEATABLE_READ": "Repeatable read",
105 "SERIALIZABLE": "Serializable"
106 },
107 default=None,
108 optional=True
109 )
110 connection_pool = Field[bool](
111 doc="If False then disable SQLAlchemy connection pool. "
112 "Do not use connection pool when forking.",
113 default=True
114 )
115 connection_timeout = Field[float](
116 doc="Maximum time to wait time for database lock to be released before "
117 "exiting. Defaults to sqlalchemy defaults if not set.",
118 default=None,
119 optional=True
120 )
121 sql_echo = Field[bool](
122 doc="If True then pass SQLAlchemy echo option.",
123 default=False
124 )
125 dia_object_index = ChoiceField[str](
126 doc="Indexing mode for DiaObject table",
127 allowed={
128 'baseline': "Index defined in baseline schema",
129 'pix_id_iov': "(pixelId, objectId, iovStart) PK",
130 'last_object_table': "Separate DiaObjectLast table"
131 },
132 default='baseline'
133 )
134 htm_level = Field[int](
135 doc="HTM indexing level",
136 default=20
137 )
138 htm_max_ranges = Field[int](
139 doc="Max number of ranges in HTM envelope",
140 default=64
141 )
142 htm_index_column = Field[str](
143 default="pixelId",
144 doc="Name of a HTM index column for DiaObject and DiaSource tables"
145 )
146 ra_dec_columns = ListField[str](
147 default=["ra", "decl"],
148 doc="Names ra/dec columns in DiaObject table"
149 )
150 dia_object_columns = ListField[str](
151 doc="List of columns to read from DiaObject, by default read all columns",
152 default=[]
153 )
154 object_last_replace = Field[bool](
155 doc="If True (default) then use \"upsert\" for DiaObjectsLast table",
156 default=True,
157 deprecated="This field is not used and will be removed on 2022-21-31."
158 )
159 prefix = Field[str](
160 doc="Prefix to add to table names and index names",
161 default=""
162 )
163 namespace = Field[str](
164 doc=(
165 "Namespace or schema name for all tables in APDB database. "
166 "Presently only makes sense for PostgresQL backend. "
167 "If schema with this name does not exist it will be created when "
168 "APDB tables are created."
169 ),
170 default=None,
171 optional=True
172 )
173 explain = Field[bool](
174 doc="If True then run EXPLAIN SQL command on each executed query",
175 default=False
176 )
177 timer = Field[bool](
178 doc="If True then print/log timing information",
179 default=False
180 )
182 def validate(self) -> None:
183 super().validate()
184 if len(self.ra_dec_columns) != 2:
185 raise ValueError("ra_dec_columns must have exactly two column names")
188class ApdbSql(Apdb):
189 """Implementation of APDB interface based on SQL database.
191 The implementation is configured via standard ``pex_config`` mechanism
192 using `ApdbSqlConfig` configuration class. For an example of different
193 configurations check ``config/`` folder.
195 Parameters
196 ----------
197 config : `ApdbSqlConfig`
198 Configuration object.
199 """
201 ConfigClass = ApdbSqlConfig
203 def __init__(self, config: ApdbSqlConfig):
205 config.validate()
206 self.config = config
208 _LOG.debug("APDB Configuration:")
209 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
210 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
211 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
212 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
213 _LOG.debug(" schema_file: %s", self.config.schema_file)
214 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
215 _LOG.debug(" schema prefix: %s", self.config.prefix)
217 # engine is reused between multiple processes, make sure that we don't
218 # share connections by disabling pool (by using NullPool class)
219 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
220 conn_args: Dict[str, Any] = dict()
221 if not self.config.connection_pool:
222 kw.update(poolclass=NullPool)
223 if self.config.isolation_level is not None:
224 kw.update(isolation_level=self.config.isolation_level)
225 elif self.config.db_url.startswith("sqlite"): # type: ignore
226 # Use READ_UNCOMMITTED as default value for sqlite.
227 kw.update(isolation_level="READ_UNCOMMITTED")
228 if self.config.connection_timeout is not None:
229 if self.config.db_url.startswith("sqlite"):
230 conn_args.update(timeout=self.config.connection_timeout)
231 elif self.config.db_url.startswith(("postgresql", "mysql")):
232 conn_args.update(connect_timeout=self.config.connection_timeout)
233 kw.update(connect_args=conn_args)
234 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
236 self._schema = ApdbSqlSchema(engine=self._engine,
237 dia_object_index=self.config.dia_object_index,
238 schema_file=self.config.schema_file,
239 schema_name=self.config.schema_name,
240 prefix=self.config.prefix,
241 namespace=self.config.namespace,
242 htm_index_column=self.config.htm_index_column)
244 self.pixelator = HtmPixelization(self.config.htm_level)
246 def tableRowCount(self) -> Dict[str, int]:
247 """Returns dictionary with the table names and row counts.
249 Used by ``ap_proto`` to keep track of the size of the database tables.
250 Depending on database technology this could be expensive operation.
252 Returns
253 -------
254 row_counts : `dict`
255 Dict where key is a table name and value is a row count.
256 """
257 res = {}
258 tables: List[sqlalchemy.schema.Table] = [
259 self._schema.objects, self._schema.sources, self._schema.forcedSources]
260 if self.config.dia_object_index == 'last_object_table':
261 tables.append(self._schema.objects_last)
262 for table in tables:
263 stmt = sql.select([func.count()]).select_from(table)
264 count = self._engine.scalar(stmt)
265 res[table.name] = count
267 return res
269 def tableDef(self, table: ApdbTables) -> Optional[Table]:
270 # docstring is inherited from a base class
271 return self._schema.tableSchemas.get(table)
273 def makeSchema(self, drop: bool = False) -> None:
274 # docstring is inherited from a base class
275 self._schema.makeSchema(drop=drop)
277 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
278 # docstring is inherited from a base class
280 # decide what columns we need
281 table: sqlalchemy.schema.Table
282 if self.config.dia_object_index == 'last_object_table':
283 table = self._schema.objects_last
284 else:
285 table = self._schema.objects
286 if not self.config.dia_object_columns:
287 query = table.select()
288 else:
289 columns = [table.c[col] for col in self.config.dia_object_columns]
290 query = sql.select(columns)
292 # build selection
293 query = query.where(self._filterRegion(table, region))
295 # select latest version of objects
296 if self.config.dia_object_index != 'last_object_table':
297 query = query.where(table.c.validityEnd == None) # noqa: E711
299 # _LOG.debug("query: %s", query)
301 if self.config.explain:
302 # run the same query with explain
303 self._explain(query, self._engine)
305 # execute select
306 with Timer('DiaObject select', self.config.timer):
307 with self._engine.begin() as conn:
308 objects = pandas.read_sql_query(query, conn)
309 _LOG.debug("found %s DiaObjects", len(objects))
310 return objects
312 def getDiaSources(self, region: Region,
313 object_ids: Optional[Iterable[int]],
314 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
315 # docstring is inherited from a base class
316 if self.config.read_sources_months == 0:
317 _LOG.debug("Skip DiaSources fetching")
318 return None
320 if object_ids is None:
321 # region-based select
322 return self._getDiaSourcesInRegion(region, visit_time)
323 else:
324 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
326 def getDiaForcedSources(self, region: Region,
327 object_ids: Optional[Iterable[int]],
328 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
329 """Return catalog of DiaForcedSource instances from a given region.
331 Parameters
332 ----------
333 region : `lsst.sphgeom.Region`
334 Region to search for DIASources.
335 object_ids : iterable [ `int` ], optional
336 List of DiaObject IDs to further constrain the set of returned
337 sources. If list is empty then empty catalog is returned with a
338 correct schema.
339 visit_time : `lsst.daf.base.DateTime`
340 Time of the current visit.
342 Returns
343 -------
344 catalog : `pandas.DataFrame`, or `None`
345 Catalog containing DiaSource records. `None` is returned if
346 ``read_sources_months`` configuration parameter is set to 0.
348 Raises
349 ------
350 NotImplementedError
351 Raised if ``object_ids`` is `None`.
353 Notes
354 -----
355 Even though base class allows `None` to be passed for ``object_ids``,
356 this class requires ``object_ids`` to be not-`None`.
357 `NotImplementedError` is raised if `None` is passed.
359 This method returns DiaForcedSource catalog for a region with additional
360 filtering based on DiaObject IDs. Only a subset of DiaSource history
361 is returned limited by ``read_forced_sources_months`` config parameter,
362 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
363 is always returned with a correct schema (columns/types).
364 """
366 if self.config.read_forced_sources_months == 0:
367 _LOG.debug("Skip DiaForceSources fetching")
368 return None
370 if object_ids is None:
371 # This implementation does not support region-based selection.
372 raise NotImplementedError("Region-based selection is not supported")
374 # TODO: DateTime.MJD must be consistent with code in ap_association,
375 # alternatively we can fill midPointTai ourselves in store()
376 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
377 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
379 table: sqlalchemy.schema.Table = self._schema.forcedSources
380 with Timer('DiaForcedSource select', self.config.timer):
381 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start)
383 _LOG.debug("found %s DiaForcedSources", len(sources))
384 return sources
386 def getDiaObjectsHistory(self,
387 start_time: dafBase.DateTime,
388 end_time: dafBase.DateTime,
389 region: Optional[Region] = None) -> pandas.DataFrame:
390 # docstring is inherited from a base class
392 table = self._schema.objects
393 query = table.select()
395 # build selection
396 time_filter = sql.expression.and_(
397 table.columns["validityStart"] >= start_time.toPython(),
398 table.columns["validityStart"] < end_time.toPython()
399 )
401 if region:
402 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
403 query = query.where(where)
404 else:
405 query = query.where(time_filter)
407 # execute select
408 with Timer('DiaObject history select', self.config.timer):
409 with self._engine.begin() as conn:
410 catalog = pandas.read_sql_query(query, conn)
411 _LOG.debug("found %s DiaObjects history records", len(catalog))
412 return catalog
414 def getDiaSourcesHistory(self,
415 start_time: dafBase.DateTime,
416 end_time: dafBase.DateTime,
417 region: Optional[Region] = None) -> pandas.DataFrame:
418 # docstring is inherited from a base class
420 table = self._schema.sources
421 query = table.select()
423 # build selection
424 time_filter = sql.expression.and_(
425 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
426 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
427 )
429 if region:
430 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
431 query = query.where(where)
432 else:
433 query = query.where(time_filter)
435 # execute select
436 with Timer('DiaSource history select', self.config.timer):
437 with self._engine.begin() as conn:
438 catalog = pandas.read_sql_query(query, conn)
439 _LOG.debug("found %s DiaSource history records", len(catalog))
440 return catalog
442 def getDiaForcedSourcesHistory(self,
443 start_time: dafBase.DateTime,
444 end_time: dafBase.DateTime,
445 region: Optional[Region] = None) -> pandas.DataFrame:
446 # docstring is inherited from a base class
448 table = self._schema.forcedSources
449 query = table.select()
451 # build selection
452 time_filter = sql.expression.and_(
453 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
454 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
455 )
456 # Forced sources have no pixel index, so no region filtering
457 query = query.where(time_filter)
459 # execute select
460 with Timer('DiaForcedSource history select', self.config.timer):
461 with self._engine.begin() as conn:
462 catalog = pandas.read_sql_query(query, conn)
463 _LOG.debug("found %s DiaForcedSource history records", len(catalog))
464 return catalog
466 def getSSObjects(self) -> pandas.DataFrame:
467 # docstring is inherited from a base class
469 table = self._schema.ssObjects
470 query = table.select()
472 if self.config.explain:
473 # run the same query with explain
474 self._explain(query, self._engine)
476 # execute select
477 with Timer('DiaObject select', self.config.timer):
478 with self._engine.begin() as conn:
479 objects = pandas.read_sql_query(query, conn)
480 _LOG.debug("found %s SSObjects", len(objects))
481 return objects
483 def store(self,
484 visit_time: dafBase.DateTime,
485 objects: pandas.DataFrame,
486 sources: Optional[pandas.DataFrame] = None,
487 forced_sources: Optional[pandas.DataFrame] = None) -> None:
488 # docstring is inherited from a base class
490 # fill pixelId column for DiaObjects
491 objects = self._add_obj_htm_index(objects)
492 self._storeDiaObjects(objects, visit_time)
494 if sources is not None:
495 # copy pixelId column from DiaObjects to DiaSources
496 sources = self._add_src_htm_index(sources, objects)
497 self._storeDiaSources(sources)
499 if forced_sources is not None:
500 self._storeDiaForcedSources(forced_sources)
502 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
503 # docstring is inherited from a base class
505 idColumn = "ssObjectId"
506 table = self._schema.ssObjects
508 # everything to be done in single transaction
509 with self._engine.begin() as conn:
511 # Find record IDs that already exist. Some types like np.int64 can
512 # cause issues with sqlalchemy, convert them to int.
513 ids = sorted(int(oid) for oid in objects[idColumn])
515 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
516 result = conn.execute(query)
517 knownIds = set(row[idColumn] for row in result)
519 filter = objects[idColumn].isin(knownIds)
520 toUpdate = cast(pandas.DataFrame, objects[filter])
521 toInsert = cast(pandas.DataFrame, objects[~filter])
523 # insert new records
524 if len(toInsert) > 0:
525 toInsert.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
527 # update existing records
528 if len(toUpdate) > 0:
529 whereKey = f"{idColumn}_param"
530 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
531 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
532 values = toUpdate.to_dict("records")
533 result = conn.execute(query, values)
535 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
536 # docstring is inherited from a base class
538 table = self._schema.sources
539 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
541 with self._engine.begin() as conn:
542 # Need to make sure that every ID exists in the database, but
543 # executemany may not support rowcount, so iterate and check what is
544 # missing.
545 missing_ids: List[int] = []
546 for key, value in idMap.items():
547 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
548 result = conn.execute(query, params)
549 if result.rowcount == 0:
550 missing_ids.append(key)
551 if missing_ids:
552 missing = ",".join(str(item)for item in missing_ids)
553 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
555 def dailyJob(self) -> None:
556 # docstring is inherited from a base class
558 if self._engine.name == 'postgresql':
560 # do VACUUM on all tables
561 _LOG.info("Running VACUUM on all tables")
562 connection = self._engine.raw_connection()
563 ISOLATION_LEVEL_AUTOCOMMIT = 0
564 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
565 cursor = connection.cursor()
566 cursor.execute("VACUUM ANALYSE")
568 def countUnassociatedObjects(self) -> int:
569 # docstring is inherited from a base class
571 # Retrieve the DiaObject table.
572 table: sqlalchemy.schema.Table = self._schema.objects
574 # Construct the sql statement.
575 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
576 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
578 # Return the count.
579 with self._engine.begin() as conn:
580 count = conn.scalar(stmt)
582 return count
584 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime
585 ) -> pandas.DataFrame:
586 """Returns catalog of DiaSource instances from given region.
588 Parameters
589 ----------
590 region : `lsst.sphgeom.Region`
591 Region to search for DIASources.
592 visit_time : `lsst.daf.base.DateTime`
593 Time of the current visit.
595 Returns
596 -------
597 catalog : `pandas.DataFrame`
598 Catalog containing DiaSource records.
599 """
600 # TODO: DateTime.MJD must be consistent with code in ap_association,
601 # alternatively we can fill midPointTai ourselves in store()
602 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
603 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
605 table: sqlalchemy.schema.Table = self._schema.sources
606 query = table.select()
608 # build selection
609 time_filter = table.columns["midPointTai"] > midPointTai_start
610 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
611 query = query.where(where)
613 # execute select
614 with Timer('DiaSource select', self.config.timer):
615 with _ansi_session(self._engine) as conn:
616 sources = pandas.read_sql_query(query, conn)
617 _LOG.debug("found %s DiaSources", len(sources))
618 return sources
620 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime
621 ) -> pandas.DataFrame:
622 """Returns catalog of DiaSource instances given set of DiaObject IDs.
624 Parameters
625 ----------
626 object_ids :
627 Collection of DiaObject IDs
628 visit_time : `lsst.daf.base.DateTime`
629 Time of the current visit.
631 Returns
632 -------
633 catalog : `pandas.DataFrame`
634 Catalog contaning DiaSource records.
635 """
636 # TODO: DateTime.MJD must be consistent with code in ap_association,
637 # alternatively we can fill midPointTai ourselves in store()
638 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
639 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
641 table: sqlalchemy.schema.Table = self._schema.sources
642 with Timer('DiaSource select', self.config.timer):
643 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start)
645 _LOG.debug("found %s DiaSources", len(sources))
646 return sources
648 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table,
649 object_ids: List[int],
650 midPointTai_start: float
651 ) -> pandas.DataFrame:
652 """Returns catalog of DiaSource or DiaForcedSource instances given set
653 of DiaObject IDs.
655 Parameters
656 ----------
657 table : `sqlalchemy.schema.Table`
658 Database table.
659 object_ids :
660 Collection of DiaObject IDs
661 midPointTai_start : `float`
662 Earliest midPointTai to retrieve.
664 Returns
665 -------
666 catalog : `pandas.DataFrame`
667 Catalog contaning DiaSource records. `None` is returned if
668 ``read_sources_months`` configuration parameter is set to 0 or
669 when ``object_ids`` is empty.
670 """
671 sources: Optional[pandas.DataFrame] = None
672 with _ansi_session(self._engine) as conn:
673 if len(object_ids) <= 0:
674 _LOG.debug("ID list is empty, just fetch empty result")
675 query = table.select().where(False)
676 sources = pandas.read_sql_query(query, conn)
677 else:
678 for ids in chunk_iterable(sorted(object_ids), 1000):
679 query = table.select()
681 # Some types like np.int64 can cause issues with
682 # sqlalchemy, convert them to int.
683 int_ids = [int(oid) for oid in ids]
685 # select by object id
686 query = query.where(
687 sql.expression.and_(
688 table.columns["diaObjectId"].in_(int_ids),
689 table.columns["midPointTai"] > midPointTai_start,
690 )
691 )
693 # execute select
694 df = pandas.read_sql_query(query, conn)
695 if sources is None:
696 sources = df
697 else:
698 sources = sources.append(df)
699 assert sources is not None, "Catalog cannot be None"
700 return sources
702 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
703 """Store catalog of DiaObjects from current visit.
705 Parameters
706 ----------
707 objs : `pandas.DataFrame`
708 Catalog with DiaObject records.
709 visit_time : `lsst.daf.base.DateTime`
710 Time of the visit.
711 """
713 # Some types like np.int64 can cause issues with sqlalchemy, convert
714 # them to int.
715 ids = sorted(int(oid) for oid in objs['diaObjectId'])
716 _LOG.debug("first object ID: %d", ids[0])
718 # NOTE: workaround for sqlite, need this here to avoid
719 # "database is locked" error.
720 table: sqlalchemy.schema.Table = self._schema.objects
722 # TODO: Need to verify that we are using correct scale here for
723 # DATETIME representation (see DM-31996).
724 dt = visit_time.toPython()
726 # everything to be done in single transaction
727 with _ansi_session(self._engine) as conn:
729 if self.config.dia_object_index == 'last_object_table':
731 # insert and replace all records in LAST table, mysql and postgres have
732 # non-standard features
733 table = self._schema.objects_last
735 # Drop the previous objects (pandas cannot upsert).
736 query = table.delete().where(
737 table.columns["diaObjectId"].in_(ids)
738 )
740 if self.config.explain:
741 # run the same query with explain
742 self._explain(query, conn)
744 with Timer(table.name + ' delete', self.config.timer):
745 res = conn.execute(query)
746 _LOG.debug("deleted %s objects", res.rowcount)
748 # DiaObjectLast is a subset of DiaObject, strip missing columns
749 last_column_names = [column.name for column in table.columns]
750 last_objs = objs[last_column_names]
752 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt)
753 with Timer("DiaObjectLast insert", self.config.timer):
754 last_objs = _coerce_uint64(last_objs)
755 for col, data in extra_columns.items():
756 last_objs[col] = data
757 last_objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
758 else:
760 # truncate existing validity intervals
761 table = self._schema.objects
763 query = table.update().values(validityEnd=dt).where(
764 sql.expression.and_(
765 table.columns["diaObjectId"].in_(ids),
766 table.columns["validityEnd"].is_(None),
767 )
768 )
770 # _LOG.debug("query: %s", query)
772 if self.config.explain:
773 # run the same query with explain
774 self._explain(query, conn)
776 with Timer(table.name + ' truncate', self.config.timer):
777 res = conn.execute(query)
778 _LOG.debug("truncated %s intervals", res.rowcount)
780 # insert new versions
781 table = self._schema.objects
782 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt,
783 validityEnd=None)
784 with Timer("DiaObject insert", self.config.timer):
785 objs = _coerce_uint64(objs)
786 if extra_columns:
787 columns: List[pandas.Series] = []
788 for col, data in extra_columns.items():
789 columns.append(pandas.Series([data]*len(objs), name=col))
790 objs.set_index(columns[0].index, inplace=True)
791 objs = pandas.concat([objs] + columns, axis="columns")
792 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
794 def _storeDiaSources(self, sources: pandas.DataFrame) -> None:
795 """Store catalog of DiaSources from current visit.
797 Parameters
798 ----------
799 sources : `pandas.DataFrame`
800 Catalog containing DiaSource records
801 """
802 # everything to be done in single transaction
803 with _ansi_session(self._engine) as conn:
805 with Timer("DiaSource insert", self.config.timer):
806 sources = _coerce_uint64(sources)
807 table = self._schema.sources
808 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
810 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None:
811 """Store a set of DiaForcedSources from current visit.
813 Parameters
814 ----------
815 sources : `pandas.DataFrame`
816 Catalog containing DiaForcedSource records
817 """
819 # everything to be done in single transaction
820 with _ansi_session(self._engine) as conn:
822 with Timer("DiaForcedSource insert", self.config.timer):
823 sources = _coerce_uint64(sources)
824 table = self._schema.forcedSources
825 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema)
827 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None:
828 """Run the query with explain
829 """
831 _LOG.info("explain for query: %s...", query[:64])
833 if conn.engine.name == 'mysql':
834 query = "EXPLAIN EXTENDED " + query
835 else:
836 query = "EXPLAIN " + query
838 res = conn.execute(sql.text(query))
839 if res.returns_rows:
840 _LOG.info("explain: %s", res.keys())
841 for row in res:
842 _LOG.info("explain: %s", row)
843 else:
844 _LOG.info("EXPLAIN returned nothing")
846 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
847 """Generate a set of HTM indices covering specified region.
849 Parameters
850 ----------
851 region: `sphgeom.Region`
852 Region that needs to be indexed.
854 Returns
855 -------
856 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
857 """
858 _LOG.debug('region: %s', region)
859 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
861 return indices.ranges()
863 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement:
864 """Make SQLAlchemy expression for selecting records in a region.
865 """
866 htm_index_column = table.columns[self.config.htm_index_column]
867 exprlist = []
868 pixel_ranges = self._htm_indices(region)
869 for low, upper in pixel_ranges:
870 upper -= 1
871 if low == upper:
872 exprlist.append(htm_index_column == low)
873 else:
874 exprlist.append(sql.expression.between(htm_index_column, low, upper))
876 return sql.expression.or_(*exprlist)
878 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
879 """Calculate HTM index for each record and add it to a DataFrame.
881 Notes
882 -----
883 This overrides any existing column in a DataFrame with the same name
884 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
885 returned.
886 """
887 # calculate HTM index for every DiaObject
888 htm_index = np.zeros(df.shape[0], dtype=np.int64)
889 ra_col, dec_col = self.config.ra_dec_columns
890 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
891 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
892 idx = self.pixelator.index(uv3d)
893 htm_index[i] = idx
894 df = df.copy()
895 df[self.config.htm_index_column] = htm_index
896 return df
898 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
899 """Add pixelId column to DiaSource catalog.
901 Notes
902 -----
903 This method copies pixelId value from a matching DiaObject record.
904 DiaObject catalog needs to have a pixelId column filled by
905 ``_add_obj_htm_index`` method and DiaSource records need to be
906 associated to DiaObjects via ``diaObjectId`` column.
908 This overrides any existing column in a DataFrame with the same name
909 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
910 returned.
911 """
912 pixel_id_map: Dict[int, int] = {
913 diaObjectId: pixelId for diaObjectId, pixelId
914 in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
915 }
916 # DiaSources associated with SolarSystemObjects do not have an
917 # associated DiaObject hence we skip them and set their htmIndex
918 # value to 0.
919 pixel_id_map[0] = 0
920 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
921 for i, diaObjId in enumerate(sources["diaObjectId"]):
922 htm_index[i] = pixel_id_map[diaObjId]
923 sources = sources.copy()
924 sources[self.config.htm_index_column] = htm_index
925 return sources