Coverage for python / lsst / dax / apdb / cassandra / apdbCassandraAdmin.py: 14%

275 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:58 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraAdmin"] 

25 

26import dataclasses 

27import itertools 

28import logging 

29import warnings 

30from collections import defaultdict 

31from collections.abc import Iterable, Mapping 

32from typing import TYPE_CHECKING, Protocol 

33 

34import astropy.time 

35 

36from lsst.sphgeom import LonLat, UnitVector3d 

37from lsst.utils.iteration import chunk_iterable 

38 

39try: 

40 import cassandra 

41except ImportError: 

42 pass 

43 

44from ..apdbAdmin import ApdbAdmin, DiaForcedSourceLocator, DiaObjectLocator, DiaSourceLocator 

45from ..apdbSchema import ApdbTables 

46from ..monitor import MonAgent 

47from ..timer import Timer 

48from .cassandra_utils import StatementFactory, execute_concurrent, quote_id 

49from .config import ApdbCassandraConfig, ApdbCassandraTimePartitionRange 

50from .queries import Column as C # noqa: N817 

51from .queries import Delete, Select 

52from .sessionFactory import SessionContext 

53 

54if TYPE_CHECKING: 

55 from .apdbCassandra import ApdbCassandra 

56 from .partitioner import Partitioner 

57 

58_LOG = logging.getLogger(__name__) 

59 

60_MON = MonAgent(__name__) 

61 

62 

63class ConfirmDeletePartitions(Protocol): 

64 """Protocol for callable which confirms deletion of partitions.""" 

65 

66 def __call__(self, *, partitions: list[int], tables: list[str], partitioner: Partitioner) -> bool: ... 66 ↛ exitline 66 didn't return from function '__call__' because

67 

68 

69@dataclasses.dataclass 

70class DatabaseInfo: 

71 """Collection of information about a specific database.""" 

72 

73 name: str 

74 """Keyspace name.""" 

75 

76 permissions: dict[str, set[str]] | None = None 

77 """Roles that can access the database and their permissions. 

78 

79 `None` means that authentication information is not accessible due to 

80 system table permissions. If anonymous access is enabled then dictionary 

81 will be empty but not `None`. 

82 """ 

83 

84 

85class ApdbCassandraAdmin(ApdbAdmin): 

86 """Implementation of `ApdbAdmin` for Cassandra backend. 

87 

88 Parameters 

89 ---------- 

90 apdb : `ApdbCassandra` 

91 APDB implementation. 

92 """ 

93 

94 def __init__(self, apdb: ApdbCassandra): 

95 self._apdb = apdb 

96 

97 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

98 """Create `Timer` instance given its name.""" 

99 return Timer(name, _MON, _LOG, tags=tags) 

100 

101 @classmethod 

102 def list_databases(cls, host: str) -> Iterable[DatabaseInfo]: 

103 """Return the list of keyspaces with APDB databases. 

104 

105 Parameters 

106 ---------- 

107 host : `str` 

108 Name of one of the hosts in Cassandra cluster. 

109 

110 Returns 

111 ------- 

112 databases : `~collections.abc.Iterable` [`DatabaseInfo`] 

113 Information about databases that contain APDB instance. 

114 """ 

115 # For DbAuth we need to use database name "*" to try to match any 

116 # database. 

117 config = ApdbCassandraConfig(contact_points=(host,), keyspace="*") 

118 with SessionContext(config) as session: 

119 stmt_factory = StatementFactory(session) 

120 

121 # Get names of all keyspaces containing DiaSource table 

122 table_name = ApdbTables.DiaSource.table_name() 

123 query = Select("system_schema", "tables", ["keyspace_name"], extra_clause="ALLOW FILTERING") 

124 query = query.where(C("table_name") == table_name) 

125 stmt, params = stmt_factory.with_params(query) 

126 

127 result = session.execute(stmt, params) 

128 keyspaces = [row[0] for row in result.all()] 

129 

130 if not keyspaces: 

131 return [] 

132 

133 # Retrieve roles for each keyspace. 

134 resources = [f"data/{keyspace}" for keyspace in keyspaces] 

135 query = Select( 

136 "system_auth", 

137 "role_permissions", 

138 ("resource", "role", "permissions"), 

139 extra_clause="ALLOW FILTERING", 

140 ) 

141 query = query.where(C("resource").in_(resources)) 

142 stmt, params = stmt_factory.with_params(query) 

143 

144 try: 

145 result = session.execute(stmt, params) 

146 # If anonymous access is enabled then result will be empty, 

147 # set infos to have empty permissions dict in that case. 

148 infos = {keyspace: DatabaseInfo(name=keyspace, permissions={}) for keyspace in keyspaces} 

149 for row in result: 

150 _, _, keyspace = row[0].partition("/") 

151 role: str = row[1] 

152 role_permissions: set[str] = set(row[2]) 

153 infos[keyspace].permissions[role] = role_permissions # type: ignore[index] 

154 except cassandra.Unauthorized as exc: 

155 # Likely that access to role_permissions is not granted for 

156 # current user. 

157 warnings.warn( 

158 f"Authentication information is not accessible to current user - {exc}", stacklevel=2 

159 ) 

160 infos = {keyspace: DatabaseInfo(name=keyspace) for keyspace in keyspaces} 

161 

162 # Would be nice to get size estimate, but this is not available 

163 # via CQL queries. 

164 return infos.values() 

165 

166 @classmethod 

167 def delete_database(cls, host: str, keyspace: str, *, timeout: int = 3600) -> None: 

168 """Delete APDB database by dropping its keyspace. 

169 

170 Parameters 

171 ---------- 

172 host : `str` 

173 Name of one of the hosts in Cassandra cluster. 

174 keyspace : `str` 

175 Name of keyspace to delete. 

176 timeout : `int`, optional 

177 Timeout for delete operation in seconds. Dropping a large keyspace 

178 can be a long operation, but this default value of one hour should 

179 be sufficient for most or all cases. 

180 """ 

181 # For DbAuth we need to use database name "*" to try to match any 

182 # database. 

183 config = ApdbCassandraConfig(contact_points=(host,), keyspace="*") 

184 with SessionContext(config) as session: 

185 query = f"DROP KEYSPACE {quote_id(keyspace)}" 

186 session.execute(query, timeout=timeout) 

187 

188 @property 

189 def partitioner(self) -> Partitioner: 

190 """Partitoner used by this APDB instance (`Partitioner`).""" 

191 context = self._apdb._context 

192 return context.partitioner 

193 

194 def apdb_part(self, ra: float, dec: float) -> int: 

195 # docstring is inherited from a base class 

196 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

197 return self.partitioner.pixel(uv3d) 

198 

199 def apdb_time_part(self, midpointMjdTai: float) -> int: 

200 # docstring is inherited from a base class 

201 return self.partitioner.time_partition(midpointMjdTai) 

202 

203 def delete_records( 

204 self, 

205 objects: Iterable[DiaObjectLocator], 

206 sources: Iterable[DiaSourceLocator], 

207 forced_sources: Iterable[DiaForcedSourceLocator], 

208 ) -> None: 

209 # docstring is inherited from a base class 

210 context = self._apdb._context 

211 config = context.config 

212 keyspace = self._apdb._keyspace 

213 has_dia_object_table = not (config.enable_replica and config.replica_skips_diaobjects) 

214 

215 # Group objects by partition. 

216 partitions = defaultdict(list) 

217 for object in objects: 

218 apdb_part = self.apdb_part(object.ra, object.dec) 

219 partitions[apdb_part].append(object.diaObjectId) 

220 object_ids = set(itertools.chain.from_iterable(partitions.values())) 

221 

222 # Group sources by associated object ID. 

223 source_groups = defaultdict(list) 

224 for source in sources: 

225 if source.diaObjectId in object_ids: 

226 source_groups[source.diaObjectId].append(source) 

227 

228 object_deletes = [] 

229 object_count = 0 

230 # Delete from DiaObjectLast table. 

231 for apdb_part, oids in partitions.items(): 

232 oids = sorted(oids) 

233 object_count += len(oids) 

234 for oid_chunk in chunk_iterable(oids, 1000): 

235 query = ( 

236 Delete(keyspace, "DiaObjectLast") 

237 .where(C("apdb_part") == apdb_part) 

238 .where(C("diaObjectId").in_(oid_chunk)) 

239 ) 

240 object_deletes.append(context.stmt_factory.with_params(query)) 

241 

242 # If DiaObject is in use then delete from that too. 

243 if has_dia_object_table: 

244 # Need temporal partitions for DiaObject, the only source for that 

245 # is the timestamp of the associated DiaSource. Problem here is 

246 # that DiaObject temporal partitioning is based on validityStart, 

247 # which is "visit_time"", but DiaSource does not record visit_time, 

248 # it is partitioned on midpointMjdTai. There is time_processed 

249 # defined for DiaSource but it does not match "visit_time" though 

250 # it is close. I use midpointMjdTai as approximation for 

251 # validityStart, this may skip some DiaObjects, but in production 

252 # we are not going to have DiaObjects table at all. There is also 

253 # a chance that DiaObject moves from one spatial partition to 

254 # another with the same consequences, which we also ignore. 

255 oids_by_partition: dict[tuple[int, int], list[int]] = defaultdict(list) 

256 for apdb_part, oids in partitions.items(): 

257 for oid in oids: 

258 temporal_partitions = { 

259 self.apdb_time_part(src.midpointMjdTai) for src in source_groups.get(oid, []) 

260 } 

261 for time_part in temporal_partitions: 

262 oids_by_partition[(apdb_part, time_part)].append(oid) 

263 for (apdb_part, time_part), oids in oids_by_partition.items(): 

264 for oid_chunk in chunk_iterable(oids, 1000): 

265 if config.partitioning.time_partition_tables: 

266 table_name = context.schema.tableName(ApdbTables.DiaObject, time_part) 

267 query = ( 

268 Delete(keyspace, table_name) 

269 .where(C("apdb_part") == apdb_part) 

270 .where(C("diaObjectId").in_(oid_chunk)) 

271 ) 

272 object_deletes.append(context.stmt_factory.with_params(query)) 

273 else: 

274 table_name = context.schema.tableName(ApdbTables.DiaObject) 

275 query = ( 

276 Delete(keyspace, table_name) 

277 .where(C("apdb_part") == apdb_part) 

278 .where(C("apdb_time_part") == time_part) 

279 .where(C("diaObjectId").in_(oid_chunk)) 

280 ) 

281 object_deletes.append(context.stmt_factory.with_params(query)) 

282 

283 # Delete from DiaObjectLastToPartition table. 

284 for oid_chunk in chunk_iterable(sorted(object_ids), 1000): 

285 query = Delete(keyspace, "DiaObjectLastToPartition").where(C("diaObjectId").in_(oid_chunk)) 

286 object_deletes.append(context.stmt_factory.with_params(query)) 

287 

288 # Group sources by partition. 

289 source_partitions = defaultdict(list) 

290 for source in itertools.chain.from_iterable(source_groups.values()): 

291 apdb_part = self.apdb_part(source.ra, source.dec) 

292 apdb_time_part = self.apdb_time_part(source.midpointMjdTai) 

293 source_partitions[(apdb_part, apdb_time_part)].append(source) 

294 

295 source_deletes = [] 

296 source_count = 0 

297 for (apdb_part, apdb_time_part), source_list in source_partitions.items(): 

298 source_ids = sorted(source.diaSourceId for source in source_list) 

299 source_count += len(source_ids) 

300 for id_chunk in chunk_iterable(source_ids, 1000): 

301 if config.partitioning.time_partition_tables: 

302 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part) 

303 query = ( 

304 Delete(keyspace, table_name) 

305 .where(C("apdb_part") == apdb_part) 

306 .where(C("diaSourceId").in_(id_chunk)) 

307 ) 

308 source_deletes.append(context.stmt_factory.with_params(query)) 

309 else: 

310 table_name = context.schema.tableName(ApdbTables.DiaSource) 

311 query = ( 

312 Delete(keyspace, table_name) 

313 .where(C("apdb_part") == apdb_part) 

314 .where(C("apdb_time_part") == apdb_time_part) 

315 .where(C("diaSourceId").in_(id_chunk)) 

316 ) 

317 source_deletes.append(context.stmt_factory.with_params(query)) 

318 

319 # Group forced sources by partition. 

320 forced_source_partitions = defaultdict(list) 

321 for forced_source in forced_sources: 

322 if forced_source.diaObjectId in object_ids: 

323 apdb_part = self.apdb_part(forced_source.ra, forced_source.dec) 

324 apdb_time_part = self.apdb_time_part(forced_source.midpointMjdTai) 

325 forced_source_partitions[(apdb_part, apdb_time_part)].append(forced_source) 

326 

327 forced_source_deletes = [] 

328 forced_source_count = 0 

329 for (apdb_part, apdb_time_part), forced_source_list in forced_source_partitions.items(): 

330 clustering_keys = sorted( 

331 (fsource.diaObjectId, fsource.visit, fsource.detector) for fsource in forced_source_list 

332 ) 

333 forced_source_count += len(clustering_keys) 

334 for key_chunk in chunk_iterable(clustering_keys, 1000): 

335 cl_str = ",".join(f"({oid}, {v}, {d})" for oid, v, d in key_chunk) 

336 if config.partitioning.time_partition_tables: 

337 table_name = context.schema.tableName(ApdbTables.DiaForcedSource, apdb_time_part) 

338 query = ( 

339 Delete(keyspace, table_name) 

340 .where(C("apdb_part") == apdb_part) 

341 .where(f'("diaObjectId", visit, detector) IN ({cl_str})') 

342 ) 

343 forced_source_deletes.append(context.stmt_factory.with_params(query)) 

344 else: 

345 table_name = context.schema.tableName(ApdbTables.DiaForcedSource) 

346 query = ( 

347 Delete(keyspace, table_name) 

348 .where(C("apdb_part") == apdb_part) 

349 .where(C("apdb_time_part") == apdb_time_part) 

350 .where(f'("diaObjectId", visit, detector) IN ({cl_str})') 

351 ) 

352 forced_source_deletes.append(context.stmt_factory.with_params(query)) 

353 

354 _LOG.info( 

355 "Deleting %d objects, %d sources, and %d forced sources", 

356 object_count, 

357 source_count, 

358 forced_source_count, 

359 ) 

360 

361 # Now run all queries. 

362 with self._timer("delete_forced_sources"): 

363 execute_concurrent(context.session, forced_source_deletes) 

364 with self._timer("delete_sources"): 

365 execute_concurrent(context.session, source_deletes) 

366 with self._timer("delete_objects"): 

367 execute_concurrent(context.session, object_deletes) 

368 

369 def time_partitions(self) -> ApdbCassandraTimePartitionRange: 

370 """Return range of existing time partitions. 

371 

372 Returns 

373 ------- 

374 range : `ApdbCassandraTimePartitionRange` 

375 Time partition range. 

376 

377 Raises 

378 ------ 

379 TypeError 

380 Raised if APDB instance does not use time-partition tables. 

381 """ 

382 context = self._apdb._context 

383 part_range = context.time_partitions_range 

384 if not part_range: 

385 raise TypeError("This APDB instance does not use time-partitioned tables.") 

386 return part_range 

387 

388 def extend_time_partitions( 

389 self, 

390 time: astropy.time.Time, 

391 forward: bool = True, 

392 max_delta: astropy.time.TimeDelta | None = None, 

393 ) -> list[int]: 

394 """Extend set of time-partitioned tables to include specified time. 

395 

396 Parameters 

397 ---------- 

398 time : `astropy.time.Time` 

399 Time to which to extend partitions. 

400 forward : `bool`, optional 

401 If `True` then extend partitions into the future, time should be 

402 later than the end time of the last existing partition. If `False` 

403 then extend partitions into the past, time should be earlier than 

404 the start time of the first existing partition. 

405 max_delta : `astropy.time.TimeDelta`, optional 

406 Maximum possible extension of the aprtitions, default is 365 days. 

407 

408 Returns 

409 ------- 

410 partitions : `list` [`int`] 

411 List of partitons added to the database, empty list returned if 

412 ``time`` is already in the existing partition range. 

413 

414 Raises 

415 ------ 

416 TypeError 

417 Raised if APDB instance does not use time-partition tables. 

418 ValueError 

419 Raised if extension request exceeds time limit of ``max_delta``. 

420 """ 

421 if max_delta is None: 

422 max_delta = astropy.time.TimeDelta(365, format="jd") 

423 

424 context = self._apdb._context 

425 

426 # Get current partitions. 

427 part_range = context.time_partitions_range 

428 if not part_range: 

429 raise TypeError("This APDB instance does not use time-partitioned tables.") 

430 

431 # Partitions that we need to create. 

432 partitions = self._partitions_to_add(time, forward, max_delta) 

433 if not partitions: 

434 return [] 

435 

436 _LOG.debug("New partitions to create: %s", partitions) 

437 

438 # Tables that are time-partitioned. 

439 keyspace = self._apdb._keyspace 

440 tables = context.schema.time_partitioned_tables() 

441 

442 # Easiest way to create new tables is to take DDL from existing one 

443 # and update table name. 

444 table_name_token = "%TABLE_NAME%" 

445 table_schemas = {} 

446 for table in tables: 

447 existing_table_name = context.schema.tableName(table, part_range.end) 

448 query = f'DESCRIBE TABLE "{keyspace}"."{existing_table_name}"' 

449 result = context.session.execute(query).one() 

450 if not result: 

451 raise LookupError(f'Failed to read schema for table "{keyspace}"."{existing_table_name}"') 

452 schema: str = result.create_statement 

453 schema = schema.replace(existing_table_name, table_name_token) 

454 table_schemas[table] = schema 

455 

456 # Be paranoid and check that none of the new tables exist. 

457 exsisting_tables = context.schema.existing_tables(*tables) 

458 for table in tables: 

459 new_tables = {context.schema.tableName(table, partition) for partition in partitions} 

460 old_tables = new_tables.intersection(exsisting_tables[table]) 

461 if old_tables: 

462 raise ValueError(f"Some to be created tables already exist: {old_tables}") 

463 

464 # Now can create all of them. 

465 for table, schema in table_schemas.items(): 

466 for partition in partitions: 

467 new_table_name = context.schema.tableName(table, partition) 

468 _LOG.debug("Creating table %s", new_table_name) 

469 new_ddl = schema.replace(table_name_token, new_table_name) 

470 context.session.execute(new_ddl) 

471 

472 # Update metadata. 

473 if forward: 

474 part_range.end = max(partitions) 

475 else: 

476 part_range.start = min(partitions) 

477 part_range.save_to_meta(context.metadata) 

478 

479 return partitions 

480 

481 def _partitions_to_add( 

482 self, 

483 time: astropy.time.Time, 

484 forward: bool, 

485 max_delta: astropy.time.TimeDelta, 

486 ) -> list[int]: 

487 """Make the list of time partitions to add to current range.""" 

488 context = self._apdb._context 

489 part_range = context.time_partitions_range 

490 assert part_range is not None 

491 

492 new_partition = context.partitioner.time_partition(time) 

493 if forward: 

494 if new_partition <= part_range.end: 

495 _LOG.debug( 

496 "Partition for time=%s (%d) is below existing end (%d)", 

497 time, 

498 new_partition, 

499 part_range.end, 

500 ) 

501 return [] 

502 _, end = context.partitioner.partition_period(part_range.end) 

503 if time - end > max_delta: 

504 raise ValueError( 

505 f"Extension exceeds limit: current end time = {end.isot}, new end time = {time.isot}, " 

506 f"limit = {max_delta.jd} days" 

507 ) 

508 partitions = list(range(part_range.end + 1, new_partition + 1)) 

509 else: 

510 if new_partition >= part_range.start: 

511 _LOG.debug( 

512 "Partition for time=%s (%d) is above existing start (%d)", 

513 time, 

514 new_partition, 

515 part_range.start, 

516 ) 

517 return [] 

518 start, _ = context.partitioner.partition_period(part_range.start) 

519 if start - time > max_delta: 

520 raise ValueError( 

521 f"Extension exceeds limit: current start time = {start.isot}, " 

522 f"new start time = {time.isot}, " 

523 f"limit = {max_delta.jd} days" 

524 ) 

525 partitions = list(range(new_partition, part_range.start)) 

526 

527 return partitions 

528 

529 def delete_time_partitions( 

530 self, time: astropy.time.Time, after: bool = False, *, confirm: ConfirmDeletePartitions | None = None 

531 ) -> list[int]: 

532 """Delete time-partitioned tables before or after specified time. 

533 

534 Parameters 

535 ---------- 

536 time : `astropy.time.Time` 

537 Time before or after which to remove partitions. Partition that 

538 includes this time is not deleted. 

539 after : `bool`, optional 

540 If `True` then delete partitions after the specified time. Default 

541 is to delete partitions before this time. 

542 confirm : `~collections.abc.Callable`, optional 

543 A callable that will be called to confirm deletion of the 

544 partitions. The callable needs to accept three keyword arguments: 

545 

546 - `partitions` - a list of partition numbers to be deleted, 

547 - `tables` - a list of table names to be deleted, 

548 - `partitioner` - a `Partitioner` instance. 

549 

550 Partitions are deleted only if callable returns `True`. 

551 

552 Returns 

553 ------- 

554 partitions : `list` [`int`] 

555 List of partitons deleted from the database, empty list returned if 

556 nothing is deleted. 

557 

558 Raises 

559 ------ 

560 TypeError 

561 Raised if APDB instance does not use time-partition tables. 

562 ValueError 

563 Raised if requested to delete all partitions. 

564 """ 

565 context = self._apdb._context 

566 

567 # Get current partitions. 

568 part_range = context.time_partitions_range 

569 if not part_range: 

570 raise TypeError("This APDB instance does not use time-partitioned tables.") 

571 

572 partitions = self._partitions_to_delete(time, after) 

573 if not partitions: 

574 return [] 

575 

576 # Cannot delete all partitions. 

577 if min(partitions) == part_range.start and max(partitions) == part_range.end: 

578 raise ValueError("Cannot delete all partitions.") 

579 

580 # Tables that are time-partitioned. 

581 keyspace = self._apdb._keyspace 

582 tables = context.schema.time_partitioned_tables() 

583 

584 table_names = [] 

585 for table in tables: 

586 for partition in partitions: 

587 table_names.append(context.schema.tableName(table, partition)) 

588 

589 if confirm is not None: 

590 # It can raise an exception, but at this point it's completely 

591 # harmless. 

592 answer = confirm(partitions=partitions, tables=table_names, partitioner=context.partitioner) 

593 if not answer: 

594 return [] 

595 

596 for table_name in table_names: 

597 _LOG.debug("Dropping table %s", table_name) 

598 # Use IF EXISTS just in case. 

599 query = f'DROP TABLE IF EXISTS "{keyspace}"."{table_name}"' 

600 context.session.execute(query) 

601 

602 # Update metadata. 

603 if after: 

604 part_range.end = min(partitions) - 1 

605 else: 

606 part_range.start = max(partitions) + 1 

607 part_range.save_to_meta(context.metadata) 

608 

609 return partitions 

610 

611 def _partitions_to_delete( 

612 self, 

613 time: astropy.time.Time, 

614 after: bool = False, 

615 ) -> list[int]: 

616 """Make the list of time partitions to delete.""" 

617 context = self._apdb._context 

618 part_range = context.time_partitions_range 

619 assert part_range is not None 

620 

621 partition = context.partitioner.time_partition(time) 

622 if after: 

623 return list(range(max(partition + 1, part_range.start), part_range.end + 1)) 

624 else: 

625 return list(range(part_range.start, min(partition, part_range.end + 1)))