Coverage for python / lsst / dax / apdb / cassandra / apdbCassandra.py: 10%

621 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:49 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandra"] 

25 

26import datetime 

27import logging 

28import random 

29import uuid 

30import warnings 

31from collections import defaultdict 

32from collections.abc import Iterable, Iterator, Mapping, Set 

33from typing import TYPE_CHECKING, Any, cast 

34 

35import numpy as np 

36import pandas 

37 

38# If cassandra-driver is not there the module can still be imported 

39# but ApdbCassandra cannot be instantiated. 

40try: 

41 import cassandra 

42 import cassandra.query 

43 from cassandra.query import UNSET_VALUE 

44 

45 CASSANDRA_IMPORTED = True 

46except ImportError: 

47 CASSANDRA_IMPORTED = False 

48 

49import astropy.time 

50import felis.datamodel 

51 

52from lsst import sphgeom 

53from lsst.utils.iteration import chunk_iterable 

54 

55from ..apdb import Apdb, ApdbConfig 

56from ..apdbConfigFreezer import ApdbConfigFreezer 

57from ..apdbReplica import ApdbTableData, ReplicaChunk 

58from ..apdbSchema import ApdbSchema, ApdbTables 

59from ..monitor import MonAgent 

60from ..schema_model import Table 

61from ..timer import Timer 

62from ..versionTuple import IncompatibleVersionError, VersionTuple 

63from .apdbCassandraAdmin import ApdbCassandraAdmin 

64from .apdbCassandraReplica import ApdbCassandraReplica 

65from .apdbCassandraSchema import ApdbCassandraSchema, CreateTableOptions, ExtraTables 

66from .apdbMetadataCassandra import ApdbMetadataCassandra 

67from .cassandra_utils import ( 

68 execute_concurrent, 

69 literal, 

70 quote_id, 

71 select_concurrent, 

72) 

73from .config import ApdbCassandraConfig, ApdbCassandraConnectionConfig, ApdbCassandraTimePartitionRange 

74from .connectionContext import ConnectionContext, DbVersions 

75from .exceptions import CassandraMissingError 

76from .partitioner import Partitioner 

77from .sessionFactory import SessionContext, SessionFactory 

78 

79if TYPE_CHECKING: 

80 from ..apdbMetadata import ApdbMetadata 

81 from ..apdbUpdateRecord import ApdbUpdateRecord 

82 

83_LOG = logging.getLogger(__name__) 

84 

85_MON = MonAgent(__name__) 

86 

87VERSION = VersionTuple(1, 2, 1) 

88"""Version for the code controlling non-replication tables. This needs to be 

89updated following compatibility rules when schema produced by this code 

90changes. 

91""" 

92 

93 

94class ApdbCassandra(Apdb): 

95 """Implementation of APDB database with Apache Cassandra backend. 

96 

97 Parameters 

98 ---------- 

99 config : `ApdbCassandraConfig` 

100 Configuration object. 

101 """ 

102 

103 def __init__(self, config: ApdbCassandraConfig): 

104 if not CASSANDRA_IMPORTED: 

105 raise CassandraMissingError() 

106 

107 self._config = config 

108 self._keyspace = config.keyspace 

109 self._schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

110 

111 self._session_factory = SessionFactory(config) 

112 self._connection_context: ConnectionContext | None = None 

113 

114 @property 

115 def _context(self) -> ConnectionContext: 

116 """Establish connection if not established and return context.""" 

117 if self._connection_context is None: 

118 session = self._session_factory.session() 

119 self._connection_context = ConnectionContext(session, self._config, self._schema.tableSchemas) 

120 

121 # Check version compatibility 

122 current_versions = DbVersions( 

123 schema_version=self._schema.schemaVersion(), 

124 code_version=self.apdbImplementationVersion(), 

125 replica_version=( 

126 ApdbCassandraReplica.apdbReplicaImplementationVersion() 

127 if self._connection_context.config.enable_replica 

128 else None 

129 ), 

130 ) 

131 _LOG.debug("Current versions: %s", current_versions) 

132 self._versionCheck(current_versions, self._connection_context.db_versions) 

133 

134 if _LOG.isEnabledFor(logging.DEBUG): 

135 _LOG.debug("ApdbCassandra Configuration: %s", self._connection_context.config.model_dump()) 

136 

137 return self._connection_context 

138 

139 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

140 """Create `Timer` instance given its name.""" 

141 return Timer(name, _MON, tags=tags) 

142 

143 def _versionCheck(self, current_versions: DbVersions, db_versions: DbVersions) -> None: 

144 """Check schema version compatibility.""" 

145 if not current_versions.schema_version.checkCompatibility(db_versions.schema_version): 

146 raise IncompatibleVersionError( 

147 f"Configured schema version {current_versions.schema_version} " 

148 f"is not compatible with database version {db_versions.schema_version}" 

149 ) 

150 if not current_versions.code_version.checkCompatibility(db_versions.code_version): 

151 raise IncompatibleVersionError( 

152 f"Current code version {current_versions.code_version} " 

153 f"is not compatible with database version {db_versions.code_version}" 

154 ) 

155 

156 # Check replica code version only if replica is enabled. 

157 match current_versions.replica_version, db_versions.replica_version: 

158 case None, None: 

159 pass 

160 case VersionTuple() as current, VersionTuple() as stored: 

161 if not current.checkCompatibility(stored): 

162 raise IncompatibleVersionError( 

163 f"Current replication code version {current} " 

164 f"is not compatible with database version {stored}" 

165 ) 

166 case _: 

167 raise IncompatibleVersionError( 

168 f"Current replication code version {current_versions.replica_version} " 

169 f"is not compatible with database version {db_versions.replica_version}" 

170 ) 

171 

172 @classmethod 

173 def apdbImplementationVersion(cls) -> VersionTuple: 

174 """Return version number for current APDB implementation. 

175 

176 Returns 

177 ------- 

178 version : `VersionTuple` 

179 Version of the code defined in implementation class. 

180 """ 

181 return VERSION 

182 

183 def getConfig(self) -> ApdbCassandraConfig: 

184 # docstring is inherited from a base class 

185 return self._context.config 

186 

187 def tableDef(self, table: ApdbTables) -> Table | None: 

188 # docstring is inherited from a base class 

189 return self._schema.tableSchemas.get(table) 

190 

191 @classmethod 

192 def init_database( 

193 cls, 

194 hosts: tuple[str, ...], 

195 keyspace: str, 

196 *, 

197 schema_file: str | None = None, 

198 ss_schema_file: str | None = None, 

199 read_sources_months: int | None = None, 

200 read_forced_sources_months: int | None = None, 

201 enable_replica: bool = False, 

202 replica_skips_diaobjects: bool = False, 

203 port: int | None = None, 

204 username: str | None = None, 

205 prefix: str | None = None, 

206 part_pixelization: str | None = None, 

207 part_pix_level: int | None = None, 

208 time_partition_tables: bool = True, 

209 time_partition_start: str | None = None, 

210 time_partition_end: str | None = None, 

211 read_consistency: str | None = None, 

212 write_consistency: str | None = None, 

213 read_timeout: int | None = None, 

214 write_timeout: int | None = None, 

215 ra_dec_columns: tuple[str, str] | None = None, 

216 replication_factor: int | None = None, 

217 drop: bool = False, 

218 table_options: CreateTableOptions | None = None, 

219 ) -> ApdbCassandraConfig: 

220 """Initialize new APDB instance and make configuration object for it. 

221 

222 Parameters 

223 ---------- 

224 hosts : `tuple` [`str`, ...] 

225 List of host names or IP addresses for Cassandra cluster. 

226 keyspace : `str` 

227 Name of the keyspace for APDB tables. 

228 schema_file : `str`, optional 

229 Location of (YAML) configuration file with APDB schema. If not 

230 specified then default location will be used. 

231 ss_schema_file : `str`, optional 

232 Location of (YAML) configuration file with SSO schema. If not 

233 specified then default location will be used. 

234 read_sources_months : `int`, optional 

235 Number of months of history to read from DiaSource. 

236 read_forced_sources_months : `int`, optional 

237 Number of months of history to read from DiaForcedSource. 

238 enable_replica : `bool`, optional 

239 If True, make additional tables used for replication to PPDB. 

240 replica_skips_diaobjects : `bool`, optional 

241 If `True` then do not fill regular ``DiaObject`` table when 

242 ``enable_replica`` is `True`. 

243 port : `int`, optional 

244 Port number to use for Cassandra connections. 

245 username : `str`, optional 

246 User name for Cassandra connections. 

247 prefix : `str`, optional 

248 Optional prefix for all table names. 

249 part_pixelization : `str`, optional 

250 Name of the MOC pixelization used for partitioning. 

251 part_pix_level : `int`, optional 

252 Pixelization level. 

253 time_partition_tables : `bool`, optional 

254 Create per-partition tables. 

255 time_partition_start : `str`, optional 

256 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

257 format, in TAI. 

258 time_partition_end : `str`, optional 

259 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

260 format, in TAI. 

261 read_consistency : `str`, optional 

262 Name of the consistency level for read operations. 

263 write_consistency : `str`, optional 

264 Name of the consistency level for write operations. 

265 read_timeout : `int`, optional 

266 Read timeout in seconds. 

267 write_timeout : `int`, optional 

268 Write timeout in seconds. 

269 ra_dec_columns : `tuple` [`str`, `str`], optional 

270 Names of ra/dec columns in DiaObject table. 

271 replication_factor : `int`, optional 

272 Replication factor used when creating new keyspace, if keyspace 

273 already exists its replication factor is not changed. 

274 drop : `bool`, optional 

275 If `True` then drop existing tables before re-creating the schema. 

276 table_options : `CreateTableOptions`, optional 

277 Options used when creating Cassandra tables. 

278 

279 Returns 

280 ------- 

281 config : `ApdbCassandraConfig` 

282 Resulting configuration object for a created APDB instance. 

283 """ 

284 # Some non-standard defaults for connection parameters, these can be 

285 # changed later in generated config. Check Cassandra driver 

286 # documentation for what these parameters do. These parameters are not 

287 # used during database initialization, but they will be saved with 

288 # generated config. 

289 connection_config = ApdbCassandraConnectionConfig( 

290 extra_parameters={ 

291 "idle_heartbeat_interval": 0, 

292 "idle_heartbeat_timeout": 30, 

293 "control_connection_timeout": 100, 

294 }, 

295 ) 

296 config = ApdbCassandraConfig( 

297 contact_points=hosts, 

298 keyspace=keyspace, 

299 enable_replica=enable_replica, 

300 replica_skips_diaobjects=replica_skips_diaobjects, 

301 connection_config=connection_config, 

302 ) 

303 config.partitioning.time_partition_tables = time_partition_tables 

304 if schema_file is not None: 

305 config.schema_file = schema_file 

306 if ss_schema_file is not None: 

307 config.ss_schema_file = ss_schema_file 

308 if read_sources_months is not None: 

309 config.read_sources_months = read_sources_months 

310 if read_forced_sources_months is not None: 

311 config.read_forced_sources_months = read_forced_sources_months 

312 if port is not None: 

313 config.connection_config.port = port 

314 if username is not None: 

315 config.connection_config.username = username 

316 if prefix is not None: 

317 config.prefix = prefix 

318 if part_pixelization is not None: 

319 config.partitioning.part_pixelization = part_pixelization 

320 if part_pix_level is not None: 

321 config.partitioning.part_pix_level = part_pix_level 

322 if time_partition_start is not None: 

323 config.partitioning.time_partition_start = time_partition_start 

324 if time_partition_end is not None: 

325 config.partitioning.time_partition_end = time_partition_end 

326 if read_consistency is not None: 

327 config.connection_config.read_consistency = read_consistency 

328 if write_consistency is not None: 

329 config.connection_config.write_consistency = write_consistency 

330 if read_timeout is not None: 

331 config.connection_config.read_timeout = read_timeout 

332 if write_timeout is not None: 

333 config.connection_config.write_timeout = write_timeout 

334 if ra_dec_columns is not None: 

335 config.ra_dec_columns = ra_dec_columns 

336 

337 cls._makeSchema(config, drop=drop, replication_factor=replication_factor, table_options=table_options) 

338 

339 return config 

340 

341 def get_replica(self) -> ApdbCassandraReplica: 

342 """Return `ApdbReplica` instance for this database.""" 

343 # Note that this instance has to stay alive while replica exists, so 

344 # we pass reference to self. 

345 return ApdbCassandraReplica(self) 

346 

347 @classmethod 

348 def _makeSchema( 

349 cls, 

350 config: ApdbConfig, 

351 *, 

352 drop: bool = False, 

353 replication_factor: int | None = None, 

354 table_options: CreateTableOptions | None = None, 

355 ) -> None: 

356 # docstring is inherited from a base class 

357 

358 if not isinstance(config, ApdbCassandraConfig): 

359 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

360 

361 simple_schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

362 

363 with SessionContext(config) as session: 

364 schema = ApdbCassandraSchema( 

365 session=session, 

366 keyspace=config.keyspace, 

367 table_schemas=simple_schema.tableSchemas, 

368 prefix=config.prefix, 

369 time_partition_tables=config.partitioning.time_partition_tables, 

370 enable_replica=config.enable_replica, 

371 replica_skips_diaobjects=config.replica_skips_diaobjects, 

372 ) 

373 

374 # Ask schema to create all tables. 

375 part_range_config: ApdbCassandraTimePartitionRange | None = None 

376 if config.partitioning.time_partition_tables: 

377 partitioner = Partitioner(config) 

378 time_partition_start = astropy.time.Time( 

379 config.partitioning.time_partition_start, format="isot", scale="tai" 

380 ) 

381 time_partition_end = astropy.time.Time( 

382 config.partitioning.time_partition_end, format="isot", scale="tai" 

383 ) 

384 part_range_config = ApdbCassandraTimePartitionRange( 

385 start=partitioner.time_partition(time_partition_start), 

386 end=partitioner.time_partition(time_partition_end), 

387 ) 

388 schema.makeSchema( 

389 drop=drop, 

390 part_range=part_range_config, 

391 replication_factor=replication_factor, 

392 table_options=table_options, 

393 ) 

394 else: 

395 schema.makeSchema( 

396 drop=drop, replication_factor=replication_factor, table_options=table_options 

397 ) 

398 

399 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

400 metadata = ApdbMetadataCassandra( 

401 session, meta_table_name, config.keyspace, "read_tuples", "write" 

402 ) 

403 

404 # Fill version numbers, overrides if they existed before. 

405 metadata.set( 

406 ConnectionContext.metadataSchemaVersionKey, str(simple_schema.schemaVersion()), force=True 

407 ) 

408 metadata.set( 

409 ConnectionContext.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True 

410 ) 

411 

412 if config.enable_replica: 

413 # Only store replica code version if replica is enabled. 

414 metadata.set( 

415 ConnectionContext.metadataReplicaVersionKey, 

416 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

417 force=True, 

418 ) 

419 

420 # Store frozen part of a configuration in metadata. 

421 freezer = ApdbConfigFreezer[ApdbCassandraConfig](ConnectionContext.frozen_parameters) 

422 metadata.set(ConnectionContext.metadataConfigKey, freezer.to_json(config), force=True) 

423 

424 # Store time partition range. 

425 if part_range_config: 

426 part_range_config.save_to_meta(metadata) 

427 

428 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

429 # docstring is inherited from a base class 

430 context = self._context 

431 config = context.config 

432 

433 sp_where, num_sp_part = context.partitioner.spatial_where(region, for_prepare=True) 

434 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

435 

436 # We need to exclude extra partitioning columns from result. 

437 column_names = context.schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

438 what = ",".join(quote_id(column) for column in column_names) 

439 

440 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

441 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

442 statements: list[tuple] = [] 

443 for where, params in sp_where: 

444 full_query = f"{query} WHERE {where}" 

445 if params: 

446 statement = context.preparer.prepare(full_query) 

447 else: 

448 # If there are no params then it is likely that query has a 

449 # bunch of literals rendered already, no point trying to 

450 # prepare it because it's not reusable. 

451 statement = cassandra.query.SimpleStatement(full_query) 

452 statements.append((statement, params)) 

453 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

454 

455 with _MON.context_tags({"table": "DiaObject"}): 

456 _MON.add_record( 

457 "select_query_stats", values={"num_sp_part": num_sp_part, "num_queries": len(statements)} 

458 ) 

459 with self._timer("select_time") as timer: 

460 objects = cast( 

461 pandas.DataFrame, 

462 select_concurrent( 

463 context.session, 

464 statements, 

465 "read_pandas_multi", 

466 config.connection_config.read_concurrency, 

467 ), 

468 ) 

469 timer.add_values(row_count=len(objects)) 

470 

471 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

472 return objects 

473 

474 def getDiaSources( 

475 self, 

476 region: sphgeom.Region, 

477 object_ids: Iterable[int] | None, 

478 visit_time: astropy.time.Time, 

479 start_time: astropy.time.Time | None = None, 

480 ) -> pandas.DataFrame | None: 

481 # docstring is inherited from a base class 

482 context = self._context 

483 config = context.config 

484 

485 months = config.read_sources_months 

486 if start_time is None and months == 0: 

487 return None 

488 

489 mjd_end = float(visit_time.tai.mjd) 

490 if start_time is None: 

491 mjd_start = mjd_end - months * 30 

492 else: 

493 mjd_start = float(start_time.tai.mjd) 

494 

495 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

496 

497 def getDiaForcedSources( 

498 self, 

499 region: sphgeom.Region, 

500 object_ids: Iterable[int] | None, 

501 visit_time: astropy.time.Time, 

502 start_time: astropy.time.Time | None = None, 

503 ) -> pandas.DataFrame | None: 

504 # docstring is inherited from a base class 

505 context = self._context 

506 config = context.config 

507 

508 months = config.read_forced_sources_months 

509 if start_time is None and months == 0: 

510 return None 

511 

512 mjd_end = float(visit_time.tai.mjd) 

513 if start_time is None: 

514 mjd_start = mjd_end - months * 30 

515 else: 

516 mjd_start = float(start_time.tai.mjd) 

517 

518 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

519 

520 def containsVisitDetector( 

521 self, 

522 visit: int, 

523 detector: int, 

524 region: sphgeom.Region, 

525 visit_time: astropy.time.Time, 

526 ) -> bool: 

527 # docstring is inherited from a base class 

528 context = self._context 

529 config = context.config 

530 

531 # If ApdbDetectorVisit table exists just check it. 

532 if context.has_visit_detector_table: 

533 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

534 query = ( 

535 f'SELECT count(*) FROM "{self._keyspace}"."{table_name}" WHERE visit = %s AND detector = %s' 

536 ) 

537 with self._timer("contains_visit_detector_time"): 

538 result = context.session.execute(query, (visit, detector)) 

539 return bool(result.one()[0]) 

540 

541 # The order of checks corresponds to order in store(), on potential 

542 # store failure earlier tables have higher probability containing 

543 # stored records. With per-partition tables there will be many tables 

544 # in the list, but it is unlikely that we'll use that setup in 

545 # production. 

546 sp_where, _ = context.partitioner.spatial_where(region, use_ranges=True, for_prepare=True) 

547 visit_detector_where = ("visit = ? AND detector = ?", (visit, detector)) 

548 

549 # Sources are partitioned on their midPointMjdTai. To avoid precision 

550 # issues add some fuzzines to visit time. 

551 mjd_start = float(visit_time.tai.mjd) - 1.0 / 24 

552 mjd_end = float(visit_time.tai.mjd) + 1.0 / 24 

553 

554 statements: list[tuple] = [] 

555 for table_type in ApdbTables.DiaSource, ApdbTables.DiaForcedSource: 

556 tables, temporal_where = context.partitioner.temporal_where( 

557 table_type, mjd_start, mjd_end, query_per_time_part=True, for_prepare=True 

558 ) 

559 for table in tables: 

560 prefix = f'SELECT apdb_part FROM "{self._keyspace}"."{table}"' 

561 # Needs ALLOW FILTERING as there is no PK constraint. 

562 suffix = "PER PARTITION LIMIT 1 LIMIT 1 ALLOW FILTERING" 

563 statements += list( 

564 self._combine_where(prefix, sp_where, temporal_where, visit_detector_where, suffix) 

565 ) 

566 

567 with self._timer("contains_visit_detector_time"): 

568 result = cast( 

569 list[tuple[int] | None], 

570 select_concurrent( 

571 context.session, 

572 statements, 

573 "read_tuples", 

574 config.connection_config.read_concurrency, 

575 ), 

576 ) 

577 return bool(result) 

578 

579 def store( 

580 self, 

581 visit_time: astropy.time.Time, 

582 objects: pandas.DataFrame, 

583 sources: pandas.DataFrame | None = None, 

584 forced_sources: pandas.DataFrame | None = None, 

585 ) -> None: 

586 # docstring is inherited from a base class 

587 context = self._context 

588 config = context.config 

589 

590 if context.has_visit_detector_table: 

591 # Store visit/detector in a special table, this has to be done 

592 # before all other writes so if there is a failure at any point 

593 # later we still have a record for attempted write. 

594 visit_detector: set[tuple[int, int]] = set() 

595 for df in sources, forced_sources: 

596 if df is not None and not df.empty: 

597 df = df[["visit", "detector"]] 

598 for visit, detector in df.itertuples(index=False): 

599 visit_detector.add((visit, detector)) 

600 

601 if visit_detector: 

602 # Typically there is only one entry, do not bother with 

603 # concurrency. 

604 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

605 query = f'INSERT INTO "{self._keyspace}"."{table_name}" (visit, detector) VALUES (%s, %s)' 

606 for item in visit_detector: 

607 context.session.execute(query, item, execution_profile="write") 

608 

609 objects = self._fix_input_timestamps(objects) 

610 if sources is not None: 

611 sources = self._fix_input_timestamps(sources) 

612 if forced_sources is not None: 

613 forced_sources = self._fix_input_timestamps(forced_sources) 

614 

615 replica_chunk: ReplicaChunk | None = None 

616 if context.schema.replication_enabled: 

617 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, config.replica_chunk_seconds) 

618 self._storeReplicaChunk(replica_chunk) 

619 

620 # fill region partition column for DiaObjects 

621 objects = self._add_apdb_part(objects) 

622 self._storeDiaObjects(objects, visit_time, replica_chunk) 

623 

624 if sources is not None and len(sources) > 0: 

625 # copy apdb_part column from DiaObjects to DiaSources 

626 sources = self._add_apdb_part(sources) 

627 subchunk = self._storeDiaSources(ApdbTables.DiaSource, sources, replica_chunk) 

628 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk, subchunk) 

629 

630 if forced_sources is not None and len(forced_sources) > 0: 

631 forced_sources = self._add_apdb_part(forced_sources) 

632 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, replica_chunk) 

633 

634 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

635 # docstring is inherited from a base class 

636 context = self._context 

637 config = context.config 

638 

639 if self._schema.has_mjd_timestamps: 

640 reassign_time_column = "ssObjectReassocTimeMjdTai" 

641 reassignTime = float(astropy.time.Time.now().tai.mjd) 

642 else: 

643 reassign_time_column = "ssObjectReassocTime" 

644 # Current time as milliseconds since epoch. 

645 reassignTime = int(datetime.datetime.now(tz=datetime.UTC).timestamp() * 1000) 

646 

647 # To update a record we need to know its exact primary key (including 

648 # partition key) so we start by querying for diaSourceId to find the 

649 # primary keys. 

650 

651 table_name = context.schema.tableName(ExtraTables.DiaSourceToPartition) 

652 # split it into 1k IDs per query 

653 selects: list[tuple] = [] 

654 for ids in chunk_iterable(idMap.keys(), 1_000): 

655 ids_str = ",".join(str(item) for item in ids) 

656 selects.append( 

657 ( 

658 ( 

659 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk" ' 

660 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

661 ), 

662 {}, 

663 ) 

664 ) 

665 

666 # No need for DataFrame here, read data as tuples. 

667 result = cast( 

668 list[tuple[int, int, int, int | None]], 

669 select_concurrent( 

670 context.session, selects, "read_tuples", config.connection_config.read_concurrency 

671 ), 

672 ) 

673 

674 # Make mapping from source ID to its partition. 

675 id2partitions: dict[int, tuple[int, int]] = {} 

676 id2chunk_id: dict[int, int] = {} 

677 for row in result: 

678 id2partitions[row[0]] = row[1:3] 

679 if row[3] is not None: 

680 id2chunk_id[row[0]] = row[3] 

681 

682 # make sure we know partitions for each ID 

683 if set(id2partitions) != set(idMap): 

684 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

685 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

686 

687 # Reassign in standard tables 

688 queries: list[tuple[cassandra.query.PreparedStatement, tuple]] = [] 

689 for diaSourceId, ssObjectId in idMap.items(): 

690 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

691 values: tuple 

692 if config.partitioning.time_partition_tables: 

693 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part) 

694 query = ( 

695 f'UPDATE "{self._keyspace}"."{table_name}"' 

696 f' SET "ssObjectId" = ?, "diaObjectId" = NULL, "{reassign_time_column}" = ?' 

697 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

698 ) 

699 values = (ssObjectId, reassignTime, apdb_part, diaSourceId) 

700 else: 

701 table_name = context.schema.tableName(ApdbTables.DiaSource) 

702 query = ( 

703 f'UPDATE "{self._keyspace}"."{table_name}"' 

704 f' SET "ssObjectId" = ?, "diaObjectId" = NULL, "{reassign_time_column}" = ?' 

705 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

706 ) 

707 values = (ssObjectId, reassignTime, apdb_part, apdb_time_part, diaSourceId) 

708 queries.append((context.preparer.prepare(query), values)) 

709 

710 # TODO: (DM-50190) Replication for updated records is not implemented. 

711 if id2chunk_id: 

712 warnings.warn("Replication of reassigned DiaSource records is not implemented.", stacklevel=2) 

713 

714 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

715 with self._timer("source_reassign_time") as timer: 

716 execute_concurrent(context.session, queries, execution_profile="write") 

717 timer.add_values(source_count=len(idMap)) 

718 

719 def dailyJob(self) -> None: 

720 # docstring is inherited from a base class 

721 pass 

722 

723 def countUnassociatedObjects(self) -> int: 

724 # docstring is inherited from a base class 

725 

726 # It's too inefficient to implement it for Cassandra in current schema. 

727 raise NotImplementedError() 

728 

729 @property 

730 def metadata(self) -> ApdbMetadata: 

731 # docstring is inherited from a base class 

732 context = self._context 

733 return context.metadata 

734 

735 @property 

736 def admin(self) -> ApdbCassandraAdmin: 

737 # docstring is inherited from a base class 

738 return ApdbCassandraAdmin(self) 

739 

740 def _getSources( 

741 self, 

742 region: sphgeom.Region, 

743 object_ids: Iterable[int] | None, 

744 mjd_start: float, 

745 mjd_end: float, 

746 table_name: ApdbTables, 

747 ) -> pandas.DataFrame: 

748 """Return catalog of DiaSource instances given set of DiaObject IDs. 

749 

750 Parameters 

751 ---------- 

752 region : `lsst.sphgeom.Region` 

753 Spherical region. 

754 object_ids : 

755 Collection of DiaObject IDs 

756 mjd_start : `float` 

757 Lower bound of time interval. 

758 mjd_end : `float` 

759 Upper bound of time interval. 

760 table_name : `ApdbTables` 

761 Name of the table. 

762 

763 Returns 

764 ------- 

765 catalog : `pandas.DataFrame`, or `None` 

766 Catalog containing DiaSource records. Empty catalog is returned if 

767 ``object_ids`` is empty. 

768 """ 

769 context = self._context 

770 config = context.config 

771 

772 object_id_set: Set[int] = set() 

773 if object_ids is not None: 

774 object_id_set = set(object_ids) 

775 if len(object_id_set) == 0: 

776 return self._make_empty_catalog(table_name) 

777 

778 sp_where, num_sp_part = context.partitioner.spatial_where(region, for_prepare=True) 

779 tables, temporal_where = context.partitioner.temporal_where( 

780 table_name, mjd_start, mjd_end, for_prepare=True, partitons_range=context.time_partitions_range 

781 ) 

782 if not tables: 

783 start = astropy.time.Time(mjd_start, format="mjd", scale="tai") 

784 end = astropy.time.Time(mjd_end, format="mjd", scale="tai") 

785 warnings.warn( 

786 f"Query time range ({start.isot} - {end.isot}) does not overlap database time partitions." 

787 ) 

788 

789 # We need to exclude extra partitioning columns from result. 

790 column_names = context.schema.apdbColumnNames(table_name) 

791 what = ",".join(quote_id(column) for column in column_names) 

792 

793 # Build all queries 

794 statements: list[tuple] = [] 

795 for table in tables: 

796 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

797 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

798 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

799 

800 with _MON.context_tags({"table": table_name.name}): 

801 _MON.add_record( 

802 "select_query_stats", values={"num_sp_part": num_sp_part, "num_queries": len(statements)} 

803 ) 

804 with self._timer("select_time") as timer: 

805 catalog = cast( 

806 pandas.DataFrame, 

807 select_concurrent( 

808 context.session, 

809 statements, 

810 "read_pandas_multi", 

811 config.connection_config.read_concurrency, 

812 ), 

813 ) 

814 timer.add_values(row_count_from_db=len(catalog)) 

815 

816 # filter by given object IDs 

817 if len(object_id_set) > 0: 

818 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

819 

820 # precise filtering on midpointMjdTai 

821 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

822 

823 timer.add_values(row_count=len(catalog)) 

824 

825 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

826 return catalog 

827 

828 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk) -> None: 

829 context = self._context 

830 config = context.config 

831 

832 # Cassandra timestamp uses milliseconds since epoch 

833 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

834 

835 # everything goes into a single partition 

836 partition = 0 

837 

838 table_name = context.schema.tableName(ExtraTables.ApdbReplicaChunks) 

839 

840 columns = ["partition", "apdb_replica_chunk", "last_update_time", "unique_id"] 

841 values = [partition, replica_chunk.id, timestamp, replica_chunk.unique_id] 

842 if context.has_chunk_sub_partitions: 

843 columns.append("has_subchunks") 

844 values.append(True) 

845 

846 column_list = ", ".join(columns) 

847 placeholders = ",".join(["%s"] * len(columns)) 

848 query = f'INSERT INTO "{self._keyspace}"."{table_name}" ({column_list}) VALUES ({placeholders})' 

849 

850 context.session.execute( 

851 query, 

852 values, 

853 timeout=config.connection_config.write_timeout, 

854 execution_profile="write", 

855 ) 

856 

857 def _queryDiaObjectLastPartitions(self, ids: Iterable[int]) -> Mapping[int, int]: 

858 """Return existing mapping of diaObjectId to its last partition.""" 

859 context = self._context 

860 config = context.config 

861 

862 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

863 queries = [] 

864 object_count = 0 

865 for id_chunk in chunk_iterable(ids, 10_000): 

866 id_chunk_list = list(id_chunk) 

867 query = ( 

868 f'SELECT "diaObjectId", apdb_part FROM "{self._keyspace}"."{table_name}" ' 

869 f'WHERE "diaObjectId" in ({",".join(str(oid) for oid in id_chunk_list)})' 

870 ) 

871 queries.append((query, ())) 

872 object_count += len(id_chunk_list) 

873 

874 with self._timer("query_object_last_partitions") as timer: 

875 data = cast( 

876 ApdbTableData, 

877 select_concurrent( 

878 context.session, 

879 queries, 

880 "read_raw_multi", 

881 config.connection_config.read_concurrency, 

882 ), 

883 ) 

884 timer.add_values(object_count=object_count, row_count=len(data.rows())) 

885 

886 if data.column_names() != ["diaObjectId", "apdb_part"]: 

887 raise RuntimeError(f"Unexpected column names in query result: {data.column_names()}") 

888 

889 return {row[0]: row[1] for row in data.rows()} 

890 

891 def _deleteMovingObjects(self, objs: pandas.DataFrame) -> None: 

892 """Objects in DiaObjectsLast can move from one spatial partition to 

893 another. For those objects inserting new version does not replace old 

894 one, so we need to explicitly remove old versions before inserting new 

895 ones. 

896 """ 

897 context = self._context 

898 

899 # Extract all object IDs. 

900 new_partitions = dict(zip(objs["diaObjectId"], objs["apdb_part"])) 

901 old_partitions = self._queryDiaObjectLastPartitions(objs["diaObjectId"]) 

902 

903 moved_oids: dict[int, tuple[int, int]] = {} 

904 for oid, old_part in old_partitions.items(): 

905 new_part = new_partitions.get(oid, old_part) 

906 if new_part != old_part: 

907 moved_oids[oid] = (old_part, new_part) 

908 _LOG.debug("DiaObject IDs that moved to new partition: %s", moved_oids) 

909 

910 if moved_oids: 

911 # Delete old records from DiaObjectLast. 

912 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

913 query = f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE apdb_part = ? AND "diaObjectId" = ?' 

914 statement = context.preparer.prepare(query) 

915 queries = [] 

916 for oid, (old_part, _) in moved_oids.items(): 

917 queries.append((statement, (old_part, oid))) 

918 with self._timer("delete_object_last") as timer: 

919 execute_concurrent(context.session, queries, execution_profile="write") 

920 timer.add_values(row_count=len(moved_oids)) 

921 

922 # Add all new records to the map. 

923 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

924 query = f'INSERT INTO "{self._keyspace}"."{table_name}" ("diaObjectId", apdb_part) VALUES (?,?)' 

925 statement = context.preparer.prepare(query) 

926 

927 queries = [] 

928 for oid, new_part in new_partitions.items(): 

929 queries.append((statement, (oid, new_part))) 

930 

931 with self._timer("update_object_last_partition") as timer: 

932 execute_concurrent(context.session, queries, execution_profile="write") 

933 timer.add_values(row_count=len(queries)) 

934 

935 def _storeDiaObjects( 

936 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

937 ) -> None: 

938 """Store catalog of DiaObjects from current visit. 

939 

940 Parameters 

941 ---------- 

942 objs : `pandas.DataFrame` 

943 Catalog with DiaObject records 

944 visit_time : `astropy.time.Time` 

945 Time of the current visit. 

946 replica_chunk : `ReplicaChunk` or `None` 

947 Replica chunk identifier if replication is configured. 

948 """ 

949 if len(objs) == 0: 

950 _LOG.debug("No objects to write to database.") 

951 return 

952 

953 context = self._context 

954 config = context.config 

955 

956 if context.has_dia_object_last_to_partition: 

957 self._deleteMovingObjects(objs) 

958 

959 timestamp: float | datetime.datetime 

960 if self._schema.has_mjd_timestamps: 

961 validity_start_column = "validityStartMjdTai" 

962 timestamp = float(visit_time.tai.mjd) 

963 else: 

964 validity_start_column = "validityStart" 

965 timestamp = visit_time.datetime 

966 

967 # DiaObjectLast did not have this column in the past. 

968 extra_columns: dict[str, Any] = {} 

969 if context.schema.check_column(ApdbTables.DiaObjectLast, validity_start_column): 

970 extra_columns[validity_start_column] = timestamp 

971 

972 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

973 

974 extra_columns[validity_start_column] = timestamp 

975 visit_time_part = context.partitioner.time_partition(visit_time) 

976 time_part: int | None = visit_time_part 

977 if (time_partitions_range := context.time_partitions_range) is not None: 

978 self._check_time_partitions([visit_time_part], time_partitions_range) 

979 if not config.partitioning.time_partition_tables: 

980 extra_columns["apdb_time_part"] = time_part 

981 time_part = None 

982 

983 # Only store DiaObects if not doing replication or explicitly 

984 # configured to always store them. 

985 if replica_chunk is None or not config.replica_skips_diaobjects: 

986 self._storeObjectsPandas( 

987 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

988 ) 

989 

990 if replica_chunk is not None: 

991 extra_columns = {"apdb_replica_chunk": replica_chunk.id, validity_start_column: timestamp} 

992 table = ExtraTables.DiaObjectChunks 

993 if context.has_chunk_sub_partitions: 

994 table = ExtraTables.DiaObjectChunks2 

995 # Use a random number for a second part of partitioning key so 

996 # that different clients could wrtite to different partitions. 

997 # This makes it not exactly reproducible. 

998 extra_columns["apdb_replica_subchunk"] = random.randrange(config.replica_sub_chunk_count) 

999 self._storeObjectsPandas(objs, table, extra_columns=extra_columns) 

1000 

1001 def _storeDiaSources( 

1002 self, 

1003 table_name: ApdbTables, 

1004 sources: pandas.DataFrame, 

1005 replica_chunk: ReplicaChunk | None, 

1006 ) -> int | None: 

1007 """Store catalog of DIASources or DIAForcedSources from current visit. 

1008 

1009 Parameters 

1010 ---------- 

1011 table_name : `ApdbTables` 

1012 Table where to store the data. 

1013 sources : `pandas.DataFrame` 

1014 Catalog containing DiaSource records 

1015 visit_time : `astropy.time.Time` 

1016 Time of the current visit. 

1017 replica_chunk : `ReplicaChunk` or `None` 

1018 Replica chunk identifier if replication is configured. 

1019 

1020 Returns 

1021 ------- 

1022 subchunk : `int` or `None` 

1023 Subchunk number for resulting replica data, `None` if relication is 

1024 not enabled ot subchunking is not enabled. 

1025 """ 

1026 context = self._context 

1027 config = context.config 

1028 

1029 # Time partitioning has to be based on midpointMjdTai, not visit_time 

1030 # as visit_time is not really a visit time. 

1031 tp_sources = sources.copy(deep=False) 

1032 tp_sources["apdb_time_part"] = tp_sources["midpointMjdTai"].apply(context.partitioner.time_partition) 

1033 if (time_partitions_range := context.time_partitions_range) is not None: 

1034 self._check_time_partitions(tp_sources["apdb_time_part"], time_partitions_range) 

1035 extra_columns: dict[str, Any] = {} 

1036 if not config.partitioning.time_partition_tables: 

1037 self._storeObjectsPandas(tp_sources, table_name) 

1038 else: 

1039 # Group by time partition 

1040 partitions = set(tp_sources["apdb_time_part"]) 

1041 if len(partitions) == 1: 

1042 # Single partition - just save the whole thing. 

1043 time_part = partitions.pop() 

1044 self._storeObjectsPandas(sources, table_name, time_part=time_part) 

1045 else: 

1046 # group by time partition. 

1047 for time_part, sub_frame in tp_sources.groupby(by="apdb_time_part"): 

1048 sub_frame.drop(columns="apdb_time_part", inplace=True) 

1049 self._storeObjectsPandas(sub_frame, table_name, time_part=time_part) 

1050 

1051 subchunk: int | None = None 

1052 if replica_chunk is not None: 

1053 extra_columns = {"apdb_replica_chunk": replica_chunk.id} 

1054 if context.has_chunk_sub_partitions: 

1055 subchunk = random.randrange(config.replica_sub_chunk_count) 

1056 extra_columns["apdb_replica_subchunk"] = subchunk 

1057 if table_name is ApdbTables.DiaSource: 

1058 extra_table = ExtraTables.DiaSourceChunks2 

1059 else: 

1060 extra_table = ExtraTables.DiaForcedSourceChunks2 

1061 else: 

1062 if table_name is ApdbTables.DiaSource: 

1063 extra_table = ExtraTables.DiaSourceChunks 

1064 else: 

1065 extra_table = ExtraTables.DiaForcedSourceChunks 

1066 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1067 

1068 return subchunk 

1069 

1070 def _check_time_partitions( 

1071 self, partitions: Iterable[int], time_partitions_range: ApdbCassandraTimePartitionRange 

1072 ) -> None: 

1073 """Check that time partitons for new data actually exist. 

1074 

1075 Parameters 

1076 ---------- 

1077 partitions : `~collections.abc.Iterable` [`int`] 

1078 Time partitions for new data. 

1079 time_partitions_range : `ApdbCassandraTimePartitionRange` 

1080 Currrent time partition range. 

1081 """ 

1082 partitions = set(partitions) 

1083 min_part = min(partitions) 

1084 max_part = max(partitions) 

1085 if min_part < time_partitions_range.start or max_part > time_partitions_range.end: 

1086 raise ValueError( 

1087 "Attempt to store data for time partitions that do not yet exist. " 

1088 f"Partitons for new records: {min_part}-{max_part}. " 

1089 f"Database partitons: {time_partitions_range.start}-{time_partitions_range.end}." 

1090 ) 

1091 # Make a noise when writing to the last partition. 

1092 if max_part == time_partitions_range.end: 

1093 warnings.warn( 

1094 "Writing into the last temporal partition. Partition range needs to be extended soon.", 

1095 stacklevel=3, 

1096 ) 

1097 

1098 def _storeDiaSourcesPartitions( 

1099 self, 

1100 sources: pandas.DataFrame, 

1101 visit_time: astropy.time.Time, 

1102 replica_chunk: ReplicaChunk | None, 

1103 subchunk: int | None, 

1104 ) -> None: 

1105 """Store mapping of diaSourceId to its partitioning values. 

1106 

1107 Parameters 

1108 ---------- 

1109 sources : `pandas.DataFrame` 

1110 Catalog containing DiaSource records 

1111 visit_time : `astropy.time.Time` 

1112 Time of the current visit. 

1113 replica_chunk : `ReplicaChunk` or `None` 

1114 Replication chunk, or `None` when replication is disabled. 

1115 subchunk : `int` or `None` 

1116 Replication sub-chunk, or `None` when replication is disabled or 

1117 sub-chunking is not used. 

1118 """ 

1119 context = self._context 

1120 

1121 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1122 extra_columns = { 

1123 "apdb_time_part": context.partitioner.time_partition(visit_time), 

1124 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1125 } 

1126 if context.has_chunk_sub_partitions: 

1127 extra_columns["apdb_replica_subchunk"] = subchunk 

1128 

1129 self._storeObjectsPandas( 

1130 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1131 ) 

1132 

1133 def _storeObjectsPandas( 

1134 self, 

1135 records: pandas.DataFrame, 

1136 table_name: ApdbTables | ExtraTables, 

1137 extra_columns: Mapping | None = None, 

1138 time_part: int | None = None, 

1139 ) -> None: 

1140 """Store generic objects. 

1141 

1142 Takes Pandas catalog and stores a bunch of records in a table. 

1143 

1144 Parameters 

1145 ---------- 

1146 records : `pandas.DataFrame` 

1147 Catalog containing object records 

1148 table_name : `ApdbTables` 

1149 Name of the table as defined in APDB schema. 

1150 extra_columns : `dict`, optional 

1151 Mapping (column_name, column_value) which gives fixed values for 

1152 columns in each row, overrides values in ``records`` if matching 

1153 columns exist there. 

1154 time_part : `int`, optional 

1155 If not `None` then insert into a per-partition table. 

1156 

1157 Notes 

1158 ----- 

1159 If Pandas catalog contains additional columns not defined in table 

1160 schema they are ignored. Catalog does not have to contain all columns 

1161 defined in a table, but partition and clustering keys must be present 

1162 in a catalog or ``extra_columns``. 

1163 """ 

1164 context = self._context 

1165 

1166 # use extra columns if specified 

1167 if extra_columns is None: 

1168 extra_columns = {} 

1169 extra_fields = list(extra_columns.keys()) 

1170 

1171 # Fields that will come from dataframe. 

1172 df_fields = [column for column in records.columns if column not in extra_fields] 

1173 

1174 column_map = context.schema.getColumnMap(table_name) 

1175 # list of columns (as in felis schema) 

1176 fields = [column_map[field].name for field in df_fields if field in column_map] 

1177 fields += extra_fields 

1178 

1179 # check that all partitioning and clustering columns are defined 

1180 partition_columns = context.schema.partitionColumns(table_name) 

1181 required_columns = partition_columns + context.schema.clusteringColumns(table_name) 

1182 missing_columns = [column for column in required_columns if column not in fields] 

1183 if missing_columns: 

1184 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1185 

1186 qfields = [quote_id(field) for field in fields] 

1187 qfields_str = ",".join(qfields) 

1188 

1189 batch_size = self._batch_size(table_name) 

1190 

1191 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1192 # Multi-partition batches are problematic in general, so we want to 

1193 # group records in a batch by their partition key. 

1194 values_by_key: dict[tuple, list[list]] = defaultdict(list) 

1195 for rec in records.itertuples(index=False): 

1196 values = [] 

1197 partitioning_values: dict[str, Any] = {} 

1198 for field in df_fields: 

1199 if field not in column_map: 

1200 continue 

1201 value = getattr(rec, field) 

1202 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1203 if isinstance(value, pandas.Timestamp): 

1204 value = value.to_pydatetime() 

1205 elif value is pandas.NaT: 

1206 value = None 

1207 else: 

1208 # Assume it's seconds since epoch, Cassandra 

1209 # datetime is in milliseconds 

1210 value = int(value * 1000) 

1211 value = literal(value) 

1212 values.append(UNSET_VALUE if value is None else value) 

1213 if field in partition_columns: 

1214 partitioning_values[field] = value 

1215 for field in extra_fields: 

1216 value = literal(extra_columns[field]) 

1217 values.append(UNSET_VALUE if value is None else value) 

1218 if field in partition_columns: 

1219 partitioning_values[field] = value 

1220 

1221 key = tuple(partitioning_values[field] for field in partition_columns) 

1222 values_by_key[key].append(values) 

1223 

1224 table = context.schema.tableName(table_name, time_part) 

1225 

1226 holders = ",".join(["?"] * len(qfields)) 

1227 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1228 statement = context.preparer.prepare(query) 

1229 # Cassandra has 64k limit on batch size, normally that should be 

1230 # enough but some tests generate too many forced sources. 

1231 queries = [] 

1232 for key_values in values_by_key.values(): 

1233 for values_chunk in chunk_iterable(key_values, batch_size): 

1234 batch = cassandra.query.BatchStatement() 

1235 for row_values in values_chunk: 

1236 batch.add(statement, row_values) 

1237 queries.append((batch, None)) 

1238 assert batch.routing_key is not None and batch.keyspace is not None 

1239 

1240 _LOG.debug("%s: will store %d records", context.schema.tableName(table_name), records.shape[0]) 

1241 with self._timer("insert_time", tags={"table": table_name.name}) as timer: 

1242 execute_concurrent(context.session, queries, execution_profile="write") 

1243 timer.add_values(row_count=len(records), num_batches=len(queries)) 

1244 

1245 def _storeUpdateRecords( 

1246 self, records: Iterable[ApdbUpdateRecord], chunk: ReplicaChunk, *, store_chunk: bool = False 

1247 ) -> None: 

1248 """Store ApdbUpdateRecords in the replica table for those records. 

1249 

1250 Parameters 

1251 ---------- 

1252 records : `list` [`ApdbUpdateRecord`] 

1253 Records to store. 

1254 chunk : `ReplicaChunk` 

1255 Replica chunk for these records. 

1256 store_chunk : `bool` 

1257 If True then also store replica chunk. 

1258 

1259 Raises 

1260 ------ 

1261 TypeError 

1262 Raised if replication is not enabled for this instance. 

1263 """ 

1264 context = self._context 

1265 config = context.config 

1266 

1267 if not context.schema.replication_enabled: 

1268 raise TypeError("Replication is not enabled for this APDB instance.") 

1269 

1270 if store_chunk: 

1271 self._storeReplicaChunk(chunk) 

1272 

1273 apdb_replica_chunk = chunk.id 

1274 # Do not use unique_if from ReplicaChunk as it could be reused in 

1275 # multiple calls to this method. 

1276 update_unique_id = uuid.uuid4() 

1277 

1278 rows = [] 

1279 for record in records: 

1280 rows.append( 

1281 [ 

1282 apdb_replica_chunk, 

1283 record.update_time_ns, 

1284 record.update_order, 

1285 update_unique_id, 

1286 record.to_json(), 

1287 ] 

1288 ) 

1289 columns = [ 

1290 "apdb_replica_chunk", 

1291 "update_time_ns", 

1292 "update_order", 

1293 "update_unique_id", 

1294 "update_payload", 

1295 ] 

1296 if context.has_chunk_sub_partitions: 

1297 subchunk = random.randrange(config.replica_sub_chunk_count) 

1298 for row in rows: 

1299 row.append(subchunk) 

1300 columns.append("apdb_replica_subchunk") 

1301 

1302 table_name = context.schema.tableName(ExtraTables.ApdbUpdateRecordChunks) 

1303 placeholders = ", ".join(["%s"] * len(columns)) 

1304 columns_str = ", ".join(columns) 

1305 query = f'INSERT INTO "{self._keyspace}"."{table_name}" ({columns_str}) VALUES ({placeholders})' 

1306 queries = [(query, row) for row in rows] 

1307 

1308 with self._timer("store_update_record") as timer: 

1309 execute_concurrent(context.session, queries, execution_profile="write") 

1310 timer.add_values(row_count=len(queries)) 

1311 

1312 def _add_apdb_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1313 """Calculate spatial partition for each record and add it to a 

1314 DataFrame. 

1315 

1316 Parameters 

1317 ---------- 

1318 df : `pandas.DataFrame` 

1319 DataFrame which has to contain ra/dec columns, names of these 

1320 columns are defined by configuration ``ra_dec_columns`` field. 

1321 

1322 Returns 

1323 ------- 

1324 df : `pandas.DataFrame` 

1325 DataFrame with ``apdb_part`` column which contains pixel index 

1326 for ra/dec coordinates. 

1327 

1328 Notes 

1329 ----- 

1330 This overrides any existing column in a DataFrame with the same name 

1331 (``apdb_part``). Original DataFrame is not changed, copy of a DataFrame 

1332 is returned. 

1333 """ 

1334 context = self._context 

1335 config = context.config 

1336 

1337 # Calculate pixelization index for every record. 

1338 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1339 ra_col, dec_col = config.ra_dec_columns 

1340 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1341 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1342 idx = context.partitioner.pixel(uv3d) 

1343 apdb_part[i] = idx 

1344 df = df.copy() 

1345 df["apdb_part"] = apdb_part 

1346 return df 

1347 

1348 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1349 """Make an empty catalog for a table with a given name. 

1350 

1351 Parameters 

1352 ---------- 

1353 table_name : `ApdbTables` 

1354 Name of the table. 

1355 

1356 Returns 

1357 ------- 

1358 catalog : `pandas.DataFrame` 

1359 An empty catalog. 

1360 """ 

1361 table = self._schema.tableSchemas[table_name] 

1362 

1363 data = { 

1364 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1365 for columnDef in table.columns 

1366 } 

1367 return pandas.DataFrame(data) 

1368 

1369 def _combine_where( 

1370 self, 

1371 prefix: str, 

1372 where1: list[tuple[str, tuple]], 

1373 where2: list[tuple[str, tuple]], 

1374 where3: tuple[str, tuple] | None = None, 

1375 suffix: str | None = None, 

1376 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1377 """Make cartesian product of two parts of WHERE clause into a series 

1378 of statements to execute. 

1379 

1380 Parameters 

1381 ---------- 

1382 prefix : `str` 

1383 Initial statement prefix that comes before WHERE clause, e.g. 

1384 "SELECT * from Table" 

1385 """ 

1386 context = self._context 

1387 

1388 # If lists are empty use special sentinels. 

1389 if not where1: 

1390 where1 = [("", ())] 

1391 if not where2: 

1392 where2 = [("", ())] 

1393 

1394 for expr1, params1 in where1: 

1395 for expr2, params2 in where2: 

1396 full_query = prefix 

1397 wheres = [] 

1398 params = params1 + params2 

1399 if expr1: 

1400 wheres.append(expr1) 

1401 if expr2: 

1402 wheres.append(expr2) 

1403 if where3: 

1404 wheres.append(where3[0]) 

1405 params += where3[1] 

1406 if wheres: 

1407 full_query += " WHERE " + " AND ".join(wheres) 

1408 if suffix: 

1409 full_query += " " + suffix 

1410 if params: 

1411 statement = context.preparer.prepare(full_query) 

1412 else: 

1413 # If there are no params then it is likely that query 

1414 # has a bunch of literals rendered already, no point 

1415 # trying to prepare it. 

1416 statement = cassandra.query.SimpleStatement(full_query) 

1417 yield (statement, params) 

1418 

1419 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1420 """Update timestamp columns in input DataFrame to be naive datetime 

1421 type. 

1422 

1423 Clients may or may not generate aware timestamps, code in this class 

1424 assumes that timestamps are naive, so we convert them to UTC and 

1425 drop timezone. 

1426 """ 

1427 # Find all columns with aware timestamps. 

1428 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)] 

1429 for column in columns: 

1430 # tz_convert(None) will convert to UTC and drop timezone. 

1431 df[column] = df[column].dt.tz_convert(None) 

1432 return df 

1433 

1434 def _batch_size(self, table: ApdbTables | ExtraTables) -> int: 

1435 """Calculate batch size based on config parameters.""" 

1436 context = self._context 

1437 config = context.config 

1438 

1439 # Cassandra limit on number of statements in a batch is 64k. 

1440 batch_size = 65_535 

1441 if 0 < config.batch_statement_limit < batch_size: 

1442 batch_size = config.batch_statement_limit 

1443 if config.batch_size_limit > 0: 

1444 # The purpose of this limit is to try not to exceed batch size 

1445 # threshold which is set on server side. Cassandra wire protocol 

1446 # for prepared queries (and batches) only sends column values with 

1447 # with an additional 4 bytes per value specifying size. Value is 

1448 # not included for NULL or NOT_SET values, but the size is always 

1449 # there. There is additional small per-query overhead, which we 

1450 # ignore. 

1451 row_size = context.schema.table_row_size(table) 

1452 row_size += 4 * len(context.schema.getColumnMap(table)) 

1453 batch_size = min(batch_size, (config.batch_size_limit // row_size) + 1) 

1454 return batch_size