Coverage for python/lsst/dax/apdb/apdbSql.py: 13%
377 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-06 20:50 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-06 20:50 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29from contextlib import contextmanager
30import logging
31import numpy as np
32import pandas
33from typing import cast, Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple
35import lsst.daf.base as dafBase
36from lsst.pex.config import Field, ChoiceField, ListField
37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
38import sqlalchemy
39from sqlalchemy import (func, sql)
40from sqlalchemy.pool import NullPool
41from .apdb import Apdb, ApdbConfig
42from .apdbSchema import ApdbTables, TableDef
43from .apdbSqlSchema import ApdbSqlSchema
44from .timer import Timer
47_LOG = logging.getLogger(__name__)
50def _split(seq: Iterable, nItems: int) -> Iterator[List]:
51 """Split a sequence into smaller sequences"""
52 seq = list(seq)
53 while seq:
54 yield seq[:nItems]
55 del seq[:nItems]
58def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
59 """Change type of the uint64 columns to int64, return copy of data frame.
60 """
61 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
62 return df.astype({name: np.int64 for name in names})
65def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
66 """Calculate starting point for time-based source search.
68 Parameters
69 ----------
70 visit_time : `lsst.daf.base.DateTime`
71 Time of current visit.
72 months : `int`
73 Number of months in the sources history.
75 Returns
76 -------
77 time : `float`
78 A ``midPointTai`` starting point, MJD time.
79 """
80 # TODO: `system` must be consistent with the code in ap_association
81 # (see DM-31996)
82 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
85@contextmanager
86def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]:
87 """Returns a connection, makes sure that ANSI mode is set for MySQL
88 """
89 with engine.begin() as conn:
90 if engine.name == 'mysql':
91 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'"))
92 yield conn
93 return
96class ApdbSqlConfig(ApdbConfig):
97 """APDB configuration class for SQL implementation (ApdbSql).
98 """
99 db_url = Field(
100 dtype=str,
101 doc="SQLAlchemy database connection URI"
102 )
103 isolation_level = ChoiceField(
104 dtype=str,
105 doc="Transaction isolation level, if unset then backend-default value "
106 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
107 "Some backends may not support every allowed value.",
108 allowed={
109 "READ_COMMITTED": "Read committed",
110 "READ_UNCOMMITTED": "Read uncommitted",
111 "REPEATABLE_READ": "Repeatable read",
112 "SERIALIZABLE": "Serializable"
113 },
114 default=None,
115 optional=True
116 )
117 connection_pool = Field(
118 dtype=bool,
119 doc="If False then disable SQLAlchemy connection pool. "
120 "Do not use connection pool when forking.",
121 default=True
122 )
123 connection_timeout = Field(
124 dtype=float,
125 doc="Maximum time to wait time for database lock to be released before "
126 "exiting. Defaults to sqlachemy defaults if not set.",
127 default=None,
128 optional=True
129 )
130 sql_echo = Field(
131 dtype=bool,
132 doc="If True then pass SQLAlchemy echo option.",
133 default=False
134 )
135 dia_object_index = ChoiceField(
136 dtype=str,
137 doc="Indexing mode for DiaObject table",
138 allowed={
139 'baseline': "Index defined in baseline schema",
140 'pix_id_iov': "(pixelId, objectId, iovStart) PK",
141 'last_object_table': "Separate DiaObjectLast table"
142 },
143 default='baseline'
144 )
145 htm_level = Field(
146 dtype=int,
147 doc="HTM indexing level",
148 default=20
149 )
150 htm_max_ranges = Field(
151 dtype=int,
152 doc="Max number of ranges in HTM envelope",
153 default=64
154 )
155 htm_index_column = Field(
156 dtype=str,
157 default="pixelId",
158 doc="Name of a HTM index column for DiaObject and DiaSource tables"
159 )
160 ra_dec_columns = ListField(
161 dtype=str,
162 default=["ra", "decl"],
163 doc="Names ra/dec columns in DiaObject table"
164 )
165 dia_object_columns = ListField(
166 dtype=str,
167 doc="List of columns to read from DiaObject, by default read all columns",
168 default=[]
169 )
170 object_last_replace = Field(
171 dtype=bool,
172 doc="If True (default) then use \"upsert\" for DiaObjectsLast table",
173 default=True
174 )
175 prefix = Field(
176 dtype=str,
177 doc="Prefix to add to table names and index names",
178 default=""
179 )
180 explain = Field(
181 dtype=bool,
182 doc="If True then run EXPLAIN SQL command on each executed query",
183 default=False
184 )
185 timer = Field(
186 dtype=bool,
187 doc="If True then print/log timing information",
188 default=False
189 )
191 def validate(self) -> None:
192 super().validate()
193 if len(self.ra_dec_columns) != 2:
194 raise ValueError("ra_dec_columns must have exactly two column names")
197class ApdbSql(Apdb):
198 """Implementation of APDB interface based on SQL database.
200 The implementation is configured via standard ``pex_config`` mechanism
201 using `ApdbSqlConfig` configuration class. For an example of different
202 configurations check ``config/`` folder.
204 Parameters
205 ----------
206 config : `ApdbSqlConfig`
207 Configuration object.
208 """
210 ConfigClass = ApdbSqlConfig
212 def __init__(self, config: ApdbSqlConfig):
214 self.config = config
216 _LOG.debug("APDB Configuration:")
217 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
218 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
219 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
220 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
221 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace)
222 _LOG.debug(" schema_file: %s", self.config.schema_file)
223 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
224 _LOG.debug(" schema prefix: %s", self.config.prefix)
226 # engine is reused between multiple processes, make sure that we don't
227 # share connections by disabling pool (by using NullPool class)
228 kw = dict(echo=self.config.sql_echo)
229 conn_args: Dict[str, Any] = dict()
230 if not self.config.connection_pool:
231 kw.update(poolclass=NullPool)
232 if self.config.isolation_level is not None:
233 kw.update(isolation_level=self.config.isolation_level)
234 elif self.config.db_url.startswith("sqlite"):
235 # Use READ_UNCOMMITTED as default value for sqlite.
236 kw.update(isolation_level="READ_UNCOMMITTED")
237 if self.config.connection_timeout is not None:
238 if self.config.db_url.startswith("sqlite"):
239 conn_args.update(timeout=self.config.connection_timeout)
240 elif self.config.db_url.startswith(("postgresql", "mysql")):
241 conn_args.update(connect_timeout=self.config.connection_timeout)
242 kw.update(connect_args=conn_args)
243 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
245 self._schema = ApdbSqlSchema(engine=self._engine,
246 dia_object_index=self.config.dia_object_index,
247 schema_file=self.config.schema_file,
248 schema_name=self.config.schema_name,
249 prefix=self.config.prefix,
250 htm_index_column=self.config.htm_index_column)
252 self.pixelator = HtmPixelization(self.config.htm_level)
254 def tableRowCount(self) -> Dict[str, int]:
255 """Returns dictionary with the table names and row counts.
257 Used by ``ap_proto`` to keep track of the size of the database tables.
258 Depending on database technology this could be expensive operation.
260 Returns
261 -------
262 row_counts : `dict`
263 Dict where key is a table name and value is a row count.
264 """
265 res = {}
266 tables: List[sqlalchemy.schema.Table] = [
267 self._schema.objects, self._schema.sources, self._schema.forcedSources]
268 if self.config.dia_object_index == 'last_object_table':
269 tables.append(self._schema.objects_last)
270 for table in tables:
271 stmt = sql.select([func.count()]).select_from(table)
272 count = self._engine.scalar(stmt)
273 res[table.name] = count
275 return res
277 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
278 # docstring is inherited from a base class
279 return self._schema.tableSchemas.get(table)
281 def makeSchema(self, drop: bool = False) -> None:
282 # docstring is inherited from a base class
283 self._schema.makeSchema(drop=drop)
285 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
286 # docstring is inherited from a base class
288 # decide what columns we need
289 table: sqlalchemy.schema.Table
290 if self.config.dia_object_index == 'last_object_table':
291 table = self._schema.objects_last
292 else:
293 table = self._schema.objects
294 if not self.config.dia_object_columns:
295 query = table.select()
296 else:
297 columns = [table.c[col] for col in self.config.dia_object_columns]
298 query = sql.select(columns)
300 # build selection
301 query = query.where(self._filterRegion(table, region))
303 # select latest version of objects
304 if self.config.dia_object_index != 'last_object_table':
305 query = query.where(table.c.validityEnd == None) # noqa: E711
307 _LOG.debug("query: %s", query)
309 if self.config.explain:
310 # run the same query with explain
311 self._explain(query, self._engine)
313 # execute select
314 with Timer('DiaObject select', self.config.timer):
315 with self._engine.begin() as conn:
316 objects = pandas.read_sql_query(query, conn)
317 _LOG.debug("found %s DiaObjects", len(objects))
318 return objects
320 def getDiaSources(self, region: Region,
321 object_ids: Optional[Iterable[int]],
322 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
323 # docstring is inherited from a base class
324 if self.config.read_sources_months == 0:
325 _LOG.debug("Skip DiaSources fetching")
326 return None
328 if object_ids is None:
329 # region-based select
330 return self._getDiaSourcesInRegion(region, visit_time)
331 else:
332 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
334 def getDiaForcedSources(self, region: Region,
335 object_ids: Optional[Iterable[int]],
336 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
337 """Return catalog of DiaForcedSource instances from a given region.
339 Parameters
340 ----------
341 region : `lsst.sphgeom.Region`
342 Region to search for DIASources.
343 object_ids : iterable [ `int` ], optional
344 List of DiaObject IDs to further constrain the set of returned
345 sources. If list is empty then empty catalog is returned with a
346 correct schema.
347 visit_time : `lsst.daf.base.DateTime`
348 Time of the current visit.
350 Returns
351 -------
352 catalog : `pandas.DataFrame`, or `None`
353 Catalog containing DiaSource records. `None` is returned if
354 ``read_sources_months`` configuration parameter is set to 0.
356 Raises
357 ------
358 NotImplementedError
359 Raised if ``object_ids`` is `None`.
361 Notes
362 -----
363 Even though base class allows `None` to be passed for ``object_ids``,
364 this class requires ``object_ids`` to be not-`None`.
365 `NotImplementedError` is raised if `None` is passed.
367 This method returns DiaForcedSource catalog for a region with additional
368 filtering based on DiaObject IDs. Only a subset of DiaSource history
369 is returned limited by ``read_forced_sources_months`` config parameter,
370 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
371 is always returned with a correct schema (columns/types).
372 """
374 if self.config.read_forced_sources_months == 0:
375 _LOG.debug("Skip DiaForceSources fetching")
376 return None
378 if object_ids is None:
379 # This implementation does not support region-based selection.
380 raise NotImplementedError("Region-based selection is not supported")
382 # TODO: DateTime.MJD must be consistent with code in ap_association,
383 # alternatively we can fill midPointTai ourselves in store()
384 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
385 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
387 table: sqlalchemy.schema.Table = self._schema.forcedSources
388 with Timer('DiaForcedSource select', self.config.timer):
389 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start)
391 _LOG.debug("found %s DiaForcedSources", len(sources))
392 return sources
394 def getDiaObjectsHistory(self,
395 start_time: dafBase.DateTime,
396 end_time: dafBase.DateTime,
397 region: Optional[Region] = None) -> pandas.DataFrame:
398 # docstring is inherited from a base class
400 table = self._schema.objects
401 query = table.select()
403 # build selection
404 time_filter = sql.expression.and_(
405 table.columns["validityStart"] >= start_time.toPython(),
406 table.columns["validityStart"] < end_time.toPython()
407 )
409 if region:
410 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
411 query = query.where(where)
412 else:
413 query = query.where(time_filter)
415 # execute select
416 with Timer('DiaObject history select', self.config.timer):
417 with self._engine.begin() as conn:
418 catalog = pandas.read_sql_query(query, conn)
419 _LOG.debug("found %s DiaObjects history records", len(catalog))
420 return catalog
422 def getDiaSourcesHistory(self,
423 start_time: dafBase.DateTime,
424 end_time: dafBase.DateTime,
425 region: Optional[Region] = None) -> pandas.DataFrame:
426 # docstring is inherited from a base class
428 table = self._schema.sources
429 query = table.select()
431 # build selection
432 time_filter = sql.expression.and_(
433 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
434 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
435 )
437 if region:
438 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
439 query = query.where(where)
440 else:
441 query = query.where(time_filter)
443 # execute select
444 with Timer('DiaSource history select', self.config.timer):
445 with self._engine.begin() as conn:
446 catalog = pandas.read_sql_query(query, conn)
447 _LOG.debug("found %s DiaSource history records", len(catalog))
448 return catalog
450 def getDiaForcedSourcesHistory(self,
451 start_time: dafBase.DateTime,
452 end_time: dafBase.DateTime,
453 region: Optional[Region] = None) -> pandas.DataFrame:
454 # docstring is inherited from a base class
456 table = self._schema.forcedSources
457 query = table.select()
459 # build selection
460 time_filter = sql.expression.and_(
461 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD),
462 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD)
463 )
464 # Forced sources have no pixel index, so no region filtering
465 query = query.where(time_filter)
467 # execute select
468 with Timer('DiaForcedSource history select', self.config.timer):
469 with self._engine.begin() as conn:
470 catalog = pandas.read_sql_query(query, conn)
471 _LOG.debug("found %s DiaForcedSource history records", len(catalog))
472 return catalog
474 def getSSObjects(self) -> pandas.DataFrame:
475 # docstring is inherited from a base class
477 table = self._schema.ssObjects
478 query = table.select()
480 if self.config.explain:
481 # run the same query with explain
482 self._explain(query, self._engine)
484 # execute select
485 with Timer('DiaObject select', self.config.timer):
486 with self._engine.begin() as conn:
487 objects = pandas.read_sql_query(query, conn)
488 _LOG.debug("found %s SSObjects", len(objects))
489 return objects
491 def store(self,
492 visit_time: dafBase.DateTime,
493 objects: pandas.DataFrame,
494 sources: Optional[pandas.DataFrame] = None,
495 forced_sources: Optional[pandas.DataFrame] = None) -> None:
496 # docstring is inherited from a base class
498 # fill pixelId column for DiaObjects
499 objects = self._add_obj_htm_index(objects)
500 self._storeDiaObjects(objects, visit_time)
502 if sources is not None:
503 # copy pixelId column from DiaObjects to DiaSources
504 sources = self._add_src_htm_index(sources, objects)
505 self._storeDiaSources(sources)
507 if forced_sources is not None:
508 self._storeDiaForcedSources(forced_sources)
510 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
511 # docstring is inherited from a base class
513 idColumn = "ssObjectId"
514 table = self._schema.ssObjects
516 # everything to be done in single transaction
517 with self._engine.begin() as conn:
519 # find record IDs that already exist
520 ids = sorted(objects[idColumn])
521 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
522 result = conn.execute(query)
523 knownIds = set(row[idColumn] for row in result)
525 filter = objects[idColumn].isin(knownIds)
526 toUpdate = cast(pandas.DataFrame, objects[filter])
527 toInsert = cast(pandas.DataFrame, objects[~filter])
529 # insert new records
530 if len(toInsert) > 0:
531 toInsert.to_sql(ApdbTables.SSObject.table_name(), conn, if_exists='append', index=False)
533 # update existing records
534 if len(toUpdate) > 0:
535 whereKey = f"{idColumn}_param"
536 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
537 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
538 values = toUpdate.to_dict("records")
539 result = conn.execute(query, values)
541 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
542 # docstring is inherited from a base class
544 table = self._schema.sources
545 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
547 with self._engine.begin() as conn:
548 # Need to make sure that every ID exists in the database, but
549 # executemany may not support rowcount, so iterate and check what is
550 # missing.
551 missing_ids: List[int] = []
552 for key, value in idMap.items():
553 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
554 result = conn.execute(query, params)
555 if result.rowcount == 0:
556 missing_ids.append(key)
557 if missing_ids:
558 missing = ",".join(str(item)for item in missing_ids)
559 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
561 def dailyJob(self) -> None:
562 # docstring is inherited from a base class
564 if self._engine.name == 'postgresql':
566 # do VACUUM on all tables
567 _LOG.info("Running VACUUM on all tables")
568 connection = self._engine.raw_connection()
569 ISOLATION_LEVEL_AUTOCOMMIT = 0
570 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
571 cursor = connection.cursor()
572 cursor.execute("VACUUM ANALYSE")
574 def countUnassociatedObjects(self) -> int:
575 # docstring is inherited from a base class
577 # Retrieve the DiaObject table.
578 table: sqlalchemy.schema.Table = self._schema.objects
580 # Construct the sql statement.
581 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
582 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
584 # Return the count.
585 with self._engine.begin() as conn:
586 count = conn.scalar(stmt)
588 return count
590 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime
591 ) -> pandas.DataFrame:
592 """Returns catalog of DiaSource instances from given region.
594 Parameters
595 ----------
596 region : `lsst.sphgeom.Region`
597 Region to search for DIASources.
598 visit_time : `lsst.daf.base.DateTime`
599 Time of the current visit.
601 Returns
602 -------
603 catalog : `pandas.DataFrame`
604 Catalog containing DiaSource records.
605 """
606 # TODO: DateTime.MJD must be consistent with code in ap_association,
607 # alternatively we can fill midPointTai ourselves in store()
608 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
609 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
611 table: sqlalchemy.schema.Table = self._schema.sources
612 query = table.select()
614 # build selection
615 time_filter = table.columns["midPointTai"] > midPointTai_start
616 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
617 query = query.where(where)
619 # execute select
620 with Timer('DiaSource select', self.config.timer):
621 with _ansi_session(self._engine) as conn:
622 sources = pandas.read_sql_query(query, conn)
623 _LOG.debug("found %s DiaSources", len(sources))
624 return sources
626 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime
627 ) -> pandas.DataFrame:
628 """Returns catalog of DiaSource instances given set of DiaObject IDs.
630 Parameters
631 ----------
632 object_ids :
633 Collection of DiaObject IDs
634 visit_time : `lsst.daf.base.DateTime`
635 Time of the current visit.
637 Returns
638 -------
639 catalog : `pandas.DataFrame`
640 Catalog contaning DiaSource records.
641 """
642 # TODO: DateTime.MJD must be consistent with code in ap_association,
643 # alternatively we can fill midPointTai ourselves in store()
644 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
645 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
647 table: sqlalchemy.schema.Table = self._schema.sources
648 with Timer('DiaSource select', self.config.timer):
649 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start)
651 _LOG.debug("found %s DiaSources", len(sources))
652 return sources
654 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table,
655 object_ids: List[int],
656 midPointTai_start: float
657 ) -> pandas.DataFrame:
658 """Returns catalog of DiaSource or DiaForcedSource instances given set
659 of DiaObject IDs.
661 Parameters
662 ----------
663 table : `sqlalchemy.schema.Table`
664 Database table.
665 object_ids :
666 Collection of DiaObject IDs
667 midPointTai_start : `float`
668 Earliest midPointTai to retrieve.
670 Returns
671 -------
672 catalog : `pandas.DataFrame`
673 Catalog contaning DiaSource records. `None` is returned if
674 ``read_sources_months`` configuration parameter is set to 0 or
675 when ``object_ids`` is empty.
676 """
677 sources: Optional[pandas.DataFrame] = None
678 with _ansi_session(self._engine) as conn:
679 if len(object_ids) <= 0:
680 _LOG.debug("ID list is empty, just fetch empty result")
681 query = table.select().where(False)
682 sources = pandas.read_sql_query(query, conn)
683 else:
684 for ids in _split(sorted(object_ids), 1000):
685 query = f'SELECT * FROM "{table.name}" WHERE '
687 # select by object id
688 ids_str = ",".join(str(id) for id in ids)
689 query += f'"diaObjectId" IN ({ids_str})'
690 query += f' AND "midPointTai" > {midPointTai_start}'
692 # execute select
693 df = pandas.read_sql_query(sql.text(query), conn)
694 if sources is None:
695 sources = df
696 else:
697 sources = sources.append(df)
698 assert sources is not None, "Catalog cannot be None"
699 return sources
701 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
702 """Store catalog of DiaObjects from current visit.
704 Parameters
705 ----------
706 objs : `pandas.DataFrame`
707 Catalog with DiaObject records.
708 visit_time : `lsst.daf.base.DateTime`
709 Time of the visit.
710 """
712 ids = sorted(objs['diaObjectId'])
713 _LOG.debug("first object ID: %d", ids[0])
715 # NOTE: workaround for sqlite, need this here to avoid
716 # "database is locked" error.
717 table: sqlalchemy.schema.Table = self._schema.objects
719 # TODO: Need to verify that we are using correct scale here for
720 # DATETIME representation (see DM-31996).
721 dt = visit_time.toPython()
723 # everything to be done in single transaction
724 with _ansi_session(self._engine) as conn:
726 ids_str = ",".join(str(id) for id in ids)
728 if self.config.dia_object_index == 'last_object_table':
730 # insert and replace all records in LAST table, mysql and postgres have
731 # non-standard features
732 table = self._schema.objects_last
733 do_replace = self.config.object_last_replace
734 # If the input data is of type Pandas, we drop the previous
735 # objects regardless of the do_replace setting due to how
736 # Pandas inserts objects.
737 if not do_replace or isinstance(objs, pandas.DataFrame):
738 query = 'DELETE FROM "' + table.name + '" '
739 query += 'WHERE "diaObjectId" IN (' + ids_str + ') '
741 if self.config.explain:
742 # run the same query with explain
743 self._explain(query, conn)
745 with Timer(table.name + ' delete', self.config.timer):
746 res = conn.execute(sql.text(query))
747 _LOG.debug("deleted %s objects", res.rowcount)
749 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt)
750 with Timer("DiaObjectLast insert", self.config.timer):
751 objs = _coerce_uint64(objs)
752 for col, data in extra_columns.items():
753 objs[col] = data
754 objs.to_sql("DiaObjectLast", conn, if_exists='append',
755 index=False)
756 else:
758 # truncate existing validity intervals
759 table = self._schema.objects
760 query = 'UPDATE "' + table.name + '" '
761 query += "SET \"validityEnd\" = '" + str(dt) + "' "
762 query += 'WHERE "diaObjectId" IN (' + ids_str + ') '
763 query += 'AND "validityEnd" IS NULL'
765 # _LOG.debug("query: %s", query)
767 if self.config.explain:
768 # run the same query with explain
769 self._explain(query, conn)
771 with Timer(table.name + ' truncate', self.config.timer):
772 res = conn.execute(sql.text(query))
773 _LOG.debug("truncated %s intervals", res.rowcount)
775 # insert new versions
776 table = self._schema.objects
777 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt,
778 validityEnd=None)
779 with Timer("DiaObject insert", self.config.timer):
780 objs = _coerce_uint64(objs)
781 for col, data in extra_columns.items():
782 objs[col] = data
783 objs.to_sql("DiaObject", conn, if_exists='append',
784 index=False)
786 def _storeDiaSources(self, sources: pandas.DataFrame) -> None:
787 """Store catalog of DiaSources from current visit.
789 Parameters
790 ----------
791 sources : `pandas.DataFrame`
792 Catalog containing DiaSource records
793 """
794 # everything to be done in single transaction
795 with _ansi_session(self._engine) as conn:
797 with Timer("DiaSource insert", self.config.timer):
798 sources = _coerce_uint64(sources)
799 sources.to_sql("DiaSource", conn, if_exists='append', index=False)
801 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None:
802 """Store a set of DiaForcedSources from current visit.
804 Parameters
805 ----------
806 sources : `pandas.DataFrame`
807 Catalog containing DiaForcedSource records
808 """
810 # everything to be done in single transaction
811 with _ansi_session(self._engine) as conn:
813 with Timer("DiaForcedSource insert", self.config.timer):
814 sources = _coerce_uint64(sources)
815 sources.to_sql("DiaForcedSource", conn, if_exists='append', index=False)
817 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None:
818 """Run the query with explain
819 """
821 _LOG.info("explain for query: %s...", query[:64])
823 if conn.engine.name == 'mysql':
824 query = "EXPLAIN EXTENDED " + query
825 else:
826 query = "EXPLAIN " + query
828 res = conn.execute(sql.text(query))
829 if res.returns_rows:
830 _LOG.info("explain: %s", res.keys())
831 for row in res:
832 _LOG.info("explain: %s", row)
833 else:
834 _LOG.info("EXPLAIN returned nothing")
836 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
837 """Generate a set of HTM indices covering specified region.
839 Parameters
840 ----------
841 region: `sphgeom.Region`
842 Region that needs to be indexed.
844 Returns
845 -------
846 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
847 """
848 _LOG.debug('region: %s', region)
849 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
851 if _LOG.isEnabledFor(logging.DEBUG):
852 for irange in indices.ranges():
853 _LOG.debug('range: %s %s', self.pixelator.toString(irange[0]),
854 self.pixelator.toString(irange[1]))
856 return indices.ranges()
858 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement:
859 """Make SQLAlchemy expression for selecting records in a region.
860 """
861 htm_index_column = table.columns[self.config.htm_index_column]
862 exprlist = []
863 pixel_ranges = self._htm_indices(region)
864 for low, upper in pixel_ranges:
865 upper -= 1
866 if low == upper:
867 exprlist.append(htm_index_column == low)
868 else:
869 exprlist.append(sql.expression.between(htm_index_column, low, upper))
871 return sql.expression.or_(*exprlist)
873 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
874 """Calculate HTM index for each record and add it to a DataFrame.
876 Notes
877 -----
878 This overrides any existing column in a DataFrame with the same name
879 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
880 returned.
881 """
882 # calculate HTM index for every DiaObject
883 htm_index = np.zeros(df.shape[0], dtype=np.int64)
884 ra_col, dec_col = self.config.ra_dec_columns
885 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
886 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
887 idx = self.pixelator.index(uv3d)
888 htm_index[i] = idx
889 df = df.copy()
890 df[self.config.htm_index_column] = htm_index
891 return df
893 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
894 """Add pixelId column to DiaSource catalog.
896 Notes
897 -----
898 This method copies pixelId value from a matching DiaObject record.
899 DiaObject catalog needs to have a pixelId column filled by
900 ``_add_obj_htm_index`` method and DiaSource records need to be
901 associated to DiaObjects via ``diaObjectId`` column.
903 This overrides any existing column in a DataFrame with the same name
904 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
905 returned.
906 """
907 pixel_id_map: Dict[int, int] = {
908 diaObjectId: pixelId for diaObjectId, pixelId
909 in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
910 }
911 # DiaSources associated with SolarSystemObjects do not have an
912 # associated DiaObject hence we skip them and set their htmIndex
913 # value to 0.
914 pixel_id_map[0] = 0
915 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
916 for i, diaObjId in enumerate(sources["diaObjectId"]):
917 htm_index[i] = pixel_id_map[diaObjId]
918 sources = sources.copy()
919 sources[self.config.htm_index_column] = htm_index
920 return sources