Coverage for python/lsst/dax/apdb/cassandra/apdbCassandra.py: 18%

613 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 03:20 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import dataclasses 

27import json 

28import logging 

29from collections.abc import Iterable, Iterator, Mapping, Set 

30from typing import TYPE_CHECKING, Any, cast 

31 

32import numpy as np 

33import pandas 

34 

35# If cassandra-driver is not there the module can still be imported 

36# but ApdbCassandra cannot be instantiated. 

37try: 

38 import cassandra 

39 import cassandra.query 

40 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

41 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

42 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

43 

44 CASSANDRA_IMPORTED = True 

45except ImportError: 

46 CASSANDRA_IMPORTED = False 

47 

48import astropy.time 

49import felis.datamodel 

50from lsst import sphgeom 

51from lsst.pex.config import ChoiceField, Field, ListField 

52from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

53from lsst.utils.iteration import chunk_iterable 

54 

55from ..apdb import Apdb, ApdbConfig 

56from ..apdbConfigFreezer import ApdbConfigFreezer 

57from ..apdbReplica import ReplicaChunk 

58from ..apdbSchema import ApdbTables 

59from ..monitor import MonAgent 

60from ..pixelization import Pixelization 

61from ..schema_model import Table 

62from ..timer import Timer 

63from ..versionTuple import IncompatibleVersionError, VersionTuple 

64from .apdbCassandraReplica import ApdbCassandraReplica 

65from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

66from .apdbMetadataCassandra import ApdbMetadataCassandra 

67from .cassandra_utils import ( 

68 PreparedStatementCache, 

69 literal, 

70 pandas_dataframe_factory, 

71 quote_id, 

72 raw_data_factory, 

73 select_concurrent, 

74) 

75 

76if TYPE_CHECKING: 

77 from ..apdbMetadata import ApdbMetadata 

78 

79_LOG = logging.getLogger(__name__) 

80 

81_MON = MonAgent(__name__) 

82 

83VERSION = VersionTuple(0, 1, 0) 

84"""Version for the code controlling non-replication tables. This needs to be 

85updated following compatibility rules when schema produced by this code 

86changes. 

87""" 

88 

89# Copied from daf_butler. 

90DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

91"""Default name of the environmental variable that will be used to locate DB 

92credentials configuration file. """ 

93 

94DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

95"""Default path at which it is expected that DB credentials are found.""" 

96 

97 

98class CassandraMissingError(Exception): 

99 def __init__(self) -> None: 

100 super().__init__("cassandra-driver module cannot be imported") 

101 

102 

103class ApdbCassandraConfig(ApdbConfig): 

104 """Configuration class for Cassandra-based APDB implementation.""" 

105 

106 contact_points = ListField[str]( 

107 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

108 ) 

109 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

110 port = Field[int](doc="Port number to connect to.", default=9042) 

111 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

112 username = Field[str]( 

113 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

114 default="", 

115 ) 

116 read_consistency = Field[str]( 

117 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

118 ) 

119 write_consistency = Field[str]( 

120 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

121 ) 

122 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

123 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=60.0) 

124 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0) 

125 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

126 protocol_version = Field[int]( 

127 doc="Cassandra protocol version to use, default is V4", 

128 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

129 ) 

130 dia_object_columns = ListField[str]( 

131 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

132 ) 

133 prefix = Field[str](doc="Prefix to add to table names", default="") 

134 part_pixelization = ChoiceField[str]( 

135 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

136 doc="Pixelization used for partitioning index.", 

137 default="mq3c", 

138 ) 

139 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

140 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

141 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

142 timer = Field[bool](doc="If True then print/log timing information", default=False) 

143 time_partition_tables = Field[bool]( 

144 doc="Use per-partition tables for sources instead of partitioning by time", default=False 

145 ) 

146 time_partition_days = Field[int]( 

147 doc=( 

148 "Time partitioning granularity in days, this value must not be changed after database is " 

149 "initialized" 

150 ), 

151 default=30, 

152 ) 

153 time_partition_start = Field[str]( 

154 doc=( 

155 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

156 "This is used only when time_partition_tables is True." 

157 ), 

158 default="2018-12-01T00:00:00", 

159 ) 

160 time_partition_end = Field[str]( 

161 doc=( 

162 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

163 "This is used only when time_partition_tables is True." 

164 ), 

165 default="2030-01-01T00:00:00", 

166 ) 

167 query_per_time_part = Field[bool]( 

168 default=False, 

169 doc=( 

170 "If True then build separate query for each time partition, otherwise build one single query. " 

171 "This is only used when time_partition_tables is False in schema config." 

172 ), 

173 ) 

174 query_per_spatial_part = Field[bool]( 

175 default=False, 

176 doc="If True then build one query per spatial partition, otherwise build single query.", 

177 ) 

178 use_insert_id_skips_diaobjects = Field[bool]( 

179 default=False, 

180 doc=( 

181 "If True then do not store DiaObjects when use_insert_id is True " 

182 "(DiaObjectsChunks has the same data)." 

183 ), 

184 ) 

185 

186 

187@dataclasses.dataclass 

188class _FrozenApdbCassandraConfig: 

189 """Part of the configuration that is saved in metadata table and read back. 

190 

191 The attributes are a subset of attributes in `ApdbCassandraConfig` class. 

192 

193 Parameters 

194 ---------- 

195 config : `ApdbSqlConfig` 

196 Configuration used to copy initial values of attributes. 

197 """ 

198 

199 use_insert_id: bool 

200 part_pixelization: str 

201 part_pix_level: int 

202 ra_dec_columns: list[str] 

203 time_partition_tables: bool 

204 time_partition_days: int 

205 use_insert_id_skips_diaobjects: bool 

206 

207 def __init__(self, config: ApdbCassandraConfig): 

208 self.use_insert_id = config.use_insert_id 

209 self.part_pixelization = config.part_pixelization 

210 self.part_pix_level = config.part_pix_level 

211 self.ra_dec_columns = list(config.ra_dec_columns) 

212 self.time_partition_tables = config.time_partition_tables 

213 self.time_partition_days = config.time_partition_days 

214 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects 

215 

216 def to_json(self) -> str: 

217 """Convert this instance to JSON representation.""" 

218 return json.dumps(dataclasses.asdict(self)) 

219 

220 def update(self, json_str: str) -> None: 

221 """Update attribute values from a JSON string. 

222 

223 Parameters 

224 ---------- 

225 json_str : str 

226 String containing JSON representation of configuration. 

227 """ 

228 data = json.loads(json_str) 

229 if not isinstance(data, dict): 

230 raise TypeError(f"JSON string must be convertible to object: {json_str!r}") 

231 allowed_names = {field.name for field in dataclasses.fields(self)} 

232 for key, value in data.items(): 

233 if key not in allowed_names: 

234 raise ValueError(f"JSON object contains unknown key: {key}") 

235 setattr(self, key, value) 

236 

237 

238if CASSANDRA_IMPORTED: 238 ↛ 240line 238 didn't jump to line 240, because the condition on line 238 was never true

239 

240 class _AddressTranslator(AddressTranslator): 

241 """Translate internal IP address to external. 

242 

243 Only used for docker-based setup, not viable long-term solution. 

244 """ 

245 

246 def __init__(self, public_ips: list[str], private_ips: list[str]): 

247 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

248 

249 def translate(self, private_ip: str) -> str: 

250 return self._map.get(private_ip, private_ip) 

251 

252 

253class ApdbCassandra(Apdb): 

254 """Implementation of APDB database on to of Apache Cassandra. 

255 

256 The implementation is configured via standard ``pex_config`` mechanism 

257 using `ApdbCassandraConfig` configuration class. For an example of 

258 different configurations check config/ folder. 

259 

260 Parameters 

261 ---------- 

262 config : `ApdbCassandraConfig` 

263 Configuration object. 

264 """ 

265 

266 metadataSchemaVersionKey = "version:schema" 

267 """Name of the metadata key to store schema version number.""" 

268 

269 metadataCodeVersionKey = "version:ApdbCassandra" 

270 """Name of the metadata key to store code version number.""" 

271 

272 metadataReplicaVersionKey = "version:ApdbCassandraReplica" 

273 """Name of the metadata key to store replica code version number.""" 

274 

275 metadataConfigKey = "config:apdb-cassandra.json" 

276 """Name of the metadata key to store code version number.""" 

277 

278 _frozen_parameters = ( 

279 "use_insert_id", 

280 "part_pixelization", 

281 "part_pix_level", 

282 "ra_dec_columns", 

283 "time_partition_tables", 

284 "time_partition_days", 

285 "use_insert_id_skips_diaobjects", 

286 ) 

287 """Names of the config parameters to be frozen in metadata table.""" 

288 

289 partition_zero_epoch = astropy.time.Time(0, format="unix_tai") 

290 """Start time for partition 0, this should never be changed.""" 

291 

292 def __init__(self, config: ApdbCassandraConfig): 

293 if not CASSANDRA_IMPORTED: 

294 raise CassandraMissingError() 

295 

296 self._keyspace = config.keyspace 

297 

298 self._cluster, self._session = self._make_session(config) 

299 

300 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

301 self._metadata = ApdbMetadataCassandra( 

302 self._session, meta_table_name, config.keyspace, "read_tuples", "write" 

303 ) 

304 

305 # Read frozen config from metadata. 

306 config_json = self._metadata.get(self.metadataConfigKey) 

307 if config_json is not None: 

308 # Update config from metadata. 

309 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters) 

310 self.config = freezer.update(config, config_json) 

311 else: 

312 self.config = config 

313 self.config.validate() 

314 

315 self._pixelization = Pixelization( 

316 self.config.part_pixelization, 

317 self.config.part_pix_level, 

318 config.part_pix_max_ranges, 

319 ) 

320 

321 self._schema = ApdbCassandraSchema( 

322 session=self._session, 

323 keyspace=self._keyspace, 

324 schema_file=self.config.schema_file, 

325 schema_name=self.config.schema_name, 

326 prefix=self.config.prefix, 

327 time_partition_tables=self.config.time_partition_tables, 

328 enable_replica=self.config.use_insert_id, 

329 ) 

330 self._partition_zero_epoch_mjd = float(self.partition_zero_epoch.mjd) 

331 

332 if self._metadata.table_exists(): 

333 self._versionCheck(self._metadata) 

334 

335 # Cache for prepared statements 

336 self._preparer = PreparedStatementCache(self._session) 

337 

338 _LOG.debug("ApdbCassandra Configuration:") 

339 for key, value in self.config.items(): 

340 _LOG.debug(" %s: %s", key, value) 

341 

342 self._timer_args: list[MonAgent | logging.Logger] = [_MON] 

343 if self.config.timer: 

344 self._timer_args.append(_LOG) 

345 

346 def __del__(self) -> None: 

347 if hasattr(self, "_cluster"): 

348 self._cluster.shutdown() 

349 

350 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

351 """Create `Timer` instance given its name.""" 

352 return Timer(name, *self._timer_args, tags=tags) 

353 

354 @classmethod 

355 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]: 

356 """Make Cassandra session.""" 

357 addressTranslator: AddressTranslator | None = None 

358 if config.private_ips: 

359 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

360 

361 cluster = Cluster( 

362 execution_profiles=cls._makeProfiles(config), 

363 contact_points=config.contact_points, 

364 port=config.port, 

365 address_translator=addressTranslator, 

366 protocol_version=config.protocol_version, 

367 auth_provider=cls._make_auth_provider(config), 

368 ) 

369 session = cluster.connect() 

370 # Disable result paging 

371 session.default_fetch_size = None 

372 

373 return cluster, session 

374 

375 @classmethod 

376 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None: 

377 """Make Cassandra authentication provider instance.""" 

378 try: 

379 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

380 except DbAuthNotFoundError: 

381 # Credentials file doesn't exist, use anonymous login. 

382 return None 

383 

384 empty_username = True 

385 # Try every contact point in turn. 

386 for hostname in config.contact_points: 

387 try: 

388 username, password = dbauth.getAuth( 

389 "cassandra", config.username, hostname, config.port, config.keyspace 

390 ) 

391 if not username: 

392 # Password without user name, try next hostname, but give 

393 # warning later if no better match is found. 

394 empty_username = True 

395 else: 

396 return PlainTextAuthProvider(username=username, password=password) 

397 except DbAuthNotFoundError: 

398 pass 

399 

400 if empty_username: 

401 _LOG.warning( 

402 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

403 f"user name, anonymous Cassandra logon will be attempted." 

404 ) 

405 

406 return None 

407 

408 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None: 

409 """Check schema version compatibility.""" 

410 

411 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

412 """Retrieve version number from given metadata key.""" 

413 if metadata.table_exists(): 

414 version_str = metadata.get(key) 

415 if version_str is None: 

416 # Should not happen with existing metadata table. 

417 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

418 return VersionTuple.fromString(version_str) 

419 return default 

420 

421 # For old databases where metadata table does not exist we assume that 

422 # version of both code and schema is 0.1.0. 

423 initial_version = VersionTuple(0, 1, 0) 

424 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

425 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

426 

427 # For now there is no way to make read-only APDB instances, assume that 

428 # any access can do updates. 

429 if not self._schema.schemaVersion().checkCompatibility(db_schema_version): 

430 raise IncompatibleVersionError( 

431 f"Configured schema version {self._schema.schemaVersion()} " 

432 f"is not compatible with database version {db_schema_version}" 

433 ) 

434 if not self.apdbImplementationVersion().checkCompatibility(db_code_version): 

435 raise IncompatibleVersionError( 

436 f"Current code version {self.apdbImplementationVersion()} " 

437 f"is not compatible with database version {db_code_version}" 

438 ) 

439 

440 # Check replica code version only if replica is enabled. 

441 if self._schema.has_replica_chunks: 

442 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version) 

443 code_replica_version = ApdbCassandraReplica.apdbReplicaImplementationVersion() 

444 if not code_replica_version.checkCompatibility(db_replica_version): 

445 raise IncompatibleVersionError( 

446 f"Current replication code version {code_replica_version} " 

447 f"is not compatible with database version {db_replica_version}" 

448 ) 

449 

450 @classmethod 

451 def apdbImplementationVersion(cls) -> VersionTuple: 

452 """Return version number for current APDB implementation. 

453 

454 Returns 

455 ------- 

456 version : `VersionTuple` 

457 Version of the code defined in implementation class. 

458 """ 

459 return VERSION 

460 

461 def tableDef(self, table: ApdbTables) -> Table | None: 

462 # docstring is inherited from a base class 

463 return self._schema.tableSchemas.get(table) 

464 

465 @classmethod 

466 def init_database( 

467 cls, 

468 hosts: list[str], 

469 keyspace: str, 

470 *, 

471 schema_file: str | None = None, 

472 schema_name: str | None = None, 

473 read_sources_months: int | None = None, 

474 read_forced_sources_months: int | None = None, 

475 use_insert_id: bool = False, 

476 use_insert_id_skips_diaobjects: bool = False, 

477 port: int | None = None, 

478 username: str | None = None, 

479 prefix: str | None = None, 

480 part_pixelization: str | None = None, 

481 part_pix_level: int | None = None, 

482 time_partition_tables: bool = True, 

483 time_partition_start: str | None = None, 

484 time_partition_end: str | None = None, 

485 read_consistency: str | None = None, 

486 write_consistency: str | None = None, 

487 read_timeout: int | None = None, 

488 write_timeout: int | None = None, 

489 ra_dec_columns: list[str] | None = None, 

490 replication_factor: int | None = None, 

491 drop: bool = False, 

492 ) -> ApdbCassandraConfig: 

493 """Initialize new APDB instance and make configuration object for it. 

494 

495 Parameters 

496 ---------- 

497 hosts : `list` [`str`] 

498 List of host names or IP addresses for Cassandra cluster. 

499 keyspace : `str` 

500 Name of the keyspace for APDB tables. 

501 schema_file : `str`, optional 

502 Location of (YAML) configuration file with APDB schema. If not 

503 specified then default location will be used. 

504 schema_name : `str`, optional 

505 Name of the schema in YAML configuration file. If not specified 

506 then default name will be used. 

507 read_sources_months : `int`, optional 

508 Number of months of history to read from DiaSource. 

509 read_forced_sources_months : `int`, optional 

510 Number of months of history to read from DiaForcedSource. 

511 use_insert_id : `bool`, optional 

512 If True, make additional tables used for replication to PPDB. 

513 use_insert_id_skips_diaobjects : `bool`, optional 

514 If `True` then do not fill regular ``DiaObject`` table when 

515 ``use_insert_id`` is `True`. 

516 port : `int`, optional 

517 Port number to use for Cassandra connections. 

518 username : `str`, optional 

519 User name for Cassandra connections. 

520 prefix : `str`, optional 

521 Optional prefix for all table names. 

522 part_pixelization : `str`, optional 

523 Name of the MOC pixelization used for partitioning. 

524 part_pix_level : `int`, optional 

525 Pixelization level. 

526 time_partition_tables : `bool`, optional 

527 Create per-partition tables. 

528 time_partition_start : `str`, optional 

529 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

530 format, in TAI. 

531 time_partition_end : `str`, optional 

532 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

533 format, in TAI. 

534 read_consistency : `str`, optional 

535 Name of the consistency level for read operations. 

536 write_consistency : `str`, optional 

537 Name of the consistency level for write operations. 

538 read_timeout : `int`, optional 

539 Read timeout in seconds. 

540 write_timeout : `int`, optional 

541 Write timeout in seconds. 

542 ra_dec_columns : `list` [`str`], optional 

543 Names of ra/dec columns in DiaObject table. 

544 replication_factor : `int`, optional 

545 Replication factor used when creating new keyspace, if keyspace 

546 already exists its replication factor is not changed. 

547 drop : `bool`, optional 

548 If `True` then drop existing tables before re-creating the schema. 

549 

550 Returns 

551 ------- 

552 config : `ApdbCassandraConfig` 

553 Resulting configuration object for a created APDB instance. 

554 """ 

555 config = ApdbCassandraConfig( 

556 contact_points=hosts, 

557 keyspace=keyspace, 

558 use_insert_id=use_insert_id, 

559 use_insert_id_skips_diaobjects=use_insert_id_skips_diaobjects, 

560 time_partition_tables=time_partition_tables, 

561 ) 

562 if schema_file is not None: 

563 config.schema_file = schema_file 

564 if schema_name is not None: 

565 config.schema_name = schema_name 

566 if read_sources_months is not None: 

567 config.read_sources_months = read_sources_months 

568 if read_forced_sources_months is not None: 

569 config.read_forced_sources_months = read_forced_sources_months 

570 if port is not None: 

571 config.port = port 

572 if username is not None: 

573 config.username = username 

574 if prefix is not None: 

575 config.prefix = prefix 

576 if part_pixelization is not None: 

577 config.part_pixelization = part_pixelization 

578 if part_pix_level is not None: 

579 config.part_pix_level = part_pix_level 

580 if time_partition_start is not None: 

581 config.time_partition_start = time_partition_start 

582 if time_partition_end is not None: 

583 config.time_partition_end = time_partition_end 

584 if read_consistency is not None: 

585 config.read_consistency = read_consistency 

586 if write_consistency is not None: 

587 config.write_consistency = write_consistency 

588 if read_timeout is not None: 

589 config.read_timeout = read_timeout 

590 if write_timeout is not None: 

591 config.write_timeout = write_timeout 

592 if ra_dec_columns is not None: 

593 config.ra_dec_columns = ra_dec_columns 

594 

595 cls._makeSchema(config, drop=drop, replication_factor=replication_factor) 

596 

597 return config 

598 

599 def get_replica(self) -> ApdbCassandraReplica: 

600 """Return `ApdbReplica` instance for this database.""" 

601 # Note that this instance has to stay alive while replica exists, so 

602 # we pass reference to self. 

603 return ApdbCassandraReplica(self, self._schema, self._session) 

604 

605 @classmethod 

606 def _makeSchema( 

607 cls, config: ApdbConfig, *, drop: bool = False, replication_factor: int | None = None 

608 ) -> None: 

609 # docstring is inherited from a base class 

610 

611 if not isinstance(config, ApdbCassandraConfig): 

612 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

613 

614 cluster, session = cls._make_session(config) 

615 

616 schema = ApdbCassandraSchema( 

617 session=session, 

618 keyspace=config.keyspace, 

619 schema_file=config.schema_file, 

620 schema_name=config.schema_name, 

621 prefix=config.prefix, 

622 time_partition_tables=config.time_partition_tables, 

623 enable_replica=config.use_insert_id, 

624 ) 

625 

626 # Ask schema to create all tables. 

627 if config.time_partition_tables: 

628 time_partition_start = astropy.time.Time(config.time_partition_start, format="isot", scale="tai") 

629 time_partition_end = astropy.time.Time(config.time_partition_end, format="isot", scale="tai") 

630 part_epoch = float(cls.partition_zero_epoch.mjd) 

631 part_days = config.time_partition_days 

632 part_range = ( 

633 cls._time_partition_cls(time_partition_start, part_epoch, part_days), 

634 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1, 

635 ) 

636 schema.makeSchema(drop=drop, part_range=part_range, replication_factor=replication_factor) 

637 else: 

638 schema.makeSchema(drop=drop, replication_factor=replication_factor) 

639 

640 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

641 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write") 

642 

643 # Fill version numbers, overrides if they existed before. 

644 if metadata.table_exists(): 

645 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

646 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

647 

648 if config.use_insert_id: 

649 # Only store replica code version if replica is enabled. 

650 metadata.set( 

651 cls.metadataReplicaVersionKey, 

652 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

653 force=True, 

654 ) 

655 

656 # Store frozen part of a configuration in metadata. 

657 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters) 

658 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

659 

660 cluster.shutdown() 

661 

662 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

663 # docstring is inherited from a base class 

664 

665 sp_where = self._spatial_where(region) 

666 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

667 

668 # We need to exclude extra partitioning columns from result. 

669 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

670 what = ",".join(quote_id(column) for column in column_names) 

671 

672 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

673 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

674 statements: list[tuple] = [] 

675 for where, params in sp_where: 

676 full_query = f"{query} WHERE {where}" 

677 if params: 

678 statement = self._preparer.prepare(full_query) 

679 else: 

680 # If there are no params then it is likely that query has a 

681 # bunch of literals rendered already, no point trying to 

682 # prepare it because it's not reusable. 

683 statement = cassandra.query.SimpleStatement(full_query) 

684 statements.append((statement, params)) 

685 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

686 

687 with _MON.context_tags({"table": "DiaObject"}): 

688 _MON.add_record( 

689 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

690 ) 

691 with self._timer("select_time"): 

692 objects = cast( 

693 pandas.DataFrame, 

694 select_concurrent( 

695 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

696 ), 

697 ) 

698 

699 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

700 return objects 

701 

702 def getDiaSources( 

703 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

704 ) -> pandas.DataFrame | None: 

705 # docstring is inherited from a base class 

706 months = self.config.read_sources_months 

707 if months == 0: 

708 return None 

709 mjd_end = visit_time.mjd 

710 mjd_start = mjd_end - months * 30 

711 

712 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

713 

714 def getDiaForcedSources( 

715 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

716 ) -> pandas.DataFrame | None: 

717 # docstring is inherited from a base class 

718 months = self.config.read_forced_sources_months 

719 if months == 0: 

720 return None 

721 mjd_end = visit_time.mjd 

722 mjd_start = mjd_end - months * 30 

723 

724 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

725 

726 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

727 # docstring is inherited from a base class 

728 # The order of checks corresponds to order in store(), on potential 

729 # store failure earlier tables have higher probability containing 

730 # stored records. With per-partition tables there will be many tables 

731 # in the list, but it is unlikely that we'll use that setup in 

732 # production. 

733 existing_tables = self._schema.existing_tables(ApdbTables.DiaSource, ApdbTables.DiaForcedSource) 

734 tables_to_check = existing_tables[ApdbTables.DiaSource][:] 

735 if self.config.use_insert_id: 

736 tables_to_check.append(self._schema.tableName(ExtraTables.DiaSourceChunks)) 

737 tables_to_check.extend(existing_tables[ApdbTables.DiaForcedSource]) 

738 if self.config.use_insert_id: 

739 tables_to_check.append(self._schema.tableName(ExtraTables.DiaForcedSourceChunks)) 

740 

741 # I do not want to run concurrent queries as they are all full-scan 

742 # queries, so we do one by one. 

743 for table_name in tables_to_check: 

744 # Try to find a single record with given visit/detector. This is a 

745 # full scan query so ALLOW FILTERING is needed. It will probably 

746 # guess PER PARTITION LIMIT itself, but let's help it. 

747 query = ( 

748 f'SELECT * from "{self._keyspace}"."{table_name}" ' 

749 "WHERE visit = ? AND detector = ? " 

750 "PER PARTITION LIMIT 1 LIMIT 1 ALLOW FILTERING" 

751 ) 

752 with self._timer("contains_visit_detector_time", tags={"table": table_name}): 

753 result = self._session.execute(self._preparer.prepare(query), (visit, detector)) 

754 if result.one() is not None: 

755 # There is a result. 

756 return True 

757 return False 

758 

759 def getSSObjects(self) -> pandas.DataFrame: 

760 # docstring is inherited from a base class 

761 tableName = self._schema.tableName(ApdbTables.SSObject) 

762 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

763 

764 objects = None 

765 with self._timer("select_time", tags={"table": "SSObject"}): 

766 result = self._session.execute(query, execution_profile="read_pandas") 

767 objects = result._current_rows 

768 

769 _LOG.debug("found %s SSObjects", objects.shape[0]) 

770 return objects 

771 

772 def store( 

773 self, 

774 visit_time: astropy.time.Time, 

775 objects: pandas.DataFrame, 

776 sources: pandas.DataFrame | None = None, 

777 forced_sources: pandas.DataFrame | None = None, 

778 ) -> None: 

779 # docstring is inherited from a base class 

780 

781 replica_chunk: ReplicaChunk | None = None 

782 if self._schema.has_replica_chunks: 

783 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

784 self._storeReplicaChunk(replica_chunk, visit_time) 

785 

786 # fill region partition column for DiaObjects 

787 objects = self._add_obj_part(objects) 

788 self._storeDiaObjects(objects, visit_time, replica_chunk) 

789 

790 if sources is not None: 

791 # copy apdb_part column from DiaObjects to DiaSources 

792 sources = self._add_src_part(sources, objects) 

793 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, replica_chunk) 

794 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk) 

795 

796 if forced_sources is not None: 

797 forced_sources = self._add_fsrc_part(forced_sources, objects) 

798 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, replica_chunk) 

799 

800 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

801 # docstring is inherited from a base class 

802 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

803 

804 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

805 # docstring is inherited from a base class 

806 

807 # To update a record we need to know its exact primary key (including 

808 # partition key) so we start by querying for diaSourceId to find the 

809 # primary keys. 

810 

811 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

812 # split it into 1k IDs per query 

813 selects: list[tuple] = [] 

814 for ids in chunk_iterable(idMap.keys(), 1_000): 

815 ids_str = ",".join(str(item) for item in ids) 

816 selects.append( 

817 ( 

818 ( 

819 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk" ' 

820 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

821 ), 

822 {}, 

823 ) 

824 ) 

825 

826 # No need for DataFrame here, read data as tuples. 

827 result = cast( 

828 list[tuple[int, int, int, int | None]], 

829 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

830 ) 

831 

832 # Make mapping from source ID to its partition. 

833 id2partitions: dict[int, tuple[int, int]] = {} 

834 id2chunk_id: dict[int, int] = {} 

835 for row in result: 

836 id2partitions[row[0]] = row[1:3] 

837 if row[3] is not None: 

838 id2chunk_id[row[0]] = row[3] 

839 

840 # make sure we know partitions for each ID 

841 if set(id2partitions) != set(idMap): 

842 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

843 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

844 

845 # Reassign in standard tables 

846 queries = cassandra.query.BatchStatement() 

847 table_name = self._schema.tableName(ApdbTables.DiaSource) 

848 for diaSourceId, ssObjectId in idMap.items(): 

849 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

850 values: tuple 

851 if self.config.time_partition_tables: 

852 query = ( 

853 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

854 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

855 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

856 ) 

857 values = (ssObjectId, apdb_part, diaSourceId) 

858 else: 

859 query = ( 

860 f'UPDATE "{self._keyspace}"."{table_name}"' 

861 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

862 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

863 ) 

864 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

865 queries.add(self._preparer.prepare(query), values) 

866 

867 # Reassign in replica tables, only if replication is enabled 

868 if id2chunk_id: 

869 # Filter out chunks that have been deleted already. There is a 

870 # potential race with concurrent removal of chunks, but it 

871 # should be handled by WHERE in UPDATE. 

872 known_ids = set() 

873 if replica_chunks := self.get_replica().getReplicaChunks(): 

874 known_ids = set(replica_chunk.id for replica_chunk in replica_chunks) 

875 id2chunk_id = {key: value for key, value in id2chunk_id.items() if value in known_ids} 

876 if id2chunk_id: 

877 table_name = self._schema.tableName(ExtraTables.DiaSourceChunks) 

878 for diaSourceId, ssObjectId in idMap.items(): 

879 if replica_chunk := id2chunk_id.get(diaSourceId): 

880 query = ( 

881 f'UPDATE "{self._keyspace}"."{table_name}" ' 

882 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

883 'WHERE "apdb_replica_chunk" = ? AND "diaSourceId" = ?' 

884 ) 

885 values = (ssObjectId, replica_chunk, diaSourceId) 

886 queries.add(self._preparer.prepare(query), values) 

887 

888 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

889 with self._timer("update_time", tags={"table": table_name}): 

890 self._session.execute(queries, execution_profile="write") 

891 

892 def dailyJob(self) -> None: 

893 # docstring is inherited from a base class 

894 pass 

895 

896 def countUnassociatedObjects(self) -> int: 

897 # docstring is inherited from a base class 

898 

899 # It's too inefficient to implement it for Cassandra in current schema. 

900 raise NotImplementedError() 

901 

902 @property 

903 def metadata(self) -> ApdbMetadata: 

904 # docstring is inherited from a base class 

905 if self._metadata is None: 

906 raise RuntimeError("Database schema was not initialized.") 

907 return self._metadata 

908 

909 @classmethod 

910 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

911 """Make all execution profiles used in the code.""" 

912 if config.private_ips: 

913 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

914 else: 

915 loadBalancePolicy = RoundRobinPolicy() 

916 

917 read_tuples_profile = ExecutionProfile( 

918 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

919 request_timeout=config.read_timeout, 

920 row_factory=cassandra.query.tuple_factory, 

921 load_balancing_policy=loadBalancePolicy, 

922 ) 

923 read_pandas_profile = ExecutionProfile( 

924 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

925 request_timeout=config.read_timeout, 

926 row_factory=pandas_dataframe_factory, 

927 load_balancing_policy=loadBalancePolicy, 

928 ) 

929 read_raw_profile = ExecutionProfile( 

930 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

931 request_timeout=config.read_timeout, 

932 row_factory=raw_data_factory, 

933 load_balancing_policy=loadBalancePolicy, 

934 ) 

935 # Profile to use with select_concurrent to return pandas data frame 

936 read_pandas_multi_profile = ExecutionProfile( 

937 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

938 request_timeout=config.read_timeout, 

939 row_factory=pandas_dataframe_factory, 

940 load_balancing_policy=loadBalancePolicy, 

941 ) 

942 # Profile to use with select_concurrent to return raw data (columns and 

943 # rows) 

944 read_raw_multi_profile = ExecutionProfile( 

945 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

946 request_timeout=config.read_timeout, 

947 row_factory=raw_data_factory, 

948 load_balancing_policy=loadBalancePolicy, 

949 ) 

950 write_profile = ExecutionProfile( 

951 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

952 request_timeout=config.write_timeout, 

953 load_balancing_policy=loadBalancePolicy, 

954 ) 

955 # To replace default DCAwareRoundRobinPolicy 

956 default_profile = ExecutionProfile( 

957 load_balancing_policy=loadBalancePolicy, 

958 ) 

959 return { 

960 "read_tuples": read_tuples_profile, 

961 "read_pandas": read_pandas_profile, 

962 "read_raw": read_raw_profile, 

963 "read_pandas_multi": read_pandas_multi_profile, 

964 "read_raw_multi": read_raw_multi_profile, 

965 "write": write_profile, 

966 EXEC_PROFILE_DEFAULT: default_profile, 

967 } 

968 

969 def _getSources( 

970 self, 

971 region: sphgeom.Region, 

972 object_ids: Iterable[int] | None, 

973 mjd_start: float, 

974 mjd_end: float, 

975 table_name: ApdbTables, 

976 ) -> pandas.DataFrame: 

977 """Return catalog of DiaSource instances given set of DiaObject IDs. 

978 

979 Parameters 

980 ---------- 

981 region : `lsst.sphgeom.Region` 

982 Spherical region. 

983 object_ids : 

984 Collection of DiaObject IDs 

985 mjd_start : `float` 

986 Lower bound of time interval. 

987 mjd_end : `float` 

988 Upper bound of time interval. 

989 table_name : `ApdbTables` 

990 Name of the table. 

991 

992 Returns 

993 ------- 

994 catalog : `pandas.DataFrame`, or `None` 

995 Catalog containing DiaSource records. Empty catalog is returned if 

996 ``object_ids`` is empty. 

997 """ 

998 object_id_set: Set[int] = set() 

999 if object_ids is not None: 

1000 object_id_set = set(object_ids) 

1001 if len(object_id_set) == 0: 

1002 return self._make_empty_catalog(table_name) 

1003 

1004 sp_where = self._spatial_where(region) 

1005 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

1006 

1007 # We need to exclude extra partitioning columns from result. 

1008 column_names = self._schema.apdbColumnNames(table_name) 

1009 what = ",".join(quote_id(column) for column in column_names) 

1010 

1011 # Build all queries 

1012 statements: list[tuple] = [] 

1013 for table in tables: 

1014 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

1015 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

1016 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

1017 

1018 with _MON.context_tags({"table": table_name.name}): 

1019 _MON.add_record( 

1020 "select_query_stats", values={"num_sp_part": len(sp_where), "num_queries": len(statements)} 

1021 ) 

1022 with self._timer("select_time"): 

1023 catalog = cast( 

1024 pandas.DataFrame, 

1025 select_concurrent( 

1026 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

1027 ), 

1028 ) 

1029 

1030 # filter by given object IDs 

1031 if len(object_id_set) > 0: 

1032 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1033 

1034 # precise filtering on midpointMjdTai 

1035 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1036 

1037 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1038 return catalog 

1039 

1040 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk, visit_time: astropy.time.Time) -> None: 

1041 # Cassandra timestamp uses milliseconds since epoch 

1042 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

1043 

1044 # everything goes into a single partition 

1045 partition = 0 

1046 

1047 table_name = self._schema.tableName(ExtraTables.ApdbReplicaChunks) 

1048 query = ( 

1049 f'INSERT INTO "{self._keyspace}"."{table_name}" ' 

1050 "(partition, apdb_replica_chunk, last_update_time, unique_id) " 

1051 "VALUES (?, ?, ?, ?)" 

1052 ) 

1053 

1054 self._session.execute( 

1055 self._preparer.prepare(query), 

1056 (partition, replica_chunk.id, timestamp, replica_chunk.unique_id), 

1057 timeout=self.config.write_timeout, 

1058 execution_profile="write", 

1059 ) 

1060 

1061 def _storeDiaObjects( 

1062 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1063 ) -> None: 

1064 """Store catalog of DiaObjects from current visit. 

1065 

1066 Parameters 

1067 ---------- 

1068 objs : `pandas.DataFrame` 

1069 Catalog with DiaObject records 

1070 visit_time : `astropy.time.Time` 

1071 Time of the current visit. 

1072 replica_chunk : `ReplicaChunk` or `None` 

1073 Replica chunk identifier if replication is configured. 

1074 """ 

1075 if len(objs) == 0: 

1076 _LOG.debug("No objects to write to database.") 

1077 return 

1078 

1079 visit_time_dt = visit_time.datetime 

1080 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

1081 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1082 

1083 extra_columns["validityStart"] = visit_time_dt 

1084 time_part: int | None = self._time_partition(visit_time) 

1085 if not self.config.time_partition_tables: 

1086 extra_columns["apdb_time_part"] = time_part 

1087 time_part = None 

1088 

1089 # Only store DiaObects if not doing replication or explicitly 

1090 # configured to always store them. 

1091 if replica_chunk is None or not self.config.use_insert_id_skips_diaobjects: 

1092 self._storeObjectsPandas( 

1093 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1094 ) 

1095 

1096 if replica_chunk is not None: 

1097 extra_columns = dict(apdb_replica_chunk=replica_chunk.id, validityStart=visit_time_dt) 

1098 self._storeObjectsPandas(objs, ExtraTables.DiaObjectChunks, extra_columns=extra_columns) 

1099 

1100 def _storeDiaSources( 

1101 self, 

1102 table_name: ApdbTables, 

1103 sources: pandas.DataFrame, 

1104 visit_time: astropy.time.Time, 

1105 replica_chunk: ReplicaChunk | None, 

1106 ) -> None: 

1107 """Store catalog of DIASources or DIAForcedSources from current visit. 

1108 

1109 Parameters 

1110 ---------- 

1111 table_name : `ApdbTables` 

1112 Table where to store the data. 

1113 sources : `pandas.DataFrame` 

1114 Catalog containing DiaSource records 

1115 visit_time : `astropy.time.Time` 

1116 Time of the current visit. 

1117 replica_chunk : `ReplicaChunk` or `None` 

1118 Replica chunk identifier if replication is configured. 

1119 """ 

1120 time_part: int | None = self._time_partition(visit_time) 

1121 extra_columns: dict[str, Any] = {} 

1122 if not self.config.time_partition_tables: 

1123 extra_columns["apdb_time_part"] = time_part 

1124 time_part = None 

1125 

1126 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

1127 

1128 if replica_chunk is not None: 

1129 extra_columns = dict(apdb_replica_chunk=replica_chunk.id) 

1130 if table_name is ApdbTables.DiaSource: 

1131 extra_table = ExtraTables.DiaSourceChunks 

1132 else: 

1133 extra_table = ExtraTables.DiaForcedSourceChunks 

1134 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1135 

1136 def _storeDiaSourcesPartitions( 

1137 self, sources: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1138 ) -> None: 

1139 """Store mapping of diaSourceId to its partitioning values. 

1140 

1141 Parameters 

1142 ---------- 

1143 sources : `pandas.DataFrame` 

1144 Catalog containing DiaSource records 

1145 visit_time : `astropy.time.Time` 

1146 Time of the current visit. 

1147 """ 

1148 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1149 extra_columns = { 

1150 "apdb_time_part": self._time_partition(visit_time), 

1151 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1152 } 

1153 

1154 self._storeObjectsPandas( 

1155 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1156 ) 

1157 

1158 def _storeObjectsPandas( 

1159 self, 

1160 records: pandas.DataFrame, 

1161 table_name: ApdbTables | ExtraTables, 

1162 extra_columns: Mapping | None = None, 

1163 time_part: int | None = None, 

1164 ) -> None: 

1165 """Store generic objects. 

1166 

1167 Takes Pandas catalog and stores a bunch of records in a table. 

1168 

1169 Parameters 

1170 ---------- 

1171 records : `pandas.DataFrame` 

1172 Catalog containing object records 

1173 table_name : `ApdbTables` 

1174 Name of the table as defined in APDB schema. 

1175 extra_columns : `dict`, optional 

1176 Mapping (column_name, column_value) which gives fixed values for 

1177 columns in each row, overrides values in ``records`` if matching 

1178 columns exist there. 

1179 time_part : `int`, optional 

1180 If not `None` then insert into a per-partition table. 

1181 

1182 Notes 

1183 ----- 

1184 If Pandas catalog contains additional columns not defined in table 

1185 schema they are ignored. Catalog does not have to contain all columns 

1186 defined in a table, but partition and clustering keys must be present 

1187 in a catalog or ``extra_columns``. 

1188 """ 

1189 # use extra columns if specified 

1190 if extra_columns is None: 

1191 extra_columns = {} 

1192 extra_fields = list(extra_columns.keys()) 

1193 

1194 # Fields that will come from dataframe. 

1195 df_fields = [column for column in records.columns if column not in extra_fields] 

1196 

1197 column_map = self._schema.getColumnMap(table_name) 

1198 # list of columns (as in felis schema) 

1199 fields = [column_map[field].name for field in df_fields if field in column_map] 

1200 fields += extra_fields 

1201 

1202 # check that all partitioning and clustering columns are defined 

1203 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

1204 table_name 

1205 ) 

1206 missing_columns = [column for column in required_columns if column not in fields] 

1207 if missing_columns: 

1208 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1209 

1210 qfields = [quote_id(field) for field in fields] 

1211 qfields_str = ",".join(qfields) 

1212 

1213 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1214 table = self._schema.tableName(table_name) 

1215 if time_part is not None: 

1216 table = f"{table}_{time_part}" 

1217 

1218 holders = ",".join(["?"] * len(qfields)) 

1219 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1220 statement = self._preparer.prepare(query) 

1221 queries = cassandra.query.BatchStatement() 

1222 for rec in records.itertuples(index=False): 

1223 values = [] 

1224 for field in df_fields: 

1225 if field not in column_map: 

1226 continue 

1227 value = getattr(rec, field) 

1228 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1229 if isinstance(value, pandas.Timestamp): 

1230 value = literal(value.to_pydatetime()) 

1231 else: 

1232 # Assume it's seconds since epoch, Cassandra 

1233 # datetime is in milliseconds 

1234 value = int(value * 1000) 

1235 values.append(literal(value)) 

1236 for field in extra_fields: 

1237 value = extra_columns[field] 

1238 values.append(literal(value)) 

1239 queries.add(statement, values) 

1240 

1241 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

1242 with self._timer("insert_time", tags={"table": table_name.name}): 

1243 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

1244 

1245 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1246 """Calculate spatial partition for each record and add it to a 

1247 DataFrame. 

1248 

1249 Notes 

1250 ----- 

1251 This overrides any existing column in a DataFrame with the same name 

1252 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1253 returned. 

1254 """ 

1255 # calculate HTM index for every DiaObject 

1256 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1257 ra_col, dec_col = self.config.ra_dec_columns 

1258 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1259 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1260 idx = self._pixelization.pixel(uv3d) 

1261 apdb_part[i] = idx 

1262 df = df.copy() 

1263 df["apdb_part"] = apdb_part 

1264 return df 

1265 

1266 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1267 """Add apdb_part column to DiaSource catalog. 

1268 

1269 Notes 

1270 ----- 

1271 This method copies apdb_part value from a matching DiaObject record. 

1272 DiaObject catalog needs to have a apdb_part column filled by 

1273 ``_add_obj_part`` method and DiaSource records need to be 

1274 associated to DiaObjects via ``diaObjectId`` column. 

1275 

1276 This overrides any existing column in a DataFrame with the same name 

1277 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1278 returned. 

1279 """ 

1280 pixel_id_map: dict[int, int] = { 

1281 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1282 } 

1283 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1284 ra_col, dec_col = self.config.ra_dec_columns 

1285 for i, (diaObjId, ra, dec) in enumerate( 

1286 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

1287 ): 

1288 if diaObjId == 0: 

1289 # DiaSources associated with SolarSystemObjects do not have an 

1290 # associated DiaObject hence we skip them and set partition 

1291 # based on its own ra/dec 

1292 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1293 idx = self._pixelization.pixel(uv3d) 

1294 apdb_part[i] = idx 

1295 else: 

1296 apdb_part[i] = pixel_id_map[diaObjId] 

1297 sources = sources.copy() 

1298 sources["apdb_part"] = apdb_part 

1299 return sources 

1300 

1301 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1302 """Add apdb_part column to DiaForcedSource catalog. 

1303 

1304 Notes 

1305 ----- 

1306 This method copies apdb_part value from a matching DiaObject record. 

1307 DiaObject catalog needs to have a apdb_part column filled by 

1308 ``_add_obj_part`` method and DiaSource records need to be 

1309 associated to DiaObjects via ``diaObjectId`` column. 

1310 

1311 This overrides any existing column in a DataFrame with the same name 

1312 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1313 returned. 

1314 """ 

1315 pixel_id_map: dict[int, int] = { 

1316 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1317 } 

1318 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1319 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1320 apdb_part[i] = pixel_id_map[diaObjId] 

1321 sources = sources.copy() 

1322 sources["apdb_part"] = apdb_part 

1323 return sources 

1324 

1325 @classmethod 

1326 def _time_partition_cls(cls, time: float | astropy.time.Time, epoch_mjd: float, part_days: int) -> int: 

1327 """Calculate time partition number for a given time. 

1328 

1329 Parameters 

1330 ---------- 

1331 time : `float` or `astropy.time.Time` 

1332 Time for which to calculate partition number. Can be float to mean 

1333 MJD or `astropy.time.Time` 

1334 epoch_mjd : `float` 

1335 Epoch time for partition 0. 

1336 part_days : `int` 

1337 Number of days per partition. 

1338 

1339 Returns 

1340 ------- 

1341 partition : `int` 

1342 Partition number for a given time. 

1343 """ 

1344 if isinstance(time, astropy.time.Time): 

1345 mjd = float(time.mjd) 

1346 else: 

1347 mjd = time 

1348 days_since_epoch = mjd - epoch_mjd 

1349 partition = int(days_since_epoch) // part_days 

1350 return partition 

1351 

1352 def _time_partition(self, time: float | astropy.time.Time) -> int: 

1353 """Calculate time partition number for a given time. 

1354 

1355 Parameters 

1356 ---------- 

1357 time : `float` or `astropy.time.Time` 

1358 Time for which to calculate partition number. Can be float to mean 

1359 MJD or `astropy.time.Time` 

1360 

1361 Returns 

1362 ------- 

1363 partition : `int` 

1364 Partition number for a given time. 

1365 """ 

1366 if isinstance(time, astropy.time.Time): 

1367 mjd = float(time.mjd) 

1368 else: 

1369 mjd = time 

1370 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

1371 partition = int(days_since_epoch) // self.config.time_partition_days 

1372 return partition 

1373 

1374 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1375 """Make an empty catalog for a table with a given name. 

1376 

1377 Parameters 

1378 ---------- 

1379 table_name : `ApdbTables` 

1380 Name of the table. 

1381 

1382 Returns 

1383 ------- 

1384 catalog : `pandas.DataFrame` 

1385 An empty catalog. 

1386 """ 

1387 table = self._schema.tableSchemas[table_name] 

1388 

1389 data = { 

1390 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1391 for columnDef in table.columns 

1392 } 

1393 return pandas.DataFrame(data) 

1394 

1395 def _combine_where( 

1396 self, 

1397 prefix: str, 

1398 where1: list[tuple[str, tuple]], 

1399 where2: list[tuple[str, tuple]], 

1400 suffix: str | None = None, 

1401 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1402 """Make cartesian product of two parts of WHERE clause into a series 

1403 of statements to execute. 

1404 

1405 Parameters 

1406 ---------- 

1407 prefix : `str` 

1408 Initial statement prefix that comes before WHERE clause, e.g. 

1409 "SELECT * from Table" 

1410 """ 

1411 # If lists are empty use special sentinels. 

1412 if not where1: 

1413 where1 = [("", ())] 

1414 if not where2: 

1415 where2 = [("", ())] 

1416 

1417 for expr1, params1 in where1: 

1418 for expr2, params2 in where2: 

1419 full_query = prefix 

1420 wheres = [] 

1421 if expr1: 

1422 wheres.append(expr1) 

1423 if expr2: 

1424 wheres.append(expr2) 

1425 if wheres: 

1426 full_query += " WHERE " + " AND ".join(wheres) 

1427 if suffix: 

1428 full_query += " " + suffix 

1429 params = params1 + params2 

1430 if params: 

1431 statement = self._preparer.prepare(full_query) 

1432 else: 

1433 # If there are no params then it is likely that query 

1434 # has a bunch of literals rendered already, no point 

1435 # trying to prepare it. 

1436 statement = cassandra.query.SimpleStatement(full_query) 

1437 yield (statement, params) 

1438 

1439 def _spatial_where( 

1440 self, region: sphgeom.Region | None, use_ranges: bool = False 

1441 ) -> list[tuple[str, tuple]]: 

1442 """Generate expressions for spatial part of WHERE clause. 

1443 

1444 Parameters 

1445 ---------- 

1446 region : `sphgeom.Region` 

1447 Spatial region for query results. 

1448 use_ranges : `bool` 

1449 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1450 p2") instead of exact list of pixels. Should be set to True for 

1451 large regions covering very many pixels. 

1452 

1453 Returns 

1454 ------- 

1455 expressions : `list` [ `tuple` ] 

1456 Empty list is returned if ``region`` is `None`, otherwise a list 

1457 of one or more (expression, parameters) tuples 

1458 """ 

1459 if region is None: 

1460 return [] 

1461 if use_ranges: 

1462 pixel_ranges = self._pixelization.envelope(region) 

1463 expressions: list[tuple[str, tuple]] = [] 

1464 for lower, upper in pixel_ranges: 

1465 upper -= 1 

1466 if lower == upper: 

1467 expressions.append(('"apdb_part" = ?', (lower,))) 

1468 else: 

1469 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1470 return expressions 

1471 else: 

1472 pixels = self._pixelization.pixels(region) 

1473 if self.config.query_per_spatial_part: 

1474 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1475 else: 

1476 pixels_str = ",".join([str(pix) for pix in pixels]) 

1477 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1478 

1479 def _temporal_where( 

1480 self, 

1481 table: ApdbTables, 

1482 start_time: float | astropy.time.Time, 

1483 end_time: float | astropy.time.Time, 

1484 query_per_time_part: bool | None = None, 

1485 ) -> tuple[list[str], list[tuple[str, tuple]]]: 

1486 """Generate table names and expressions for temporal part of WHERE 

1487 clauses. 

1488 

1489 Parameters 

1490 ---------- 

1491 table : `ApdbTables` 

1492 Table to select from. 

1493 start_time : `astropy.time.Time` or `float` 

1494 Starting Datetime of MJD value of the time range. 

1495 end_time : `astropy.time.Time` or `float` 

1496 Starting Datetime of MJD value of the time range. 

1497 query_per_time_part : `bool`, optional 

1498 If None then use ``query_per_time_part`` from configuration. 

1499 

1500 Returns 

1501 ------- 

1502 tables : `list` [ `str` ] 

1503 List of the table names to query. 

1504 expressions : `list` [ `tuple` ] 

1505 A list of zero or more (expression, parameters) tuples. 

1506 """ 

1507 tables: list[str] 

1508 temporal_where: list[tuple[str, tuple]] = [] 

1509 table_name = self._schema.tableName(table) 

1510 time_part_start = self._time_partition(start_time) 

1511 time_part_end = self._time_partition(end_time) 

1512 time_parts = list(range(time_part_start, time_part_end + 1)) 

1513 if self.config.time_partition_tables: 

1514 tables = [f"{table_name}_{part}" for part in time_parts] 

1515 else: 

1516 tables = [table_name] 

1517 if query_per_time_part is None: 

1518 query_per_time_part = self.config.query_per_time_part 

1519 if query_per_time_part: 

1520 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1521 else: 

1522 time_part_list = ",".join([str(part) for part in time_parts]) 

1523 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1524 

1525 return tables, temporal_where