Coverage for python / lsst / dax / apdb / cassandra / apdbCassandraAdmin.py: 14%
275 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandraAdmin"]
26import dataclasses
27import itertools
28import logging
29import warnings
30from collections import defaultdict
31from collections.abc import Iterable, Mapping
32from typing import TYPE_CHECKING, Protocol
34import astropy.time
36from lsst.sphgeom import LonLat, UnitVector3d
37from lsst.utils.iteration import chunk_iterable
39try:
40 import cassandra
41except ImportError:
42 pass
44from ..apdbAdmin import ApdbAdmin, DiaForcedSourceLocator, DiaObjectLocator, DiaSourceLocator
45from ..apdbSchema import ApdbTables
46from ..monitor import MonAgent
47from ..timer import Timer
48from .cassandra_utils import StatementFactory, execute_concurrent, quote_id
49from .config import ApdbCassandraConfig, ApdbCassandraTimePartitionRange
50from .queries import Column as C # noqa: N817
51from .queries import Delete, Select
52from .sessionFactory import SessionContext
54if TYPE_CHECKING:
55 from .apdbCassandra import ApdbCassandra
56 from .partitioner import Partitioner
58_LOG = logging.getLogger(__name__)
60_MON = MonAgent(__name__)
63class ConfirmDeletePartitions(Protocol):
64 """Protocol for callable which confirms deletion of partitions."""
66 def __call__(self, *, partitions: list[int], tables: list[str], partitioner: Partitioner) -> bool: ... 66 ↛ exitline 66 didn't return from function '__call__' because
69@dataclasses.dataclass
70class DatabaseInfo:
71 """Collection of information about a specific database."""
73 name: str
74 """Keyspace name."""
76 permissions: dict[str, set[str]] | None = None
77 """Roles that can access the database and their permissions.
79 `None` means that authentication information is not accessible due to
80 system table permissions. If anonymous access is enabled then dictionary
81 will be empty but not `None`.
82 """
85class ApdbCassandraAdmin(ApdbAdmin):
86 """Implementation of `ApdbAdmin` for Cassandra backend.
88 Parameters
89 ----------
90 apdb : `ApdbCassandra`
91 APDB implementation.
92 """
94 def __init__(self, apdb: ApdbCassandra):
95 self._apdb = apdb
97 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer:
98 """Create `Timer` instance given its name."""
99 return Timer(name, _MON, _LOG, tags=tags)
101 @classmethod
102 def list_databases(cls, host: str) -> Iterable[DatabaseInfo]:
103 """Return the list of keyspaces with APDB databases.
105 Parameters
106 ----------
107 host : `str`
108 Name of one of the hosts in Cassandra cluster.
110 Returns
111 -------
112 databases : `~collections.abc.Iterable` [`DatabaseInfo`]
113 Information about databases that contain APDB instance.
114 """
115 # For DbAuth we need to use database name "*" to try to match any
116 # database.
117 config = ApdbCassandraConfig(contact_points=(host,), keyspace="*")
118 with SessionContext(config) as session:
119 stmt_factory = StatementFactory(session)
121 # Get names of all keyspaces containing DiaSource table
122 table_name = ApdbTables.DiaSource.table_name()
123 query = Select("system_schema", "tables", ["keyspace_name"], extra_clause="ALLOW FILTERING")
124 query = query.where(C("table_name") == table_name)
125 stmt, params = stmt_factory.with_params(query)
127 result = session.execute(stmt, params)
128 keyspaces = [row[0] for row in result.all()]
130 if not keyspaces:
131 return []
133 # Retrieve roles for each keyspace.
134 resources = [f"data/{keyspace}" for keyspace in keyspaces]
135 query = Select(
136 "system_auth",
137 "role_permissions",
138 ("resource", "role", "permissions"),
139 extra_clause="ALLOW FILTERING",
140 )
141 query = query.where(C("resource").in_(resources))
142 stmt, params = stmt_factory.with_params(query)
144 try:
145 result = session.execute(stmt, params)
146 # If anonymous access is enabled then result will be empty,
147 # set infos to have empty permissions dict in that case.
148 infos = {keyspace: DatabaseInfo(name=keyspace, permissions={}) for keyspace in keyspaces}
149 for row in result:
150 _, _, keyspace = row[0].partition("/")
151 role: str = row[1]
152 role_permissions: set[str] = set(row[2])
153 infos[keyspace].permissions[role] = role_permissions # type: ignore[index]
154 except cassandra.Unauthorized as exc:
155 # Likely that access to role_permissions is not granted for
156 # current user.
157 warnings.warn(
158 f"Authentication information is not accessible to current user - {exc}", stacklevel=2
159 )
160 infos = {keyspace: DatabaseInfo(name=keyspace) for keyspace in keyspaces}
162 # Would be nice to get size estimate, but this is not available
163 # via CQL queries.
164 return infos.values()
166 @classmethod
167 def delete_database(cls, host: str, keyspace: str, *, timeout: int = 3600) -> None:
168 """Delete APDB database by dropping its keyspace.
170 Parameters
171 ----------
172 host : `str`
173 Name of one of the hosts in Cassandra cluster.
174 keyspace : `str`
175 Name of keyspace to delete.
176 timeout : `int`, optional
177 Timeout for delete operation in seconds. Dropping a large keyspace
178 can be a long operation, but this default value of one hour should
179 be sufficient for most or all cases.
180 """
181 # For DbAuth we need to use database name "*" to try to match any
182 # database.
183 config = ApdbCassandraConfig(contact_points=(host,), keyspace="*")
184 with SessionContext(config) as session:
185 query = f"DROP KEYSPACE {quote_id(keyspace)}"
186 session.execute(query, timeout=timeout)
188 @property
189 def partitioner(self) -> Partitioner:
190 """Partitoner used by this APDB instance (`Partitioner`)."""
191 context = self._apdb._context
192 return context.partitioner
194 def apdb_part(self, ra: float, dec: float) -> int:
195 # docstring is inherited from a base class
196 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
197 return self.partitioner.pixel(uv3d)
199 def apdb_time_part(self, midpointMjdTai: float) -> int:
200 # docstring is inherited from a base class
201 return self.partitioner.time_partition(midpointMjdTai)
203 def delete_records(
204 self,
205 objects: Iterable[DiaObjectLocator],
206 sources: Iterable[DiaSourceLocator],
207 forced_sources: Iterable[DiaForcedSourceLocator],
208 ) -> None:
209 # docstring is inherited from a base class
210 context = self._apdb._context
211 config = context.config
212 keyspace = self._apdb._keyspace
213 has_dia_object_table = not (config.enable_replica and config.replica_skips_diaobjects)
215 # Group objects by partition.
216 partitions = defaultdict(list)
217 for object in objects:
218 apdb_part = self.apdb_part(object.ra, object.dec)
219 partitions[apdb_part].append(object.diaObjectId)
220 object_ids = set(itertools.chain.from_iterable(partitions.values()))
222 # Group sources by associated object ID.
223 source_groups = defaultdict(list)
224 for source in sources:
225 if source.diaObjectId in object_ids:
226 source_groups[source.diaObjectId].append(source)
228 object_deletes = []
229 object_count = 0
230 # Delete from DiaObjectLast table.
231 for apdb_part, oids in partitions.items():
232 oids = sorted(oids)
233 object_count += len(oids)
234 for oid_chunk in chunk_iterable(oids, 1000):
235 query = (
236 Delete(keyspace, "DiaObjectLast")
237 .where(C("apdb_part") == apdb_part)
238 .where(C("diaObjectId").in_(oid_chunk))
239 )
240 object_deletes.append(context.stmt_factory.with_params(query))
242 # If DiaObject is in use then delete from that too.
243 if has_dia_object_table:
244 # Need temporal partitions for DiaObject, the only source for that
245 # is the timestamp of the associated DiaSource. Problem here is
246 # that DiaObject temporal partitioning is based on validityStart,
247 # which is "visit_time"", but DiaSource does not record visit_time,
248 # it is partitioned on midpointMjdTai. There is time_processed
249 # defined for DiaSource but it does not match "visit_time" though
250 # it is close. I use midpointMjdTai as approximation for
251 # validityStart, this may skip some DiaObjects, but in production
252 # we are not going to have DiaObjects table at all. There is also
253 # a chance that DiaObject moves from one spatial partition to
254 # another with the same consequences, which we also ignore.
255 oids_by_partition: dict[tuple[int, int], list[int]] = defaultdict(list)
256 for apdb_part, oids in partitions.items():
257 for oid in oids:
258 temporal_partitions = {
259 self.apdb_time_part(src.midpointMjdTai) for src in source_groups.get(oid, [])
260 }
261 for time_part in temporal_partitions:
262 oids_by_partition[(apdb_part, time_part)].append(oid)
263 for (apdb_part, time_part), oids in oids_by_partition.items():
264 for oid_chunk in chunk_iterable(oids, 1000):
265 if config.partitioning.time_partition_tables:
266 table_name = context.schema.tableName(ApdbTables.DiaObject, time_part)
267 query = (
268 Delete(keyspace, table_name)
269 .where(C("apdb_part") == apdb_part)
270 .where(C("diaObjectId").in_(oid_chunk))
271 )
272 object_deletes.append(context.stmt_factory.with_params(query))
273 else:
274 table_name = context.schema.tableName(ApdbTables.DiaObject)
275 query = (
276 Delete(keyspace, table_name)
277 .where(C("apdb_part") == apdb_part)
278 .where(C("apdb_time_part") == time_part)
279 .where(C("diaObjectId").in_(oid_chunk))
280 )
281 object_deletes.append(context.stmt_factory.with_params(query))
283 # Delete from DiaObjectLastToPartition table.
284 for oid_chunk in chunk_iterable(sorted(object_ids), 1000):
285 query = Delete(keyspace, "DiaObjectLastToPartition").where(C("diaObjectId").in_(oid_chunk))
286 object_deletes.append(context.stmt_factory.with_params(query))
288 # Group sources by partition.
289 source_partitions = defaultdict(list)
290 for source in itertools.chain.from_iterable(source_groups.values()):
291 apdb_part = self.apdb_part(source.ra, source.dec)
292 apdb_time_part = self.apdb_time_part(source.midpointMjdTai)
293 source_partitions[(apdb_part, apdb_time_part)].append(source)
295 source_deletes = []
296 source_count = 0
297 for (apdb_part, apdb_time_part), source_list in source_partitions.items():
298 source_ids = sorted(source.diaSourceId for source in source_list)
299 source_count += len(source_ids)
300 for id_chunk in chunk_iterable(source_ids, 1000):
301 if config.partitioning.time_partition_tables:
302 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part)
303 query = (
304 Delete(keyspace, table_name)
305 .where(C("apdb_part") == apdb_part)
306 .where(C("diaSourceId").in_(id_chunk))
307 )
308 source_deletes.append(context.stmt_factory.with_params(query))
309 else:
310 table_name = context.schema.tableName(ApdbTables.DiaSource)
311 query = (
312 Delete(keyspace, table_name)
313 .where(C("apdb_part") == apdb_part)
314 .where(C("apdb_time_part") == apdb_time_part)
315 .where(C("diaSourceId").in_(id_chunk))
316 )
317 source_deletes.append(context.stmt_factory.with_params(query))
319 # Group forced sources by partition.
320 forced_source_partitions = defaultdict(list)
321 for forced_source in forced_sources:
322 if forced_source.diaObjectId in object_ids:
323 apdb_part = self.apdb_part(forced_source.ra, forced_source.dec)
324 apdb_time_part = self.apdb_time_part(forced_source.midpointMjdTai)
325 forced_source_partitions[(apdb_part, apdb_time_part)].append(forced_source)
327 forced_source_deletes = []
328 forced_source_count = 0
329 for (apdb_part, apdb_time_part), forced_source_list in forced_source_partitions.items():
330 clustering_keys = sorted(
331 (fsource.diaObjectId, fsource.visit, fsource.detector) for fsource in forced_source_list
332 )
333 forced_source_count += len(clustering_keys)
334 for key_chunk in chunk_iterable(clustering_keys, 1000):
335 cl_str = ",".join(f"({oid}, {v}, {d})" for oid, v, d in key_chunk)
336 if config.partitioning.time_partition_tables:
337 table_name = context.schema.tableName(ApdbTables.DiaForcedSource, apdb_time_part)
338 query = (
339 Delete(keyspace, table_name)
340 .where(C("apdb_part") == apdb_part)
341 .where(f'("diaObjectId", visit, detector) IN ({cl_str})')
342 )
343 forced_source_deletes.append(context.stmt_factory.with_params(query))
344 else:
345 table_name = context.schema.tableName(ApdbTables.DiaForcedSource)
346 query = (
347 Delete(keyspace, table_name)
348 .where(C("apdb_part") == apdb_part)
349 .where(C("apdb_time_part") == apdb_time_part)
350 .where(f'("diaObjectId", visit, detector) IN ({cl_str})')
351 )
352 forced_source_deletes.append(context.stmt_factory.with_params(query))
354 _LOG.info(
355 "Deleting %d objects, %d sources, and %d forced sources",
356 object_count,
357 source_count,
358 forced_source_count,
359 )
361 # Now run all queries.
362 with self._timer("delete_forced_sources"):
363 execute_concurrent(context.session, forced_source_deletes)
364 with self._timer("delete_sources"):
365 execute_concurrent(context.session, source_deletes)
366 with self._timer("delete_objects"):
367 execute_concurrent(context.session, object_deletes)
369 def time_partitions(self) -> ApdbCassandraTimePartitionRange:
370 """Return range of existing time partitions.
372 Returns
373 -------
374 range : `ApdbCassandraTimePartitionRange`
375 Time partition range.
377 Raises
378 ------
379 TypeError
380 Raised if APDB instance does not use time-partition tables.
381 """
382 context = self._apdb._context
383 part_range = context.time_partitions_range
384 if not part_range:
385 raise TypeError("This APDB instance does not use time-partitioned tables.")
386 return part_range
388 def extend_time_partitions(
389 self,
390 time: astropy.time.Time,
391 forward: bool = True,
392 max_delta: astropy.time.TimeDelta | None = None,
393 ) -> list[int]:
394 """Extend set of time-partitioned tables to include specified time.
396 Parameters
397 ----------
398 time : `astropy.time.Time`
399 Time to which to extend partitions.
400 forward : `bool`, optional
401 If `True` then extend partitions into the future, time should be
402 later than the end time of the last existing partition. If `False`
403 then extend partitions into the past, time should be earlier than
404 the start time of the first existing partition.
405 max_delta : `astropy.time.TimeDelta`, optional
406 Maximum possible extension of the aprtitions, default is 365 days.
408 Returns
409 -------
410 partitions : `list` [`int`]
411 List of partitons added to the database, empty list returned if
412 ``time`` is already in the existing partition range.
414 Raises
415 ------
416 TypeError
417 Raised if APDB instance does not use time-partition tables.
418 ValueError
419 Raised if extension request exceeds time limit of ``max_delta``.
420 """
421 if max_delta is None:
422 max_delta = astropy.time.TimeDelta(365, format="jd")
424 context = self._apdb._context
426 # Get current partitions.
427 part_range = context.time_partitions_range
428 if not part_range:
429 raise TypeError("This APDB instance does not use time-partitioned tables.")
431 # Partitions that we need to create.
432 partitions = self._partitions_to_add(time, forward, max_delta)
433 if not partitions:
434 return []
436 _LOG.debug("New partitions to create: %s", partitions)
438 # Tables that are time-partitioned.
439 keyspace = self._apdb._keyspace
440 tables = context.schema.time_partitioned_tables()
442 # Easiest way to create new tables is to take DDL from existing one
443 # and update table name.
444 table_name_token = "%TABLE_NAME%"
445 table_schemas = {}
446 for table in tables:
447 existing_table_name = context.schema.tableName(table, part_range.end)
448 query = f'DESCRIBE TABLE "{keyspace}"."{existing_table_name}"'
449 result = context.session.execute(query).one()
450 if not result:
451 raise LookupError(f'Failed to read schema for table "{keyspace}"."{existing_table_name}"')
452 schema: str = result.create_statement
453 schema = schema.replace(existing_table_name, table_name_token)
454 table_schemas[table] = schema
456 # Be paranoid and check that none of the new tables exist.
457 exsisting_tables = context.schema.existing_tables(*tables)
458 for table in tables:
459 new_tables = {context.schema.tableName(table, partition) for partition in partitions}
460 old_tables = new_tables.intersection(exsisting_tables[table])
461 if old_tables:
462 raise ValueError(f"Some to be created tables already exist: {old_tables}")
464 # Now can create all of them.
465 for table, schema in table_schemas.items():
466 for partition in partitions:
467 new_table_name = context.schema.tableName(table, partition)
468 _LOG.debug("Creating table %s", new_table_name)
469 new_ddl = schema.replace(table_name_token, new_table_name)
470 context.session.execute(new_ddl)
472 # Update metadata.
473 if forward:
474 part_range.end = max(partitions)
475 else:
476 part_range.start = min(partitions)
477 part_range.save_to_meta(context.metadata)
479 return partitions
481 def _partitions_to_add(
482 self,
483 time: astropy.time.Time,
484 forward: bool,
485 max_delta: astropy.time.TimeDelta,
486 ) -> list[int]:
487 """Make the list of time partitions to add to current range."""
488 context = self._apdb._context
489 part_range = context.time_partitions_range
490 assert part_range is not None
492 new_partition = context.partitioner.time_partition(time)
493 if forward:
494 if new_partition <= part_range.end:
495 _LOG.debug(
496 "Partition for time=%s (%d) is below existing end (%d)",
497 time,
498 new_partition,
499 part_range.end,
500 )
501 return []
502 _, end = context.partitioner.partition_period(part_range.end)
503 if time - end > max_delta:
504 raise ValueError(
505 f"Extension exceeds limit: current end time = {end.isot}, new end time = {time.isot}, "
506 f"limit = {max_delta.jd} days"
507 )
508 partitions = list(range(part_range.end + 1, new_partition + 1))
509 else:
510 if new_partition >= part_range.start:
511 _LOG.debug(
512 "Partition for time=%s (%d) is above existing start (%d)",
513 time,
514 new_partition,
515 part_range.start,
516 )
517 return []
518 start, _ = context.partitioner.partition_period(part_range.start)
519 if start - time > max_delta:
520 raise ValueError(
521 f"Extension exceeds limit: current start time = {start.isot}, "
522 f"new start time = {time.isot}, "
523 f"limit = {max_delta.jd} days"
524 )
525 partitions = list(range(new_partition, part_range.start))
527 return partitions
529 def delete_time_partitions(
530 self, time: astropy.time.Time, after: bool = False, *, confirm: ConfirmDeletePartitions | None = None
531 ) -> list[int]:
532 """Delete time-partitioned tables before or after specified time.
534 Parameters
535 ----------
536 time : `astropy.time.Time`
537 Time before or after which to remove partitions. Partition that
538 includes this time is not deleted.
539 after : `bool`, optional
540 If `True` then delete partitions after the specified time. Default
541 is to delete partitions before this time.
542 confirm : `~collections.abc.Callable`, optional
543 A callable that will be called to confirm deletion of the
544 partitions. The callable needs to accept three keyword arguments:
546 - `partitions` - a list of partition numbers to be deleted,
547 - `tables` - a list of table names to be deleted,
548 - `partitioner` - a `Partitioner` instance.
550 Partitions are deleted only if callable returns `True`.
552 Returns
553 -------
554 partitions : `list` [`int`]
555 List of partitons deleted from the database, empty list returned if
556 nothing is deleted.
558 Raises
559 ------
560 TypeError
561 Raised if APDB instance does not use time-partition tables.
562 ValueError
563 Raised if requested to delete all partitions.
564 """
565 context = self._apdb._context
567 # Get current partitions.
568 part_range = context.time_partitions_range
569 if not part_range:
570 raise TypeError("This APDB instance does not use time-partitioned tables.")
572 partitions = self._partitions_to_delete(time, after)
573 if not partitions:
574 return []
576 # Cannot delete all partitions.
577 if min(partitions) == part_range.start and max(partitions) == part_range.end:
578 raise ValueError("Cannot delete all partitions.")
580 # Tables that are time-partitioned.
581 keyspace = self._apdb._keyspace
582 tables = context.schema.time_partitioned_tables()
584 table_names = []
585 for table in tables:
586 for partition in partitions:
587 table_names.append(context.schema.tableName(table, partition))
589 if confirm is not None:
590 # It can raise an exception, but at this point it's completely
591 # harmless.
592 answer = confirm(partitions=partitions, tables=table_names, partitioner=context.partitioner)
593 if not answer:
594 return []
596 for table_name in table_names:
597 _LOG.debug("Dropping table %s", table_name)
598 # Use IF EXISTS just in case.
599 query = f'DROP TABLE IF EXISTS "{keyspace}"."{table_name}"'
600 context.session.execute(query)
602 # Update metadata.
603 if after:
604 part_range.end = min(partitions) - 1
605 else:
606 part_range.start = max(partitions) + 1
607 part_range.save_to_meta(context.metadata)
609 return partitions
611 def _partitions_to_delete(
612 self,
613 time: astropy.time.Time,
614 after: bool = False,
615 ) -> list[int]:
616 """Make the list of time partitions to delete."""
617 context = self._apdb._context
618 part_range = context.time_partitions_range
619 assert part_range is not None
621 partition = context.partitioner.time_partition(time)
622 if after:
623 return list(range(max(partition + 1, part_range.start), part_range.end + 1))
624 else:
625 return list(range(part_range.start, min(partition, part_range.end + 1)))