Coverage for python/lsst/dax/apdb/apdbCassandra.py: 16%
401 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-25 01:34 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-25 01:34 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26import logging
27import numpy as np
28import pandas
29from typing import Any, cast, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union
31# If cassandra-driver is not there the module can still be imported
32# but ApdbCassandra cannot be instantiated.
33try:
34 import cassandra
35 from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
36 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator
37 import cassandra.query
38 CASSANDRA_IMPORTED = True
39except ImportError:
40 CASSANDRA_IMPORTED = False
42import lsst.daf.base as dafBase
43from lsst import sphgeom
44from lsst.pex.config import ChoiceField, Field, ListField
45from lsst.utils.iteration import chunk_iterable
46from .timer import Timer
47from .apdb import Apdb, ApdbConfig
48from .apdbSchema import ApdbTables, TableDef
49from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
50from .cassandra_utils import (
51 literal,
52 pandas_dataframe_factory,
53 quote_id,
54 raw_data_factory,
55 select_concurrent,
56)
57from .pixelization import Pixelization
59_LOG = logging.getLogger(__name__)
62class CassandraMissingError(Exception):
63 def __init__(self) -> None:
64 super().__init__("cassandra-driver module cannot be imported")
67class ApdbCassandraConfig(ApdbConfig):
69 contact_points = ListField(
70 dtype=str,
71 doc="The list of contact points to try connecting for cluster discovery.",
72 default=["127.0.0.1"]
73 )
74 private_ips = ListField(
75 dtype=str,
76 doc="List of internal IP addresses for contact_points.",
77 default=[]
78 )
79 keyspace = Field(
80 dtype=str,
81 doc="Default keyspace for operations.",
82 default="apdb"
83 )
84 read_consistency = Field(
85 dtype=str,
86 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.",
87 default="QUORUM"
88 )
89 write_consistency = Field(
90 dtype=str,
91 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.",
92 default="QUORUM"
93 )
94 read_timeout = Field(
95 dtype=float,
96 doc="Timeout in seconds for read operations.",
97 default=120.
98 )
99 write_timeout = Field(
100 dtype=float,
101 doc="Timeout in seconds for write operations.",
102 default=10.
103 )
104 read_concurrency = Field(
105 dtype=int,
106 doc="Concurrency level for read operations.",
107 default=500
108 )
109 protocol_version = Field(
110 dtype=int,
111 doc="Cassandra protocol version to use, default is V4",
112 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0
113 )
114 dia_object_columns = ListField(
115 dtype=str,
116 doc="List of columns to read from DiaObject, by default read all columns",
117 default=[]
118 )
119 prefix = Field(
120 dtype=str,
121 doc="Prefix to add to table names",
122 default=""
123 )
124 part_pixelization = ChoiceField(
125 dtype=str,
126 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
127 doc="Pixelization used for partitioning index.",
128 default="mq3c"
129 )
130 part_pix_level = Field(
131 dtype=int,
132 doc="Pixelization level used for partitioning index.",
133 default=10
134 )
135 part_pix_max_ranges = Field(
136 dtype=int,
137 doc="Max number of ranges in pixelization envelope",
138 default=64
139 )
140 ra_dec_columns = ListField(
141 dtype=str,
142 default=["ra", "decl"],
143 doc="Names ra/dec columns in DiaObject table"
144 )
145 timer = Field(
146 dtype=bool,
147 doc="If True then print/log timing information",
148 default=False
149 )
150 time_partition_tables = Field(
151 dtype=bool,
152 doc="Use per-partition tables for sources instead of partitioning by time",
153 default=True
154 )
155 time_partition_days = Field(
156 dtype=int,
157 doc="Time partitoning granularity in days, this value must not be changed"
158 " after database is initialized",
159 default=30
160 )
161 time_partition_start = Field(
162 dtype=str,
163 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI."
164 " This is used only when time_partition_tables is True.",
165 default="2018-12-01T00:00:00"
166 )
167 time_partition_end = Field(
168 dtype=str,
169 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI"
170 " This is used only when time_partition_tables is True.",
171 default="2030-01-01T00:00:00"
172 )
173 query_per_time_part = Field(
174 dtype=bool,
175 default=False,
176 doc="If True then build separate query for each time partition, otherwise build one single query. "
177 "This is only used when time_partition_tables is False in schema config."
178 )
179 query_per_spatial_part = Field(
180 dtype=bool,
181 default=False,
182 doc="If True then build one query per spacial partition, otherwise build single query. "
183 )
184 pandas_delay_conv = Field(
185 dtype=bool,
186 default=True,
187 doc="If True then combine result rows before converting to pandas. "
188 )
191if CASSANDRA_IMPORTED: 191 ↛ 193line 191 didn't jump to line 193, because the condition on line 191 was never true
193 class _AddressTranslator(AddressTranslator):
194 """Translate internal IP address to external.
196 Only used for docker-based setup, not viable long-term solution.
197 """
198 def __init__(self, public_ips: List[str], private_ips: List[str]):
199 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
201 def translate(self, private_ip: str) -> str:
202 return self._map.get(private_ip, private_ip)
205class ApdbCassandra(Apdb):
206 """Implementation of APDB database on to of Apache Cassandra.
208 The implementation is configured via standard ``pex_config`` mechanism
209 using `ApdbCassandraConfig` configuration class. For an example of
210 different configurations check config/ folder.
212 Parameters
213 ----------
214 config : `ApdbCassandraConfig`
215 Configuration object.
216 """
218 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
219 """Start time for partition 0, this should never be changed."""
221 def __init__(self, config: ApdbCassandraConfig):
223 if not CASSANDRA_IMPORTED:
224 raise CassandraMissingError()
226 self.config = config
228 _LOG.debug("ApdbCassandra Configuration:")
229 for key, value in self.config.items():
230 _LOG.debug(" %s: %s", key, value)
232 self._pixelization = Pixelization(
233 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
234 )
236 addressTranslator: Optional[AddressTranslator] = None
237 if config.private_ips:
238 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
240 self._keyspace = config.keyspace
242 self._cluster = Cluster(execution_profiles=self._makeProfiles(config),
243 contact_points=self.config.contact_points,
244 address_translator=addressTranslator,
245 protocol_version=self.config.protocol_version)
246 self._session = self._cluster.connect()
247 # Disable result paging
248 self._session.default_fetch_size = None
250 self._schema = ApdbCassandraSchema(session=self._session,
251 keyspace=self._keyspace,
252 schema_file=self.config.schema_file,
253 schema_name=self.config.schema_name,
254 prefix=self.config.prefix,
255 time_partition_tables=self.config.time_partition_tables)
256 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
258 # Cache for prepared statements
259 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
261 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
262 # docstring is inherited from a base class
263 return self._schema.tableSchemas.get(table)
265 def makeSchema(self, drop: bool = False) -> None:
266 # docstring is inherited from a base class
268 if self.config.time_partition_tables:
269 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
270 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
271 part_range = (
272 self._time_partition(time_partition_start),
273 self._time_partition(time_partition_end) + 1
274 )
275 self._schema.makeSchema(drop=drop, part_range=part_range)
276 else:
277 self._schema.makeSchema(drop=drop)
279 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
280 # docstring is inherited from a base class
282 sp_where = self._spatial_where(region)
283 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
285 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
286 query = f'SELECT * from "{self._keyspace}"."{table_name}"'
287 statements: List[Tuple] = []
288 for where, params in sp_where:
289 full_query = f"{query} WHERE {where}"
290 if params:
291 statement = self._prep_statement(full_query)
292 else:
293 # If there are no params then it is likely that query has a
294 # bunch of literals rendered already, no point trying to
295 # prepare it because it's not reusable.
296 statement = cassandra.query.SimpleStatement(full_query)
297 statements.append((statement, params))
298 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
300 with Timer('DiaObject select', self.config.timer):
301 objects = cast(
302 pandas.DataFrame,
303 select_concurrent(
304 self._session, statements, "read_pandas_multi", self.config.read_concurrency
305 )
306 )
308 _LOG.debug("found %s DiaObjects", objects.shape[0])
309 return objects
311 def getDiaSources(self, region: sphgeom.Region,
312 object_ids: Optional[Iterable[int]],
313 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
314 # docstring is inherited from a base class
315 months = self.config.read_sources_months
316 if months == 0:
317 return None
318 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
319 mjd_start = mjd_end - months*30
321 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
323 def getDiaForcedSources(self, region: sphgeom.Region,
324 object_ids: Optional[Iterable[int]],
325 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
326 # docstring is inherited from a base class
327 months = self.config.read_forced_sources_months
328 if months == 0:
329 return None
330 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
331 mjd_start = mjd_end - months*30
333 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
335 def getDiaObjectsHistory(self,
336 start_time: dafBase.DateTime,
337 end_time: dafBase.DateTime,
338 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
339 # docstring is inherited from a base class
341 sp_where = self._spatial_where(region, use_ranges=True)
342 tables, temporal_where = self._temporal_where(ApdbTables.DiaObject, start_time, end_time, True)
344 # Build all queries
345 statements: List[Tuple] = []
346 for table in tables:
347 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
348 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
349 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
351 # Run all selects in parallel
352 with Timer("DiaObject history", self.config.timer):
353 catalog = cast(
354 pandas.DataFrame,
355 select_concurrent(
356 self._session, statements, "read_pandas_multi", self.config.read_concurrency
357 )
358 )
360 # precise filtering on validityStart
361 validity_start = start_time.toPython()
362 validity_end = end_time.toPython()
363 catalog = cast(
364 pandas.DataFrame,
365 catalog[(catalog["validityStart"] >= validity_start) & (catalog["validityStart"] < validity_end)]
366 )
368 _LOG.debug("found %d DiaObjects", catalog.shape[0])
369 return catalog
371 def getDiaSourcesHistory(self,
372 start_time: dafBase.DateTime,
373 end_time: dafBase.DateTime,
374 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
375 # docstring is inherited from a base class
376 return self._getSourcesHistory(ApdbTables.DiaSource, start_time, end_time, region)
378 def getDiaForcedSourcesHistory(self,
379 start_time: dafBase.DateTime,
380 end_time: dafBase.DateTime,
381 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
382 # docstring is inherited from a base class
383 return self._getSourcesHistory(ApdbTables.DiaForcedSource, start_time, end_time, region)
385 def getSSObjects(self) -> pandas.DataFrame:
386 # docstring is inherited from a base class
387 tableName = self._schema.tableName(ApdbTables.SSObject)
388 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
390 objects = None
391 with Timer('SSObject select', self.config.timer):
392 result = self._session.execute(query, execution_profile="read_pandas")
393 objects = result._current_rows
395 _LOG.debug("found %s DiaObjects", objects.shape[0])
396 return objects
398 def store(self,
399 visit_time: dafBase.DateTime,
400 objects: pandas.DataFrame,
401 sources: Optional[pandas.DataFrame] = None,
402 forced_sources: Optional[pandas.DataFrame] = None) -> None:
403 # docstring is inherited from a base class
405 # fill region partition column for DiaObjects
406 objects = self._add_obj_part(objects)
407 self._storeDiaObjects(objects, visit_time)
409 if sources is not None:
410 # copy apdb_part column from DiaObjects to DiaSources
411 sources = self._add_src_part(sources, objects)
412 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time)
413 self._storeDiaSourcesPartitions(sources, visit_time)
415 if forced_sources is not None:
416 forced_sources = self._add_fsrc_part(forced_sources, objects)
417 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time)
419 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
420 # docstring is inherited from a base class
421 self._storeObjectsPandas(objects, ApdbTables.SSObject)
423 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
424 # docstring is inherited from a base class
426 # To update a record we need to know its exact primary key (including
427 # partition key) so we start by querying for diaSourceId to find the
428 # primary keys.
430 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
431 # split it into 1k IDs per query
432 selects: List[Tuple] = []
433 for ids in chunk_iterable(idMap.keys(), 1_000):
434 ids_str = ",".join(str(item) for item in ids)
435 selects.append((
436 (f'SELECT "diaSourceId", "apdb_part", "apdb_time_part" FROM "{self._keyspace}"."{table_name}"'
437 f' WHERE "diaSourceId" IN ({ids_str})'),
438 {}
439 ))
441 # No need for DataFrame here, read data as tuples.
442 result = cast(
443 List[Tuple[int, int, int]],
444 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency)
445 )
447 # Make mapping from source ID to its partition.
448 id2partitions: Dict[int, Tuple[int, int]] = {}
449 for row in result:
450 id2partitions[row[0]] = row[1:]
452 # make sure we know partitions for each ID
453 if set(id2partitions) != set(idMap):
454 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
455 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
457 queries = cassandra.query.BatchStatement()
458 table_name = self._schema.tableName(ApdbTables.DiaSource)
459 for diaSourceId, ssObjectId in idMap.items():
460 apdb_part, apdb_time_part = id2partitions[diaSourceId]
461 values: Tuple
462 if self.config.time_partition_tables:
463 query = (
464 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
465 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
466 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
467 )
468 values = (ssObjectId, apdb_part, diaSourceId)
469 else:
470 query = (
471 f'UPDATE "{self._keyspace}"."{table_name}"'
472 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
473 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
474 )
475 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
476 queries.add(self._prep_statement(query), values)
478 _LOG.debug("%s: will update %d records", table_name, len(idMap))
479 with Timer(table_name + ' update', self.config.timer):
480 self._session.execute(queries, execution_profile="write")
482 def dailyJob(self) -> None:
483 # docstring is inherited from a base class
484 pass
486 def countUnassociatedObjects(self) -> int:
487 # docstring is inherited from a base class
489 # It's too inefficient to implement it for Cassandra in current schema.
490 raise NotImplementedError()
492 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
493 """Make all execution profiles used in the code."""
495 if config.private_ips:
496 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
497 else:
498 loadBalancePolicy = RoundRobinPolicy()
500 pandas_row_factory: Callable
501 if not config.pandas_delay_conv:
502 pandas_row_factory = pandas_dataframe_factory
503 else:
504 pandas_row_factory = raw_data_factory
506 read_tuples_profile = ExecutionProfile(
507 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
508 request_timeout=config.read_timeout,
509 row_factory=cassandra.query.tuple_factory,
510 load_balancing_policy=loadBalancePolicy,
511 )
512 read_pandas_profile = ExecutionProfile(
513 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
514 request_timeout=config.read_timeout,
515 row_factory=pandas_dataframe_factory,
516 load_balancing_policy=loadBalancePolicy,
517 )
518 read_pandas_multi_profile = ExecutionProfile(
519 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
520 request_timeout=config.read_timeout,
521 row_factory=pandas_row_factory,
522 load_balancing_policy=loadBalancePolicy,
523 )
524 write_profile = ExecutionProfile(
525 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
526 request_timeout=config.write_timeout,
527 load_balancing_policy=loadBalancePolicy,
528 )
529 # To replace default DCAwareRoundRobinPolicy
530 default_profile = ExecutionProfile(
531 load_balancing_policy=loadBalancePolicy,
532 )
533 return {
534 "read_tuples": read_tuples_profile,
535 "read_pandas": read_pandas_profile,
536 "read_pandas_multi": read_pandas_multi_profile,
537 "write": write_profile,
538 EXEC_PROFILE_DEFAULT: default_profile,
539 }
541 def _getSources(self, region: sphgeom.Region,
542 object_ids: Optional[Iterable[int]],
543 mjd_start: float,
544 mjd_end: float,
545 table_name: ApdbTables) -> pandas.DataFrame:
546 """Returns catalog of DiaSource instances given set of DiaObject IDs.
548 Parameters
549 ----------
550 region : `lsst.sphgeom.Region`
551 Spherical region.
552 object_ids :
553 Collection of DiaObject IDs
554 mjd_start : `float`
555 Lower bound of time interval.
556 mjd_end : `float`
557 Upper bound of time interval.
558 table_name : `ApdbTables`
559 Name of the table.
561 Returns
562 -------
563 catalog : `pandas.DataFrame`, or `None`
564 Catalog contaning DiaSource records. Empty catalog is returned if
565 ``object_ids`` is empty.
566 """
567 object_id_set: Set[int] = set()
568 if object_ids is not None:
569 object_id_set = set(object_ids)
570 if len(object_id_set) == 0:
571 return self._make_empty_catalog(table_name)
573 sp_where = self._spatial_where(region)
574 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
576 # Build all queries
577 statements: List[Tuple] = []
578 for table in tables:
579 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
580 statements += list(self._combine_where(prefix, sp_where, temporal_where))
581 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
583 with Timer(table_name.name + ' select', self.config.timer):
584 catalog = cast(
585 pandas.DataFrame,
586 select_concurrent(
587 self._session, statements, "read_pandas_multi", self.config.read_concurrency
588 )
589 )
591 # filter by given object IDs
592 if len(object_id_set) > 0:
593 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
595 # precise filtering on midPointTai
596 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start])
598 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
599 return catalog
601 def _getSourcesHistory(
602 self,
603 table: ApdbTables,
604 start_time: dafBase.DateTime,
605 end_time: dafBase.DateTime,
606 region: Optional[sphgeom.Region] = None,
607 ) -> pandas.DataFrame:
608 """Returns catalog of DiaSource instances given set of DiaObject IDs.
610 Parameters
611 ----------
612 table : `ApdbTables`
613 Name of the table.
614 start_time : `dafBase.DateTime`
615 Starting time for DiaSource history search. DiaSource record is
616 selected when its ``midPointTai`` falls into an interval between
617 ``start_time`` (inclusive) and ``end_time`` (exclusive).
618 end_time : `dafBase.DateTime`
619 Upper limit on time for DiaSource history search.
620 region : `lsst.sphgeom.Region`
621 Spherical region.
623 Returns
624 -------
625 catalog : `pandas.DataFrame`
626 Catalog contaning DiaSource records.
627 """
628 sp_where = self._spatial_where(region, use_ranges=False)
629 tables, temporal_where = self._temporal_where(table, start_time, end_time, True)
631 # Build all queries
632 statements: List[Tuple] = []
633 for table_name in tables:
634 prefix = f'SELECT * from "{self._keyspace}"."{table_name}"'
635 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
636 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
638 # Run all selects in parallel
639 with Timer(f"{table.name} history", self.config.timer):
640 catalog = cast(
641 pandas.DataFrame,
642 select_concurrent(
643 self._session, statements, "read_pandas_multi", self.config.read_concurrency
644 )
645 )
647 # precise filtering on validityStart
648 period_start = start_time.get(system=dafBase.DateTime.MJD)
649 period_end = end_time.get(system=dafBase.DateTime.MJD)
650 catalog = cast(
651 pandas.DataFrame,
652 catalog[(catalog["midPointTai"] >= period_start) & (catalog["midPointTai"] < period_end)]
653 )
655 _LOG.debug("found %d %ss", catalog.shape[0], table.name)
656 return catalog
658 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
659 """Store catalog of DiaObjects from current visit.
661 Parameters
662 ----------
663 objs : `pandas.DataFrame`
664 Catalog with DiaObject records
665 visit_time : `lsst.daf.base.DateTime`
666 Time of the current visit.
667 """
668 visit_time_dt = visit_time.toPython()
669 extra_columns = dict(lastNonForcedSource=visit_time_dt)
670 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
672 extra_columns["validityStart"] = visit_time_dt
673 time_part: Optional[int] = self._time_partition(visit_time)
674 if not self.config.time_partition_tables:
675 extra_columns["apdb_time_part"] = time_part
676 time_part = None
678 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
680 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame,
681 visit_time: dafBase.DateTime) -> None:
682 """Store catalog of DIASources or DIAForcedSources from current visit.
684 Parameters
685 ----------
686 sources : `pandas.DataFrame`
687 Catalog containing DiaSource records
688 visit_time : `lsst.daf.base.DateTime`
689 Time of the current visit.
690 """
691 time_part: Optional[int] = self._time_partition(visit_time)
692 extra_columns = {}
693 if not self.config.time_partition_tables:
694 extra_columns["apdb_time_part"] = time_part
695 time_part = None
697 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
699 def _storeDiaSourcesPartitions(self, sources: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
700 """Store mapping of diaSourceId to its partitioning values.
702 Parameters
703 ----------
704 sources : `pandas.DataFrame`
705 Catalog containing DiaSource records
706 visit_time : `lsst.daf.base.DateTime`
707 Time of the current visit.
708 """
709 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
710 extra_columns = {
711 "apdb_time_part": self._time_partition(visit_time),
712 }
714 self._storeObjectsPandas(
715 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
716 )
718 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: Union[ApdbTables, ExtraTables],
719 extra_columns: Optional[Mapping] = None,
720 time_part: Optional[int] = None) -> None:
721 """Generic store method.
723 Takes catalog of records and stores a bunch of objects in a table.
725 Parameters
726 ----------
727 objects : `pandas.DataFrame`
728 Catalog containing object records
729 table_name : `ApdbTables`
730 Name of the table as defined in APDB schema.
731 extra_columns : `dict`, optional
732 Mapping (column_name, column_value) which gives column values to add
733 to every row, only if column is missing in catalog records.
734 time_part : `int`, optional
735 If not `None` then insert into a per-partition table.
736 """
737 # use extra columns if specified
738 if extra_columns is None:
739 extra_columns = {}
740 extra_fields = list(extra_columns.keys())
742 df_fields = [
743 column for column in objects.columns if column not in extra_fields
744 ]
746 column_map = self._schema.getColumnMap(table_name)
747 # list of columns (as in cat schema)
748 fields = [column_map[field].name for field in df_fields if field in column_map]
749 fields += extra_fields
751 # check that all partitioning and clustering columns are defined
752 required_columns = self._schema.partitionColumns(table_name) \
753 + self._schema.clusteringColumns(table_name)
754 missing_columns = [column for column in required_columns if column not in fields]
755 if missing_columns:
756 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
758 qfields = [quote_id(field) for field in fields]
759 qfields_str = ','.join(qfields)
761 with Timer(table_name.name + ' query build', self.config.timer):
763 table = self._schema.tableName(table_name)
764 if time_part is not None:
765 table = f"{table}_{time_part}"
767 holders = ','.join(['?']*len(qfields))
768 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
769 statement = self._prep_statement(query)
770 queries = cassandra.query.BatchStatement()
771 for rec in objects.itertuples(index=False):
772 values = []
773 for field in df_fields:
774 if field not in column_map:
775 continue
776 value = getattr(rec, field)
777 if column_map[field].type == "DATETIME":
778 if isinstance(value, pandas.Timestamp):
779 value = literal(value.to_pydatetime())
780 else:
781 # Assume it's seconds since epoch, Cassandra
782 # datetime is in milliseconds
783 value = int(value*1000)
784 values.append(literal(value))
785 for field in extra_fields:
786 value = extra_columns[field]
787 values.append(literal(value))
788 queries.add(statement, values)
790 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0])
791 with Timer(table_name.name + ' insert', self.config.timer):
792 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
794 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
795 """Calculate spacial partition for each record and add it to a
796 DataFrame.
798 Notes
799 -----
800 This overrides any existing column in a DataFrame with the same name
801 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
802 returned.
803 """
804 # calculate HTM index for every DiaObject
805 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
806 ra_col, dec_col = self.config.ra_dec_columns
807 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
808 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
809 idx = self._pixelization.pixel(uv3d)
810 apdb_part[i] = idx
811 df = df.copy()
812 df["apdb_part"] = apdb_part
813 return df
815 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
816 """Add apdb_part column to DiaSource catalog.
818 Notes
819 -----
820 This method copies apdb_part value from a matching DiaObject record.
821 DiaObject catalog needs to have a apdb_part column filled by
822 ``_add_obj_part`` method and DiaSource records need to be
823 associated to DiaObjects via ``diaObjectId`` column.
825 This overrides any existing column in a DataFrame with the same name
826 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
827 returned.
828 """
829 pixel_id_map: Dict[int, int] = {
830 diaObjectId: apdb_part for diaObjectId, apdb_part
831 in zip(objs["diaObjectId"], objs["apdb_part"])
832 }
833 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
834 ra_col, dec_col = self.config.ra_dec_columns
835 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"],
836 sources[ra_col], sources[dec_col])):
837 if diaObjId == 0:
838 # DiaSources associated with SolarSystemObjects do not have an
839 # associated DiaObject hence we skip them and set partition
840 # based on its own ra/dec
841 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
842 idx = self._pixelization.pixel(uv3d)
843 apdb_part[i] = idx
844 else:
845 apdb_part[i] = pixel_id_map[diaObjId]
846 sources = sources.copy()
847 sources["apdb_part"] = apdb_part
848 return sources
850 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
851 """Add apdb_part column to DiaForcedSource catalog.
853 Notes
854 -----
855 This method copies apdb_part value from a matching DiaObject record.
856 DiaObject catalog needs to have a apdb_part column filled by
857 ``_add_obj_part`` method and DiaSource records need to be
858 associated to DiaObjects via ``diaObjectId`` column.
860 This overrides any existing column in a DataFrame with the same name
861 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
862 returned.
863 """
864 pixel_id_map: Dict[int, int] = {
865 diaObjectId: apdb_part for diaObjectId, apdb_part
866 in zip(objs["diaObjectId"], objs["apdb_part"])
867 }
868 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
869 for i, diaObjId in enumerate(sources["diaObjectId"]):
870 apdb_part[i] = pixel_id_map[diaObjId]
871 sources = sources.copy()
872 sources["apdb_part"] = apdb_part
873 return sources
875 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
876 """Calculate time partiton number for a given time.
878 Parameters
879 ----------
880 time : `float` or `lsst.daf.base.DateTime`
881 Time for which to calculate partition number. Can be float to mean
882 MJD or `lsst.daf.base.DateTime`
884 Returns
885 -------
886 partition : `int`
887 Partition number for a given time.
888 """
889 if isinstance(time, dafBase.DateTime):
890 mjd = time.get(system=dafBase.DateTime.MJD)
891 else:
892 mjd = time
893 days_since_epoch = mjd - self._partition_zero_epoch_mjd
894 partition = int(days_since_epoch) // self.config.time_partition_days
895 return partition
897 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
898 """Make an empty catalog for a table with a given name.
900 Parameters
901 ----------
902 table_name : `ApdbTables`
903 Name of the table.
905 Returns
906 -------
907 catalog : `pandas.DataFrame`
908 An empty catalog.
909 """
910 table = self._schema.tableSchemas[table_name]
912 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns}
913 return pandas.DataFrame(data)
915 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
916 """Convert query string into prepared statement."""
917 stmt = self._prepared_statements.get(query)
918 if stmt is None:
919 stmt = self._session.prepare(query)
920 self._prepared_statements[query] = stmt
921 return stmt
923 def _combine_where(
924 self,
925 prefix: str,
926 where1: List[Tuple[str, Tuple]],
927 where2: List[Tuple[str, Tuple]],
928 suffix: Optional[str] = None,
929 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
930 """Make cartesian product of two parts of WHERE clause into a series
931 of statements to execute.
933 Parameters
934 ----------
935 prefix : `str`
936 Initial statement prefix that comes before WHERE clause, e.g.
937 "SELECT * from Table"
938 """
939 # If lists are empty use special sentinels.
940 if not where1:
941 where1 = [("", ())]
942 if not where2:
943 where2 = [("", ())]
945 for expr1, params1 in where1:
946 for expr2, params2 in where2:
947 full_query = prefix
948 wheres = []
949 if expr1:
950 wheres.append(expr1)
951 if expr2:
952 wheres.append(expr2)
953 if wheres:
954 full_query += " WHERE " + " AND ".join(wheres)
955 if suffix:
956 full_query += " " + suffix
957 params = params1 + params2
958 if params:
959 statement = self._prep_statement(full_query)
960 else:
961 # If there are no params then it is likely that query
962 # has a bunch of literals rendered already, no point
963 # trying to prepare it.
964 statement = cassandra.query.SimpleStatement(full_query)
965 yield (statement, params)
967 def _spatial_where(
968 self, region: Optional[sphgeom.Region], use_ranges: bool = False
969 ) -> List[Tuple[str, Tuple]]:
970 """Generate expressions for spatial part of WHERE clause.
972 Parameters
973 ----------
974 region : `sphgeom.Region`
975 Spatial region for query results.
976 use_ranges : `bool`
977 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
978 p2") instead of exact list of pixels. Should be set to True for
979 large regions covering very many pixels.
981 Returns
982 -------
983 expressions : `list` [ `tuple` ]
984 Empty list is returned if ``region`` is `None`, otherwise a list
985 of one or more (expression, parameters) tuples
986 """
987 if region is None:
988 return []
989 if use_ranges:
990 pixel_ranges = self._pixelization.envelope(region)
991 expressions: List[Tuple[str, Tuple]] = []
992 for lower, upper in pixel_ranges:
993 upper -= 1
994 if lower == upper:
995 expressions.append(('"apdb_part" = ?', (lower, )))
996 else:
997 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
998 return expressions
999 else:
1000 pixels = self._pixelization.pixels(region)
1001 if self.config.query_per_spatial_part:
1002 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1003 else:
1004 pixels_str = ",".join([str(pix) for pix in pixels])
1005 return [(f'"apdb_part" IN ({pixels_str})', ())]
1007 def _temporal_where(
1008 self,
1009 table: ApdbTables,
1010 start_time: Union[float, dafBase.DateTime],
1011 end_time: Union[float, dafBase.DateTime],
1012 query_per_time_part: Optional[bool] = None,
1013 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
1014 """Generate table names and expressions for temporal part of WHERE
1015 clauses.
1017 Parameters
1018 ----------
1019 table : `ApdbTables`
1020 Table to select from.
1021 start_time : `dafBase.DateTime` or `float`
1022 Starting Datetime of MJD value of the time range.
1023 start_time : `dafBase.DateTime` or `float`
1024 Starting Datetime of MJD value of the time range.
1025 query_per_time_part : `bool`, optional
1026 If None then use ``query_per_time_part`` from configuration.
1028 Returns
1029 -------
1030 tables : `list` [ `str` ]
1031 List of the table names to query.
1032 expressions : `list` [ `tuple` ]
1033 A list of zero or more (expression, parameters) tuples.
1034 """
1035 tables: List[str]
1036 temporal_where: List[Tuple[str, Tuple]] = []
1037 table_name = self._schema.tableName(table)
1038 time_part_start = self._time_partition(start_time)
1039 time_part_end = self._time_partition(end_time)
1040 time_parts = list(range(time_part_start, time_part_end + 1))
1041 if self.config.time_partition_tables:
1042 tables = [f"{table_name}_{part}" for part in time_parts]
1043 else:
1044 tables = [table_name]
1045 if query_per_time_part is None:
1046 query_per_time_part = self.config.query_per_time_part
1047 if query_per_time_part:
1048 temporal_where = [
1049 ('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts
1050 ]
1051 else:
1052 time_part_list = ",".join([str(part) for part in time_parts])
1053 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1055 return tables, temporal_where