Coverage for python/lsst/dax/apdb/apdbCassandra.py: 16%
486 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-12 10:17 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-12 10:17 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26import logging
27import uuid
28from collections.abc import Iterable, Iterator, Mapping, Set
29from typing import Any, cast
31import numpy as np
32import pandas
34# If cassandra-driver is not there the module can still be imported
35# but ApdbCassandra cannot be instantiated.
36try:
37 import cassandra
38 import cassandra.query
39 from cassandra.auth import AuthProvider, PlainTextAuthProvider
40 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
41 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
43 CASSANDRA_IMPORTED = True
44except ImportError:
45 CASSANDRA_IMPORTED = False
47import felis.types
48import lsst.daf.base as dafBase
49from felis.simple import Table
50from lsst import sphgeom
51from lsst.pex.config import ChoiceField, Field, ListField
52from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError
53from lsst.utils.iteration import chunk_iterable
55from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
56from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
57from .apdbSchema import ApdbTables
58from .cassandra_utils import (
59 ApdbCassandraTableData,
60 literal,
61 pandas_dataframe_factory,
62 quote_id,
63 raw_data_factory,
64 select_concurrent,
65)
66from .pixelization import Pixelization
67from .timer import Timer
69_LOG = logging.getLogger(__name__)
71# Copied from daf_butler.
72DB_AUTH_ENVVAR = "LSST_DB_AUTH"
73"""Default name of the environmental variable that will be used to locate DB
74credentials configuration file. """
76DB_AUTH_PATH = "~/.lsst/db-auth.yaml"
77"""Default path at which it is expected that DB credentials are found."""
80class CassandraMissingError(Exception):
81 def __init__(self) -> None:
82 super().__init__("cassandra-driver module cannot be imported")
85class ApdbCassandraConfig(ApdbConfig):
86 """Configuration class for Cassandra-based APDB implementation."""
88 contact_points = ListField[str](
89 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"]
90 )
91 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[])
92 port = Field[int](doc="Port number to connect to.", default=9042)
93 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb")
94 username = Field[str](
95 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.",
96 default="",
97 )
98 read_consistency = Field[str](
99 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM"
100 )
101 write_consistency = Field[str](
102 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM"
103 )
104 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0)
105 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0)
106 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0)
107 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500)
108 protocol_version = Field[int](
109 doc="Cassandra protocol version to use, default is V4",
110 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0,
111 )
112 dia_object_columns = ListField[str](
113 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[]
114 )
115 prefix = Field[str](doc="Prefix to add to table names", default="")
116 part_pixelization = ChoiceField[str](
117 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
118 doc="Pixelization used for partitioning index.",
119 default="mq3c",
120 )
121 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10)
122 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64)
123 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
124 timer = Field[bool](doc="If True then print/log timing information", default=False)
125 time_partition_tables = Field[bool](
126 doc="Use per-partition tables for sources instead of partitioning by time", default=True
127 )
128 time_partition_days = Field[int](
129 doc=(
130 "Time partitioning granularity in days, this value must not be changed after database is "
131 "initialized"
132 ),
133 default=30,
134 )
135 time_partition_start = Field[str](
136 doc=(
137 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
138 "This is used only when time_partition_tables is True."
139 ),
140 default="2018-12-01T00:00:00",
141 )
142 time_partition_end = Field[str](
143 doc=(
144 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
145 "This is used only when time_partition_tables is True."
146 ),
147 default="2030-01-01T00:00:00",
148 )
149 query_per_time_part = Field[bool](
150 default=False,
151 doc=(
152 "If True then build separate query for each time partition, otherwise build one single query. "
153 "This is only used when time_partition_tables is False in schema config."
154 ),
155 )
156 query_per_spatial_part = Field[bool](
157 default=False,
158 doc="If True then build one query per spatial partition, otherwise build single query.",
159 )
160 use_insert_id_skips_diaobjects = Field[bool](
161 default=False,
162 doc=(
163 "If True then do not store DiaObjects when use_insert_id is True "
164 "(DiaObjectsInsertId has the same data)."
165 ),
166 )
169if CASSANDRA_IMPORTED: 169 ↛ 184line 169 didn't jump to line 184, because the condition on line 169 was never false
171 class _AddressTranslator(AddressTranslator):
172 """Translate internal IP address to external.
174 Only used for docker-based setup, not viable long-term solution.
175 """
177 def __init__(self, public_ips: list[str], private_ips: list[str]):
178 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
180 def translate(self, private_ip: str) -> str:
181 return self._map.get(private_ip, private_ip)
184def _quote_column(name: str) -> str:
185 """Quote column name"""
186 if name.islower():
187 return name
188 else:
189 return f'"{name}"'
192class ApdbCassandra(Apdb):
193 """Implementation of APDB database on to of Apache Cassandra.
195 The implementation is configured via standard ``pex_config`` mechanism
196 using `ApdbCassandraConfig` configuration class. For an example of
197 different configurations check config/ folder.
199 Parameters
200 ----------
201 config : `ApdbCassandraConfig`
202 Configuration object.
203 """
205 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
206 """Start time for partition 0, this should never be changed."""
208 def __init__(self, config: ApdbCassandraConfig):
209 if not CASSANDRA_IMPORTED:
210 raise CassandraMissingError()
212 config.validate()
213 self.config = config
215 _LOG.debug("ApdbCassandra Configuration:")
216 for key, value in self.config.items():
217 _LOG.debug(" %s: %s", key, value)
219 self._pixelization = Pixelization(
220 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
221 )
223 addressTranslator: AddressTranslator | None = None
224 if config.private_ips:
225 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips))
227 self._keyspace = config.keyspace
229 self._cluster = Cluster(
230 execution_profiles=self._makeProfiles(config),
231 contact_points=self.config.contact_points,
232 port=self.config.port,
233 address_translator=addressTranslator,
234 protocol_version=self.config.protocol_version,
235 auth_provider=self._make_auth_provider(config),
236 )
237 self._session = self._cluster.connect()
238 # Disable result paging
239 self._session.default_fetch_size = None
241 self._schema = ApdbCassandraSchema(
242 session=self._session,
243 keyspace=self._keyspace,
244 schema_file=self.config.schema_file,
245 schema_name=self.config.schema_name,
246 prefix=self.config.prefix,
247 time_partition_tables=self.config.time_partition_tables,
248 use_insert_id=self.config.use_insert_id,
249 )
250 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
252 # Cache for prepared statements
253 self._prepared_statements: dict[str, cassandra.query.PreparedStatement] = {}
255 def __del__(self) -> None:
256 if hasattr(self, "_cluster"):
257 self._cluster.shutdown()
259 def _make_auth_provider(self, config: ApdbCassandraConfig) -> AuthProvider | None:
260 """Make Cassandra authentication provider instance."""
261 try:
262 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR)
263 except DbAuthNotFoundError:
264 # Credentials file doesn't exist, use anonymous login.
265 return None
267 empty_username = True
268 # Try every contact point in turn.
269 for hostname in config.contact_points:
270 try:
271 username, password = dbauth.getAuth(
272 "cassandra", config.username, hostname, config.port, config.keyspace
273 )
274 if not username:
275 # Password without user name, try next hostname, but give
276 # warning later if no better match is found.
277 empty_username = True
278 else:
279 return PlainTextAuthProvider(username=username, password=password)
280 except DbAuthNotFoundError:
281 pass
283 if empty_username:
284 _LOG.warning(
285 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not "
286 f"user name, anonymous Cassandra logon will be attempted."
287 )
289 return None
291 def tableDef(self, table: ApdbTables) -> Table | None:
292 # docstring is inherited from a base class
293 return self._schema.tableSchemas.get(table)
295 def makeSchema(self, drop: bool = False) -> None:
296 # docstring is inherited from a base class
298 if self.config.time_partition_tables:
299 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
300 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
301 part_range = (
302 self._time_partition(time_partition_start),
303 self._time_partition(time_partition_end) + 1,
304 )
305 self._schema.makeSchema(drop=drop, part_range=part_range)
306 else:
307 self._schema.makeSchema(drop=drop)
309 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
310 # docstring is inherited from a base class
312 sp_where = self._spatial_where(region)
313 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
315 # We need to exclude extra partitioning columns from result.
316 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
317 what = ",".join(_quote_column(column) for column in column_names)
319 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
320 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"'
321 statements: list[tuple] = []
322 for where, params in sp_where:
323 full_query = f"{query} WHERE {where}"
324 if params:
325 statement = self._prep_statement(full_query)
326 else:
327 # If there are no params then it is likely that query has a
328 # bunch of literals rendered already, no point trying to
329 # prepare it because it's not reusable.
330 statement = cassandra.query.SimpleStatement(full_query)
331 statements.append((statement, params))
332 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
334 with Timer("DiaObject select", self.config.timer):
335 objects = cast(
336 pandas.DataFrame,
337 select_concurrent(
338 self._session, statements, "read_pandas_multi", self.config.read_concurrency
339 ),
340 )
342 _LOG.debug("found %s DiaObjects", objects.shape[0])
343 return objects
345 def getDiaSources(
346 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
347 ) -> pandas.DataFrame | None:
348 # docstring is inherited from a base class
349 months = self.config.read_sources_months
350 if months == 0:
351 return None
352 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
353 mjd_start = mjd_end - months * 30
355 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
357 def getDiaForcedSources(
358 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime
359 ) -> pandas.DataFrame | None:
360 # docstring is inherited from a base class
361 months = self.config.read_forced_sources_months
362 if months == 0:
363 return None
364 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
365 mjd_start = mjd_end - months * 30
367 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
369 def containsVisitDetector(self, visit: int, detector: int) -> bool:
370 # docstring is inherited from a base class
371 raise NotImplementedError()
373 def getInsertIds(self) -> list[ApdbInsertId] | None:
374 # docstring is inherited from a base class
375 if not self._schema.has_insert_id:
376 return None
378 # everything goes into a single partition
379 partition = 0
381 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
382 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?'
384 result = self._session.execute(
385 self._prep_statement(query),
386 (partition,),
387 timeout=self.config.read_timeout,
388 execution_profile="read_tuples",
389 )
390 # order by insert_time
391 rows = sorted(result)
392 return [
393 ApdbInsertId(id=row[1], insert_time=dafBase.DateTime(int(row[0].timestamp() * 1e9)))
394 for row in rows
395 ]
397 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
398 # docstring is inherited from a base class
399 if not self._schema.has_insert_id:
400 raise ValueError("APDB is not configured for history storage")
402 all_insert_ids = [id.id for id in ids]
403 # There is 64k limit on number of markers in Cassandra CQL
404 for insert_ids in chunk_iterable(all_insert_ids, 20_000):
405 params = ",".join("?" * len(insert_ids))
407 # everything goes into a single partition
408 partition = 0
410 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
411 query = (
412 f'DELETE FROM "{self._keyspace}"."{table_name}" '
413 f"WHERE partition = ? AND insert_id IN ({params})"
414 )
416 self._session.execute(
417 self._prep_statement(query),
418 [partition] + list(insert_ids),
419 timeout=self.config.remove_timeout,
420 )
422 # Also remove those insert_ids from Dia*InsertId tables.abs
423 for table in (
424 ExtraTables.DiaObjectInsertId,
425 ExtraTables.DiaSourceInsertId,
426 ExtraTables.DiaForcedSourceInsertId,
427 ):
428 table_name = self._schema.tableName(table)
429 query = f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
430 self._session.execute(
431 self._prep_statement(query),
432 insert_ids,
433 timeout=self.config.remove_timeout,
434 )
436 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
437 # docstring is inherited from a base class
438 return self._get_history(ExtraTables.DiaObjectInsertId, ids)
440 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
441 # docstring is inherited from a base class
442 return self._get_history(ExtraTables.DiaSourceInsertId, ids)
444 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
445 # docstring is inherited from a base class
446 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids)
448 def getSSObjects(self) -> pandas.DataFrame:
449 # docstring is inherited from a base class
450 tableName = self._schema.tableName(ApdbTables.SSObject)
451 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
453 objects = None
454 with Timer("SSObject select", self.config.timer):
455 result = self._session.execute(query, execution_profile="read_pandas")
456 objects = result._current_rows
458 _LOG.debug("found %s DiaObjects", objects.shape[0])
459 return objects
461 def store(
462 self,
463 visit_time: dafBase.DateTime,
464 objects: pandas.DataFrame,
465 sources: pandas.DataFrame | None = None,
466 forced_sources: pandas.DataFrame | None = None,
467 ) -> None:
468 # docstring is inherited from a base class
470 insert_id: ApdbInsertId | None = None
471 if self._schema.has_insert_id:
472 insert_id = ApdbInsertId.new_insert_id(visit_time)
473 self._storeInsertId(insert_id, visit_time)
475 # fill region partition column for DiaObjects
476 objects = self._add_obj_part(objects)
477 self._storeDiaObjects(objects, visit_time, insert_id)
479 if sources is not None:
480 # copy apdb_part column from DiaObjects to DiaSources
481 sources = self._add_src_part(sources, objects)
482 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id)
483 self._storeDiaSourcesPartitions(sources, visit_time, insert_id)
485 if forced_sources is not None:
486 forced_sources = self._add_fsrc_part(forced_sources, objects)
487 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id)
489 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
490 # docstring is inherited from a base class
491 self._storeObjectsPandas(objects, ApdbTables.SSObject)
493 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
494 # docstring is inherited from a base class
496 # To update a record we need to know its exact primary key (including
497 # partition key) so we start by querying for diaSourceId to find the
498 # primary keys.
500 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
501 # split it into 1k IDs per query
502 selects: list[tuple] = []
503 for ids in chunk_iterable(idMap.keys(), 1_000):
504 ids_str = ",".join(str(item) for item in ids)
505 selects.append(
506 (
507 (
508 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" '
509 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})'
510 ),
511 {},
512 )
513 )
515 # No need for DataFrame here, read data as tuples.
516 result = cast(
517 list[tuple[int, int, int, uuid.UUID | None]],
518 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency),
519 )
521 # Make mapping from source ID to its partition.
522 id2partitions: dict[int, tuple[int, int]] = {}
523 id2insert_id: dict[int, uuid.UUID] = {}
524 for row in result:
525 id2partitions[row[0]] = row[1:3]
526 if row[3] is not None:
527 id2insert_id[row[0]] = row[3]
529 # make sure we know partitions for each ID
530 if set(id2partitions) != set(idMap):
531 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
532 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
534 # Reassign in standard tables
535 queries = cassandra.query.BatchStatement()
536 table_name = self._schema.tableName(ApdbTables.DiaSource)
537 for diaSourceId, ssObjectId in idMap.items():
538 apdb_part, apdb_time_part = id2partitions[diaSourceId]
539 values: tuple
540 if self.config.time_partition_tables:
541 query = (
542 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
543 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
544 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
545 )
546 values = (ssObjectId, apdb_part, diaSourceId)
547 else:
548 query = (
549 f'UPDATE "{self._keyspace}"."{table_name}"'
550 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
551 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
552 )
553 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
554 queries.add(self._prep_statement(query), values)
556 # Reassign in history tables, only if history is enabled
557 if id2insert_id:
558 # Filter out insert ids that have been deleted already. There is a
559 # potential race with concurrent removal of insert IDs, but it
560 # should be handled by WHERE in UPDATE.
561 known_ids = set()
562 if insert_ids := self.getInsertIds():
563 known_ids = set(insert_id.id for insert_id in insert_ids)
564 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids}
565 if id2insert_id:
566 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId)
567 for diaSourceId, ssObjectId in idMap.items():
568 if insert_id := id2insert_id.get(diaSourceId):
569 query = (
570 f'UPDATE "{self._keyspace}"."{table_name}" '
571 ' SET "ssObjectId" = ?, "diaObjectId" = NULL '
572 'WHERE "insert_id" = ? AND "diaSourceId" = ?'
573 )
574 values = (ssObjectId, insert_id, diaSourceId)
575 queries.add(self._prep_statement(query), values)
577 _LOG.debug("%s: will update %d records", table_name, len(idMap))
578 with Timer(table_name + " update", self.config.timer):
579 self._session.execute(queries, execution_profile="write")
581 def dailyJob(self) -> None:
582 # docstring is inherited from a base class
583 pass
585 def countUnassociatedObjects(self) -> int:
586 # docstring is inherited from a base class
588 # It's too inefficient to implement it for Cassandra in current schema.
589 raise NotImplementedError()
591 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
592 """Make all execution profiles used in the code."""
593 if config.private_ips:
594 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
595 else:
596 loadBalancePolicy = RoundRobinPolicy()
598 read_tuples_profile = ExecutionProfile(
599 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
600 request_timeout=config.read_timeout,
601 row_factory=cassandra.query.tuple_factory,
602 load_balancing_policy=loadBalancePolicy,
603 )
604 read_pandas_profile = ExecutionProfile(
605 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
606 request_timeout=config.read_timeout,
607 row_factory=pandas_dataframe_factory,
608 load_balancing_policy=loadBalancePolicy,
609 )
610 read_raw_profile = ExecutionProfile(
611 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
612 request_timeout=config.read_timeout,
613 row_factory=raw_data_factory,
614 load_balancing_policy=loadBalancePolicy,
615 )
616 # Profile to use with select_concurrent to return pandas data frame
617 read_pandas_multi_profile = ExecutionProfile(
618 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
619 request_timeout=config.read_timeout,
620 row_factory=pandas_dataframe_factory,
621 load_balancing_policy=loadBalancePolicy,
622 )
623 # Profile to use with select_concurrent to return raw data (columns and
624 # rows)
625 read_raw_multi_profile = ExecutionProfile(
626 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
627 request_timeout=config.read_timeout,
628 row_factory=raw_data_factory,
629 load_balancing_policy=loadBalancePolicy,
630 )
631 write_profile = ExecutionProfile(
632 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
633 request_timeout=config.write_timeout,
634 load_balancing_policy=loadBalancePolicy,
635 )
636 # To replace default DCAwareRoundRobinPolicy
637 default_profile = ExecutionProfile(
638 load_balancing_policy=loadBalancePolicy,
639 )
640 return {
641 "read_tuples": read_tuples_profile,
642 "read_pandas": read_pandas_profile,
643 "read_raw": read_raw_profile,
644 "read_pandas_multi": read_pandas_multi_profile,
645 "read_raw_multi": read_raw_multi_profile,
646 "write": write_profile,
647 EXEC_PROFILE_DEFAULT: default_profile,
648 }
650 def _getSources(
651 self,
652 region: sphgeom.Region,
653 object_ids: Iterable[int] | None,
654 mjd_start: float,
655 mjd_end: float,
656 table_name: ApdbTables,
657 ) -> pandas.DataFrame:
658 """Return catalog of DiaSource instances given set of DiaObject IDs.
660 Parameters
661 ----------
662 region : `lsst.sphgeom.Region`
663 Spherical region.
664 object_ids :
665 Collection of DiaObject IDs
666 mjd_start : `float`
667 Lower bound of time interval.
668 mjd_end : `float`
669 Upper bound of time interval.
670 table_name : `ApdbTables`
671 Name of the table.
673 Returns
674 -------
675 catalog : `pandas.DataFrame`, or `None`
676 Catalog containing DiaSource records. Empty catalog is returned if
677 ``object_ids`` is empty.
678 """
679 object_id_set: Set[int] = set()
680 if object_ids is not None:
681 object_id_set = set(object_ids)
682 if len(object_id_set) == 0:
683 return self._make_empty_catalog(table_name)
685 sp_where = self._spatial_where(region)
686 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
688 # We need to exclude extra partitioning columns from result.
689 column_names = self._schema.apdbColumnNames(table_name)
690 what = ",".join(_quote_column(column) for column in column_names)
692 # Build all queries
693 statements: list[tuple] = []
694 for table in tables:
695 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"'
696 statements += list(self._combine_where(prefix, sp_where, temporal_where))
697 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
699 with Timer(table_name.name + " select", self.config.timer):
700 catalog = cast(
701 pandas.DataFrame,
702 select_concurrent(
703 self._session, statements, "read_pandas_multi", self.config.read_concurrency
704 ),
705 )
707 # filter by given object IDs
708 if len(object_id_set) > 0:
709 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
711 # precise filtering on midpointMjdTai
712 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
714 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
715 return catalog
717 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
718 """Return records from a particular table given set of insert IDs."""
719 if not self._schema.has_insert_id:
720 raise ValueError("APDB is not configured for history retrieval")
722 insert_ids = [id.id for id in ids]
723 params = ",".join("?" * len(insert_ids))
725 table_name = self._schema.tableName(table)
726 # I know that history table schema has only regular APDB columns plus
727 # an insert_id column, and this is exactly what we need to return from
728 # this method, so selecting a star is fine here.
729 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
730 statement = self._prep_statement(query)
732 with Timer("DiaObject history", self.config.timer):
733 result = self._session.execute(statement, insert_ids, execution_profile="read_raw")
734 table_data = cast(ApdbCassandraTableData, result._current_rows)
735 return table_data
737 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None:
738 # Cassandra timestamp uses milliseconds since epoch
739 timestamp = insert_id.insert_time.nsecs() // 1_000_000
741 # everything goes into a single partition
742 partition = 0
744 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
745 query = (
746 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) '
747 "VALUES (?, ?, ?)"
748 )
750 self._session.execute(
751 self._prep_statement(query),
752 (partition, insert_id.id, timestamp),
753 timeout=self.config.write_timeout,
754 execution_profile="write",
755 )
757 def _storeDiaObjects(
758 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
759 ) -> None:
760 """Store catalog of DiaObjects from current visit.
762 Parameters
763 ----------
764 objs : `pandas.DataFrame`
765 Catalog with DiaObject records
766 visit_time : `lsst.daf.base.DateTime`
767 Time of the current visit.
768 """
769 visit_time_dt = visit_time.toPython()
770 extra_columns = dict(lastNonForcedSource=visit_time_dt)
771 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
773 extra_columns["validityStart"] = visit_time_dt
774 time_part: int | None = self._time_partition(visit_time)
775 if not self.config.time_partition_tables:
776 extra_columns["apdb_time_part"] = time_part
777 time_part = None
779 # Only store DiaObects if not storing insert_ids or explicitly
780 # configured to always store them
781 if insert_id is None or not self.config.use_insert_id_skips_diaobjects:
782 self._storeObjectsPandas(
783 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part
784 )
786 if insert_id is not None:
787 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt)
788 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns)
790 def _storeDiaSources(
791 self,
792 table_name: ApdbTables,
793 sources: pandas.DataFrame,
794 visit_time: dafBase.DateTime,
795 insert_id: ApdbInsertId | None,
796 ) -> None:
797 """Store catalog of DIASources or DIAForcedSources from current visit.
799 Parameters
800 ----------
801 sources : `pandas.DataFrame`
802 Catalog containing DiaSource records
803 visit_time : `lsst.daf.base.DateTime`
804 Time of the current visit.
805 """
806 time_part: int | None = self._time_partition(visit_time)
807 extra_columns: dict[str, Any] = {}
808 if not self.config.time_partition_tables:
809 extra_columns["apdb_time_part"] = time_part
810 time_part = None
812 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
814 if insert_id is not None:
815 extra_columns = dict(insert_id=insert_id.id)
816 if table_name is ApdbTables.DiaSource:
817 extra_table = ExtraTables.DiaSourceInsertId
818 else:
819 extra_table = ExtraTables.DiaForcedSourceInsertId
820 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
822 def _storeDiaSourcesPartitions(
823 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
824 ) -> None:
825 """Store mapping of diaSourceId to its partitioning values.
827 Parameters
828 ----------
829 sources : `pandas.DataFrame`
830 Catalog containing DiaSource records
831 visit_time : `lsst.daf.base.DateTime`
832 Time of the current visit.
833 """
834 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
835 extra_columns = {
836 "apdb_time_part": self._time_partition(visit_time),
837 "insert_id": insert_id.id if insert_id is not None else None,
838 }
840 self._storeObjectsPandas(
841 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
842 )
844 def _storeObjectsPandas(
845 self,
846 records: pandas.DataFrame,
847 table_name: ApdbTables | ExtraTables,
848 extra_columns: Mapping | None = None,
849 time_part: int | None = None,
850 ) -> None:
851 """Store generic objects.
853 Takes Pandas catalog and stores a bunch of records in a table.
855 Parameters
856 ----------
857 records : `pandas.DataFrame`
858 Catalog containing object records
859 table_name : `ApdbTables`
860 Name of the table as defined in APDB schema.
861 extra_columns : `dict`, optional
862 Mapping (column_name, column_value) which gives fixed values for
863 columns in each row, overrides values in ``records`` if matching
864 columns exist there.
865 time_part : `int`, optional
866 If not `None` then insert into a per-partition table.
868 Notes
869 -----
870 If Pandas catalog contains additional columns not defined in table
871 schema they are ignored. Catalog does not have to contain all columns
872 defined in a table, but partition and clustering keys must be present
873 in a catalog or ``extra_columns``.
874 """
875 # use extra columns if specified
876 if extra_columns is None:
877 extra_columns = {}
878 extra_fields = list(extra_columns.keys())
880 # Fields that will come from dataframe.
881 df_fields = [column for column in records.columns if column not in extra_fields]
883 column_map = self._schema.getColumnMap(table_name)
884 # list of columns (as in felis schema)
885 fields = [column_map[field].name for field in df_fields if field in column_map]
886 fields += extra_fields
888 # check that all partitioning and clustering columns are defined
889 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns(
890 table_name
891 )
892 missing_columns = [column for column in required_columns if column not in fields]
893 if missing_columns:
894 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
896 qfields = [quote_id(field) for field in fields]
897 qfields_str = ",".join(qfields)
899 with Timer(table_name.name + " query build", self.config.timer):
900 table = self._schema.tableName(table_name)
901 if time_part is not None:
902 table = f"{table}_{time_part}"
904 holders = ",".join(["?"] * len(qfields))
905 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
906 statement = self._prep_statement(query)
907 queries = cassandra.query.BatchStatement()
908 for rec in records.itertuples(index=False):
909 values = []
910 for field in df_fields:
911 if field not in column_map:
912 continue
913 value = getattr(rec, field)
914 if column_map[field].datatype is felis.types.Timestamp:
915 if isinstance(value, pandas.Timestamp):
916 value = literal(value.to_pydatetime())
917 else:
918 # Assume it's seconds since epoch, Cassandra
919 # datetime is in milliseconds
920 value = int(value * 1000)
921 values.append(literal(value))
922 for field in extra_fields:
923 value = extra_columns[field]
924 values.append(literal(value))
925 queries.add(statement, values)
927 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0])
928 with Timer(table_name.name + " insert", self.config.timer):
929 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
931 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
932 """Calculate spatial partition for each record and add it to a
933 DataFrame.
935 Notes
936 -----
937 This overrides any existing column in a DataFrame with the same name
938 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
939 returned.
940 """
941 # calculate HTM index for every DiaObject
942 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
943 ra_col, dec_col = self.config.ra_dec_columns
944 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
945 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
946 idx = self._pixelization.pixel(uv3d)
947 apdb_part[i] = idx
948 df = df.copy()
949 df["apdb_part"] = apdb_part
950 return df
952 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
953 """Add apdb_part column to DiaSource catalog.
955 Notes
956 -----
957 This method copies apdb_part value from a matching DiaObject record.
958 DiaObject catalog needs to have a apdb_part column filled by
959 ``_add_obj_part`` method and DiaSource records need to be
960 associated to DiaObjects via ``diaObjectId`` column.
962 This overrides any existing column in a DataFrame with the same name
963 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
964 returned.
965 """
966 pixel_id_map: dict[int, int] = {
967 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
968 }
969 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
970 ra_col, dec_col = self.config.ra_dec_columns
971 for i, (diaObjId, ra, dec) in enumerate(
972 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col])
973 ):
974 if diaObjId == 0:
975 # DiaSources associated with SolarSystemObjects do not have an
976 # associated DiaObject hence we skip them and set partition
977 # based on its own ra/dec
978 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
979 idx = self._pixelization.pixel(uv3d)
980 apdb_part[i] = idx
981 else:
982 apdb_part[i] = pixel_id_map[diaObjId]
983 sources = sources.copy()
984 sources["apdb_part"] = apdb_part
985 return sources
987 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
988 """Add apdb_part column to DiaForcedSource catalog.
990 Notes
991 -----
992 This method copies apdb_part value from a matching DiaObject record.
993 DiaObject catalog needs to have a apdb_part column filled by
994 ``_add_obj_part`` method and DiaSource records need to be
995 associated to DiaObjects via ``diaObjectId`` column.
997 This overrides any existing column in a DataFrame with the same name
998 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
999 returned.
1000 """
1001 pixel_id_map: dict[int, int] = {
1002 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
1003 }
1004 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
1005 for i, diaObjId in enumerate(sources["diaObjectId"]):
1006 apdb_part[i] = pixel_id_map[diaObjId]
1007 sources = sources.copy()
1008 sources["apdb_part"] = apdb_part
1009 return sources
1011 def _time_partition(self, time: float | dafBase.DateTime) -> int:
1012 """Calculate time partiton number for a given time.
1014 Parameters
1015 ----------
1016 time : `float` or `lsst.daf.base.DateTime`
1017 Time for which to calculate partition number. Can be float to mean
1018 MJD or `lsst.daf.base.DateTime`
1020 Returns
1021 -------
1022 partition : `int`
1023 Partition number for a given time.
1024 """
1025 if isinstance(time, dafBase.DateTime):
1026 mjd = time.get(system=dafBase.DateTime.MJD)
1027 else:
1028 mjd = time
1029 days_since_epoch = mjd - self._partition_zero_epoch_mjd
1030 partition = int(days_since_epoch) // self.config.time_partition_days
1031 return partition
1033 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
1034 """Make an empty catalog for a table with a given name.
1036 Parameters
1037 ----------
1038 table_name : `ApdbTables`
1039 Name of the table.
1041 Returns
1042 -------
1043 catalog : `pandas.DataFrame`
1044 An empty catalog.
1045 """
1046 table = self._schema.tableSchemas[table_name]
1048 data = {
1049 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
1050 for columnDef in table.columns
1051 }
1052 return pandas.DataFrame(data)
1054 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
1055 """Convert query string into prepared statement."""
1056 stmt = self._prepared_statements.get(query)
1057 if stmt is None:
1058 stmt = self._session.prepare(query)
1059 self._prepared_statements[query] = stmt
1060 return stmt
1062 def _combine_where(
1063 self,
1064 prefix: str,
1065 where1: list[tuple[str, tuple]],
1066 where2: list[tuple[str, tuple]],
1067 suffix: str | None = None,
1068 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]:
1069 """Make cartesian product of two parts of WHERE clause into a series
1070 of statements to execute.
1072 Parameters
1073 ----------
1074 prefix : `str`
1075 Initial statement prefix that comes before WHERE clause, e.g.
1076 "SELECT * from Table"
1077 """
1078 # If lists are empty use special sentinels.
1079 if not where1:
1080 where1 = [("", ())]
1081 if not where2:
1082 where2 = [("", ())]
1084 for expr1, params1 in where1:
1085 for expr2, params2 in where2:
1086 full_query = prefix
1087 wheres = []
1088 if expr1:
1089 wheres.append(expr1)
1090 if expr2:
1091 wheres.append(expr2)
1092 if wheres:
1093 full_query += " WHERE " + " AND ".join(wheres)
1094 if suffix:
1095 full_query += " " + suffix
1096 params = params1 + params2
1097 if params:
1098 statement = self._prep_statement(full_query)
1099 else:
1100 # If there are no params then it is likely that query
1101 # has a bunch of literals rendered already, no point
1102 # trying to prepare it.
1103 statement = cassandra.query.SimpleStatement(full_query)
1104 yield (statement, params)
1106 def _spatial_where(
1107 self, region: sphgeom.Region | None, use_ranges: bool = False
1108 ) -> list[tuple[str, tuple]]:
1109 """Generate expressions for spatial part of WHERE clause.
1111 Parameters
1112 ----------
1113 region : `sphgeom.Region`
1114 Spatial region for query results.
1115 use_ranges : `bool`
1116 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
1117 p2") instead of exact list of pixels. Should be set to True for
1118 large regions covering very many pixels.
1120 Returns
1121 -------
1122 expressions : `list` [ `tuple` ]
1123 Empty list is returned if ``region`` is `None`, otherwise a list
1124 of one or more (expression, parameters) tuples
1125 """
1126 if region is None:
1127 return []
1128 if use_ranges:
1129 pixel_ranges = self._pixelization.envelope(region)
1130 expressions: list[tuple[str, tuple]] = []
1131 for lower, upper in pixel_ranges:
1132 upper -= 1
1133 if lower == upper:
1134 expressions.append(('"apdb_part" = ?', (lower,)))
1135 else:
1136 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
1137 return expressions
1138 else:
1139 pixels = self._pixelization.pixels(region)
1140 if self.config.query_per_spatial_part:
1141 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1142 else:
1143 pixels_str = ",".join([str(pix) for pix in pixels])
1144 return [(f'"apdb_part" IN ({pixels_str})', ())]
1146 def _temporal_where(
1147 self,
1148 table: ApdbTables,
1149 start_time: float | dafBase.DateTime,
1150 end_time: float | dafBase.DateTime,
1151 query_per_time_part: bool | None = None,
1152 ) -> tuple[list[str], list[tuple[str, tuple]]]:
1153 """Generate table names and expressions for temporal part of WHERE
1154 clauses.
1156 Parameters
1157 ----------
1158 table : `ApdbTables`
1159 Table to select from.
1160 start_time : `dafBase.DateTime` or `float`
1161 Starting Datetime of MJD value of the time range.
1162 start_time : `dafBase.DateTime` or `float`
1163 Starting Datetime of MJD value of the time range.
1164 query_per_time_part : `bool`, optional
1165 If None then use ``query_per_time_part`` from configuration.
1167 Returns
1168 -------
1169 tables : `list` [ `str` ]
1170 List of the table names to query.
1171 expressions : `list` [ `tuple` ]
1172 A list of zero or more (expression, parameters) tuples.
1173 """
1174 tables: list[str]
1175 temporal_where: list[tuple[str, tuple]] = []
1176 table_name = self._schema.tableName(table)
1177 time_part_start = self._time_partition(start_time)
1178 time_part_end = self._time_partition(end_time)
1179 time_parts = list(range(time_part_start, time_part_end + 1))
1180 if self.config.time_partition_tables:
1181 tables = [f"{table_name}_{part}" for part in time_parts]
1182 else:
1183 tables = [table_name]
1184 if query_per_time_part is None:
1185 query_per_time_part = self.config.query_per_time_part
1186 if query_per_time_part:
1187 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts]
1188 else:
1189 time_part_list = ",".join([str(part) for part in time_parts])
1190 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1192 return tables, temporal_where