Coverage for python/lsst/dax/apdb/apdbCassandra.py: 18%
619 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 09:59 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 09:59 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
26import dataclasses
27import json
28import logging
29import uuid
30from collections.abc import Iterable, Iterator, Mapping, Set
31from typing import TYPE_CHECKING, Any, cast
33import numpy as np
34import pandas
36# If cassandra-driver is not there the module can still be imported
37# but ApdbCassandra cannot be instantiated.
38try:
39 import cassandra
40 import cassandra.query
41 from cassandra.auth import AuthProvider, PlainTextAuthProvider
42 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session
43 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
45 CASSANDRA_IMPORTED = True
46except ImportError:
47 CASSANDRA_IMPORTED = False
49import astropy.time
50import felis.types
51from felis.simple import Table
52from lsst import sphgeom
53from lsst.pex.config import ChoiceField, Field, ListField
54from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError
55from lsst.utils.iteration import chunk_iterable
57from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
58from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
59from .apdbConfigFreezer import ApdbConfigFreezer
60from .apdbMetadataCassandra import ApdbMetadataCassandra
61from .apdbSchema import ApdbTables
62from .cassandra_utils import (
63 ApdbCassandraTableData,
64 PreparedStatementCache,
65 literal,
66 pandas_dataframe_factory,
67 quote_id,
68 raw_data_factory,
69 select_concurrent,
70)
71from .pixelization import Pixelization
72from .timer import Timer
73from .versionTuple import IncompatibleVersionError, VersionTuple
75if TYPE_CHECKING:
76 from .apdbMetadata import ApdbMetadata
78_LOG = logging.getLogger(__name__)
80VERSION = VersionTuple(0, 1, 0)
81"""Version for the code defined in this module. This needs to be updated
82(following compatibility rules) when schema produced by this code changes.
83"""
85# Copied from daf_butler.
86DB_AUTH_ENVVAR = "LSST_DB_AUTH"
87"""Default name of the environmental variable that will be used to locate DB
88credentials configuration file. """
90DB_AUTH_PATH = "~/.lsst/db-auth.yaml"
91"""Default path at which it is expected that DB credentials are found."""
94class CassandraMissingError(Exception):
95 def __init__(self) -> None:
96 super().__init__("cassandra-driver module cannot be imported")
99class ApdbCassandraConfig(ApdbConfig):
100 """Configuration class for Cassandra-based APDB implementation."""
102 contact_points = ListField[str](
103 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"]
104 )
105 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[])
106 port = Field[int](doc="Port number to connect to.", default=9042)
107 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb")
108 username = Field[str](
109 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.",
110 default="",
111 )
112 read_consistency = Field[str](
113 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM"
114 )
115 write_consistency = Field[str](
116 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM"
117 )
118 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0)
119 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0)
120 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0)
121 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500)
122 protocol_version = Field[int](
123 doc="Cassandra protocol version to use, default is V4",
124 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0,
125 )
126 dia_object_columns = ListField[str](
127 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[]
128 )
129 prefix = Field[str](doc="Prefix to add to table names", default="")
130 part_pixelization = ChoiceField[str](
131 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
132 doc="Pixelization used for partitioning index.",
133 default="mq3c",
134 )
135 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10)
136 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64)
137 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
138 timer = Field[bool](doc="If True then print/log timing information", default=False)
139 time_partition_tables = Field[bool](
140 doc="Use per-partition tables for sources instead of partitioning by time", default=False
141 )
142 time_partition_days = Field[int](
143 doc=(
144 "Time partitioning granularity in days, this value must not be changed after database is "
145 "initialized"
146 ),
147 default=30,
148 )
149 time_partition_start = Field[str](
150 doc=(
151 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
152 "This is used only when time_partition_tables is True."
153 ),
154 default="2018-12-01T00:00:00",
155 )
156 time_partition_end = Field[str](
157 doc=(
158 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
159 "This is used only when time_partition_tables is True."
160 ),
161 default="2030-01-01T00:00:00",
162 )
163 query_per_time_part = Field[bool](
164 default=False,
165 doc=(
166 "If True then build separate query for each time partition, otherwise build one single query. "
167 "This is only used when time_partition_tables is False in schema config."
168 ),
169 )
170 query_per_spatial_part = Field[bool](
171 default=False,
172 doc="If True then build one query per spatial partition, otherwise build single query.",
173 )
174 use_insert_id_skips_diaobjects = Field[bool](
175 default=False,
176 doc=(
177 "If True then do not store DiaObjects when use_insert_id is True "
178 "(DiaObjectsInsertId has the same data)."
179 ),
180 )
183@dataclasses.dataclass
184class _FrozenApdbCassandraConfig:
185 """Part of the configuration that is saved in metadata table and read back.
187 The attributes are a subset of attributes in `ApdbCassandraConfig` class.
189 Parameters
190 ----------
191 config : `ApdbSqlConfig`
192 Configuration used to copy initial values of attributes.
193 """
195 use_insert_id: bool
196 part_pixelization: str
197 part_pix_level: int
198 ra_dec_columns: list[str]
199 time_partition_tables: bool
200 time_partition_days: int
201 use_insert_id_skips_diaobjects: bool
203 def __init__(self, config: ApdbCassandraConfig):
204 self.use_insert_id = config.use_insert_id
205 self.part_pixelization = config.part_pixelization
206 self.part_pix_level = config.part_pix_level
207 self.ra_dec_columns = list(config.ra_dec_columns)
208 self.time_partition_tables = config.time_partition_tables
209 self.time_partition_days = config.time_partition_days
210 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects
212 def to_json(self) -> str:
213 """Convert this instance to JSON representation."""
214 return json.dumps(dataclasses.asdict(self))
216 def update(self, json_str: str) -> None:
217 """Update attribute values from a JSON string.
219 Parameters
220 ----------
221 json_str : str
222 String containing JSON representation of configuration.
223 """
224 data = json.loads(json_str)
225 if not isinstance(data, dict):
226 raise TypeError(f"JSON string must be convertible to object: {json_str!r}")
227 allowed_names = {field.name for field in dataclasses.fields(self)}
228 for key, value in data.items():
229 if key not in allowed_names:
230 raise ValueError(f"JSON object contains unknown key: {key}")
231 setattr(self, key, value)
234if CASSANDRA_IMPORTED: 234 ↛ 249line 234 didn't jump to line 249, because the condition on line 234 was never false
236 class _AddressTranslator(AddressTranslator):
237 """Translate internal IP address to external.
239 Only used for docker-based setup, not viable long-term solution.
240 """
242 def __init__(self, public_ips: list[str], private_ips: list[str]):
243 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
245 def translate(self, private_ip: str) -> str:
246 return self._map.get(private_ip, private_ip)
249class ApdbCassandra(Apdb):
250 """Implementation of APDB database on to of Apache Cassandra.
252 The implementation is configured via standard ``pex_config`` mechanism
253 using `ApdbCassandraConfig` configuration class. For an example of
254 different configurations check config/ folder.
256 Parameters
257 ----------
258 config : `ApdbCassandraConfig`
259 Configuration object.
260 """
262 metadataSchemaVersionKey = "version:schema"
263 """Name of the metadata key to store schema version number."""
265 metadataCodeVersionKey = "version:ApdbCassandra"
266 """Name of the metadata key to store code version number."""
268 metadataConfigKey = "config:apdb-cassandra.json"
269 """Name of the metadata key to store code version number."""
271 _frozen_parameters = (
272 "use_insert_id",
273 "part_pixelization",
274 "part_pix_level",
275 "ra_dec_columns",
276 "time_partition_tables",
277 "time_partition_days",
278 "use_insert_id_skips_diaobjects",
279 )
280 """Names of the config parameters to be frozen in metadata table."""
282 partition_zero_epoch = astropy.time.Time(0, format="unix_tai")
283 """Start time for partition 0, this should never be changed."""
285 def __init__(self, config: ApdbCassandraConfig):
286 if not CASSANDRA_IMPORTED:
287 raise CassandraMissingError()
289 self._keyspace = config.keyspace
291 self._cluster, self._session = self._make_session(config)
293 meta_table_name = ApdbTables.metadata.table_name(config.prefix)
294 self._metadata = ApdbMetadataCassandra(
295 self._session, meta_table_name, config.keyspace, "read_tuples", "write"
296 )
298 # Read frozen config from metadata.
299 config_json = self._metadata.get(self.metadataConfigKey)
300 if config_json is not None:
301 # Update config from metadata.
302 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters)
303 self.config = freezer.update(config, config_json)
304 else:
305 self.config = config
306 self.config.validate()
308 self._pixelization = Pixelization(
309 self.config.part_pixelization,
310 self.config.part_pix_level,
311 config.part_pix_max_ranges,
312 )
314 self._schema = ApdbCassandraSchema(
315 session=self._session,
316 keyspace=self._keyspace,
317 schema_file=self.config.schema_file,
318 schema_name=self.config.schema_name,
319 prefix=self.config.prefix,
320 time_partition_tables=self.config.time_partition_tables,
321 use_insert_id=self.config.use_insert_id,
322 )
323 self._partition_zero_epoch_mjd = float(self.partition_zero_epoch.mjd)
325 if self._metadata.table_exists():
326 self._versionCheck(self._metadata)
328 # Cache for prepared statements
329 self._preparer = PreparedStatementCache(self._session)
331 _LOG.debug("ApdbCassandra Configuration:")
332 for key, value in self.config.items():
333 _LOG.debug(" %s: %s", key, value)
335 def __del__(self) -> None:
336 if hasattr(self, "_cluster"):
337 self._cluster.shutdown()
339 @classmethod
340 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]:
341 """Make Cassandra session."""
342 addressTranslator: AddressTranslator | None = None
343 if config.private_ips:
344 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips))
346 cluster = Cluster(
347 execution_profiles=cls._makeProfiles(config),
348 contact_points=config.contact_points,
349 port=config.port,
350 address_translator=addressTranslator,
351 protocol_version=config.protocol_version,
352 auth_provider=cls._make_auth_provider(config),
353 )
354 session = cluster.connect()
355 # Disable result paging
356 session.default_fetch_size = None
358 return cluster, session
360 @classmethod
361 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None:
362 """Make Cassandra authentication provider instance."""
363 try:
364 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR)
365 except DbAuthNotFoundError:
366 # Credentials file doesn't exist, use anonymous login.
367 return None
369 empty_username = True
370 # Try every contact point in turn.
371 for hostname in config.contact_points:
372 try:
373 username, password = dbauth.getAuth(
374 "cassandra", config.username, hostname, config.port, config.keyspace
375 )
376 if not username:
377 # Password without user name, try next hostname, but give
378 # warning later if no better match is found.
379 empty_username = True
380 else:
381 return PlainTextAuthProvider(username=username, password=password)
382 except DbAuthNotFoundError:
383 pass
385 if empty_username:
386 _LOG.warning(
387 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not "
388 f"user name, anonymous Cassandra logon will be attempted."
389 )
391 return None
393 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None:
394 """Check schema version compatibility."""
396 def _get_version(key: str, default: VersionTuple) -> VersionTuple:
397 """Retrieve version number from given metadata key."""
398 if metadata.table_exists():
399 version_str = metadata.get(key)
400 if version_str is None:
401 # Should not happen with existing metadata table.
402 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.")
403 return VersionTuple.fromString(version_str)
404 return default
406 # For old databases where metadata table does not exist we assume that
407 # version of both code and schema is 0.1.0.
408 initial_version = VersionTuple(0, 1, 0)
409 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version)
410 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version)
412 # For now there is no way to make read-only APDB instances, assume that
413 # any access can do updates.
414 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True):
415 raise IncompatibleVersionError(
416 f"Configured schema version {self._schema.schemaVersion()} "
417 f"is not compatible with database version {db_schema_version}"
418 )
419 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True):
420 raise IncompatibleVersionError(
421 f"Current code version {self.apdbImplementationVersion()} "
422 f"is not compatible with database version {db_code_version}"
423 )
425 @classmethod
426 def apdbImplementationVersion(cls) -> VersionTuple:
427 # Docstring inherited from base class.
428 return VERSION
430 def apdbSchemaVersion(self) -> VersionTuple:
431 # Docstring inherited from base class.
432 return self._schema.schemaVersion()
434 def tableDef(self, table: ApdbTables) -> Table | None:
435 # docstring is inherited from a base class
436 return self._schema.tableSchemas.get(table)
438 @classmethod
439 def init_database(
440 cls,
441 hosts: list[str],
442 keyspace: str,
443 *,
444 schema_file: str | None = None,
445 schema_name: str | None = None,
446 read_sources_months: int | None = None,
447 read_forced_sources_months: int | None = None,
448 use_insert_id: bool = False,
449 use_insert_id_skips_diaobjects: bool = False,
450 port: int | None = None,
451 username: str | None = None,
452 prefix: str | None = None,
453 part_pixelization: str | None = None,
454 part_pix_level: int | None = None,
455 time_partition_tables: bool = True,
456 time_partition_start: str | None = None,
457 time_partition_end: str | None = None,
458 read_consistency: str | None = None,
459 write_consistency: str | None = None,
460 read_timeout: int | None = None,
461 write_timeout: int | None = None,
462 ra_dec_columns: list[str] | None = None,
463 replication_factor: int | None = None,
464 drop: bool = False,
465 ) -> ApdbCassandraConfig:
466 """Initialize new APDB instance and make configuration object for it.
468 Parameters
469 ----------
470 hosts : `list` [`str`]
471 List of host names or IP addresses for Cassandra cluster.
472 keyspace : `str`
473 Name of the keyspace for APDB tables.
474 schema_file : `str`, optional
475 Location of (YAML) configuration file with APDB schema. If not
476 specified then default location will be used.
477 schema_name : `str`, optional
478 Name of the schema in YAML configuration file. If not specified
479 then default name will be used.
480 read_sources_months : `int`, optional
481 Number of months of history to read from DiaSource.
482 read_forced_sources_months : `int`, optional
483 Number of months of history to read from DiaForcedSource.
484 use_insert_id : `bool`, optional
485 If True, make additional tables used for replication to PPDB.
486 use_insert_id_skips_diaobjects : `bool`, optional
487 If `True` then do not fill regular ``DiaObject`` table when
488 ``use_insert_id`` is `True`.
489 port : `int`, optional
490 Port number to use for Cassandra connections.
491 username : `str`, optional
492 User name for Cassandra connections.
493 prefix : `str`, optional
494 Optional prefix for all table names.
495 part_pixelization : `str`, optional
496 Name of the MOC pixelization used for partitioning.
497 part_pix_level : `int`, optional
498 Pixelization level.
499 time_partition_tables : `bool`, optional
500 Create per-partition tables.
501 time_partition_start : `str`, optional
502 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss
503 format, in TAI.
504 time_partition_end : `str`, optional
505 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss
506 format, in TAI.
507 read_consistency : `str`, optional
508 Name of the consistency level for read operations.
509 write_consistency : `str`, optional
510 Name of the consistency level for write operations.
511 read_timeout : `int`, optional
512 Read timeout in seconds.
513 write_timeout : `int`, optional
514 Write timeout in seconds.
515 ra_dec_columns : `list` [`str`], optional
516 Names of ra/dec columns in DiaObject table.
517 replication_factor : `int`, optional
518 Replication factor used when creating new keyspace, if keyspace
519 already exists its replication factor is not changed.
520 drop : `bool`, optional
521 If `True` then drop existing tables before re-creating the schema.
523 Returns
524 -------
525 config : `ApdbCassandraConfig`
526 Resulting configuration object for a created APDB instance.
527 """
528 config = ApdbCassandraConfig(
529 contact_points=hosts,
530 keyspace=keyspace,
531 use_insert_id=use_insert_id,
532 use_insert_id_skips_diaobjects=use_insert_id_skips_diaobjects,
533 time_partition_tables=time_partition_tables,
534 )
535 if schema_file is not None:
536 config.schema_file = schema_file
537 if schema_name is not None:
538 config.schema_name = schema_name
539 if read_sources_months is not None:
540 config.read_sources_months = read_sources_months
541 if read_forced_sources_months is not None:
542 config.read_forced_sources_months = read_forced_sources_months
543 if port is not None:
544 config.port = port
545 if username is not None:
546 config.username = username
547 if prefix is not None:
548 config.prefix = prefix
549 if part_pixelization is not None:
550 config.part_pixelization = part_pixelization
551 if part_pix_level is not None:
552 config.part_pix_level = part_pix_level
553 if time_partition_start is not None:
554 config.time_partition_start = time_partition_start
555 if time_partition_end is not None:
556 config.time_partition_end = time_partition_end
557 if read_consistency is not None:
558 config.read_consistency = read_consistency
559 if write_consistency is not None:
560 config.write_consistency = write_consistency
561 if read_timeout is not None:
562 config.read_timeout = read_timeout
563 if write_timeout is not None:
564 config.write_timeout = write_timeout
565 if ra_dec_columns is not None:
566 config.ra_dec_columns = ra_dec_columns
568 cls._makeSchema(config, drop=drop, replication_factor=replication_factor)
570 return config
572 @classmethod
573 def _makeSchema(
574 cls, config: ApdbConfig, *, drop: bool = False, replication_factor: int | None = None
575 ) -> None:
576 # docstring is inherited from a base class
578 if not isinstance(config, ApdbCassandraConfig):
579 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
581 cluster, session = cls._make_session(config)
583 schema = ApdbCassandraSchema(
584 session=session,
585 keyspace=config.keyspace,
586 schema_file=config.schema_file,
587 schema_name=config.schema_name,
588 prefix=config.prefix,
589 time_partition_tables=config.time_partition_tables,
590 use_insert_id=config.use_insert_id,
591 )
593 # Ask schema to create all tables.
594 if config.time_partition_tables:
595 time_partition_start = astropy.time.Time(config.time_partition_start, format="isot", scale="tai")
596 time_partition_end = astropy.time.Time(config.time_partition_end, format="isot", scale="tai")
597 part_epoch = float(cls.partition_zero_epoch.mjd)
598 part_days = config.time_partition_days
599 part_range = (
600 cls._time_partition_cls(time_partition_start, part_epoch, part_days),
601 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1,
602 )
603 schema.makeSchema(drop=drop, part_range=part_range, replication_factor=replication_factor)
604 else:
605 schema.makeSchema(drop=drop, replication_factor=replication_factor)
607 meta_table_name = ApdbTables.metadata.table_name(config.prefix)
608 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write")
610 # Fill version numbers, overrides if they existed before.
611 if metadata.table_exists():
612 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True)
613 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True)
615 # Store frozen part of a configuration in metadata.
616 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters)
617 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True)
619 cluster.shutdown()
621 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
622 # docstring is inherited from a base class
624 sp_where = self._spatial_where(region)
625 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
627 # We need to exclude extra partitioning columns from result.
628 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
629 what = ",".join(quote_id(column) for column in column_names)
631 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
632 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"'
633 statements: list[tuple] = []
634 for where, params in sp_where:
635 full_query = f"{query} WHERE {where}"
636 if params:
637 statement = self._preparer.prepare(full_query)
638 else:
639 # If there are no params then it is likely that query has a
640 # bunch of literals rendered already, no point trying to
641 # prepare it because it's not reusable.
642 statement = cassandra.query.SimpleStatement(full_query)
643 statements.append((statement, params))
644 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
646 with Timer("DiaObject select", self.config.timer):
647 objects = cast(
648 pandas.DataFrame,
649 select_concurrent(
650 self._session, statements, "read_pandas_multi", self.config.read_concurrency
651 ),
652 )
654 _LOG.debug("found %s DiaObjects", objects.shape[0])
655 return objects
657 def getDiaSources(
658 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
659 ) -> pandas.DataFrame | None:
660 # docstring is inherited from a base class
661 months = self.config.read_sources_months
662 if months == 0:
663 return None
664 mjd_end = visit_time.mjd
665 mjd_start = mjd_end - months * 30
667 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
669 def getDiaForcedSources(
670 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time
671 ) -> pandas.DataFrame | None:
672 # docstring is inherited from a base class
673 months = self.config.read_forced_sources_months
674 if months == 0:
675 return None
676 mjd_end = visit_time.mjd
677 mjd_start = mjd_end - months * 30
679 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
681 def containsVisitDetector(self, visit: int, detector: int) -> bool:
682 # docstring is inherited from a base class
683 raise NotImplementedError()
685 def getInsertIds(self) -> list[ApdbInsertId] | None:
686 # docstring is inherited from a base class
687 if not self._schema.has_insert_id:
688 return None
690 # everything goes into a single partition
691 partition = 0
693 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
694 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?'
696 result = self._session.execute(
697 self._preparer.prepare(query),
698 (partition,),
699 timeout=self.config.read_timeout,
700 execution_profile="read_tuples",
701 )
702 # order by insert_time
703 rows = sorted(result)
704 return [
705 ApdbInsertId(id=row[1], insert_time=astropy.time.Time(row[0].timestamp(), format="unix_tai"))
706 for row in rows
707 ]
709 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
710 # docstring is inherited from a base class
711 if not self._schema.has_insert_id:
712 raise ValueError("APDB is not configured for history storage")
714 all_insert_ids = [id.id for id in ids]
715 # There is 64k limit on number of markers in Cassandra CQL
716 for insert_ids in chunk_iterable(all_insert_ids, 20_000):
717 params = ",".join("?" * len(insert_ids))
719 # everything goes into a single partition
720 partition = 0
722 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
723 query = (
724 f'DELETE FROM "{self._keyspace}"."{table_name}" '
725 f"WHERE partition = ? AND insert_id IN ({params})"
726 )
728 self._session.execute(
729 self._preparer.prepare(query),
730 [partition] + list(insert_ids),
731 timeout=self.config.remove_timeout,
732 )
734 # Also remove those insert_ids from Dia*InsertId tables.abs
735 for table in (
736 ExtraTables.DiaObjectInsertId,
737 ExtraTables.DiaSourceInsertId,
738 ExtraTables.DiaForcedSourceInsertId,
739 ):
740 table_name = self._schema.tableName(table)
741 query = f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
742 self._session.execute(
743 self._preparer.prepare(query),
744 insert_ids,
745 timeout=self.config.remove_timeout,
746 )
748 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
749 # docstring is inherited from a base class
750 return self._get_history(ExtraTables.DiaObjectInsertId, ids)
752 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
753 # docstring is inherited from a base class
754 return self._get_history(ExtraTables.DiaSourceInsertId, ids)
756 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
757 # docstring is inherited from a base class
758 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids)
760 def getSSObjects(self) -> pandas.DataFrame:
761 # docstring is inherited from a base class
762 tableName = self._schema.tableName(ApdbTables.SSObject)
763 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
765 objects = None
766 with Timer("SSObject select", self.config.timer):
767 result = self._session.execute(query, execution_profile="read_pandas")
768 objects = result._current_rows
770 _LOG.debug("found %s DiaObjects", objects.shape[0])
771 return objects
773 def store(
774 self,
775 visit_time: astropy.time.Time,
776 objects: pandas.DataFrame,
777 sources: pandas.DataFrame | None = None,
778 forced_sources: pandas.DataFrame | None = None,
779 ) -> None:
780 # docstring is inherited from a base class
782 insert_id: ApdbInsertId | None = None
783 if self._schema.has_insert_id:
784 insert_id = ApdbInsertId.new_insert_id(visit_time)
785 self._storeInsertId(insert_id, visit_time)
787 # fill region partition column for DiaObjects
788 objects = self._add_obj_part(objects)
789 self._storeDiaObjects(objects, visit_time, insert_id)
791 if sources is not None:
792 # copy apdb_part column from DiaObjects to DiaSources
793 sources = self._add_src_part(sources, objects)
794 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id)
795 self._storeDiaSourcesPartitions(sources, visit_time, insert_id)
797 if forced_sources is not None:
798 forced_sources = self._add_fsrc_part(forced_sources, objects)
799 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id)
801 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
802 # docstring is inherited from a base class
803 self._storeObjectsPandas(objects, ApdbTables.SSObject)
805 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
806 # docstring is inherited from a base class
808 # To update a record we need to know its exact primary key (including
809 # partition key) so we start by querying for diaSourceId to find the
810 # primary keys.
812 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
813 # split it into 1k IDs per query
814 selects: list[tuple] = []
815 for ids in chunk_iterable(idMap.keys(), 1_000):
816 ids_str = ",".join(str(item) for item in ids)
817 selects.append(
818 (
819 (
820 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" '
821 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})'
822 ),
823 {},
824 )
825 )
827 # No need for DataFrame here, read data as tuples.
828 result = cast(
829 list[tuple[int, int, int, uuid.UUID | None]],
830 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency),
831 )
833 # Make mapping from source ID to its partition.
834 id2partitions: dict[int, tuple[int, int]] = {}
835 id2insert_id: dict[int, uuid.UUID] = {}
836 for row in result:
837 id2partitions[row[0]] = row[1:3]
838 if row[3] is not None:
839 id2insert_id[row[0]] = row[3]
841 # make sure we know partitions for each ID
842 if set(id2partitions) != set(idMap):
843 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
844 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
846 # Reassign in standard tables
847 queries = cassandra.query.BatchStatement()
848 table_name = self._schema.tableName(ApdbTables.DiaSource)
849 for diaSourceId, ssObjectId in idMap.items():
850 apdb_part, apdb_time_part = id2partitions[diaSourceId]
851 values: tuple
852 if self.config.time_partition_tables:
853 query = (
854 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
855 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
856 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
857 )
858 values = (ssObjectId, apdb_part, diaSourceId)
859 else:
860 query = (
861 f'UPDATE "{self._keyspace}"."{table_name}"'
862 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
863 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
864 )
865 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
866 queries.add(self._preparer.prepare(query), values)
868 # Reassign in history tables, only if history is enabled
869 if id2insert_id:
870 # Filter out insert ids that have been deleted already. There is a
871 # potential race with concurrent removal of insert IDs, but it
872 # should be handled by WHERE in UPDATE.
873 known_ids = set()
874 if insert_ids := self.getInsertIds():
875 known_ids = set(insert_id.id for insert_id in insert_ids)
876 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids}
877 if id2insert_id:
878 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId)
879 for diaSourceId, ssObjectId in idMap.items():
880 if insert_id := id2insert_id.get(diaSourceId):
881 query = (
882 f'UPDATE "{self._keyspace}"."{table_name}" '
883 ' SET "ssObjectId" = ?, "diaObjectId" = NULL '
884 'WHERE "insert_id" = ? AND "diaSourceId" = ?'
885 )
886 values = (ssObjectId, insert_id, diaSourceId)
887 queries.add(self._preparer.prepare(query), values)
889 _LOG.debug("%s: will update %d records", table_name, len(idMap))
890 with Timer(table_name + " update", self.config.timer):
891 self._session.execute(queries, execution_profile="write")
893 def dailyJob(self) -> None:
894 # docstring is inherited from a base class
895 pass
897 def countUnassociatedObjects(self) -> int:
898 # docstring is inherited from a base class
900 # It's too inefficient to implement it for Cassandra in current schema.
901 raise NotImplementedError()
903 @property
904 def metadata(self) -> ApdbMetadata:
905 # docstring is inherited from a base class
906 if self._metadata is None:
907 raise RuntimeError("Database schema was not initialized.")
908 return self._metadata
910 @classmethod
911 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
912 """Make all execution profiles used in the code."""
913 if config.private_ips:
914 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
915 else:
916 loadBalancePolicy = RoundRobinPolicy()
918 read_tuples_profile = ExecutionProfile(
919 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
920 request_timeout=config.read_timeout,
921 row_factory=cassandra.query.tuple_factory,
922 load_balancing_policy=loadBalancePolicy,
923 )
924 read_pandas_profile = ExecutionProfile(
925 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
926 request_timeout=config.read_timeout,
927 row_factory=pandas_dataframe_factory,
928 load_balancing_policy=loadBalancePolicy,
929 )
930 read_raw_profile = ExecutionProfile(
931 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
932 request_timeout=config.read_timeout,
933 row_factory=raw_data_factory,
934 load_balancing_policy=loadBalancePolicy,
935 )
936 # Profile to use with select_concurrent to return pandas data frame
937 read_pandas_multi_profile = ExecutionProfile(
938 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
939 request_timeout=config.read_timeout,
940 row_factory=pandas_dataframe_factory,
941 load_balancing_policy=loadBalancePolicy,
942 )
943 # Profile to use with select_concurrent to return raw data (columns and
944 # rows)
945 read_raw_multi_profile = ExecutionProfile(
946 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
947 request_timeout=config.read_timeout,
948 row_factory=raw_data_factory,
949 load_balancing_policy=loadBalancePolicy,
950 )
951 write_profile = ExecutionProfile(
952 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
953 request_timeout=config.write_timeout,
954 load_balancing_policy=loadBalancePolicy,
955 )
956 # To replace default DCAwareRoundRobinPolicy
957 default_profile = ExecutionProfile(
958 load_balancing_policy=loadBalancePolicy,
959 )
960 return {
961 "read_tuples": read_tuples_profile,
962 "read_pandas": read_pandas_profile,
963 "read_raw": read_raw_profile,
964 "read_pandas_multi": read_pandas_multi_profile,
965 "read_raw_multi": read_raw_multi_profile,
966 "write": write_profile,
967 EXEC_PROFILE_DEFAULT: default_profile,
968 }
970 def _getSources(
971 self,
972 region: sphgeom.Region,
973 object_ids: Iterable[int] | None,
974 mjd_start: float,
975 mjd_end: float,
976 table_name: ApdbTables,
977 ) -> pandas.DataFrame:
978 """Return catalog of DiaSource instances given set of DiaObject IDs.
980 Parameters
981 ----------
982 region : `lsst.sphgeom.Region`
983 Spherical region.
984 object_ids :
985 Collection of DiaObject IDs
986 mjd_start : `float`
987 Lower bound of time interval.
988 mjd_end : `float`
989 Upper bound of time interval.
990 table_name : `ApdbTables`
991 Name of the table.
993 Returns
994 -------
995 catalog : `pandas.DataFrame`, or `None`
996 Catalog containing DiaSource records. Empty catalog is returned if
997 ``object_ids`` is empty.
998 """
999 object_id_set: Set[int] = set()
1000 if object_ids is not None:
1001 object_id_set = set(object_ids)
1002 if len(object_id_set) == 0:
1003 return self._make_empty_catalog(table_name)
1005 sp_where = self._spatial_where(region)
1006 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
1008 # We need to exclude extra partitioning columns from result.
1009 column_names = self._schema.apdbColumnNames(table_name)
1010 what = ",".join(quote_id(column) for column in column_names)
1012 # Build all queries
1013 statements: list[tuple] = []
1014 for table in tables:
1015 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"'
1016 statements += list(self._combine_where(prefix, sp_where, temporal_where))
1017 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
1019 with Timer(table_name.name + " select", self.config.timer):
1020 catalog = cast(
1021 pandas.DataFrame,
1022 select_concurrent(
1023 self._session, statements, "read_pandas_multi", self.config.read_concurrency
1024 ),
1025 )
1027 # filter by given object IDs
1028 if len(object_id_set) > 0:
1029 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
1031 # precise filtering on midpointMjdTai
1032 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
1034 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
1035 return catalog
1037 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
1038 """Return records from a particular table given set of insert IDs."""
1039 if not self._schema.has_insert_id:
1040 raise ValueError("APDB is not configured for history retrieval")
1042 insert_ids = [id.id for id in ids]
1043 params = ",".join("?" * len(insert_ids))
1045 table_name = self._schema.tableName(table)
1046 # I know that history table schema has only regular APDB columns plus
1047 # an insert_id column, and this is exactly what we need to return from
1048 # this method, so selecting a star is fine here.
1049 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
1050 statement = self._preparer.prepare(query)
1052 with Timer("DiaObject history", self.config.timer):
1053 result = self._session.execute(statement, insert_ids, execution_profile="read_raw")
1054 table_data = cast(ApdbCassandraTableData, result._current_rows)
1055 return table_data
1057 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: astropy.time.Time) -> None:
1058 # Cassandra timestamp uses milliseconds since epoch
1059 timestamp = int(insert_id.insert_time.unix_tai / 1_000_000)
1061 # everything goes into a single partition
1062 partition = 0
1064 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
1065 query = (
1066 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) '
1067 "VALUES (?, ?, ?)"
1068 )
1070 self._session.execute(
1071 self._preparer.prepare(query),
1072 (partition, insert_id.id, timestamp),
1073 timeout=self.config.write_timeout,
1074 execution_profile="write",
1075 )
1077 def _storeDiaObjects(
1078 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, insert_id: ApdbInsertId | None
1079 ) -> None:
1080 """Store catalog of DiaObjects from current visit.
1082 Parameters
1083 ----------
1084 objs : `pandas.DataFrame`
1085 Catalog with DiaObject records
1086 visit_time : `astropy.time.Time`
1087 Time of the current visit.
1088 """
1089 if len(objs) == 0:
1090 _LOG.debug("No objects to write to database.")
1091 return
1093 visit_time_dt = visit_time.datetime
1094 extra_columns = dict(lastNonForcedSource=visit_time_dt)
1095 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
1097 extra_columns["validityStart"] = visit_time_dt
1098 time_part: int | None = self._time_partition(visit_time)
1099 if not self.config.time_partition_tables:
1100 extra_columns["apdb_time_part"] = time_part
1101 time_part = None
1103 # Only store DiaObects if not storing insert_ids or explicitly
1104 # configured to always store them
1105 if insert_id is None or not self.config.use_insert_id_skips_diaobjects:
1106 self._storeObjectsPandas(
1107 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part
1108 )
1110 if insert_id is not None:
1111 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt)
1112 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns)
1114 def _storeDiaSources(
1115 self,
1116 table_name: ApdbTables,
1117 sources: pandas.DataFrame,
1118 visit_time: astropy.time.Time,
1119 insert_id: ApdbInsertId | None,
1120 ) -> None:
1121 """Store catalog of DIASources or DIAForcedSources from current visit.
1123 Parameters
1124 ----------
1125 sources : `pandas.DataFrame`
1126 Catalog containing DiaSource records
1127 visit_time : `astropy.time.Time`
1128 Time of the current visit.
1129 """
1130 time_part: int | None = self._time_partition(visit_time)
1131 extra_columns: dict[str, Any] = {}
1132 if not self.config.time_partition_tables:
1133 extra_columns["apdb_time_part"] = time_part
1134 time_part = None
1136 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
1138 if insert_id is not None:
1139 extra_columns = dict(insert_id=insert_id.id)
1140 if table_name is ApdbTables.DiaSource:
1141 extra_table = ExtraTables.DiaSourceInsertId
1142 else:
1143 extra_table = ExtraTables.DiaForcedSourceInsertId
1144 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
1146 def _storeDiaSourcesPartitions(
1147 self, sources: pandas.DataFrame, visit_time: astropy.time.Time, insert_id: ApdbInsertId | None
1148 ) -> None:
1149 """Store mapping of diaSourceId to its partitioning values.
1151 Parameters
1152 ----------
1153 sources : `pandas.DataFrame`
1154 Catalog containing DiaSource records
1155 visit_time : `astropy.time.Time`
1156 Time of the current visit.
1157 """
1158 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
1159 extra_columns = {
1160 "apdb_time_part": self._time_partition(visit_time),
1161 "insert_id": insert_id.id if insert_id is not None else None,
1162 }
1164 self._storeObjectsPandas(
1165 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
1166 )
1168 def _storeObjectsPandas(
1169 self,
1170 records: pandas.DataFrame,
1171 table_name: ApdbTables | ExtraTables,
1172 extra_columns: Mapping | None = None,
1173 time_part: int | None = None,
1174 ) -> None:
1175 """Store generic objects.
1177 Takes Pandas catalog and stores a bunch of records in a table.
1179 Parameters
1180 ----------
1181 records : `pandas.DataFrame`
1182 Catalog containing object records
1183 table_name : `ApdbTables`
1184 Name of the table as defined in APDB schema.
1185 extra_columns : `dict`, optional
1186 Mapping (column_name, column_value) which gives fixed values for
1187 columns in each row, overrides values in ``records`` if matching
1188 columns exist there.
1189 time_part : `int`, optional
1190 If not `None` then insert into a per-partition table.
1192 Notes
1193 -----
1194 If Pandas catalog contains additional columns not defined in table
1195 schema they are ignored. Catalog does not have to contain all columns
1196 defined in a table, but partition and clustering keys must be present
1197 in a catalog or ``extra_columns``.
1198 """
1199 # use extra columns if specified
1200 if extra_columns is None:
1201 extra_columns = {}
1202 extra_fields = list(extra_columns.keys())
1204 # Fields that will come from dataframe.
1205 df_fields = [column for column in records.columns if column not in extra_fields]
1207 column_map = self._schema.getColumnMap(table_name)
1208 # list of columns (as in felis schema)
1209 fields = [column_map[field].name for field in df_fields if field in column_map]
1210 fields += extra_fields
1212 # check that all partitioning and clustering columns are defined
1213 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns(
1214 table_name
1215 )
1216 missing_columns = [column for column in required_columns if column not in fields]
1217 if missing_columns:
1218 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
1220 qfields = [quote_id(field) for field in fields]
1221 qfields_str = ",".join(qfields)
1223 with Timer(table_name.name + " query build", self.config.timer):
1224 table = self._schema.tableName(table_name)
1225 if time_part is not None:
1226 table = f"{table}_{time_part}"
1228 holders = ",".join(["?"] * len(qfields))
1229 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
1230 statement = self._preparer.prepare(query)
1231 queries = cassandra.query.BatchStatement()
1232 for rec in records.itertuples(index=False):
1233 values = []
1234 for field in df_fields:
1235 if field not in column_map:
1236 continue
1237 value = getattr(rec, field)
1238 if column_map[field].datatype is felis.types.Timestamp:
1239 if isinstance(value, pandas.Timestamp):
1240 value = literal(value.to_pydatetime())
1241 else:
1242 # Assume it's seconds since epoch, Cassandra
1243 # datetime is in milliseconds
1244 value = int(value * 1000)
1245 values.append(literal(value))
1246 for field in extra_fields:
1247 value = extra_columns[field]
1248 values.append(literal(value))
1249 queries.add(statement, values)
1251 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0])
1252 with Timer(table_name.name + " insert", self.config.timer):
1253 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
1255 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
1256 """Calculate spatial partition for each record and add it to a
1257 DataFrame.
1259 Notes
1260 -----
1261 This overrides any existing column in a DataFrame with the same name
1262 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
1263 returned.
1264 """
1265 # calculate HTM index for every DiaObject
1266 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
1267 ra_col, dec_col = self.config.ra_dec_columns
1268 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1269 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
1270 idx = self._pixelization.pixel(uv3d)
1271 apdb_part[i] = idx
1272 df = df.copy()
1273 df["apdb_part"] = apdb_part
1274 return df
1276 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1277 """Add apdb_part column to DiaSource catalog.
1279 Notes
1280 -----
1281 This method copies apdb_part value from a matching DiaObject record.
1282 DiaObject catalog needs to have a apdb_part column filled by
1283 ``_add_obj_part`` method and DiaSource records need to be
1284 associated to DiaObjects via ``diaObjectId`` column.
1286 This overrides any existing column in a DataFrame with the same name
1287 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
1288 returned.
1289 """
1290 pixel_id_map: dict[int, int] = {
1291 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
1292 }
1293 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
1294 ra_col, dec_col = self.config.ra_dec_columns
1295 for i, (diaObjId, ra, dec) in enumerate(
1296 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col])
1297 ):
1298 if diaObjId == 0:
1299 # DiaSources associated with SolarSystemObjects do not have an
1300 # associated DiaObject hence we skip them and set partition
1301 # based on its own ra/dec
1302 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
1303 idx = self._pixelization.pixel(uv3d)
1304 apdb_part[i] = idx
1305 else:
1306 apdb_part[i] = pixel_id_map[diaObjId]
1307 sources = sources.copy()
1308 sources["apdb_part"] = apdb_part
1309 return sources
1311 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
1312 """Add apdb_part column to DiaForcedSource catalog.
1314 Notes
1315 -----
1316 This method copies apdb_part value from a matching DiaObject record.
1317 DiaObject catalog needs to have a apdb_part column filled by
1318 ``_add_obj_part`` method and DiaSource records need to be
1319 associated to DiaObjects via ``diaObjectId`` column.
1321 This overrides any existing column in a DataFrame with the same name
1322 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
1323 returned.
1324 """
1325 pixel_id_map: dict[int, int] = {
1326 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
1327 }
1328 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
1329 for i, diaObjId in enumerate(sources["diaObjectId"]):
1330 apdb_part[i] = pixel_id_map[diaObjId]
1331 sources = sources.copy()
1332 sources["apdb_part"] = apdb_part
1333 return sources
1335 @classmethod
1336 def _time_partition_cls(cls, time: float | astropy.time.Time, epoch_mjd: float, part_days: int) -> int:
1337 """Calculate time partition number for a given time.
1339 Parameters
1340 ----------
1341 time : `float` or `astropy.time.Time`
1342 Time for which to calculate partition number. Can be float to mean
1343 MJD or `astropy.time.Time`
1344 epoch_mjd : `float`
1345 Epoch time for partition 0.
1346 part_days : `int`
1347 Number of days per partition.
1349 Returns
1350 -------
1351 partition : `int`
1352 Partition number for a given time.
1353 """
1354 if isinstance(time, astropy.time.Time):
1355 mjd = float(time.mjd)
1356 else:
1357 mjd = time
1358 days_since_epoch = mjd - epoch_mjd
1359 partition = int(days_since_epoch) // part_days
1360 return partition
1362 def _time_partition(self, time: float | astropy.time.Time) -> int:
1363 """Calculate time partition number for a given time.
1365 Parameters
1366 ----------
1367 time : `float` or `astropy.time.Time`
1368 Time for which to calculate partition number. Can be float to mean
1369 MJD or `astropy.time.Time`
1371 Returns
1372 -------
1373 partition : `int`
1374 Partition number for a given time.
1375 """
1376 if isinstance(time, astropy.time.Time):
1377 mjd = float(time.mjd)
1378 else:
1379 mjd = time
1380 days_since_epoch = mjd - self._partition_zero_epoch_mjd
1381 partition = int(days_since_epoch) // self.config.time_partition_days
1382 return partition
1384 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
1385 """Make an empty catalog for a table with a given name.
1387 Parameters
1388 ----------
1389 table_name : `ApdbTables`
1390 Name of the table.
1392 Returns
1393 -------
1394 catalog : `pandas.DataFrame`
1395 An empty catalog.
1396 """
1397 table = self._schema.tableSchemas[table_name]
1399 data = {
1400 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
1401 for columnDef in table.columns
1402 }
1403 return pandas.DataFrame(data)
1405 def _combine_where(
1406 self,
1407 prefix: str,
1408 where1: list[tuple[str, tuple]],
1409 where2: list[tuple[str, tuple]],
1410 suffix: str | None = None,
1411 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]:
1412 """Make cartesian product of two parts of WHERE clause into a series
1413 of statements to execute.
1415 Parameters
1416 ----------
1417 prefix : `str`
1418 Initial statement prefix that comes before WHERE clause, e.g.
1419 "SELECT * from Table"
1420 """
1421 # If lists are empty use special sentinels.
1422 if not where1:
1423 where1 = [("", ())]
1424 if not where2:
1425 where2 = [("", ())]
1427 for expr1, params1 in where1:
1428 for expr2, params2 in where2:
1429 full_query = prefix
1430 wheres = []
1431 if expr1:
1432 wheres.append(expr1)
1433 if expr2:
1434 wheres.append(expr2)
1435 if wheres:
1436 full_query += " WHERE " + " AND ".join(wheres)
1437 if suffix:
1438 full_query += " " + suffix
1439 params = params1 + params2
1440 if params:
1441 statement = self._preparer.prepare(full_query)
1442 else:
1443 # If there are no params then it is likely that query
1444 # has a bunch of literals rendered already, no point
1445 # trying to prepare it.
1446 statement = cassandra.query.SimpleStatement(full_query)
1447 yield (statement, params)
1449 def _spatial_where(
1450 self, region: sphgeom.Region | None, use_ranges: bool = False
1451 ) -> list[tuple[str, tuple]]:
1452 """Generate expressions for spatial part of WHERE clause.
1454 Parameters
1455 ----------
1456 region : `sphgeom.Region`
1457 Spatial region for query results.
1458 use_ranges : `bool`
1459 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
1460 p2") instead of exact list of pixels. Should be set to True for
1461 large regions covering very many pixels.
1463 Returns
1464 -------
1465 expressions : `list` [ `tuple` ]
1466 Empty list is returned if ``region`` is `None`, otherwise a list
1467 of one or more (expression, parameters) tuples
1468 """
1469 if region is None:
1470 return []
1471 if use_ranges:
1472 pixel_ranges = self._pixelization.envelope(region)
1473 expressions: list[tuple[str, tuple]] = []
1474 for lower, upper in pixel_ranges:
1475 upper -= 1
1476 if lower == upper:
1477 expressions.append(('"apdb_part" = ?', (lower,)))
1478 else:
1479 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
1480 return expressions
1481 else:
1482 pixels = self._pixelization.pixels(region)
1483 if self.config.query_per_spatial_part:
1484 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1485 else:
1486 pixels_str = ",".join([str(pix) for pix in pixels])
1487 return [(f'"apdb_part" IN ({pixels_str})', ())]
1489 def _temporal_where(
1490 self,
1491 table: ApdbTables,
1492 start_time: float | astropy.time.Time,
1493 end_time: float | astropy.time.Time,
1494 query_per_time_part: bool | None = None,
1495 ) -> tuple[list[str], list[tuple[str, tuple]]]:
1496 """Generate table names and expressions for temporal part of WHERE
1497 clauses.
1499 Parameters
1500 ----------
1501 table : `ApdbTables`
1502 Table to select from.
1503 start_time : `astropy.time.Time` or `float`
1504 Starting Datetime of MJD value of the time range.
1505 end_time : `astropy.time.Time` or `float`
1506 Starting Datetime of MJD value of the time range.
1507 query_per_time_part : `bool`, optional
1508 If None then use ``query_per_time_part`` from configuration.
1510 Returns
1511 -------
1512 tables : `list` [ `str` ]
1513 List of the table names to query.
1514 expressions : `list` [ `tuple` ]
1515 A list of zero or more (expression, parameters) tuples.
1516 """
1517 tables: list[str]
1518 temporal_where: list[tuple[str, tuple]] = []
1519 table_name = self._schema.tableName(table)
1520 time_part_start = self._time_partition(start_time)
1521 time_part_end = self._time_partition(end_time)
1522 time_parts = list(range(time_part_start, time_part_end + 1))
1523 if self.config.time_partition_tables:
1524 tables = [f"{table_name}_{part}" for part in time_parts]
1525 else:
1526 tables = [table_name]
1527 if query_per_time_part is None:
1528 query_per_time_part = self.config.query_per_time_part
1529 if query_per_time_part:
1530 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts]
1531 else:
1532 time_part_list = ",".join([str(part) for part in time_parts])
1533 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1535 return tables, temporal_where