Coverage for python/lsst/dax/apdb/apdbCassandra.py: 14%
449 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-02 06:34 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-02 06:34 -0800
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26import logging
27import uuid
28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast
30import numpy as np
31import pandas
33# If cassandra-driver is not there the module can still be imported
34# but ApdbCassandra cannot be instantiated.
35try:
36 import cassandra
37 import cassandra.query
38 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
39 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
41 CASSANDRA_IMPORTED = True
42except ImportError:
43 CASSANDRA_IMPORTED = False
45import felis.types
46import lsst.daf.base as dafBase
47from felis.simple import Table
48from lsst import sphgeom
49from lsst.pex.config import ChoiceField, Field, ListField
50from lsst.utils.iteration import chunk_iterable
52from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
53from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
54from .apdbSchema import ApdbTables
55from .cassandra_utils import (
56 ApdbCassandraTableData,
57 literal,
58 pandas_dataframe_factory,
59 quote_id,
60 raw_data_factory,
61 select_concurrent,
62)
63from .pixelization import Pixelization
64from .timer import Timer
66_LOG = logging.getLogger(__name__)
69class CassandraMissingError(Exception):
70 def __init__(self) -> None:
71 super().__init__("cassandra-driver module cannot be imported")
74class ApdbCassandraConfig(ApdbConfig):
76 contact_points = ListField[str](
77 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"]
78 )
79 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[])
80 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb")
81 read_consistency = Field[str](
82 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM"
83 )
84 write_consistency = Field[str](
85 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM"
86 )
87 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0)
88 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0)
89 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500)
90 protocol_version = Field[int](
91 doc="Cassandra protocol version to use, default is V4",
92 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0,
93 )
94 dia_object_columns = ListField[str](
95 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[]
96 )
97 prefix = Field[str](doc="Prefix to add to table names", default="")
98 part_pixelization = ChoiceField[str](
99 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
100 doc="Pixelization used for partitioning index.",
101 default="mq3c",
102 )
103 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10)
104 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64)
105 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names ra/dec columns in DiaObject table")
106 timer = Field[bool](doc="If True then print/log timing information", default=False)
107 time_partition_tables = Field[bool](
108 doc="Use per-partition tables for sources instead of partitioning by time", default=True
109 )
110 time_partition_days = Field[int](
111 doc=(
112 "Time partitioning granularity in days, this value must not be changed after database is "
113 "initialized"
114 ),
115 default=30,
116 )
117 time_partition_start = Field[str](
118 doc=(
119 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
120 "This is used only when time_partition_tables is True."
121 ),
122 default="2018-12-01T00:00:00",
123 )
124 time_partition_end = Field[str](
125 doc=(
126 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
127 "This is used only when time_partition_tables is True."
128 ),
129 default="2030-01-01T00:00:00",
130 )
131 query_per_time_part = Field[bool](
132 default=False,
133 doc=(
134 "If True then build separate query for each time partition, otherwise build one single query. "
135 "This is only used when time_partition_tables is False in schema config."
136 ),
137 )
138 query_per_spatial_part = Field[bool](
139 default=False,
140 doc="If True then build one query per spacial partition, otherwise build single query.",
141 )
144if CASSANDRA_IMPORTED: 144 ↛ 146line 144 didn't jump to line 146, because the condition on line 144 was never true
146 class _AddressTranslator(AddressTranslator):
147 """Translate internal IP address to external.
149 Only used for docker-based setup, not viable long-term solution.
150 """
152 def __init__(self, public_ips: List[str], private_ips: List[str]):
153 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
155 def translate(self, private_ip: str) -> str:
156 return self._map.get(private_ip, private_ip)
159def _quote_column(name: str) -> str:
160 """Quote column name"""
161 if name.islower():
162 return name
163 else:
164 return f'"{name}"'
167class ApdbCassandra(Apdb):
168 """Implementation of APDB database on to of Apache Cassandra.
170 The implementation is configured via standard ``pex_config`` mechanism
171 using `ApdbCassandraConfig` configuration class. For an example of
172 different configurations check config/ folder.
174 Parameters
175 ----------
176 config : `ApdbCassandraConfig`
177 Configuration object.
178 """
180 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
181 """Start time for partition 0, this should never be changed."""
183 def __init__(self, config: ApdbCassandraConfig):
185 if not CASSANDRA_IMPORTED:
186 raise CassandraMissingError()
188 config.validate()
189 self.config = config
191 _LOG.debug("ApdbCassandra Configuration:")
192 for key, value in self.config.items():
193 _LOG.debug(" %s: %s", key, value)
195 self._pixelization = Pixelization(
196 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
197 )
199 addressTranslator: Optional[AddressTranslator] = None
200 if config.private_ips:
201 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
203 self._keyspace = config.keyspace
205 self._cluster = Cluster(
206 execution_profiles=self._makeProfiles(config),
207 contact_points=self.config.contact_points,
208 address_translator=addressTranslator,
209 protocol_version=self.config.protocol_version,
210 )
211 self._session = self._cluster.connect()
212 # Disable result paging
213 self._session.default_fetch_size = None
215 self._schema = ApdbCassandraSchema(
216 session=self._session,
217 keyspace=self._keyspace,
218 schema_file=self.config.schema_file,
219 schema_name=self.config.schema_name,
220 prefix=self.config.prefix,
221 time_partition_tables=self.config.time_partition_tables,
222 use_insert_id=self.config.use_insert_id,
223 )
224 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
226 # Cache for prepared statements
227 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
229 def __del__(self) -> None:
230 self._cluster.shutdown()
232 def tableDef(self, table: ApdbTables) -> Optional[Table]:
233 # docstring is inherited from a base class
234 return self._schema.tableSchemas.get(table)
236 def makeSchema(self, drop: bool = False) -> None:
237 # docstring is inherited from a base class
239 if self.config.time_partition_tables:
240 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
241 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
242 part_range = (
243 self._time_partition(time_partition_start),
244 self._time_partition(time_partition_end) + 1,
245 )
246 self._schema.makeSchema(drop=drop, part_range=part_range)
247 else:
248 self._schema.makeSchema(drop=drop)
250 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
251 # docstring is inherited from a base class
253 sp_where = self._spatial_where(region)
254 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
256 # We need to exclude extra partitioning columns from result.
257 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
258 what = ",".join(_quote_column(column) for column in column_names)
260 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
261 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"'
262 statements: List[Tuple] = []
263 for where, params in sp_where:
264 full_query = f"{query} WHERE {where}"
265 if params:
266 statement = self._prep_statement(full_query)
267 else:
268 # If there are no params then it is likely that query has a
269 # bunch of literals rendered already, no point trying to
270 # prepare it because it's not reusable.
271 statement = cassandra.query.SimpleStatement(full_query)
272 statements.append((statement, params))
273 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
275 with Timer("DiaObject select", self.config.timer):
276 objects = cast(
277 pandas.DataFrame,
278 select_concurrent(
279 self._session, statements, "read_pandas_multi", self.config.read_concurrency
280 ),
281 )
283 _LOG.debug("found %s DiaObjects", objects.shape[0])
284 return objects
286 def getDiaSources(
287 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
288 ) -> Optional[pandas.DataFrame]:
289 # docstring is inherited from a base class
290 months = self.config.read_sources_months
291 if months == 0:
292 return None
293 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
294 mjd_start = mjd_end - months * 30
296 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
298 def getDiaForcedSources(
299 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
300 ) -> Optional[pandas.DataFrame]:
301 # docstring is inherited from a base class
302 months = self.config.read_forced_sources_months
303 if months == 0:
304 return None
305 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
306 mjd_start = mjd_end - months * 30
308 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
310 def getInsertIds(self) -> list[ApdbInsertId] | None:
311 # docstring is inherited from a base class
312 if not self._schema.has_insert_id:
313 return None
315 # everything goes into a single partition
316 partition = 0
318 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
319 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?'
321 result = self._session.execute(
322 self._prep_statement(query),
323 (partition,),
324 timeout=self.config.read_timeout,
325 execution_profile="read_tuples",
326 )
327 # order by insert_time
328 rows = sorted(result)
329 return [ApdbInsertId(row[1]) for row in rows]
331 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
332 # docstring is inherited from a base class
333 if not self._schema.has_insert_id:
334 raise ValueError("APDB is not configured for history storage")
336 insert_ids = [id.id for id in ids]
337 params = ",".join("?" * len(insert_ids))
339 # everything goes into a single partition
340 partition = 0
342 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
343 query = (
344 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})'
345 )
347 self._session.execute(
348 self._prep_statement(query),
349 [partition] + insert_ids,
350 timeout=self.config.write_timeout,
351 )
353 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
354 # docstring is inherited from a base class
355 return self._get_history(ExtraTables.DiaObjectInsertId, ids)
357 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
358 # docstring is inherited from a base class
359 return self._get_history(ExtraTables.DiaSourceInsertId, ids)
361 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
362 # docstring is inherited from a base class
363 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids)
365 def getSSObjects(self) -> pandas.DataFrame:
366 # docstring is inherited from a base class
367 tableName = self._schema.tableName(ApdbTables.SSObject)
368 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
370 objects = None
371 with Timer("SSObject select", self.config.timer):
372 result = self._session.execute(query, execution_profile="read_pandas")
373 objects = result._current_rows
375 _LOG.debug("found %s DiaObjects", objects.shape[0])
376 return objects
378 def store(
379 self,
380 visit_time: dafBase.DateTime,
381 objects: pandas.DataFrame,
382 sources: Optional[pandas.DataFrame] = None,
383 forced_sources: Optional[pandas.DataFrame] = None,
384 ) -> None:
385 # docstring is inherited from a base class
387 insert_id: ApdbInsertId | None = None
388 if self._schema.has_insert_id:
389 insert_id = ApdbInsertId.new_insert_id()
390 self._storeInsertId(insert_id, visit_time)
392 # fill region partition column for DiaObjects
393 objects = self._add_obj_part(objects)
394 self._storeDiaObjects(objects, visit_time, insert_id)
396 if sources is not None:
397 # copy apdb_part column from DiaObjects to DiaSources
398 sources = self._add_src_part(sources, objects)
399 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id)
400 self._storeDiaSourcesPartitions(sources, visit_time, insert_id)
402 if forced_sources is not None:
403 forced_sources = self._add_fsrc_part(forced_sources, objects)
404 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id)
406 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
407 # docstring is inherited from a base class
408 self._storeObjectsPandas(objects, ApdbTables.SSObject)
410 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
411 # docstring is inherited from a base class
413 # To update a record we need to know its exact primary key (including
414 # partition key) so we start by querying for diaSourceId to find the
415 # primary keys.
417 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
418 # split it into 1k IDs per query
419 selects: List[Tuple] = []
420 for ids in chunk_iterable(idMap.keys(), 1_000):
421 ids_str = ",".join(str(item) for item in ids)
422 selects.append(
423 (
424 (
425 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" '
426 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})'
427 ),
428 {},
429 )
430 )
432 # No need for DataFrame here, read data as tuples.
433 result = cast(
434 List[Tuple[int, int, int, uuid.UUID | None]],
435 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency),
436 )
438 # Make mapping from source ID to its partition.
439 id2partitions: Dict[int, Tuple[int, int]] = {}
440 id2insert_id: Dict[int, ApdbInsertId] = {}
441 for row in result:
442 id2partitions[row[0]] = row[1:3]
443 if row[3] is not None:
444 id2insert_id[row[0]] = ApdbInsertId(row[3])
446 # make sure we know partitions for each ID
447 if set(id2partitions) != set(idMap):
448 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
449 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
451 # Reassign in standard tables
452 queries = cassandra.query.BatchStatement()
453 table_name = self._schema.tableName(ApdbTables.DiaSource)
454 for diaSourceId, ssObjectId in idMap.items():
455 apdb_part, apdb_time_part = id2partitions[diaSourceId]
456 values: Tuple
457 if self.config.time_partition_tables:
458 query = (
459 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
460 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
461 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
462 )
463 values = (ssObjectId, apdb_part, diaSourceId)
464 else:
465 query = (
466 f'UPDATE "{self._keyspace}"."{table_name}"'
467 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
468 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
469 )
470 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
471 queries.add(self._prep_statement(query), values)
473 # Reassign in history tables, only if history is enabled
474 if id2insert_id:
475 # Filter out insert ids that have been deleted already. There is a
476 # potential race with concurrent removal of insert IDs, but it
477 # should be handled by WHERE in UPDATE.
478 known_ids = set()
479 if insert_ids := self.getInsertIds():
480 known_ids = set(insert_ids)
481 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids}
482 if id2insert_id:
483 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId)
484 for diaSourceId, ssObjectId in idMap.items():
485 if insert_id := id2insert_id.get(diaSourceId):
486 query = (
487 f'UPDATE "{self._keyspace}"."{table_name}" '
488 ' SET "ssObjectId" = ?, "diaObjectId" = NULL '
489 'WHERE "insert_id" = ? AND "diaSourceId" = ?'
490 )
491 values = (ssObjectId, insert_id.id, diaSourceId)
492 queries.add(self._prep_statement(query), values)
494 _LOG.debug("%s: will update %d records", table_name, len(idMap))
495 with Timer(table_name + " update", self.config.timer):
496 self._session.execute(queries, execution_profile="write")
498 def dailyJob(self) -> None:
499 # docstring is inherited from a base class
500 pass
502 def countUnassociatedObjects(self) -> int:
503 # docstring is inherited from a base class
505 # It's too inefficient to implement it for Cassandra in current schema.
506 raise NotImplementedError()
508 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
509 """Make all execution profiles used in the code."""
511 if config.private_ips:
512 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
513 else:
514 loadBalancePolicy = RoundRobinPolicy()
516 read_tuples_profile = ExecutionProfile(
517 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
518 request_timeout=config.read_timeout,
519 row_factory=cassandra.query.tuple_factory,
520 load_balancing_policy=loadBalancePolicy,
521 )
522 read_pandas_profile = ExecutionProfile(
523 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
524 request_timeout=config.read_timeout,
525 row_factory=pandas_dataframe_factory,
526 load_balancing_policy=loadBalancePolicy,
527 )
528 read_raw_profile = ExecutionProfile(
529 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
530 request_timeout=config.read_timeout,
531 row_factory=raw_data_factory,
532 load_balancing_policy=loadBalancePolicy,
533 )
534 # Profile to use with select_concurrent to return pandas data frame
535 read_pandas_multi_profile = ExecutionProfile(
536 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
537 request_timeout=config.read_timeout,
538 row_factory=pandas_dataframe_factory,
539 load_balancing_policy=loadBalancePolicy,
540 )
541 # Profile to use with select_concurrent to return raw data (columns and
542 # rows)
543 read_raw_multi_profile = ExecutionProfile(
544 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
545 request_timeout=config.read_timeout,
546 row_factory=raw_data_factory,
547 load_balancing_policy=loadBalancePolicy,
548 )
549 write_profile = ExecutionProfile(
550 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
551 request_timeout=config.write_timeout,
552 load_balancing_policy=loadBalancePolicy,
553 )
554 # To replace default DCAwareRoundRobinPolicy
555 default_profile = ExecutionProfile(
556 load_balancing_policy=loadBalancePolicy,
557 )
558 return {
559 "read_tuples": read_tuples_profile,
560 "read_pandas": read_pandas_profile,
561 "read_raw": read_raw_profile,
562 "read_pandas_multi": read_pandas_multi_profile,
563 "read_raw_multi": read_raw_multi_profile,
564 "write": write_profile,
565 EXEC_PROFILE_DEFAULT: default_profile,
566 }
568 def _getSources(
569 self,
570 region: sphgeom.Region,
571 object_ids: Optional[Iterable[int]],
572 mjd_start: float,
573 mjd_end: float,
574 table_name: ApdbTables,
575 ) -> pandas.DataFrame:
576 """Returns catalog of DiaSource instances given set of DiaObject IDs.
578 Parameters
579 ----------
580 region : `lsst.sphgeom.Region`
581 Spherical region.
582 object_ids :
583 Collection of DiaObject IDs
584 mjd_start : `float`
585 Lower bound of time interval.
586 mjd_end : `float`
587 Upper bound of time interval.
588 table_name : `ApdbTables`
589 Name of the table.
591 Returns
592 -------
593 catalog : `pandas.DataFrame`, or `None`
594 Catalog contaning DiaSource records. Empty catalog is returned if
595 ``object_ids`` is empty.
596 """
597 object_id_set: Set[int] = set()
598 if object_ids is not None:
599 object_id_set = set(object_ids)
600 if len(object_id_set) == 0:
601 return self._make_empty_catalog(table_name)
603 sp_where = self._spatial_where(region)
604 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
606 # We need to exclude extra partitioning columns from result.
607 column_names = self._schema.apdbColumnNames(table_name)
608 what = ",".join(_quote_column(column) for column in column_names)
610 # Build all queries
611 statements: List[Tuple] = []
612 for table in tables:
613 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"'
614 statements += list(self._combine_where(prefix, sp_where, temporal_where))
615 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
617 with Timer(table_name.name + " select", self.config.timer):
618 catalog = cast(
619 pandas.DataFrame,
620 select_concurrent(
621 self._session, statements, "read_pandas_multi", self.config.read_concurrency
622 ),
623 )
625 # filter by given object IDs
626 if len(object_id_set) > 0:
627 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
629 # precise filtering on midPointTai
630 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start])
632 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
633 return catalog
635 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
636 """Return records from a particular table given set of insert IDs."""
637 if not self._schema.has_insert_id:
638 raise ValueError("APDB is not configured for history retrieval")
640 insert_ids = [id.id for id in ids]
641 params = ",".join("?" * len(insert_ids))
643 table_name = self._schema.tableName(table)
644 # I know that history table schema has only regular APDB columns plus
645 # an insert_id column, and this is exactly what we need to return from
646 # this method, so selecting a star is fine here.
647 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
648 statement = self._prep_statement(query)
650 with Timer("DiaObject history", self.config.timer):
651 result = self._session.execute(statement, insert_ids, execution_profile="read_raw")
652 table_data = cast(ApdbCassandraTableData, result._current_rows)
653 return table_data
655 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None:
657 # Cassandra timestamp uses milliseconds since epoch
658 timestamp = visit_time.nsecs() // 1_000_000
660 # everything goes into a single partition
661 partition = 0
663 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
664 query = (
665 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) '
666 "VALUES (?, ?, ?)"
667 )
669 self._session.execute(
670 self._prep_statement(query),
671 (partition, insert_id.id, timestamp),
672 timeout=self.config.write_timeout,
673 execution_profile="write",
674 )
676 def _storeDiaObjects(
677 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
678 ) -> None:
679 """Store catalog of DiaObjects from current visit.
681 Parameters
682 ----------
683 objs : `pandas.DataFrame`
684 Catalog with DiaObject records
685 visit_time : `lsst.daf.base.DateTime`
686 Time of the current visit.
687 """
688 visit_time_dt = visit_time.toPython()
689 extra_columns = dict(lastNonForcedSource=visit_time_dt)
690 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
692 extra_columns["validityStart"] = visit_time_dt
693 time_part: Optional[int] = self._time_partition(visit_time)
694 if not self.config.time_partition_tables:
695 extra_columns["apdb_time_part"] = time_part
696 time_part = None
698 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
700 if insert_id is not None:
701 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt)
702 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns)
704 def _storeDiaSources(
705 self,
706 table_name: ApdbTables,
707 sources: pandas.DataFrame,
708 visit_time: dafBase.DateTime,
709 insert_id: ApdbInsertId | None,
710 ) -> None:
711 """Store catalog of DIASources or DIAForcedSources from current visit.
713 Parameters
714 ----------
715 sources : `pandas.DataFrame`
716 Catalog containing DiaSource records
717 visit_time : `lsst.daf.base.DateTime`
718 Time of the current visit.
719 """
720 time_part: Optional[int] = self._time_partition(visit_time)
721 extra_columns: dict[str, Any] = {}
722 if not self.config.time_partition_tables:
723 extra_columns["apdb_time_part"] = time_part
724 time_part = None
726 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
728 if insert_id is not None:
729 extra_columns = dict(insert_id=insert_id.id)
730 if table_name is ApdbTables.DiaSource:
731 extra_table = ExtraTables.DiaSourceInsertId
732 else:
733 extra_table = ExtraTables.DiaForcedSourceInsertId
734 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
736 def _storeDiaSourcesPartitions(
737 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
738 ) -> None:
739 """Store mapping of diaSourceId to its partitioning values.
741 Parameters
742 ----------
743 sources : `pandas.DataFrame`
744 Catalog containing DiaSource records
745 visit_time : `lsst.daf.base.DateTime`
746 Time of the current visit.
747 """
748 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
749 extra_columns = {
750 "apdb_time_part": self._time_partition(visit_time),
751 "insert_id": insert_id.id if insert_id is not None else None,
752 }
754 self._storeObjectsPandas(
755 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
756 )
758 def _storeObjectsPandas(
759 self,
760 records: pandas.DataFrame,
761 table_name: Union[ApdbTables, ExtraTables],
762 extra_columns: Optional[Mapping] = None,
763 time_part: Optional[int] = None,
764 ) -> None:
765 """Generic store method.
767 Takes Pandas catalog and stores a bunch of records in a table.
769 Parameters
770 ----------
771 records : `pandas.DataFrame`
772 Catalog containing object records
773 table_name : `ApdbTables`
774 Name of the table as defined in APDB schema.
775 extra_columns : `dict`, optional
776 Mapping (column_name, column_value) which gives fixed values for
777 columns in each row, overrides values in ``records`` if matching
778 columns exist there.
779 time_part : `int`, optional
780 If not `None` then insert into a per-partition table.
782 Notes
783 -----
784 If Pandas catalog contains additional columns not defined in table
785 schema they are ignored. Catalog does not have to contain all columns
786 defined in a table, but partition and clustering keys must be present
787 in a catalog or ``extra_columns``.
788 """
789 # use extra columns if specified
790 if extra_columns is None:
791 extra_columns = {}
792 extra_fields = list(extra_columns.keys())
794 # Fields that will come from dataframe.
795 df_fields = [column for column in records.columns if column not in extra_fields]
797 column_map = self._schema.getColumnMap(table_name)
798 # list of columns (as in felis schema)
799 fields = [column_map[field].name for field in df_fields if field in column_map]
800 fields += extra_fields
802 # check that all partitioning and clustering columns are defined
803 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns(
804 table_name
805 )
806 missing_columns = [column for column in required_columns if column not in fields]
807 if missing_columns:
808 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
810 qfields = [quote_id(field) for field in fields]
811 qfields_str = ",".join(qfields)
813 with Timer(table_name.name + " query build", self.config.timer):
815 table = self._schema.tableName(table_name)
816 if time_part is not None:
817 table = f"{table}_{time_part}"
819 holders = ",".join(["?"] * len(qfields))
820 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
821 statement = self._prep_statement(query)
822 queries = cassandra.query.BatchStatement()
823 for rec in records.itertuples(index=False):
824 values = []
825 for field in df_fields:
826 if field not in column_map:
827 continue
828 value = getattr(rec, field)
829 if column_map[field].datatype is felis.types.Timestamp:
830 if isinstance(value, pandas.Timestamp):
831 value = literal(value.to_pydatetime())
832 else:
833 # Assume it's seconds since epoch, Cassandra
834 # datetime is in milliseconds
835 value = int(value * 1000)
836 values.append(literal(value))
837 for field in extra_fields:
838 value = extra_columns[field]
839 values.append(literal(value))
840 queries.add(statement, values)
842 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0])
843 with Timer(table_name.name + " insert", self.config.timer):
844 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
846 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
847 """Calculate spacial partition for each record and add it to a
848 DataFrame.
850 Notes
851 -----
852 This overrides any existing column in a DataFrame with the same name
853 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
854 returned.
855 """
856 # calculate HTM index for every DiaObject
857 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
858 ra_col, dec_col = self.config.ra_dec_columns
859 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
860 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
861 idx = self._pixelization.pixel(uv3d)
862 apdb_part[i] = idx
863 df = df.copy()
864 df["apdb_part"] = apdb_part
865 return df
867 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
868 """Add apdb_part column to DiaSource catalog.
870 Notes
871 -----
872 This method copies apdb_part value from a matching DiaObject record.
873 DiaObject catalog needs to have a apdb_part column filled by
874 ``_add_obj_part`` method and DiaSource records need to be
875 associated to DiaObjects via ``diaObjectId`` column.
877 This overrides any existing column in a DataFrame with the same name
878 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
879 returned.
880 """
881 pixel_id_map: Dict[int, int] = {
882 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
883 }
884 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
885 ra_col, dec_col = self.config.ra_dec_columns
886 for i, (diaObjId, ra, dec) in enumerate(
887 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col])
888 ):
889 if diaObjId == 0:
890 # DiaSources associated with SolarSystemObjects do not have an
891 # associated DiaObject hence we skip them and set partition
892 # based on its own ra/dec
893 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
894 idx = self._pixelization.pixel(uv3d)
895 apdb_part[i] = idx
896 else:
897 apdb_part[i] = pixel_id_map[diaObjId]
898 sources = sources.copy()
899 sources["apdb_part"] = apdb_part
900 return sources
902 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
903 """Add apdb_part column to DiaForcedSource catalog.
905 Notes
906 -----
907 This method copies apdb_part value from a matching DiaObject record.
908 DiaObject catalog needs to have a apdb_part column filled by
909 ``_add_obj_part`` method and DiaSource records need to be
910 associated to DiaObjects via ``diaObjectId`` column.
912 This overrides any existing column in a DataFrame with the same name
913 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
914 returned.
915 """
916 pixel_id_map: Dict[int, int] = {
917 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
918 }
919 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
920 for i, diaObjId in enumerate(sources["diaObjectId"]):
921 apdb_part[i] = pixel_id_map[diaObjId]
922 sources = sources.copy()
923 sources["apdb_part"] = apdb_part
924 return sources
926 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
927 """Calculate time partiton number for a given time.
929 Parameters
930 ----------
931 time : `float` or `lsst.daf.base.DateTime`
932 Time for which to calculate partition number. Can be float to mean
933 MJD or `lsst.daf.base.DateTime`
935 Returns
936 -------
937 partition : `int`
938 Partition number for a given time.
939 """
940 if isinstance(time, dafBase.DateTime):
941 mjd = time.get(system=dafBase.DateTime.MJD)
942 else:
943 mjd = time
944 days_since_epoch = mjd - self._partition_zero_epoch_mjd
945 partition = int(days_since_epoch) // self.config.time_partition_days
946 return partition
948 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
949 """Make an empty catalog for a table with a given name.
951 Parameters
952 ----------
953 table_name : `ApdbTables`
954 Name of the table.
956 Returns
957 -------
958 catalog : `pandas.DataFrame`
959 An empty catalog.
960 """
961 table = self._schema.tableSchemas[table_name]
963 data = {
964 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
965 for columnDef in table.columns
966 }
967 return pandas.DataFrame(data)
969 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
970 """Convert query string into prepared statement."""
971 stmt = self._prepared_statements.get(query)
972 if stmt is None:
973 stmt = self._session.prepare(query)
974 self._prepared_statements[query] = stmt
975 return stmt
977 def _combine_where(
978 self,
979 prefix: str,
980 where1: List[Tuple[str, Tuple]],
981 where2: List[Tuple[str, Tuple]],
982 suffix: Optional[str] = None,
983 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
984 """Make cartesian product of two parts of WHERE clause into a series
985 of statements to execute.
987 Parameters
988 ----------
989 prefix : `str`
990 Initial statement prefix that comes before WHERE clause, e.g.
991 "SELECT * from Table"
992 """
993 # If lists are empty use special sentinels.
994 if not where1:
995 where1 = [("", ())]
996 if not where2:
997 where2 = [("", ())]
999 for expr1, params1 in where1:
1000 for expr2, params2 in where2:
1001 full_query = prefix
1002 wheres = []
1003 if expr1:
1004 wheres.append(expr1)
1005 if expr2:
1006 wheres.append(expr2)
1007 if wheres:
1008 full_query += " WHERE " + " AND ".join(wheres)
1009 if suffix:
1010 full_query += " " + suffix
1011 params = params1 + params2
1012 if params:
1013 statement = self._prep_statement(full_query)
1014 else:
1015 # If there are no params then it is likely that query
1016 # has a bunch of literals rendered already, no point
1017 # trying to prepare it.
1018 statement = cassandra.query.SimpleStatement(full_query)
1019 yield (statement, params)
1021 def _spatial_where(
1022 self, region: Optional[sphgeom.Region], use_ranges: bool = False
1023 ) -> List[Tuple[str, Tuple]]:
1024 """Generate expressions for spatial part of WHERE clause.
1026 Parameters
1027 ----------
1028 region : `sphgeom.Region`
1029 Spatial region for query results.
1030 use_ranges : `bool`
1031 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
1032 p2") instead of exact list of pixels. Should be set to True for
1033 large regions covering very many pixels.
1035 Returns
1036 -------
1037 expressions : `list` [ `tuple` ]
1038 Empty list is returned if ``region`` is `None`, otherwise a list
1039 of one or more (expression, parameters) tuples
1040 """
1041 if region is None:
1042 return []
1043 if use_ranges:
1044 pixel_ranges = self._pixelization.envelope(region)
1045 expressions: List[Tuple[str, Tuple]] = []
1046 for lower, upper in pixel_ranges:
1047 upper -= 1
1048 if lower == upper:
1049 expressions.append(('"apdb_part" = ?', (lower,)))
1050 else:
1051 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
1052 return expressions
1053 else:
1054 pixels = self._pixelization.pixels(region)
1055 if self.config.query_per_spatial_part:
1056 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1057 else:
1058 pixels_str = ",".join([str(pix) for pix in pixels])
1059 return [(f'"apdb_part" IN ({pixels_str})', ())]
1061 def _temporal_where(
1062 self,
1063 table: ApdbTables,
1064 start_time: Union[float, dafBase.DateTime],
1065 end_time: Union[float, dafBase.DateTime],
1066 query_per_time_part: Optional[bool] = None,
1067 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
1068 """Generate table names and expressions for temporal part of WHERE
1069 clauses.
1071 Parameters
1072 ----------
1073 table : `ApdbTables`
1074 Table to select from.
1075 start_time : `dafBase.DateTime` or `float`
1076 Starting Datetime of MJD value of the time range.
1077 start_time : `dafBase.DateTime` or `float`
1078 Starting Datetime of MJD value of the time range.
1079 query_per_time_part : `bool`, optional
1080 If None then use ``query_per_time_part`` from configuration.
1082 Returns
1083 -------
1084 tables : `list` [ `str` ]
1085 List of the table names to query.
1086 expressions : `list` [ `tuple` ]
1087 A list of zero or more (expression, parameters) tuples.
1088 """
1089 tables: List[str]
1090 temporal_where: List[Tuple[str, Tuple]] = []
1091 table_name = self._schema.tableName(table)
1092 time_part_start = self._time_partition(start_time)
1093 time_part_end = self._time_partition(end_time)
1094 time_parts = list(range(time_part_start, time_part_end + 1))
1095 if self.config.time_partition_tables:
1096 tables = [f"{table_name}_{part}" for part in time_parts]
1097 else:
1098 tables = [table_name]
1099 if query_per_time_part is None:
1100 query_per_time_part = self.config.query_per_time_part
1101 if query_per_time_part:
1102 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts]
1103 else:
1104 time_part_list = ",".join([str(part) for part in time_parts])
1105 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1107 return tables, temporal_where