Coverage for python/lsst/dax/apdb/apdbCassandra.py: 15%
475 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-26 10:23 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-26 10:23 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26import logging
27import uuid
28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast
30import numpy as np
31import pandas
33# If cassandra-driver is not there the module can still be imported
34# but ApdbCassandra cannot be instantiated.
35try:
36 import cassandra
37 import cassandra.query
38 from cassandra.auth import AuthProvider, PlainTextAuthProvider
39 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
40 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
42 CASSANDRA_IMPORTED = True
43except ImportError:
44 CASSANDRA_IMPORTED = False
46import felis.types
47import lsst.daf.base as dafBase
48from felis.simple import Table
49from lsst import sphgeom
50from lsst.pex.config import ChoiceField, Field, ListField
51from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError
52from lsst.utils.iteration import chunk_iterable
54from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
55from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
56from .apdbSchema import ApdbTables
57from .cassandra_utils import (
58 ApdbCassandraTableData,
59 literal,
60 pandas_dataframe_factory,
61 quote_id,
62 raw_data_factory,
63 select_concurrent,
64)
65from .pixelization import Pixelization
66from .timer import Timer
68_LOG = logging.getLogger(__name__)
70# Copied from daf_butler.
71DB_AUTH_ENVVAR = "LSST_DB_AUTH"
72"""Default name of the environmental variable that will be used to locate DB
73credentials configuration file. """
75DB_AUTH_PATH = "~/.lsst/db-auth.yaml"
76"""Default path at which it is expected that DB credentials are found."""
79class CassandraMissingError(Exception):
80 def __init__(self) -> None:
81 super().__init__("cassandra-driver module cannot be imported")
84class ApdbCassandraConfig(ApdbConfig):
85 """Configuration class for Cassandra-based APDB implementation."""
87 contact_points = ListField[str](
88 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"]
89 )
90 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[])
91 port = Field[int](doc="Port number to connect to.", default=9042)
92 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb")
93 username = Field[str](
94 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.",
95 default="",
96 )
97 read_consistency = Field[str](
98 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM"
99 )
100 write_consistency = Field[str](
101 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM"
102 )
103 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0)
104 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0)
105 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500)
106 protocol_version = Field[int](
107 doc="Cassandra protocol version to use, default is V4",
108 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0,
109 )
110 dia_object_columns = ListField[str](
111 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[]
112 )
113 prefix = Field[str](doc="Prefix to add to table names", default="")
114 part_pixelization = ChoiceField[str](
115 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
116 doc="Pixelization used for partitioning index.",
117 default="mq3c",
118 )
119 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10)
120 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64)
121 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
122 timer = Field[bool](doc="If True then print/log timing information", default=False)
123 time_partition_tables = Field[bool](
124 doc="Use per-partition tables for sources instead of partitioning by time", default=True
125 )
126 time_partition_days = Field[int](
127 doc=(
128 "Time partitioning granularity in days, this value must not be changed after database is "
129 "initialized"
130 ),
131 default=30,
132 )
133 time_partition_start = Field[str](
134 doc=(
135 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
136 "This is used only when time_partition_tables is True."
137 ),
138 default="2018-12-01T00:00:00",
139 )
140 time_partition_end = Field[str](
141 doc=(
142 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
143 "This is used only when time_partition_tables is True."
144 ),
145 default="2030-01-01T00:00:00",
146 )
147 query_per_time_part = Field[bool](
148 default=False,
149 doc=(
150 "If True then build separate query for each time partition, otherwise build one single query. "
151 "This is only used when time_partition_tables is False in schema config."
152 ),
153 )
154 query_per_spatial_part = Field[bool](
155 default=False,
156 doc="If True then build one query per spatial partition, otherwise build single query.",
157 )
160if CASSANDRA_IMPORTED: 160 ↛ 175line 160 didn't jump to line 175, because the condition on line 160 was never false
162 class _AddressTranslator(AddressTranslator):
163 """Translate internal IP address to external.
165 Only used for docker-based setup, not viable long-term solution.
166 """
168 def __init__(self, public_ips: List[str], private_ips: List[str]):
169 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
171 def translate(self, private_ip: str) -> str:
172 return self._map.get(private_ip, private_ip)
175def _quote_column(name: str) -> str:
176 """Quote column name"""
177 if name.islower():
178 return name
179 else:
180 return f'"{name}"'
183class ApdbCassandra(Apdb):
184 """Implementation of APDB database on to of Apache Cassandra.
186 The implementation is configured via standard ``pex_config`` mechanism
187 using `ApdbCassandraConfig` configuration class. For an example of
188 different configurations check config/ folder.
190 Parameters
191 ----------
192 config : `ApdbCassandraConfig`
193 Configuration object.
194 """
196 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
197 """Start time for partition 0, this should never be changed."""
199 def __init__(self, config: ApdbCassandraConfig):
200 if not CASSANDRA_IMPORTED:
201 raise CassandraMissingError()
203 config.validate()
204 self.config = config
206 _LOG.debug("ApdbCassandra Configuration:")
207 for key, value in self.config.items():
208 _LOG.debug(" %s: %s", key, value)
210 self._pixelization = Pixelization(
211 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
212 )
214 addressTranslator: Optional[AddressTranslator] = None
215 if config.private_ips:
216 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips))
218 self._keyspace = config.keyspace
220 self._cluster = Cluster(
221 execution_profiles=self._makeProfiles(config),
222 contact_points=self.config.contact_points,
223 port=self.config.port,
224 address_translator=addressTranslator,
225 protocol_version=self.config.protocol_version,
226 auth_provider=self._make_auth_provider(config),
227 )
228 self._session = self._cluster.connect()
229 # Disable result paging
230 self._session.default_fetch_size = None
232 self._schema = ApdbCassandraSchema(
233 session=self._session,
234 keyspace=self._keyspace,
235 schema_file=self.config.schema_file,
236 schema_name=self.config.schema_name,
237 prefix=self.config.prefix,
238 time_partition_tables=self.config.time_partition_tables,
239 use_insert_id=self.config.use_insert_id,
240 )
241 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
243 # Cache for prepared statements
244 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
246 def __del__(self) -> None:
247 if hasattr(self, "_cluster"):
248 self._cluster.shutdown()
250 def _make_auth_provider(self, config: ApdbCassandraConfig) -> AuthProvider | None:
251 """Make Cassandra authentication provider instance."""
252 try:
253 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR)
254 except DbAuthNotFoundError:
255 # Credentials file doesn't exist, use anonymous login.
256 return None
258 empty_username = True
259 # Try every contact point in turn.
260 for hostname in config.contact_points:
261 try:
262 username, password = dbauth.getAuth(
263 "cassandra", config.username, hostname, config.port, config.keyspace
264 )
265 if not username:
266 # Password without user name, try next hostname, but give
267 # warning later if no better match is found.
268 empty_username = True
269 else:
270 return PlainTextAuthProvider(username=username, password=password)
271 except DbAuthNotFoundError:
272 pass
274 if empty_username:
275 _LOG.warning(
276 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not "
277 f"user name, anonymous Cassandra logon will be attempted."
278 )
280 return None
282 def tableDef(self, table: ApdbTables) -> Optional[Table]:
283 # docstring is inherited from a base class
284 return self._schema.tableSchemas.get(table)
286 def makeSchema(self, drop: bool = False) -> None:
287 # docstring is inherited from a base class
289 if self.config.time_partition_tables:
290 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
291 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
292 part_range = (
293 self._time_partition(time_partition_start),
294 self._time_partition(time_partition_end) + 1,
295 )
296 self._schema.makeSchema(drop=drop, part_range=part_range)
297 else:
298 self._schema.makeSchema(drop=drop)
300 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
301 # docstring is inherited from a base class
303 sp_where = self._spatial_where(region)
304 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
306 # We need to exclude extra partitioning columns from result.
307 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
308 what = ",".join(_quote_column(column) for column in column_names)
310 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
311 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"'
312 statements: List[Tuple] = []
313 for where, params in sp_where:
314 full_query = f"{query} WHERE {where}"
315 if params:
316 statement = self._prep_statement(full_query)
317 else:
318 # If there are no params then it is likely that query has a
319 # bunch of literals rendered already, no point trying to
320 # prepare it because it's not reusable.
321 statement = cassandra.query.SimpleStatement(full_query)
322 statements.append((statement, params))
323 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
325 with Timer("DiaObject select", self.config.timer):
326 objects = cast(
327 pandas.DataFrame,
328 select_concurrent(
329 self._session, statements, "read_pandas_multi", self.config.read_concurrency
330 ),
331 )
333 _LOG.debug("found %s DiaObjects", objects.shape[0])
334 return objects
336 def getDiaSources(
337 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
338 ) -> Optional[pandas.DataFrame]:
339 # docstring is inherited from a base class
340 months = self.config.read_sources_months
341 if months == 0:
342 return None
343 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
344 mjd_start = mjd_end - months * 30
346 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
348 def getDiaForcedSources(
349 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
350 ) -> Optional[pandas.DataFrame]:
351 # docstring is inherited from a base class
352 months = self.config.read_forced_sources_months
353 if months == 0:
354 return None
355 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
356 mjd_start = mjd_end - months * 30
358 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
360 def getInsertIds(self) -> list[ApdbInsertId] | None:
361 # docstring is inherited from a base class
362 if not self._schema.has_insert_id:
363 return None
365 # everything goes into a single partition
366 partition = 0
368 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
369 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?'
371 result = self._session.execute(
372 self._prep_statement(query),
373 (partition,),
374 timeout=self.config.read_timeout,
375 execution_profile="read_tuples",
376 )
377 # order by insert_time
378 rows = sorted(result)
379 return [ApdbInsertId(row[1]) for row in rows]
381 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
382 # docstring is inherited from a base class
383 if not self._schema.has_insert_id:
384 raise ValueError("APDB is not configured for history storage")
386 insert_ids = [id.id for id in ids]
387 params = ",".join("?" * len(insert_ids))
389 # everything goes into a single partition
390 partition = 0
392 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
393 query = (
394 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})'
395 )
397 self._session.execute(
398 self._prep_statement(query),
399 [partition] + insert_ids,
400 timeout=self.config.write_timeout,
401 )
403 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
404 # docstring is inherited from a base class
405 return self._get_history(ExtraTables.DiaObjectInsertId, ids)
407 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
408 # docstring is inherited from a base class
409 return self._get_history(ExtraTables.DiaSourceInsertId, ids)
411 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
412 # docstring is inherited from a base class
413 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids)
415 def getSSObjects(self) -> pandas.DataFrame:
416 # docstring is inherited from a base class
417 tableName = self._schema.tableName(ApdbTables.SSObject)
418 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
420 objects = None
421 with Timer("SSObject select", self.config.timer):
422 result = self._session.execute(query, execution_profile="read_pandas")
423 objects = result._current_rows
425 _LOG.debug("found %s DiaObjects", objects.shape[0])
426 return objects
428 def store(
429 self,
430 visit_time: dafBase.DateTime,
431 objects: pandas.DataFrame,
432 sources: Optional[pandas.DataFrame] = None,
433 forced_sources: Optional[pandas.DataFrame] = None,
434 ) -> None:
435 # docstring is inherited from a base class
437 insert_id: ApdbInsertId | None = None
438 if self._schema.has_insert_id:
439 insert_id = ApdbInsertId.new_insert_id()
440 self._storeInsertId(insert_id, visit_time)
442 # fill region partition column for DiaObjects
443 objects = self._add_obj_part(objects)
444 self._storeDiaObjects(objects, visit_time, insert_id)
446 if sources is not None:
447 # copy apdb_part column from DiaObjects to DiaSources
448 sources = self._add_src_part(sources, objects)
449 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id)
450 self._storeDiaSourcesPartitions(sources, visit_time, insert_id)
452 if forced_sources is not None:
453 forced_sources = self._add_fsrc_part(forced_sources, objects)
454 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id)
456 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
457 # docstring is inherited from a base class
458 self._storeObjectsPandas(objects, ApdbTables.SSObject)
460 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
461 # docstring is inherited from a base class
463 # To update a record we need to know its exact primary key (including
464 # partition key) so we start by querying for diaSourceId to find the
465 # primary keys.
467 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
468 # split it into 1k IDs per query
469 selects: List[Tuple] = []
470 for ids in chunk_iterable(idMap.keys(), 1_000):
471 ids_str = ",".join(str(item) for item in ids)
472 selects.append(
473 (
474 (
475 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" '
476 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})'
477 ),
478 {},
479 )
480 )
482 # No need for DataFrame here, read data as tuples.
483 result = cast(
484 List[Tuple[int, int, int, uuid.UUID | None]],
485 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency),
486 )
488 # Make mapping from source ID to its partition.
489 id2partitions: Dict[int, Tuple[int, int]] = {}
490 id2insert_id: Dict[int, ApdbInsertId] = {}
491 for row in result:
492 id2partitions[row[0]] = row[1:3]
493 if row[3] is not None:
494 id2insert_id[row[0]] = ApdbInsertId(row[3])
496 # make sure we know partitions for each ID
497 if set(id2partitions) != set(idMap):
498 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
499 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
501 # Reassign in standard tables
502 queries = cassandra.query.BatchStatement()
503 table_name = self._schema.tableName(ApdbTables.DiaSource)
504 for diaSourceId, ssObjectId in idMap.items():
505 apdb_part, apdb_time_part = id2partitions[diaSourceId]
506 values: Tuple
507 if self.config.time_partition_tables:
508 query = (
509 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
510 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
511 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
512 )
513 values = (ssObjectId, apdb_part, diaSourceId)
514 else:
515 query = (
516 f'UPDATE "{self._keyspace}"."{table_name}"'
517 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
518 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
519 )
520 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
521 queries.add(self._prep_statement(query), values)
523 # Reassign in history tables, only if history is enabled
524 if id2insert_id:
525 # Filter out insert ids that have been deleted already. There is a
526 # potential race with concurrent removal of insert IDs, but it
527 # should be handled by WHERE in UPDATE.
528 known_ids = set()
529 if insert_ids := self.getInsertIds():
530 known_ids = set(insert_ids)
531 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids}
532 if id2insert_id:
533 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId)
534 for diaSourceId, ssObjectId in idMap.items():
535 if insert_id := id2insert_id.get(diaSourceId):
536 query = (
537 f'UPDATE "{self._keyspace}"."{table_name}" '
538 ' SET "ssObjectId" = ?, "diaObjectId" = NULL '
539 'WHERE "insert_id" = ? AND "diaSourceId" = ?'
540 )
541 values = (ssObjectId, insert_id.id, diaSourceId)
542 queries.add(self._prep_statement(query), values)
544 _LOG.debug("%s: will update %d records", table_name, len(idMap))
545 with Timer(table_name + " update", self.config.timer):
546 self._session.execute(queries, execution_profile="write")
548 def dailyJob(self) -> None:
549 # docstring is inherited from a base class
550 pass
552 def countUnassociatedObjects(self) -> int:
553 # docstring is inherited from a base class
555 # It's too inefficient to implement it for Cassandra in current schema.
556 raise NotImplementedError()
558 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
559 """Make all execution profiles used in the code."""
560 if config.private_ips:
561 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
562 else:
563 loadBalancePolicy = RoundRobinPolicy()
565 read_tuples_profile = ExecutionProfile(
566 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
567 request_timeout=config.read_timeout,
568 row_factory=cassandra.query.tuple_factory,
569 load_balancing_policy=loadBalancePolicy,
570 )
571 read_pandas_profile = ExecutionProfile(
572 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
573 request_timeout=config.read_timeout,
574 row_factory=pandas_dataframe_factory,
575 load_balancing_policy=loadBalancePolicy,
576 )
577 read_raw_profile = ExecutionProfile(
578 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
579 request_timeout=config.read_timeout,
580 row_factory=raw_data_factory,
581 load_balancing_policy=loadBalancePolicy,
582 )
583 # Profile to use with select_concurrent to return pandas data frame
584 read_pandas_multi_profile = ExecutionProfile(
585 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
586 request_timeout=config.read_timeout,
587 row_factory=pandas_dataframe_factory,
588 load_balancing_policy=loadBalancePolicy,
589 )
590 # Profile to use with select_concurrent to return raw data (columns and
591 # rows)
592 read_raw_multi_profile = ExecutionProfile(
593 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
594 request_timeout=config.read_timeout,
595 row_factory=raw_data_factory,
596 load_balancing_policy=loadBalancePolicy,
597 )
598 write_profile = ExecutionProfile(
599 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
600 request_timeout=config.write_timeout,
601 load_balancing_policy=loadBalancePolicy,
602 )
603 # To replace default DCAwareRoundRobinPolicy
604 default_profile = ExecutionProfile(
605 load_balancing_policy=loadBalancePolicy,
606 )
607 return {
608 "read_tuples": read_tuples_profile,
609 "read_pandas": read_pandas_profile,
610 "read_raw": read_raw_profile,
611 "read_pandas_multi": read_pandas_multi_profile,
612 "read_raw_multi": read_raw_multi_profile,
613 "write": write_profile,
614 EXEC_PROFILE_DEFAULT: default_profile,
615 }
617 def _getSources(
618 self,
619 region: sphgeom.Region,
620 object_ids: Optional[Iterable[int]],
621 mjd_start: float,
622 mjd_end: float,
623 table_name: ApdbTables,
624 ) -> pandas.DataFrame:
625 """Return catalog of DiaSource instances given set of DiaObject IDs.
627 Parameters
628 ----------
629 region : `lsst.sphgeom.Region`
630 Spherical region.
631 object_ids :
632 Collection of DiaObject IDs
633 mjd_start : `float`
634 Lower bound of time interval.
635 mjd_end : `float`
636 Upper bound of time interval.
637 table_name : `ApdbTables`
638 Name of the table.
640 Returns
641 -------
642 catalog : `pandas.DataFrame`, or `None`
643 Catalog containing DiaSource records. Empty catalog is returned if
644 ``object_ids`` is empty.
645 """
646 object_id_set: Set[int] = set()
647 if object_ids is not None:
648 object_id_set = set(object_ids)
649 if len(object_id_set) == 0:
650 return self._make_empty_catalog(table_name)
652 sp_where = self._spatial_where(region)
653 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
655 # We need to exclude extra partitioning columns from result.
656 column_names = self._schema.apdbColumnNames(table_name)
657 what = ",".join(_quote_column(column) for column in column_names)
659 # Build all queries
660 statements: List[Tuple] = []
661 for table in tables:
662 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"'
663 statements += list(self._combine_where(prefix, sp_where, temporal_where))
664 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
666 with Timer(table_name.name + " select", self.config.timer):
667 catalog = cast(
668 pandas.DataFrame,
669 select_concurrent(
670 self._session, statements, "read_pandas_multi", self.config.read_concurrency
671 ),
672 )
674 # filter by given object IDs
675 if len(object_id_set) > 0:
676 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
678 # precise filtering on midpointMjdTai
679 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
681 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
682 return catalog
684 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
685 """Return records from a particular table given set of insert IDs."""
686 if not self._schema.has_insert_id:
687 raise ValueError("APDB is not configured for history retrieval")
689 insert_ids = [id.id for id in ids]
690 params = ",".join("?" * len(insert_ids))
692 table_name = self._schema.tableName(table)
693 # I know that history table schema has only regular APDB columns plus
694 # an insert_id column, and this is exactly what we need to return from
695 # this method, so selecting a star is fine here.
696 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
697 statement = self._prep_statement(query)
699 with Timer("DiaObject history", self.config.timer):
700 result = self._session.execute(statement, insert_ids, execution_profile="read_raw")
701 table_data = cast(ApdbCassandraTableData, result._current_rows)
702 return table_data
704 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None:
705 # Cassandra timestamp uses milliseconds since epoch
706 timestamp = visit_time.nsecs() // 1_000_000
708 # everything goes into a single partition
709 partition = 0
711 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
712 query = (
713 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) '
714 "VALUES (?, ?, ?)"
715 )
717 self._session.execute(
718 self._prep_statement(query),
719 (partition, insert_id.id, timestamp),
720 timeout=self.config.write_timeout,
721 execution_profile="write",
722 )
724 def _storeDiaObjects(
725 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
726 ) -> None:
727 """Store catalog of DiaObjects from current visit.
729 Parameters
730 ----------
731 objs : `pandas.DataFrame`
732 Catalog with DiaObject records
733 visit_time : `lsst.daf.base.DateTime`
734 Time of the current visit.
735 """
736 visit_time_dt = visit_time.toPython()
737 extra_columns = dict(lastNonForcedSource=visit_time_dt)
738 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
740 extra_columns["validityStart"] = visit_time_dt
741 time_part: Optional[int] = self._time_partition(visit_time)
742 if not self.config.time_partition_tables:
743 extra_columns["apdb_time_part"] = time_part
744 time_part = None
746 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
748 if insert_id is not None:
749 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt)
750 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns)
752 def _storeDiaSources(
753 self,
754 table_name: ApdbTables,
755 sources: pandas.DataFrame,
756 visit_time: dafBase.DateTime,
757 insert_id: ApdbInsertId | None,
758 ) -> None:
759 """Store catalog of DIASources or DIAForcedSources from current visit.
761 Parameters
762 ----------
763 sources : `pandas.DataFrame`
764 Catalog containing DiaSource records
765 visit_time : `lsst.daf.base.DateTime`
766 Time of the current visit.
767 """
768 time_part: Optional[int] = self._time_partition(visit_time)
769 extra_columns: dict[str, Any] = {}
770 if not self.config.time_partition_tables:
771 extra_columns["apdb_time_part"] = time_part
772 time_part = None
774 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
776 if insert_id is not None:
777 extra_columns = dict(insert_id=insert_id.id)
778 if table_name is ApdbTables.DiaSource:
779 extra_table = ExtraTables.DiaSourceInsertId
780 else:
781 extra_table = ExtraTables.DiaForcedSourceInsertId
782 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
784 def _storeDiaSourcesPartitions(
785 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
786 ) -> None:
787 """Store mapping of diaSourceId to its partitioning values.
789 Parameters
790 ----------
791 sources : `pandas.DataFrame`
792 Catalog containing DiaSource records
793 visit_time : `lsst.daf.base.DateTime`
794 Time of the current visit.
795 """
796 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
797 extra_columns = {
798 "apdb_time_part": self._time_partition(visit_time),
799 "insert_id": insert_id.id if insert_id is not None else None,
800 }
802 self._storeObjectsPandas(
803 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
804 )
806 def _storeObjectsPandas(
807 self,
808 records: pandas.DataFrame,
809 table_name: Union[ApdbTables, ExtraTables],
810 extra_columns: Optional[Mapping] = None,
811 time_part: Optional[int] = None,
812 ) -> None:
813 """Store generic objects.
815 Takes Pandas catalog and stores a bunch of records in a table.
817 Parameters
818 ----------
819 records : `pandas.DataFrame`
820 Catalog containing object records
821 table_name : `ApdbTables`
822 Name of the table as defined in APDB schema.
823 extra_columns : `dict`, optional
824 Mapping (column_name, column_value) which gives fixed values for
825 columns in each row, overrides values in ``records`` if matching
826 columns exist there.
827 time_part : `int`, optional
828 If not `None` then insert into a per-partition table.
830 Notes
831 -----
832 If Pandas catalog contains additional columns not defined in table
833 schema they are ignored. Catalog does not have to contain all columns
834 defined in a table, but partition and clustering keys must be present
835 in a catalog or ``extra_columns``.
836 """
837 # use extra columns if specified
838 if extra_columns is None:
839 extra_columns = {}
840 extra_fields = list(extra_columns.keys())
842 # Fields that will come from dataframe.
843 df_fields = [column for column in records.columns if column not in extra_fields]
845 column_map = self._schema.getColumnMap(table_name)
846 # list of columns (as in felis schema)
847 fields = [column_map[field].name for field in df_fields if field in column_map]
848 fields += extra_fields
850 # check that all partitioning and clustering columns are defined
851 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns(
852 table_name
853 )
854 missing_columns = [column for column in required_columns if column not in fields]
855 if missing_columns:
856 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
858 qfields = [quote_id(field) for field in fields]
859 qfields_str = ",".join(qfields)
861 with Timer(table_name.name + " query build", self.config.timer):
862 table = self._schema.tableName(table_name)
863 if time_part is not None:
864 table = f"{table}_{time_part}"
866 holders = ",".join(["?"] * len(qfields))
867 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
868 statement = self._prep_statement(query)
869 queries = cassandra.query.BatchStatement()
870 for rec in records.itertuples(index=False):
871 values = []
872 for field in df_fields:
873 if field not in column_map:
874 continue
875 value = getattr(rec, field)
876 if column_map[field].datatype is felis.types.Timestamp:
877 if isinstance(value, pandas.Timestamp):
878 value = literal(value.to_pydatetime())
879 else:
880 # Assume it's seconds since epoch, Cassandra
881 # datetime is in milliseconds
882 value = int(value * 1000)
883 values.append(literal(value))
884 for field in extra_fields:
885 value = extra_columns[field]
886 values.append(literal(value))
887 queries.add(statement, values)
889 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0])
890 with Timer(table_name.name + " insert", self.config.timer):
891 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
893 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
894 """Calculate spatial partition for each record and add it to a
895 DataFrame.
897 Notes
898 -----
899 This overrides any existing column in a DataFrame with the same name
900 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
901 returned.
902 """
903 # calculate HTM index for every DiaObject
904 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
905 ra_col, dec_col = self.config.ra_dec_columns
906 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
907 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
908 idx = self._pixelization.pixel(uv3d)
909 apdb_part[i] = idx
910 df = df.copy()
911 df["apdb_part"] = apdb_part
912 return df
914 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
915 """Add apdb_part column to DiaSource catalog.
917 Notes
918 -----
919 This method copies apdb_part value from a matching DiaObject record.
920 DiaObject catalog needs to have a apdb_part column filled by
921 ``_add_obj_part`` method and DiaSource records need to be
922 associated to DiaObjects via ``diaObjectId`` column.
924 This overrides any existing column in a DataFrame with the same name
925 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
926 returned.
927 """
928 pixel_id_map: Dict[int, int] = {
929 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
930 }
931 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
932 ra_col, dec_col = self.config.ra_dec_columns
933 for i, (diaObjId, ra, dec) in enumerate(
934 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col])
935 ):
936 if diaObjId == 0:
937 # DiaSources associated with SolarSystemObjects do not have an
938 # associated DiaObject hence we skip them and set partition
939 # based on its own ra/dec
940 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
941 idx = self._pixelization.pixel(uv3d)
942 apdb_part[i] = idx
943 else:
944 apdb_part[i] = pixel_id_map[diaObjId]
945 sources = sources.copy()
946 sources["apdb_part"] = apdb_part
947 return sources
949 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
950 """Add apdb_part column to DiaForcedSource catalog.
952 Notes
953 -----
954 This method copies apdb_part value from a matching DiaObject record.
955 DiaObject catalog needs to have a apdb_part column filled by
956 ``_add_obj_part`` method and DiaSource records need to be
957 associated to DiaObjects via ``diaObjectId`` column.
959 This overrides any existing column in a DataFrame with the same name
960 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
961 returned.
962 """
963 pixel_id_map: Dict[int, int] = {
964 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
965 }
966 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
967 for i, diaObjId in enumerate(sources["diaObjectId"]):
968 apdb_part[i] = pixel_id_map[diaObjId]
969 sources = sources.copy()
970 sources["apdb_part"] = apdb_part
971 return sources
973 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
974 """Calculate time partiton number for a given time.
976 Parameters
977 ----------
978 time : `float` or `lsst.daf.base.DateTime`
979 Time for which to calculate partition number. Can be float to mean
980 MJD or `lsst.daf.base.DateTime`
982 Returns
983 -------
984 partition : `int`
985 Partition number for a given time.
986 """
987 if isinstance(time, dafBase.DateTime):
988 mjd = time.get(system=dafBase.DateTime.MJD)
989 else:
990 mjd = time
991 days_since_epoch = mjd - self._partition_zero_epoch_mjd
992 partition = int(days_since_epoch) // self.config.time_partition_days
993 return partition
995 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
996 """Make an empty catalog for a table with a given name.
998 Parameters
999 ----------
1000 table_name : `ApdbTables`
1001 Name of the table.
1003 Returns
1004 -------
1005 catalog : `pandas.DataFrame`
1006 An empty catalog.
1007 """
1008 table = self._schema.tableSchemas[table_name]
1010 data = {
1011 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
1012 for columnDef in table.columns
1013 }
1014 return pandas.DataFrame(data)
1016 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
1017 """Convert query string into prepared statement."""
1018 stmt = self._prepared_statements.get(query)
1019 if stmt is None:
1020 stmt = self._session.prepare(query)
1021 self._prepared_statements[query] = stmt
1022 return stmt
1024 def _combine_where(
1025 self,
1026 prefix: str,
1027 where1: List[Tuple[str, Tuple]],
1028 where2: List[Tuple[str, Tuple]],
1029 suffix: Optional[str] = None,
1030 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
1031 """Make cartesian product of two parts of WHERE clause into a series
1032 of statements to execute.
1034 Parameters
1035 ----------
1036 prefix : `str`
1037 Initial statement prefix that comes before WHERE clause, e.g.
1038 "SELECT * from Table"
1039 """
1040 # If lists are empty use special sentinels.
1041 if not where1:
1042 where1 = [("", ())]
1043 if not where2:
1044 where2 = [("", ())]
1046 for expr1, params1 in where1:
1047 for expr2, params2 in where2:
1048 full_query = prefix
1049 wheres = []
1050 if expr1:
1051 wheres.append(expr1)
1052 if expr2:
1053 wheres.append(expr2)
1054 if wheres:
1055 full_query += " WHERE " + " AND ".join(wheres)
1056 if suffix:
1057 full_query += " " + suffix
1058 params = params1 + params2
1059 if params:
1060 statement = self._prep_statement(full_query)
1061 else:
1062 # If there are no params then it is likely that query
1063 # has a bunch of literals rendered already, no point
1064 # trying to prepare it.
1065 statement = cassandra.query.SimpleStatement(full_query)
1066 yield (statement, params)
1068 def _spatial_where(
1069 self, region: Optional[sphgeom.Region], use_ranges: bool = False
1070 ) -> List[Tuple[str, Tuple]]:
1071 """Generate expressions for spatial part of WHERE clause.
1073 Parameters
1074 ----------
1075 region : `sphgeom.Region`
1076 Spatial region for query results.
1077 use_ranges : `bool`
1078 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
1079 p2") instead of exact list of pixels. Should be set to True for
1080 large regions covering very many pixels.
1082 Returns
1083 -------
1084 expressions : `list` [ `tuple` ]
1085 Empty list is returned if ``region`` is `None`, otherwise a list
1086 of one or more (expression, parameters) tuples
1087 """
1088 if region is None:
1089 return []
1090 if use_ranges:
1091 pixel_ranges = self._pixelization.envelope(region)
1092 expressions: List[Tuple[str, Tuple]] = []
1093 for lower, upper in pixel_ranges:
1094 upper -= 1
1095 if lower == upper:
1096 expressions.append(('"apdb_part" = ?', (lower,)))
1097 else:
1098 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
1099 return expressions
1100 else:
1101 pixels = self._pixelization.pixels(region)
1102 if self.config.query_per_spatial_part:
1103 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1104 else:
1105 pixels_str = ",".join([str(pix) for pix in pixels])
1106 return [(f'"apdb_part" IN ({pixels_str})', ())]
1108 def _temporal_where(
1109 self,
1110 table: ApdbTables,
1111 start_time: Union[float, dafBase.DateTime],
1112 end_time: Union[float, dafBase.DateTime],
1113 query_per_time_part: Optional[bool] = None,
1114 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
1115 """Generate table names and expressions for temporal part of WHERE
1116 clauses.
1118 Parameters
1119 ----------
1120 table : `ApdbTables`
1121 Table to select from.
1122 start_time : `dafBase.DateTime` or `float`
1123 Starting Datetime of MJD value of the time range.
1124 start_time : `dafBase.DateTime` or `float`
1125 Starting Datetime of MJD value of the time range.
1126 query_per_time_part : `bool`, optional
1127 If None then use ``query_per_time_part`` from configuration.
1129 Returns
1130 -------
1131 tables : `list` [ `str` ]
1132 List of the table names to query.
1133 expressions : `list` [ `tuple` ]
1134 A list of zero or more (expression, parameters) tuples.
1135 """
1136 tables: List[str]
1137 temporal_where: List[Tuple[str, Tuple]] = []
1138 table_name = self._schema.tableName(table)
1139 time_part_start = self._time_partition(start_time)
1140 time_part_end = self._time_partition(end_time)
1141 time_parts = list(range(time_part_start, time_part_end + 1))
1142 if self.config.time_partition_tables:
1143 tables = [f"{table_name}_{part}" for part in time_parts]
1144 else:
1145 tables = [table_name]
1146 if query_per_time_part is None:
1147 query_per_time_part = self.config.query_per_time_part
1148 if query_per_time_part:
1149 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts]
1150 else:
1151 time_part_list = ",".join([str(part) for part in time_parts])
1152 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1154 return tables, temporal_where