Coverage for python / lsst / dax / apdb / cassandra / apdbCassandra.py: 9%
809 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandra"]
26import datetime
27import json
28import logging
29import random
30import uuid
31import warnings
32from collections import Counter, defaultdict
33from collections.abc import Iterable, Mapping, Set
34from typing import TYPE_CHECKING, Any, cast
36import numpy as np
37import pandas
39# If cassandra-driver is not there the module can still be imported
40# but ApdbCassandra cannot be instantiated.
41try:
42 import cassandra
43 import cassandra.query
44 from cassandra.query import UNSET_VALUE
46 CASSANDRA_IMPORTED = True
47except ImportError:
48 CASSANDRA_IMPORTED = False
50import astropy.time
51import felis.datamodel
53from lsst import sphgeom
54from lsst.utils.iteration import chunk_iterable
56from ..apdb import Apdb, ApdbConfig
57from ..apdbConfigFreezer import ApdbConfigFreezer
58from ..apdbReplica import ApdbTableData, ReplicaChunk
59from ..apdbSchema import ApdbSchema, ApdbTables
60from ..apdbUpdateRecord import (
61 ApdbCloseDiaObjectValidityRecord,
62 ApdbReassignDiaSourceToDiaObjectRecord,
63 ApdbUpdateNDiaSourcesRecord,
64)
65from ..monitor import MonAgent
66from ..recordIds import DiaObjectId, DiaSourceId
67from ..schema_model import Table
68from ..timer import Timer
69from ..versionTuple import VersionTuple
70from .apdbCassandraAdmin import ApdbCassandraAdmin
71from .apdbCassandraReplica import ApdbCassandraReplica
72from .apdbCassandraSchema import ApdbCassandraSchema, CreateTableOptions, ExtraTables
73from .apdbMetadataCassandra import ApdbMetadataCassandra
74from .cassandra_utils import (
75 ApdbCassandraTableData,
76 execute_concurrent,
77 literal,
78 select_concurrent,
79)
80from .config import ApdbCassandraConfig, ApdbCassandraConnectionConfig, ApdbCassandraTimePartitionRange
81from .connectionContext import ConnectionContext, DbVersions
82from .exceptions import CassandraMissingError
83from .partitioner import Partitioner
84from .queries import Column as C # noqa: N817
85from .queries import ColumnExpr, Delete, Insert, QExpr, Select, Update
86from .sessionFactory import SessionContext, SessionFactory
88if TYPE_CHECKING:
89 from ..apdbMetadata import ApdbMetadata
90 from ..apdbUpdateRecord import ApdbUpdateRecord
92_LOG = logging.getLogger(__name__)
94_MON = MonAgent(__name__)
96VERSION = VersionTuple(1, 3, 0)
97"""Version for the code controlling non-replication tables. This needs to be
98updated following compatibility rules when schema produced by this code
99changes.
100"""
103class ApdbCassandra(Apdb):
104 """Implementation of APDB database with Apache Cassandra backend.
106 Parameters
107 ----------
108 config : `ApdbCassandraConfig`
109 Configuration object.
110 """
112 def __init__(self, config: ApdbCassandraConfig):
113 if not CASSANDRA_IMPORTED:
114 raise CassandraMissingError()
116 self._config = config
117 self._keyspace = config.keyspace
118 self._schema = ApdbSchema(config.schema_file, config.ss_schema_file)
120 self._session_factory = SessionFactory(config)
121 self._connection_context: ConnectionContext | None = None
123 @property
124 def _context(self) -> ConnectionContext:
125 """Establish connection if not established and return context."""
126 if self._connection_context is None:
127 current_versions = DbVersions(
128 schema_version=self.schema.schemaVersion(),
129 code_version=self.apdbImplementationVersion(),
130 replica_version=ApdbCassandraReplica.apdbReplicaImplementationVersion(),
131 )
132 _LOG.debug("Current versions: %s", current_versions)
134 session = self._session_factory.session()
135 self._connection_context = ConnectionContext(
136 session, self._config, self.schema.tableSchemas, current_versions
137 )
139 if _LOG.isEnabledFor(logging.DEBUG):
140 _LOG.debug("ApdbCassandra Configuration: %s", self._connection_context.config.model_dump())
142 return self._connection_context
144 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer:
145 """Create `Timer` instance given its name."""
146 return Timer(name, _MON, tags=tags)
148 @classmethod
149 def apdbImplementationVersion(cls) -> VersionTuple:
150 """Return version number for current APDB implementation.
152 Returns
153 -------
154 version : `VersionTuple`
155 Version of the code defined in implementation class.
156 """
157 return VERSION
159 def getConfig(self) -> ApdbCassandraConfig:
160 # docstring is inherited from a base class
161 return self._context.config
163 def tableDef(self, table: ApdbTables) -> Table | None:
164 # docstring is inherited from a base class
165 return self.schema.tableSchemas.get(table)
167 @classmethod
168 def init_database(
169 cls,
170 hosts: tuple[str, ...],
171 keyspace: str,
172 *,
173 schema_file: str | None = None,
174 ss_schema_file: str | None = None,
175 read_sources_months: int | None = None,
176 read_forced_sources_months: int | None = None,
177 enable_replica: bool = False,
178 replica_skips_diaobjects: bool = False,
179 port: int | None = None,
180 username: str | None = None,
181 prefix: str | None = None,
182 part_pixelization: str | None = None,
183 part_pix_level: int | None = None,
184 time_partition_tables: bool = True,
185 time_partition_start: str | None = None,
186 time_partition_end: str | None = None,
187 read_consistency: str | None = None,
188 write_consistency: str | None = None,
189 read_timeout: int | None = None,
190 write_timeout: int | None = None,
191 ra_dec_columns: tuple[str, str] | None = None,
192 replication_factor: int | None = None,
193 drop: bool = False,
194 table_options: CreateTableOptions | None = None,
195 ) -> ApdbCassandraConfig:
196 """Initialize new APDB instance and make configuration object for it.
198 Parameters
199 ----------
200 hosts : `tuple` [`str`, ...]
201 List of host names or IP addresses for Cassandra cluster.
202 keyspace : `str`
203 Name of the keyspace for APDB tables.
204 schema_file : `str`, optional
205 Location of (YAML) configuration file with APDB schema. If not
206 specified then default location will be used.
207 ss_schema_file : `str`, optional
208 Location of (YAML) configuration file with SSO schema. If not
209 specified then default location will be used.
210 read_sources_months : `int`, optional
211 Number of months of history to read from DiaSource.
212 read_forced_sources_months : `int`, optional
213 Number of months of history to read from DiaForcedSource.
214 enable_replica : `bool`, optional
215 If True, make additional tables used for replication to PPDB.
216 replica_skips_diaobjects : `bool`, optional
217 If `True` then do not fill regular ``DiaObject`` table when
218 ``enable_replica`` is `True`.
219 port : `int`, optional
220 Port number to use for Cassandra connections.
221 username : `str`, optional
222 User name for Cassandra connections.
223 prefix : `str`, optional
224 Optional prefix for all table names.
225 part_pixelization : `str`, optional
226 Name of the MOC pixelization used for partitioning.
227 part_pix_level : `int`, optional
228 Pixelization level.
229 time_partition_tables : `bool`, optional
230 Create per-partition tables.
231 time_partition_start : `str`, optional
232 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss
233 format, in TAI.
234 time_partition_end : `str`, optional
235 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss
236 format, in TAI.
237 read_consistency : `str`, optional
238 Name of the consistency level for read operations.
239 write_consistency : `str`, optional
240 Name of the consistency level for write operations.
241 read_timeout : `int`, optional
242 Read timeout in seconds.
243 write_timeout : `int`, optional
244 Write timeout in seconds.
245 ra_dec_columns : `tuple` [`str`, `str`], optional
246 Names of ra/dec columns in DiaObject table.
247 replication_factor : `int`, optional
248 Replication factor used when creating new keyspace, if keyspace
249 already exists its replication factor is not changed.
250 drop : `bool`, optional
251 If `True` then drop existing tables before re-creating the schema.
252 table_options : `CreateTableOptions`, optional
253 Options used when creating Cassandra tables.
255 Returns
256 -------
257 config : `ApdbCassandraConfig`
258 Resulting configuration object for a created APDB instance.
259 """
260 # Some non-standard defaults for connection parameters, these can be
261 # changed later in generated config. Check Cassandra driver
262 # documentation for what these parameters do. These parameters are not
263 # used during database initialization, but they will be saved with
264 # generated config.
265 connection_config = ApdbCassandraConnectionConfig(
266 extra_parameters={
267 "idle_heartbeat_interval": 0,
268 "idle_heartbeat_timeout": 30,
269 "control_connection_timeout": 100,
270 },
271 )
272 config = ApdbCassandraConfig(
273 contact_points=hosts,
274 keyspace=keyspace,
275 enable_replica=enable_replica,
276 replica_skips_diaobjects=replica_skips_diaobjects,
277 connection_config=connection_config,
278 )
279 config.partitioning.time_partition_tables = time_partition_tables
280 if schema_file is not None:
281 config.schema_file = schema_file
282 if ss_schema_file is not None:
283 config.ss_schema_file = ss_schema_file
284 if read_sources_months is not None:
285 config.read_sources_months = read_sources_months
286 if read_forced_sources_months is not None:
287 config.read_forced_sources_months = read_forced_sources_months
288 if port is not None:
289 config.connection_config.port = port
290 if username is not None:
291 config.connection_config.username = username
292 if prefix is not None:
293 config.prefix = prefix
294 if part_pixelization is not None:
295 config.partitioning.part_pixelization = part_pixelization
296 if part_pix_level is not None:
297 config.partitioning.part_pix_level = part_pix_level
298 if time_partition_start is not None:
299 config.partitioning.time_partition_start = time_partition_start
300 if time_partition_end is not None:
301 config.partitioning.time_partition_end = time_partition_end
302 if read_consistency is not None:
303 config.connection_config.read_consistency = read_consistency
304 if write_consistency is not None:
305 config.connection_config.write_consistency = write_consistency
306 if read_timeout is not None:
307 config.connection_config.read_timeout = read_timeout
308 if write_timeout is not None:
309 config.connection_config.write_timeout = write_timeout
310 if ra_dec_columns is not None:
311 config.ra_dec_columns = ra_dec_columns
313 cls._makeSchema(config, drop=drop, replication_factor=replication_factor, table_options=table_options)
315 return config
317 def get_replica(self) -> ApdbCassandraReplica:
318 """Return `ApdbReplica` instance for this database."""
319 # Note that this instance has to stay alive while replica exists, so
320 # we pass reference to self.
321 return ApdbCassandraReplica(self)
323 @classmethod
324 def _makeSchema(
325 cls,
326 config: ApdbConfig,
327 *,
328 drop: bool = False,
329 replication_factor: int | None = None,
330 table_options: CreateTableOptions | None = None,
331 ) -> None:
332 # docstring is inherited from a base class
334 if not isinstance(config, ApdbCassandraConfig):
335 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
337 simple_schema = ApdbSchema(config.schema_file, config.ss_schema_file)
339 with SessionContext(config) as session:
340 schema = ApdbCassandraSchema(
341 session=session,
342 keyspace=config.keyspace,
343 table_schemas=simple_schema.tableSchemas,
344 prefix=config.prefix,
345 time_partition_tables=config.partitioning.time_partition_tables,
346 enable_replica=config.enable_replica,
347 replica_skips_diaobjects=config.replica_skips_diaobjects,
348 )
350 # Ask schema to create all tables.
351 part_range_config: ApdbCassandraTimePartitionRange | None = None
352 if config.partitioning.time_partition_tables:
353 partitioner = Partitioner(config)
354 time_partition_start = astropy.time.Time(
355 config.partitioning.time_partition_start, format="isot", scale="tai"
356 )
357 time_partition_end = astropy.time.Time(
358 config.partitioning.time_partition_end, format="isot", scale="tai"
359 )
360 part_range_config = ApdbCassandraTimePartitionRange(
361 start=partitioner.time_partition(time_partition_start),
362 end=partitioner.time_partition(time_partition_end),
363 )
364 schema.makeSchema(
365 drop=drop,
366 part_range=part_range_config,
367 replication_factor=replication_factor,
368 table_options=table_options,
369 )
370 else:
371 schema.makeSchema(
372 drop=drop, replication_factor=replication_factor, table_options=table_options
373 )
375 meta_table_name = ApdbTables.metadata.table_name(config.prefix)
376 metadata = ApdbMetadataCassandra(
377 session, meta_table_name, config.keyspace, "read_tuples", "write"
378 )
380 # Fill version numbers, overrides if they existed before.
381 metadata.set(
382 ConnectionContext.metadataSchemaVersionKey, str(simple_schema.schemaVersion()), force=True
383 )
384 metadata.set(
385 ConnectionContext.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True
386 )
388 if config.enable_replica:
389 # Only store replica code version if replica is enabled.
390 metadata.set(
391 ConnectionContext.metadataReplicaVersionKey,
392 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()),
393 force=True,
394 )
396 # Store frozen part of a configuration in metadata.
397 freezer = ApdbConfigFreezer[ApdbCassandraConfig](ConnectionContext.frozen_parameters)
398 metadata.set(ConnectionContext.metadataConfigKey, freezer.to_json(config), force=True)
400 # Store time partition range.
401 if part_range_config:
402 part_range_config.save_to_meta(metadata)
404 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
405 # docstring is inherited from a base class
406 context = self._context
407 config = context.config
409 sp_where, num_sp_part = context.partitioner.spatial_where(region)
410 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
412 # We need to exclude extra partitioning columns from result.
413 column_names = context.schema.apdbColumnNames(ApdbTables.DiaObjectLast)
414 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
415 query = Select(self._keyspace, table_name, column_names)
416 statements: list[tuple] = []
417 for where_clause in sp_where:
418 full_query = query.where(where_clause)
419 statements.append(context.stmt_factory.with_params(full_query, prepare=True))
420 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
422 with self._timer("select_time", tags={"table": "DiaObject", "method": "getDiaObjects"}) as timer:
423 raw_objects = cast(
424 ApdbCassandraTableData,
425 select_concurrent(
426 context.session,
427 statements,
428 "read_raw_multi",
429 config.connection_config.read_concurrency,
430 ),
431 )
432 objects = raw_objects.to_pandas(context.schema._table_schema(ApdbTables.DiaObjectLast))
433 timer.add_values(row_count=len(objects), num_sp_part=num_sp_part, num_queries=len(statements))
435 _LOG.debug("found %s DiaObjects", objects.shape[0])
436 return objects
438 def getDiaSources(
439 self,
440 region: sphgeom.Region,
441 object_ids: Iterable[int] | None,
442 visit_time: astropy.time.Time,
443 start_time: astropy.time.Time | None = None,
444 ) -> pandas.DataFrame | None:
445 # docstring is inherited from a base class
446 context = self._context
447 config = context.config
449 months = config.read_sources_months
450 if start_time is None and months == 0:
451 return None
453 mjd_end = float(visit_time.tai.mjd)
454 if start_time is None:
455 mjd_start = mjd_end - months * 30
456 else:
457 mjd_start = float(start_time.tai.mjd)
459 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
461 def getDiaForcedSources(
462 self,
463 region: sphgeom.Region,
464 object_ids: Iterable[int] | None,
465 visit_time: astropy.time.Time,
466 start_time: astropy.time.Time | None = None,
467 ) -> pandas.DataFrame | None:
468 # docstring is inherited from a base class
469 context = self._context
470 config = context.config
472 months = config.read_forced_sources_months
473 if start_time is None and months == 0:
474 return None
476 mjd_end = float(visit_time.tai.mjd)
477 if start_time is None:
478 mjd_start = mjd_end - months * 30
479 else:
480 mjd_start = float(start_time.tai.mjd)
482 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
484 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame:
485 # docstring is inherited from a base class
486 context = self._context
487 config = context.config
489 if not context.has_dedup_table:
490 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.")
492 if since is None:
493 # Read last deduplication time from metadata.
494 dedup_str = context.metadata.get(context.metadataDedupKey)
495 if dedup_str is not None:
496 dedup_state = json.loads(dedup_str)
497 dedup_time_str = dedup_state["dedup_time_iso_tai"]
498 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai")
500 column_names = context.schema.apdbColumnNames(ExtraTables.DiaObjectDedup)
502 validity_start_column = self._timestamp_column_name("validityStart")
503 timestamp = None if since is None else self._timestamp_column_value(since)
505 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup)
506 query = Select(self._keyspace, table_name, column_names, extra_clause="ALLOW FILTERING")
507 query = query.where(C("dedup_part") == 0)
508 if since is not None:
509 query = query.where(C(validity_start_column) >= 0)
511 statement = context.stmt_factory(query, prepare=False)
513 num_part = config.partitioning.num_part_dedup
514 statements = []
515 for dedup_part in range(num_part):
516 params = (dedup_part,) if timestamp is None else (dedup_part, timestamp)
517 statements.append((statement, params))
519 with self._timer(
520 "select_time", tags={"table": "DiaObjectDedup", "method": "getDiaObjectsForDedup"}
521 ) as timer:
522 objects_raw = cast(
523 ApdbCassandraTableData,
524 select_concurrent(
525 context.session,
526 statements,
527 "read_raw_multi_dedup",
528 config.connection_config.read_concurrency,
529 ),
530 )
531 objects = objects_raw.to_pandas(context.schema._table_schema(ExtraTables.DiaObjectDedup))
532 timer.add_values(row_count=len(objects), num_queries=num_part)
534 _LOG.debug("found %s DiaObjectDedup records", objects.shape[0])
535 return objects
537 def getDiaSourcesForDiaObjects(
538 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0
539 ) -> pandas.DataFrame:
540 # docstring is inherited from a base class
541 context = self._context
542 config = context.config
544 # Which tables to query and temporal constraints.
545 end_time = self._current_time()
546 tables, temporal_where = context.partitioner.temporal_where(
547 ApdbTables.DiaSource,
548 start_time,
549 end_time,
550 partitons_range=context.time_partitions_range,
551 query_per_time_part=False,
552 )
553 if not tables:
554 warnings.warn(
555 f"Query time range ({start_time.isot} - {end_time.isot}) does not overlap database "
556 "time partitions."
557 )
559 # Group DiaObjects by partition.
560 partitioned_object_ids = self._group_dia_objects_by_partition(
561 context.partitioner, objects, max_dist_arcsec
562 )
564 # Columns to return.
565 column_names = context.schema.apdbColumnNames(ApdbTables.DiaSource)
567 # Make a bunch of queries.
568 statements = []
569 for apdb_part, diaObjectIds in partitioned_object_ids.items():
570 spatial_where = [C("apdb_part") == apdb_part]
571 for table in tables:
572 query = Select(self._keyspace, table, column_names, extra_clause="ALLOW FILTERING")
573 for id_chunk in chunk_iterable(diaObjectIds, 10_000):
574 id_where = C("diaObjectId").in_(id_chunk)
575 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where):
576 statements.append(
577 context.stmt_factory.with_params(query.where(clause), prepare=False)
578 )
580 _LOG.debug("getDiaSourcesForDiaObjects #queries: %s", len(statements))
582 with self._timer(
583 "select_time", tags={"table": "DiaSource", "method": "getDiaSourcesForDiaObjects"}
584 ) as timer:
585 table_data_raw = cast(
586 ApdbCassandraTableData,
587 select_concurrent(
588 context.session,
589 statements,
590 "read_raw_multi",
591 config.connection_config.read_concurrency,
592 ),
593 )
594 catalog = table_data_raw.to_pandas(context.schema._table_schema(ApdbTables.DiaSource))
595 timer.add_values(row_count_from_db=len(catalog), num_queries=len(statements))
597 # precise filtering on midpointMjdTai
598 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] >= start_time.tai.mjd])
600 timer.add_values(row_count=len(catalog))
602 _LOG.debug("found %d DiaSources", len(catalog))
603 return catalog
605 def containsVisitDetector(
606 self,
607 visit: int,
608 detector: int,
609 region: sphgeom.Region,
610 visit_time: astropy.time.Time,
611 ) -> bool:
612 # docstring is inherited from a base class
613 context = self._context
615 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector)
616 query = Select(self._keyspace, table_name, [ColumnExpr("count(*)")])
617 query = query.where((C("visit") == visit) & (C("detector") == detector))
618 stmt, params = context.stmt_factory.with_params(query, prepare=False)
620 with self._timer("contains_visit_detector_time", tags={"table": table_name}):
621 result = context.session.execute(stmt, params)
622 return bool(result.one()[0])
624 def store(
625 self,
626 visit_time: astropy.time.Time,
627 objects: pandas.DataFrame,
628 sources: pandas.DataFrame | None = None,
629 forced_sources: pandas.DataFrame | None = None,
630 ) -> None:
631 # docstring is inherited from a base class
632 context = self._context
633 config = context.config
635 # Store visit/detector in a special table, this has to be done
636 # before all other writes so if there is a failure at any point
637 # later we still have a record for attempted write.
638 visit_detector: set[tuple[int, int]] = set()
639 for df in sources, forced_sources:
640 if df is not None and not df.empty:
641 df = df[["visit", "detector"]]
642 for visit, detector in df.itertuples(index=False):
643 visit_detector.add((visit, detector))
645 if visit_detector:
646 # Typically there is only one entry, do not bother with
647 # concurrency.
648 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector)
649 query = Insert(self._keyspace, table_name, ("visit", "detector"))
650 stmt = context.stmt_factory(query)
651 for item in visit_detector:
652 context.session.execute(stmt, item, execution_profile="write")
654 objects = self._fix_input_timestamps(objects)
655 if sources is not None:
656 sources = self._fix_input_timestamps(sources)
657 if forced_sources is not None:
658 forced_sources = self._fix_input_timestamps(forced_sources)
660 replica_chunk: ReplicaChunk | None = None
661 if context.schema.replication_enabled:
662 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, config.replica_chunk_seconds)
663 self._storeReplicaChunk(replica_chunk)
665 # fill region partition column for DiaObjects
666 objects = self._add_apdb_part(objects)
667 self._storeDiaObjects(objects, visit_time, replica_chunk)
669 if sources is not None and len(sources) > 0:
670 # copy apdb_part column from DiaObjects to DiaSources
671 sources = self._add_apdb_part(sources)
672 subchunk = self._storeDiaSources(ApdbTables.DiaSource, sources, replica_chunk)
673 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk, subchunk)
675 if forced_sources is not None and len(forced_sources) > 0:
676 forced_sources = self._add_apdb_part(forced_sources)
677 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, replica_chunk)
679 def reassignDiaSourcesToDiaObjects(
680 self,
681 idMap: Mapping[DiaSourceId, int],
682 *,
683 increment_nDiaSources: bool = True,
684 decrement_nDiaSources: bool = True,
685 ) -> None:
686 # docstring is inherited from a base class
687 context = self._context
688 config = context.config
690 source_ids = {source_id.diaSourceId for source_id in idMap}
692 # Find all DiaSources.
693 found_sources = self._get_diasource_data(
694 idMap, "apdb_part", "diaObjectId", "ra", "dec", "midpointMjdTai"
695 )
697 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}):
698 raise LookupError(f"Some source IDs were not found in DiaSource table: {missing_ids}")
700 found_sources_by_id = {row.diaSourceId: row for row in found_sources}
702 # Make sure that all DiaObjects exist, we also want to know
703 # nDiaSources count for current and new records because we want to
704 # send updated values to replica.
705 current_object_ids = {
706 DiaObjectId(diaObjectId=row.diaObjectId, ra=row.ra, dec=row.dec) for row in found_sources
707 }
708 # Assume that DiaSource ra/dec are very close to re-assigned objects.
709 new_object_ids = {
710 DiaObjectId(diaObjectId=diaObjectId, ra=source_id.ra, dec=source_id.dec)
711 for source_id, diaObjectId in idMap.items()
712 }
713 all_object_ids = new_object_ids | current_object_ids
714 found_objects = self._get_diaobject_data(all_object_ids, "apdb_part", "ra", "dec", "nDiaSources")
716 if missing_ids := (
717 {row.diaObjectId for row in all_object_ids} - {row.diaObjectId for row in found_objects}
718 ):
719 raise LookupError(f"Some object IDs were not found in DiaObjectLast table: {missing_ids}")
721 update_records: list[ApdbUpdateRecord] = []
722 update_order = 0
723 current_time = self._current_time()
724 current_time_ns = int(current_time.unix_tai * 1e9)
726 # Update DiaSources.
727 statements: list[tuple] = []
728 for source_id, diaObjectId in idMap.items():
729 source_row = found_sources_by_id[source_id.diaSourceId]
730 apdb_part = source_row.apdb_part
731 time_part = context.partitioner.time_partition(source_row.midpointMjdTai)
733 if config.partitioning.time_partition_tables:
734 table_name = context.schema.tableName(ApdbTables.DiaSource, time_part)
735 update = (
736 Update(self._keyspace, table_name)
737 .values(C("diaObjectId").update(diaObjectId))
738 .where(C("apdb_part") == apdb_part)
739 .where(C("diaSourceId") == source_id.diaSourceId)
740 )
741 else:
742 table_name = context.schema.tableName(ApdbTables.DiaSource)
743 update = (
744 Update(self._keyspace, table_name)
745 .values(C("diaObjectId").update(diaObjectId))
746 .where(C("apdb_part") == apdb_part)
747 .where(C("apdb_time_part") == time_part)
748 .where(C("diaSourceId") == source_id.diaSourceId)
749 )
750 statements.append(context.stmt_factory.with_params(update, prepare=True))
752 if context.schema.replication_enabled:
753 update_records.append(
754 ApdbReassignDiaSourceToDiaObjectRecord(
755 diaSourceId=source_id.diaSourceId,
756 ra=source_id.ra,
757 dec=source_id.dec,
758 midpointMjdTai=source_id.midpointMjdTai,
759 diaObjectId=diaObjectId,
760 update_time_ns=current_time_ns,
761 update_order=update_order,
762 )
763 )
764 update_order += 1
766 with self._timer(
767 "update_time", tags={"table": "DiaSource", "method": "reassignDiaSourcesToDiaObjects"}
768 ) as timer:
769 execute_concurrent(context.session, statements, execution_profile="write")
770 timer.add_values(num_queries=len(statements))
772 # Update nDiaSources in DiaObjectLast. We do not update DiaObject table
773 # here because it may not even exist. PPDB updates DiaObject from
774 # update records.
775 if increment_nDiaSources or decrement_nDiaSources:
776 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
777 update = (
778 Update(self._keyspace, table_name)
779 .values(C("nDiaSources").update(-1))
780 .where(C("apdb_part") == -1)
781 .where(C("diaObjectId") == -1)
782 )
783 statement = context.stmt_factory(update, prepare=True)
784 statements = []
786 # Calculate increments/decrements for all affected DiaObjects.
787 increments: Counter = Counter()
788 if increment_nDiaSources:
789 increments.update(idMap.values())
790 if decrement_nDiaSources:
791 increments.subtract(row.diaObjectId for row in found_sources)
793 for row in found_objects:
794 if increments.get(row.diaObjectId):
795 nDiaSources = row.nDiaSources + increments[row.diaObjectId]
796 statements.append((statement, (nDiaSources, row.apdb_part, row.diaObjectId)))
798 # Also send updated values to replica.
799 if context.schema.replication_enabled:
800 update_records.append(
801 ApdbUpdateNDiaSourcesRecord(
802 diaObjectId=row.diaObjectId,
803 ra=row.ra,
804 dec=row.dec,
805 nDiaSources=nDiaSources,
806 update_time_ns=current_time_ns,
807 update_order=update_order,
808 )
809 )
810 update_order += 1
812 if statements:
813 with self._timer(
814 "update_time", tags={"table": table_name, "method": "reassignDiaSourcesToDiaObjects"}
815 ) as timer:
816 execute_concurrent(context.session, statements, execution_profile="write")
817 timer.add_values(num_queries=len(statements))
819 if update_records:
820 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds)
821 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True)
823 def setValidityEnd(
824 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False
825 ) -> int:
826 # docstring is inherited from a base class
827 if not objects:
828 return 0
830 context = self._context
831 config = context.config
833 pad_arcsec = 1.0
834 partitioned_object_ids = self._group_dia_objects_by_partition(
835 context.partitioner, objects, pad_arcsec
836 )
838 # Check that all objects exist.
839 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
840 statements: list[tuple] = []
841 for apdb_part, diaObjectIds in partitioned_object_ids.items():
842 query = Select(self._keyspace, table_name, ["apdb_part", "diaObjectId"])
843 query = query.where(C("apdb_part") == apdb_part)
844 query = query.where(C("diaObjectId").in_(diaObjectIds))
845 statements.append(context.stmt_factory.with_params(query, prepare=False))
847 with self._timer("select_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer:
848 records = cast(
849 list[tuple[int, int]],
850 select_concurrent(
851 context.session,
852 statements,
853 "read_tuples",
854 config.connection_config.read_concurrency,
855 ),
856 )
857 timer.add_values(row_count=len(objects))
859 requested_ids = {obj.diaObjectId for obj in objects}
860 found_ids = {rec[1] for rec in records}
861 if extra_ids := (found_ids - requested_ids):
862 raise RuntimeError(f"Consistency error - found duplicate records for object IDs: {extra_ids}")
863 if raise_on_missing_id:
864 if missing_ids := (requested_ids - found_ids):
865 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}")
867 # Filter existing records.
868 if len(objects) != len(found_ids):
869 objects = [obj for obj in objects if obj.diaObjectId in found_ids]
871 if not objects:
872 return 0
874 # Group by partitions again.
875 grouped_object_ids: dict[int, list[int]] = defaultdict(list)
876 for apdb_part, diaObjectId in records:
877 grouped_object_ids[apdb_part].append(diaObjectId)
879 # Remove all matching rows from DiaObjectLast.
880 statements = []
881 for apdb_part, diaObjectIds in grouped_object_ids.items():
882 delete = (
883 Delete(self._keyspace, table_name)
884 .where(C("apdb_part") == apdb_part)
885 .where(C("diaObjectId").in_(diaObjectIds))
886 )
887 statements.append(context.stmt_factory.with_params(delete))
889 # Also remove from DiaObjectLastToPartition.
890 reverse_table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
891 delete = Delete(self._keyspace, reverse_table_name).where(
892 C("diaObjectId").in_([rec[1] for rec in records])
893 )
894 statements.append(context.stmt_factory.with_params(delete))
896 with self._timer("delete_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer:
897 execute_concurrent(context.session, statements, execution_profile="write")
898 timer.add_values(row_count=len(records))
900 # If repication is enabled then send all updates.
901 if context.schema.replication_enabled:
902 current_time = self._current_time()
903 current_time_ns = int(current_time.unix_tai * 1e9)
904 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds)
906 update_records = [
907 ApdbCloseDiaObjectValidityRecord(
908 diaObjectId=obj.diaObjectId,
909 ra=obj.ra,
910 dec=obj.dec,
911 update_time_ns=current_time_ns,
912 update_order=index,
913 validityEndMjdTai=float(validityEnd.tai.mjd),
914 nDiaSources=None,
915 )
916 for index, obj in enumerate(objects)
917 ]
919 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True)
921 return len(objects)
923 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None:
924 # docstring is inherited from a base class
925 context = self._context
927 if not context.has_dedup_table:
928 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.")
930 if dedup_time is None:
931 dedup_time = self._current_time()
933 validity_start_column = self._timestamp_column_name("validityStart")
935 # Find latest timestamp in deduplication table.
936 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup)
937 query = Select(self._keyspace, table_name, [ColumnExpr(f'MAX("{validity_start_column}")')])
938 stmt = context.stmt_factory(query, prepare=False)
940 result = context.session.execute(stmt, execution_profile="read_tuples")
941 max_value = result.one()[0]
942 if self._schema.has_mjd_timestamps:
943 max_validity_start = astropy.time.Time(max_value, format="mjd", scale="tai")
944 else:
945 max_validity_start = astropy.time.Time(max_value, format="datetime", scale="tai")
947 # If max time is lower than dedup time we can do TRUNCATE.
948 if dedup_time >= max_validity_start:
949 query_str = f'TRUNCATE TABLE "{self._keyspace}"."{table_name}"'
950 context.session.execute(query_str, execution_profile="write")
951 else:
952 dedup_time_value = self._timestamp_column_value(dedup_time)
953 delete = Delete(self._keyspace, table_name).where(C(validity_start_column) < dedup_time_value)
954 stmt, params = context.stmt_factory.with_params(delete)
955 context.session.execute(stmt, params, execution_profile="write")
957 # Store dedup time.
958 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")}
959 data_json = json.dumps(data)
960 context.metadata.set(context.metadataDedupKey, data_json, force=True)
962 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
963 # docstring is inherited from a base class
964 context = self._context
965 config = context.config
967 now = self._current_time()
968 reassign_time_column = self._timestamp_column_name("ssObjectReassocTime")
969 reassignTime = self._timestamp_column_value(now)
971 # To update a record we need to know its exact primary key (including
972 # partition key) so we start by querying for diaSourceId to find the
973 # primary keys.
975 table_name = context.schema.tableName(ExtraTables.DiaSourceToPartition)
976 # split it into 1k IDs per query
977 selects: list[tuple] = []
978 columns = ["diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk"]
979 query = Select(self._keyspace, table_name, columns)
980 for ids in chunk_iterable(idMap.keys(), 1_000):
981 full_query = query.where(C("diaSourceId").in_(ids))
982 selects.append(context.stmt_factory.with_params(full_query, prepare=False))
984 # No need for DataFrame here, read data as tuples.
985 result = cast(
986 list[tuple[int, int, int, int | None]],
987 select_concurrent(
988 context.session, selects, "read_tuples", config.connection_config.read_concurrency
989 ),
990 )
992 # Make mapping from source ID to its partition.
993 id2partitions: dict[int, tuple[int, int]] = {}
994 id2chunk_id: dict[int, int] = {}
995 for row in result:
996 id2partitions[row[0]] = row[1:3]
997 if row[3] is not None:
998 id2chunk_id[row[0]] = row[3]
1000 # make sure we know partitions for each ID
1001 if set(id2partitions) != set(idMap):
1002 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
1003 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
1005 # Reassign in standard tables
1006 queries: list[tuple[cassandra.query.PreparedStatement, tuple]] = []
1007 for diaSourceId, ssObjectId in idMap.items():
1008 apdb_part, apdb_time_part = id2partitions[diaSourceId]
1009 if config.partitioning.time_partition_tables:
1010 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part)
1011 update = (
1012 Update(self._keyspace, table_name)
1013 .values(
1014 C("ssObjectId").update(ssObjectId),
1015 C("diaObjectId").update(None),
1016 C(reassign_time_column).update(reassignTime),
1017 )
1018 .where(C("apdb_part") == apdb_part)
1019 .where(C("diaSourceId") == diaSourceId)
1020 )
1021 else:
1022 table_name = context.schema.tableName(ApdbTables.DiaSource)
1023 update = (
1024 Update(self._keyspace, table_name)
1025 .values(
1026 C("ssObjectId").update(ssObjectId),
1027 C("diaObjectId").update(None),
1028 C(reassign_time_column).update(reassignTime),
1029 )
1030 .where(C("apdb_part") == apdb_part)
1031 .where(C("apdb_time_part") == apdb_time_part)
1032 .where(C("diaSourceId") == diaSourceId)
1033 )
1034 queries.append(context.stmt_factory.with_params(update, prepare=True))
1036 # TODO: (DM-50190) Replication for updated records is not implemented.
1037 if id2chunk_id:
1038 warnings.warn("Replication of reassigned DiaSource records is not implemented.", stacklevel=2)
1040 _LOG.debug("%s: will update %d records", table_name, len(idMap))
1041 with self._timer("source_reassign_time") as timer:
1042 execute_concurrent(context.session, queries, execution_profile="write")
1043 timer.add_values(source_count=len(idMap))
1045 def countUnassociatedObjects(self) -> int:
1046 # docstring is inherited from a base class
1048 # It's too inefficient to implement it for Cassandra in current schema.
1049 raise NotImplementedError()
1051 @property
1052 def schema(self) -> ApdbSchema:
1053 # docstring is inherited from a base class
1054 return self._schema
1056 @property
1057 def metadata(self) -> ApdbMetadata:
1058 # docstring is inherited from a base class
1059 context = self._context
1060 return context.metadata
1062 @property
1063 def admin(self) -> ApdbCassandraAdmin:
1064 # docstring is inherited from a base class
1065 return ApdbCassandraAdmin(self)
1067 def _getSources(
1068 self,
1069 region: sphgeom.Region,
1070 object_ids: Iterable[int] | None,
1071 mjd_start: float,
1072 mjd_end: float,
1073 table_name: ApdbTables,
1074 ) -> pandas.DataFrame:
1075 """Return catalog of DiaSource instances given set of DiaObject IDs.
1077 Parameters
1078 ----------
1079 region : `lsst.sphgeom.Region`
1080 Spherical region.
1081 object_ids :
1082 Collection of DiaObject IDs
1083 mjd_start : `float`
1084 Lower bound of time interval.
1085 mjd_end : `float`
1086 Upper bound of time interval.
1087 table_name : `ApdbTables`
1088 Name of the table.
1090 Returns
1091 -------
1092 catalog : `pandas.DataFrame`, or `None`
1093 Catalog containing DiaSource records. Empty catalog is returned if
1094 ``object_ids`` is empty.
1095 """
1096 context = self._context
1097 config = context.config
1099 object_id_set: Set[int] = set()
1100 if object_ids is not None:
1101 object_id_set = set(object_ids)
1102 if len(object_id_set) == 0:
1103 return self._make_empty_catalog(table_name)
1105 sp_where, num_sp_part = context.partitioner.spatial_where(region)
1106 tables, temporal_where = context.partitioner.temporal_where(
1107 table_name, mjd_start, mjd_end, partitons_range=context.time_partitions_range
1108 )
1109 if not tables:
1110 start = astropy.time.Time(mjd_start, format="mjd", scale="tai")
1111 end = astropy.time.Time(mjd_end, format="mjd", scale="tai")
1112 warnings.warn(
1113 f"Query time range ({start.isot} - {end.isot}) does not overlap database time partitions."
1114 )
1116 # We need to exclude extra partitioning columns from result.
1117 column_names = context.schema.apdbColumnNames(table_name)
1119 # Build all queries
1120 statements: list[tuple] = []
1121 for table in tables:
1122 query = Select(self._keyspace, table, column_names)
1123 for clause in QExpr.combine(sp_where, temporal_where):
1124 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True))
1125 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
1127 with self._timer("select_time", tags={"table": table_name.name, "method": "_getSources"}) as timer:
1128 table_data_raw = cast(
1129 ApdbCassandraTableData,
1130 select_concurrent(
1131 context.session,
1132 statements,
1133 "read_raw_multi",
1134 config.connection_config.read_concurrency,
1135 ),
1136 )
1137 catalog = table_data_raw.to_pandas(context.schema._table_schema(table_name))
1138 timer.add_values(
1139 row_count_from_db=len(catalog), num_sp_part=num_sp_part, num_queries=len(statements)
1140 )
1142 # filter by given object IDs
1143 if len(object_id_set) > 0:
1144 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
1146 # precise filtering on midpointMjdTai
1147 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
1149 timer.add_values(row_count=len(catalog))
1151 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
1152 return catalog
1154 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk) -> None:
1155 context = self._context
1156 config = context.config
1158 # Cassandra timestamp uses milliseconds since epoch
1159 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000)
1161 # everything goes into a single partition
1162 partition = 0
1164 table_name = context.schema.tableName(ExtraTables.ApdbReplicaChunks)
1166 columns = ["partition", "apdb_replica_chunk", "last_update_time", "unique_id"]
1167 values = [partition, replica_chunk.id, timestamp, replica_chunk.unique_id]
1168 if context.has_chunk_sub_partitions:
1169 columns.append("has_subchunks")
1170 values.append(True)
1172 query = Insert(self._keyspace, table_name, columns)
1173 stmt = context.stmt_factory(query)
1175 context.session.execute(
1176 stmt,
1177 values,
1178 timeout=config.connection_config.write_timeout,
1179 execution_profile="write",
1180 )
1182 def _queryDiaObjectLastPartitions(self, ids: Iterable[int]) -> Mapping[int, int]:
1183 """Return existing mapping of diaObjectId to its last partition."""
1184 context = self._context
1185 config = context.config
1187 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
1188 queries = []
1189 object_count = 0
1190 for id_chunk in chunk_iterable(ids, 10_000):
1191 id_chunk_list = tuple(id_chunk)
1192 query = Select(self._keyspace, table_name, ("diaObjectId", "apdb_part"))
1193 query = query.where(C("diaObjectId").in_(id_chunk_list))
1194 queries.append(context.stmt_factory.with_params(query, prepare=False))
1195 object_count += len(id_chunk_list)
1197 with self._timer("query_object_last_partitions", tags={"table": table_name}) as timer:
1198 data = cast(
1199 ApdbTableData,
1200 select_concurrent(
1201 context.session,
1202 queries,
1203 "read_raw_multi",
1204 config.connection_config.read_concurrency,
1205 ),
1206 )
1207 timer.add_values(object_count=object_count, row_count=len(data.rows()))
1209 if data.column_names() != ["diaObjectId", "apdb_part"]:
1210 raise RuntimeError(f"Unexpected column names in query result: {data.column_names()}")
1212 return {row[0]: row[1] for row in data.rows()}
1214 def _deleteMovingObjects(self, objs: pandas.DataFrame) -> None:
1215 """Objects in DiaObjectsLast can move from one spatial partition to
1216 another. For those objects inserting new version does not replace old
1217 one, so we need to explicitly remove old versions before inserting new
1218 ones.
1219 """
1220 context = self._context
1222 # Extract all object IDs.
1223 new_partitions = dict(zip(objs["diaObjectId"], objs["apdb_part"]))
1224 old_partitions = self._queryDiaObjectLastPartitions(objs["diaObjectId"])
1226 moved_oids: dict[int, tuple[int, int]] = {}
1227 for oid, old_part in old_partitions.items():
1228 new_part = new_partitions.get(oid, old_part)
1229 if new_part != old_part:
1230 moved_oids[oid] = (old_part, new_part)
1231 _LOG.debug("DiaObject IDs that moved to new partition: %s", moved_oids)
1233 if moved_oids:
1234 # Delete old records from DiaObjectLast.
1235 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
1236 query = Delete(self._keyspace, table_name)
1237 query = query.where('apdb_part = {} AND "diaObjectId" = {}', (-1, -1))
1238 statement = context.stmt_factory(query, prepare=True)
1239 queries = []
1240 for oid, (old_part, _) in moved_oids.items():
1241 queries.append((statement, (old_part, oid)))
1242 with self._timer("delete_object_last", tags={"table": table_name}) as timer:
1243 execute_concurrent(context.session, queries, execution_profile="write")
1244 timer.add_values(row_count=len(moved_oids))
1246 # Add all new records to the map.
1247 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
1248 insert = Insert(self._keyspace, table_name, ("diaObjectId", "apdb_part"))
1249 statement = context.stmt_factory(insert, prepare=True)
1251 queries = []
1252 for oid, new_part in new_partitions.items():
1253 queries.append((statement, (oid, new_part)))
1255 with self._timer("update_object_last_partition", tags={"table": table_name}) as timer:
1256 execute_concurrent(context.session, queries, execution_profile="write")
1257 timer.add_values(row_count=len(queries))
1259 def _storeDiaObjects(
1260 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None
1261 ) -> None:
1262 """Store catalog of DiaObjects from current visit.
1264 Parameters
1265 ----------
1266 objs : `pandas.DataFrame`
1267 Catalog with DiaObject records
1268 visit_time : `astropy.time.Time`
1269 Time of the current visit.
1270 replica_chunk : `ReplicaChunk` or `None`
1271 Replica chunk identifier if replication is configured.
1272 """
1273 if len(objs) == 0:
1274 _LOG.debug("No objects to write to database.")
1275 return
1277 context = self._context
1278 config = context.config
1280 self._deleteMovingObjects(objs)
1282 validity_start_column = self._timestamp_column_name("validityStart")
1283 timestamp = self._timestamp_column_value(visit_time)
1285 # DiaObjectLast did not have this column in the past.
1286 extra_columns: dict[str, Any] = {}
1287 if context.schema.check_column(ApdbTables.DiaObjectLast, validity_start_column):
1288 extra_columns[validity_start_column] = timestamp
1290 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
1292 extra_columns[validity_start_column] = timestamp
1293 visit_time_part = context.partitioner.time_partition(visit_time)
1294 time_part: int | None = visit_time_part
1295 if (time_partitions_range := context.time_partitions_range) is not None:
1296 self._check_time_partitions([visit_time_part], time_partitions_range)
1297 if not config.partitioning.time_partition_tables:
1298 extra_columns["apdb_time_part"] = time_part
1299 time_part = None
1301 # Only store DiaObects if not doing replication or explicitly
1302 # configured to always store them.
1303 if replica_chunk is None or not config.replica_skips_diaobjects:
1304 self._storeObjectsPandas(
1305 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part
1306 )
1308 if replica_chunk is not None:
1309 extra_columns = {"apdb_replica_chunk": replica_chunk.id, validity_start_column: timestamp}
1310 table = ExtraTables.DiaObjectChunks
1311 if context.has_chunk_sub_partitions:
1312 table = ExtraTables.DiaObjectChunks2
1313 # Use a random number for a second part of partitioning key so
1314 # that different clients could wrtite to different partitions.
1315 # This makes it not exactly reproducible.
1316 extra_columns["apdb_replica_subchunk"] = random.randrange(config.replica_sub_chunk_count)
1317 self._storeObjectsPandas(objs, table, extra_columns=extra_columns)
1319 # Store copy of the records in dedup table.
1320 if context.has_dedup_table:
1321 table = ExtraTables.DiaObjectDedup
1322 extra_columns = {
1323 "dedup_part": random.randrange(config.partitioning.num_part_dedup),
1324 validity_start_column: timestamp,
1325 }
1326 self._storeObjectsPandas(objs, table, extra_columns=extra_columns)
1328 def _storeDiaSources(
1329 self,
1330 table_name: ApdbTables,
1331 sources: pandas.DataFrame,
1332 replica_chunk: ReplicaChunk | None,
1333 ) -> int | None:
1334 """Store catalog of DIASources or DIAForcedSources from current visit.
1336 Parameters
1337 ----------
1338 table_name : `ApdbTables`
1339 Table where to store the data.
1340 sources : `pandas.DataFrame`
1341 Catalog containing DiaSource records
1342 visit_time : `astropy.time.Time`
1343 Time of the current visit.
1344 replica_chunk : `ReplicaChunk` or `None`
1345 Replica chunk identifier if replication is configured.
1347 Returns
1348 -------
1349 subchunk : `int` or `None`
1350 Subchunk number for resulting replica data, `None` if relication is
1351 not enabled ot subchunking is not enabled.
1352 """
1353 context = self._context
1354 config = context.config
1356 # Time partitioning has to be based on midpointMjdTai, not visit_time
1357 # as visit_time is not really a visit time.
1358 tp_sources = sources.copy(deep=False)
1359 tp_sources["apdb_time_part"] = tp_sources["midpointMjdTai"].apply(context.partitioner.time_partition)
1360 if (time_partitions_range := context.time_partitions_range) is not None:
1361 self._check_time_partitions(tp_sources["apdb_time_part"], time_partitions_range)
1362 extra_columns: dict[str, Any] = {}
1363 if not config.partitioning.time_partition_tables:
1364 self._storeObjectsPandas(tp_sources, table_name)
1365 else:
1366 # Group by time partition
1367 partitions = set(tp_sources["apdb_time_part"])
1368 if len(partitions) == 1:
1369 # Single partition - just save the whole thing.
1370 time_part = partitions.pop()
1371 self._storeObjectsPandas(sources, table_name, time_part=time_part)
1372 else:
1373 # group by time partition.
1374 for time_part, sub_frame in tp_sources.groupby(by="apdb_time_part"):
1375 sub_frame.drop(columns="apdb_time_part", inplace=True)
1376 self._storeObjectsPandas(sub_frame, table_name, time_part=time_part)
1378 subchunk: int | None = None
1379 if replica_chunk is not None:
1380 extra_columns = {"apdb_replica_chunk": replica_chunk.id}
1381 if context.has_chunk_sub_partitions:
1382 subchunk = random.randrange(config.replica_sub_chunk_count)
1383 extra_columns["apdb_replica_subchunk"] = subchunk
1384 if table_name is ApdbTables.DiaSource:
1385 extra_table = ExtraTables.DiaSourceChunks2
1386 else:
1387 extra_table = ExtraTables.DiaForcedSourceChunks2
1388 else:
1389 if table_name is ApdbTables.DiaSource:
1390 extra_table = ExtraTables.DiaSourceChunks
1391 else:
1392 extra_table = ExtraTables.DiaForcedSourceChunks
1393 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
1395 return subchunk
1397 def _check_time_partitions(
1398 self, partitions: Iterable[int], time_partitions_range: ApdbCassandraTimePartitionRange
1399 ) -> None:
1400 """Check that time partitons for new data actually exist.
1402 Parameters
1403 ----------
1404 partitions : `~collections.abc.Iterable` [`int`]
1405 Time partitions for new data.
1406 time_partitions_range : `ApdbCassandraTimePartitionRange`
1407 Currrent time partition range.
1408 """
1409 partitions = set(partitions)
1410 min_part = min(partitions)
1411 max_part = max(partitions)
1412 if min_part < time_partitions_range.start or max_part > time_partitions_range.end:
1413 raise ValueError(
1414 "Attempt to store data for time partitions that do not yet exist. "
1415 f"Partitons for new records: {min_part}-{max_part}. "
1416 f"Database partitons: {time_partitions_range.start}-{time_partitions_range.end}."
1417 )
1418 # Make a noise when writing to the last partition.
1419 if max_part == time_partitions_range.end:
1420 warnings.warn(
1421 "Writing into the last temporal partition. Partition range needs to be extended soon.",
1422 stacklevel=3,
1423 )
1425 def _storeDiaSourcesPartitions(
1426 self,
1427 sources: pandas.DataFrame,
1428 visit_time: astropy.time.Time,
1429 replica_chunk: ReplicaChunk | None,
1430 subchunk: int | None,
1431 ) -> None:
1432 """Store mapping of diaSourceId to its partitioning values.
1434 Parameters
1435 ----------
1436 sources : `pandas.DataFrame`
1437 Catalog containing DiaSource records
1438 visit_time : `astropy.time.Time`
1439 Time of the current visit.
1440 replica_chunk : `ReplicaChunk` or `None`
1441 Replication chunk, or `None` when replication is disabled.
1442 subchunk : `int` or `None`
1443 Replication sub-chunk, or `None` when replication is disabled or
1444 sub-chunking is not used.
1445 """
1446 context = self._context
1448 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
1449 extra_columns = {
1450 "apdb_time_part": context.partitioner.time_partition(visit_time),
1451 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None,
1452 }
1453 if context.has_chunk_sub_partitions:
1454 extra_columns["apdb_replica_subchunk"] = subchunk
1456 self._storeObjectsPandas(
1457 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
1458 )
1460 def _storeObjectsPandas(
1461 self,
1462 records: pandas.DataFrame,
1463 table_name: ApdbTables | ExtraTables,
1464 extra_columns: Mapping | None = None,
1465 time_part: int | None = None,
1466 ) -> None:
1467 """Store generic objects.
1469 Takes Pandas catalog and stores a bunch of records in a table.
1471 Parameters
1472 ----------
1473 records : `pandas.DataFrame`
1474 Catalog containing object records
1475 table_name : `ApdbTables`
1476 Name of the table as defined in APDB schema.
1477 extra_columns : `dict`, optional
1478 Mapping (column_name, column_value) which gives fixed values for
1479 columns in each row, overrides values in ``records`` if matching
1480 columns exist there.
1481 time_part : `int`, optional
1482 If not `None` then insert into a per-partition table.
1484 Notes
1485 -----
1486 If Pandas catalog contains additional columns not defined in table
1487 schema they are ignored. Catalog does not have to contain all columns
1488 defined in a table, but partition and clustering keys must be present
1489 in a catalog or ``extra_columns``.
1490 """
1491 context = self._context
1493 # use extra columns if specified
1494 if extra_columns is None:
1495 extra_columns = {}
1496 extra_fields = list(extra_columns.keys())
1498 # Fields that will come from dataframe.
1499 df_fields = [column for column in records.columns if column not in extra_fields]
1501 column_map = context.schema.getColumnMap(table_name)
1502 # list of columns (as in felis schema)
1503 fields = [column_map[field].name for field in df_fields if field in column_map]
1504 fields += extra_fields
1506 # check that all partitioning and clustering columns are defined
1507 partition_columns = context.schema.partitionColumns(table_name)
1508 required_columns = partition_columns + context.schema.clusteringColumns(table_name)
1509 missing_columns = [column for column in required_columns if column not in fields]
1510 if missing_columns:
1511 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
1513 batch_size = self._batch_size(table_name)
1515 with self._timer("insert_build_time", tags={"table": table_name.name}):
1516 # Multi-partition batches are problematic in general, so we want to
1517 # group records in a batch by their partition key.
1518 values_by_key: dict[tuple, list[list]] = defaultdict(list)
1519 for rec in records.itertuples(index=False):
1520 values = []
1521 partitioning_values: dict[str, Any] = {}
1522 for field in df_fields:
1523 if field not in column_map:
1524 continue
1525 value = getattr(rec, field)
1526 if column_map[field].datatype is felis.datamodel.DataType.timestamp:
1527 if isinstance(value, pandas.Timestamp):
1528 value = value.to_pydatetime()
1529 elif value is pandas.NaT:
1530 value = None
1531 else:
1532 # Assume it's seconds since epoch, Cassandra
1533 # datetime is in milliseconds
1534 value = int(value * 1000)
1535 value = literal(value)
1536 values.append(UNSET_VALUE if value is None else value)
1537 if field in partition_columns:
1538 partitioning_values[field] = value
1539 for field in extra_fields:
1540 value = literal(extra_columns[field])
1541 values.append(UNSET_VALUE if value is None else value)
1542 if field in partition_columns:
1543 partitioning_values[field] = value
1545 key = tuple(partitioning_values[field] for field in partition_columns)
1546 values_by_key[key].append(values)
1548 table = context.schema.tableName(table_name, time_part)
1550 query = Insert(self._keyspace, table, fields)
1551 statement = context.stmt_factory(query, prepare=True)
1552 # Cassandra has 64k limit on batch size, normally that should be
1553 # enough but some tests generate too many forced sources.
1554 queries = []
1555 for key_values in values_by_key.values():
1556 for values_chunk in chunk_iterable(key_values, batch_size):
1557 batch = cassandra.query.BatchStatement()
1558 for row_values in values_chunk:
1559 batch.add(statement, row_values)
1560 queries.append((batch, None))
1561 assert batch.routing_key is not None and batch.keyspace is not None
1563 _LOG.debug("%s: will store %d records", context.schema.tableName(table_name), records.shape[0])
1564 with self._timer(
1565 "insert_time", tags={"table": table_name.name, "method": "_storeObjectsPandas"}
1566 ) as timer:
1567 execute_concurrent(context.session, queries, execution_profile="write")
1568 timer.add_values(row_count=len(records), num_batches=len(queries))
1570 def _storeUpdateRecords(
1571 self, records: Iterable[ApdbUpdateRecord], chunk: ReplicaChunk, *, store_chunk: bool = False
1572 ) -> None:
1573 """Store ApdbUpdateRecords in the replica table for those records.
1575 Parameters
1576 ----------
1577 records : `list` [`ApdbUpdateRecord`]
1578 Records to store.
1579 chunk : `ReplicaChunk`
1580 Replica chunk for these records.
1581 store_chunk : `bool`
1582 If True then also store replica chunk.
1584 Raises
1585 ------
1586 TypeError
1587 Raised if replication is not enabled for this instance.
1588 """
1589 context = self._context
1590 config = context.config
1592 if not context.schema.replication_enabled:
1593 raise TypeError("Replication is not enabled for this APDB instance.")
1595 if store_chunk:
1596 self._storeReplicaChunk(chunk)
1598 apdb_replica_chunk = chunk.id
1599 # Do not use unique_if from ReplicaChunk as it could be reused in
1600 # multiple calls to this method.
1601 update_unique_id = uuid.uuid4()
1603 rows = []
1604 for record in records:
1605 rows.append(
1606 [
1607 apdb_replica_chunk,
1608 record.update_time_ns,
1609 record.update_order,
1610 update_unique_id,
1611 record.to_json(),
1612 ]
1613 )
1614 columns = [
1615 "apdb_replica_chunk",
1616 "update_time_ns",
1617 "update_order",
1618 "update_unique_id",
1619 "update_payload",
1620 ]
1621 if context.has_chunk_sub_partitions:
1622 subchunk = random.randrange(config.replica_sub_chunk_count)
1623 for row in rows:
1624 row.append(subchunk)
1625 columns.append("apdb_replica_subchunk")
1627 table_name = context.schema.tableName(ExtraTables.ApdbUpdateRecordChunks)
1628 query = Insert(self._keyspace, table_name, columns)
1629 stmt = context.stmt_factory(query)
1630 queries = [(stmt, row) for row in rows]
1632 with self._timer("store_update_record", tags={"table": table_name}) as timer:
1633 execute_concurrent(context.session, queries, execution_profile="write")
1634 timer.add_values(row_count=len(queries))
1636 def _add_apdb_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
1637 """Calculate spatial partition for each record and add it to a
1638 DataFrame.
1640 Parameters
1641 ----------
1642 df : `pandas.DataFrame`
1643 DataFrame which has to contain ra/dec columns, names of these
1644 columns are defined by configuration ``ra_dec_columns`` field.
1646 Returns
1647 -------
1648 df : `pandas.DataFrame`
1649 DataFrame with ``apdb_part`` column which contains pixel index
1650 for ra/dec coordinates.
1652 Notes
1653 -----
1654 This overrides any existing column in a DataFrame with the same name
1655 (``apdb_part``). Original DataFrame is not changed, copy of a DataFrame
1656 is returned.
1657 """
1658 context = self._context
1659 config = context.config
1661 # Calculate pixelization index for every record.
1662 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
1663 ra_col, dec_col = config.ra_dec_columns
1664 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1665 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
1666 idx = context.partitioner.pixel(uv3d)
1667 apdb_part[i] = idx
1668 df = df.copy()
1669 df["apdb_part"] = apdb_part
1670 return df
1672 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
1673 """Make an empty catalog for a table with a given name.
1675 Parameters
1676 ----------
1677 table_name : `ApdbTables`
1678 Name of the table.
1680 Returns
1681 -------
1682 catalog : `pandas.DataFrame`
1683 An empty catalog.
1684 """
1685 table = self.schema.tableSchemas[table_name]
1687 data = {columnDef.name: pandas.Series(dtype=columnDef.pandas_type) for columnDef in table.columns}
1688 return pandas.DataFrame(data)
1690 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame:
1691 """Update timestamp columns in input DataFrame to be naive datetime
1692 type.
1694 Clients may or may not generate aware timestamps, code in this class
1695 assumes that timestamps are naive, so we convert them to UTC and
1696 drop timezone.
1697 """
1698 # Find all columns with aware timestamps.
1699 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)]
1700 for column in columns:
1701 # tz_convert(None) will convert to UTC and drop timezone.
1702 df[column] = df[column].dt.tz_convert(None)
1703 return df
1705 def _batch_size(self, table: ApdbTables | ExtraTables) -> int:
1706 """Calculate batch size based on config parameters."""
1707 context = self._context
1708 config = context.config
1710 # Cassandra limit on number of statements in a batch is 64k.
1711 batch_size = 65_535
1712 if 0 < config.batch_statement_limit < batch_size:
1713 batch_size = config.batch_statement_limit
1714 if config.batch_size_limit > 0:
1715 # The purpose of this limit is to try not to exceed batch size
1716 # threshold which is set on server side. Cassandra wire protocol
1717 # for prepared queries (and batches) only sends column values with
1718 # with an additional 4 bytes per value specifying size. Value is
1719 # not included for NULL or NOT_SET values, but the size is always
1720 # there. There is additional small per-query overhead, which we
1721 # ignore.
1722 row_size = context.schema.table_row_size(table)
1723 row_size += 4 * len(context.schema.getColumnMap(table))
1724 batch_size = min(batch_size, (config.batch_size_limit // row_size) + 1)
1725 return batch_size
1727 def _group_dia_objects_by_partition(
1728 self, partitioner: Partitioner, objects: list[DiaObjectId], pad_arcsec: float
1729 ) -> Mapping[int, list[int]]:
1730 """Group DiaObjects by partition.
1732 Parameters
1733 ----------
1734 partitioner : `Partitioner`
1735 Objects which knows how to partition things.
1736 objects : `list` [`DiaObjectId`]
1737 Collection of objects to partition.
1738 pad_arcsec : `float`
1739 Additional padding around object position.
1741 Returns
1742 -------
1743 grouped_objects
1744 Mapping of spatial patition ID to list ob object IDs that it
1745 contains. Some objects may belong to more than one partition.
1746 """
1747 partitioned_object_ids: dict[int, list[int]] = defaultdict(list)
1748 for obj_id in objects:
1749 partitions = partitioner.pixelization.circle_pixels(obj_id.ra, obj_id.dec, pad_arcsec)
1750 for pixel in partitions:
1751 partitioned_object_ids[pixel].append(obj_id.diaObjectId)
1752 return partitioned_object_ids
1754 def _timestamp_column_name(self, column: str) -> str:
1755 """Return column name before/after schema migration to MJD TAI."""
1756 return self._schema.timestamp_column_name(column)
1758 def _timestamp_column_value(self, time: astropy.time.Time) -> float | int:
1759 """Return column value before/after schema migration to MJD TAI."""
1760 if self._schema.has_mjd_timestamps:
1761 return float(time.tai.mjd)
1762 else:
1763 return int(time.datetime.astimezone(tz=datetime.UTC).timestamp() * 1000)
1765 def _get_diasource_data(self, source_ids: Iterable[DiaSourceId], *columns: str) -> list:
1766 """Select records from DiaSource table by diaSourceId and return all
1767 records as a list of named tuples.
1768 """
1769 context = self._context
1770 config = context.config
1771 partitioner = context.partitioner
1773 columns = ("diaSourceId",) + columns
1775 # Allow some uncertainty for coordinates and time when calculating
1776 # partitions.
1777 statements: list[tuple] = []
1778 pad_arcsec = 1.0
1779 pad_time_day = 10 / (24 * 3600)
1780 for source_id in source_ids:
1781 center = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(source_id.ra, source_id.dec))
1782 region = sphgeom.Circle(center, sphgeom.Angle.fromDegrees(pad_arcsec / 3600.0))
1783 spatial_where, _ = partitioner.spatial_where(region)
1785 tables, temporal_where = partitioner.temporal_where(
1786 ApdbTables.DiaSource,
1787 source_id.midpointMjdTai - pad_time_day,
1788 source_id.midpointMjdTai + pad_time_day,
1789 partitons_range=context.time_partitions_range,
1790 query_per_time_part=True,
1791 )
1793 id_where = QExpr('"diaSourceId" = {}', (source_id.diaSourceId,))
1795 for table in tables:
1796 query = Select(self._keyspace, table, columns)
1797 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where):
1798 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True))
1800 with self._timer(
1801 "select_time", tags={"table": "DiaSource", "method": "_get_diasource_data"}
1802 ) as timer:
1803 result = cast(
1804 list[tuple],
1805 select_concurrent(
1806 context.session,
1807 statements,
1808 "read_named_tuples",
1809 config.connection_config.read_concurrency,
1810 ),
1811 )
1812 timer.add_values(row_count=len(result), num_queries=len(statements))
1814 return result
1816 def _get_diaobject_data(self, object_ids: Iterable[DiaObjectId], *columns: str) -> list:
1817 """Select records from DiaObjectLast table by diaObjectId and return
1818 all records as a list of named tuples.
1819 """
1820 context = self._context
1821 config = context.config
1822 partitioner = context.partitioner
1824 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
1825 columns = ("diaObjectId",) + columns
1827 # Allow some uncertainty for coordinates when calculating partitions.
1828 pad_arcsec = 1.0
1829 ids_by_partition = defaultdict(list)
1830 for object_id in object_ids:
1831 pixels = partitioner.pixelization.circle_pixels(object_id.ra, object_id.dec, pad_arcsec)
1832 for pixel in pixels:
1833 ids_by_partition[pixel].append(object_id.diaObjectId)
1835 statements: list[tuple] = []
1836 for apdb_part, diaObjectIds in ids_by_partition.items():
1837 query = Select(self._keyspace, table_name, columns)
1838 query = query.where(C("apdb_part") == apdb_part)
1839 query = query.where(C("diaObjectId").in_(diaObjectIds))
1840 statements.append(context.stmt_factory.with_params(query, prepare=False))
1842 with self._timer(
1843 "select_time", tags={"table": "DiaObjectLast", "method": "_get_diaobject_data"}
1844 ) as timer:
1845 result = cast(
1846 list[tuple],
1847 select_concurrent(
1848 context.session,
1849 statements,
1850 "read_named_tuples",
1851 config.connection_config.read_concurrency,
1852 ),
1853 )
1854 timer.add_values(row_count=len(result), num_queries=len(statements))
1856 return result