Coverage for python / lsst / dax / apdb / cassandra / apdbCassandra.py: 9%

809 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:43 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandra"] 

25 

26import datetime 

27import json 

28import logging 

29import random 

30import uuid 

31import warnings 

32from collections import Counter, defaultdict 

33from collections.abc import Iterable, Mapping, Set 

34from typing import TYPE_CHECKING, Any, cast 

35 

36import numpy as np 

37import pandas 

38 

39# If cassandra-driver is not there the module can still be imported 

40# but ApdbCassandra cannot be instantiated. 

41try: 

42 import cassandra 

43 import cassandra.query 

44 from cassandra.query import UNSET_VALUE 

45 

46 CASSANDRA_IMPORTED = True 

47except ImportError: 

48 CASSANDRA_IMPORTED = False 

49 

50import astropy.time 

51import felis.datamodel 

52 

53from lsst import sphgeom 

54from lsst.utils.iteration import chunk_iterable 

55 

56from ..apdb import Apdb, ApdbConfig 

57from ..apdbConfigFreezer import ApdbConfigFreezer 

58from ..apdbReplica import ApdbTableData, ReplicaChunk 

59from ..apdbSchema import ApdbSchema, ApdbTables 

60from ..apdbUpdateRecord import ( 

61 ApdbCloseDiaObjectValidityRecord, 

62 ApdbReassignDiaSourceToDiaObjectRecord, 

63 ApdbUpdateNDiaSourcesRecord, 

64) 

65from ..monitor import MonAgent 

66from ..recordIds import DiaObjectId, DiaSourceId 

67from ..schema_model import Table 

68from ..timer import Timer 

69from ..versionTuple import VersionTuple 

70from .apdbCassandraAdmin import ApdbCassandraAdmin 

71from .apdbCassandraReplica import ApdbCassandraReplica 

72from .apdbCassandraSchema import ApdbCassandraSchema, CreateTableOptions, ExtraTables 

73from .apdbMetadataCassandra import ApdbMetadataCassandra 

74from .cassandra_utils import ( 

75 ApdbCassandraTableData, 

76 execute_concurrent, 

77 literal, 

78 select_concurrent, 

79) 

80from .config import ApdbCassandraConfig, ApdbCassandraConnectionConfig, ApdbCassandraTimePartitionRange 

81from .connectionContext import ConnectionContext, DbVersions 

82from .exceptions import CassandraMissingError 

83from .partitioner import Partitioner 

84from .queries import Column as C # noqa: N817 

85from .queries import ColumnExpr, Delete, Insert, QExpr, Select, Update 

86from .sessionFactory import SessionContext, SessionFactory 

87 

88if TYPE_CHECKING: 

89 from ..apdbMetadata import ApdbMetadata 

90 from ..apdbUpdateRecord import ApdbUpdateRecord 

91 

92_LOG = logging.getLogger(__name__) 

93 

94_MON = MonAgent(__name__) 

95 

96VERSION = VersionTuple(1, 3, 0) 

97"""Version for the code controlling non-replication tables. This needs to be 

98updated following compatibility rules when schema produced by this code 

99changes. 

100""" 

101 

102 

103class ApdbCassandra(Apdb): 

104 """Implementation of APDB database with Apache Cassandra backend. 

105 

106 Parameters 

107 ---------- 

108 config : `ApdbCassandraConfig` 

109 Configuration object. 

110 """ 

111 

112 def __init__(self, config: ApdbCassandraConfig): 

113 if not CASSANDRA_IMPORTED: 

114 raise CassandraMissingError() 

115 

116 self._config = config 

117 self._keyspace = config.keyspace 

118 self._schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

119 

120 self._session_factory = SessionFactory(config) 

121 self._connection_context: ConnectionContext | None = None 

122 

123 @property 

124 def _context(self) -> ConnectionContext: 

125 """Establish connection if not established and return context.""" 

126 if self._connection_context is None: 

127 current_versions = DbVersions( 

128 schema_version=self.schema.schemaVersion(), 

129 code_version=self.apdbImplementationVersion(), 

130 replica_version=ApdbCassandraReplica.apdbReplicaImplementationVersion(), 

131 ) 

132 _LOG.debug("Current versions: %s", current_versions) 

133 

134 session = self._session_factory.session() 

135 self._connection_context = ConnectionContext( 

136 session, self._config, self.schema.tableSchemas, current_versions 

137 ) 

138 

139 if _LOG.isEnabledFor(logging.DEBUG): 

140 _LOG.debug("ApdbCassandra Configuration: %s", self._connection_context.config.model_dump()) 

141 

142 return self._connection_context 

143 

144 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

145 """Create `Timer` instance given its name.""" 

146 return Timer(name, _MON, tags=tags) 

147 

148 @classmethod 

149 def apdbImplementationVersion(cls) -> VersionTuple: 

150 """Return version number for current APDB implementation. 

151 

152 Returns 

153 ------- 

154 version : `VersionTuple` 

155 Version of the code defined in implementation class. 

156 """ 

157 return VERSION 

158 

159 def getConfig(self) -> ApdbCassandraConfig: 

160 # docstring is inherited from a base class 

161 return self._context.config 

162 

163 def tableDef(self, table: ApdbTables) -> Table | None: 

164 # docstring is inherited from a base class 

165 return self.schema.tableSchemas.get(table) 

166 

167 @classmethod 

168 def init_database( 

169 cls, 

170 hosts: tuple[str, ...], 

171 keyspace: str, 

172 *, 

173 schema_file: str | None = None, 

174 ss_schema_file: str | None = None, 

175 read_sources_months: int | None = None, 

176 read_forced_sources_months: int | None = None, 

177 enable_replica: bool = False, 

178 replica_skips_diaobjects: bool = False, 

179 port: int | None = None, 

180 username: str | None = None, 

181 prefix: str | None = None, 

182 part_pixelization: str | None = None, 

183 part_pix_level: int | None = None, 

184 time_partition_tables: bool = True, 

185 time_partition_start: str | None = None, 

186 time_partition_end: str | None = None, 

187 read_consistency: str | None = None, 

188 write_consistency: str | None = None, 

189 read_timeout: int | None = None, 

190 write_timeout: int | None = None, 

191 ra_dec_columns: tuple[str, str] | None = None, 

192 replication_factor: int | None = None, 

193 drop: bool = False, 

194 table_options: CreateTableOptions | None = None, 

195 ) -> ApdbCassandraConfig: 

196 """Initialize new APDB instance and make configuration object for it. 

197 

198 Parameters 

199 ---------- 

200 hosts : `tuple` [`str`, ...] 

201 List of host names or IP addresses for Cassandra cluster. 

202 keyspace : `str` 

203 Name of the keyspace for APDB tables. 

204 schema_file : `str`, optional 

205 Location of (YAML) configuration file with APDB schema. If not 

206 specified then default location will be used. 

207 ss_schema_file : `str`, optional 

208 Location of (YAML) configuration file with SSO schema. If not 

209 specified then default location will be used. 

210 read_sources_months : `int`, optional 

211 Number of months of history to read from DiaSource. 

212 read_forced_sources_months : `int`, optional 

213 Number of months of history to read from DiaForcedSource. 

214 enable_replica : `bool`, optional 

215 If True, make additional tables used for replication to PPDB. 

216 replica_skips_diaobjects : `bool`, optional 

217 If `True` then do not fill regular ``DiaObject`` table when 

218 ``enable_replica`` is `True`. 

219 port : `int`, optional 

220 Port number to use for Cassandra connections. 

221 username : `str`, optional 

222 User name for Cassandra connections. 

223 prefix : `str`, optional 

224 Optional prefix for all table names. 

225 part_pixelization : `str`, optional 

226 Name of the MOC pixelization used for partitioning. 

227 part_pix_level : `int`, optional 

228 Pixelization level. 

229 time_partition_tables : `bool`, optional 

230 Create per-partition tables. 

231 time_partition_start : `str`, optional 

232 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

233 format, in TAI. 

234 time_partition_end : `str`, optional 

235 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

236 format, in TAI. 

237 read_consistency : `str`, optional 

238 Name of the consistency level for read operations. 

239 write_consistency : `str`, optional 

240 Name of the consistency level for write operations. 

241 read_timeout : `int`, optional 

242 Read timeout in seconds. 

243 write_timeout : `int`, optional 

244 Write timeout in seconds. 

245 ra_dec_columns : `tuple` [`str`, `str`], optional 

246 Names of ra/dec columns in DiaObject table. 

247 replication_factor : `int`, optional 

248 Replication factor used when creating new keyspace, if keyspace 

249 already exists its replication factor is not changed. 

250 drop : `bool`, optional 

251 If `True` then drop existing tables before re-creating the schema. 

252 table_options : `CreateTableOptions`, optional 

253 Options used when creating Cassandra tables. 

254 

255 Returns 

256 ------- 

257 config : `ApdbCassandraConfig` 

258 Resulting configuration object for a created APDB instance. 

259 """ 

260 # Some non-standard defaults for connection parameters, these can be 

261 # changed later in generated config. Check Cassandra driver 

262 # documentation for what these parameters do. These parameters are not 

263 # used during database initialization, but they will be saved with 

264 # generated config. 

265 connection_config = ApdbCassandraConnectionConfig( 

266 extra_parameters={ 

267 "idle_heartbeat_interval": 0, 

268 "idle_heartbeat_timeout": 30, 

269 "control_connection_timeout": 100, 

270 }, 

271 ) 

272 config = ApdbCassandraConfig( 

273 contact_points=hosts, 

274 keyspace=keyspace, 

275 enable_replica=enable_replica, 

276 replica_skips_diaobjects=replica_skips_diaobjects, 

277 connection_config=connection_config, 

278 ) 

279 config.partitioning.time_partition_tables = time_partition_tables 

280 if schema_file is not None: 

281 config.schema_file = schema_file 

282 if ss_schema_file is not None: 

283 config.ss_schema_file = ss_schema_file 

284 if read_sources_months is not None: 

285 config.read_sources_months = read_sources_months 

286 if read_forced_sources_months is not None: 

287 config.read_forced_sources_months = read_forced_sources_months 

288 if port is not None: 

289 config.connection_config.port = port 

290 if username is not None: 

291 config.connection_config.username = username 

292 if prefix is not None: 

293 config.prefix = prefix 

294 if part_pixelization is not None: 

295 config.partitioning.part_pixelization = part_pixelization 

296 if part_pix_level is not None: 

297 config.partitioning.part_pix_level = part_pix_level 

298 if time_partition_start is not None: 

299 config.partitioning.time_partition_start = time_partition_start 

300 if time_partition_end is not None: 

301 config.partitioning.time_partition_end = time_partition_end 

302 if read_consistency is not None: 

303 config.connection_config.read_consistency = read_consistency 

304 if write_consistency is not None: 

305 config.connection_config.write_consistency = write_consistency 

306 if read_timeout is not None: 

307 config.connection_config.read_timeout = read_timeout 

308 if write_timeout is not None: 

309 config.connection_config.write_timeout = write_timeout 

310 if ra_dec_columns is not None: 

311 config.ra_dec_columns = ra_dec_columns 

312 

313 cls._makeSchema(config, drop=drop, replication_factor=replication_factor, table_options=table_options) 

314 

315 return config 

316 

317 def get_replica(self) -> ApdbCassandraReplica: 

318 """Return `ApdbReplica` instance for this database.""" 

319 # Note that this instance has to stay alive while replica exists, so 

320 # we pass reference to self. 

321 return ApdbCassandraReplica(self) 

322 

323 @classmethod 

324 def _makeSchema( 

325 cls, 

326 config: ApdbConfig, 

327 *, 

328 drop: bool = False, 

329 replication_factor: int | None = None, 

330 table_options: CreateTableOptions | None = None, 

331 ) -> None: 

332 # docstring is inherited from a base class 

333 

334 if not isinstance(config, ApdbCassandraConfig): 

335 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

336 

337 simple_schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

338 

339 with SessionContext(config) as session: 

340 schema = ApdbCassandraSchema( 

341 session=session, 

342 keyspace=config.keyspace, 

343 table_schemas=simple_schema.tableSchemas, 

344 prefix=config.prefix, 

345 time_partition_tables=config.partitioning.time_partition_tables, 

346 enable_replica=config.enable_replica, 

347 replica_skips_diaobjects=config.replica_skips_diaobjects, 

348 ) 

349 

350 # Ask schema to create all tables. 

351 part_range_config: ApdbCassandraTimePartitionRange | None = None 

352 if config.partitioning.time_partition_tables: 

353 partitioner = Partitioner(config) 

354 time_partition_start = astropy.time.Time( 

355 config.partitioning.time_partition_start, format="isot", scale="tai" 

356 ) 

357 time_partition_end = astropy.time.Time( 

358 config.partitioning.time_partition_end, format="isot", scale="tai" 

359 ) 

360 part_range_config = ApdbCassandraTimePartitionRange( 

361 start=partitioner.time_partition(time_partition_start), 

362 end=partitioner.time_partition(time_partition_end), 

363 ) 

364 schema.makeSchema( 

365 drop=drop, 

366 part_range=part_range_config, 

367 replication_factor=replication_factor, 

368 table_options=table_options, 

369 ) 

370 else: 

371 schema.makeSchema( 

372 drop=drop, replication_factor=replication_factor, table_options=table_options 

373 ) 

374 

375 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

376 metadata = ApdbMetadataCassandra( 

377 session, meta_table_name, config.keyspace, "read_tuples", "write" 

378 ) 

379 

380 # Fill version numbers, overrides if they existed before. 

381 metadata.set( 

382 ConnectionContext.metadataSchemaVersionKey, str(simple_schema.schemaVersion()), force=True 

383 ) 

384 metadata.set( 

385 ConnectionContext.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True 

386 ) 

387 

388 if config.enable_replica: 

389 # Only store replica code version if replica is enabled. 

390 metadata.set( 

391 ConnectionContext.metadataReplicaVersionKey, 

392 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

393 force=True, 

394 ) 

395 

396 # Store frozen part of a configuration in metadata. 

397 freezer = ApdbConfigFreezer[ApdbCassandraConfig](ConnectionContext.frozen_parameters) 

398 metadata.set(ConnectionContext.metadataConfigKey, freezer.to_json(config), force=True) 

399 

400 # Store time partition range. 

401 if part_range_config: 

402 part_range_config.save_to_meta(metadata) 

403 

404 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

405 # docstring is inherited from a base class 

406 context = self._context 

407 config = context.config 

408 

409 sp_where, num_sp_part = context.partitioner.spatial_where(region) 

410 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

411 

412 # We need to exclude extra partitioning columns from result. 

413 column_names = context.schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

414 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

415 query = Select(self._keyspace, table_name, column_names) 

416 statements: list[tuple] = [] 

417 for where_clause in sp_where: 

418 full_query = query.where(where_clause) 

419 statements.append(context.stmt_factory.with_params(full_query, prepare=True)) 

420 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

421 

422 with self._timer("select_time", tags={"table": "DiaObject", "method": "getDiaObjects"}) as timer: 

423 raw_objects = cast( 

424 ApdbCassandraTableData, 

425 select_concurrent( 

426 context.session, 

427 statements, 

428 "read_raw_multi", 

429 config.connection_config.read_concurrency, 

430 ), 

431 ) 

432 objects = raw_objects.to_pandas(context.schema._table_schema(ApdbTables.DiaObjectLast)) 

433 timer.add_values(row_count=len(objects), num_sp_part=num_sp_part, num_queries=len(statements)) 

434 

435 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

436 return objects 

437 

438 def getDiaSources( 

439 self, 

440 region: sphgeom.Region, 

441 object_ids: Iterable[int] | None, 

442 visit_time: astropy.time.Time, 

443 start_time: astropy.time.Time | None = None, 

444 ) -> pandas.DataFrame | None: 

445 # docstring is inherited from a base class 

446 context = self._context 

447 config = context.config 

448 

449 months = config.read_sources_months 

450 if start_time is None and months == 0: 

451 return None 

452 

453 mjd_end = float(visit_time.tai.mjd) 

454 if start_time is None: 

455 mjd_start = mjd_end - months * 30 

456 else: 

457 mjd_start = float(start_time.tai.mjd) 

458 

459 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

460 

461 def getDiaForcedSources( 

462 self, 

463 region: sphgeom.Region, 

464 object_ids: Iterable[int] | None, 

465 visit_time: astropy.time.Time, 

466 start_time: astropy.time.Time | None = None, 

467 ) -> pandas.DataFrame | None: 

468 # docstring is inherited from a base class 

469 context = self._context 

470 config = context.config 

471 

472 months = config.read_forced_sources_months 

473 if start_time is None and months == 0: 

474 return None 

475 

476 mjd_end = float(visit_time.tai.mjd) 

477 if start_time is None: 

478 mjd_start = mjd_end - months * 30 

479 else: 

480 mjd_start = float(start_time.tai.mjd) 

481 

482 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

483 

484 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame: 

485 # docstring is inherited from a base class 

486 context = self._context 

487 config = context.config 

488 

489 if not context.has_dedup_table: 

490 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.") 

491 

492 if since is None: 

493 # Read last deduplication time from metadata. 

494 dedup_str = context.metadata.get(context.metadataDedupKey) 

495 if dedup_str is not None: 

496 dedup_state = json.loads(dedup_str) 

497 dedup_time_str = dedup_state["dedup_time_iso_tai"] 

498 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai") 

499 

500 column_names = context.schema.apdbColumnNames(ExtraTables.DiaObjectDedup) 

501 

502 validity_start_column = self._timestamp_column_name("validityStart") 

503 timestamp = None if since is None else self._timestamp_column_value(since) 

504 

505 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup) 

506 query = Select(self._keyspace, table_name, column_names, extra_clause="ALLOW FILTERING") 

507 query = query.where(C("dedup_part") == 0) 

508 if since is not None: 

509 query = query.where(C(validity_start_column) >= 0) 

510 

511 statement = context.stmt_factory(query, prepare=False) 

512 

513 num_part = config.partitioning.num_part_dedup 

514 statements = [] 

515 for dedup_part in range(num_part): 

516 params = (dedup_part,) if timestamp is None else (dedup_part, timestamp) 

517 statements.append((statement, params)) 

518 

519 with self._timer( 

520 "select_time", tags={"table": "DiaObjectDedup", "method": "getDiaObjectsForDedup"} 

521 ) as timer: 

522 objects_raw = cast( 

523 ApdbCassandraTableData, 

524 select_concurrent( 

525 context.session, 

526 statements, 

527 "read_raw_multi_dedup", 

528 config.connection_config.read_concurrency, 

529 ), 

530 ) 

531 objects = objects_raw.to_pandas(context.schema._table_schema(ExtraTables.DiaObjectDedup)) 

532 timer.add_values(row_count=len(objects), num_queries=num_part) 

533 

534 _LOG.debug("found %s DiaObjectDedup records", objects.shape[0]) 

535 return objects 

536 

537 def getDiaSourcesForDiaObjects( 

538 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0 

539 ) -> pandas.DataFrame: 

540 # docstring is inherited from a base class 

541 context = self._context 

542 config = context.config 

543 

544 # Which tables to query and temporal constraints. 

545 end_time = self._current_time() 

546 tables, temporal_where = context.partitioner.temporal_where( 

547 ApdbTables.DiaSource, 

548 start_time, 

549 end_time, 

550 partitons_range=context.time_partitions_range, 

551 query_per_time_part=False, 

552 ) 

553 if not tables: 

554 warnings.warn( 

555 f"Query time range ({start_time.isot} - {end_time.isot}) does not overlap database " 

556 "time partitions." 

557 ) 

558 

559 # Group DiaObjects by partition. 

560 partitioned_object_ids = self._group_dia_objects_by_partition( 

561 context.partitioner, objects, max_dist_arcsec 

562 ) 

563 

564 # Columns to return. 

565 column_names = context.schema.apdbColumnNames(ApdbTables.DiaSource) 

566 

567 # Make a bunch of queries. 

568 statements = [] 

569 for apdb_part, diaObjectIds in partitioned_object_ids.items(): 

570 spatial_where = [C("apdb_part") == apdb_part] 

571 for table in tables: 

572 query = Select(self._keyspace, table, column_names, extra_clause="ALLOW FILTERING") 

573 for id_chunk in chunk_iterable(diaObjectIds, 10_000): 

574 id_where = C("diaObjectId").in_(id_chunk) 

575 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where): 

576 statements.append( 

577 context.stmt_factory.with_params(query.where(clause), prepare=False) 

578 ) 

579 

580 _LOG.debug("getDiaSourcesForDiaObjects #queries: %s", len(statements)) 

581 

582 with self._timer( 

583 "select_time", tags={"table": "DiaSource", "method": "getDiaSourcesForDiaObjects"} 

584 ) as timer: 

585 table_data_raw = cast( 

586 ApdbCassandraTableData, 

587 select_concurrent( 

588 context.session, 

589 statements, 

590 "read_raw_multi", 

591 config.connection_config.read_concurrency, 

592 ), 

593 ) 

594 catalog = table_data_raw.to_pandas(context.schema._table_schema(ApdbTables.DiaSource)) 

595 timer.add_values(row_count_from_db=len(catalog), num_queries=len(statements)) 

596 

597 # precise filtering on midpointMjdTai 

598 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] >= start_time.tai.mjd]) 

599 

600 timer.add_values(row_count=len(catalog)) 

601 

602 _LOG.debug("found %d DiaSources", len(catalog)) 

603 return catalog 

604 

605 def containsVisitDetector( 

606 self, 

607 visit: int, 

608 detector: int, 

609 region: sphgeom.Region, 

610 visit_time: astropy.time.Time, 

611 ) -> bool: 

612 # docstring is inherited from a base class 

613 context = self._context 

614 

615 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

616 query = Select(self._keyspace, table_name, [ColumnExpr("count(*)")]) 

617 query = query.where((C("visit") == visit) & (C("detector") == detector)) 

618 stmt, params = context.stmt_factory.with_params(query, prepare=False) 

619 

620 with self._timer("contains_visit_detector_time", tags={"table": table_name}): 

621 result = context.session.execute(stmt, params) 

622 return bool(result.one()[0]) 

623 

624 def store( 

625 self, 

626 visit_time: astropy.time.Time, 

627 objects: pandas.DataFrame, 

628 sources: pandas.DataFrame | None = None, 

629 forced_sources: pandas.DataFrame | None = None, 

630 ) -> None: 

631 # docstring is inherited from a base class 

632 context = self._context 

633 config = context.config 

634 

635 # Store visit/detector in a special table, this has to be done 

636 # before all other writes so if there is a failure at any point 

637 # later we still have a record for attempted write. 

638 visit_detector: set[tuple[int, int]] = set() 

639 for df in sources, forced_sources: 

640 if df is not None and not df.empty: 

641 df = df[["visit", "detector"]] 

642 for visit, detector in df.itertuples(index=False): 

643 visit_detector.add((visit, detector)) 

644 

645 if visit_detector: 

646 # Typically there is only one entry, do not bother with 

647 # concurrency. 

648 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

649 query = Insert(self._keyspace, table_name, ("visit", "detector")) 

650 stmt = context.stmt_factory(query) 

651 for item in visit_detector: 

652 context.session.execute(stmt, item, execution_profile="write") 

653 

654 objects = self._fix_input_timestamps(objects) 

655 if sources is not None: 

656 sources = self._fix_input_timestamps(sources) 

657 if forced_sources is not None: 

658 forced_sources = self._fix_input_timestamps(forced_sources) 

659 

660 replica_chunk: ReplicaChunk | None = None 

661 if context.schema.replication_enabled: 

662 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, config.replica_chunk_seconds) 

663 self._storeReplicaChunk(replica_chunk) 

664 

665 # fill region partition column for DiaObjects 

666 objects = self._add_apdb_part(objects) 

667 self._storeDiaObjects(objects, visit_time, replica_chunk) 

668 

669 if sources is not None and len(sources) > 0: 

670 # copy apdb_part column from DiaObjects to DiaSources 

671 sources = self._add_apdb_part(sources) 

672 subchunk = self._storeDiaSources(ApdbTables.DiaSource, sources, replica_chunk) 

673 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk, subchunk) 

674 

675 if forced_sources is not None and len(forced_sources) > 0: 

676 forced_sources = self._add_apdb_part(forced_sources) 

677 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, replica_chunk) 

678 

679 def reassignDiaSourcesToDiaObjects( 

680 self, 

681 idMap: Mapping[DiaSourceId, int], 

682 *, 

683 increment_nDiaSources: bool = True, 

684 decrement_nDiaSources: bool = True, 

685 ) -> None: 

686 # docstring is inherited from a base class 

687 context = self._context 

688 config = context.config 

689 

690 source_ids = {source_id.diaSourceId for source_id in idMap} 

691 

692 # Find all DiaSources. 

693 found_sources = self._get_diasource_data( 

694 idMap, "apdb_part", "diaObjectId", "ra", "dec", "midpointMjdTai" 

695 ) 

696 

697 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}): 

698 raise LookupError(f"Some source IDs were not found in DiaSource table: {missing_ids}") 

699 

700 found_sources_by_id = {row.diaSourceId: row for row in found_sources} 

701 

702 # Make sure that all DiaObjects exist, we also want to know 

703 # nDiaSources count for current and new records because we want to 

704 # send updated values to replica. 

705 current_object_ids = { 

706 DiaObjectId(diaObjectId=row.diaObjectId, ra=row.ra, dec=row.dec) for row in found_sources 

707 } 

708 # Assume that DiaSource ra/dec are very close to re-assigned objects. 

709 new_object_ids = { 

710 DiaObjectId(diaObjectId=diaObjectId, ra=source_id.ra, dec=source_id.dec) 

711 for source_id, diaObjectId in idMap.items() 

712 } 

713 all_object_ids = new_object_ids | current_object_ids 

714 found_objects = self._get_diaobject_data(all_object_ids, "apdb_part", "ra", "dec", "nDiaSources") 

715 

716 if missing_ids := ( 

717 {row.diaObjectId for row in all_object_ids} - {row.diaObjectId for row in found_objects} 

718 ): 

719 raise LookupError(f"Some object IDs were not found in DiaObjectLast table: {missing_ids}") 

720 

721 update_records: list[ApdbUpdateRecord] = [] 

722 update_order = 0 

723 current_time = self._current_time() 

724 current_time_ns = int(current_time.unix_tai * 1e9) 

725 

726 # Update DiaSources. 

727 statements: list[tuple] = [] 

728 for source_id, diaObjectId in idMap.items(): 

729 source_row = found_sources_by_id[source_id.diaSourceId] 

730 apdb_part = source_row.apdb_part 

731 time_part = context.partitioner.time_partition(source_row.midpointMjdTai) 

732 

733 if config.partitioning.time_partition_tables: 

734 table_name = context.schema.tableName(ApdbTables.DiaSource, time_part) 

735 update = ( 

736 Update(self._keyspace, table_name) 

737 .values(C("diaObjectId").update(diaObjectId)) 

738 .where(C("apdb_part") == apdb_part) 

739 .where(C("diaSourceId") == source_id.diaSourceId) 

740 ) 

741 else: 

742 table_name = context.schema.tableName(ApdbTables.DiaSource) 

743 update = ( 

744 Update(self._keyspace, table_name) 

745 .values(C("diaObjectId").update(diaObjectId)) 

746 .where(C("apdb_part") == apdb_part) 

747 .where(C("apdb_time_part") == time_part) 

748 .where(C("diaSourceId") == source_id.diaSourceId) 

749 ) 

750 statements.append(context.stmt_factory.with_params(update, prepare=True)) 

751 

752 if context.schema.replication_enabled: 

753 update_records.append( 

754 ApdbReassignDiaSourceToDiaObjectRecord( 

755 diaSourceId=source_id.diaSourceId, 

756 ra=source_id.ra, 

757 dec=source_id.dec, 

758 midpointMjdTai=source_id.midpointMjdTai, 

759 diaObjectId=diaObjectId, 

760 update_time_ns=current_time_ns, 

761 update_order=update_order, 

762 ) 

763 ) 

764 update_order += 1 

765 

766 with self._timer( 

767 "update_time", tags={"table": "DiaSource", "method": "reassignDiaSourcesToDiaObjects"} 

768 ) as timer: 

769 execute_concurrent(context.session, statements, execution_profile="write") 

770 timer.add_values(num_queries=len(statements)) 

771 

772 # Update nDiaSources in DiaObjectLast. We do not update DiaObject table 

773 # here because it may not even exist. PPDB updates DiaObject from 

774 # update records. 

775 if increment_nDiaSources or decrement_nDiaSources: 

776 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

777 update = ( 

778 Update(self._keyspace, table_name) 

779 .values(C("nDiaSources").update(-1)) 

780 .where(C("apdb_part") == -1) 

781 .where(C("diaObjectId") == -1) 

782 ) 

783 statement = context.stmt_factory(update, prepare=True) 

784 statements = [] 

785 

786 # Calculate increments/decrements for all affected DiaObjects. 

787 increments: Counter = Counter() 

788 if increment_nDiaSources: 

789 increments.update(idMap.values()) 

790 if decrement_nDiaSources: 

791 increments.subtract(row.diaObjectId for row in found_sources) 

792 

793 for row in found_objects: 

794 if increments.get(row.diaObjectId): 

795 nDiaSources = row.nDiaSources + increments[row.diaObjectId] 

796 statements.append((statement, (nDiaSources, row.apdb_part, row.diaObjectId))) 

797 

798 # Also send updated values to replica. 

799 if context.schema.replication_enabled: 

800 update_records.append( 

801 ApdbUpdateNDiaSourcesRecord( 

802 diaObjectId=row.diaObjectId, 

803 ra=row.ra, 

804 dec=row.dec, 

805 nDiaSources=nDiaSources, 

806 update_time_ns=current_time_ns, 

807 update_order=update_order, 

808 ) 

809 ) 

810 update_order += 1 

811 

812 if statements: 

813 with self._timer( 

814 "update_time", tags={"table": table_name, "method": "reassignDiaSourcesToDiaObjects"} 

815 ) as timer: 

816 execute_concurrent(context.session, statements, execution_profile="write") 

817 timer.add_values(num_queries=len(statements)) 

818 

819 if update_records: 

820 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds) 

821 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) 

822 

823 def setValidityEnd( 

824 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False 

825 ) -> int: 

826 # docstring is inherited from a base class 

827 if not objects: 

828 return 0 

829 

830 context = self._context 

831 config = context.config 

832 

833 pad_arcsec = 1.0 

834 partitioned_object_ids = self._group_dia_objects_by_partition( 

835 context.partitioner, objects, pad_arcsec 

836 ) 

837 

838 # Check that all objects exist. 

839 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

840 statements: list[tuple] = [] 

841 for apdb_part, diaObjectIds in partitioned_object_ids.items(): 

842 query = Select(self._keyspace, table_name, ["apdb_part", "diaObjectId"]) 

843 query = query.where(C("apdb_part") == apdb_part) 

844 query = query.where(C("diaObjectId").in_(diaObjectIds)) 

845 statements.append(context.stmt_factory.with_params(query, prepare=False)) 

846 

847 with self._timer("select_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer: 

848 records = cast( 

849 list[tuple[int, int]], 

850 select_concurrent( 

851 context.session, 

852 statements, 

853 "read_tuples", 

854 config.connection_config.read_concurrency, 

855 ), 

856 ) 

857 timer.add_values(row_count=len(objects)) 

858 

859 requested_ids = {obj.diaObjectId for obj in objects} 

860 found_ids = {rec[1] for rec in records} 

861 if extra_ids := (found_ids - requested_ids): 

862 raise RuntimeError(f"Consistency error - found duplicate records for object IDs: {extra_ids}") 

863 if raise_on_missing_id: 

864 if missing_ids := (requested_ids - found_ids): 

865 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}") 

866 

867 # Filter existing records. 

868 if len(objects) != len(found_ids): 

869 objects = [obj for obj in objects if obj.diaObjectId in found_ids] 

870 

871 if not objects: 

872 return 0 

873 

874 # Group by partitions again. 

875 grouped_object_ids: dict[int, list[int]] = defaultdict(list) 

876 for apdb_part, diaObjectId in records: 

877 grouped_object_ids[apdb_part].append(diaObjectId) 

878 

879 # Remove all matching rows from DiaObjectLast. 

880 statements = [] 

881 for apdb_part, diaObjectIds in grouped_object_ids.items(): 

882 delete = ( 

883 Delete(self._keyspace, table_name) 

884 .where(C("apdb_part") == apdb_part) 

885 .where(C("diaObjectId").in_(diaObjectIds)) 

886 ) 

887 statements.append(context.stmt_factory.with_params(delete)) 

888 

889 # Also remove from DiaObjectLastToPartition. 

890 reverse_table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

891 delete = Delete(self._keyspace, reverse_table_name).where( 

892 C("diaObjectId").in_([rec[1] for rec in records]) 

893 ) 

894 statements.append(context.stmt_factory.with_params(delete)) 

895 

896 with self._timer("delete_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer: 

897 execute_concurrent(context.session, statements, execution_profile="write") 

898 timer.add_values(row_count=len(records)) 

899 

900 # If repication is enabled then send all updates. 

901 if context.schema.replication_enabled: 

902 current_time = self._current_time() 

903 current_time_ns = int(current_time.unix_tai * 1e9) 

904 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds) 

905 

906 update_records = [ 

907 ApdbCloseDiaObjectValidityRecord( 

908 diaObjectId=obj.diaObjectId, 

909 ra=obj.ra, 

910 dec=obj.dec, 

911 update_time_ns=current_time_ns, 

912 update_order=index, 

913 validityEndMjdTai=float(validityEnd.tai.mjd), 

914 nDiaSources=None, 

915 ) 

916 for index, obj in enumerate(objects) 

917 ] 

918 

919 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) 

920 

921 return len(objects) 

922 

923 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None: 

924 # docstring is inherited from a base class 

925 context = self._context 

926 

927 if not context.has_dedup_table: 

928 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.") 

929 

930 if dedup_time is None: 

931 dedup_time = self._current_time() 

932 

933 validity_start_column = self._timestamp_column_name("validityStart") 

934 

935 # Find latest timestamp in deduplication table. 

936 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup) 

937 query = Select(self._keyspace, table_name, [ColumnExpr(f'MAX("{validity_start_column}")')]) 

938 stmt = context.stmt_factory(query, prepare=False) 

939 

940 result = context.session.execute(stmt, execution_profile="read_tuples") 

941 max_value = result.one()[0] 

942 if self._schema.has_mjd_timestamps: 

943 max_validity_start = astropy.time.Time(max_value, format="mjd", scale="tai") 

944 else: 

945 max_validity_start = astropy.time.Time(max_value, format="datetime", scale="tai") 

946 

947 # If max time is lower than dedup time we can do TRUNCATE. 

948 if dedup_time >= max_validity_start: 

949 query_str = f'TRUNCATE TABLE "{self._keyspace}"."{table_name}"' 

950 context.session.execute(query_str, execution_profile="write") 

951 else: 

952 dedup_time_value = self._timestamp_column_value(dedup_time) 

953 delete = Delete(self._keyspace, table_name).where(C(validity_start_column) < dedup_time_value) 

954 stmt, params = context.stmt_factory.with_params(delete) 

955 context.session.execute(stmt, params, execution_profile="write") 

956 

957 # Store dedup time. 

958 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")} 

959 data_json = json.dumps(data) 

960 context.metadata.set(context.metadataDedupKey, data_json, force=True) 

961 

962 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

963 # docstring is inherited from a base class 

964 context = self._context 

965 config = context.config 

966 

967 now = self._current_time() 

968 reassign_time_column = self._timestamp_column_name("ssObjectReassocTime") 

969 reassignTime = self._timestamp_column_value(now) 

970 

971 # To update a record we need to know its exact primary key (including 

972 # partition key) so we start by querying for diaSourceId to find the 

973 # primary keys. 

974 

975 table_name = context.schema.tableName(ExtraTables.DiaSourceToPartition) 

976 # split it into 1k IDs per query 

977 selects: list[tuple] = [] 

978 columns = ["diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk"] 

979 query = Select(self._keyspace, table_name, columns) 

980 for ids in chunk_iterable(idMap.keys(), 1_000): 

981 full_query = query.where(C("diaSourceId").in_(ids)) 

982 selects.append(context.stmt_factory.with_params(full_query, prepare=False)) 

983 

984 # No need for DataFrame here, read data as tuples. 

985 result = cast( 

986 list[tuple[int, int, int, int | None]], 

987 select_concurrent( 

988 context.session, selects, "read_tuples", config.connection_config.read_concurrency 

989 ), 

990 ) 

991 

992 # Make mapping from source ID to its partition. 

993 id2partitions: dict[int, tuple[int, int]] = {} 

994 id2chunk_id: dict[int, int] = {} 

995 for row in result: 

996 id2partitions[row[0]] = row[1:3] 

997 if row[3] is not None: 

998 id2chunk_id[row[0]] = row[3] 

999 

1000 # make sure we know partitions for each ID 

1001 if set(id2partitions) != set(idMap): 

1002 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

1003 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

1004 

1005 # Reassign in standard tables 

1006 queries: list[tuple[cassandra.query.PreparedStatement, tuple]] = [] 

1007 for diaSourceId, ssObjectId in idMap.items(): 

1008 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

1009 if config.partitioning.time_partition_tables: 

1010 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part) 

1011 update = ( 

1012 Update(self._keyspace, table_name) 

1013 .values( 

1014 C("ssObjectId").update(ssObjectId), 

1015 C("diaObjectId").update(None), 

1016 C(reassign_time_column).update(reassignTime), 

1017 ) 

1018 .where(C("apdb_part") == apdb_part) 

1019 .where(C("diaSourceId") == diaSourceId) 

1020 ) 

1021 else: 

1022 table_name = context.schema.tableName(ApdbTables.DiaSource) 

1023 update = ( 

1024 Update(self._keyspace, table_name) 

1025 .values( 

1026 C("ssObjectId").update(ssObjectId), 

1027 C("diaObjectId").update(None), 

1028 C(reassign_time_column).update(reassignTime), 

1029 ) 

1030 .where(C("apdb_part") == apdb_part) 

1031 .where(C("apdb_time_part") == apdb_time_part) 

1032 .where(C("diaSourceId") == diaSourceId) 

1033 ) 

1034 queries.append(context.stmt_factory.with_params(update, prepare=True)) 

1035 

1036 # TODO: (DM-50190) Replication for updated records is not implemented. 

1037 if id2chunk_id: 

1038 warnings.warn("Replication of reassigned DiaSource records is not implemented.", stacklevel=2) 

1039 

1040 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

1041 with self._timer("source_reassign_time") as timer: 

1042 execute_concurrent(context.session, queries, execution_profile="write") 

1043 timer.add_values(source_count=len(idMap)) 

1044 

1045 def countUnassociatedObjects(self) -> int: 

1046 # docstring is inherited from a base class 

1047 

1048 # It's too inefficient to implement it for Cassandra in current schema. 

1049 raise NotImplementedError() 

1050 

1051 @property 

1052 def schema(self) -> ApdbSchema: 

1053 # docstring is inherited from a base class 

1054 return self._schema 

1055 

1056 @property 

1057 def metadata(self) -> ApdbMetadata: 

1058 # docstring is inherited from a base class 

1059 context = self._context 

1060 return context.metadata 

1061 

1062 @property 

1063 def admin(self) -> ApdbCassandraAdmin: 

1064 # docstring is inherited from a base class 

1065 return ApdbCassandraAdmin(self) 

1066 

1067 def _getSources( 

1068 self, 

1069 region: sphgeom.Region, 

1070 object_ids: Iterable[int] | None, 

1071 mjd_start: float, 

1072 mjd_end: float, 

1073 table_name: ApdbTables, 

1074 ) -> pandas.DataFrame: 

1075 """Return catalog of DiaSource instances given set of DiaObject IDs. 

1076 

1077 Parameters 

1078 ---------- 

1079 region : `lsst.sphgeom.Region` 

1080 Spherical region. 

1081 object_ids : 

1082 Collection of DiaObject IDs 

1083 mjd_start : `float` 

1084 Lower bound of time interval. 

1085 mjd_end : `float` 

1086 Upper bound of time interval. 

1087 table_name : `ApdbTables` 

1088 Name of the table. 

1089 

1090 Returns 

1091 ------- 

1092 catalog : `pandas.DataFrame`, or `None` 

1093 Catalog containing DiaSource records. Empty catalog is returned if 

1094 ``object_ids`` is empty. 

1095 """ 

1096 context = self._context 

1097 config = context.config 

1098 

1099 object_id_set: Set[int] = set() 

1100 if object_ids is not None: 

1101 object_id_set = set(object_ids) 

1102 if len(object_id_set) == 0: 

1103 return self._make_empty_catalog(table_name) 

1104 

1105 sp_where, num_sp_part = context.partitioner.spatial_where(region) 

1106 tables, temporal_where = context.partitioner.temporal_where( 

1107 table_name, mjd_start, mjd_end, partitons_range=context.time_partitions_range 

1108 ) 

1109 if not tables: 

1110 start = astropy.time.Time(mjd_start, format="mjd", scale="tai") 

1111 end = astropy.time.Time(mjd_end, format="mjd", scale="tai") 

1112 warnings.warn( 

1113 f"Query time range ({start.isot} - {end.isot}) does not overlap database time partitions." 

1114 ) 

1115 

1116 # We need to exclude extra partitioning columns from result. 

1117 column_names = context.schema.apdbColumnNames(table_name) 

1118 

1119 # Build all queries 

1120 statements: list[tuple] = [] 

1121 for table in tables: 

1122 query = Select(self._keyspace, table, column_names) 

1123 for clause in QExpr.combine(sp_where, temporal_where): 

1124 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True)) 

1125 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

1126 

1127 with self._timer("select_time", tags={"table": table_name.name, "method": "_getSources"}) as timer: 

1128 table_data_raw = cast( 

1129 ApdbCassandraTableData, 

1130 select_concurrent( 

1131 context.session, 

1132 statements, 

1133 "read_raw_multi", 

1134 config.connection_config.read_concurrency, 

1135 ), 

1136 ) 

1137 catalog = table_data_raw.to_pandas(context.schema._table_schema(table_name)) 

1138 timer.add_values( 

1139 row_count_from_db=len(catalog), num_sp_part=num_sp_part, num_queries=len(statements) 

1140 ) 

1141 

1142 # filter by given object IDs 

1143 if len(object_id_set) > 0: 

1144 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1145 

1146 # precise filtering on midpointMjdTai 

1147 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1148 

1149 timer.add_values(row_count=len(catalog)) 

1150 

1151 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1152 return catalog 

1153 

1154 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk) -> None: 

1155 context = self._context 

1156 config = context.config 

1157 

1158 # Cassandra timestamp uses milliseconds since epoch 

1159 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

1160 

1161 # everything goes into a single partition 

1162 partition = 0 

1163 

1164 table_name = context.schema.tableName(ExtraTables.ApdbReplicaChunks) 

1165 

1166 columns = ["partition", "apdb_replica_chunk", "last_update_time", "unique_id"] 

1167 values = [partition, replica_chunk.id, timestamp, replica_chunk.unique_id] 

1168 if context.has_chunk_sub_partitions: 

1169 columns.append("has_subchunks") 

1170 values.append(True) 

1171 

1172 query = Insert(self._keyspace, table_name, columns) 

1173 stmt = context.stmt_factory(query) 

1174 

1175 context.session.execute( 

1176 stmt, 

1177 values, 

1178 timeout=config.connection_config.write_timeout, 

1179 execution_profile="write", 

1180 ) 

1181 

1182 def _queryDiaObjectLastPartitions(self, ids: Iterable[int]) -> Mapping[int, int]: 

1183 """Return existing mapping of diaObjectId to its last partition.""" 

1184 context = self._context 

1185 config = context.config 

1186 

1187 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

1188 queries = [] 

1189 object_count = 0 

1190 for id_chunk in chunk_iterable(ids, 10_000): 

1191 id_chunk_list = tuple(id_chunk) 

1192 query = Select(self._keyspace, table_name, ("diaObjectId", "apdb_part")) 

1193 query = query.where(C("diaObjectId").in_(id_chunk_list)) 

1194 queries.append(context.stmt_factory.with_params(query, prepare=False)) 

1195 object_count += len(id_chunk_list) 

1196 

1197 with self._timer("query_object_last_partitions", tags={"table": table_name}) as timer: 

1198 data = cast( 

1199 ApdbTableData, 

1200 select_concurrent( 

1201 context.session, 

1202 queries, 

1203 "read_raw_multi", 

1204 config.connection_config.read_concurrency, 

1205 ), 

1206 ) 

1207 timer.add_values(object_count=object_count, row_count=len(data.rows())) 

1208 

1209 if data.column_names() != ["diaObjectId", "apdb_part"]: 

1210 raise RuntimeError(f"Unexpected column names in query result: {data.column_names()}") 

1211 

1212 return {row[0]: row[1] for row in data.rows()} 

1213 

1214 def _deleteMovingObjects(self, objs: pandas.DataFrame) -> None: 

1215 """Objects in DiaObjectsLast can move from one spatial partition to 

1216 another. For those objects inserting new version does not replace old 

1217 one, so we need to explicitly remove old versions before inserting new 

1218 ones. 

1219 """ 

1220 context = self._context 

1221 

1222 # Extract all object IDs. 

1223 new_partitions = dict(zip(objs["diaObjectId"], objs["apdb_part"])) 

1224 old_partitions = self._queryDiaObjectLastPartitions(objs["diaObjectId"]) 

1225 

1226 moved_oids: dict[int, tuple[int, int]] = {} 

1227 for oid, old_part in old_partitions.items(): 

1228 new_part = new_partitions.get(oid, old_part) 

1229 if new_part != old_part: 

1230 moved_oids[oid] = (old_part, new_part) 

1231 _LOG.debug("DiaObject IDs that moved to new partition: %s", moved_oids) 

1232 

1233 if moved_oids: 

1234 # Delete old records from DiaObjectLast. 

1235 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

1236 query = Delete(self._keyspace, table_name) 

1237 query = query.where('apdb_part = {} AND "diaObjectId" = {}', (-1, -1)) 

1238 statement = context.stmt_factory(query, prepare=True) 

1239 queries = [] 

1240 for oid, (old_part, _) in moved_oids.items(): 

1241 queries.append((statement, (old_part, oid))) 

1242 with self._timer("delete_object_last", tags={"table": table_name}) as timer: 

1243 execute_concurrent(context.session, queries, execution_profile="write") 

1244 timer.add_values(row_count=len(moved_oids)) 

1245 

1246 # Add all new records to the map. 

1247 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

1248 insert = Insert(self._keyspace, table_name, ("diaObjectId", "apdb_part")) 

1249 statement = context.stmt_factory(insert, prepare=True) 

1250 

1251 queries = [] 

1252 for oid, new_part in new_partitions.items(): 

1253 queries.append((statement, (oid, new_part))) 

1254 

1255 with self._timer("update_object_last_partition", tags={"table": table_name}) as timer: 

1256 execute_concurrent(context.session, queries, execution_profile="write") 

1257 timer.add_values(row_count=len(queries)) 

1258 

1259 def _storeDiaObjects( 

1260 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1261 ) -> None: 

1262 """Store catalog of DiaObjects from current visit. 

1263 

1264 Parameters 

1265 ---------- 

1266 objs : `pandas.DataFrame` 

1267 Catalog with DiaObject records 

1268 visit_time : `astropy.time.Time` 

1269 Time of the current visit. 

1270 replica_chunk : `ReplicaChunk` or `None` 

1271 Replica chunk identifier if replication is configured. 

1272 """ 

1273 if len(objs) == 0: 

1274 _LOG.debug("No objects to write to database.") 

1275 return 

1276 

1277 context = self._context 

1278 config = context.config 

1279 

1280 self._deleteMovingObjects(objs) 

1281 

1282 validity_start_column = self._timestamp_column_name("validityStart") 

1283 timestamp = self._timestamp_column_value(visit_time) 

1284 

1285 # DiaObjectLast did not have this column in the past. 

1286 extra_columns: dict[str, Any] = {} 

1287 if context.schema.check_column(ApdbTables.DiaObjectLast, validity_start_column): 

1288 extra_columns[validity_start_column] = timestamp 

1289 

1290 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1291 

1292 extra_columns[validity_start_column] = timestamp 

1293 visit_time_part = context.partitioner.time_partition(visit_time) 

1294 time_part: int | None = visit_time_part 

1295 if (time_partitions_range := context.time_partitions_range) is not None: 

1296 self._check_time_partitions([visit_time_part], time_partitions_range) 

1297 if not config.partitioning.time_partition_tables: 

1298 extra_columns["apdb_time_part"] = time_part 

1299 time_part = None 

1300 

1301 # Only store DiaObects if not doing replication or explicitly 

1302 # configured to always store them. 

1303 if replica_chunk is None or not config.replica_skips_diaobjects: 

1304 self._storeObjectsPandas( 

1305 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1306 ) 

1307 

1308 if replica_chunk is not None: 

1309 extra_columns = {"apdb_replica_chunk": replica_chunk.id, validity_start_column: timestamp} 

1310 table = ExtraTables.DiaObjectChunks 

1311 if context.has_chunk_sub_partitions: 

1312 table = ExtraTables.DiaObjectChunks2 

1313 # Use a random number for a second part of partitioning key so 

1314 # that different clients could wrtite to different partitions. 

1315 # This makes it not exactly reproducible. 

1316 extra_columns["apdb_replica_subchunk"] = random.randrange(config.replica_sub_chunk_count) 

1317 self._storeObjectsPandas(objs, table, extra_columns=extra_columns) 

1318 

1319 # Store copy of the records in dedup table. 

1320 if context.has_dedup_table: 

1321 table = ExtraTables.DiaObjectDedup 

1322 extra_columns = { 

1323 "dedup_part": random.randrange(config.partitioning.num_part_dedup), 

1324 validity_start_column: timestamp, 

1325 } 

1326 self._storeObjectsPandas(objs, table, extra_columns=extra_columns) 

1327 

1328 def _storeDiaSources( 

1329 self, 

1330 table_name: ApdbTables, 

1331 sources: pandas.DataFrame, 

1332 replica_chunk: ReplicaChunk | None, 

1333 ) -> int | None: 

1334 """Store catalog of DIASources or DIAForcedSources from current visit. 

1335 

1336 Parameters 

1337 ---------- 

1338 table_name : `ApdbTables` 

1339 Table where to store the data. 

1340 sources : `pandas.DataFrame` 

1341 Catalog containing DiaSource records 

1342 visit_time : `astropy.time.Time` 

1343 Time of the current visit. 

1344 replica_chunk : `ReplicaChunk` or `None` 

1345 Replica chunk identifier if replication is configured. 

1346 

1347 Returns 

1348 ------- 

1349 subchunk : `int` or `None` 

1350 Subchunk number for resulting replica data, `None` if relication is 

1351 not enabled ot subchunking is not enabled. 

1352 """ 

1353 context = self._context 

1354 config = context.config 

1355 

1356 # Time partitioning has to be based on midpointMjdTai, not visit_time 

1357 # as visit_time is not really a visit time. 

1358 tp_sources = sources.copy(deep=False) 

1359 tp_sources["apdb_time_part"] = tp_sources["midpointMjdTai"].apply(context.partitioner.time_partition) 

1360 if (time_partitions_range := context.time_partitions_range) is not None: 

1361 self._check_time_partitions(tp_sources["apdb_time_part"], time_partitions_range) 

1362 extra_columns: dict[str, Any] = {} 

1363 if not config.partitioning.time_partition_tables: 

1364 self._storeObjectsPandas(tp_sources, table_name) 

1365 else: 

1366 # Group by time partition 

1367 partitions = set(tp_sources["apdb_time_part"]) 

1368 if len(partitions) == 1: 

1369 # Single partition - just save the whole thing. 

1370 time_part = partitions.pop() 

1371 self._storeObjectsPandas(sources, table_name, time_part=time_part) 

1372 else: 

1373 # group by time partition. 

1374 for time_part, sub_frame in tp_sources.groupby(by="apdb_time_part"): 

1375 sub_frame.drop(columns="apdb_time_part", inplace=True) 

1376 self._storeObjectsPandas(sub_frame, table_name, time_part=time_part) 

1377 

1378 subchunk: int | None = None 

1379 if replica_chunk is not None: 

1380 extra_columns = {"apdb_replica_chunk": replica_chunk.id} 

1381 if context.has_chunk_sub_partitions: 

1382 subchunk = random.randrange(config.replica_sub_chunk_count) 

1383 extra_columns["apdb_replica_subchunk"] = subchunk 

1384 if table_name is ApdbTables.DiaSource: 

1385 extra_table = ExtraTables.DiaSourceChunks2 

1386 else: 

1387 extra_table = ExtraTables.DiaForcedSourceChunks2 

1388 else: 

1389 if table_name is ApdbTables.DiaSource: 

1390 extra_table = ExtraTables.DiaSourceChunks 

1391 else: 

1392 extra_table = ExtraTables.DiaForcedSourceChunks 

1393 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1394 

1395 return subchunk 

1396 

1397 def _check_time_partitions( 

1398 self, partitions: Iterable[int], time_partitions_range: ApdbCassandraTimePartitionRange 

1399 ) -> None: 

1400 """Check that time partitons for new data actually exist. 

1401 

1402 Parameters 

1403 ---------- 

1404 partitions : `~collections.abc.Iterable` [`int`] 

1405 Time partitions for new data. 

1406 time_partitions_range : `ApdbCassandraTimePartitionRange` 

1407 Currrent time partition range. 

1408 """ 

1409 partitions = set(partitions) 

1410 min_part = min(partitions) 

1411 max_part = max(partitions) 

1412 if min_part < time_partitions_range.start or max_part > time_partitions_range.end: 

1413 raise ValueError( 

1414 "Attempt to store data for time partitions that do not yet exist. " 

1415 f"Partitons for new records: {min_part}-{max_part}. " 

1416 f"Database partitons: {time_partitions_range.start}-{time_partitions_range.end}." 

1417 ) 

1418 # Make a noise when writing to the last partition. 

1419 if max_part == time_partitions_range.end: 

1420 warnings.warn( 

1421 "Writing into the last temporal partition. Partition range needs to be extended soon.", 

1422 stacklevel=3, 

1423 ) 

1424 

1425 def _storeDiaSourcesPartitions( 

1426 self, 

1427 sources: pandas.DataFrame, 

1428 visit_time: astropy.time.Time, 

1429 replica_chunk: ReplicaChunk | None, 

1430 subchunk: int | None, 

1431 ) -> None: 

1432 """Store mapping of diaSourceId to its partitioning values. 

1433 

1434 Parameters 

1435 ---------- 

1436 sources : `pandas.DataFrame` 

1437 Catalog containing DiaSource records 

1438 visit_time : `astropy.time.Time` 

1439 Time of the current visit. 

1440 replica_chunk : `ReplicaChunk` or `None` 

1441 Replication chunk, or `None` when replication is disabled. 

1442 subchunk : `int` or `None` 

1443 Replication sub-chunk, or `None` when replication is disabled or 

1444 sub-chunking is not used. 

1445 """ 

1446 context = self._context 

1447 

1448 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1449 extra_columns = { 

1450 "apdb_time_part": context.partitioner.time_partition(visit_time), 

1451 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1452 } 

1453 if context.has_chunk_sub_partitions: 

1454 extra_columns["apdb_replica_subchunk"] = subchunk 

1455 

1456 self._storeObjectsPandas( 

1457 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1458 ) 

1459 

1460 def _storeObjectsPandas( 

1461 self, 

1462 records: pandas.DataFrame, 

1463 table_name: ApdbTables | ExtraTables, 

1464 extra_columns: Mapping | None = None, 

1465 time_part: int | None = None, 

1466 ) -> None: 

1467 """Store generic objects. 

1468 

1469 Takes Pandas catalog and stores a bunch of records in a table. 

1470 

1471 Parameters 

1472 ---------- 

1473 records : `pandas.DataFrame` 

1474 Catalog containing object records 

1475 table_name : `ApdbTables` 

1476 Name of the table as defined in APDB schema. 

1477 extra_columns : `dict`, optional 

1478 Mapping (column_name, column_value) which gives fixed values for 

1479 columns in each row, overrides values in ``records`` if matching 

1480 columns exist there. 

1481 time_part : `int`, optional 

1482 If not `None` then insert into a per-partition table. 

1483 

1484 Notes 

1485 ----- 

1486 If Pandas catalog contains additional columns not defined in table 

1487 schema they are ignored. Catalog does not have to contain all columns 

1488 defined in a table, but partition and clustering keys must be present 

1489 in a catalog or ``extra_columns``. 

1490 """ 

1491 context = self._context 

1492 

1493 # use extra columns if specified 

1494 if extra_columns is None: 

1495 extra_columns = {} 

1496 extra_fields = list(extra_columns.keys()) 

1497 

1498 # Fields that will come from dataframe. 

1499 df_fields = [column for column in records.columns if column not in extra_fields] 

1500 

1501 column_map = context.schema.getColumnMap(table_name) 

1502 # list of columns (as in felis schema) 

1503 fields = [column_map[field].name for field in df_fields if field in column_map] 

1504 fields += extra_fields 

1505 

1506 # check that all partitioning and clustering columns are defined 

1507 partition_columns = context.schema.partitionColumns(table_name) 

1508 required_columns = partition_columns + context.schema.clusteringColumns(table_name) 

1509 missing_columns = [column for column in required_columns if column not in fields] 

1510 if missing_columns: 

1511 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1512 

1513 batch_size = self._batch_size(table_name) 

1514 

1515 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1516 # Multi-partition batches are problematic in general, so we want to 

1517 # group records in a batch by their partition key. 

1518 values_by_key: dict[tuple, list[list]] = defaultdict(list) 

1519 for rec in records.itertuples(index=False): 

1520 values = [] 

1521 partitioning_values: dict[str, Any] = {} 

1522 for field in df_fields: 

1523 if field not in column_map: 

1524 continue 

1525 value = getattr(rec, field) 

1526 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1527 if isinstance(value, pandas.Timestamp): 

1528 value = value.to_pydatetime() 

1529 elif value is pandas.NaT: 

1530 value = None 

1531 else: 

1532 # Assume it's seconds since epoch, Cassandra 

1533 # datetime is in milliseconds 

1534 value = int(value * 1000) 

1535 value = literal(value) 

1536 values.append(UNSET_VALUE if value is None else value) 

1537 if field in partition_columns: 

1538 partitioning_values[field] = value 

1539 for field in extra_fields: 

1540 value = literal(extra_columns[field]) 

1541 values.append(UNSET_VALUE if value is None else value) 

1542 if field in partition_columns: 

1543 partitioning_values[field] = value 

1544 

1545 key = tuple(partitioning_values[field] for field in partition_columns) 

1546 values_by_key[key].append(values) 

1547 

1548 table = context.schema.tableName(table_name, time_part) 

1549 

1550 query = Insert(self._keyspace, table, fields) 

1551 statement = context.stmt_factory(query, prepare=True) 

1552 # Cassandra has 64k limit on batch size, normally that should be 

1553 # enough but some tests generate too many forced sources. 

1554 queries = [] 

1555 for key_values in values_by_key.values(): 

1556 for values_chunk in chunk_iterable(key_values, batch_size): 

1557 batch = cassandra.query.BatchStatement() 

1558 for row_values in values_chunk: 

1559 batch.add(statement, row_values) 

1560 queries.append((batch, None)) 

1561 assert batch.routing_key is not None and batch.keyspace is not None 

1562 

1563 _LOG.debug("%s: will store %d records", context.schema.tableName(table_name), records.shape[0]) 

1564 with self._timer( 

1565 "insert_time", tags={"table": table_name.name, "method": "_storeObjectsPandas"} 

1566 ) as timer: 

1567 execute_concurrent(context.session, queries, execution_profile="write") 

1568 timer.add_values(row_count=len(records), num_batches=len(queries)) 

1569 

1570 def _storeUpdateRecords( 

1571 self, records: Iterable[ApdbUpdateRecord], chunk: ReplicaChunk, *, store_chunk: bool = False 

1572 ) -> None: 

1573 """Store ApdbUpdateRecords in the replica table for those records. 

1574 

1575 Parameters 

1576 ---------- 

1577 records : `list` [`ApdbUpdateRecord`] 

1578 Records to store. 

1579 chunk : `ReplicaChunk` 

1580 Replica chunk for these records. 

1581 store_chunk : `bool` 

1582 If True then also store replica chunk. 

1583 

1584 Raises 

1585 ------ 

1586 TypeError 

1587 Raised if replication is not enabled for this instance. 

1588 """ 

1589 context = self._context 

1590 config = context.config 

1591 

1592 if not context.schema.replication_enabled: 

1593 raise TypeError("Replication is not enabled for this APDB instance.") 

1594 

1595 if store_chunk: 

1596 self._storeReplicaChunk(chunk) 

1597 

1598 apdb_replica_chunk = chunk.id 

1599 # Do not use unique_if from ReplicaChunk as it could be reused in 

1600 # multiple calls to this method. 

1601 update_unique_id = uuid.uuid4() 

1602 

1603 rows = [] 

1604 for record in records: 

1605 rows.append( 

1606 [ 

1607 apdb_replica_chunk, 

1608 record.update_time_ns, 

1609 record.update_order, 

1610 update_unique_id, 

1611 record.to_json(), 

1612 ] 

1613 ) 

1614 columns = [ 

1615 "apdb_replica_chunk", 

1616 "update_time_ns", 

1617 "update_order", 

1618 "update_unique_id", 

1619 "update_payload", 

1620 ] 

1621 if context.has_chunk_sub_partitions: 

1622 subchunk = random.randrange(config.replica_sub_chunk_count) 

1623 for row in rows: 

1624 row.append(subchunk) 

1625 columns.append("apdb_replica_subchunk") 

1626 

1627 table_name = context.schema.tableName(ExtraTables.ApdbUpdateRecordChunks) 

1628 query = Insert(self._keyspace, table_name, columns) 

1629 stmt = context.stmt_factory(query) 

1630 queries = [(stmt, row) for row in rows] 

1631 

1632 with self._timer("store_update_record", tags={"table": table_name}) as timer: 

1633 execute_concurrent(context.session, queries, execution_profile="write") 

1634 timer.add_values(row_count=len(queries)) 

1635 

1636 def _add_apdb_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1637 """Calculate spatial partition for each record and add it to a 

1638 DataFrame. 

1639 

1640 Parameters 

1641 ---------- 

1642 df : `pandas.DataFrame` 

1643 DataFrame which has to contain ra/dec columns, names of these 

1644 columns are defined by configuration ``ra_dec_columns`` field. 

1645 

1646 Returns 

1647 ------- 

1648 df : `pandas.DataFrame` 

1649 DataFrame with ``apdb_part`` column which contains pixel index 

1650 for ra/dec coordinates. 

1651 

1652 Notes 

1653 ----- 

1654 This overrides any existing column in a DataFrame with the same name 

1655 (``apdb_part``). Original DataFrame is not changed, copy of a DataFrame 

1656 is returned. 

1657 """ 

1658 context = self._context 

1659 config = context.config 

1660 

1661 # Calculate pixelization index for every record. 

1662 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1663 ra_col, dec_col = config.ra_dec_columns 

1664 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1665 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1666 idx = context.partitioner.pixel(uv3d) 

1667 apdb_part[i] = idx 

1668 df = df.copy() 

1669 df["apdb_part"] = apdb_part 

1670 return df 

1671 

1672 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1673 """Make an empty catalog for a table with a given name. 

1674 

1675 Parameters 

1676 ---------- 

1677 table_name : `ApdbTables` 

1678 Name of the table. 

1679 

1680 Returns 

1681 ------- 

1682 catalog : `pandas.DataFrame` 

1683 An empty catalog. 

1684 """ 

1685 table = self.schema.tableSchemas[table_name] 

1686 

1687 data = {columnDef.name: pandas.Series(dtype=columnDef.pandas_type) for columnDef in table.columns} 

1688 return pandas.DataFrame(data) 

1689 

1690 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1691 """Update timestamp columns in input DataFrame to be naive datetime 

1692 type. 

1693 

1694 Clients may or may not generate aware timestamps, code in this class 

1695 assumes that timestamps are naive, so we convert them to UTC and 

1696 drop timezone. 

1697 """ 

1698 # Find all columns with aware timestamps. 

1699 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)] 

1700 for column in columns: 

1701 # tz_convert(None) will convert to UTC and drop timezone. 

1702 df[column] = df[column].dt.tz_convert(None) 

1703 return df 

1704 

1705 def _batch_size(self, table: ApdbTables | ExtraTables) -> int: 

1706 """Calculate batch size based on config parameters.""" 

1707 context = self._context 

1708 config = context.config 

1709 

1710 # Cassandra limit on number of statements in a batch is 64k. 

1711 batch_size = 65_535 

1712 if 0 < config.batch_statement_limit < batch_size: 

1713 batch_size = config.batch_statement_limit 

1714 if config.batch_size_limit > 0: 

1715 # The purpose of this limit is to try not to exceed batch size 

1716 # threshold which is set on server side. Cassandra wire protocol 

1717 # for prepared queries (and batches) only sends column values with 

1718 # with an additional 4 bytes per value specifying size. Value is 

1719 # not included for NULL or NOT_SET values, but the size is always 

1720 # there. There is additional small per-query overhead, which we 

1721 # ignore. 

1722 row_size = context.schema.table_row_size(table) 

1723 row_size += 4 * len(context.schema.getColumnMap(table)) 

1724 batch_size = min(batch_size, (config.batch_size_limit // row_size) + 1) 

1725 return batch_size 

1726 

1727 def _group_dia_objects_by_partition( 

1728 self, partitioner: Partitioner, objects: list[DiaObjectId], pad_arcsec: float 

1729 ) -> Mapping[int, list[int]]: 

1730 """Group DiaObjects by partition. 

1731 

1732 Parameters 

1733 ---------- 

1734 partitioner : `Partitioner` 

1735 Objects which knows how to partition things. 

1736 objects : `list` [`DiaObjectId`] 

1737 Collection of objects to partition. 

1738 pad_arcsec : `float` 

1739 Additional padding around object position. 

1740 

1741 Returns 

1742 ------- 

1743 grouped_objects 

1744 Mapping of spatial patition ID to list ob object IDs that it 

1745 contains. Some objects may belong to more than one partition. 

1746 """ 

1747 partitioned_object_ids: dict[int, list[int]] = defaultdict(list) 

1748 for obj_id in objects: 

1749 partitions = partitioner.pixelization.circle_pixels(obj_id.ra, obj_id.dec, pad_arcsec) 

1750 for pixel in partitions: 

1751 partitioned_object_ids[pixel].append(obj_id.diaObjectId) 

1752 return partitioned_object_ids 

1753 

1754 def _timestamp_column_name(self, column: str) -> str: 

1755 """Return column name before/after schema migration to MJD TAI.""" 

1756 return self._schema.timestamp_column_name(column) 

1757 

1758 def _timestamp_column_value(self, time: astropy.time.Time) -> float | int: 

1759 """Return column value before/after schema migration to MJD TAI.""" 

1760 if self._schema.has_mjd_timestamps: 

1761 return float(time.tai.mjd) 

1762 else: 

1763 return int(time.datetime.astimezone(tz=datetime.UTC).timestamp() * 1000) 

1764 

1765 def _get_diasource_data(self, source_ids: Iterable[DiaSourceId], *columns: str) -> list: 

1766 """Select records from DiaSource table by diaSourceId and return all 

1767 records as a list of named tuples. 

1768 """ 

1769 context = self._context 

1770 config = context.config 

1771 partitioner = context.partitioner 

1772 

1773 columns = ("diaSourceId",) + columns 

1774 

1775 # Allow some uncertainty for coordinates and time when calculating 

1776 # partitions. 

1777 statements: list[tuple] = [] 

1778 pad_arcsec = 1.0 

1779 pad_time_day = 10 / (24 * 3600) 

1780 for source_id in source_ids: 

1781 center = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(source_id.ra, source_id.dec)) 

1782 region = sphgeom.Circle(center, sphgeom.Angle.fromDegrees(pad_arcsec / 3600.0)) 

1783 spatial_where, _ = partitioner.spatial_where(region) 

1784 

1785 tables, temporal_where = partitioner.temporal_where( 

1786 ApdbTables.DiaSource, 

1787 source_id.midpointMjdTai - pad_time_day, 

1788 source_id.midpointMjdTai + pad_time_day, 

1789 partitons_range=context.time_partitions_range, 

1790 query_per_time_part=True, 

1791 ) 

1792 

1793 id_where = QExpr('"diaSourceId" = {}', (source_id.diaSourceId,)) 

1794 

1795 for table in tables: 

1796 query = Select(self._keyspace, table, columns) 

1797 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where): 

1798 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True)) 

1799 

1800 with self._timer( 

1801 "select_time", tags={"table": "DiaSource", "method": "_get_diasource_data"} 

1802 ) as timer: 

1803 result = cast( 

1804 list[tuple], 

1805 select_concurrent( 

1806 context.session, 

1807 statements, 

1808 "read_named_tuples", 

1809 config.connection_config.read_concurrency, 

1810 ), 

1811 ) 

1812 timer.add_values(row_count=len(result), num_queries=len(statements)) 

1813 

1814 return result 

1815 

1816 def _get_diaobject_data(self, object_ids: Iterable[DiaObjectId], *columns: str) -> list: 

1817 """Select records from DiaObjectLast table by diaObjectId and return 

1818 all records as a list of named tuples. 

1819 """ 

1820 context = self._context 

1821 config = context.config 

1822 partitioner = context.partitioner 

1823 

1824 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

1825 columns = ("diaObjectId",) + columns 

1826 

1827 # Allow some uncertainty for coordinates when calculating partitions. 

1828 pad_arcsec = 1.0 

1829 ids_by_partition = defaultdict(list) 

1830 for object_id in object_ids: 

1831 pixels = partitioner.pixelization.circle_pixels(object_id.ra, object_id.dec, pad_arcsec) 

1832 for pixel in pixels: 

1833 ids_by_partition[pixel].append(object_id.diaObjectId) 

1834 

1835 statements: list[tuple] = [] 

1836 for apdb_part, diaObjectIds in ids_by_partition.items(): 

1837 query = Select(self._keyspace, table_name, columns) 

1838 query = query.where(C("apdb_part") == apdb_part) 

1839 query = query.where(C("diaObjectId").in_(diaObjectIds)) 

1840 statements.append(context.stmt_factory.with_params(query, prepare=False)) 

1841 

1842 with self._timer( 

1843 "select_time", tags={"table": "DiaObjectLast", "method": "_get_diaobject_data"} 

1844 ) as timer: 

1845 result = cast( 

1846 list[tuple], 

1847 select_concurrent( 

1848 context.session, 

1849 statements, 

1850 "read_named_tuples", 

1851 config.connection_config.read_concurrency, 

1852 ), 

1853 ) 

1854 timer.add_values(row_count=len(result), num_queries=len(statements)) 

1855 

1856 return result