Coverage for python/lsst/dax/apdb/apdbCassandra.py: 18%

619 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 09:59 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import dataclasses 

27import json 

28import logging 

29import uuid 

30from collections.abc import Iterable, Iterator, Mapping, Set 

31from typing import TYPE_CHECKING, Any, cast 

32 

33import numpy as np 

34import pandas 

35 

36# If cassandra-driver is not there the module can still be imported 

37# but ApdbCassandra cannot be instantiated. 

38try: 

39 import cassandra 

40 import cassandra.query 

41 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

42 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

43 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

44 

45 CASSANDRA_IMPORTED = True 

46except ImportError: 

47 CASSANDRA_IMPORTED = False 

48 

49import astropy.time 

50import felis.types 

51from felis.simple import Table 

52from lsst import sphgeom 

53from lsst.pex.config import ChoiceField, Field, ListField 

54from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

55from lsst.utils.iteration import chunk_iterable 

56 

57from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

58from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

59from .apdbConfigFreezer import ApdbConfigFreezer 

60from .apdbMetadataCassandra import ApdbMetadataCassandra 

61from .apdbSchema import ApdbTables 

62from .cassandra_utils import ( 

63 ApdbCassandraTableData, 

64 PreparedStatementCache, 

65 literal, 

66 pandas_dataframe_factory, 

67 quote_id, 

68 raw_data_factory, 

69 select_concurrent, 

70) 

71from .pixelization import Pixelization 

72from .timer import Timer 

73from .versionTuple import IncompatibleVersionError, VersionTuple 

74 

75if TYPE_CHECKING: 

76 from .apdbMetadata import ApdbMetadata 

77 

78_LOG = logging.getLogger(__name__) 

79 

80VERSION = VersionTuple(0, 1, 0) 

81"""Version for the code defined in this module. This needs to be updated 

82(following compatibility rules) when schema produced by this code changes. 

83""" 

84 

85# Copied from daf_butler. 

86DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

87"""Default name of the environmental variable that will be used to locate DB 

88credentials configuration file. """ 

89 

90DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

91"""Default path at which it is expected that DB credentials are found.""" 

92 

93 

94class CassandraMissingError(Exception): 

95 def __init__(self) -> None: 

96 super().__init__("cassandra-driver module cannot be imported") 

97 

98 

99class ApdbCassandraConfig(ApdbConfig): 

100 """Configuration class for Cassandra-based APDB implementation.""" 

101 

102 contact_points = ListField[str]( 

103 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

104 ) 

105 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

106 port = Field[int](doc="Port number to connect to.", default=9042) 

107 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

108 username = Field[str]( 

109 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

110 default="", 

111 ) 

112 read_consistency = Field[str]( 

113 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

114 ) 

115 write_consistency = Field[str]( 

116 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

117 ) 

118 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

119 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0) 

120 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0) 

121 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

122 protocol_version = Field[int]( 

123 doc="Cassandra protocol version to use, default is V4", 

124 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

125 ) 

126 dia_object_columns = ListField[str]( 

127 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

128 ) 

129 prefix = Field[str](doc="Prefix to add to table names", default="") 

130 part_pixelization = ChoiceField[str]( 

131 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

132 doc="Pixelization used for partitioning index.", 

133 default="mq3c", 

134 ) 

135 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

136 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

137 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

138 timer = Field[bool](doc="If True then print/log timing information", default=False) 

139 time_partition_tables = Field[bool]( 

140 doc="Use per-partition tables for sources instead of partitioning by time", default=False 

141 ) 

142 time_partition_days = Field[int]( 

143 doc=( 

144 "Time partitioning granularity in days, this value must not be changed after database is " 

145 "initialized" 

146 ), 

147 default=30, 

148 ) 

149 time_partition_start = Field[str]( 

150 doc=( 

151 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

152 "This is used only when time_partition_tables is True." 

153 ), 

154 default="2018-12-01T00:00:00", 

155 ) 

156 time_partition_end = Field[str]( 

157 doc=( 

158 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

159 "This is used only when time_partition_tables is True." 

160 ), 

161 default="2030-01-01T00:00:00", 

162 ) 

163 query_per_time_part = Field[bool]( 

164 default=False, 

165 doc=( 

166 "If True then build separate query for each time partition, otherwise build one single query. " 

167 "This is only used when time_partition_tables is False in schema config." 

168 ), 

169 ) 

170 query_per_spatial_part = Field[bool]( 

171 default=False, 

172 doc="If True then build one query per spatial partition, otherwise build single query.", 

173 ) 

174 use_insert_id_skips_diaobjects = Field[bool]( 

175 default=False, 

176 doc=( 

177 "If True then do not store DiaObjects when use_insert_id is True " 

178 "(DiaObjectsInsertId has the same data)." 

179 ), 

180 ) 

181 

182 

183@dataclasses.dataclass 

184class _FrozenApdbCassandraConfig: 

185 """Part of the configuration that is saved in metadata table and read back. 

186 

187 The attributes are a subset of attributes in `ApdbCassandraConfig` class. 

188 

189 Parameters 

190 ---------- 

191 config : `ApdbSqlConfig` 

192 Configuration used to copy initial values of attributes. 

193 """ 

194 

195 use_insert_id: bool 

196 part_pixelization: str 

197 part_pix_level: int 

198 ra_dec_columns: list[str] 

199 time_partition_tables: bool 

200 time_partition_days: int 

201 use_insert_id_skips_diaobjects: bool 

202 

203 def __init__(self, config: ApdbCassandraConfig): 

204 self.use_insert_id = config.use_insert_id 

205 self.part_pixelization = config.part_pixelization 

206 self.part_pix_level = config.part_pix_level 

207 self.ra_dec_columns = list(config.ra_dec_columns) 

208 self.time_partition_tables = config.time_partition_tables 

209 self.time_partition_days = config.time_partition_days 

210 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects 

211 

212 def to_json(self) -> str: 

213 """Convert this instance to JSON representation.""" 

214 return json.dumps(dataclasses.asdict(self)) 

215 

216 def update(self, json_str: str) -> None: 

217 """Update attribute values from a JSON string. 

218 

219 Parameters 

220 ---------- 

221 json_str : str 

222 String containing JSON representation of configuration. 

223 """ 

224 data = json.loads(json_str) 

225 if not isinstance(data, dict): 

226 raise TypeError(f"JSON string must be convertible to object: {json_str!r}") 

227 allowed_names = {field.name for field in dataclasses.fields(self)} 

228 for key, value in data.items(): 

229 if key not in allowed_names: 

230 raise ValueError(f"JSON object contains unknown key: {key}") 

231 setattr(self, key, value) 

232 

233 

234if CASSANDRA_IMPORTED: 234 ↛ 249line 234 didn't jump to line 249, because the condition on line 234 was never false

235 

236 class _AddressTranslator(AddressTranslator): 

237 """Translate internal IP address to external. 

238 

239 Only used for docker-based setup, not viable long-term solution. 

240 """ 

241 

242 def __init__(self, public_ips: list[str], private_ips: list[str]): 

243 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

244 

245 def translate(self, private_ip: str) -> str: 

246 return self._map.get(private_ip, private_ip) 

247 

248 

249class ApdbCassandra(Apdb): 

250 """Implementation of APDB database on to of Apache Cassandra. 

251 

252 The implementation is configured via standard ``pex_config`` mechanism 

253 using `ApdbCassandraConfig` configuration class. For an example of 

254 different configurations check config/ folder. 

255 

256 Parameters 

257 ---------- 

258 config : `ApdbCassandraConfig` 

259 Configuration object. 

260 """ 

261 

262 metadataSchemaVersionKey = "version:schema" 

263 """Name of the metadata key to store schema version number.""" 

264 

265 metadataCodeVersionKey = "version:ApdbCassandra" 

266 """Name of the metadata key to store code version number.""" 

267 

268 metadataConfigKey = "config:apdb-cassandra.json" 

269 """Name of the metadata key to store code version number.""" 

270 

271 _frozen_parameters = ( 

272 "use_insert_id", 

273 "part_pixelization", 

274 "part_pix_level", 

275 "ra_dec_columns", 

276 "time_partition_tables", 

277 "time_partition_days", 

278 "use_insert_id_skips_diaobjects", 

279 ) 

280 """Names of the config parameters to be frozen in metadata table.""" 

281 

282 partition_zero_epoch = astropy.time.Time(0, format="unix_tai") 

283 """Start time for partition 0, this should never be changed.""" 

284 

285 def __init__(self, config: ApdbCassandraConfig): 

286 if not CASSANDRA_IMPORTED: 

287 raise CassandraMissingError() 

288 

289 self._keyspace = config.keyspace 

290 

291 self._cluster, self._session = self._make_session(config) 

292 

293 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

294 self._metadata = ApdbMetadataCassandra( 

295 self._session, meta_table_name, config.keyspace, "read_tuples", "write" 

296 ) 

297 

298 # Read frozen config from metadata. 

299 config_json = self._metadata.get(self.metadataConfigKey) 

300 if config_json is not None: 

301 # Update config from metadata. 

302 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters) 

303 self.config = freezer.update(config, config_json) 

304 else: 

305 self.config = config 

306 self.config.validate() 

307 

308 self._pixelization = Pixelization( 

309 self.config.part_pixelization, 

310 self.config.part_pix_level, 

311 config.part_pix_max_ranges, 

312 ) 

313 

314 self._schema = ApdbCassandraSchema( 

315 session=self._session, 

316 keyspace=self._keyspace, 

317 schema_file=self.config.schema_file, 

318 schema_name=self.config.schema_name, 

319 prefix=self.config.prefix, 

320 time_partition_tables=self.config.time_partition_tables, 

321 use_insert_id=self.config.use_insert_id, 

322 ) 

323 self._partition_zero_epoch_mjd = float(self.partition_zero_epoch.mjd) 

324 

325 if self._metadata.table_exists(): 

326 self._versionCheck(self._metadata) 

327 

328 # Cache for prepared statements 

329 self._preparer = PreparedStatementCache(self._session) 

330 

331 _LOG.debug("ApdbCassandra Configuration:") 

332 for key, value in self.config.items(): 

333 _LOG.debug(" %s: %s", key, value) 

334 

335 def __del__(self) -> None: 

336 if hasattr(self, "_cluster"): 

337 self._cluster.shutdown() 

338 

339 @classmethod 

340 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]: 

341 """Make Cassandra session.""" 

342 addressTranslator: AddressTranslator | None = None 

343 if config.private_ips: 

344 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

345 

346 cluster = Cluster( 

347 execution_profiles=cls._makeProfiles(config), 

348 contact_points=config.contact_points, 

349 port=config.port, 

350 address_translator=addressTranslator, 

351 protocol_version=config.protocol_version, 

352 auth_provider=cls._make_auth_provider(config), 

353 ) 

354 session = cluster.connect() 

355 # Disable result paging 

356 session.default_fetch_size = None 

357 

358 return cluster, session 

359 

360 @classmethod 

361 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None: 

362 """Make Cassandra authentication provider instance.""" 

363 try: 

364 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

365 except DbAuthNotFoundError: 

366 # Credentials file doesn't exist, use anonymous login. 

367 return None 

368 

369 empty_username = True 

370 # Try every contact point in turn. 

371 for hostname in config.contact_points: 

372 try: 

373 username, password = dbauth.getAuth( 

374 "cassandra", config.username, hostname, config.port, config.keyspace 

375 ) 

376 if not username: 

377 # Password without user name, try next hostname, but give 

378 # warning later if no better match is found. 

379 empty_username = True 

380 else: 

381 return PlainTextAuthProvider(username=username, password=password) 

382 except DbAuthNotFoundError: 

383 pass 

384 

385 if empty_username: 

386 _LOG.warning( 

387 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

388 f"user name, anonymous Cassandra logon will be attempted." 

389 ) 

390 

391 return None 

392 

393 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None: 

394 """Check schema version compatibility.""" 

395 

396 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

397 """Retrieve version number from given metadata key.""" 

398 if metadata.table_exists(): 

399 version_str = metadata.get(key) 

400 if version_str is None: 

401 # Should not happen with existing metadata table. 

402 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

403 return VersionTuple.fromString(version_str) 

404 return default 

405 

406 # For old databases where metadata table does not exist we assume that 

407 # version of both code and schema is 0.1.0. 

408 initial_version = VersionTuple(0, 1, 0) 

409 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

410 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

411 

412 # For now there is no way to make read-only APDB instances, assume that 

413 # any access can do updates. 

414 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

415 raise IncompatibleVersionError( 

416 f"Configured schema version {self._schema.schemaVersion()} " 

417 f"is not compatible with database version {db_schema_version}" 

418 ) 

419 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

420 raise IncompatibleVersionError( 

421 f"Current code version {self.apdbImplementationVersion()} " 

422 f"is not compatible with database version {db_code_version}" 

423 ) 

424 

425 @classmethod 

426 def apdbImplementationVersion(cls) -> VersionTuple: 

427 # Docstring inherited from base class. 

428 return VERSION 

429 

430 def apdbSchemaVersion(self) -> VersionTuple: 

431 # Docstring inherited from base class. 

432 return self._schema.schemaVersion() 

433 

434 def tableDef(self, table: ApdbTables) -> Table | None: 

435 # docstring is inherited from a base class 

436 return self._schema.tableSchemas.get(table) 

437 

438 @classmethod 

439 def init_database( 

440 cls, 

441 hosts: list[str], 

442 keyspace: str, 

443 *, 

444 schema_file: str | None = None, 

445 schema_name: str | None = None, 

446 read_sources_months: int | None = None, 

447 read_forced_sources_months: int | None = None, 

448 use_insert_id: bool = False, 

449 use_insert_id_skips_diaobjects: bool = False, 

450 port: int | None = None, 

451 username: str | None = None, 

452 prefix: str | None = None, 

453 part_pixelization: str | None = None, 

454 part_pix_level: int | None = None, 

455 time_partition_tables: bool = True, 

456 time_partition_start: str | None = None, 

457 time_partition_end: str | None = None, 

458 read_consistency: str | None = None, 

459 write_consistency: str | None = None, 

460 read_timeout: int | None = None, 

461 write_timeout: int | None = None, 

462 ra_dec_columns: list[str] | None = None, 

463 replication_factor: int | None = None, 

464 drop: bool = False, 

465 ) -> ApdbCassandraConfig: 

466 """Initialize new APDB instance and make configuration object for it. 

467 

468 Parameters 

469 ---------- 

470 hosts : `list` [`str`] 

471 List of host names or IP addresses for Cassandra cluster. 

472 keyspace : `str` 

473 Name of the keyspace for APDB tables. 

474 schema_file : `str`, optional 

475 Location of (YAML) configuration file with APDB schema. If not 

476 specified then default location will be used. 

477 schema_name : `str`, optional 

478 Name of the schema in YAML configuration file. If not specified 

479 then default name will be used. 

480 read_sources_months : `int`, optional 

481 Number of months of history to read from DiaSource. 

482 read_forced_sources_months : `int`, optional 

483 Number of months of history to read from DiaForcedSource. 

484 use_insert_id : `bool`, optional 

485 If True, make additional tables used for replication to PPDB. 

486 use_insert_id_skips_diaobjects : `bool`, optional 

487 If `True` then do not fill regular ``DiaObject`` table when 

488 ``use_insert_id`` is `True`. 

489 port : `int`, optional 

490 Port number to use for Cassandra connections. 

491 username : `str`, optional 

492 User name for Cassandra connections. 

493 prefix : `str`, optional 

494 Optional prefix for all table names. 

495 part_pixelization : `str`, optional 

496 Name of the MOC pixelization used for partitioning. 

497 part_pix_level : `int`, optional 

498 Pixelization level. 

499 time_partition_tables : `bool`, optional 

500 Create per-partition tables. 

501 time_partition_start : `str`, optional 

502 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

503 format, in TAI. 

504 time_partition_end : `str`, optional 

505 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

506 format, in TAI. 

507 read_consistency : `str`, optional 

508 Name of the consistency level for read operations. 

509 write_consistency : `str`, optional 

510 Name of the consistency level for write operations. 

511 read_timeout : `int`, optional 

512 Read timeout in seconds. 

513 write_timeout : `int`, optional 

514 Write timeout in seconds. 

515 ra_dec_columns : `list` [`str`], optional 

516 Names of ra/dec columns in DiaObject table. 

517 replication_factor : `int`, optional 

518 Replication factor used when creating new keyspace, if keyspace 

519 already exists its replication factor is not changed. 

520 drop : `bool`, optional 

521 If `True` then drop existing tables before re-creating the schema. 

522 

523 Returns 

524 ------- 

525 config : `ApdbCassandraConfig` 

526 Resulting configuration object for a created APDB instance. 

527 """ 

528 config = ApdbCassandraConfig( 

529 contact_points=hosts, 

530 keyspace=keyspace, 

531 use_insert_id=use_insert_id, 

532 use_insert_id_skips_diaobjects=use_insert_id_skips_diaobjects, 

533 time_partition_tables=time_partition_tables, 

534 ) 

535 if schema_file is not None: 

536 config.schema_file = schema_file 

537 if schema_name is not None: 

538 config.schema_name = schema_name 

539 if read_sources_months is not None: 

540 config.read_sources_months = read_sources_months 

541 if read_forced_sources_months is not None: 

542 config.read_forced_sources_months = read_forced_sources_months 

543 if port is not None: 

544 config.port = port 

545 if username is not None: 

546 config.username = username 

547 if prefix is not None: 

548 config.prefix = prefix 

549 if part_pixelization is not None: 

550 config.part_pixelization = part_pixelization 

551 if part_pix_level is not None: 

552 config.part_pix_level = part_pix_level 

553 if time_partition_start is not None: 

554 config.time_partition_start = time_partition_start 

555 if time_partition_end is not None: 

556 config.time_partition_end = time_partition_end 

557 if read_consistency is not None: 

558 config.read_consistency = read_consistency 

559 if write_consistency is not None: 

560 config.write_consistency = write_consistency 

561 if read_timeout is not None: 

562 config.read_timeout = read_timeout 

563 if write_timeout is not None: 

564 config.write_timeout = write_timeout 

565 if ra_dec_columns is not None: 

566 config.ra_dec_columns = ra_dec_columns 

567 

568 cls._makeSchema(config, drop=drop, replication_factor=replication_factor) 

569 

570 return config 

571 

572 @classmethod 

573 def _makeSchema( 

574 cls, config: ApdbConfig, *, drop: bool = False, replication_factor: int | None = None 

575 ) -> None: 

576 # docstring is inherited from a base class 

577 

578 if not isinstance(config, ApdbCassandraConfig): 

579 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

580 

581 cluster, session = cls._make_session(config) 

582 

583 schema = ApdbCassandraSchema( 

584 session=session, 

585 keyspace=config.keyspace, 

586 schema_file=config.schema_file, 

587 schema_name=config.schema_name, 

588 prefix=config.prefix, 

589 time_partition_tables=config.time_partition_tables, 

590 use_insert_id=config.use_insert_id, 

591 ) 

592 

593 # Ask schema to create all tables. 

594 if config.time_partition_tables: 

595 time_partition_start = astropy.time.Time(config.time_partition_start, format="isot", scale="tai") 

596 time_partition_end = astropy.time.Time(config.time_partition_end, format="isot", scale="tai") 

597 part_epoch = float(cls.partition_zero_epoch.mjd) 

598 part_days = config.time_partition_days 

599 part_range = ( 

600 cls._time_partition_cls(time_partition_start, part_epoch, part_days), 

601 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1, 

602 ) 

603 schema.makeSchema(drop=drop, part_range=part_range, replication_factor=replication_factor) 

604 else: 

605 schema.makeSchema(drop=drop, replication_factor=replication_factor) 

606 

607 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

608 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write") 

609 

610 # Fill version numbers, overrides if they existed before. 

611 if metadata.table_exists(): 

612 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

613 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

614 

615 # Store frozen part of a configuration in metadata. 

616 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters) 

617 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

618 

619 cluster.shutdown() 

620 

621 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

622 # docstring is inherited from a base class 

623 

624 sp_where = self._spatial_where(region) 

625 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

626 

627 # We need to exclude extra partitioning columns from result. 

628 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

629 what = ",".join(quote_id(column) for column in column_names) 

630 

631 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

632 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

633 statements: list[tuple] = [] 

634 for where, params in sp_where: 

635 full_query = f"{query} WHERE {where}" 

636 if params: 

637 statement = self._preparer.prepare(full_query) 

638 else: 

639 # If there are no params then it is likely that query has a 

640 # bunch of literals rendered already, no point trying to 

641 # prepare it because it's not reusable. 

642 statement = cassandra.query.SimpleStatement(full_query) 

643 statements.append((statement, params)) 

644 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

645 

646 with Timer("DiaObject select", self.config.timer): 

647 objects = cast( 

648 pandas.DataFrame, 

649 select_concurrent( 

650 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

651 ), 

652 ) 

653 

654 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

655 return objects 

656 

657 def getDiaSources( 

658 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

659 ) -> pandas.DataFrame | None: 

660 # docstring is inherited from a base class 

661 months = self.config.read_sources_months 

662 if months == 0: 

663 return None 

664 mjd_end = visit_time.mjd 

665 mjd_start = mjd_end - months * 30 

666 

667 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

668 

669 def getDiaForcedSources( 

670 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

671 ) -> pandas.DataFrame | None: 

672 # docstring is inherited from a base class 

673 months = self.config.read_forced_sources_months 

674 if months == 0: 

675 return None 

676 mjd_end = visit_time.mjd 

677 mjd_start = mjd_end - months * 30 

678 

679 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

680 

681 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

682 # docstring is inherited from a base class 

683 raise NotImplementedError() 

684 

685 def getInsertIds(self) -> list[ApdbInsertId] | None: 

686 # docstring is inherited from a base class 

687 if not self._schema.has_insert_id: 

688 return None 

689 

690 # everything goes into a single partition 

691 partition = 0 

692 

693 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

694 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?' 

695 

696 result = self._session.execute( 

697 self._preparer.prepare(query), 

698 (partition,), 

699 timeout=self.config.read_timeout, 

700 execution_profile="read_tuples", 

701 ) 

702 # order by insert_time 

703 rows = sorted(result) 

704 return [ 

705 ApdbInsertId(id=row[1], insert_time=astropy.time.Time(row[0].timestamp(), format="unix_tai")) 

706 for row in rows 

707 ] 

708 

709 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

710 # docstring is inherited from a base class 

711 if not self._schema.has_insert_id: 

712 raise ValueError("APDB is not configured for history storage") 

713 

714 all_insert_ids = [id.id for id in ids] 

715 # There is 64k limit on number of markers in Cassandra CQL 

716 for insert_ids in chunk_iterable(all_insert_ids, 20_000): 

717 params = ",".join("?" * len(insert_ids)) 

718 

719 # everything goes into a single partition 

720 partition = 0 

721 

722 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

723 query = ( 

724 f'DELETE FROM "{self._keyspace}"."{table_name}" ' 

725 f"WHERE partition = ? AND insert_id IN ({params})" 

726 ) 

727 

728 self._session.execute( 

729 self._preparer.prepare(query), 

730 [partition] + list(insert_ids), 

731 timeout=self.config.remove_timeout, 

732 ) 

733 

734 # Also remove those insert_ids from Dia*InsertId tables.abs 

735 for table in ( 

736 ExtraTables.DiaObjectInsertId, 

737 ExtraTables.DiaSourceInsertId, 

738 ExtraTables.DiaForcedSourceInsertId, 

739 ): 

740 table_name = self._schema.tableName(table) 

741 query = f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

742 self._session.execute( 

743 self._preparer.prepare(query), 

744 insert_ids, 

745 timeout=self.config.remove_timeout, 

746 ) 

747 

748 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

749 # docstring is inherited from a base class 

750 return self._get_history(ExtraTables.DiaObjectInsertId, ids) 

751 

752 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

753 # docstring is inherited from a base class 

754 return self._get_history(ExtraTables.DiaSourceInsertId, ids) 

755 

756 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

757 # docstring is inherited from a base class 

758 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids) 

759 

760 def getSSObjects(self) -> pandas.DataFrame: 

761 # docstring is inherited from a base class 

762 tableName = self._schema.tableName(ApdbTables.SSObject) 

763 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

764 

765 objects = None 

766 with Timer("SSObject select", self.config.timer): 

767 result = self._session.execute(query, execution_profile="read_pandas") 

768 objects = result._current_rows 

769 

770 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

771 return objects 

772 

773 def store( 

774 self, 

775 visit_time: astropy.time.Time, 

776 objects: pandas.DataFrame, 

777 sources: pandas.DataFrame | None = None, 

778 forced_sources: pandas.DataFrame | None = None, 

779 ) -> None: 

780 # docstring is inherited from a base class 

781 

782 insert_id: ApdbInsertId | None = None 

783 if self._schema.has_insert_id: 

784 insert_id = ApdbInsertId.new_insert_id(visit_time) 

785 self._storeInsertId(insert_id, visit_time) 

786 

787 # fill region partition column for DiaObjects 

788 objects = self._add_obj_part(objects) 

789 self._storeDiaObjects(objects, visit_time, insert_id) 

790 

791 if sources is not None: 

792 # copy apdb_part column from DiaObjects to DiaSources 

793 sources = self._add_src_part(sources, objects) 

794 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id) 

795 self._storeDiaSourcesPartitions(sources, visit_time, insert_id) 

796 

797 if forced_sources is not None: 

798 forced_sources = self._add_fsrc_part(forced_sources, objects) 

799 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id) 

800 

801 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

802 # docstring is inherited from a base class 

803 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

804 

805 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

806 # docstring is inherited from a base class 

807 

808 # To update a record we need to know its exact primary key (including 

809 # partition key) so we start by querying for diaSourceId to find the 

810 # primary keys. 

811 

812 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

813 # split it into 1k IDs per query 

814 selects: list[tuple] = [] 

815 for ids in chunk_iterable(idMap.keys(), 1_000): 

816 ids_str = ",".join(str(item) for item in ids) 

817 selects.append( 

818 ( 

819 ( 

820 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" ' 

821 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

822 ), 

823 {}, 

824 ) 

825 ) 

826 

827 # No need for DataFrame here, read data as tuples. 

828 result = cast( 

829 list[tuple[int, int, int, uuid.UUID | None]], 

830 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

831 ) 

832 

833 # Make mapping from source ID to its partition. 

834 id2partitions: dict[int, tuple[int, int]] = {} 

835 id2insert_id: dict[int, uuid.UUID] = {} 

836 for row in result: 

837 id2partitions[row[0]] = row[1:3] 

838 if row[3] is not None: 

839 id2insert_id[row[0]] = row[3] 

840 

841 # make sure we know partitions for each ID 

842 if set(id2partitions) != set(idMap): 

843 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

844 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

845 

846 # Reassign in standard tables 

847 queries = cassandra.query.BatchStatement() 

848 table_name = self._schema.tableName(ApdbTables.DiaSource) 

849 for diaSourceId, ssObjectId in idMap.items(): 

850 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

851 values: tuple 

852 if self.config.time_partition_tables: 

853 query = ( 

854 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

855 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

856 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

857 ) 

858 values = (ssObjectId, apdb_part, diaSourceId) 

859 else: 

860 query = ( 

861 f'UPDATE "{self._keyspace}"."{table_name}"' 

862 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

863 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

864 ) 

865 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

866 queries.add(self._preparer.prepare(query), values) 

867 

868 # Reassign in history tables, only if history is enabled 

869 if id2insert_id: 

870 # Filter out insert ids that have been deleted already. There is a 

871 # potential race with concurrent removal of insert IDs, but it 

872 # should be handled by WHERE in UPDATE. 

873 known_ids = set() 

874 if insert_ids := self.getInsertIds(): 

875 known_ids = set(insert_id.id for insert_id in insert_ids) 

876 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids} 

877 if id2insert_id: 

878 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId) 

879 for diaSourceId, ssObjectId in idMap.items(): 

880 if insert_id := id2insert_id.get(diaSourceId): 

881 query = ( 

882 f'UPDATE "{self._keyspace}"."{table_name}" ' 

883 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

884 'WHERE "insert_id" = ? AND "diaSourceId" = ?' 

885 ) 

886 values = (ssObjectId, insert_id, diaSourceId) 

887 queries.add(self._preparer.prepare(query), values) 

888 

889 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

890 with Timer(table_name + " update", self.config.timer): 

891 self._session.execute(queries, execution_profile="write") 

892 

893 def dailyJob(self) -> None: 

894 # docstring is inherited from a base class 

895 pass 

896 

897 def countUnassociatedObjects(self) -> int: 

898 # docstring is inherited from a base class 

899 

900 # It's too inefficient to implement it for Cassandra in current schema. 

901 raise NotImplementedError() 

902 

903 @property 

904 def metadata(self) -> ApdbMetadata: 

905 # docstring is inherited from a base class 

906 if self._metadata is None: 

907 raise RuntimeError("Database schema was not initialized.") 

908 return self._metadata 

909 

910 @classmethod 

911 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

912 """Make all execution profiles used in the code.""" 

913 if config.private_ips: 

914 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

915 else: 

916 loadBalancePolicy = RoundRobinPolicy() 

917 

918 read_tuples_profile = ExecutionProfile( 

919 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

920 request_timeout=config.read_timeout, 

921 row_factory=cassandra.query.tuple_factory, 

922 load_balancing_policy=loadBalancePolicy, 

923 ) 

924 read_pandas_profile = ExecutionProfile( 

925 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

926 request_timeout=config.read_timeout, 

927 row_factory=pandas_dataframe_factory, 

928 load_balancing_policy=loadBalancePolicy, 

929 ) 

930 read_raw_profile = ExecutionProfile( 

931 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

932 request_timeout=config.read_timeout, 

933 row_factory=raw_data_factory, 

934 load_balancing_policy=loadBalancePolicy, 

935 ) 

936 # Profile to use with select_concurrent to return pandas data frame 

937 read_pandas_multi_profile = ExecutionProfile( 

938 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

939 request_timeout=config.read_timeout, 

940 row_factory=pandas_dataframe_factory, 

941 load_balancing_policy=loadBalancePolicy, 

942 ) 

943 # Profile to use with select_concurrent to return raw data (columns and 

944 # rows) 

945 read_raw_multi_profile = ExecutionProfile( 

946 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

947 request_timeout=config.read_timeout, 

948 row_factory=raw_data_factory, 

949 load_balancing_policy=loadBalancePolicy, 

950 ) 

951 write_profile = ExecutionProfile( 

952 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

953 request_timeout=config.write_timeout, 

954 load_balancing_policy=loadBalancePolicy, 

955 ) 

956 # To replace default DCAwareRoundRobinPolicy 

957 default_profile = ExecutionProfile( 

958 load_balancing_policy=loadBalancePolicy, 

959 ) 

960 return { 

961 "read_tuples": read_tuples_profile, 

962 "read_pandas": read_pandas_profile, 

963 "read_raw": read_raw_profile, 

964 "read_pandas_multi": read_pandas_multi_profile, 

965 "read_raw_multi": read_raw_multi_profile, 

966 "write": write_profile, 

967 EXEC_PROFILE_DEFAULT: default_profile, 

968 } 

969 

970 def _getSources( 

971 self, 

972 region: sphgeom.Region, 

973 object_ids: Iterable[int] | None, 

974 mjd_start: float, 

975 mjd_end: float, 

976 table_name: ApdbTables, 

977 ) -> pandas.DataFrame: 

978 """Return catalog of DiaSource instances given set of DiaObject IDs. 

979 

980 Parameters 

981 ---------- 

982 region : `lsst.sphgeom.Region` 

983 Spherical region. 

984 object_ids : 

985 Collection of DiaObject IDs 

986 mjd_start : `float` 

987 Lower bound of time interval. 

988 mjd_end : `float` 

989 Upper bound of time interval. 

990 table_name : `ApdbTables` 

991 Name of the table. 

992 

993 Returns 

994 ------- 

995 catalog : `pandas.DataFrame`, or `None` 

996 Catalog containing DiaSource records. Empty catalog is returned if 

997 ``object_ids`` is empty. 

998 """ 

999 object_id_set: Set[int] = set() 

1000 if object_ids is not None: 

1001 object_id_set = set(object_ids) 

1002 if len(object_id_set) == 0: 

1003 return self._make_empty_catalog(table_name) 

1004 

1005 sp_where = self._spatial_where(region) 

1006 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

1007 

1008 # We need to exclude extra partitioning columns from result. 

1009 column_names = self._schema.apdbColumnNames(table_name) 

1010 what = ",".join(quote_id(column) for column in column_names) 

1011 

1012 # Build all queries 

1013 statements: list[tuple] = [] 

1014 for table in tables: 

1015 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

1016 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

1017 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

1018 

1019 with Timer(table_name.name + " select", self.config.timer): 

1020 catalog = cast( 

1021 pandas.DataFrame, 

1022 select_concurrent( 

1023 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

1024 ), 

1025 ) 

1026 

1027 # filter by given object IDs 

1028 if len(object_id_set) > 0: 

1029 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1030 

1031 # precise filtering on midpointMjdTai 

1032 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1033 

1034 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1035 return catalog 

1036 

1037 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

1038 """Return records from a particular table given set of insert IDs.""" 

1039 if not self._schema.has_insert_id: 

1040 raise ValueError("APDB is not configured for history retrieval") 

1041 

1042 insert_ids = [id.id for id in ids] 

1043 params = ",".join("?" * len(insert_ids)) 

1044 

1045 table_name = self._schema.tableName(table) 

1046 # I know that history table schema has only regular APDB columns plus 

1047 # an insert_id column, and this is exactly what we need to return from 

1048 # this method, so selecting a star is fine here. 

1049 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

1050 statement = self._preparer.prepare(query) 

1051 

1052 with Timer("DiaObject history", self.config.timer): 

1053 result = self._session.execute(statement, insert_ids, execution_profile="read_raw") 

1054 table_data = cast(ApdbCassandraTableData, result._current_rows) 

1055 return table_data 

1056 

1057 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: astropy.time.Time) -> None: 

1058 # Cassandra timestamp uses milliseconds since epoch 

1059 timestamp = int(insert_id.insert_time.unix_tai / 1_000_000) 

1060 

1061 # everything goes into a single partition 

1062 partition = 0 

1063 

1064 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

1065 query = ( 

1066 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) ' 

1067 "VALUES (?, ?, ?)" 

1068 ) 

1069 

1070 self._session.execute( 

1071 self._preparer.prepare(query), 

1072 (partition, insert_id.id, timestamp), 

1073 timeout=self.config.write_timeout, 

1074 execution_profile="write", 

1075 ) 

1076 

1077 def _storeDiaObjects( 

1078 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, insert_id: ApdbInsertId | None 

1079 ) -> None: 

1080 """Store catalog of DiaObjects from current visit. 

1081 

1082 Parameters 

1083 ---------- 

1084 objs : `pandas.DataFrame` 

1085 Catalog with DiaObject records 

1086 visit_time : `astropy.time.Time` 

1087 Time of the current visit. 

1088 """ 

1089 if len(objs) == 0: 

1090 _LOG.debug("No objects to write to database.") 

1091 return 

1092 

1093 visit_time_dt = visit_time.datetime 

1094 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

1095 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1096 

1097 extra_columns["validityStart"] = visit_time_dt 

1098 time_part: int | None = self._time_partition(visit_time) 

1099 if not self.config.time_partition_tables: 

1100 extra_columns["apdb_time_part"] = time_part 

1101 time_part = None 

1102 

1103 # Only store DiaObects if not storing insert_ids or explicitly 

1104 # configured to always store them 

1105 if insert_id is None or not self.config.use_insert_id_skips_diaobjects: 

1106 self._storeObjectsPandas( 

1107 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1108 ) 

1109 

1110 if insert_id is not None: 

1111 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt) 

1112 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns) 

1113 

1114 def _storeDiaSources( 

1115 self, 

1116 table_name: ApdbTables, 

1117 sources: pandas.DataFrame, 

1118 visit_time: astropy.time.Time, 

1119 insert_id: ApdbInsertId | None, 

1120 ) -> None: 

1121 """Store catalog of DIASources or DIAForcedSources from current visit. 

1122 

1123 Parameters 

1124 ---------- 

1125 sources : `pandas.DataFrame` 

1126 Catalog containing DiaSource records 

1127 visit_time : `astropy.time.Time` 

1128 Time of the current visit. 

1129 """ 

1130 time_part: int | None = self._time_partition(visit_time) 

1131 extra_columns: dict[str, Any] = {} 

1132 if not self.config.time_partition_tables: 

1133 extra_columns["apdb_time_part"] = time_part 

1134 time_part = None 

1135 

1136 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

1137 

1138 if insert_id is not None: 

1139 extra_columns = dict(insert_id=insert_id.id) 

1140 if table_name is ApdbTables.DiaSource: 

1141 extra_table = ExtraTables.DiaSourceInsertId 

1142 else: 

1143 extra_table = ExtraTables.DiaForcedSourceInsertId 

1144 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1145 

1146 def _storeDiaSourcesPartitions( 

1147 self, sources: pandas.DataFrame, visit_time: astropy.time.Time, insert_id: ApdbInsertId | None 

1148 ) -> None: 

1149 """Store mapping of diaSourceId to its partitioning values. 

1150 

1151 Parameters 

1152 ---------- 

1153 sources : `pandas.DataFrame` 

1154 Catalog containing DiaSource records 

1155 visit_time : `astropy.time.Time` 

1156 Time of the current visit. 

1157 """ 

1158 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1159 extra_columns = { 

1160 "apdb_time_part": self._time_partition(visit_time), 

1161 "insert_id": insert_id.id if insert_id is not None else None, 

1162 } 

1163 

1164 self._storeObjectsPandas( 

1165 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1166 ) 

1167 

1168 def _storeObjectsPandas( 

1169 self, 

1170 records: pandas.DataFrame, 

1171 table_name: ApdbTables | ExtraTables, 

1172 extra_columns: Mapping | None = None, 

1173 time_part: int | None = None, 

1174 ) -> None: 

1175 """Store generic objects. 

1176 

1177 Takes Pandas catalog and stores a bunch of records in a table. 

1178 

1179 Parameters 

1180 ---------- 

1181 records : `pandas.DataFrame` 

1182 Catalog containing object records 

1183 table_name : `ApdbTables` 

1184 Name of the table as defined in APDB schema. 

1185 extra_columns : `dict`, optional 

1186 Mapping (column_name, column_value) which gives fixed values for 

1187 columns in each row, overrides values in ``records`` if matching 

1188 columns exist there. 

1189 time_part : `int`, optional 

1190 If not `None` then insert into a per-partition table. 

1191 

1192 Notes 

1193 ----- 

1194 If Pandas catalog contains additional columns not defined in table 

1195 schema they are ignored. Catalog does not have to contain all columns 

1196 defined in a table, but partition and clustering keys must be present 

1197 in a catalog or ``extra_columns``. 

1198 """ 

1199 # use extra columns if specified 

1200 if extra_columns is None: 

1201 extra_columns = {} 

1202 extra_fields = list(extra_columns.keys()) 

1203 

1204 # Fields that will come from dataframe. 

1205 df_fields = [column for column in records.columns if column not in extra_fields] 

1206 

1207 column_map = self._schema.getColumnMap(table_name) 

1208 # list of columns (as in felis schema) 

1209 fields = [column_map[field].name for field in df_fields if field in column_map] 

1210 fields += extra_fields 

1211 

1212 # check that all partitioning and clustering columns are defined 

1213 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

1214 table_name 

1215 ) 

1216 missing_columns = [column for column in required_columns if column not in fields] 

1217 if missing_columns: 

1218 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1219 

1220 qfields = [quote_id(field) for field in fields] 

1221 qfields_str = ",".join(qfields) 

1222 

1223 with Timer(table_name.name + " query build", self.config.timer): 

1224 table = self._schema.tableName(table_name) 

1225 if time_part is not None: 

1226 table = f"{table}_{time_part}" 

1227 

1228 holders = ",".join(["?"] * len(qfields)) 

1229 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1230 statement = self._preparer.prepare(query) 

1231 queries = cassandra.query.BatchStatement() 

1232 for rec in records.itertuples(index=False): 

1233 values = [] 

1234 for field in df_fields: 

1235 if field not in column_map: 

1236 continue 

1237 value = getattr(rec, field) 

1238 if column_map[field].datatype is felis.types.Timestamp: 

1239 if isinstance(value, pandas.Timestamp): 

1240 value = literal(value.to_pydatetime()) 

1241 else: 

1242 # Assume it's seconds since epoch, Cassandra 

1243 # datetime is in milliseconds 

1244 value = int(value * 1000) 

1245 values.append(literal(value)) 

1246 for field in extra_fields: 

1247 value = extra_columns[field] 

1248 values.append(literal(value)) 

1249 queries.add(statement, values) 

1250 

1251 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

1252 with Timer(table_name.name + " insert", self.config.timer): 

1253 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

1254 

1255 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1256 """Calculate spatial partition for each record and add it to a 

1257 DataFrame. 

1258 

1259 Notes 

1260 ----- 

1261 This overrides any existing column in a DataFrame with the same name 

1262 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1263 returned. 

1264 """ 

1265 # calculate HTM index for every DiaObject 

1266 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1267 ra_col, dec_col = self.config.ra_dec_columns 

1268 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1269 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1270 idx = self._pixelization.pixel(uv3d) 

1271 apdb_part[i] = idx 

1272 df = df.copy() 

1273 df["apdb_part"] = apdb_part 

1274 return df 

1275 

1276 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1277 """Add apdb_part column to DiaSource catalog. 

1278 

1279 Notes 

1280 ----- 

1281 This method copies apdb_part value from a matching DiaObject record. 

1282 DiaObject catalog needs to have a apdb_part column filled by 

1283 ``_add_obj_part`` method and DiaSource records need to be 

1284 associated to DiaObjects via ``diaObjectId`` column. 

1285 

1286 This overrides any existing column in a DataFrame with the same name 

1287 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1288 returned. 

1289 """ 

1290 pixel_id_map: dict[int, int] = { 

1291 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1292 } 

1293 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1294 ra_col, dec_col = self.config.ra_dec_columns 

1295 for i, (diaObjId, ra, dec) in enumerate( 

1296 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

1297 ): 

1298 if diaObjId == 0: 

1299 # DiaSources associated with SolarSystemObjects do not have an 

1300 # associated DiaObject hence we skip them and set partition 

1301 # based on its own ra/dec 

1302 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1303 idx = self._pixelization.pixel(uv3d) 

1304 apdb_part[i] = idx 

1305 else: 

1306 apdb_part[i] = pixel_id_map[diaObjId] 

1307 sources = sources.copy() 

1308 sources["apdb_part"] = apdb_part 

1309 return sources 

1310 

1311 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1312 """Add apdb_part column to DiaForcedSource catalog. 

1313 

1314 Notes 

1315 ----- 

1316 This method copies apdb_part value from a matching DiaObject record. 

1317 DiaObject catalog needs to have a apdb_part column filled by 

1318 ``_add_obj_part`` method and DiaSource records need to be 

1319 associated to DiaObjects via ``diaObjectId`` column. 

1320 

1321 This overrides any existing column in a DataFrame with the same name 

1322 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1323 returned. 

1324 """ 

1325 pixel_id_map: dict[int, int] = { 

1326 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1327 } 

1328 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1329 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1330 apdb_part[i] = pixel_id_map[diaObjId] 

1331 sources = sources.copy() 

1332 sources["apdb_part"] = apdb_part 

1333 return sources 

1334 

1335 @classmethod 

1336 def _time_partition_cls(cls, time: float | astropy.time.Time, epoch_mjd: float, part_days: int) -> int: 

1337 """Calculate time partition number for a given time. 

1338 

1339 Parameters 

1340 ---------- 

1341 time : `float` or `astropy.time.Time` 

1342 Time for which to calculate partition number. Can be float to mean 

1343 MJD or `astropy.time.Time` 

1344 epoch_mjd : `float` 

1345 Epoch time for partition 0. 

1346 part_days : `int` 

1347 Number of days per partition. 

1348 

1349 Returns 

1350 ------- 

1351 partition : `int` 

1352 Partition number for a given time. 

1353 """ 

1354 if isinstance(time, astropy.time.Time): 

1355 mjd = float(time.mjd) 

1356 else: 

1357 mjd = time 

1358 days_since_epoch = mjd - epoch_mjd 

1359 partition = int(days_since_epoch) // part_days 

1360 return partition 

1361 

1362 def _time_partition(self, time: float | astropy.time.Time) -> int: 

1363 """Calculate time partition number for a given time. 

1364 

1365 Parameters 

1366 ---------- 

1367 time : `float` or `astropy.time.Time` 

1368 Time for which to calculate partition number. Can be float to mean 

1369 MJD or `astropy.time.Time` 

1370 

1371 Returns 

1372 ------- 

1373 partition : `int` 

1374 Partition number for a given time. 

1375 """ 

1376 if isinstance(time, astropy.time.Time): 

1377 mjd = float(time.mjd) 

1378 else: 

1379 mjd = time 

1380 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

1381 partition = int(days_since_epoch) // self.config.time_partition_days 

1382 return partition 

1383 

1384 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1385 """Make an empty catalog for a table with a given name. 

1386 

1387 Parameters 

1388 ---------- 

1389 table_name : `ApdbTables` 

1390 Name of the table. 

1391 

1392 Returns 

1393 ------- 

1394 catalog : `pandas.DataFrame` 

1395 An empty catalog. 

1396 """ 

1397 table = self._schema.tableSchemas[table_name] 

1398 

1399 data = { 

1400 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1401 for columnDef in table.columns 

1402 } 

1403 return pandas.DataFrame(data) 

1404 

1405 def _combine_where( 

1406 self, 

1407 prefix: str, 

1408 where1: list[tuple[str, tuple]], 

1409 where2: list[tuple[str, tuple]], 

1410 suffix: str | None = None, 

1411 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1412 """Make cartesian product of two parts of WHERE clause into a series 

1413 of statements to execute. 

1414 

1415 Parameters 

1416 ---------- 

1417 prefix : `str` 

1418 Initial statement prefix that comes before WHERE clause, e.g. 

1419 "SELECT * from Table" 

1420 """ 

1421 # If lists are empty use special sentinels. 

1422 if not where1: 

1423 where1 = [("", ())] 

1424 if not where2: 

1425 where2 = [("", ())] 

1426 

1427 for expr1, params1 in where1: 

1428 for expr2, params2 in where2: 

1429 full_query = prefix 

1430 wheres = [] 

1431 if expr1: 

1432 wheres.append(expr1) 

1433 if expr2: 

1434 wheres.append(expr2) 

1435 if wheres: 

1436 full_query += " WHERE " + " AND ".join(wheres) 

1437 if suffix: 

1438 full_query += " " + suffix 

1439 params = params1 + params2 

1440 if params: 

1441 statement = self._preparer.prepare(full_query) 

1442 else: 

1443 # If there are no params then it is likely that query 

1444 # has a bunch of literals rendered already, no point 

1445 # trying to prepare it. 

1446 statement = cassandra.query.SimpleStatement(full_query) 

1447 yield (statement, params) 

1448 

1449 def _spatial_where( 

1450 self, region: sphgeom.Region | None, use_ranges: bool = False 

1451 ) -> list[tuple[str, tuple]]: 

1452 """Generate expressions for spatial part of WHERE clause. 

1453 

1454 Parameters 

1455 ---------- 

1456 region : `sphgeom.Region` 

1457 Spatial region for query results. 

1458 use_ranges : `bool` 

1459 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1460 p2") instead of exact list of pixels. Should be set to True for 

1461 large regions covering very many pixels. 

1462 

1463 Returns 

1464 ------- 

1465 expressions : `list` [ `tuple` ] 

1466 Empty list is returned if ``region`` is `None`, otherwise a list 

1467 of one or more (expression, parameters) tuples 

1468 """ 

1469 if region is None: 

1470 return [] 

1471 if use_ranges: 

1472 pixel_ranges = self._pixelization.envelope(region) 

1473 expressions: list[tuple[str, tuple]] = [] 

1474 for lower, upper in pixel_ranges: 

1475 upper -= 1 

1476 if lower == upper: 

1477 expressions.append(('"apdb_part" = ?', (lower,))) 

1478 else: 

1479 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1480 return expressions 

1481 else: 

1482 pixels = self._pixelization.pixels(region) 

1483 if self.config.query_per_spatial_part: 

1484 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1485 else: 

1486 pixels_str = ",".join([str(pix) for pix in pixels]) 

1487 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1488 

1489 def _temporal_where( 

1490 self, 

1491 table: ApdbTables, 

1492 start_time: float | astropy.time.Time, 

1493 end_time: float | astropy.time.Time, 

1494 query_per_time_part: bool | None = None, 

1495 ) -> tuple[list[str], list[tuple[str, tuple]]]: 

1496 """Generate table names and expressions for temporal part of WHERE 

1497 clauses. 

1498 

1499 Parameters 

1500 ---------- 

1501 table : `ApdbTables` 

1502 Table to select from. 

1503 start_time : `astropy.time.Time` or `float` 

1504 Starting Datetime of MJD value of the time range. 

1505 end_time : `astropy.time.Time` or `float` 

1506 Starting Datetime of MJD value of the time range. 

1507 query_per_time_part : `bool`, optional 

1508 If None then use ``query_per_time_part`` from configuration. 

1509 

1510 Returns 

1511 ------- 

1512 tables : `list` [ `str` ] 

1513 List of the table names to query. 

1514 expressions : `list` [ `tuple` ] 

1515 A list of zero or more (expression, parameters) tuples. 

1516 """ 

1517 tables: list[str] 

1518 temporal_where: list[tuple[str, tuple]] = [] 

1519 table_name = self._schema.tableName(table) 

1520 time_part_start = self._time_partition(start_time) 

1521 time_part_end = self._time_partition(end_time) 

1522 time_parts = list(range(time_part_start, time_part_end + 1)) 

1523 if self.config.time_partition_tables: 

1524 tables = [f"{table_name}_{part}" for part in time_parts] 

1525 else: 

1526 tables = [table_name] 

1527 if query_per_time_part is None: 

1528 query_per_time_part = self.config.query_per_time_part 

1529 if query_per_time_part: 

1530 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1531 else: 

1532 time_part_list = ",".join([str(part) for part in time_parts]) 

1533 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1534 

1535 return tables, temporal_where