Coverage for python/lsst/dax/apdb/apdbCassandra.py: 17%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26from datetime import datetime, timedelta
27import logging
28import numpy as np
29import pandas
30from typing import cast, Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
32try:
33 import cbor
34except ImportError:
35 cbor = None
37# If cassandra-driver is not there the module can still be imported
38# but ApdbCassandra cannot be instantiated.
39try:
40 import cassandra
41 from cassandra.cluster import Cluster
42 from cassandra.concurrent import execute_concurrent
43 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator
44 import cassandra.query
45 CASSANDRA_IMPORTED = True
46except ImportError:
47 CASSANDRA_IMPORTED = False
49import lsst.daf.base as dafBase
50from lsst.pex.config import ChoiceField, Field, ListField
51from lsst import sphgeom
52from .timer import Timer
53from .apdb import Apdb, ApdbConfig
54from .apdbSchema import ApdbTables, ColumnDef, TableDef
55from .apdbCassandraSchema import ApdbCassandraSchema
58_LOG = logging.getLogger(__name__)
61class CassandraMissingError(Exception):
62 def __init__(self) -> None:
63 super().__init__("cassandra-driver module cannot be imported")
66class ApdbCassandraConfig(ApdbConfig):
68 contact_points = ListField(
69 dtype=str,
70 doc="The list of contact points to try connecting for cluster discovery.",
71 default=["127.0.0.1"]
72 )
73 private_ips = ListField(
74 dtype=str,
75 doc="List of internal IP addresses for contact_points.",
76 default=[]
77 )
78 keyspace = Field(
79 dtype=str,
80 doc="Default keyspace for operations.",
81 default="apdb"
82 )
83 read_consistency = Field(
84 dtype=str,
85 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.",
86 default="QUORUM"
87 )
88 write_consistency = Field(
89 dtype=str,
90 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.",
91 default="QUORUM"
92 )
93 read_timeout = Field(
94 dtype=float,
95 doc="Timeout in seconds for read operations.",
96 default=120.
97 )
98 write_timeout = Field(
99 dtype=float,
100 doc="Timeout in seconds for write operations.",
101 default=10.
102 )
103 read_concurrency = Field(
104 dtype=int,
105 doc="Concurrency level for read operations.",
106 default=500
107 )
108 protocol_version = Field(
109 dtype=int,
110 doc="Cassandra protocol version to use, default is V4",
111 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0
112 )
113 dia_object_columns = ListField(
114 dtype=str,
115 doc="List of columns to read from DiaObject, by default read all columns",
116 default=[]
117 )
118 prefix = Field(
119 dtype=str,
120 doc="Prefix to add to table names",
121 default=""
122 )
123 part_pixelization = ChoiceField(
124 dtype=str,
125 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
126 doc="Pixelization used for partitioning index.",
127 default="mq3c"
128 )
129 part_pix_level = Field(
130 dtype=int,
131 doc="Pixelization level used for partitioning index.",
132 default=10
133 )
134 ra_dec_columns = ListField(
135 dtype=str,
136 default=["ra", "decl"],
137 doc="Names ra/dec columns in DiaObject table"
138 )
139 timer = Field(
140 dtype=bool,
141 doc="If True then print/log timing information",
142 default=False
143 )
144 time_partition_tables = Field(
145 dtype=bool,
146 doc="Use per-partition tables for sources instead of partitioning by time",
147 default=True
148 )
149 time_partition_days = Field(
150 dtype=int,
151 doc="Time partitoning granularity in days, this value must not be changed"
152 " after database is initialized",
153 default=30
154 )
155 time_partition_start = Field(
156 dtype=str,
157 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI."
158 " This is used only when time_partition_tables is True.",
159 default="2018-12-01T00:00:00"
160 )
161 time_partition_end = Field(
162 dtype=str,
163 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI"
164 " This is used only when time_partition_tables is True.",
165 default="2030-01-01T00:00:00"
166 )
167 query_per_time_part = Field(
168 dtype=bool,
169 default=False,
170 doc="If True then build separate query for each time partition, otherwise build one single query. "
171 "This is only used when time_partition_tables is False in schema config."
172 )
173 query_per_spatial_part = Field(
174 dtype=bool,
175 default=False,
176 doc="If True then build one query per spacial partition, otherwise build single query. "
177 )
178 pandas_delay_conv = Field(
179 dtype=bool,
180 default=True,
181 doc="If True then combine result rows before converting to pandas. "
182 )
183 packing = ChoiceField(
184 dtype=str,
185 allowed=dict(none="No field packing", cbor="Pack using CBOR"),
186 doc="Packing method for table records.",
187 default="none"
188 )
189 prepared_statements = Field(
190 dtype=bool,
191 default=True,
192 doc="If True use Cassandra prepared statements."
193 )
196class Partitioner:
197 """Class that calculates indices of the objects for partitioning.
199 Used internally by `ApdbCassandra`
201 Parameters
202 ----------
203 config : `ApdbCassandraConfig`
204 """
205 def __init__(self, config: ApdbCassandraConfig):
206 pix = config.part_pixelization
207 if pix == "htm":
208 self.pixelator = sphgeom.HtmPixelization(config.part_pix_level)
209 elif pix == "q3c":
210 self.pixelator = sphgeom.Q3cPixelization(config.part_pix_level)
211 elif pix == "mq3c":
212 self.pixelator = sphgeom.Mq3cPixelization(config.part_pix_level)
213 else:
214 raise ValueError(f"unknown pixelization: {pix}")
216 def pixels(self, region: sphgeom.Region) -> List[int]:
217 """Compute set of the pixel indices for given region.
219 Parameters
220 ----------
221 region : `lsst.sphgeom.Region`
222 """
223 # we want finest set of pixels, so ask as many pixel as possible
224 ranges = self.pixelator.envelope(region, 1_000_000)
225 indices = []
226 for lower, upper in ranges:
227 indices += list(range(lower, upper))
228 return indices
230 def pixel(self, direction: sphgeom.UnitVector3d) -> int:
231 """Compute the index of the pixel for given direction.
233 Parameters
234 ----------
235 direction : `lsst.sphgeom.UnitVector3d`
236 """
237 index = self.pixelator.index(direction)
238 return index
241if CASSANDRA_IMPORTED: 241 ↛ 243line 241 didn't jump to line 243, because the condition on line 241 was never true
243 class _AddressTranslator(AddressTranslator):
244 """Translate internal IP address to external.
246 Only used for docker-based setup, not viable long-term solution.
247 """
248 def __init__(self, public_ips: List[str], private_ips: List[str]):
249 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
251 def translate(self, private_ip: str) -> str:
252 return self._map.get(private_ip, private_ip)
255def _rows_to_pandas(colnames: List[str], rows: List[Tuple],
256 packedColumns: List[ColumnDef]) -> pandas.DataFrame:
257 """Convert result rows to pandas.
259 Unpacks BLOBs that were packed on insert.
261 Parameters
262 ----------
263 colname : `list` [ `str` ]
264 Names of the columns.
265 rows : `list` of `tuple`
266 Result rows.
267 packedColumns : `list` [ `ColumnDef` ]
268 Column definitions for packed columns.
270 Returns
271 -------
272 catalog : `pandas.DataFrame`
273 DataFrame with the result set.
274 """
275 try:
276 idx = colnames.index("apdb_packed")
277 except ValueError:
278 # no packed columns
279 return pandas.DataFrame.from_records(rows, columns=colnames)
281 # make data frame for non-packed columns
282 df = pandas.DataFrame.from_records(rows, columns=colnames, exclude=["apdb_packed"])
284 # make records with packed data only as dicts
285 packed_rows = []
286 for row in rows:
287 blob = row[idx]
288 if blob[:5] == b"cbor:":
289 blob = cbor.loads(blob[5:])
290 else:
291 raise ValueError("Unexpected BLOB format: %r", blob)
292 packed_rows.append(blob)
294 # make data frome from packed data
295 packed = pandas.DataFrame.from_records(packed_rows, columns=[col.name for col in packedColumns])
297 # convert timestamps which are integer milliseconds into datetime
298 for col in packedColumns:
299 if col.type == "DATETIME":
300 packed[col.name] = pandas.to_datetime(packed[col.name], unit="ms", origin="unix")
302 return pandas.concat([df, packed], axis=1)
305class _PandasRowFactory:
306 """Create pandas DataFrame from Cassandra result set.
308 Parameters
309 ----------
310 packedColumns : `list` [ `ColumnDef` ]
311 Column definitions for packed columns.
312 """
313 def __init__(self, packedColumns: Iterable[ColumnDef]):
314 self.packedColumns = list(packedColumns)
316 def __call__(self, colnames: List[str], rows: List[Tuple]) -> pandas.DataFrame:
317 """Convert result set into output catalog.
319 Parameters
320 ----------
321 colname : `list` [ `str` ]
322 Names of the columns.
323 rows : `list` of `tuple`
324 Result rows
326 Returns
327 -------
328 catalog : `pandas.DataFrame`
329 DataFrame with the result set.
330 """
331 return _rows_to_pandas(colnames, rows, self.packedColumns)
334class _RawRowFactory:
335 """Row factory that makes no conversions.
337 Parameters
338 ----------
339 packedColumns : `list` [ `ColumnDef` ]
340 Column definitions for packed columns.
341 """
342 def __init__(self, packedColumns: Iterable[ColumnDef]):
343 self.packedColumns = list(packedColumns)
345 def __call__(self, colnames: List[str], rows: List[Tuple]) -> Tuple[List[str], List[Tuple]]:
346 """Return parameters without change.
348 Parameters
349 ----------
350 colname : `list` of `str`
351 Names of the columns.
352 rows : `list` of `tuple`
353 Result rows
355 Returns
356 -------
357 colname : `list` of `str`
358 Names of the columns.
359 rows : `list` of `tuple`
360 Result rows
361 """
362 return (colnames, rows)
365class ApdbCassandra(Apdb):
366 """Implementation of APDB database on to of Apache Cassandra.
368 The implementation is configured via standard ``pex_config`` mechanism
369 using `ApdbCassandraConfig` configuration class. For an example of
370 different configurations check config/ folder.
372 Parameters
373 ----------
374 config : `ApdbCassandraConfig`
375 Configuration object.
376 """
378 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
379 """Start time for partition 0, this should never be changed."""
381 def __init__(self, config: ApdbCassandraConfig):
383 if not CASSANDRA_IMPORTED:
384 raise CassandraMissingError()
386 self.config = config
388 _LOG.debug("ApdbCassandra Configuration:")
389 _LOG.debug(" read_consistency: %s", self.config.read_consistency)
390 _LOG.debug(" write_consistency: %s", self.config.write_consistency)
391 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
392 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
393 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
394 _LOG.debug(" schema_file: %s", self.config.schema_file)
395 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
396 _LOG.debug(" schema prefix: %s", self.config.prefix)
397 _LOG.debug(" part_pixelization: %s", self.config.part_pixelization)
398 _LOG.debug(" part_pix_level: %s", self.config.part_pix_level)
399 _LOG.debug(" query_per_time_part: %s", self.config.query_per_time_part)
400 _LOG.debug(" query_per_spatial_part: %s", self.config.query_per_spatial_part)
402 self._partitioner = Partitioner(config)
404 addressTranslator: Optional[AddressTranslator] = None
405 if config.private_ips:
406 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
407 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
408 else:
409 loadBalancePolicy = RoundRobinPolicy()
411 self._read_consistency = getattr(cassandra.ConsistencyLevel, config.read_consistency)
412 self._write_consistency = getattr(cassandra.ConsistencyLevel, config.write_consistency)
414 self._cluster = Cluster(contact_points=self.config.contact_points,
415 load_balancing_policy=loadBalancePolicy,
416 address_translator=addressTranslator,
417 protocol_version=self.config.protocol_version)
418 self._session = self._cluster.connect(keyspace=config.keyspace)
419 self._session.row_factory = cassandra.query.named_tuple_factory
421 self._schema = ApdbCassandraSchema(session=self._session,
422 schema_file=self.config.schema_file,
423 extra_schema_file=self.config.extra_schema_file,
424 prefix=self.config.prefix,
425 packing=self.config.packing,
426 time_partition_tables=self.config.time_partition_tables)
427 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
429 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
430 # docstring is inherited from a base class
431 return self._schema.tableSchemas.get(table)
433 def makeSchema(self, drop: bool = False) -> None:
434 # docstring is inherited from a base class
436 if self.config.time_partition_tables:
437 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
438 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
439 part_range = (
440 self._time_partition(time_partition_start),
441 self._time_partition(time_partition_end) + 1
442 )
443 self._schema.makeSchema(drop=drop, part_range=part_range)
444 else:
445 self._schema.makeSchema(drop=drop)
447 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
448 # docstring is inherited from a base class
449 packedColumns = self._schema.packedColumns(ApdbTables.DiaObjectLast)
450 self._session.row_factory = _PandasRowFactory(packedColumns)
451 self._session.default_fetch_size = None
453 pixels = self._partitioner.pixels(region)
454 _LOG.debug("getDiaObjects: #partitions: %s", len(pixels))
455 pixels_str = ",".join([str(pix) for pix in pixels])
457 queries: List[Tuple] = []
458 query = f'SELECT * from "DiaObjectLast" WHERE "apdb_part" IN ({pixels_str})'
459 queries += [(cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})]
460 _LOG.debug("getDiaObjects: #queries: %s", len(queries))
461 # _LOG.debug("getDiaObjects: queries: %s", queries)
463 objects = None
464 with Timer('DiaObject select', self.config.timer):
465 # submit all queries
466 futures = [self._session.execute_async(query, values, timeout=self.config.read_timeout)
467 for query, values in queries]
468 # TODO: This orders result processing which is not very efficient
469 dataframes = [future.result()._current_rows for future in futures]
470 # concatenate all frames
471 if len(dataframes) == 1:
472 objects = dataframes[0]
473 else:
474 objects = pandas.concat(dataframes)
476 _LOG.debug("found %s DiaObjects", objects.shape[0])
477 return objects
479 def getDiaSources(self, region: sphgeom.Region,
480 object_ids: Optional[Iterable[int]],
481 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
482 # docstring is inherited from a base class
483 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaSource,
484 self.config.read_sources_months)
486 def getDiaForcedSources(self, region: sphgeom.Region,
487 object_ids: Optional[Iterable[int]],
488 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
489 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaForcedSource,
490 self.config.read_forced_sources_months)
492 def _getSources(self, region: sphgeom.Region,
493 object_ids: Optional[Iterable[int]],
494 visit_time: dafBase.DateTime,
495 table_name: ApdbTables,
496 months: int) -> Optional[pandas.DataFrame]:
497 """Returns catalog of DiaSource instances given set of DiaObject IDs.
499 Parameters
500 ----------
501 region : `lsst.sphgeom.Region`
502 Spherical region.
503 object_ids :
504 Collection of DiaObject IDs
505 visit_time : `lsst.daf.base.DateTime`
506 Time of the current visit
507 table_name : `ApdbTables`
508 Name of the table, either "DiaSource" or "DiaForcedSource"
509 months : `int`
510 Number of months of history to return, if negative returns whole
511 history (Note: negative does not work with table-per-partition
512 case)
514 Returns
515 -------
516 catalog : `pandas.DataFrame`, or `None`
517 Catalog contaning DiaSource records. `None` is returned if
518 ``months`` is 0 or when ``object_ids`` is empty.
519 """
520 if months == 0:
521 return None
522 object_id_set: Set[int] = set()
523 if object_ids is not None:
524 object_id_set = set(object_ids)
525 if len(object_id_set) == 0:
526 return self._make_empty_catalog(table_name)
528 packedColumns = self._schema.packedColumns(table_name)
529 if self.config.pandas_delay_conv:
530 self._session.row_factory = _RawRowFactory(packedColumns)
531 else:
532 self._session.row_factory = _PandasRowFactory(packedColumns)
533 self._session.default_fetch_size = None
535 # spatial pixels included into query
536 pixels = self._partitioner.pixels(region)
537 _LOG.debug("_getSources: %s #partitions: %s", table_name.name, len(pixels))
539 # spatial part of WHERE
540 spatial_where = []
541 if self.config.query_per_spatial_part:
542 spatial_where = [f'"apdb_part" = {pixel}' for pixel in pixels]
543 else:
544 pixels_str = ",".join([str(pix) for pix in pixels])
545 spatial_where = [f'"apdb_part" IN ({pixels_str})']
547 # temporal part of WHERE, can be empty
548 temporal_where = []
549 # time partitions and table names to query, there may be multiple
550 # tables depending on configuration
551 full_name = self._schema.tableName(table_name)
552 tables = [full_name]
553 mjd_now = visit_time.get(system=dafBase.DateTime.MJD)
554 mjd_begin = mjd_now - months*30
555 time_part_now = self._time_partition(mjd_now)
556 time_part_begin = self._time_partition(mjd_begin)
557 time_parts = list(range(time_part_begin, time_part_now + 1))
558 if self.config.time_partition_tables:
559 tables = [f"{full_name}_{part}" for part in time_parts]
560 else:
561 if self.config.query_per_time_part:
562 temporal_where = [f'"apdb_time_part" = {time_part}' for time_part in time_parts]
563 else:
564 time_part_list = ",".join([str(part) for part in time_parts])
565 temporal_where = [f'"apdb_time_part" IN ({time_part_list})']
567 # Build all queries
568 queries: List[str] = []
569 for table in tables:
570 query = f'SELECT * from "{table}" WHERE '
571 for spacial in spatial_where:
572 if temporal_where:
573 for temporal in temporal_where:
574 queries.append(query + spacial + " AND " + temporal)
575 else:
576 queries.append(query + spacial)
577 # _LOG.debug("_getSources: queries: %s", queries)
579 statements: List[Tuple] = [
580 (cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})
581 for query in queries
582 ]
583 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
585 with Timer(table_name.name + ' select', self.config.timer):
586 # submit all queries
587 results = execute_concurrent(self._session, statements, results_generator=True,
588 concurrency=self.config.read_concurrency)
589 if self.config.pandas_delay_conv:
590 _LOG.debug("making pandas data frame out of rows/columns")
591 columns: Any = None
592 rows = []
593 for success, result in results:
594 result = result._current_rows
595 if success:
596 if columns is None:
597 columns = result[0]
598 elif columns != result[0]:
599 _LOG.error("different columns returned by queries: %s and %s",
600 columns, result[0])
601 raise ValueError(
602 f"diferent columns returned by queries: {columns} and {result[0]}"
603 )
604 rows += result[1]
605 else:
606 _LOG.error("error returned by query: %s", result)
607 raise result
608 catalog = _rows_to_pandas(columns, rows, self._schema.packedColumns(table_name))
609 _LOG.debug("pandas catalog shape: %s", catalog.shape)
610 # filter by given object IDs
611 if len(object_id_set) > 0:
612 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
613 else:
614 _LOG.debug("making pandas data frame out of set of data frames")
615 dataframes = []
616 for success, result in results:
617 if success:
618 dataframes.append(result._current_rows)
619 else:
620 _LOG.error("error returned by query: %s", result)
621 raise result
622 # concatenate all frames
623 if len(dataframes) == 1:
624 catalog = dataframes[0]
625 else:
626 catalog = pandas.concat(dataframes)
627 _LOG.debug("pandas catalog shape: %s", catalog.shape)
628 # filter by given object IDs
629 if len(object_id_set) > 0:
630 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
632 # precise filtering on midPointTai
633 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_begin])
635 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
636 return catalog
638 def getDiaObjectsHistory(self,
639 start_time: dafBase.DateTime,
640 end_time: Optional[dafBase.DateTime] = None,
641 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
642 # docstring is inherited from a base class
643 raise NotImplementedError()
645 def getDiaSourcesHistory(self,
646 start_time: dafBase.DateTime,
647 end_time: Optional[dafBase.DateTime] = None,
648 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
649 # docstring is inherited from a base class
650 raise NotImplementedError()
652 def getDiaForcedSourcesHistory(self,
653 start_time: dafBase.DateTime,
654 end_time: Optional[dafBase.DateTime] = None,
655 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
656 # docstring is inherited from a base class
657 raise NotImplementedError()
659 def getSSObjects(self) -> pandas.DataFrame:
660 # docstring is inherited from a base class
661 raise NotImplementedError()
663 def store(self,
664 visit_time: dafBase.DateTime,
665 objects: pandas.DataFrame,
666 sources: Optional[pandas.DataFrame] = None,
667 forced_sources: Optional[pandas.DataFrame] = None) -> None:
668 # docstring is inherited from a base class
670 # fill region partition column for DiaObjects
671 objects = self._add_obj_part(objects)
672 self._storeDiaObjects(objects, visit_time)
674 if sources is not None:
675 # copy apdb_part column from DiaObjects to DiaSources
676 sources = self._add_src_part(sources, objects)
677 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time)
679 if forced_sources is not None:
680 forced_sources = self._add_fsrc_part(forced_sources, objects)
681 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time)
683 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
684 """Store catalog of DiaObjects from current visit.
686 Parameters
687 ----------
688 objs : `pandas.DataFrame`
689 Catalog with DiaObject records
690 visit_time : `lsst.daf.base.DateTime`
691 Time of the current visit.
692 """
693 visit_time_dt = visit_time.toPython()
694 extra_columns = dict(lastNonForcedSource=visit_time_dt)
695 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, visit_time, extra_columns=extra_columns)
697 extra_columns["validityStart"] = visit_time_dt
698 time_part: Optional[int] = self._time_partition(visit_time)
699 if not self.config.time_partition_tables:
700 extra_columns["apdb_time_part"] = time_part
701 time_part = None
703 self._storeObjectsPandas(objs, ApdbTables.DiaObject, visit_time,
704 extra_columns=extra_columns, time_part=time_part)
706 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame,
707 visit_time: dafBase.DateTime) -> None:
708 """Store catalog of DIASources or DIAForcedSources from current visit.
710 Parameters
711 ----------
712 sources : `pandas.DataFrame`
713 Catalog containing DiaSource records
714 visit_time : `lsst.daf.base.DateTime`
715 Time of the current visit.
716 """
717 time_part: Optional[int] = self._time_partition(visit_time)
718 extra_columns = {}
719 if not self.config.time_partition_tables:
720 extra_columns["apdb_time_part"] = time_part
721 time_part = None
723 self._storeObjectsPandas(sources, table_name, visit_time,
724 extra_columns=extra_columns, time_part=time_part)
726 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
727 # docstring is inherited from a base class
728 raise NotImplementedError()
730 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
731 # docstring is inherited from a base class
732 raise NotImplementedError()
734 def dailyJob(self) -> None:
735 # docstring is inherited from a base class
736 pass
738 def countUnassociatedObjects(self) -> int:
739 # docstring is inherited from a base class
740 raise NotImplementedError()
742 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: ApdbTables,
743 visit_time: dafBase.DateTime, extra_columns: Optional[Mapping] = None,
744 time_part: Optional[int] = None) -> None:
745 """Generic store method.
747 Takes catalog of records and stores a bunch of objects in a table.
749 Parameters
750 ----------
751 objects : `pandas.DataFrame`
752 Catalog containing object records
753 table_name : `ApdbTables`
754 Name of the table as defined in APDB schema.
755 visit_time : `lsst.daf.base.DateTime`
756 Time of the current visit.
757 extra_columns : `dict`, optional
758 Mapping (column_name, column_value) which gives column values to add
759 to every row, only if column is missing in catalog records.
760 time_part : `int`, optional
761 If not `None` then insert into a per-partition table.
762 """
764 def qValue(v: Any) -> Any:
765 """Transform object into a value for query"""
766 if v is None:
767 pass
768 elif isinstance(v, datetime):
769 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1))*1000
770 elif isinstance(v, (bytes, str)):
771 pass
772 else:
773 try:
774 if not np.isfinite(v):
775 v = None
776 except TypeError:
777 pass
778 return v
780 def quoteId(columnName: str) -> str:
781 """Smart quoting for column names.
782 Lower-case names are not quoted.
783 """
784 if not columnName.islower():
785 columnName = '"' + columnName + '"'
786 return columnName
788 # use extra columns if specified
789 if extra_columns is None:
790 extra_columns = {}
791 extra_fields = list(extra_columns.keys())
793 df_fields = [column for column in objects.columns
794 if column not in extra_fields]
796 column_map = self._schema.getColumnMap(table_name)
797 # list of columns (as in cat schema)
798 fields = [column_map[field].name for field in df_fields if field in column_map]
799 fields += extra_fields
801 # check that all partitioning and clustering columns are defined
802 required_columns = self._schema.partitionColumns(table_name) \
803 + self._schema.clusteringColumns(table_name)
804 missing_columns = [column for column in required_columns if column not in fields]
805 if missing_columns:
806 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
808 blob_columns = set(col.name for col in self._schema.packedColumns(table_name))
809 # _LOG.debug("blob_columns: %s", blob_columns)
811 qfields = [quoteId(field) for field in fields if field not in blob_columns]
812 if blob_columns:
813 qfields += [quoteId("apdb_packed")]
814 qfields_str = ','.join(qfields)
816 with Timer(table_name.name + ' query build', self.config.timer):
818 table = self._schema.tableName(table_name)
819 if time_part is not None:
820 table = f"{table}_{time_part}"
822 prepared: Optional[cassandra.query.PreparedStatement] = None
823 if self.config.prepared_statements:
824 holders = ','.join(['?']*len(qfields))
825 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})'
826 prepared = self._session.prepare(query)
827 queries = cassandra.query.BatchStatement(consistency_level=self._write_consistency)
828 for rec in objects.itertuples(index=False):
829 values = []
830 blob = {}
831 for field in df_fields:
832 if field not in column_map:
833 continue
834 value = getattr(rec, field)
835 if column_map[field].type == "DATETIME":
836 if isinstance(value, pandas.Timestamp):
837 value = qValue(value.to_pydatetime())
838 else:
839 # Assume it's seconds since epoch, Cassandra
840 # datetime is in milliseconds
841 value = int(value*1000)
842 if field in blob_columns:
843 blob[field] = qValue(value)
844 else:
845 values.append(qValue(value))
846 for field in extra_fields:
847 value = extra_columns[field]
848 if field in blob_columns:
849 blob[field] = qValue(value)
850 else:
851 values.append(qValue(value))
852 if blob_columns:
853 if self.config.packing == "cbor":
854 blob = b"cbor:" + cbor.dumps(blob)
855 values.append(blob)
856 holders = ','.join(['%s']*len(values))
857 if prepared is not None:
858 stmt = prepared
859 else:
860 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})'
861 # _LOG.debug("query: %r", query)
862 # _LOG.debug("values: %s", values)
863 stmt = cassandra.query.SimpleStatement(query, consistency_level=self._write_consistency)
864 queries.add(stmt, values)
866 # _LOG.debug("query: %s", query)
867 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0])
868 with Timer(table_name.name + ' insert', self.config.timer):
869 self._session.execute(queries, timeout=self.config.write_timeout)
871 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
872 """Calculate spacial partition for each record and add it to a
873 DataFrame.
875 Notes
876 -----
877 This overrides any existing column in a DataFrame with the same name
878 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
879 returned.
880 """
881 # calculate HTM index for every DiaObject
882 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
883 ra_col, dec_col = self.config.ra_dec_columns
884 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
885 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
886 idx = self._partitioner.pixel(uv3d)
887 apdb_part[i] = idx
888 df = df.copy()
889 df["apdb_part"] = apdb_part
890 return df
892 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
893 """Add apdb_part column to DiaSource catalog.
895 Notes
896 -----
897 This method copies apdb_part value from a matching DiaObject record.
898 DiaObject catalog needs to have a apdb_part column filled by
899 ``_add_obj_part`` method and DiaSource records need to be
900 associated to DiaObjects via ``diaObjectId`` column.
902 This overrides any existing column in a DataFrame with the same name
903 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
904 returned.
905 """
906 pixel_id_map: Dict[int, int] = {
907 diaObjectId: apdb_part for diaObjectId, apdb_part
908 in zip(objs["diaObjectId"], objs["apdb_part"])
909 }
910 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
911 ra_col, dec_col = self.config.ra_dec_columns
912 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"],
913 sources[ra_col], sources[dec_col])):
914 if diaObjId == 0:
915 # DiaSources associated with SolarSystemObjects do not have an
916 # associated DiaObject hence we skip them and set partition
917 # based on its own ra/dec
918 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
919 idx = self._partitioner.pixel(uv3d)
920 apdb_part[i] = idx
921 else:
922 apdb_part[i] = pixel_id_map[diaObjId]
923 sources = sources.copy()
924 sources["apdb_part"] = apdb_part
925 return sources
927 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
928 """Add apdb_part column to DiaForcedSource catalog.
930 Notes
931 -----
932 This method copies apdb_part value from a matching DiaObject record.
933 DiaObject catalog needs to have a apdb_part column filled by
934 ``_add_obj_part`` method and DiaSource records need to be
935 associated to DiaObjects via ``diaObjectId`` column.
937 This overrides any existing column in a DataFrame with the same name
938 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
939 returned.
940 """
941 pixel_id_map: Dict[int, int] = {
942 diaObjectId: apdb_part for diaObjectId, apdb_part
943 in zip(objs["diaObjectId"], objs["apdb_part"])
944 }
945 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
946 for i, diaObjId in enumerate(sources["diaObjectId"]):
947 apdb_part[i] = pixel_id_map[diaObjId]
948 sources = sources.copy()
949 sources["apdb_part"] = apdb_part
950 return sources
952 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
953 """Calculate time partiton number for a given time.
955 Parameters
956 ----------
957 time : `float` or `lsst.daf.base.DateTime`
958 Time for which to calculate partition number. Can be float to mean
959 MJD or `lsst.daf.base.DateTime`
961 Returns
962 -------
963 partition : `int`
964 Partition number for a given time.
965 """
966 if isinstance(time, dafBase.DateTime):
967 mjd = time.get(system=dafBase.DateTime.MJD)
968 else:
969 mjd = time
970 days_since_epoch = mjd - self._partition_zero_epoch_mjd
971 partition = int(days_since_epoch) // self.config.time_partition_days
972 return partition
974 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
975 """Make an empty catalog for a table with a given name.
977 Parameters
978 ----------
979 table_name : `ApdbTables`
980 Name of the table.
982 Returns
983 -------
984 catalog : `pandas.DataFrame`
985 An empty catalog.
986 """
987 table = self._schema.tableSchemas[table_name]
989 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns}
990 return pandas.DataFrame(data)