Coverage for python/lsst/dax/apdb/apdbCassandra.py: 16%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26from datetime import datetime, timedelta
27import logging
28import numpy as np
29import pandas
30from typing import cast, Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
32try:
33 import cbor
34except ImportError:
35 cbor = None
37# If cassandra-driver is not there the module can still be imported
38# but ApdbCassandra cannot be instantiated.
39try:
40 import cassandra
41 from cassandra.cluster import Cluster
42 from cassandra.concurrent import execute_concurrent
43 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator
44 import cassandra.query
45 CASSANDRA_IMPORTED = True
46except ImportError:
47 CASSANDRA_IMPORTED = False
49import lsst.daf.base as dafBase
50from lsst.pex.config import ChoiceField, Field, ListField
51from lsst import sphgeom
52from .timer import Timer
53from .apdb import Apdb, ApdbConfig
54from .apdbSchema import ApdbTables, ColumnDef, TableDef
55from .apdbCassandraSchema import ApdbCassandraSchema
58_LOG = logging.getLogger(__name__)
61class CassandraMissingError(Exception):
62 def __init__(self) -> None:
63 super().__init__("cassandra-driver module cannot be imported")
66class ApdbCassandraConfig(ApdbConfig):
68 contact_points = ListField(
69 dtype=str,
70 doc="The list of contact points to try connecting for cluster discovery.",
71 default=["127.0.0.1"]
72 )
73 private_ips = ListField(
74 dtype=str,
75 doc="List of internal IP addresses for contact_points.",
76 default=[]
77 )
78 keyspace = Field(
79 dtype=str,
80 doc="Default keyspace for operations.",
81 default="apdb"
82 )
83 read_consistency = Field(
84 dtype=str,
85 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.",
86 default="QUORUM"
87 )
88 write_consistency = Field(
89 dtype=str,
90 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.",
91 default="QUORUM"
92 )
93 read_timeout = Field(
94 dtype=float,
95 doc="Timeout in seconds for read operations.",
96 default=120.
97 )
98 write_timeout = Field(
99 dtype=float,
100 doc="Timeout in seconds for write operations.",
101 default=10.
102 )
103 read_concurrency = Field(
104 dtype=int,
105 doc="Concurrency level for read operations.",
106 default=500
107 )
108 protocol_version = Field(
109 dtype=int,
110 doc="Cassandra protocol version to use, default is V4",
111 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0
112 )
113 dia_object_columns = ListField(
114 dtype=str,
115 doc="List of columns to read from DiaObject, by default read all columns",
116 default=[]
117 )
118 prefix = Field(
119 dtype=str,
120 doc="Prefix to add to table names",
121 default=""
122 )
123 part_pixelization = ChoiceField(
124 dtype=str,
125 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
126 doc="Pixelization used for partitioning index.",
127 default="mq3c"
128 )
129 part_pix_level = Field(
130 dtype=int,
131 doc="Pixelization level used for partitioning index.",
132 default=10
133 )
134 ra_dec_columns = ListField(
135 dtype=str,
136 default=["ra", "decl"],
137 doc="Names ra/dec columns in DiaObject table"
138 )
139 timer = Field(
140 dtype=bool,
141 doc="If True then print/log timing information",
142 default=False
143 )
144 time_partition_tables = Field(
145 dtype=bool,
146 doc="Use per-partition tables for sources instead of partitioning by time",
147 default=True
148 )
149 time_partition_days = Field(
150 dtype=int,
151 doc="Time partitoning granularity in days, this value must not be changed"
152 " after database is initialized",
153 default=30
154 )
155 time_partition_start = Field(
156 dtype=str,
157 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI."
158 " This is used only when time_partition_tables is True.",
159 default="2018-12-01T00:00:00"
160 )
161 time_partition_end = Field(
162 dtype=str,
163 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI"
164 " This is used only when time_partition_tables is True.",
165 default="2030-01-01T00:00:00"
166 )
167 query_per_time_part = Field(
168 dtype=bool,
169 default=False,
170 doc="If True then build separate query for each time partition, otherwise build one single query. "
171 "This is only used when time_partition_tables is False in schema config."
172 )
173 query_per_spatial_part = Field(
174 dtype=bool,
175 default=False,
176 doc="If True then build one query per spacial partition, otherwise build single query. "
177 )
178 pandas_delay_conv = Field(
179 dtype=bool,
180 default=True,
181 doc="If True then combine result rows before converting to pandas. "
182 )
183 packing = ChoiceField(
184 dtype=str,
185 allowed=dict(none="No field packing", cbor="Pack using CBOR"),
186 doc="Packing method for table records.",
187 default="none"
188 )
189 prepared_statements = Field(
190 dtype=bool,
191 default=True,
192 doc="If True use Cassandra prepared statements."
193 )
196class Partitioner:
197 """Class that calculates indices of the objects for partitioning.
199 Used internally by `ApdbCassandra`
201 Parameters
202 ----------
203 config : `ApdbCassandraConfig`
204 """
205 def __init__(self, config: ApdbCassandraConfig):
206 pix = config.part_pixelization
207 if pix == "htm":
208 self.pixelator = sphgeom.HtmPixelization(config.part_pix_level)
209 elif pix == "q3c":
210 self.pixelator = sphgeom.Q3cPixelization(config.part_pix_level)
211 elif pix == "mq3c":
212 self.pixelator = sphgeom.Mq3cPixelization(config.part_pix_level)
213 else:
214 raise ValueError(f"unknown pixelization: {pix}")
216 def pixels(self, region: sphgeom.Region) -> List[int]:
217 """Compute set of the pixel indices for given region.
219 Parameters
220 ----------
221 region : `lsst.sphgeom.Region`
222 """
223 # we want finest set of pixels, so ask as many pixel as possible
224 ranges = self.pixelator.envelope(region, 1_000_000)
225 indices = []
226 for lower, upper in ranges:
227 indices += list(range(lower, upper))
228 return indices
230 def pixel(self, direction: sphgeom.UnitVector3d) -> int:
231 """Compute the index of the pixel for given direction.
233 Parameters
234 ----------
235 direction : `lsst.sphgeom.UnitVector3d`
236 """
237 index = self.pixelator.index(direction)
238 return index
241if CASSANDRA_IMPORTED: 241 ↛ 243line 241 didn't jump to line 243, because the condition on line 241 was never true
243 class _AddressTranslator(AddressTranslator):
244 """Translate internal IP address to external.
246 Only used for docker-based setup, not viable long-term solution.
247 """
248 def __init__(self, public_ips: List[str], private_ips: List[str]):
249 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
251 def translate(self, private_ip: str) -> str:
252 return self._map.get(private_ip, private_ip)
255def _rows_to_pandas(colnames: List[str], rows: List[Tuple],
256 packedColumns: List[ColumnDef]) -> pandas.DataFrame:
257 """Convert result rows to pandas.
259 Unpacks BLOBs that were packed on insert.
261 Parameters
262 ----------
263 colname : `list` [ `str` ]
264 Names of the columns.
265 rows : `list` of `tuple`
266 Result rows.
267 packedColumns : `list` [ `ColumnDef` ]
268 Column definitions for packed columns.
270 Returns
271 -------
272 catalog : `pandas.DataFrame`
273 DataFrame with the result set.
274 """
275 try:
276 idx = colnames.index("apdb_packed")
277 except ValueError:
278 # no packed columns
279 return pandas.DataFrame.from_records(rows, columns=colnames)
281 # make data frame for non-packed columns
282 df = pandas.DataFrame.from_records(rows, columns=colnames, exclude=["apdb_packed"])
284 # make records with packed data only as dicts
285 packed_rows = []
286 for row in rows:
287 blob = row[idx]
288 if blob[:5] == b"cbor:":
289 blob = cbor.loads(blob[5:])
290 else:
291 raise ValueError("Unexpected BLOB format: %r", blob)
292 packed_rows.append(blob)
294 # make data frome from packed data
295 packed = pandas.DataFrame.from_records(packed_rows, columns=[col.name for col in packedColumns])
297 # convert timestamps which are integer milliseconds into datetime
298 for col in packedColumns:
299 if col.type == "DATETIME":
300 packed[col.name] = pandas.to_datetime(packed[col.name], unit="ms", origin="unix")
302 return pandas.concat([df, packed], axis=1)
305class _PandasRowFactory:
306 """Create pandas DataFrame from Cassandra result set.
308 Parameters
309 ----------
310 packedColumns : `list` [ `ColumnDef` ]
311 Column definitions for packed columns.
312 """
313 def __init__(self, packedColumns: Iterable[ColumnDef]):
314 self.packedColumns = list(packedColumns)
316 def __call__(self, colnames: List[str], rows: List[Tuple]) -> pandas.DataFrame:
317 """Convert result set into output catalog.
319 Parameters
320 ----------
321 colname : `list` [ `str` ]
322 Names of the columns.
323 rows : `list` of `tuple`
324 Result rows
326 Returns
327 -------
328 catalog : `pandas.DataFrame`
329 DataFrame with the result set.
330 """
331 return _rows_to_pandas(colnames, rows, self.packedColumns)
334class _RawRowFactory:
335 """Row factory that makes no conversions.
337 Parameters
338 ----------
339 packedColumns : `list` [ `ColumnDef` ]
340 Column definitions for packed columns.
341 """
342 def __init__(self, packedColumns: Iterable[ColumnDef]):
343 self.packedColumns = list(packedColumns)
345 def __call__(self, colnames: List[str], rows: List[Tuple]) -> Tuple[List[str], List[Tuple]]:
346 """Return parameters without change.
348 Parameters
349 ----------
350 colname : `list` of `str`
351 Names of the columns.
352 rows : `list` of `tuple`
353 Result rows
355 Returns
356 -------
357 colname : `list` of `str`
358 Names of the columns.
359 rows : `list` of `tuple`
360 Result rows
361 """
362 return (colnames, rows)
365class ApdbCassandra(Apdb):
366 """Implementation of APDB database on to of Apache Cassandra.
368 The implementation is configured via standard ``pex_config`` mechanism
369 using `ApdbCassandraConfig` configuration class. For an example of
370 different configurations check config/ folder.
372 Parameters
373 ----------
374 config : `ApdbCassandraConfig`
375 Configuration object.
376 """
378 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
379 """Start time for partition 0, this should never be changed."""
381 def __init__(self, config: ApdbCassandraConfig):
383 if not CASSANDRA_IMPORTED:
384 raise CassandraMissingError()
386 self.config = config
388 _LOG.debug("ApdbCassandra Configuration:")
389 _LOG.debug(" read_consistency: %s", self.config.read_consistency)
390 _LOG.debug(" write_consistency: %s", self.config.write_consistency)
391 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
392 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
393 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
394 _LOG.debug(" schema_file: %s", self.config.schema_file)
395 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
396 _LOG.debug(" schema prefix: %s", self.config.prefix)
397 _LOG.debug(" part_pixelization: %s", self.config.part_pixelization)
398 _LOG.debug(" part_pix_level: %s", self.config.part_pix_level)
399 _LOG.debug(" query_per_time_part: %s", self.config.query_per_time_part)
400 _LOG.debug(" query_per_spatial_part: %s", self.config.query_per_spatial_part)
402 self._partitioner = Partitioner(config)
404 addressTranslator: Optional[AddressTranslator] = None
405 if config.private_ips:
406 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
407 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
408 else:
409 loadBalancePolicy = RoundRobinPolicy()
411 self._read_consistency = getattr(cassandra.ConsistencyLevel, config.read_consistency)
412 self._write_consistency = getattr(cassandra.ConsistencyLevel, config.write_consistency)
414 self._cluster = Cluster(contact_points=self.config.contact_points,
415 load_balancing_policy=loadBalancePolicy,
416 address_translator=addressTranslator,
417 protocol_version=self.config.protocol_version)
418 self._session = self._cluster.connect(keyspace=config.keyspace)
419 self._session.row_factory = cassandra.query.named_tuple_factory
421 self._schema = ApdbCassandraSchema(session=self._session,
422 schema_file=self.config.schema_file,
423 extra_schema_file=self.config.extra_schema_file,
424 prefix=self.config.prefix,
425 packing=self.config.packing,
426 time_partition_tables=self.config.time_partition_tables)
427 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
429 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
430 # docstring is inherited from a base class
431 return self._schema.tableSchemas.get(table)
433 def makeSchema(self, drop: bool = False) -> None:
434 # docstring is inherited from a base class
436 if self.config.time_partition_tables:
437 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
438 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
439 part_range = (
440 self._time_partition(time_partition_start),
441 self._time_partition(time_partition_end) + 1
442 )
443 self._schema.makeSchema(drop=drop, part_range=part_range)
444 else:
445 self._schema.makeSchema(drop=drop)
447 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
448 # docstring is inherited from a base class
449 packedColumns = self._schema.packedColumns(ApdbTables.DiaObjectLast)
450 self._session.row_factory = _PandasRowFactory(packedColumns)
451 self._session.default_fetch_size = None
453 pixels = self._partitioner.pixels(region)
454 _LOG.debug("getDiaObjects: #partitions: %s", len(pixels))
455 pixels_str = ",".join([str(pix) for pix in pixels])
457 queries: List[Tuple] = []
458 query = f'SELECT * from "DiaObjectLast" WHERE "apdb_part" IN ({pixels_str})'
459 queries += [(cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})]
460 _LOG.debug("getDiaObjects: #queries: %s", len(queries))
461 # _LOG.debug("getDiaObjects: queries: %s", queries)
463 objects = None
464 with Timer('DiaObject select', self.config.timer):
465 # submit all queries
466 futures = [self._session.execute_async(query, values, timeout=self.config.read_timeout)
467 for query, values in queries]
468 # TODO: This orders result processing which is not very efficient
469 dataframes = [future.result()._current_rows for future in futures]
470 # concatenate all frames
471 if len(dataframes) == 1:
472 objects = dataframes[0]
473 else:
474 objects = pandas.concat(dataframes)
476 _LOG.debug("found %s DiaObjects", objects.shape[0])
477 return objects
479 def getDiaSources(self, region: sphgeom.Region,
480 object_ids: Optional[Iterable[int]],
481 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
482 # docstring is inherited from a base class
483 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaSource,
484 self.config.read_sources_months)
486 def getDiaForcedSources(self, region: sphgeom.Region,
487 object_ids: Optional[Iterable[int]],
488 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
489 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaForcedSource,
490 self.config.read_forced_sources_months)
492 def _getSources(self, region: sphgeom.Region,
493 object_ids: Optional[Iterable[int]],
494 visit_time: dafBase.DateTime,
495 table_name: ApdbTables,
496 months: int) -> Optional[pandas.DataFrame]:
497 """Returns catalog of DiaSource instances given set of DiaObject IDs.
499 Parameters
500 ----------
501 region : `lsst.sphgeom.Region`
502 Spherical region.
503 object_ids :
504 Collection of DiaObject IDs
505 visit_time : `lsst.daf.base.DateTime`
506 Time of the current visit
507 table_name : `ApdbTables`
508 Name of the table, either "DiaSource" or "DiaForcedSource"
509 months : `int`
510 Number of months of history to return, if negative returns whole
511 history (Note: negative does not work with table-per-partition
512 case)
514 Returns
515 -------
516 catalog : `pandas.DataFrame`, or `None`
517 Catalog contaning DiaSource records. `None` is returned if
518 ``months`` is 0 or when ``object_ids`` is empty.
519 """
520 if months == 0:
521 return None
522 object_id_set: Set[int] = set()
523 if object_ids is not None:
524 object_id_set = set(object_ids)
525 if len(object_id_set) == 0:
526 return self._make_empty_catalog(table_name)
528 packedColumns = self._schema.packedColumns(table_name)
529 if self.config.pandas_delay_conv:
530 self._session.row_factory = _RawRowFactory(packedColumns)
531 else:
532 self._session.row_factory = _PandasRowFactory(packedColumns)
533 self._session.default_fetch_size = None
535 # spatial pixels included into query
536 pixels = self._partitioner.pixels(region)
537 _LOG.debug("_getSources: %s #partitions: %s", table_name.name, len(pixels))
539 # spatial part of WHERE
540 spatial_where = []
541 if self.config.query_per_spatial_part:
542 spatial_where = [f'"apdb_part" = {pixel}' for pixel in pixels]
543 else:
544 pixels_str = ",".join([str(pix) for pix in pixels])
545 spatial_where = [f'"apdb_part" IN ({pixels_str})']
547 # temporal part of WHERE, can be empty
548 temporal_where = []
549 # time partitions and table names to query, there may be multiple
550 # tables depending on configuration
551 full_name = self._schema.tableName(table_name)
552 tables = [full_name]
553 mjd_now = visit_time.get(system=dafBase.DateTime.MJD)
554 mjd_begin = mjd_now - months*30
555 time_part_now = self._time_partition(mjd_now)
556 time_part_begin = self._time_partition(mjd_begin)
557 time_parts = list(range(time_part_begin, time_part_now + 1))
558 if self.config.time_partition_tables:
559 tables = [f"{full_name}_{part}" for part in time_parts]
560 else:
561 if self.config.query_per_time_part:
562 temporal_where = [f'"apdb_time_part" = {time_part}' for time_part in time_parts]
563 else:
564 time_part_list = ",".join([str(part) for part in time_parts])
565 temporal_where = [f'"apdb_time_part" IN ({time_part_list})']
567 # Build all queries
568 queries: List[str] = []
569 for table in tables:
570 query = f'SELECT * from "{table}" WHERE '
571 for spacial in spatial_where:
572 if temporal_where:
573 for temporal in temporal_where:
574 queries.append(query + spacial + " AND " + temporal)
575 else:
576 queries.append(query + spacial)
577 # _LOG.debug("_getSources: queries: %s", queries)
579 statements: List[Tuple] = [
580 (cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})
581 for query in queries
582 ]
583 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
585 with Timer(table_name.name + ' select', self.config.timer):
586 # submit all queries
587 results = execute_concurrent(self._session, statements, results_generator=True,
588 concurrency=self.config.read_concurrency)
589 if self.config.pandas_delay_conv:
590 _LOG.debug("making pandas data frame out of rows/columns")
591 columns: Any = None
592 rows = []
593 for success, result in results:
594 result = result._current_rows
595 if success:
596 if columns is None:
597 columns = result[0]
598 elif columns != result[0]:
599 _LOG.error("different columns returned by queries: %s and %s",
600 columns, result[0])
601 raise ValueError(
602 f"diferent columns returned by queries: {columns} and {result[0]}"
603 )
604 rows += result[1]
605 else:
606 _LOG.error("error returned by query: %s", result)
607 raise result
608 catalog = _rows_to_pandas(columns, rows, self._schema.packedColumns(table_name))
609 _LOG.debug("pandas catalog shape: %s", catalog.shape)
610 # filter by given object IDs
611 if len(object_id_set) > 0:
612 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
613 else:
614 _LOG.debug("making pandas data frame out of set of data frames")
615 dataframes = []
616 for success, result in results:
617 if success:
618 dataframes.append(result._current_rows)
619 else:
620 _LOG.error("error returned by query: %s", result)
621 raise result
622 # concatenate all frames
623 if len(dataframes) == 1:
624 catalog = dataframes[0]
625 else:
626 catalog = pandas.concat(dataframes)
627 _LOG.debug("pandas catalog shape: %s", catalog.shape)
628 # filter by given object IDs
629 if len(object_id_set) > 0:
630 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
632 # precise filtering on midPointTai
633 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_begin])
635 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
636 return catalog
638 def store(self,
639 visit_time: dafBase.DateTime,
640 objects: pandas.DataFrame,
641 sources: Optional[pandas.DataFrame] = None,
642 forced_sources: Optional[pandas.DataFrame] = None) -> None:
643 # docstring is inherited from a base class
645 # fill region partition column for DiaObjects
646 objects = self._add_obj_part(objects)
647 self._storeDiaObjects(objects, visit_time)
649 if sources is not None:
650 # copy apdb_part column from DiaObjects to DiaSources
651 sources = self._add_src_part(sources, objects)
652 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time)
654 if forced_sources is not None:
655 forced_sources = self._add_fsrc_part(forced_sources, objects)
656 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time)
658 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
659 """Store catalog of DiaObjects from current visit.
661 Parameters
662 ----------
663 objs : `pandas.DataFrame`
664 Catalog with DiaObject records
665 visit_time : `lsst.daf.base.DateTime`
666 Time of the current visit.
667 """
668 visit_time_dt = visit_time.toPython()
669 extra_columns = dict(lastNonForcedSource=visit_time_dt)
670 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, visit_time, extra_columns=extra_columns)
672 extra_columns["validityStart"] = visit_time_dt
673 time_part: Optional[int] = self._time_partition(visit_time)
674 if not self.config.time_partition_tables:
675 extra_columns["apdb_time_part"] = time_part
676 time_part = None
678 self._storeObjectsPandas(objs, ApdbTables.DiaObject, visit_time,
679 extra_columns=extra_columns, time_part=time_part)
681 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame,
682 visit_time: dafBase.DateTime) -> None:
683 """Store catalog of DIASources or DIAForcedSources from current visit.
685 Parameters
686 ----------
687 sources : `pandas.DataFrame`
688 Catalog containing DiaSource records
689 visit_time : `lsst.daf.base.DateTime`
690 Time of the current visit.
691 """
692 time_part: Optional[int] = self._time_partition(visit_time)
693 extra_columns = {}
694 if not self.config.time_partition_tables:
695 extra_columns["apdb_time_part"] = time_part
696 time_part = None
698 self._storeObjectsPandas(sources, table_name, visit_time,
699 extra_columns=extra_columns, time_part=time_part)
701 def dailyJob(self) -> None:
702 # docstring is inherited from a base class
703 pass
705 def countUnassociatedObjects(self) -> int:
706 # docstring is inherited from a base class
707 raise NotImplementedError()
709 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: ApdbTables,
710 visit_time: dafBase.DateTime, extra_columns: Optional[Mapping] = None,
711 time_part: Optional[int] = None) -> None:
712 """Generic store method.
714 Takes catalog of records and stores a bunch of objects in a table.
716 Parameters
717 ----------
718 objects : `pandas.DataFrame`
719 Catalog containing object records
720 table_name : `ApdbTables`
721 Name of the table as defined in APDB schema.
722 visit_time : `lsst.daf.base.DateTime`
723 Time of the current visit.
724 extra_columns : `dict`, optional
725 Mapping (column_name, column_value) which gives column values to add
726 to every row, only if column is missing in catalog records.
727 time_part : `int`, optional
728 If not `None` then insert into a per-partition table.
729 """
731 def qValue(v: Any) -> Any:
732 """Transform object into a value for query"""
733 if v is None:
734 pass
735 elif isinstance(v, datetime):
736 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1))*1000
737 elif isinstance(v, (bytes, str)):
738 pass
739 else:
740 try:
741 if not np.isfinite(v):
742 v = None
743 except TypeError:
744 pass
745 return v
747 def quoteId(columnName: str) -> str:
748 """Smart quoting for column names.
749 Lower-case names are not quoted.
750 """
751 if not columnName.islower():
752 columnName = '"' + columnName + '"'
753 return columnName
755 # use extra columns if specified
756 if extra_columns is None:
757 extra_columns = {}
758 extra_fields = list(extra_columns.keys())
760 df_fields = [column for column in objects.columns
761 if column not in extra_fields]
763 column_map = self._schema.getColumnMap(table_name)
764 # list of columns (as in cat schema)
765 fields = [column_map[field].name for field in df_fields if field in column_map]
766 fields += extra_fields
768 # check that all partitioning and clustering columns are defined
769 required_columns = self._schema.partitionColumns(table_name) \
770 + self._schema.clusteringColumns(table_name)
771 missing_columns = [column for column in required_columns if column not in fields]
772 if missing_columns:
773 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
775 blob_columns = set(col.name for col in self._schema.packedColumns(table_name))
776 # _LOG.debug("blob_columns: %s", blob_columns)
778 qfields = [quoteId(field) for field in fields if field not in blob_columns]
779 if blob_columns:
780 qfields += [quoteId("apdb_packed")]
781 qfields_str = ','.join(qfields)
783 with Timer(table_name.name + ' query build', self.config.timer):
785 table = self._schema.tableName(table_name)
786 if time_part is not None:
787 table = f"{table}_{time_part}"
789 prepared: Optional[cassandra.query.PreparedStatement] = None
790 if self.config.prepared_statements:
791 holders = ','.join(['?']*len(qfields))
792 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})'
793 prepared = self._session.prepare(query)
794 queries = cassandra.query.BatchStatement(consistency_level=self._write_consistency)
795 for rec in objects.itertuples(index=False):
796 values = []
797 blob = {}
798 for field in df_fields:
799 if field not in column_map:
800 continue
801 value = getattr(rec, field)
802 if column_map[field].type == "DATETIME":
803 if isinstance(value, pandas.Timestamp):
804 value = qValue(value.to_pydatetime())
805 else:
806 # Assume it's seconds since epoch, Cassandra
807 # datetime is in milliseconds
808 value = int(value*1000)
809 if field in blob_columns:
810 blob[field] = qValue(value)
811 else:
812 values.append(qValue(value))
813 for field in extra_fields:
814 value = extra_columns[field]
815 if field in blob_columns:
816 blob[field] = qValue(value)
817 else:
818 values.append(qValue(value))
819 if blob_columns:
820 if self.config.packing == "cbor":
821 blob = b"cbor:" + cbor.dumps(blob)
822 values.append(blob)
823 holders = ','.join(['%s']*len(values))
824 if prepared is not None:
825 stmt = prepared
826 else:
827 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})'
828 # _LOG.debug("query: %r", query)
829 # _LOG.debug("values: %s", values)
830 stmt = cassandra.query.SimpleStatement(query, consistency_level=self._write_consistency)
831 queries.add(stmt, values)
833 # _LOG.debug("query: %s", query)
834 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0])
835 with Timer(table_name.name + ' insert', self.config.timer):
836 self._session.execute(queries, timeout=self.config.write_timeout)
838 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
839 """Calculate spacial partition for each record and add it to a
840 DataFrame.
842 Notes
843 -----
844 This overrides any existing column in a DataFrame with the same name
845 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
846 returned.
847 """
848 # calculate HTM index for every DiaObject
849 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
850 ra_col, dec_col = self.config.ra_dec_columns
851 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
852 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
853 idx = self._partitioner.pixel(uv3d)
854 apdb_part[i] = idx
855 df = df.copy()
856 df["apdb_part"] = apdb_part
857 return df
859 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
860 """Add apdb_part column to DiaSource catalog.
862 Notes
863 -----
864 This method copies apdb_part value from a matching DiaObject record.
865 DiaObject catalog needs to have a apdb_part column filled by
866 ``_add_obj_part`` method and DiaSource records need to be
867 associated to DiaObjects via ``diaObjectId`` column.
869 This overrides any existing column in a DataFrame with the same name
870 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
871 returned.
872 """
873 pixel_id_map: Dict[int, int] = {
874 diaObjectId: apdb_part for diaObjectId, apdb_part
875 in zip(objs["diaObjectId"], objs["apdb_part"])
876 }
877 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
878 ra_col, dec_col = self.config.ra_dec_columns
879 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"],
880 sources[ra_col], sources[dec_col])):
881 if diaObjId == 0:
882 # DiaSources associated with SolarSystemObjects do not have an
883 # associated DiaObject hence we skip them and set partition
884 # based on its own ra/dec
885 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
886 idx = self._partitioner.pixel(uv3d)
887 apdb_part[i] = idx
888 else:
889 apdb_part[i] = pixel_id_map[diaObjId]
890 sources = sources.copy()
891 sources["apdb_part"] = apdb_part
892 return sources
894 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
895 """Add apdb_part column to DiaForcedSource catalog.
897 Notes
898 -----
899 This method copies apdb_part value from a matching DiaObject record.
900 DiaObject catalog needs to have a apdb_part column filled by
901 ``_add_obj_part`` method and DiaSource records need to be
902 associated to DiaObjects via ``diaObjectId`` column.
904 This overrides any existing column in a DataFrame with the same name
905 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
906 returned.
907 """
908 pixel_id_map: Dict[int, int] = {
909 diaObjectId: apdb_part for diaObjectId, apdb_part
910 in zip(objs["diaObjectId"], objs["apdb_part"])
911 }
912 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
913 for i, diaObjId in enumerate(sources["diaObjectId"]):
914 apdb_part[i] = pixel_id_map[diaObjId]
915 sources = sources.copy()
916 sources["apdb_part"] = apdb_part
917 return sources
919 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
920 """Calculate time partiton number for a given time.
922 Parameters
923 ----------
924 time : `float` or `lsst.daf.base.DateTime`
925 Time for which to calculate partition number. Can be float to mean
926 MJD or `lsst.daf.base.DateTime`
928 Returns
929 -------
930 partition : `int`
931 Partition number for a given time.
932 """
933 if isinstance(time, dafBase.DateTime):
934 mjd = time.get(system=dafBase.DateTime.MJD)
935 else:
936 mjd = time
937 days_since_epoch = mjd - self._partition_zero_epoch_mjd
938 partition = int(days_since_epoch) // self.config.time_partition_days
939 return partition
941 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
942 """Make an empty catalog for a table with a given name.
944 Parameters
945 ----------
946 table_name : `ApdbTables`
947 Name of the table.
949 Returns
950 -------
951 catalog : `pandas.DataFrame`
952 An empty catalog.
953 """
954 table = self._schema.tableSchemas[table_name]
956 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns}
957 return pandas.DataFrame(data)