Coverage for python/lsst/dax/apdb/cassandra/apdbCassandra.py: 18%

601 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:52 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import dataclasses 

27import json 

28import logging 

29from collections.abc import Iterable, Iterator, Mapping, Set 

30from typing import TYPE_CHECKING, Any, cast 

31 

32import numpy as np 

33import pandas 

34 

35# If cassandra-driver is not there the module can still be imported 

36# but ApdbCassandra cannot be instantiated. 

37try: 

38 import cassandra 

39 import cassandra.query 

40 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

41 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

42 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

43 

44 CASSANDRA_IMPORTED = True 

45except ImportError: 

46 CASSANDRA_IMPORTED = False 

47 

48import astropy.time 

49import felis.datamodel 

50from lsst import sphgeom 

51from lsst.pex.config import ChoiceField, Field, ListField 

52from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

53from lsst.utils.iteration import chunk_iterable 

54 

55from ..apdb import Apdb, ApdbConfig 

56from ..apdbConfigFreezer import ApdbConfigFreezer 

57from ..apdbReplica import ReplicaChunk 

58from ..apdbSchema import ApdbTables 

59from ..monitor import MonAgent 

60from ..pixelization import Pixelization 

61from ..schema_model import Table 

62from ..timer import Timer 

63from ..versionTuple import IncompatibleVersionError, VersionTuple 

64from .apdbCassandraReplica import ApdbCassandraReplica 

65from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

66from .apdbMetadataCassandra import ApdbMetadataCassandra 

67from .cassandra_utils import ( 

68 PreparedStatementCache, 

69 literal, 

70 pandas_dataframe_factory, 

71 quote_id, 

72 raw_data_factory, 

73 select_concurrent, 

74) 

75 

76if TYPE_CHECKING: 

77 from ..apdbMetadata import ApdbMetadata 

78 

79_LOG = logging.getLogger(__name__) 

80 

81_MON = MonAgent(__name__) 

82 

83VERSION = VersionTuple(0, 1, 0) 

84"""Version for the code controlling non-replication tables. This needs to be 

85updated following compatibility rules when schema produced by this code 

86changes. 

87""" 

88 

89# Copied from daf_butler. 

90DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

91"""Default name of the environmental variable that will be used to locate DB 

92credentials configuration file. """ 

93 

94DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

95"""Default path at which it is expected that DB credentials are found.""" 

96 

97 

98class CassandraMissingError(Exception): 

99 def __init__(self) -> None: 

100 super().__init__("cassandra-driver module cannot be imported") 

101 

102 

103class ApdbCassandraConfig(ApdbConfig): 

104 """Configuration class for Cassandra-based APDB implementation.""" 

105 

106 contact_points = ListField[str]( 

107 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

108 ) 

109 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

110 port = Field[int](doc="Port number to connect to.", default=9042) 

111 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

112 username = Field[str]( 

113 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

114 default="", 

115 ) 

116 read_consistency = Field[str]( 

117 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

118 ) 

119 write_consistency = Field[str]( 

120 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

121 ) 

122 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

123 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=60.0) 

124 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0) 

125 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

126 protocol_version = Field[int]( 

127 doc="Cassandra protocol version to use, default is V4", 

128 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

129 ) 

130 dia_object_columns = ListField[str]( 

131 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

132 ) 

133 prefix = Field[str](doc="Prefix to add to table names", default="") 

134 part_pixelization = ChoiceField[str]( 

135 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

136 doc="Pixelization used for partitioning index.", 

137 default="mq3c", 

138 ) 

139 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

140 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

141 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

142 timer = Field[bool](doc="If True then print/log timing information", default=False) 

143 time_partition_tables = Field[bool]( 

144 doc="Use per-partition tables for sources instead of partitioning by time", default=False 

145 ) 

146 time_partition_days = Field[int]( 

147 doc=( 

148 "Time partitioning granularity in days, this value must not be changed after database is " 

149 "initialized" 

150 ), 

151 default=30, 

152 ) 

153 time_partition_start = Field[str]( 

154 doc=( 

155 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

156 "This is used only when time_partition_tables is True." 

157 ), 

158 default="2018-12-01T00:00:00", 

159 ) 

160 time_partition_end = Field[str]( 

161 doc=( 

162 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

163 "This is used only when time_partition_tables is True." 

164 ), 

165 default="2030-01-01T00:00:00", 

166 ) 

167 query_per_time_part = Field[bool]( 

168 default=False, 

169 doc=( 

170 "If True then build separate query for each time partition, otherwise build one single query. " 

171 "This is only used when time_partition_tables is False in schema config." 

172 ), 

173 ) 

174 query_per_spatial_part = Field[bool]( 

175 default=False, 

176 doc="If True then build one query per spatial partition, otherwise build single query.", 

177 ) 

178 use_insert_id_skips_diaobjects = Field[bool]( 

179 default=False, 

180 doc=( 

181 "If True then do not store DiaObjects when use_insert_id is True " 

182 "(DiaObjectsChunks has the same data)." 

183 ), 

184 ) 

185 

186 

187@dataclasses.dataclass 

188class _FrozenApdbCassandraConfig: 

189 """Part of the configuration that is saved in metadata table and read back. 

190 

191 The attributes are a subset of attributes in `ApdbCassandraConfig` class. 

192 

193 Parameters 

194 ---------- 

195 config : `ApdbSqlConfig` 

196 Configuration used to copy initial values of attributes. 

197 """ 

198 

199 use_insert_id: bool 

200 part_pixelization: str 

201 part_pix_level: int 

202 ra_dec_columns: list[str] 

203 time_partition_tables: bool 

204 time_partition_days: int 

205 use_insert_id_skips_diaobjects: bool 

206 

207 def __init__(self, config: ApdbCassandraConfig): 

208 self.use_insert_id = config.use_insert_id 

209 self.part_pixelization = config.part_pixelization 

210 self.part_pix_level = config.part_pix_level 

211 self.ra_dec_columns = list(config.ra_dec_columns) 

212 self.time_partition_tables = config.time_partition_tables 

213 self.time_partition_days = config.time_partition_days 

214 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects 

215 

216 def to_json(self) -> str: 

217 """Convert this instance to JSON representation.""" 

218 return json.dumps(dataclasses.asdict(self)) 

219 

220 def update(self, json_str: str) -> None: 

221 """Update attribute values from a JSON string. 

222 

223 Parameters 

224 ---------- 

225 json_str : str 

226 String containing JSON representation of configuration. 

227 """ 

228 data = json.loads(json_str) 

229 if not isinstance(data, dict): 

230 raise TypeError(f"JSON string must be convertible to object: {json_str!r}") 

231 allowed_names = {field.name for field in dataclasses.fields(self)} 

232 for key, value in data.items(): 

233 if key not in allowed_names: 

234 raise ValueError(f"JSON object contains unknown key: {key}") 

235 setattr(self, key, value) 

236 

237 

238if CASSANDRA_IMPORTED: 238 ↛ 240line 238 didn't jump to line 240, because the condition on line 238 was never true

239 

240 class _AddressTranslator(AddressTranslator): 

241 """Translate internal IP address to external. 

242 

243 Only used for docker-based setup, not viable long-term solution. 

244 """ 

245 

246 def __init__(self, public_ips: list[str], private_ips: list[str]): 

247 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

248 

249 def translate(self, private_ip: str) -> str: 

250 return self._map.get(private_ip, private_ip) 

251 

252 

253class ApdbCassandra(Apdb): 

254 """Implementation of APDB database on to of Apache Cassandra. 

255 

256 The implementation is configured via standard ``pex_config`` mechanism 

257 using `ApdbCassandraConfig` configuration class. For an example of 

258 different configurations check config/ folder. 

259 

260 Parameters 

261 ---------- 

262 config : `ApdbCassandraConfig` 

263 Configuration object. 

264 """ 

265 

266 metadataSchemaVersionKey = "version:schema" 

267 """Name of the metadata key to store schema version number.""" 

268 

269 metadataCodeVersionKey = "version:ApdbCassandra" 

270 """Name of the metadata key to store code version number.""" 

271 

272 metadataReplicaVersionKey = "version:ApdbCassandraReplica" 

273 """Name of the metadata key to store replica code version number.""" 

274 

275 metadataConfigKey = "config:apdb-cassandra.json" 

276 """Name of the metadata key to store code version number.""" 

277 

278 _frozen_parameters = ( 

279 "use_insert_id", 

280 "part_pixelization", 

281 "part_pix_level", 

282 "ra_dec_columns", 

283 "time_partition_tables", 

284 "time_partition_days", 

285 "use_insert_id_skips_diaobjects", 

286 ) 

287 """Names of the config parameters to be frozen in metadata table.""" 

288 

289 partition_zero_epoch = astropy.time.Time(0, format="unix_tai") 

290 """Start time for partition 0, this should never be changed.""" 

291 

292 def __init__(self, config: ApdbCassandraConfig): 

293 if not CASSANDRA_IMPORTED: 

294 raise CassandraMissingError() 

295 

296 self._keyspace = config.keyspace 

297 

298 self._cluster, self._session = self._make_session(config) 

299 

300 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

301 self._metadata = ApdbMetadataCassandra( 

302 self._session, meta_table_name, config.keyspace, "read_tuples", "write" 

303 ) 

304 

305 # Read frozen config from metadata. 

306 config_json = self._metadata.get(self.metadataConfigKey) 

307 if config_json is not None: 

308 # Update config from metadata. 

309 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters) 

310 self.config = freezer.update(config, config_json) 

311 else: 

312 self.config = config 

313 self.config.validate() 

314 

315 self._pixelization = Pixelization( 

316 self.config.part_pixelization, 

317 self.config.part_pix_level, 

318 config.part_pix_max_ranges, 

319 ) 

320 

321 self._schema = ApdbCassandraSchema( 

322 session=self._session, 

323 keyspace=self._keyspace, 

324 schema_file=self.config.schema_file, 

325 schema_name=self.config.schema_name, 

326 prefix=self.config.prefix, 

327 time_partition_tables=self.config.time_partition_tables, 

328 enable_replica=self.config.use_insert_id, 

329 ) 

330 self._partition_zero_epoch_mjd = float(self.partition_zero_epoch.mjd) 

331 

332 if self._metadata.table_exists(): 

333 self._versionCheck(self._metadata) 

334 

335 # Cache for prepared statements 

336 self._preparer = PreparedStatementCache(self._session) 

337 

338 _LOG.debug("ApdbCassandra Configuration:") 

339 for key, value in self.config.items(): 

340 _LOG.debug(" %s: %s", key, value) 

341 

342 self._timer_args: list[MonAgent | logging.Logger] = [_MON] 

343 if self.config.timer: 

344 self._timer_args.append(_LOG) 

345 

346 def __del__(self) -> None: 

347 if hasattr(self, "_cluster"): 

348 self._cluster.shutdown() 

349 

350 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

351 """Create `Timer` instance given its name.""" 

352 return Timer(name, *self._timer_args, tags=tags) 

353 

354 @classmethod 

355 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]: 

356 """Make Cassandra session.""" 

357 addressTranslator: AddressTranslator | None = None 

358 if config.private_ips: 

359 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

360 

361 cluster = Cluster( 

362 execution_profiles=cls._makeProfiles(config), 

363 contact_points=config.contact_points, 

364 port=config.port, 

365 address_translator=addressTranslator, 

366 protocol_version=config.protocol_version, 

367 auth_provider=cls._make_auth_provider(config), 

368 ) 

369 session = cluster.connect() 

370 # Disable result paging 

371 session.default_fetch_size = None 

372 

373 return cluster, session 

374 

375 @classmethod 

376 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None: 

377 """Make Cassandra authentication provider instance.""" 

378 try: 

379 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

380 except DbAuthNotFoundError: 

381 # Credentials file doesn't exist, use anonymous login. 

382 return None 

383 

384 empty_username = True 

385 # Try every contact point in turn. 

386 for hostname in config.contact_points: 

387 try: 

388 username, password = dbauth.getAuth( 

389 "cassandra", config.username, hostname, config.port, config.keyspace 

390 ) 

391 if not username: 

392 # Password without user name, try next hostname, but give 

393 # warning later if no better match is found. 

394 empty_username = True 

395 else: 

396 return PlainTextAuthProvider(username=username, password=password) 

397 except DbAuthNotFoundError: 

398 pass 

399 

400 if empty_username: 

401 _LOG.warning( 

402 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

403 f"user name, anonymous Cassandra logon will be attempted." 

404 ) 

405 

406 return None 

407 

408 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None: 

409 """Check schema version compatibility.""" 

410 

411 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

412 """Retrieve version number from given metadata key.""" 

413 if metadata.table_exists(): 

414 version_str = metadata.get(key) 

415 if version_str is None: 

416 # Should not happen with existing metadata table. 

417 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

418 return VersionTuple.fromString(version_str) 

419 return default 

420 

421 # For old databases where metadata table does not exist we assume that 

422 # version of both code and schema is 0.1.0. 

423 initial_version = VersionTuple(0, 1, 0) 

424 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

425 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

426 

427 # For now there is no way to make read-only APDB instances, assume that 

428 # any access can do updates. 

429 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

430 raise IncompatibleVersionError( 

431 f"Configured schema version {self._schema.schemaVersion()} " 

432 f"is not compatible with database version {db_schema_version}" 

433 ) 

434 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

435 raise IncompatibleVersionError( 

436 f"Current code version {self.apdbImplementationVersion()} " 

437 f"is not compatible with database version {db_code_version}" 

438 ) 

439 

440 # Check replica code version only if replica is enabled. 

441 if self._schema.has_replica_chunks: 

442 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version) 

443 code_replica_version = ApdbCassandraReplica.apdbReplicaImplementationVersion() 

444 if not code_replica_version.checkCompatibility(db_replica_version, True): 

445 raise IncompatibleVersionError( 

446 f"Current replication code version {code_replica_version} " 

447 f"is not compatible with database version {db_replica_version}" 

448 ) 

449 

450 @classmethod 

451 def apdbImplementationVersion(cls) -> VersionTuple: 

452 # Docstring inherited from base class. 

453 return VERSION 

454 

455 def apdbSchemaVersion(self) -> VersionTuple: 

456 # Docstring inherited from base class. 

457 return self._schema.schemaVersion() 

458 

459 def tableDef(self, table: ApdbTables) -> Table | None: 

460 # docstring is inherited from a base class 

461 return self._schema.tableSchemas.get(table) 

462 

463 @classmethod 

464 def init_database( 

465 cls, 

466 hosts: list[str], 

467 keyspace: str, 

468 *, 

469 schema_file: str | None = None, 

470 schema_name: str | None = None, 

471 read_sources_months: int | None = None, 

472 read_forced_sources_months: int | None = None, 

473 use_insert_id: bool = False, 

474 use_insert_id_skips_diaobjects: bool = False, 

475 port: int | None = None, 

476 username: str | None = None, 

477 prefix: str | None = None, 

478 part_pixelization: str | None = None, 

479 part_pix_level: int | None = None, 

480 time_partition_tables: bool = True, 

481 time_partition_start: str | None = None, 

482 time_partition_end: str | None = None, 

483 read_consistency: str | None = None, 

484 write_consistency: str | None = None, 

485 read_timeout: int | None = None, 

486 write_timeout: int | None = None, 

487 ra_dec_columns: list[str] | None = None, 

488 replication_factor: int | None = None, 

489 drop: bool = False, 

490 ) -> ApdbCassandraConfig: 

491 """Initialize new APDB instance and make configuration object for it. 

492 

493 Parameters 

494 ---------- 

495 hosts : `list` [`str`] 

496 List of host names or IP addresses for Cassandra cluster. 

497 keyspace : `str` 

498 Name of the keyspace for APDB tables. 

499 schema_file : `str`, optional 

500 Location of (YAML) configuration file with APDB schema. If not 

501 specified then default location will be used. 

502 schema_name : `str`, optional 

503 Name of the schema in YAML configuration file. If not specified 

504 then default name will be used. 

505 read_sources_months : `int`, optional 

506 Number of months of history to read from DiaSource. 

507 read_forced_sources_months : `int`, optional 

508 Number of months of history to read from DiaForcedSource. 

509 use_insert_id : `bool`, optional 

510 If True, make additional tables used for replication to PPDB. 

511 use_insert_id_skips_diaobjects : `bool`, optional 

512 If `True` then do not fill regular ``DiaObject`` table when 

513 ``use_insert_id`` is `True`. 

514 port : `int`, optional 

515 Port number to use for Cassandra connections. 

516 username : `str`, optional 

517 User name for Cassandra connections. 

518 prefix : `str`, optional 

519 Optional prefix for all table names. 

520 part_pixelization : `str`, optional 

521 Name of the MOC pixelization used for partitioning. 

522 part_pix_level : `int`, optional 

523 Pixelization level. 

524 time_partition_tables : `bool`, optional 

525 Create per-partition tables. 

526 time_partition_start : `str`, optional 

527 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

528 format, in TAI. 

529 time_partition_end : `str`, optional 

530 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

531 format, in TAI. 

532 read_consistency : `str`, optional 

533 Name of the consistency level for read operations. 

534 write_consistency : `str`, optional 

535 Name of the consistency level for write operations. 

536 read_timeout : `int`, optional 

537 Read timeout in seconds. 

538 write_timeout : `int`, optional 

539 Write timeout in seconds. 

540 ra_dec_columns : `list` [`str`], optional 

541 Names of ra/dec columns in DiaObject table. 

542 replication_factor : `int`, optional 

543 Replication factor used when creating new keyspace, if keyspace 

544 already exists its replication factor is not changed. 

545 drop : `bool`, optional 

546 If `True` then drop existing tables before re-creating the schema. 

547 

548 Returns 

549 ------- 

550 config : `ApdbCassandraConfig` 

551 Resulting configuration object for a created APDB instance. 

552 """ 

553 config = ApdbCassandraConfig( 

554 contact_points=hosts, 

555 keyspace=keyspace, 

556 use_insert_id=use_insert_id, 

557 use_insert_id_skips_diaobjects=use_insert_id_skips_diaobjects, 

558 time_partition_tables=time_partition_tables, 

559 ) 

560 if schema_file is not None: 

561 config.schema_file = schema_file 

562 if schema_name is not None: 

563 config.schema_name = schema_name 

564 if read_sources_months is not None: 

565 config.read_sources_months = read_sources_months 

566 if read_forced_sources_months is not None: 

567 config.read_forced_sources_months = read_forced_sources_months 

568 if port is not None: 

569 config.port = port 

570 if username is not None: 

571 config.username = username 

572 if prefix is not None: 

573 config.prefix = prefix 

574 if part_pixelization is not None: 

575 config.part_pixelization = part_pixelization 

576 if part_pix_level is not None: 

577 config.part_pix_level = part_pix_level 

578 if time_partition_start is not None: 

579 config.time_partition_start = time_partition_start 

580 if time_partition_end is not None: 

581 config.time_partition_end = time_partition_end 

582 if read_consistency is not None: 

583 config.read_consistency = read_consistency 

584 if write_consistency is not None: 

585 config.write_consistency = write_consistency 

586 if read_timeout is not None: 

587 config.read_timeout = read_timeout 

588 if write_timeout is not None: 

589 config.write_timeout = write_timeout 

590 if ra_dec_columns is not None: 

591 config.ra_dec_columns = ra_dec_columns 

592 

593 cls._makeSchema(config, drop=drop, replication_factor=replication_factor) 

594 

595 return config 

596 

597 def get_replica(self) -> ApdbCassandraReplica: 

598 """Return `ApdbReplica` instance for this database.""" 

599 # Note that this instance has to stay alive while replica exists, so 

600 # we pass reference to self. 

601 return ApdbCassandraReplica(self, self._schema, self._session) 

602 

603 @classmethod 

604 def _makeSchema( 

605 cls, config: ApdbConfig, *, drop: bool = False, replication_factor: int | None = None 

606 ) -> None: 

607 # docstring is inherited from a base class 

608 

609 if not isinstance(config, ApdbCassandraConfig): 

610 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

611 

612 cluster, session = cls._make_session(config) 

613 

614 schema = ApdbCassandraSchema( 

615 session=session, 

616 keyspace=config.keyspace, 

617 schema_file=config.schema_file, 

618 schema_name=config.schema_name, 

619 prefix=config.prefix, 

620 time_partition_tables=config.time_partition_tables, 

621 enable_replica=config.use_insert_id, 

622 ) 

623 

624 # Ask schema to create all tables. 

625 if config.time_partition_tables: 

626 time_partition_start = astropy.time.Time(config.time_partition_start, format="isot", scale="tai") 

627 time_partition_end = astropy.time.Time(config.time_partition_end, format="isot", scale="tai") 

628 part_epoch = float(cls.partition_zero_epoch.mjd) 

629 part_days = config.time_partition_days 

630 part_range = ( 

631 cls._time_partition_cls(time_partition_start, part_epoch, part_days), 

632 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1, 

633 ) 

634 schema.makeSchema(drop=drop, part_range=part_range, replication_factor=replication_factor) 

635 else: 

636 schema.makeSchema(drop=drop, replication_factor=replication_factor) 

637 

638 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

639 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write") 

640 

641 # Fill version numbers, overrides if they existed before. 

642 if metadata.table_exists(): 

643 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

644 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

645 

646 if config.use_insert_id: 

647 # Only store replica code version if replica is enabled. 

648 metadata.set( 

649 cls.metadataReplicaVersionKey, 

650 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

651 force=True, 

652 ) 

653 

654 # Store frozen part of a configuration in metadata. 

655 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters) 

656 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

657 

658 cluster.shutdown() 

659 

660 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

661 # docstring is inherited from a base class 

662 

663 sp_where = self._spatial_where(region) 

664 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

665 

666 # We need to exclude extra partitioning columns from result. 

667 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

668 what = ",".join(quote_id(column) for column in column_names) 

669 

670 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

671 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

672 statements: list[tuple] = [] 

673 for where, params in sp_where: 

674 full_query = f"{query} WHERE {where}" 

675 if params: 

676 statement = self._preparer.prepare(full_query) 

677 else: 

678 # If there are no params then it is likely that query has a 

679 # bunch of literals rendered already, no point trying to 

680 # prepare it because it's not reusable. 

681 statement = cassandra.query.SimpleStatement(full_query) 

682 statements.append((statement, params)) 

683 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

684 

685 with _MON.context_tags({"table": "DiaObject"}): 

686 _MON.add_record( 

687 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

688 ) 

689 with self._timer("select_time"): 

690 objects = cast( 

691 pandas.DataFrame, 

692 select_concurrent( 

693 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

694 ), 

695 ) 

696 

697 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

698 return objects 

699 

700 def getDiaSources( 

701 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

702 ) -> pandas.DataFrame | None: 

703 # docstring is inherited from a base class 

704 months = self.config.read_sources_months 

705 if months == 0: 

706 return None 

707 mjd_end = visit_time.mjd 

708 mjd_start = mjd_end - months * 30 

709 

710 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

711 

712 def getDiaForcedSources( 

713 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

714 ) -> pandas.DataFrame | None: 

715 # docstring is inherited from a base class 

716 months = self.config.read_forced_sources_months 

717 if months == 0: 

718 return None 

719 mjd_end = visit_time.mjd 

720 mjd_start = mjd_end - months * 30 

721 

722 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

723 

724 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

725 # docstring is inherited from a base class 

726 raise NotImplementedError() 

727 

728 def getSSObjects(self) -> pandas.DataFrame: 

729 # docstring is inherited from a base class 

730 tableName = self._schema.tableName(ApdbTables.SSObject) 

731 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

732 

733 objects = None 

734 with self._timer("select_time", tags={"table": "SSObject"}): 

735 result = self._session.execute(query, execution_profile="read_pandas") 

736 objects = result._current_rows 

737 

738 _LOG.debug("found %s SSObjects", objects.shape[0]) 

739 return objects 

740 

741 def store( 

742 self, 

743 visit_time: astropy.time.Time, 

744 objects: pandas.DataFrame, 

745 sources: pandas.DataFrame | None = None, 

746 forced_sources: pandas.DataFrame | None = None, 

747 ) -> None: 

748 # docstring is inherited from a base class 

749 

750 replica_chunk: ReplicaChunk | None = None 

751 if self._schema.has_replica_chunks: 

752 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

753 self._storeReplicaChunk(replica_chunk, visit_time) 

754 

755 # fill region partition column for DiaObjects 

756 objects = self._add_obj_part(objects) 

757 self._storeDiaObjects(objects, visit_time, replica_chunk) 

758 

759 if sources is not None: 

760 # copy apdb_part column from DiaObjects to DiaSources 

761 sources = self._add_src_part(sources, objects) 

762 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, replica_chunk) 

763 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk) 

764 

765 if forced_sources is not None: 

766 forced_sources = self._add_fsrc_part(forced_sources, objects) 

767 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, replica_chunk) 

768 

769 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

770 # docstring is inherited from a base class 

771 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

772 

773 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

774 # docstring is inherited from a base class 

775 

776 # To update a record we need to know its exact primary key (including 

777 # partition key) so we start by querying for diaSourceId to find the 

778 # primary keys. 

779 

780 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

781 # split it into 1k IDs per query 

782 selects: list[tuple] = [] 

783 for ids in chunk_iterable(idMap.keys(), 1_000): 

784 ids_str = ",".join(str(item) for item in ids) 

785 selects.append( 

786 ( 

787 ( 

788 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk" ' 

789 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

790 ), 

791 {}, 

792 ) 

793 ) 

794 

795 # No need for DataFrame here, read data as tuples. 

796 result = cast( 

797 list[tuple[int, int, int, int | None]], 

798 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

799 ) 

800 

801 # Make mapping from source ID to its partition. 

802 id2partitions: dict[int, tuple[int, int]] = {} 

803 id2chunk_id: dict[int, int] = {} 

804 for row in result: 

805 id2partitions[row[0]] = row[1:3] 

806 if row[3] is not None: 

807 id2chunk_id[row[0]] = row[3] 

808 

809 # make sure we know partitions for each ID 

810 if set(id2partitions) != set(idMap): 

811 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

812 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

813 

814 # Reassign in standard tables 

815 queries = cassandra.query.BatchStatement() 

816 table_name = self._schema.tableName(ApdbTables.DiaSource) 

817 for diaSourceId, ssObjectId in idMap.items(): 

818 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

819 values: tuple 

820 if self.config.time_partition_tables: 

821 query = ( 

822 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

823 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

824 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

825 ) 

826 values = (ssObjectId, apdb_part, diaSourceId) 

827 else: 

828 query = ( 

829 f'UPDATE "{self._keyspace}"."{table_name}"' 

830 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

831 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

832 ) 

833 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

834 queries.add(self._preparer.prepare(query), values) 

835 

836 # Reassign in replica tables, only if replication is enabled 

837 if id2chunk_id: 

838 # Filter out chunks that have been deleted already. There is a 

839 # potential race with concurrent removal of chunks, but it 

840 # should be handled by WHERE in UPDATE. 

841 known_ids = set() 

842 if replica_chunks := self.get_replica().getReplicaChunks(): 

843 known_ids = set(replica_chunk.id for replica_chunk in replica_chunks) 

844 id2chunk_id = {key: value for key, value in id2chunk_id.items() if value in known_ids} 

845 if id2chunk_id: 

846 table_name = self._schema.tableName(ExtraTables.DiaSourceChunks) 

847 for diaSourceId, ssObjectId in idMap.items(): 

848 if replica_chunk := id2chunk_id.get(diaSourceId): 

849 query = ( 

850 f'UPDATE "{self._keyspace}"."{table_name}" ' 

851 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

852 'WHERE "apdb_replica_chunk" = ? AND "diaSourceId" = ?' 

853 ) 

854 values = (ssObjectId, replica_chunk, diaSourceId) 

855 queries.add(self._preparer.prepare(query), values) 

856 

857 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

858 with self._timer("update_time", tags={"table": table_name}): 

859 self._session.execute(queries, execution_profile="write") 

860 

861 def dailyJob(self) -> None: 

862 # docstring is inherited from a base class 

863 pass 

864 

865 def countUnassociatedObjects(self) -> int: 

866 # docstring is inherited from a base class 

867 

868 # It's too inefficient to implement it for Cassandra in current schema. 

869 raise NotImplementedError() 

870 

871 @property 

872 def metadata(self) -> ApdbMetadata: 

873 # docstring is inherited from a base class 

874 if self._metadata is None: 

875 raise RuntimeError("Database schema was not initialized.") 

876 return self._metadata 

877 

878 @classmethod 

879 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

880 """Make all execution profiles used in the code.""" 

881 if config.private_ips: 

882 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

883 else: 

884 loadBalancePolicy = RoundRobinPolicy() 

885 

886 read_tuples_profile = ExecutionProfile( 

887 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

888 request_timeout=config.read_timeout, 

889 row_factory=cassandra.query.tuple_factory, 

890 load_balancing_policy=loadBalancePolicy, 

891 ) 

892 read_pandas_profile = ExecutionProfile( 

893 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

894 request_timeout=config.read_timeout, 

895 row_factory=pandas_dataframe_factory, 

896 load_balancing_policy=loadBalancePolicy, 

897 ) 

898 read_raw_profile = ExecutionProfile( 

899 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

900 request_timeout=config.read_timeout, 

901 row_factory=raw_data_factory, 

902 load_balancing_policy=loadBalancePolicy, 

903 ) 

904 # Profile to use with select_concurrent to return pandas data frame 

905 read_pandas_multi_profile = ExecutionProfile( 

906 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

907 request_timeout=config.read_timeout, 

908 row_factory=pandas_dataframe_factory, 

909 load_balancing_policy=loadBalancePolicy, 

910 ) 

911 # Profile to use with select_concurrent to return raw data (columns and 

912 # rows) 

913 read_raw_multi_profile = ExecutionProfile( 

914 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

915 request_timeout=config.read_timeout, 

916 row_factory=raw_data_factory, 

917 load_balancing_policy=loadBalancePolicy, 

918 ) 

919 write_profile = ExecutionProfile( 

920 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

921 request_timeout=config.write_timeout, 

922 load_balancing_policy=loadBalancePolicy, 

923 ) 

924 # To replace default DCAwareRoundRobinPolicy 

925 default_profile = ExecutionProfile( 

926 load_balancing_policy=loadBalancePolicy, 

927 ) 

928 return { 

929 "read_tuples": read_tuples_profile, 

930 "read_pandas": read_pandas_profile, 

931 "read_raw": read_raw_profile, 

932 "read_pandas_multi": read_pandas_multi_profile, 

933 "read_raw_multi": read_raw_multi_profile, 

934 "write": write_profile, 

935 EXEC_PROFILE_DEFAULT: default_profile, 

936 } 

937 

938 def _getSources( 

939 self, 

940 region: sphgeom.Region, 

941 object_ids: Iterable[int] | None, 

942 mjd_start: float, 

943 mjd_end: float, 

944 table_name: ApdbTables, 

945 ) -> pandas.DataFrame: 

946 """Return catalog of DiaSource instances given set of DiaObject IDs. 

947 

948 Parameters 

949 ---------- 

950 region : `lsst.sphgeom.Region` 

951 Spherical region. 

952 object_ids : 

953 Collection of DiaObject IDs 

954 mjd_start : `float` 

955 Lower bound of time interval. 

956 mjd_end : `float` 

957 Upper bound of time interval. 

958 table_name : `ApdbTables` 

959 Name of the table. 

960 

961 Returns 

962 ------- 

963 catalog : `pandas.DataFrame`, or `None` 

964 Catalog containing DiaSource records. Empty catalog is returned if 

965 ``object_ids`` is empty. 

966 """ 

967 object_id_set: Set[int] = set() 

968 if object_ids is not None: 

969 object_id_set = set(object_ids) 

970 if len(object_id_set) == 0: 

971 return self._make_empty_catalog(table_name) 

972 

973 sp_where = self._spatial_where(region) 

974 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

975 

976 # We need to exclude extra partitioning columns from result. 

977 column_names = self._schema.apdbColumnNames(table_name) 

978 what = ",".join(quote_id(column) for column in column_names) 

979 

980 # Build all queries 

981 statements: list[tuple] = [] 

982 for table in tables: 

983 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

984 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

985 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

986 

987 with _MON.context_tags({"table": table_name.name}): 

988 _MON.add_record( 

989 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

990 ) 

991 with self._timer("select_time"): 

992 catalog = cast( 

993 pandas.DataFrame, 

994 select_concurrent( 

995 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

996 ), 

997 ) 

998 

999 # filter by given object IDs 

1000 if len(object_id_set) > 0: 

1001 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1002 

1003 # precise filtering on midpointMjdTai 

1004 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1005 

1006 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1007 return catalog 

1008 

1009 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk, visit_time: astropy.time.Time) -> None: 

1010 # Cassandra timestamp uses milliseconds since epoch 

1011 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

1012 

1013 # everything goes into a single partition 

1014 partition = 0 

1015 

1016 table_name = self._schema.tableName(ExtraTables.ApdbReplicaChunks) 

1017 query = ( 

1018 f'INSERT INTO "{self._keyspace}"."{table_name}" ' 

1019 "(partition, apdb_replica_chunk, last_update_time, unique_id) " 

1020 "VALUES (?, ?, ?, ?)" 

1021 ) 

1022 

1023 self._session.execute( 

1024 self._preparer.prepare(query), 

1025 (partition, replica_chunk.id, timestamp, replica_chunk.unique_id), 

1026 timeout=self.config.write_timeout, 

1027 execution_profile="write", 

1028 ) 

1029 

1030 def _storeDiaObjects( 

1031 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1032 ) -> None: 

1033 """Store catalog of DiaObjects from current visit. 

1034 

1035 Parameters 

1036 ---------- 

1037 objs : `pandas.DataFrame` 

1038 Catalog with DiaObject records 

1039 visit_time : `astropy.time.Time` 

1040 Time of the current visit. 

1041 replica_chunk : `ReplicaChunk` or `None` 

1042 Replica chunk identifier if replication is configured. 

1043 """ 

1044 if len(objs) == 0: 

1045 _LOG.debug("No objects to write to database.") 

1046 return 

1047 

1048 visit_time_dt = visit_time.datetime 

1049 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

1050 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1051 

1052 extra_columns["validityStart"] = visit_time_dt 

1053 time_part: int | None = self._time_partition(visit_time) 

1054 if not self.config.time_partition_tables: 

1055 extra_columns["apdb_time_part"] = time_part 

1056 time_part = None 

1057 

1058 # Only store DiaObects if not doing replication or explicitly 

1059 # configured to always store them. 

1060 if replica_chunk is None or not self.config.use_insert_id_skips_diaobjects: 

1061 self._storeObjectsPandas( 

1062 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1063 ) 

1064 

1065 if replica_chunk is not None: 

1066 extra_columns = dict(apdb_replica_chunk=replica_chunk.id, validityStart=visit_time_dt) 

1067 self._storeObjectsPandas(objs, ExtraTables.DiaObjectChunks, extra_columns=extra_columns) 

1068 

1069 def _storeDiaSources( 

1070 self, 

1071 table_name: ApdbTables, 

1072 sources: pandas.DataFrame, 

1073 visit_time: astropy.time.Time, 

1074 replica_chunk: ReplicaChunk | None, 

1075 ) -> None: 

1076 """Store catalog of DIASources or DIAForcedSources from current visit. 

1077 

1078 Parameters 

1079 ---------- 

1080 table_name : `ApdbTables` 

1081 Table where to store the data. 

1082 sources : `pandas.DataFrame` 

1083 Catalog containing DiaSource records 

1084 visit_time : `astropy.time.Time` 

1085 Time of the current visit. 

1086 replica_chunk : `ReplicaChunk` or `None` 

1087 Replica chunk identifier if replication is configured. 

1088 """ 

1089 time_part: int | None = self._time_partition(visit_time) 

1090 extra_columns: dict[str, Any] = {} 

1091 if not self.config.time_partition_tables: 

1092 extra_columns["apdb_time_part"] = time_part 

1093 time_part = None 

1094 

1095 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

1096 

1097 if replica_chunk is not None: 

1098 extra_columns = dict(apdb_replica_chunk=replica_chunk.id) 

1099 if table_name is ApdbTables.DiaSource: 

1100 extra_table = ExtraTables.DiaSourceChunks 

1101 else: 

1102 extra_table = ExtraTables.DiaForcedSourceChunks 

1103 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1104 

1105 def _storeDiaSourcesPartitions( 

1106 self, sources: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1107 ) -> None: 

1108 """Store mapping of diaSourceId to its partitioning values. 

1109 

1110 Parameters 

1111 ---------- 

1112 sources : `pandas.DataFrame` 

1113 Catalog containing DiaSource records 

1114 visit_time : `astropy.time.Time` 

1115 Time of the current visit. 

1116 """ 

1117 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1118 extra_columns = { 

1119 "apdb_time_part": self._time_partition(visit_time), 

1120 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1121 } 

1122 

1123 self._storeObjectsPandas( 

1124 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1125 ) 

1126 

1127 def _storeObjectsPandas( 

1128 self, 

1129 records: pandas.DataFrame, 

1130 table_name: ApdbTables | ExtraTables, 

1131 extra_columns: Mapping | None = None, 

1132 time_part: int | None = None, 

1133 ) -> None: 

1134 """Store generic objects. 

1135 

1136 Takes Pandas catalog and stores a bunch of records in a table. 

1137 

1138 Parameters 

1139 ---------- 

1140 records : `pandas.DataFrame` 

1141 Catalog containing object records 

1142 table_name : `ApdbTables` 

1143 Name of the table as defined in APDB schema. 

1144 extra_columns : `dict`, optional 

1145 Mapping (column_name, column_value) which gives fixed values for 

1146 columns in each row, overrides values in ``records`` if matching 

1147 columns exist there. 

1148 time_part : `int`, optional 

1149 If not `None` then insert into a per-partition table. 

1150 

1151 Notes 

1152 ----- 

1153 If Pandas catalog contains additional columns not defined in table 

1154 schema they are ignored. Catalog does not have to contain all columns 

1155 defined in a table, but partition and clustering keys must be present 

1156 in a catalog or ``extra_columns``. 

1157 """ 

1158 # use extra columns if specified 

1159 if extra_columns is None: 

1160 extra_columns = {} 

1161 extra_fields = list(extra_columns.keys()) 

1162 

1163 # Fields that will come from dataframe. 

1164 df_fields = [column for column in records.columns if column not in extra_fields] 

1165 

1166 column_map = self._schema.getColumnMap(table_name) 

1167 # list of columns (as in felis schema) 

1168 fields = [column_map[field].name for field in df_fields if field in column_map] 

1169 fields += extra_fields 

1170 

1171 # check that all partitioning and clustering columns are defined 

1172 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

1173 table_name 

1174 ) 

1175 missing_columns = [column for column in required_columns if column not in fields] 

1176 if missing_columns: 

1177 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1178 

1179 qfields = [quote_id(field) for field in fields] 

1180 qfields_str = ",".join(qfields) 

1181 

1182 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1183 table = self._schema.tableName(table_name) 

1184 if time_part is not None: 

1185 table = f"{table}_{time_part}" 

1186 

1187 holders = ",".join(["?"] * len(qfields)) 

1188 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1189 statement = self._preparer.prepare(query) 

1190 queries = cassandra.query.BatchStatement() 

1191 for rec in records.itertuples(index=False): 

1192 values = [] 

1193 for field in df_fields: 

1194 if field not in column_map: 

1195 continue 

1196 value = getattr(rec, field) 

1197 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1198 if isinstance(value, pandas.Timestamp): 

1199 value = literal(value.to_pydatetime()) 

1200 else: 

1201 # Assume it's seconds since epoch, Cassandra 

1202 # datetime is in milliseconds 

1203 value = int(value * 1000) 

1204 values.append(literal(value)) 

1205 for field in extra_fields: 

1206 value = extra_columns[field] 

1207 values.append(literal(value)) 

1208 queries.add(statement, values) 

1209 

1210 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

1211 with self._timer("insert_time", tags={"table": table_name.name}): 

1212 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

1213 

1214 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1215 """Calculate spatial partition for each record and add it to a 

1216 DataFrame. 

1217 

1218 Notes 

1219 ----- 

1220 This overrides any existing column in a DataFrame with the same name 

1221 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1222 returned. 

1223 """ 

1224 # calculate HTM index for every DiaObject 

1225 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1226 ra_col, dec_col = self.config.ra_dec_columns 

1227 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1228 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1229 idx = self._pixelization.pixel(uv3d) 

1230 apdb_part[i] = idx 

1231 df = df.copy() 

1232 df["apdb_part"] = apdb_part 

1233 return df 

1234 

1235 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1236 """Add apdb_part column to DiaSource catalog. 

1237 

1238 Notes 

1239 ----- 

1240 This method copies apdb_part value from a matching DiaObject record. 

1241 DiaObject catalog needs to have a apdb_part column filled by 

1242 ``_add_obj_part`` method and DiaSource records need to be 

1243 associated to DiaObjects via ``diaObjectId`` column. 

1244 

1245 This overrides any existing column in a DataFrame with the same name 

1246 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1247 returned. 

1248 """ 

1249 pixel_id_map: dict[int, int] = { 

1250 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1251 } 

1252 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1253 ra_col, dec_col = self.config.ra_dec_columns 

1254 for i, (diaObjId, ra, dec) in enumerate( 

1255 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

1256 ): 

1257 if diaObjId == 0: 

1258 # DiaSources associated with SolarSystemObjects do not have an 

1259 # associated DiaObject hence we skip them and set partition 

1260 # based on its own ra/dec 

1261 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1262 idx = self._pixelization.pixel(uv3d) 

1263 apdb_part[i] = idx 

1264 else: 

1265 apdb_part[i] = pixel_id_map[diaObjId] 

1266 sources = sources.copy() 

1267 sources["apdb_part"] = apdb_part 

1268 return sources 

1269 

1270 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1271 """Add apdb_part column to DiaForcedSource catalog. 

1272 

1273 Notes 

1274 ----- 

1275 This method copies apdb_part value from a matching DiaObject record. 

1276 DiaObject catalog needs to have a apdb_part column filled by 

1277 ``_add_obj_part`` method and DiaSource records need to be 

1278 associated to DiaObjects via ``diaObjectId`` column. 

1279 

1280 This overrides any existing column in a DataFrame with the same name 

1281 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1282 returned. 

1283 """ 

1284 pixel_id_map: dict[int, int] = { 

1285 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1286 } 

1287 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1288 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1289 apdb_part[i] = pixel_id_map[diaObjId] 

1290 sources = sources.copy() 

1291 sources["apdb_part"] = apdb_part 

1292 return sources 

1293 

1294 @classmethod 

1295 def _time_partition_cls(cls, time: float | astropy.time.Time, epoch_mjd: float, part_days: int) -> int: 

1296 """Calculate time partition number for a given time. 

1297 

1298 Parameters 

1299 ---------- 

1300 time : `float` or `astropy.time.Time` 

1301 Time for which to calculate partition number. Can be float to mean 

1302 MJD or `astropy.time.Time` 

1303 epoch_mjd : `float` 

1304 Epoch time for partition 0. 

1305 part_days : `int` 

1306 Number of days per partition. 

1307 

1308 Returns 

1309 ------- 

1310 partition : `int` 

1311 Partition number for a given time. 

1312 """ 

1313 if isinstance(time, astropy.time.Time): 

1314 mjd = float(time.mjd) 

1315 else: 

1316 mjd = time 

1317 days_since_epoch = mjd - epoch_mjd 

1318 partition = int(days_since_epoch) // part_days 

1319 return partition 

1320 

1321 def _time_partition(self, time: float | astropy.time.Time) -> int: 

1322 """Calculate time partition number for a given time. 

1323 

1324 Parameters 

1325 ---------- 

1326 time : `float` or `astropy.time.Time` 

1327 Time for which to calculate partition number. Can be float to mean 

1328 MJD or `astropy.time.Time` 

1329 

1330 Returns 

1331 ------- 

1332 partition : `int` 

1333 Partition number for a given time. 

1334 """ 

1335 if isinstance(time, astropy.time.Time): 

1336 mjd = float(time.mjd) 

1337 else: 

1338 mjd = time 

1339 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

1340 partition = int(days_since_epoch) // self.config.time_partition_days 

1341 return partition 

1342 

1343 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1344 """Make an empty catalog for a table with a given name. 

1345 

1346 Parameters 

1347 ---------- 

1348 table_name : `ApdbTables` 

1349 Name of the table. 

1350 

1351 Returns 

1352 ------- 

1353 catalog : `pandas.DataFrame` 

1354 An empty catalog. 

1355 """ 

1356 table = self._schema.tableSchemas[table_name] 

1357 

1358 data = { 

1359 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1360 for columnDef in table.columns 

1361 } 

1362 return pandas.DataFrame(data) 

1363 

1364 def _combine_where( 

1365 self, 

1366 prefix: str, 

1367 where1: list[tuple[str, tuple]], 

1368 where2: list[tuple[str, tuple]], 

1369 suffix: str | None = None, 

1370 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1371 """Make cartesian product of two parts of WHERE clause into a series 

1372 of statements to execute. 

1373 

1374 Parameters 

1375 ---------- 

1376 prefix : `str` 

1377 Initial statement prefix that comes before WHERE clause, e.g. 

1378 "SELECT * from Table" 

1379 """ 

1380 # If lists are empty use special sentinels. 

1381 if not where1: 

1382 where1 = [("", ())] 

1383 if not where2: 

1384 where2 = [("", ())] 

1385 

1386 for expr1, params1 in where1: 

1387 for expr2, params2 in where2: 

1388 full_query = prefix 

1389 wheres = [] 

1390 if expr1: 

1391 wheres.append(expr1) 

1392 if expr2: 

1393 wheres.append(expr2) 

1394 if wheres: 

1395 full_query += " WHERE " + " AND ".join(wheres) 

1396 if suffix: 

1397 full_query += " " + suffix 

1398 params = params1 + params2 

1399 if params: 

1400 statement = self._preparer.prepare(full_query) 

1401 else: 

1402 # If there are no params then it is likely that query 

1403 # has a bunch of literals rendered already, no point 

1404 # trying to prepare it. 

1405 statement = cassandra.query.SimpleStatement(full_query) 

1406 yield (statement, params) 

1407 

1408 def _spatial_where( 

1409 self, region: sphgeom.Region | None, use_ranges: bool = False 

1410 ) -> list[tuple[str, tuple]]: 

1411 """Generate expressions for spatial part of WHERE clause. 

1412 

1413 Parameters 

1414 ---------- 

1415 region : `sphgeom.Region` 

1416 Spatial region for query results. 

1417 use_ranges : `bool` 

1418 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1419 p2") instead of exact list of pixels. Should be set to True for 

1420 large regions covering very many pixels. 

1421 

1422 Returns 

1423 ------- 

1424 expressions : `list` [ `tuple` ] 

1425 Empty list is returned if ``region`` is `None`, otherwise a list 

1426 of one or more (expression, parameters) tuples 

1427 """ 

1428 if region is None: 

1429 return [] 

1430 if use_ranges: 

1431 pixel_ranges = self._pixelization.envelope(region) 

1432 expressions: list[tuple[str, tuple]] = [] 

1433 for lower, upper in pixel_ranges: 

1434 upper -= 1 

1435 if lower == upper: 

1436 expressions.append(('"apdb_part" = ?', (lower,))) 

1437 else: 

1438 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1439 return expressions 

1440 else: 

1441 pixels = self._pixelization.pixels(region) 

1442 if self.config.query_per_spatial_part: 

1443 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1444 else: 

1445 pixels_str = ",".join([str(pix) for pix in pixels]) 

1446 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1447 

1448 def _temporal_where( 

1449 self, 

1450 table: ApdbTables, 

1451 start_time: float | astropy.time.Time, 

1452 end_time: float | astropy.time.Time, 

1453 query_per_time_part: bool | None = None, 

1454 ) -> tuple[list[str], list[tuple[str, tuple]]]: 

1455 """Generate table names and expressions for temporal part of WHERE 

1456 clauses. 

1457 

1458 Parameters 

1459 ---------- 

1460 table : `ApdbTables` 

1461 Table to select from. 

1462 start_time : `astropy.time.Time` or `float` 

1463 Starting Datetime of MJD value of the time range. 

1464 end_time : `astropy.time.Time` or `float` 

1465 Starting Datetime of MJD value of the time range. 

1466 query_per_time_part : `bool`, optional 

1467 If None then use ``query_per_time_part`` from configuration. 

1468 

1469 Returns 

1470 ------- 

1471 tables : `list` [ `str` ] 

1472 List of the table names to query. 

1473 expressions : `list` [ `tuple` ] 

1474 A list of zero or more (expression, parameters) tuples. 

1475 """ 

1476 tables: list[str] 

1477 temporal_where: list[tuple[str, tuple]] = [] 

1478 table_name = self._schema.tableName(table) 

1479 time_part_start = self._time_partition(start_time) 

1480 time_part_end = self._time_partition(end_time) 

1481 time_parts = list(range(time_part_start, time_part_end + 1)) 

1482 if self.config.time_partition_tables: 

1483 tables = [f"{table_name}_{part}" for part in time_parts] 

1484 else: 

1485 tables = [table_name] 

1486 if query_per_time_part is None: 

1487 query_per_time_part = self.config.query_per_time_part 

1488 if query_per_time_part: 

1489 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1490 else: 

1491 time_part_list = ",".join([str(part) for part in time_parts]) 

1492 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1493 

1494 return tables, temporal_where