Coverage for python/lsst/dax/apdb/cassandra/apdbCassandra.py: 18%

599 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:30 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import dataclasses 

27import json 

28import logging 

29from collections.abc import Iterable, Iterator, Mapping, Set 

30from typing import TYPE_CHECKING, Any, cast 

31 

32import numpy as np 

33import pandas 

34 

35# If cassandra-driver is not there the module can still be imported 

36# but ApdbCassandra cannot be instantiated. 

37try: 

38 import cassandra 

39 import cassandra.query 

40 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

41 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

42 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

43 

44 CASSANDRA_IMPORTED = True 

45except ImportError: 

46 CASSANDRA_IMPORTED = False 

47 

48import astropy.time 

49import felis.datamodel 

50from lsst import sphgeom 

51from lsst.pex.config import ChoiceField, Field, ListField 

52from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

53from lsst.utils.iteration import chunk_iterable 

54 

55from ..apdb import Apdb, ApdbConfig 

56from ..apdbConfigFreezer import ApdbConfigFreezer 

57from ..apdbReplica import ReplicaChunk 

58from ..apdbSchema import ApdbTables 

59from ..monitor import MonAgent 

60from ..pixelization import Pixelization 

61from ..schema_model import Table 

62from ..timer import Timer 

63from ..versionTuple import IncompatibleVersionError, VersionTuple 

64from .apdbCassandraReplica import ApdbCassandraReplica 

65from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

66from .apdbMetadataCassandra import ApdbMetadataCassandra 

67from .cassandra_utils import ( 

68 PreparedStatementCache, 

69 literal, 

70 pandas_dataframe_factory, 

71 quote_id, 

72 raw_data_factory, 

73 select_concurrent, 

74) 

75 

76if TYPE_CHECKING: 

77 from ..apdbMetadata import ApdbMetadata 

78 

79_LOG = logging.getLogger(__name__) 

80 

81_MON = MonAgent(__name__) 

82 

83VERSION = VersionTuple(0, 1, 0) 

84"""Version for the code controlling non-replication tables. This needs to be 

85updated following compatibility rules when schema produced by this code 

86changes. 

87""" 

88 

89# Copied from daf_butler. 

90DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

91"""Default name of the environmental variable that will be used to locate DB 

92credentials configuration file. """ 

93 

94DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

95"""Default path at which it is expected that DB credentials are found.""" 

96 

97 

98class CassandraMissingError(Exception): 

99 def __init__(self) -> None: 

100 super().__init__("cassandra-driver module cannot be imported") 

101 

102 

103class ApdbCassandraConfig(ApdbConfig): 

104 """Configuration class for Cassandra-based APDB implementation.""" 

105 

106 contact_points = ListField[str]( 

107 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

108 ) 

109 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

110 port = Field[int](doc="Port number to connect to.", default=9042) 

111 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

112 username = Field[str]( 

113 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

114 default="", 

115 ) 

116 read_consistency = Field[str]( 

117 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

118 ) 

119 write_consistency = Field[str]( 

120 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

121 ) 

122 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

123 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=60.0) 

124 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0) 

125 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

126 protocol_version = Field[int]( 

127 doc="Cassandra protocol version to use, default is V4", 

128 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

129 ) 

130 dia_object_columns = ListField[str]( 

131 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

132 ) 

133 prefix = Field[str](doc="Prefix to add to table names", default="") 

134 part_pixelization = ChoiceField[str]( 

135 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

136 doc="Pixelization used for partitioning index.", 

137 default="mq3c", 

138 ) 

139 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

140 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

141 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

142 timer = Field[bool](doc="If True then print/log timing information", default=False) 

143 time_partition_tables = Field[bool]( 

144 doc="Use per-partition tables for sources instead of partitioning by time", default=False 

145 ) 

146 time_partition_days = Field[int]( 

147 doc=( 

148 "Time partitioning granularity in days, this value must not be changed after database is " 

149 "initialized" 

150 ), 

151 default=30, 

152 ) 

153 time_partition_start = Field[str]( 

154 doc=( 

155 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

156 "This is used only when time_partition_tables is True." 

157 ), 

158 default="2018-12-01T00:00:00", 

159 ) 

160 time_partition_end = Field[str]( 

161 doc=( 

162 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

163 "This is used only when time_partition_tables is True." 

164 ), 

165 default="2030-01-01T00:00:00", 

166 ) 

167 query_per_time_part = Field[bool]( 

168 default=False, 

169 doc=( 

170 "If True then build separate query for each time partition, otherwise build one single query. " 

171 "This is only used when time_partition_tables is False in schema config." 

172 ), 

173 ) 

174 query_per_spatial_part = Field[bool]( 

175 default=False, 

176 doc="If True then build one query per spatial partition, otherwise build single query.", 

177 ) 

178 use_insert_id_skips_diaobjects = Field[bool]( 

179 default=False, 

180 doc=( 

181 "If True then do not store DiaObjects when use_insert_id is True " 

182 "(DiaObjectsChunks has the same data)." 

183 ), 

184 ) 

185 

186 

187@dataclasses.dataclass 

188class _FrozenApdbCassandraConfig: 

189 """Part of the configuration that is saved in metadata table and read back. 

190 

191 The attributes are a subset of attributes in `ApdbCassandraConfig` class. 

192 

193 Parameters 

194 ---------- 

195 config : `ApdbSqlConfig` 

196 Configuration used to copy initial values of attributes. 

197 """ 

198 

199 use_insert_id: bool 

200 part_pixelization: str 

201 part_pix_level: int 

202 ra_dec_columns: list[str] 

203 time_partition_tables: bool 

204 time_partition_days: int 

205 use_insert_id_skips_diaobjects: bool 

206 

207 def __init__(self, config: ApdbCassandraConfig): 

208 self.use_insert_id = config.use_insert_id 

209 self.part_pixelization = config.part_pixelization 

210 self.part_pix_level = config.part_pix_level 

211 self.ra_dec_columns = list(config.ra_dec_columns) 

212 self.time_partition_tables = config.time_partition_tables 

213 self.time_partition_days = config.time_partition_days 

214 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects 

215 

216 def to_json(self) -> str: 

217 """Convert this instance to JSON representation.""" 

218 return json.dumps(dataclasses.asdict(self)) 

219 

220 def update(self, json_str: str) -> None: 

221 """Update attribute values from a JSON string. 

222 

223 Parameters 

224 ---------- 

225 json_str : str 

226 String containing JSON representation of configuration. 

227 """ 

228 data = json.loads(json_str) 

229 if not isinstance(data, dict): 

230 raise TypeError(f"JSON string must be convertible to object: {json_str!r}") 

231 allowed_names = {field.name for field in dataclasses.fields(self)} 

232 for key, value in data.items(): 

233 if key not in allowed_names: 

234 raise ValueError(f"JSON object contains unknown key: {key}") 

235 setattr(self, key, value) 

236 

237 

238if CASSANDRA_IMPORTED: 238 ↛ 240line 238 didn't jump to line 240, because the condition on line 238 was never true

239 

240 class _AddressTranslator(AddressTranslator): 

241 """Translate internal IP address to external. 

242 

243 Only used for docker-based setup, not viable long-term solution. 

244 """ 

245 

246 def __init__(self, public_ips: list[str], private_ips: list[str]): 

247 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

248 

249 def translate(self, private_ip: str) -> str: 

250 return self._map.get(private_ip, private_ip) 

251 

252 

253class ApdbCassandra(Apdb): 

254 """Implementation of APDB database on to of Apache Cassandra. 

255 

256 The implementation is configured via standard ``pex_config`` mechanism 

257 using `ApdbCassandraConfig` configuration class. For an example of 

258 different configurations check config/ folder. 

259 

260 Parameters 

261 ---------- 

262 config : `ApdbCassandraConfig` 

263 Configuration object. 

264 """ 

265 

266 metadataSchemaVersionKey = "version:schema" 

267 """Name of the metadata key to store schema version number.""" 

268 

269 metadataCodeVersionKey = "version:ApdbCassandra" 

270 """Name of the metadata key to store code version number.""" 

271 

272 metadataReplicaVersionKey = "version:ApdbCassandraReplica" 

273 """Name of the metadata key to store replica code version number.""" 

274 

275 metadataConfigKey = "config:apdb-cassandra.json" 

276 """Name of the metadata key to store code version number.""" 

277 

278 _frozen_parameters = ( 

279 "use_insert_id", 

280 "part_pixelization", 

281 "part_pix_level", 

282 "ra_dec_columns", 

283 "time_partition_tables", 

284 "time_partition_days", 

285 "use_insert_id_skips_diaobjects", 

286 ) 

287 """Names of the config parameters to be frozen in metadata table.""" 

288 

289 partition_zero_epoch = astropy.time.Time(0, format="unix_tai") 

290 """Start time for partition 0, this should never be changed.""" 

291 

292 def __init__(self, config: ApdbCassandraConfig): 

293 if not CASSANDRA_IMPORTED: 

294 raise CassandraMissingError() 

295 

296 self._keyspace = config.keyspace 

297 

298 self._cluster, self._session = self._make_session(config) 

299 

300 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

301 self._metadata = ApdbMetadataCassandra( 

302 self._session, meta_table_name, config.keyspace, "read_tuples", "write" 

303 ) 

304 

305 # Read frozen config from metadata. 

306 config_json = self._metadata.get(self.metadataConfigKey) 

307 if config_json is not None: 

308 # Update config from metadata. 

309 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters) 

310 self.config = freezer.update(config, config_json) 

311 else: 

312 self.config = config 

313 self.config.validate() 

314 

315 self._pixelization = Pixelization( 

316 self.config.part_pixelization, 

317 self.config.part_pix_level, 

318 config.part_pix_max_ranges, 

319 ) 

320 

321 self._schema = ApdbCassandraSchema( 

322 session=self._session, 

323 keyspace=self._keyspace, 

324 schema_file=self.config.schema_file, 

325 schema_name=self.config.schema_name, 

326 prefix=self.config.prefix, 

327 time_partition_tables=self.config.time_partition_tables, 

328 enable_replica=self.config.use_insert_id, 

329 ) 

330 self._partition_zero_epoch_mjd = float(self.partition_zero_epoch.mjd) 

331 

332 if self._metadata.table_exists(): 

333 self._versionCheck(self._metadata) 

334 

335 # Cache for prepared statements 

336 self._preparer = PreparedStatementCache(self._session) 

337 

338 _LOG.debug("ApdbCassandra Configuration:") 

339 for key, value in self.config.items(): 

340 _LOG.debug(" %s: %s", key, value) 

341 

342 self._timer_args: list[MonAgent | logging.Logger] = [_MON] 

343 if self.config.timer: 

344 self._timer_args.append(_LOG) 

345 

346 def __del__(self) -> None: 

347 if hasattr(self, "_cluster"): 

348 self._cluster.shutdown() 

349 

350 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

351 """Create `Timer` instance given its name.""" 

352 return Timer(name, *self._timer_args, tags=tags) 

353 

354 @classmethod 

355 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]: 

356 """Make Cassandra session.""" 

357 addressTranslator: AddressTranslator | None = None 

358 if config.private_ips: 

359 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

360 

361 cluster = Cluster( 

362 execution_profiles=cls._makeProfiles(config), 

363 contact_points=config.contact_points, 

364 port=config.port, 

365 address_translator=addressTranslator, 

366 protocol_version=config.protocol_version, 

367 auth_provider=cls._make_auth_provider(config), 

368 ) 

369 session = cluster.connect() 

370 # Disable result paging 

371 session.default_fetch_size = None 

372 

373 return cluster, session 

374 

375 @classmethod 

376 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None: 

377 """Make Cassandra authentication provider instance.""" 

378 try: 

379 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

380 except DbAuthNotFoundError: 

381 # Credentials file doesn't exist, use anonymous login. 

382 return None 

383 

384 empty_username = True 

385 # Try every contact point in turn. 

386 for hostname in config.contact_points: 

387 try: 

388 username, password = dbauth.getAuth( 

389 "cassandra", config.username, hostname, config.port, config.keyspace 

390 ) 

391 if not username: 

392 # Password without user name, try next hostname, but give 

393 # warning later if no better match is found. 

394 empty_username = True 

395 else: 

396 return PlainTextAuthProvider(username=username, password=password) 

397 except DbAuthNotFoundError: 

398 pass 

399 

400 if empty_username: 

401 _LOG.warning( 

402 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

403 f"user name, anonymous Cassandra logon will be attempted." 

404 ) 

405 

406 return None 

407 

408 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None: 

409 """Check schema version compatibility.""" 

410 

411 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

412 """Retrieve version number from given metadata key.""" 

413 if metadata.table_exists(): 

414 version_str = metadata.get(key) 

415 if version_str is None: 

416 # Should not happen with existing metadata table. 

417 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

418 return VersionTuple.fromString(version_str) 

419 return default 

420 

421 # For old databases where metadata table does not exist we assume that 

422 # version of both code and schema is 0.1.0. 

423 initial_version = VersionTuple(0, 1, 0) 

424 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

425 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

426 

427 # For now there is no way to make read-only APDB instances, assume that 

428 # any access can do updates. 

429 if not self._schema.schemaVersion().checkCompatibility(db_schema_version): 

430 raise IncompatibleVersionError( 

431 f"Configured schema version {self._schema.schemaVersion()} " 

432 f"is not compatible with database version {db_schema_version}" 

433 ) 

434 if not self.apdbImplementationVersion().checkCompatibility(db_code_version): 

435 raise IncompatibleVersionError( 

436 f"Current code version {self.apdbImplementationVersion()} " 

437 f"is not compatible with database version {db_code_version}" 

438 ) 

439 

440 # Check replica code version only if replica is enabled. 

441 if self._schema.has_replica_chunks: 

442 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version) 

443 code_replica_version = ApdbCassandraReplica.apdbReplicaImplementationVersion() 

444 if not code_replica_version.checkCompatibility(db_replica_version): 

445 raise IncompatibleVersionError( 

446 f"Current replication code version {code_replica_version} " 

447 f"is not compatible with database version {db_replica_version}" 

448 ) 

449 

450 @classmethod 

451 def apdbImplementationVersion(cls) -> VersionTuple: 

452 """Return version number for current APDB implementation. 

453 

454 Returns 

455 ------- 

456 version : `VersionTuple` 

457 Version of the code defined in implementation class. 

458 """ 

459 return VERSION 

460 

461 def tableDef(self, table: ApdbTables) -> Table | None: 

462 # docstring is inherited from a base class 

463 return self._schema.tableSchemas.get(table) 

464 

465 @classmethod 

466 def init_database( 

467 cls, 

468 hosts: list[str], 

469 keyspace: str, 

470 *, 

471 schema_file: str | None = None, 

472 schema_name: str | None = None, 

473 read_sources_months: int | None = None, 

474 read_forced_sources_months: int | None = None, 

475 use_insert_id: bool = False, 

476 use_insert_id_skips_diaobjects: bool = False, 

477 port: int | None = None, 

478 username: str | None = None, 

479 prefix: str | None = None, 

480 part_pixelization: str | None = None, 

481 part_pix_level: int | None = None, 

482 time_partition_tables: bool = True, 

483 time_partition_start: str | None = None, 

484 time_partition_end: str | None = None, 

485 read_consistency: str | None = None, 

486 write_consistency: str | None = None, 

487 read_timeout: int | None = None, 

488 write_timeout: int | None = None, 

489 ra_dec_columns: list[str] | None = None, 

490 replication_factor: int | None = None, 

491 drop: bool = False, 

492 ) -> ApdbCassandraConfig: 

493 """Initialize new APDB instance and make configuration object for it. 

494 

495 Parameters 

496 ---------- 

497 hosts : `list` [`str`] 

498 List of host names or IP addresses for Cassandra cluster. 

499 keyspace : `str` 

500 Name of the keyspace for APDB tables. 

501 schema_file : `str`, optional 

502 Location of (YAML) configuration file with APDB schema. If not 

503 specified then default location will be used. 

504 schema_name : `str`, optional 

505 Name of the schema in YAML configuration file. If not specified 

506 then default name will be used. 

507 read_sources_months : `int`, optional 

508 Number of months of history to read from DiaSource. 

509 read_forced_sources_months : `int`, optional 

510 Number of months of history to read from DiaForcedSource. 

511 use_insert_id : `bool`, optional 

512 If True, make additional tables used for replication to PPDB. 

513 use_insert_id_skips_diaobjects : `bool`, optional 

514 If `True` then do not fill regular ``DiaObject`` table when 

515 ``use_insert_id`` is `True`. 

516 port : `int`, optional 

517 Port number to use for Cassandra connections. 

518 username : `str`, optional 

519 User name for Cassandra connections. 

520 prefix : `str`, optional 

521 Optional prefix for all table names. 

522 part_pixelization : `str`, optional 

523 Name of the MOC pixelization used for partitioning. 

524 part_pix_level : `int`, optional 

525 Pixelization level. 

526 time_partition_tables : `bool`, optional 

527 Create per-partition tables. 

528 time_partition_start : `str`, optional 

529 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

530 format, in TAI. 

531 time_partition_end : `str`, optional 

532 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

533 format, in TAI. 

534 read_consistency : `str`, optional 

535 Name of the consistency level for read operations. 

536 write_consistency : `str`, optional 

537 Name of the consistency level for write operations. 

538 read_timeout : `int`, optional 

539 Read timeout in seconds. 

540 write_timeout : `int`, optional 

541 Write timeout in seconds. 

542 ra_dec_columns : `list` [`str`], optional 

543 Names of ra/dec columns in DiaObject table. 

544 replication_factor : `int`, optional 

545 Replication factor used when creating new keyspace, if keyspace 

546 already exists its replication factor is not changed. 

547 drop : `bool`, optional 

548 If `True` then drop existing tables before re-creating the schema. 

549 

550 Returns 

551 ------- 

552 config : `ApdbCassandraConfig` 

553 Resulting configuration object for a created APDB instance. 

554 """ 

555 config = ApdbCassandraConfig( 

556 contact_points=hosts, 

557 keyspace=keyspace, 

558 use_insert_id=use_insert_id, 

559 use_insert_id_skips_diaobjects=use_insert_id_skips_diaobjects, 

560 time_partition_tables=time_partition_tables, 

561 ) 

562 if schema_file is not None: 

563 config.schema_file = schema_file 

564 if schema_name is not None: 

565 config.schema_name = schema_name 

566 if read_sources_months is not None: 

567 config.read_sources_months = read_sources_months 

568 if read_forced_sources_months is not None: 

569 config.read_forced_sources_months = read_forced_sources_months 

570 if port is not None: 

571 config.port = port 

572 if username is not None: 

573 config.username = username 

574 if prefix is not None: 

575 config.prefix = prefix 

576 if part_pixelization is not None: 

577 config.part_pixelization = part_pixelization 

578 if part_pix_level is not None: 

579 config.part_pix_level = part_pix_level 

580 if time_partition_start is not None: 

581 config.time_partition_start = time_partition_start 

582 if time_partition_end is not None: 

583 config.time_partition_end = time_partition_end 

584 if read_consistency is not None: 

585 config.read_consistency = read_consistency 

586 if write_consistency is not None: 

587 config.write_consistency = write_consistency 

588 if read_timeout is not None: 

589 config.read_timeout = read_timeout 

590 if write_timeout is not None: 

591 config.write_timeout = write_timeout 

592 if ra_dec_columns is not None: 

593 config.ra_dec_columns = ra_dec_columns 

594 

595 cls._makeSchema(config, drop=drop, replication_factor=replication_factor) 

596 

597 return config 

598 

599 def get_replica(self) -> ApdbCassandraReplica: 

600 """Return `ApdbReplica` instance for this database.""" 

601 # Note that this instance has to stay alive while replica exists, so 

602 # we pass reference to self. 

603 return ApdbCassandraReplica(self, self._schema, self._session) 

604 

605 @classmethod 

606 def _makeSchema( 

607 cls, config: ApdbConfig, *, drop: bool = False, replication_factor: int | None = None 

608 ) -> None: 

609 # docstring is inherited from a base class 

610 

611 if not isinstance(config, ApdbCassandraConfig): 

612 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

613 

614 cluster, session = cls._make_session(config) 

615 

616 schema = ApdbCassandraSchema( 

617 session=session, 

618 keyspace=config.keyspace, 

619 schema_file=config.schema_file, 

620 schema_name=config.schema_name, 

621 prefix=config.prefix, 

622 time_partition_tables=config.time_partition_tables, 

623 enable_replica=config.use_insert_id, 

624 ) 

625 

626 # Ask schema to create all tables. 

627 if config.time_partition_tables: 

628 time_partition_start = astropy.time.Time(config.time_partition_start, format="isot", scale="tai") 

629 time_partition_end = astropy.time.Time(config.time_partition_end, format="isot", scale="tai") 

630 part_epoch = float(cls.partition_zero_epoch.mjd) 

631 part_days = config.time_partition_days 

632 part_range = ( 

633 cls._time_partition_cls(time_partition_start, part_epoch, part_days), 

634 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1, 

635 ) 

636 schema.makeSchema(drop=drop, part_range=part_range, replication_factor=replication_factor) 

637 else: 

638 schema.makeSchema(drop=drop, replication_factor=replication_factor) 

639 

640 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

641 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write") 

642 

643 # Fill version numbers, overrides if they existed before. 

644 if metadata.table_exists(): 

645 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

646 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

647 

648 if config.use_insert_id: 

649 # Only store replica code version if replica is enabled. 

650 metadata.set( 

651 cls.metadataReplicaVersionKey, 

652 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

653 force=True, 

654 ) 

655 

656 # Store frozen part of a configuration in metadata. 

657 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters) 

658 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

659 

660 cluster.shutdown() 

661 

662 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

663 # docstring is inherited from a base class 

664 

665 sp_where = self._spatial_where(region) 

666 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

667 

668 # We need to exclude extra partitioning columns from result. 

669 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

670 what = ",".join(quote_id(column) for column in column_names) 

671 

672 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

673 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

674 statements: list[tuple] = [] 

675 for where, params in sp_where: 

676 full_query = f"{query} WHERE {where}" 

677 if params: 

678 statement = self._preparer.prepare(full_query) 

679 else: 

680 # If there are no params then it is likely that query has a 

681 # bunch of literals rendered already, no point trying to 

682 # prepare it because it's not reusable. 

683 statement = cassandra.query.SimpleStatement(full_query) 

684 statements.append((statement, params)) 

685 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

686 

687 with _MON.context_tags({"table": "DiaObject"}): 

688 _MON.add_record( 

689 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

690 ) 

691 with self._timer("select_time"): 

692 objects = cast( 

693 pandas.DataFrame, 

694 select_concurrent( 

695 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

696 ), 

697 ) 

698 

699 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

700 return objects 

701 

702 def getDiaSources( 

703 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

704 ) -> pandas.DataFrame | None: 

705 # docstring is inherited from a base class 

706 months = self.config.read_sources_months 

707 if months == 0: 

708 return None 

709 mjd_end = visit_time.mjd 

710 mjd_start = mjd_end - months * 30 

711 

712 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

713 

714 def getDiaForcedSources( 

715 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

716 ) -> pandas.DataFrame | None: 

717 # docstring is inherited from a base class 

718 months = self.config.read_forced_sources_months 

719 if months == 0: 

720 return None 

721 mjd_end = visit_time.mjd 

722 mjd_start = mjd_end - months * 30 

723 

724 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

725 

726 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

727 # docstring is inherited from a base class 

728 raise NotImplementedError() 

729 

730 def getSSObjects(self) -> pandas.DataFrame: 

731 # docstring is inherited from a base class 

732 tableName = self._schema.tableName(ApdbTables.SSObject) 

733 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

734 

735 objects = None 

736 with self._timer("select_time", tags={"table": "SSObject"}): 

737 result = self._session.execute(query, execution_profile="read_pandas") 

738 objects = result._current_rows 

739 

740 _LOG.debug("found %s SSObjects", objects.shape[0]) 

741 return objects 

742 

743 def store( 

744 self, 

745 visit_time: astropy.time.Time, 

746 objects: pandas.DataFrame, 

747 sources: pandas.DataFrame | None = None, 

748 forced_sources: pandas.DataFrame | None = None, 

749 ) -> None: 

750 # docstring is inherited from a base class 

751 

752 replica_chunk: ReplicaChunk | None = None 

753 if self._schema.has_replica_chunks: 

754 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

755 self._storeReplicaChunk(replica_chunk, visit_time) 

756 

757 # fill region partition column for DiaObjects 

758 objects = self._add_obj_part(objects) 

759 self._storeDiaObjects(objects, visit_time, replica_chunk) 

760 

761 if sources is not None: 

762 # copy apdb_part column from DiaObjects to DiaSources 

763 sources = self._add_src_part(sources, objects) 

764 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, replica_chunk) 

765 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk) 

766 

767 if forced_sources is not None: 

768 forced_sources = self._add_fsrc_part(forced_sources, objects) 

769 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, replica_chunk) 

770 

771 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

772 # docstring is inherited from a base class 

773 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

774 

775 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

776 # docstring is inherited from a base class 

777 

778 # To update a record we need to know its exact primary key (including 

779 # partition key) so we start by querying for diaSourceId to find the 

780 # primary keys. 

781 

782 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

783 # split it into 1k IDs per query 

784 selects: list[tuple] = [] 

785 for ids in chunk_iterable(idMap.keys(), 1_000): 

786 ids_str = ",".join(str(item) for item in ids) 

787 selects.append( 

788 ( 

789 ( 

790 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk" ' 

791 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

792 ), 

793 {}, 

794 ) 

795 ) 

796 

797 # No need for DataFrame here, read data as tuples. 

798 result = cast( 

799 list[tuple[int, int, int, int | None]], 

800 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

801 ) 

802 

803 # Make mapping from source ID to its partition. 

804 id2partitions: dict[int, tuple[int, int]] = {} 

805 id2chunk_id: dict[int, int] = {} 

806 for row in result: 

807 id2partitions[row[0]] = row[1:3] 

808 if row[3] is not None: 

809 id2chunk_id[row[0]] = row[3] 

810 

811 # make sure we know partitions for each ID 

812 if set(id2partitions) != set(idMap): 

813 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

814 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

815 

816 # Reassign in standard tables 

817 queries = cassandra.query.BatchStatement() 

818 table_name = self._schema.tableName(ApdbTables.DiaSource) 

819 for diaSourceId, ssObjectId in idMap.items(): 

820 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

821 values: tuple 

822 if self.config.time_partition_tables: 

823 query = ( 

824 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

825 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

826 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

827 ) 

828 values = (ssObjectId, apdb_part, diaSourceId) 

829 else: 

830 query = ( 

831 f'UPDATE "{self._keyspace}"."{table_name}"' 

832 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

833 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

834 ) 

835 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

836 queries.add(self._preparer.prepare(query), values) 

837 

838 # Reassign in replica tables, only if replication is enabled 

839 if id2chunk_id: 

840 # Filter out chunks that have been deleted already. There is a 

841 # potential race with concurrent removal of chunks, but it 

842 # should be handled by WHERE in UPDATE. 

843 known_ids = set() 

844 if replica_chunks := self.get_replica().getReplicaChunks(): 

845 known_ids = set(replica_chunk.id for replica_chunk in replica_chunks) 

846 id2chunk_id = {key: value for key, value in id2chunk_id.items() if value in known_ids} 

847 if id2chunk_id: 

848 table_name = self._schema.tableName(ExtraTables.DiaSourceChunks) 

849 for diaSourceId, ssObjectId in idMap.items(): 

850 if replica_chunk := id2chunk_id.get(diaSourceId): 

851 query = ( 

852 f'UPDATE "{self._keyspace}"."{table_name}" ' 

853 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

854 'WHERE "apdb_replica_chunk" = ? AND "diaSourceId" = ?' 

855 ) 

856 values = (ssObjectId, replica_chunk, diaSourceId) 

857 queries.add(self._preparer.prepare(query), values) 

858 

859 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

860 with self._timer("update_time", tags={"table": table_name}): 

861 self._session.execute(queries, execution_profile="write") 

862 

863 def dailyJob(self) -> None: 

864 # docstring is inherited from a base class 

865 pass 

866 

867 def countUnassociatedObjects(self) -> int: 

868 # docstring is inherited from a base class 

869 

870 # It's too inefficient to implement it for Cassandra in current schema. 

871 raise NotImplementedError() 

872 

873 @property 

874 def metadata(self) -> ApdbMetadata: 

875 # docstring is inherited from a base class 

876 if self._metadata is None: 

877 raise RuntimeError("Database schema was not initialized.") 

878 return self._metadata 

879 

880 @classmethod 

881 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

882 """Make all execution profiles used in the code.""" 

883 if config.private_ips: 

884 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

885 else: 

886 loadBalancePolicy = RoundRobinPolicy() 

887 

888 read_tuples_profile = ExecutionProfile( 

889 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

890 request_timeout=config.read_timeout, 

891 row_factory=cassandra.query.tuple_factory, 

892 load_balancing_policy=loadBalancePolicy, 

893 ) 

894 read_pandas_profile = ExecutionProfile( 

895 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

896 request_timeout=config.read_timeout, 

897 row_factory=pandas_dataframe_factory, 

898 load_balancing_policy=loadBalancePolicy, 

899 ) 

900 read_raw_profile = ExecutionProfile( 

901 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

902 request_timeout=config.read_timeout, 

903 row_factory=raw_data_factory, 

904 load_balancing_policy=loadBalancePolicy, 

905 ) 

906 # Profile to use with select_concurrent to return pandas data frame 

907 read_pandas_multi_profile = ExecutionProfile( 

908 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

909 request_timeout=config.read_timeout, 

910 row_factory=pandas_dataframe_factory, 

911 load_balancing_policy=loadBalancePolicy, 

912 ) 

913 # Profile to use with select_concurrent to return raw data (columns and 

914 # rows) 

915 read_raw_multi_profile = ExecutionProfile( 

916 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

917 request_timeout=config.read_timeout, 

918 row_factory=raw_data_factory, 

919 load_balancing_policy=loadBalancePolicy, 

920 ) 

921 write_profile = ExecutionProfile( 

922 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

923 request_timeout=config.write_timeout, 

924 load_balancing_policy=loadBalancePolicy, 

925 ) 

926 # To replace default DCAwareRoundRobinPolicy 

927 default_profile = ExecutionProfile( 

928 load_balancing_policy=loadBalancePolicy, 

929 ) 

930 return { 

931 "read_tuples": read_tuples_profile, 

932 "read_pandas": read_pandas_profile, 

933 "read_raw": read_raw_profile, 

934 "read_pandas_multi": read_pandas_multi_profile, 

935 "read_raw_multi": read_raw_multi_profile, 

936 "write": write_profile, 

937 EXEC_PROFILE_DEFAULT: default_profile, 

938 } 

939 

940 def _getSources( 

941 self, 

942 region: sphgeom.Region, 

943 object_ids: Iterable[int] | None, 

944 mjd_start: float, 

945 mjd_end: float, 

946 table_name: ApdbTables, 

947 ) -> pandas.DataFrame: 

948 """Return catalog of DiaSource instances given set of DiaObject IDs. 

949 

950 Parameters 

951 ---------- 

952 region : `lsst.sphgeom.Region` 

953 Spherical region. 

954 object_ids : 

955 Collection of DiaObject IDs 

956 mjd_start : `float` 

957 Lower bound of time interval. 

958 mjd_end : `float` 

959 Upper bound of time interval. 

960 table_name : `ApdbTables` 

961 Name of the table. 

962 

963 Returns 

964 ------- 

965 catalog : `pandas.DataFrame`, or `None` 

966 Catalog containing DiaSource records. Empty catalog is returned if 

967 ``object_ids`` is empty. 

968 """ 

969 object_id_set: Set[int] = set() 

970 if object_ids is not None: 

971 object_id_set = set(object_ids) 

972 if len(object_id_set) == 0: 

973 return self._make_empty_catalog(table_name) 

974 

975 sp_where = self._spatial_where(region) 

976 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

977 

978 # We need to exclude extra partitioning columns from result. 

979 column_names = self._schema.apdbColumnNames(table_name) 

980 what = ",".join(quote_id(column) for column in column_names) 

981 

982 # Build all queries 

983 statements: list[tuple] = [] 

984 for table in tables: 

985 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

986 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

987 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

988 

989 with _MON.context_tags({"table": table_name.name}): 

990 _MON.add_record( 

991 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

992 ) 

993 with self._timer("select_time"): 

994 catalog = cast( 

995 pandas.DataFrame, 

996 select_concurrent( 

997 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

998 ), 

999 ) 

1000 

1001 # filter by given object IDs 

1002 if len(object_id_set) > 0: 

1003 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1004 

1005 # precise filtering on midpointMjdTai 

1006 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1007 

1008 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1009 return catalog 

1010 

1011 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk, visit_time: astropy.time.Time) -> None: 

1012 # Cassandra timestamp uses milliseconds since epoch 

1013 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

1014 

1015 # everything goes into a single partition 

1016 partition = 0 

1017 

1018 table_name = self._schema.tableName(ExtraTables.ApdbReplicaChunks) 

1019 query = ( 

1020 f'INSERT INTO "{self._keyspace}"."{table_name}" ' 

1021 "(partition, apdb_replica_chunk, last_update_time, unique_id) " 

1022 "VALUES (?, ?, ?, ?)" 

1023 ) 

1024 

1025 self._session.execute( 

1026 self._preparer.prepare(query), 

1027 (partition, replica_chunk.id, timestamp, replica_chunk.unique_id), 

1028 timeout=self.config.write_timeout, 

1029 execution_profile="write", 

1030 ) 

1031 

1032 def _storeDiaObjects( 

1033 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1034 ) -> None: 

1035 """Store catalog of DiaObjects from current visit. 

1036 

1037 Parameters 

1038 ---------- 

1039 objs : `pandas.DataFrame` 

1040 Catalog with DiaObject records 

1041 visit_time : `astropy.time.Time` 

1042 Time of the current visit. 

1043 replica_chunk : `ReplicaChunk` or `None` 

1044 Replica chunk identifier if replication is configured. 

1045 """ 

1046 if len(objs) == 0: 

1047 _LOG.debug("No objects to write to database.") 

1048 return 

1049 

1050 visit_time_dt = visit_time.datetime 

1051 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

1052 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1053 

1054 extra_columns["validityStart"] = visit_time_dt 

1055 time_part: int | None = self._time_partition(visit_time) 

1056 if not self.config.time_partition_tables: 

1057 extra_columns["apdb_time_part"] = time_part 

1058 time_part = None 

1059 

1060 # Only store DiaObects if not doing replication or explicitly 

1061 # configured to always store them. 

1062 if replica_chunk is None or not self.config.use_insert_id_skips_diaobjects: 

1063 self._storeObjectsPandas( 

1064 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1065 ) 

1066 

1067 if replica_chunk is not None: 

1068 extra_columns = dict(apdb_replica_chunk=replica_chunk.id, validityStart=visit_time_dt) 

1069 self._storeObjectsPandas(objs, ExtraTables.DiaObjectChunks, extra_columns=extra_columns) 

1070 

1071 def _storeDiaSources( 

1072 self, 

1073 table_name: ApdbTables, 

1074 sources: pandas.DataFrame, 

1075 visit_time: astropy.time.Time, 

1076 replica_chunk: ReplicaChunk | None, 

1077 ) -> None: 

1078 """Store catalog of DIASources or DIAForcedSources from current visit. 

1079 

1080 Parameters 

1081 ---------- 

1082 table_name : `ApdbTables` 

1083 Table where to store the data. 

1084 sources : `pandas.DataFrame` 

1085 Catalog containing DiaSource records 

1086 visit_time : `astropy.time.Time` 

1087 Time of the current visit. 

1088 replica_chunk : `ReplicaChunk` or `None` 

1089 Replica chunk identifier if replication is configured. 

1090 """ 

1091 time_part: int | None = self._time_partition(visit_time) 

1092 extra_columns: dict[str, Any] = {} 

1093 if not self.config.time_partition_tables: 

1094 extra_columns["apdb_time_part"] = time_part 

1095 time_part = None 

1096 

1097 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

1098 

1099 if replica_chunk is not None: 

1100 extra_columns = dict(apdb_replica_chunk=replica_chunk.id) 

1101 if table_name is ApdbTables.DiaSource: 

1102 extra_table = ExtraTables.DiaSourceChunks 

1103 else: 

1104 extra_table = ExtraTables.DiaForcedSourceChunks 

1105 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1106 

1107 def _storeDiaSourcesPartitions( 

1108 self, sources: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1109 ) -> None: 

1110 """Store mapping of diaSourceId to its partitioning values. 

1111 

1112 Parameters 

1113 ---------- 

1114 sources : `pandas.DataFrame` 

1115 Catalog containing DiaSource records 

1116 visit_time : `astropy.time.Time` 

1117 Time of the current visit. 

1118 """ 

1119 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1120 extra_columns = { 

1121 "apdb_time_part": self._time_partition(visit_time), 

1122 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1123 } 

1124 

1125 self._storeObjectsPandas( 

1126 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1127 ) 

1128 

1129 def _storeObjectsPandas( 

1130 self, 

1131 records: pandas.DataFrame, 

1132 table_name: ApdbTables | ExtraTables, 

1133 extra_columns: Mapping | None = None, 

1134 time_part: int | None = None, 

1135 ) -> None: 

1136 """Store generic objects. 

1137 

1138 Takes Pandas catalog and stores a bunch of records in a table. 

1139 

1140 Parameters 

1141 ---------- 

1142 records : `pandas.DataFrame` 

1143 Catalog containing object records 

1144 table_name : `ApdbTables` 

1145 Name of the table as defined in APDB schema. 

1146 extra_columns : `dict`, optional 

1147 Mapping (column_name, column_value) which gives fixed values for 

1148 columns in each row, overrides values in ``records`` if matching 

1149 columns exist there. 

1150 time_part : `int`, optional 

1151 If not `None` then insert into a per-partition table. 

1152 

1153 Notes 

1154 ----- 

1155 If Pandas catalog contains additional columns not defined in table 

1156 schema they are ignored. Catalog does not have to contain all columns 

1157 defined in a table, but partition and clustering keys must be present 

1158 in a catalog or ``extra_columns``. 

1159 """ 

1160 # use extra columns if specified 

1161 if extra_columns is None: 

1162 extra_columns = {} 

1163 extra_fields = list(extra_columns.keys()) 

1164 

1165 # Fields that will come from dataframe. 

1166 df_fields = [column for column in records.columns if column not in extra_fields] 

1167 

1168 column_map = self._schema.getColumnMap(table_name) 

1169 # list of columns (as in felis schema) 

1170 fields = [column_map[field].name for field in df_fields if field in column_map] 

1171 fields += extra_fields 

1172 

1173 # check that all partitioning and clustering columns are defined 

1174 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

1175 table_name 

1176 ) 

1177 missing_columns = [column for column in required_columns if column not in fields] 

1178 if missing_columns: 

1179 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1180 

1181 qfields = [quote_id(field) for field in fields] 

1182 qfields_str = ",".join(qfields) 

1183 

1184 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1185 table = self._schema.tableName(table_name) 

1186 if time_part is not None: 

1187 table = f"{table}_{time_part}" 

1188 

1189 holders = ",".join(["?"] * len(qfields)) 

1190 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1191 statement = self._preparer.prepare(query) 

1192 queries = cassandra.query.BatchStatement() 

1193 for rec in records.itertuples(index=False): 

1194 values = [] 

1195 for field in df_fields: 

1196 if field not in column_map: 

1197 continue 

1198 value = getattr(rec, field) 

1199 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1200 if isinstance(value, pandas.Timestamp): 

1201 value = literal(value.to_pydatetime()) 

1202 else: 

1203 # Assume it's seconds since epoch, Cassandra 

1204 # datetime is in milliseconds 

1205 value = int(value * 1000) 

1206 values.append(literal(value)) 

1207 for field in extra_fields: 

1208 value = extra_columns[field] 

1209 values.append(literal(value)) 

1210 queries.add(statement, values) 

1211 

1212 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

1213 with self._timer("insert_time", tags={"table": table_name.name}): 

1214 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

1215 

1216 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1217 """Calculate spatial partition for each record and add it to a 

1218 DataFrame. 

1219 

1220 Notes 

1221 ----- 

1222 This overrides any existing column in a DataFrame with the same name 

1223 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1224 returned. 

1225 """ 

1226 # calculate HTM index for every DiaObject 

1227 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1228 ra_col, dec_col = self.config.ra_dec_columns 

1229 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1230 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1231 idx = self._pixelization.pixel(uv3d) 

1232 apdb_part[i] = idx 

1233 df = df.copy() 

1234 df["apdb_part"] = apdb_part 

1235 return df 

1236 

1237 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1238 """Add apdb_part column to DiaSource catalog. 

1239 

1240 Notes 

1241 ----- 

1242 This method copies apdb_part value from a matching DiaObject record. 

1243 DiaObject catalog needs to have a apdb_part column filled by 

1244 ``_add_obj_part`` method and DiaSource records need to be 

1245 associated to DiaObjects via ``diaObjectId`` column. 

1246 

1247 This overrides any existing column in a DataFrame with the same name 

1248 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1249 returned. 

1250 """ 

1251 pixel_id_map: dict[int, int] = { 

1252 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1253 } 

1254 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1255 ra_col, dec_col = self.config.ra_dec_columns 

1256 for i, (diaObjId, ra, dec) in enumerate( 

1257 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

1258 ): 

1259 if diaObjId == 0: 

1260 # DiaSources associated with SolarSystemObjects do not have an 

1261 # associated DiaObject hence we skip them and set partition 

1262 # based on its own ra/dec 

1263 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1264 idx = self._pixelization.pixel(uv3d) 

1265 apdb_part[i] = idx 

1266 else: 

1267 apdb_part[i] = pixel_id_map[diaObjId] 

1268 sources = sources.copy() 

1269 sources["apdb_part"] = apdb_part 

1270 return sources 

1271 

1272 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1273 """Add apdb_part column to DiaForcedSource catalog. 

1274 

1275 Notes 

1276 ----- 

1277 This method copies apdb_part value from a matching DiaObject record. 

1278 DiaObject catalog needs to have a apdb_part column filled by 

1279 ``_add_obj_part`` method and DiaSource records need to be 

1280 associated to DiaObjects via ``diaObjectId`` column. 

1281 

1282 This overrides any existing column in a DataFrame with the same name 

1283 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1284 returned. 

1285 """ 

1286 pixel_id_map: dict[int, int] = { 

1287 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1288 } 

1289 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1290 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1291 apdb_part[i] = pixel_id_map[diaObjId] 

1292 sources = sources.copy() 

1293 sources["apdb_part"] = apdb_part 

1294 return sources 

1295 

1296 @classmethod 

1297 def _time_partition_cls(cls, time: float | astropy.time.Time, epoch_mjd: float, part_days: int) -> int: 

1298 """Calculate time partition number for a given time. 

1299 

1300 Parameters 

1301 ---------- 

1302 time : `float` or `astropy.time.Time` 

1303 Time for which to calculate partition number. Can be float to mean 

1304 MJD or `astropy.time.Time` 

1305 epoch_mjd : `float` 

1306 Epoch time for partition 0. 

1307 part_days : `int` 

1308 Number of days per partition. 

1309 

1310 Returns 

1311 ------- 

1312 partition : `int` 

1313 Partition number for a given time. 

1314 """ 

1315 if isinstance(time, astropy.time.Time): 

1316 mjd = float(time.mjd) 

1317 else: 

1318 mjd = time 

1319 days_since_epoch = mjd - epoch_mjd 

1320 partition = int(days_since_epoch) // part_days 

1321 return partition 

1322 

1323 def _time_partition(self, time: float | astropy.time.Time) -> int: 

1324 """Calculate time partition number for a given time. 

1325 

1326 Parameters 

1327 ---------- 

1328 time : `float` or `astropy.time.Time` 

1329 Time for which to calculate partition number. Can be float to mean 

1330 MJD or `astropy.time.Time` 

1331 

1332 Returns 

1333 ------- 

1334 partition : `int` 

1335 Partition number for a given time. 

1336 """ 

1337 if isinstance(time, astropy.time.Time): 

1338 mjd = float(time.mjd) 

1339 else: 

1340 mjd = time 

1341 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

1342 partition = int(days_since_epoch) // self.config.time_partition_days 

1343 return partition 

1344 

1345 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1346 """Make an empty catalog for a table with a given name. 

1347 

1348 Parameters 

1349 ---------- 

1350 table_name : `ApdbTables` 

1351 Name of the table. 

1352 

1353 Returns 

1354 ------- 

1355 catalog : `pandas.DataFrame` 

1356 An empty catalog. 

1357 """ 

1358 table = self._schema.tableSchemas[table_name] 

1359 

1360 data = { 

1361 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1362 for columnDef in table.columns 

1363 } 

1364 return pandas.DataFrame(data) 

1365 

1366 def _combine_where( 

1367 self, 

1368 prefix: str, 

1369 where1: list[tuple[str, tuple]], 

1370 where2: list[tuple[str, tuple]], 

1371 suffix: str | None = None, 

1372 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1373 """Make cartesian product of two parts of WHERE clause into a series 

1374 of statements to execute. 

1375 

1376 Parameters 

1377 ---------- 

1378 prefix : `str` 

1379 Initial statement prefix that comes before WHERE clause, e.g. 

1380 "SELECT * from Table" 

1381 """ 

1382 # If lists are empty use special sentinels. 

1383 if not where1: 

1384 where1 = [("", ())] 

1385 if not where2: 

1386 where2 = [("", ())] 

1387 

1388 for expr1, params1 in where1: 

1389 for expr2, params2 in where2: 

1390 full_query = prefix 

1391 wheres = [] 

1392 if expr1: 

1393 wheres.append(expr1) 

1394 if expr2: 

1395 wheres.append(expr2) 

1396 if wheres: 

1397 full_query += " WHERE " + " AND ".join(wheres) 

1398 if suffix: 

1399 full_query += " " + suffix 

1400 params = params1 + params2 

1401 if params: 

1402 statement = self._preparer.prepare(full_query) 

1403 else: 

1404 # If there are no params then it is likely that query 

1405 # has a bunch of literals rendered already, no point 

1406 # trying to prepare it. 

1407 statement = cassandra.query.SimpleStatement(full_query) 

1408 yield (statement, params) 

1409 

1410 def _spatial_where( 

1411 self, region: sphgeom.Region | None, use_ranges: bool = False 

1412 ) -> list[tuple[str, tuple]]: 

1413 """Generate expressions for spatial part of WHERE clause. 

1414 

1415 Parameters 

1416 ---------- 

1417 region : `sphgeom.Region` 

1418 Spatial region for query results. 

1419 use_ranges : `bool` 

1420 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1421 p2") instead of exact list of pixels. Should be set to True for 

1422 large regions covering very many pixels. 

1423 

1424 Returns 

1425 ------- 

1426 expressions : `list` [ `tuple` ] 

1427 Empty list is returned if ``region`` is `None`, otherwise a list 

1428 of one or more (expression, parameters) tuples 

1429 """ 

1430 if region is None: 

1431 return [] 

1432 if use_ranges: 

1433 pixel_ranges = self._pixelization.envelope(region) 

1434 expressions: list[tuple[str, tuple]] = [] 

1435 for lower, upper in pixel_ranges: 

1436 upper -= 1 

1437 if lower == upper: 

1438 expressions.append(('"apdb_part" = ?', (lower,))) 

1439 else: 

1440 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1441 return expressions 

1442 else: 

1443 pixels = self._pixelization.pixels(region) 

1444 if self.config.query_per_spatial_part: 

1445 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1446 else: 

1447 pixels_str = ",".join([str(pix) for pix in pixels]) 

1448 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1449 

1450 def _temporal_where( 

1451 self, 

1452 table: ApdbTables, 

1453 start_time: float | astropy.time.Time, 

1454 end_time: float | astropy.time.Time, 

1455 query_per_time_part: bool | None = None, 

1456 ) -> tuple[list[str], list[tuple[str, tuple]]]: 

1457 """Generate table names and expressions for temporal part of WHERE 

1458 clauses. 

1459 

1460 Parameters 

1461 ---------- 

1462 table : `ApdbTables` 

1463 Table to select from. 

1464 start_time : `astropy.time.Time` or `float` 

1465 Starting Datetime of MJD value of the time range. 

1466 end_time : `astropy.time.Time` or `float` 

1467 Starting Datetime of MJD value of the time range. 

1468 query_per_time_part : `bool`, optional 

1469 If None then use ``query_per_time_part`` from configuration. 

1470 

1471 Returns 

1472 ------- 

1473 tables : `list` [ `str` ] 

1474 List of the table names to query. 

1475 expressions : `list` [ `tuple` ] 

1476 A list of zero or more (expression, parameters) tuples. 

1477 """ 

1478 tables: list[str] 

1479 temporal_where: list[tuple[str, tuple]] = [] 

1480 table_name = self._schema.tableName(table) 

1481 time_part_start = self._time_partition(start_time) 

1482 time_part_end = self._time_partition(end_time) 

1483 time_parts = list(range(time_part_start, time_part_end + 1)) 

1484 if self.config.time_partition_tables: 

1485 tables = [f"{table_name}_{part}" for part in time_parts] 

1486 else: 

1487 tables = [table_name] 

1488 if query_per_time_part is None: 

1489 query_per_time_part = self.config.query_per_time_part 

1490 if query_per_time_part: 

1491 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1492 else: 

1493 time_part_list = ",".join([str(part) for part in time_parts]) 

1494 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1495 

1496 return tables, temporal_where