Coverage for python/lsst/dax/apdb/apdbSql.py : 16%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining Apdb class and related methods.
23"""
25from __future__ import annotations
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
29from contextlib import contextmanager
30import logging
31import numpy as np
32import pandas
33from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type
35import lsst.daf.base as dafBase
36from lsst.pex.config import Field, ChoiceField, ListField
37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
38import sqlalchemy
39from sqlalchemy import (func, sql)
40from sqlalchemy.pool import NullPool
41from .apdb import Apdb, ApdbConfig
42from .apdbSqlSchema import ApdbSqlSchema
43from . import timer
46_LOG = logging.getLogger(__name__)
49class Timer:
50 """Timer class defining context manager which tracks execution timing.
52 Typical use:
54 with Timer("timer_name"):
55 do_something
57 On exit from block it will print elapsed time.
59 See also :py:mod:`timer` module.
60 """
61 def __init__(self, name: str, do_logging: bool = True, log_before_cursor_execute: bool = False):
62 self._log_before_cursor_execute = log_before_cursor_execute
63 self._do_logging = do_logging
64 self._timer1 = timer.Timer(name)
65 self._timer2 = timer.Timer(name + " (before/after cursor)")
67 def __enter__(self) -> Timer:
68 """
69 Enter context, start timer
70 """
71# event.listen(engine.Engine, "before_cursor_execute", self._start_timer)
72# event.listen(engine.Engine, "after_cursor_execute", self._stop_timer)
73 self._timer1.start()
74 return self
76 def __exit__(self, exc_type: Optional[Type], exc_val: Any, exc_tb: Any) -> Any:
77 """
78 Exit context, stop and dump timer
79 """
80 if exc_type is None:
81 self._timer1.stop()
82 if self._do_logging:
83 self._timer1.dump()
84# event.remove(engine.Engine, "before_cursor_execute", self._start_timer)
85# event.remove(engine.Engine, "after_cursor_execute", self._stop_timer)
86 return False
88 def _start_timer(self, conn, cursor, statement, parameters, context, executemany): # type: ignore
89 """Start counting"""
90 if self._log_before_cursor_execute:
91 _LOG.info("before_cursor_execute")
92 self._timer2.start()
94 def _stop_timer(self, conn, cursor, statement, parameters, context, executemany): # type: ignore
95 """Stop counting"""
96 self._timer2.stop()
97 if self._do_logging:
98 self._timer2.dump()
101def _split(seq: Iterable, nItems: int) -> Iterator[List]:
102 """Split a sequence into smaller sequences"""
103 seq = list(seq)
104 while seq:
105 yield seq[:nItems]
106 del seq[:nItems]
109def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
110 """Change type of the uint64 columns to int64, return copy of data frame.
111 """
112 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
113 return df.astype({name: np.int64 for name in names})
116def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
117 """Calculate starting point for time-based source search.
119 Parameters
120 ----------
121 visit_time : `lsst.daf.base.DateTime`
122 Time of current visit.
123 months : `int`
124 Number of months in the sources history.
126 Returns
127 -------
128 time : `float`
129 A ``midPointTai`` starting point, MJD time.
130 """
131 # TODO: `system` must be consistent with the code in ap_association
132 # (see DM-31996)
133 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
136@contextmanager
137def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]:
138 """Returns a connection, makes sure that ANSI mode is set for MySQL
139 """
140 with engine.begin() as conn:
141 if engine.name == 'mysql':
142 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'"))
143 yield conn
144 return
147class ApdbSqlConfig(ApdbConfig):
148 """APDB configuration class for SQL implementation (ApdbSql).
149 """
150 db_url = Field(dtype=str, doc="SQLAlchemy database connection URI")
151 isolation_level = ChoiceField(dtype=str,
152 doc="Transaction isolation level, if unset then backend-default value "
153 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
154 "Some backends may not support every allowed value.",
155 allowed={"READ_COMMITTED": "Read committed",
156 "READ_UNCOMMITTED": "Read uncommitted",
157 "REPEATABLE_READ": "Repeatable read",
158 "SERIALIZABLE": "Serializable"},
159 default=None,
160 optional=True)
161 connection_pool = Field(dtype=bool,
162 doc=("If False then disable SQLAlchemy connection pool. "
163 "Do not use connection pool when forking."),
164 default=True)
165 connection_timeout = Field(dtype=float,
166 doc="Maximum time to wait time for database lock to be released before "
167 "exiting. Defaults to sqlachemy defaults if not set.",
168 default=None,
169 optional=True)
170 sql_echo = Field(dtype=bool,
171 doc="If True then pass SQLAlchemy echo option.",
172 default=False)
173 dia_object_index = ChoiceField(dtype=str,
174 doc="Indexing mode for DiaObject table",
175 allowed={'baseline': "Index defined in baseline schema",
176 'pix_id_iov': "(pixelId, objectId, iovStart) PK",
177 'last_object_table': "Separate DiaObjectLast table"},
178 default='baseline')
179 htm_level = Field(dtype=int,
180 doc="HTM indexing level",
181 default=20)
182 htm_max_ranges = Field(dtype=int,
183 doc="Max number of ranges in HTM envelope",
184 default=64)
185 htm_index_column = Field(dtype=str, default="pixelId",
186 doc="Name of a HTM index column for DiaObject and DiaSource tables")
187 ra_dec_columns = ListField(dtype=str, default=["ra", "decl"],
188 doc="Names ra/dec columns in DiaObject table")
189 dia_object_columns = ListField(dtype=str,
190 doc="List of columns to read from DiaObject, by default read all columns",
191 default=[])
192 object_last_replace = Field(dtype=bool,
193 doc="If True (default) then use \"upsert\" for DiaObjectsLast table",
194 default=True)
195 prefix = Field(dtype=str,
196 doc="Prefix to add to table names and index names",
197 default="")
198 explain = Field(dtype=bool,
199 doc="If True then run EXPLAIN SQL command on each executed query",
200 default=False)
201 timer = Field(dtype=bool,
202 doc="If True then print/log timing information",
203 default=False)
205 def validate(self) -> None:
206 super().validate()
207 if len(self.ra_dec_columns) != 2:
208 raise ValueError("ra_dec_columns must have exactly two column names")
211class ApdbSql(Apdb):
212 """Implementation of APDB interface based on SQL database.
214 The implementation is configured via standard ``pex_config`` mechanism
215 using `ApdbSqlConfig` configuration class. For an example of different
216 configurations check ``config/`` folder.
218 Parameters
219 ----------
220 config : `ApdbSqlConfig`
221 Configuration object.
222 """
224 ConfigClass = ApdbSqlConfig
226 def __init__(self, config: ApdbSqlConfig):
228 self.config = config
230 _LOG.debug("APDB Configuration:")
231 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
232 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
233 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
234 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
235 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace)
236 _LOG.debug(" schema_file: %s", self.config.schema_file)
237 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
238 _LOG.debug(" schema prefix: %s", self.config.prefix)
240 # engine is reused between multiple processes, make sure that we don't
241 # share connections by disabling pool (by using NullPool class)
242 kw = dict(echo=self.config.sql_echo)
243 conn_args: Dict[str, Any] = dict()
244 if not self.config.connection_pool:
245 kw.update(poolclass=NullPool)
246 if self.config.isolation_level is not None:
247 kw.update(isolation_level=self.config.isolation_level)
248 elif self.config.db_url.startswith("sqlite"):
249 # Use READ_UNCOMMITTED as default value for sqlite.
250 kw.update(isolation_level="READ_UNCOMMITTED")
251 if self.config.connection_timeout is not None:
252 if self.config.db_url.startswith("sqlite"):
253 conn_args.update(timeout=self.config.connection_timeout)
254 elif self.config.db_url.startswith(("postgresql", "mysql")):
255 conn_args.update(connect_timeout=self.config.connection_timeout)
256 kw.update(connect_args=conn_args)
257 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
259 self._schema = ApdbSqlSchema(engine=self._engine,
260 dia_object_index=self.config.dia_object_index,
261 schema_file=self.config.schema_file,
262 extra_schema_file=self.config.extra_schema_file,
263 prefix=self.config.prefix)
265 self.pixelator = HtmPixelization(self.config.htm_level)
267 def tableRowCount(self) -> Dict[str, int]:
268 """Returns dictionary with the table names and row counts.
270 Used by ``ap_proto`` to keep track of the size of the database tables.
271 Depending on database technology this could be expensive operation.
273 Returns
274 -------
275 row_counts : `dict`
276 Dict where key is a table name and value is a row count.
277 """
278 res = {}
279 tables: List[sqlalchemy.schema.Table] = [
280 self._schema.objects, self._schema.sources, self._schema.forcedSources]
281 if self.config.dia_object_index == 'last_object_table':
282 tables.append(self._schema.objects_last)
283 for table in tables:
284 stmt = sql.select([func.count()]).select_from(table)
285 count = self._engine.scalar(stmt)
286 res[table.name] = count
288 return res
290 def makeSchema(self, drop: bool = False) -> None:
291 # docstring is inherited from a base class
292 self._schema.makeSchema(drop=drop)
294 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
295 # docstring is inherited from a base class
297 # decide what columns we need
298 table: sqlalchemy.schema.Table
299 if self.config.dia_object_index == 'last_object_table':
300 table = self._schema.objects_last
301 else:
302 table = self._schema.objects
303 if not self.config.dia_object_columns:
304 query = table.select()
305 else:
306 columns = [table.c[col] for col in self.config.dia_object_columns]
307 query = sql.select(columns)
309 # build selection
310 htm_index_column = table.columns[self.config.htm_index_column]
311 exprlist = []
312 pixel_ranges = self._htm_indices(region)
313 for low, upper in pixel_ranges:
314 upper -= 1
315 if low == upper:
316 exprlist.append(htm_index_column == low)
317 else:
318 exprlist.append(sql.expression.between(htm_index_column, low, upper))
319 query = query.where(sql.expression.or_(*exprlist))
321 # select latest version of objects
322 if self.config.dia_object_index != 'last_object_table':
323 query = query.where(table.c.validityEnd == None) # noqa: E711
325 _LOG.debug("query: %s", query)
327 if self.config.explain:
328 # run the same query with explain
329 self._explain(query, self._engine)
331 # execute select
332 with Timer('DiaObject select', self.config.timer):
333 with self._engine.begin() as conn:
334 objects = pandas.read_sql_query(query, conn)
335 _LOG.debug("found %s DiaObjects", len(objects))
336 return objects
338 def getDiaSources(self, region: Region,
339 object_ids: Optional[Iterable[int]],
340 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
341 # docstring is inherited from a base class
342 if self.config.read_sources_months == 0:
343 _LOG.debug("Skip DiaSources fetching")
344 return None
346 if object_ids is None:
347 # region-based select
348 return self._getDiaSourcesInRegion(region, visit_time)
349 else:
350 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
352 def getDiaForcedSources(self, region: Region,
353 object_ids: Optional[Iterable[int]],
354 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
355 """Return catalog of DiaForcedSource instances from a given region.
357 Parameters
358 ----------
359 region : `lsst.sphgeom.Region`
360 Region to search for DIASources.
361 object_ids : iterable [ `int` ], optional
362 List of DiaObject IDs to further constrain the set of returned
363 sources. If list is empty then empty catalog is returned with a
364 correct schema.
365 visit_time : `lsst.daf.base.DateTime`
366 Time of the current visit.
368 Returns
369 -------
370 catalog : `pandas.DataFrame`, or `None`
371 Catalog containing DiaSource records. `None` is returned if
372 ``read_sources_months`` configuration parameter is set to 0.
374 Raises
375 ------
376 NotImplementedError
377 Raised if ``object_ids`` is `None`.
379 Notes
380 -----
381 Even though base class allows `None` to be passed for ``object_ids``,
382 this class requires ``object_ids`` to be not-`None`.
383 `NotImplementedError` is raised if `None` is passed.
385 This method returns DiaForcedSource catalog for a region with additional
386 filtering based on DiaObject IDs. Only a subset of DiaSource history
387 is returned limited by ``read_forced_sources_months`` config parameter,
388 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
389 is always returned with a correct schema (columns/types).
390 """
392 if self.config.read_forced_sources_months == 0:
393 _LOG.debug("Skip DiaForceSources fetching")
394 return None
396 if object_ids is None:
397 # This implementation does not support region-based selection.
398 raise NotImplementedError("Region-based selection is not supported")
400 # TODO: DateTime.MJD must be consistent with code in ap_association,
401 # alternatively we can fill midPointTai ourselves in store()
402 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months)
403 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
405 table: sqlalchemy.schema.Table = self._schema.forcedSources
406 with Timer('DiaForcedSource select', self.config.timer):
407 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start)
409 _LOG.debug("found %s DiaForcedSources", len(sources))
410 return sources
412 def store(self,
413 visit_time: dafBase.DateTime,
414 objects: pandas.DataFrame,
415 sources: Optional[pandas.DataFrame] = None,
416 forced_sources: Optional[pandas.DataFrame] = None) -> None:
417 # docstring is inherited from a base class
419 # fill pixelId column for DiaObjects
420 objects = self._add_obj_htm_index(objects)
421 self._storeDiaObjects(objects, visit_time)
423 if sources is not None:
424 # copy pixelId column from DiaObjects to DiaSources
425 sources = self._add_src_htm_index(sources, objects)
426 self._storeDiaSources(sources)
428 if forced_sources is not None:
429 self._storeDiaForcedSources(forced_sources)
431 def dailyJob(self) -> None:
432 # docstring is inherited from a base class
434 if self._engine.name == 'postgresql':
436 # do VACUUM on all tables
437 _LOG.info("Running VACUUM on all tables")
438 connection = self._engine.raw_connection()
439 ISOLATION_LEVEL_AUTOCOMMIT = 0
440 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
441 cursor = connection.cursor()
442 cursor.execute("VACUUM ANALYSE")
444 def countUnassociatedObjects(self) -> int:
445 # docstring is inherited from a base class
447 # Retrieve the DiaObject table.
448 table: sqlalchemy.schema.Table = self._schema.objects
450 # Construct the sql statement.
451 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
452 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
454 # Return the count.
455 with self._engine.begin() as conn:
456 count = conn.scalar(stmt)
458 return count
460 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime
461 ) -> pandas.DataFrame:
462 """Returns catalog of DiaSource instances from given region.
464 Parameters
465 ----------
466 region : `lsst.sphgeom.Region`
467 Region to search for DIASources.
468 visit_time : `lsst.daf.base.DateTime`
469 Time of the current visit.
471 Returns
472 -------
473 catalog : `pandas.DataFrame`
474 Catalog containing DiaSource records.
475 """
476 # TODO: DateTime.MJD must be consistent with code in ap_association,
477 # alternatively we can fill midPointTai ourselves in store()
478 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
479 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
481 table: sqlalchemy.schema.Table = self._schema.sources
482 query = table.select()
484 # build selection
485 htm_index_column = table.columns[self.config.htm_index_column]
486 exprlist = []
487 pixel_ranges = self._htm_indices(region)
488 for low, upper in pixel_ranges:
489 upper -= 1
490 if low == upper:
491 exprlist.append(htm_index_column == low)
492 else:
493 exprlist.append(sql.expression.between(htm_index_column, low, upper))
494 time_filter = table.columns["midPointTai"] > midPointTai_start
495 where = sql.expression.and_(sql.expression.or_(*exprlist), time_filter)
496 query = query.where(where)
498 # execute select
499 with Timer('DiaSource select', self.config.timer):
500 with _ansi_session(self._engine) as conn:
501 sources = pandas.read_sql_query(query, conn)
502 _LOG.debug("found %s DiaSources", len(sources))
503 return sources
505 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime
506 ) -> pandas.DataFrame:
507 """Returns catalog of DiaSource instances given set of DiaObject IDs.
509 Parameters
510 ----------
511 object_ids :
512 Collection of DiaObject IDs
513 visit_time : `lsst.daf.base.DateTime`
514 Time of the current visit.
516 Returns
517 -------
518 catalog : `pandas.DataFrame`
519 Catalog contaning DiaSource records.
520 """
521 # TODO: DateTime.MJD must be consistent with code in ap_association,
522 # alternatively we can fill midPointTai ourselves in store()
523 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months)
524 _LOG.debug("midPointTai_start = %.6f", midPointTai_start)
526 table: sqlalchemy.schema.Table = self._schema.sources
527 with Timer('DiaSource select', self.config.timer):
528 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start)
530 _LOG.debug("found %s DiaSources", len(sources))
531 return sources
533 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table,
534 object_ids: List[int],
535 midPointTai_start: float
536 ) -> pandas.DataFrame:
537 """Returns catalog of DiaSource or DiaForcedSource instances given set
538 of DiaObject IDs.
540 Parameters
541 ----------
542 table : `sqlalchemy.schema.Table`
543 Database table.
544 object_ids :
545 Collection of DiaObject IDs
546 midPointTai_start : `float`
547 Earliest midPointTai to retrieve.
549 Returns
550 -------
551 catalog : `pandas.DataFrame`
552 Catalog contaning DiaSource records. `None` is returned if
553 ``read_sources_months`` configuration parameter is set to 0 or
554 when ``object_ids`` is empty.
555 """
556 sources: Optional[pandas.DataFrame] = None
557 with _ansi_session(self._engine) as conn:
558 if len(object_ids) <= 0:
559 _LOG.debug("ID list is empty, just fetch empty result")
560 query = table.select().where(False)
561 sources = pandas.read_sql_query(query, conn)
562 else:
563 for ids in _split(sorted(object_ids), 1000):
564 query = f'SELECT * FROM "{table.name}" WHERE '
566 # select by object id
567 ids_str = ",".join(str(id) for id in ids)
568 query += f'"diaObjectId" IN ({ids_str})'
569 query += f' AND "midPointTai" > {midPointTai_start}'
571 # execute select
572 df = pandas.read_sql_query(sql.text(query), conn)
573 if sources is None:
574 sources = df
575 else:
576 sources = sources.append(df)
577 assert sources is not None, "Catalog cannot be None"
578 return sources
580 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
581 """Store catalog of DiaObjects from current visit.
583 Parameters
584 ----------
585 objs : `pandas.DataFrame`
586 Catalog with DiaObject records.
587 visit_time : `lsst.daf.base.DateTime`
588 Time of the visit.
589 """
591 ids = sorted(objs['diaObjectId'])
592 _LOG.debug("first object ID: %d", ids[0])
594 # NOTE: workaround for sqlite, need this here to avoid
595 # "database is locked" error.
596 table: sqlalchemy.schema.Table = self._schema.objects
598 # TODO: Need to verify that we are using correct scale here for
599 # DATETIME representation (see DM-31996).
600 dt = visit_time.toPython()
602 # everything to be done in single transaction
603 with _ansi_session(self._engine) as conn:
605 ids_str = ",".join(str(id) for id in ids)
607 if self.config.dia_object_index == 'last_object_table':
609 # insert and replace all records in LAST table, mysql and postgres have
610 # non-standard features
611 table = self._schema.objects_last
612 do_replace = self.config.object_last_replace
613 # If the input data is of type Pandas, we drop the previous
614 # objects regardless of the do_replace setting due to how
615 # Pandas inserts objects.
616 if not do_replace or isinstance(objs, pandas.DataFrame):
617 query = 'DELETE FROM "' + table.name + '" '
618 query += 'WHERE "diaObjectId" IN (' + ids_str + ') '
620 if self.config.explain:
621 # run the same query with explain
622 self._explain(query, conn)
624 with Timer(table.name + ' delete', self.config.timer):
625 res = conn.execute(sql.text(query))
626 _LOG.debug("deleted %s objects", res.rowcount)
628 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt)
629 with Timer("DiaObjectLast insert", self.config.timer):
630 objs = _coerce_uint64(objs)
631 for col, data in extra_columns.items():
632 objs[col] = data
633 objs.to_sql("DiaObjectLast", conn, if_exists='append',
634 index=False)
635 else:
637 # truncate existing validity intervals
638 table = self._schema.objects
639 query = 'UPDATE "' + table.name + '" '
640 query += "SET \"validityEnd\" = '" + str(dt) + "' "
641 query += 'WHERE "diaObjectId" IN (' + ids_str + ') '
642 query += 'AND "validityEnd" IS NULL'
644 # _LOG.debug("query: %s", query)
646 if self.config.explain:
647 # run the same query with explain
648 self._explain(query, conn)
650 with Timer(table.name + ' truncate', self.config.timer):
651 res = conn.execute(sql.text(query))
652 _LOG.debug("truncated %s intervals", res.rowcount)
654 # insert new versions
655 table = self._schema.objects
656 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt,
657 validityEnd=None)
658 with Timer("DiaObject insert", self.config.timer):
659 objs = _coerce_uint64(objs)
660 for col, data in extra_columns.items():
661 objs[col] = data
662 objs.to_sql("DiaObject", conn, if_exists='append',
663 index=False)
665 def _storeDiaSources(self, sources: pandas.DataFrame) -> None:
666 """Store catalog of DiaSources from current visit.
668 Parameters
669 ----------
670 sources : `pandas.DataFrame`
671 Catalog containing DiaSource records
672 """
673 # everything to be done in single transaction
674 with _ansi_session(self._engine) as conn:
676 with Timer("DiaSource insert", self.config.timer):
677 sources = _coerce_uint64(sources)
678 sources.to_sql("DiaSource", conn, if_exists='append', index=False)
680 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None:
681 """Store a set of DiaForcedSources from current visit.
683 Parameters
684 ----------
685 sources : `pandas.DataFrame`
686 Catalog containing DiaForcedSource records
687 """
689 # everything to be done in single transaction
690 with _ansi_session(self._engine) as conn:
692 with Timer("DiaForcedSource insert", self.config.timer):
693 sources = _coerce_uint64(sources)
694 sources.to_sql("DiaForcedSource", conn, if_exists='append', index=False)
696 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None:
697 """Run the query with explain
698 """
700 _LOG.info("explain for query: %s...", query[:64])
702 if conn.engine.name == 'mysql':
703 query = "EXPLAIN EXTENDED " + query
704 else:
705 query = "EXPLAIN " + query
707 res = conn.execute(sql.text(query))
708 if res.returns_rows:
709 _LOG.info("explain: %s", res.keys())
710 for row in res:
711 _LOG.info("explain: %s", row)
712 else:
713 _LOG.info("EXPLAIN returned nothing")
715 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
716 """Generate a set of HTM indices covering specified region.
718 Parameters
719 ----------
720 region: `sphgeom.Region`
721 Region that needs to be indexed.
723 Returns
724 -------
725 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
726 """
727 _LOG.debug('region: %s', region)
728 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
730 if _LOG.isEnabledFor(logging.DEBUG):
731 for irange in indices.ranges():
732 _LOG.debug('range: %s %s', self.pixelator.toString(irange[0]),
733 self.pixelator.toString(irange[1]))
735 return indices.ranges()
737 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
738 """Calculate HTM index for each record and add it to a DataFrame.
740 Notes
741 -----
742 This overrides any existing column in a DataFrame with the same name
743 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
744 returned.
745 """
746 # calculate HTM index for every DiaObject
747 htm_index = np.zeros(df.shape[0], dtype=np.int64)
748 ra_col, dec_col = self.config.ra_dec_columns
749 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
750 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
751 idx = self.pixelator.index(uv3d)
752 htm_index[i] = idx
753 df = df.copy()
754 df[self.config.htm_index_column] = htm_index
755 return df
757 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
758 """Add pixelId column to DiaSource catalog.
760 Notes
761 -----
762 This method copies pixelId value from a matching DiaObject record.
763 DiaObject catalog needs to have a pixelId column filled by
764 ``_add_obj_htm_index`` method and DiaSource records need to be
765 associated to DiaObjects via ``diaObjectId`` column.
767 This overrides any existing column in a DataFrame with the same name
768 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
769 returned.
770 """
771 pixel_id_map: Dict[int, int] = {
772 diaObjectId: pixelId for diaObjectId, pixelId
773 in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
774 }
776 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
777 for i, diaObjId in enumerate(sources["diaObjectId"]):
778 htm_index[i] = pixel_id_map[diaObjId]
779 sources = sources.copy()
780 sources[self.config.htm_index_column] = htm_index
781 return sources