Coverage for python/lsst/dax/apdb/apdbCassandra.py: 19%

583 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-22 02:30 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import dataclasses 

27import json 

28import logging 

29import uuid 

30from collections.abc import Iterable, Iterator, Mapping, Set 

31from typing import TYPE_CHECKING, Any, cast 

32 

33import numpy as np 

34import pandas 

35 

36# If cassandra-driver is not there the module can still be imported 

37# but ApdbCassandra cannot be instantiated. 

38try: 

39 import cassandra 

40 import cassandra.query 

41 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

42 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

43 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

44 

45 CASSANDRA_IMPORTED = True 

46except ImportError: 

47 CASSANDRA_IMPORTED = False 

48 

49import felis.types 

50import lsst.daf.base as dafBase 

51from felis.simple import Table 

52from lsst import sphgeom 

53from lsst.pex.config import ChoiceField, Field, ListField 

54from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

55from lsst.utils.iteration import chunk_iterable 

56 

57from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

58from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

59from .apdbConfigFreezer import ApdbConfigFreezer 

60from .apdbMetadataCassandra import ApdbMetadataCassandra 

61from .apdbSchema import ApdbTables 

62from .cassandra_utils import ( 

63 ApdbCassandraTableData, 

64 PreparedStatementCache, 

65 literal, 

66 pandas_dataframe_factory, 

67 quote_id, 

68 raw_data_factory, 

69 select_concurrent, 

70) 

71from .pixelization import Pixelization 

72from .timer import Timer 

73from .versionTuple import IncompatibleVersionError, VersionTuple 

74 

75if TYPE_CHECKING: 75 ↛ 76line 75 didn't jump to line 76, because the condition on line 75 was never true

76 from .apdbMetadata import ApdbMetadata 

77 

78_LOG = logging.getLogger(__name__) 

79 

80VERSION = VersionTuple(0, 1, 0) 

81"""Version for the code defined in this module. This needs to be updated 

82(following compatibility rules) when schema produced by this code changes. 

83""" 

84 

85# Copied from daf_butler. 

86DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

87"""Default name of the environmental variable that will be used to locate DB 

88credentials configuration file. """ 

89 

90DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

91"""Default path at which it is expected that DB credentials are found.""" 

92 

93 

94class CassandraMissingError(Exception): 

95 def __init__(self) -> None: 

96 super().__init__("cassandra-driver module cannot be imported") 

97 

98 

99class ApdbCassandraConfig(ApdbConfig): 

100 """Configuration class for Cassandra-based APDB implementation.""" 

101 

102 contact_points = ListField[str]( 

103 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

104 ) 

105 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

106 port = Field[int](doc="Port number to connect to.", default=9042) 

107 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

108 username = Field[str]( 

109 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

110 default="", 

111 ) 

112 read_consistency = Field[str]( 

113 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

114 ) 

115 write_consistency = Field[str]( 

116 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

117 ) 

118 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

119 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0) 

120 remove_timeout = Field[float](doc="Timeout in seconds for remove operations.", default=600.0) 

121 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

122 protocol_version = Field[int]( 

123 doc="Cassandra protocol version to use, default is V4", 

124 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

125 ) 

126 dia_object_columns = ListField[str]( 

127 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

128 ) 

129 prefix = Field[str](doc="Prefix to add to table names", default="") 

130 part_pixelization = ChoiceField[str]( 

131 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

132 doc="Pixelization used for partitioning index.", 

133 default="mq3c", 

134 ) 

135 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

136 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

137 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

138 timer = Field[bool](doc="If True then print/log timing information", default=False) 

139 time_partition_tables = Field[bool]( 

140 doc="Use per-partition tables for sources instead of partitioning by time", default=False 

141 ) 

142 time_partition_days = Field[int]( 

143 doc=( 

144 "Time partitioning granularity in days, this value must not be changed after database is " 

145 "initialized" 

146 ), 

147 default=30, 

148 ) 

149 time_partition_start = Field[str]( 

150 doc=( 

151 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

152 "This is used only when time_partition_tables is True." 

153 ), 

154 default="2018-12-01T00:00:00", 

155 ) 

156 time_partition_end = Field[str]( 

157 doc=( 

158 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

159 "This is used only when time_partition_tables is True." 

160 ), 

161 default="2030-01-01T00:00:00", 

162 ) 

163 query_per_time_part = Field[bool]( 

164 default=False, 

165 doc=( 

166 "If True then build separate query for each time partition, otherwise build one single query. " 

167 "This is only used when time_partition_tables is False in schema config." 

168 ), 

169 ) 

170 query_per_spatial_part = Field[bool]( 

171 default=False, 

172 doc="If True then build one query per spatial partition, otherwise build single query.", 

173 ) 

174 use_insert_id_skips_diaobjects = Field[bool]( 

175 default=False, 

176 doc=( 

177 "If True then do not store DiaObjects when use_insert_id is True " 

178 "(DiaObjectsInsertId has the same data)." 

179 ), 

180 ) 

181 

182 

183@dataclasses.dataclass 

184class _FrozenApdbCassandraConfig: 

185 """Part of the configuration that is saved in metadata table and read back. 

186 

187 The attributes are a subset of attributes in `ApdbCassandraConfig` class. 

188 

189 Parameters 

190 ---------- 

191 config : `ApdbSqlConfig` 

192 Configuration used to copy initial values of attributes. 

193 """ 

194 

195 use_insert_id: bool 

196 part_pixelization: str 

197 part_pix_level: int 

198 ra_dec_columns: list[str] 

199 time_partition_tables: bool 

200 time_partition_days: int 

201 use_insert_id_skips_diaobjects: bool 

202 

203 def __init__(self, config: ApdbCassandraConfig): 

204 self.use_insert_id = config.use_insert_id 

205 self.part_pixelization = config.part_pixelization 

206 self.part_pix_level = config.part_pix_level 

207 self.ra_dec_columns = list(config.ra_dec_columns) 

208 self.time_partition_tables = config.time_partition_tables 

209 self.time_partition_days = config.time_partition_days 

210 self.use_insert_id_skips_diaobjects = config.use_insert_id_skips_diaobjects 

211 

212 def to_json(self) -> str: 

213 """Convert this instance to JSON representation.""" 

214 return json.dumps(dataclasses.asdict(self)) 

215 

216 def update(self, json_str: str) -> None: 

217 """Update attribute values from a JSON string. 

218 

219 Parameters 

220 ---------- 

221 json_str : str 

222 String containing JSON representation of configuration. 

223 """ 

224 data = json.loads(json_str) 

225 if not isinstance(data, dict): 

226 raise TypeError(f"JSON string must be convertible to object: {json_str!r}") 

227 allowed_names = {field.name for field in dataclasses.fields(self)} 

228 for key, value in data.items(): 

229 if key not in allowed_names: 

230 raise ValueError(f"JSON object contains unknown key: {key}") 

231 setattr(self, key, value) 

232 

233 

234if CASSANDRA_IMPORTED: 234 ↛ 236line 234 didn't jump to line 236, because the condition on line 234 was never true

235 

236 class _AddressTranslator(AddressTranslator): 

237 """Translate internal IP address to external. 

238 

239 Only used for docker-based setup, not viable long-term solution. 

240 """ 

241 

242 def __init__(self, public_ips: list[str], private_ips: list[str]): 

243 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

244 

245 def translate(self, private_ip: str) -> str: 

246 return self._map.get(private_ip, private_ip) 

247 

248 

249class ApdbCassandra(Apdb): 

250 """Implementation of APDB database on to of Apache Cassandra. 

251 

252 The implementation is configured via standard ``pex_config`` mechanism 

253 using `ApdbCassandraConfig` configuration class. For an example of 

254 different configurations check config/ folder. 

255 

256 Parameters 

257 ---------- 

258 config : `ApdbCassandraConfig` 

259 Configuration object. 

260 """ 

261 

262 metadataSchemaVersionKey = "version:schema" 

263 """Name of the metadata key to store schema version number.""" 

264 

265 metadataCodeVersionKey = "version:ApdbCassandra" 

266 """Name of the metadata key to store code version number.""" 

267 

268 metadataConfigKey = "config:apdb-cassandra.json" 

269 """Name of the metadata key to store code version number.""" 

270 

271 _frozen_parameters = ( 

272 "use_insert_id", 

273 "part_pixelization", 

274 "part_pix_level", 

275 "ra_dec_columns", 

276 "time_partition_tables", 

277 "time_partition_days", 

278 "use_insert_id_skips_diaobjects", 

279 ) 

280 """Names of the config parameters to be frozen in metadata table.""" 

281 

282 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

283 """Start time for partition 0, this should never be changed.""" 

284 

285 def __init__(self, config: ApdbCassandraConfig): 

286 if not CASSANDRA_IMPORTED: 

287 raise CassandraMissingError() 

288 

289 self._keyspace = config.keyspace 

290 

291 self._cluster, self._session = self._make_session(config) 

292 

293 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

294 self._metadata = ApdbMetadataCassandra( 

295 self._session, meta_table_name, config.keyspace, "read_tuples", "write" 

296 ) 

297 

298 # Read frozen config from metadata. 

299 config_json = self._metadata.get(self.metadataConfigKey) 

300 if config_json is not None: 

301 # Update config from metadata. 

302 freezer = ApdbConfigFreezer[ApdbCassandraConfig](self._frozen_parameters) 

303 self.config = freezer.update(config, config_json) 

304 else: 

305 self.config = config 

306 self.config.validate() 

307 

308 self._pixelization = Pixelization( 

309 self.config.part_pixelization, 

310 self.config.part_pix_level, 

311 config.part_pix_max_ranges, 

312 ) 

313 

314 self._schema = ApdbCassandraSchema( 

315 session=self._session, 

316 keyspace=self._keyspace, 

317 schema_file=self.config.schema_file, 

318 schema_name=self.config.schema_name, 

319 prefix=self.config.prefix, 

320 time_partition_tables=self.config.time_partition_tables, 

321 use_insert_id=self.config.use_insert_id, 

322 ) 

323 self._partition_zero_epoch_mjd: float = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

324 

325 if self._metadata.table_exists(): 

326 self._versionCheck(self._metadata) 

327 

328 # Cache for prepared statements 

329 self._preparer = PreparedStatementCache(self._session) 

330 

331 _LOG.debug("ApdbCassandra Configuration:") 

332 for key, value in self.config.items(): 

333 _LOG.debug(" %s: %s", key, value) 

334 

335 def __del__(self) -> None: 

336 if hasattr(self, "_cluster"): 

337 self._cluster.shutdown() 

338 

339 @classmethod 

340 def _make_session(cls, config: ApdbCassandraConfig) -> tuple[Cluster, Session]: 

341 """Make Cassandra session.""" 

342 addressTranslator: AddressTranslator | None = None 

343 if config.private_ips: 

344 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

345 

346 cluster = Cluster( 

347 execution_profiles=cls._makeProfiles(config), 

348 contact_points=config.contact_points, 

349 port=config.port, 

350 address_translator=addressTranslator, 

351 protocol_version=config.protocol_version, 

352 auth_provider=cls._make_auth_provider(config), 

353 ) 

354 session = cluster.connect() 

355 # Disable result paging 

356 session.default_fetch_size = None 

357 

358 return cluster, session 

359 

360 @classmethod 

361 def _make_auth_provider(cls, config: ApdbCassandraConfig) -> AuthProvider | None: 

362 """Make Cassandra authentication provider instance.""" 

363 try: 

364 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

365 except DbAuthNotFoundError: 

366 # Credentials file doesn't exist, use anonymous login. 

367 return None 

368 

369 empty_username = True 

370 # Try every contact point in turn. 

371 for hostname in config.contact_points: 

372 try: 

373 username, password = dbauth.getAuth( 

374 "cassandra", config.username, hostname, config.port, config.keyspace 

375 ) 

376 if not username: 

377 # Password without user name, try next hostname, but give 

378 # warning later if no better match is found. 

379 empty_username = True 

380 else: 

381 return PlainTextAuthProvider(username=username, password=password) 

382 except DbAuthNotFoundError: 

383 pass 

384 

385 if empty_username: 

386 _LOG.warning( 

387 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

388 f"user name, anonymous Cassandra logon will be attempted." 

389 ) 

390 

391 return None 

392 

393 def _versionCheck(self, metadata: ApdbMetadataCassandra) -> None: 

394 """Check schema version compatibility.""" 

395 

396 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

397 """Retrieve version number from given metadata key.""" 

398 if metadata.table_exists(): 

399 version_str = metadata.get(key) 

400 if version_str is None: 

401 # Should not happen with existing metadata table. 

402 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

403 return VersionTuple.fromString(version_str) 

404 return default 

405 

406 # For old databases where metadata table does not exist we assume that 

407 # version of both code and schema is 0.1.0. 

408 initial_version = VersionTuple(0, 1, 0) 

409 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

410 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

411 

412 # For now there is no way to make read-only APDB instances, assume that 

413 # any access can do updates. 

414 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

415 raise IncompatibleVersionError( 

416 f"Configured schema version {self._schema.schemaVersion()} " 

417 f"is not compatible with database version {db_schema_version}" 

418 ) 

419 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

420 raise IncompatibleVersionError( 

421 f"Current code version {self.apdbImplementationVersion()} " 

422 f"is not compatible with database version {db_code_version}" 

423 ) 

424 

425 @classmethod 

426 def apdbImplementationVersion(cls) -> VersionTuple: 

427 # Docstring inherited from base class. 

428 return VERSION 

429 

430 def apdbSchemaVersion(self) -> VersionTuple: 

431 # Docstring inherited from base class. 

432 return self._schema.schemaVersion() 

433 

434 def tableDef(self, table: ApdbTables) -> Table | None: 

435 # docstring is inherited from a base class 

436 return self._schema.tableSchemas.get(table) 

437 

438 @classmethod 

439 def makeSchema(cls, config: ApdbConfig, *, drop: bool = False) -> None: 

440 # docstring is inherited from a base class 

441 

442 if not isinstance(config, ApdbCassandraConfig): 

443 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

444 

445 cluster, session = cls._make_session(config) 

446 

447 schema = ApdbCassandraSchema( 

448 session=session, 

449 keyspace=config.keyspace, 

450 schema_file=config.schema_file, 

451 schema_name=config.schema_name, 

452 prefix=config.prefix, 

453 time_partition_tables=config.time_partition_tables, 

454 use_insert_id=config.use_insert_id, 

455 ) 

456 

457 # Ask schema to create all tables. 

458 if config.time_partition_tables: 

459 time_partition_start = dafBase.DateTime(config.time_partition_start, dafBase.DateTime.TAI) 

460 time_partition_end = dafBase.DateTime(config.time_partition_end, dafBase.DateTime.TAI) 

461 part_epoch: float = cls.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

462 part_days = config.time_partition_days 

463 part_range = ( 

464 cls._time_partition_cls(time_partition_start, part_epoch, part_days), 

465 cls._time_partition_cls(time_partition_end, part_epoch, part_days) + 1, 

466 ) 

467 schema.makeSchema(drop=drop, part_range=part_range) 

468 else: 

469 schema.makeSchema(drop=drop) 

470 

471 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

472 metadata = ApdbMetadataCassandra(session, meta_table_name, config.keyspace, "read_tuples", "write") 

473 

474 # Fill version numbers, overrides if they existed before. 

475 if metadata.table_exists(): 

476 metadata.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

477 metadata.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

478 

479 # Store frozen part of a configuration in metadata. 

480 freezer = ApdbConfigFreezer[ApdbCassandraConfig](cls._frozen_parameters) 

481 metadata.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

482 

483 cluster.shutdown() 

484 

485 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

486 # docstring is inherited from a base class 

487 

488 sp_where = self._spatial_where(region) 

489 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

490 

491 # We need to exclude extra partitioning columns from result. 

492 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

493 what = ",".join(quote_id(column) for column in column_names) 

494 

495 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

496 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

497 statements: list[tuple] = [] 

498 for where, params in sp_where: 

499 full_query = f"{query} WHERE {where}" 

500 if params: 

501 statement = self._preparer.prepare(full_query) 

502 else: 

503 # If there are no params then it is likely that query has a 

504 # bunch of literals rendered already, no point trying to 

505 # prepare it because it's not reusable. 

506 statement = cassandra.query.SimpleStatement(full_query) 

507 statements.append((statement, params)) 

508 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

509 

510 with Timer("DiaObject select", self.config.timer): 

511 objects = cast( 

512 pandas.DataFrame, 

513 select_concurrent( 

514 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

515 ), 

516 ) 

517 

518 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

519 return objects 

520 

521 def getDiaSources( 

522 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

523 ) -> pandas.DataFrame | None: 

524 # docstring is inherited from a base class 

525 months = self.config.read_sources_months 

526 if months == 0: 

527 return None 

528 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

529 mjd_start = mjd_end - months * 30 

530 

531 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

532 

533 def getDiaForcedSources( 

534 self, region: sphgeom.Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

535 ) -> pandas.DataFrame | None: 

536 # docstring is inherited from a base class 

537 months = self.config.read_forced_sources_months 

538 if months == 0: 

539 return None 

540 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

541 mjd_start = mjd_end - months * 30 

542 

543 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

544 

545 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

546 # docstring is inherited from a base class 

547 raise NotImplementedError() 

548 

549 def getInsertIds(self) -> list[ApdbInsertId] | None: 

550 # docstring is inherited from a base class 

551 if not self._schema.has_insert_id: 

552 return None 

553 

554 # everything goes into a single partition 

555 partition = 0 

556 

557 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

558 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?' 

559 

560 result = self._session.execute( 

561 self._preparer.prepare(query), 

562 (partition,), 

563 timeout=self.config.read_timeout, 

564 execution_profile="read_tuples", 

565 ) 

566 # order by insert_time 

567 rows = sorted(result) 

568 return [ 

569 ApdbInsertId(id=row[1], insert_time=dafBase.DateTime(int(row[0].timestamp() * 1e9))) 

570 for row in rows 

571 ] 

572 

573 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

574 # docstring is inherited from a base class 

575 if not self._schema.has_insert_id: 

576 raise ValueError("APDB is not configured for history storage") 

577 

578 all_insert_ids = [id.id for id in ids] 

579 # There is 64k limit on number of markers in Cassandra CQL 

580 for insert_ids in chunk_iterable(all_insert_ids, 20_000): 

581 params = ",".join("?" * len(insert_ids)) 

582 

583 # everything goes into a single partition 

584 partition = 0 

585 

586 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

587 query = ( 

588 f'DELETE FROM "{self._keyspace}"."{table_name}" ' 

589 f"WHERE partition = ? AND insert_id IN ({params})" 

590 ) 

591 

592 self._session.execute( 

593 self._preparer.prepare(query), 

594 [partition] + list(insert_ids), 

595 timeout=self.config.remove_timeout, 

596 ) 

597 

598 # Also remove those insert_ids from Dia*InsertId tables.abs 

599 for table in ( 

600 ExtraTables.DiaObjectInsertId, 

601 ExtraTables.DiaSourceInsertId, 

602 ExtraTables.DiaForcedSourceInsertId, 

603 ): 

604 table_name = self._schema.tableName(table) 

605 query = f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

606 self._session.execute( 

607 self._preparer.prepare(query), 

608 insert_ids, 

609 timeout=self.config.remove_timeout, 

610 ) 

611 

612 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

613 # docstring is inherited from a base class 

614 return self._get_history(ExtraTables.DiaObjectInsertId, ids) 

615 

616 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

617 # docstring is inherited from a base class 

618 return self._get_history(ExtraTables.DiaSourceInsertId, ids) 

619 

620 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

621 # docstring is inherited from a base class 

622 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids) 

623 

624 def getSSObjects(self) -> pandas.DataFrame: 

625 # docstring is inherited from a base class 

626 tableName = self._schema.tableName(ApdbTables.SSObject) 

627 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

628 

629 objects = None 

630 with Timer("SSObject select", self.config.timer): 

631 result = self._session.execute(query, execution_profile="read_pandas") 

632 objects = result._current_rows 

633 

634 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

635 return objects 

636 

637 def store( 

638 self, 

639 visit_time: dafBase.DateTime, 

640 objects: pandas.DataFrame, 

641 sources: pandas.DataFrame | None = None, 

642 forced_sources: pandas.DataFrame | None = None, 

643 ) -> None: 

644 # docstring is inherited from a base class 

645 

646 insert_id: ApdbInsertId | None = None 

647 if self._schema.has_insert_id: 

648 insert_id = ApdbInsertId.new_insert_id(visit_time) 

649 self._storeInsertId(insert_id, visit_time) 

650 

651 # fill region partition column for DiaObjects 

652 objects = self._add_obj_part(objects) 

653 self._storeDiaObjects(objects, visit_time, insert_id) 

654 

655 if sources is not None: 

656 # copy apdb_part column from DiaObjects to DiaSources 

657 sources = self._add_src_part(sources, objects) 

658 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id) 

659 self._storeDiaSourcesPartitions(sources, visit_time, insert_id) 

660 

661 if forced_sources is not None: 

662 forced_sources = self._add_fsrc_part(forced_sources, objects) 

663 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id) 

664 

665 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

666 # docstring is inherited from a base class 

667 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

668 

669 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

670 # docstring is inherited from a base class 

671 

672 # To update a record we need to know its exact primary key (including 

673 # partition key) so we start by querying for diaSourceId to find the 

674 # primary keys. 

675 

676 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

677 # split it into 1k IDs per query 

678 selects: list[tuple] = [] 

679 for ids in chunk_iterable(idMap.keys(), 1_000): 

680 ids_str = ",".join(str(item) for item in ids) 

681 selects.append( 

682 ( 

683 ( 

684 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" ' 

685 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

686 ), 

687 {}, 

688 ) 

689 ) 

690 

691 # No need for DataFrame here, read data as tuples. 

692 result = cast( 

693 list[tuple[int, int, int, uuid.UUID | None]], 

694 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

695 ) 

696 

697 # Make mapping from source ID to its partition. 

698 id2partitions: dict[int, tuple[int, int]] = {} 

699 id2insert_id: dict[int, uuid.UUID] = {} 

700 for row in result: 

701 id2partitions[row[0]] = row[1:3] 

702 if row[3] is not None: 

703 id2insert_id[row[0]] = row[3] 

704 

705 # make sure we know partitions for each ID 

706 if set(id2partitions) != set(idMap): 

707 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

708 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

709 

710 # Reassign in standard tables 

711 queries = cassandra.query.BatchStatement() 

712 table_name = self._schema.tableName(ApdbTables.DiaSource) 

713 for diaSourceId, ssObjectId in idMap.items(): 

714 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

715 values: tuple 

716 if self.config.time_partition_tables: 

717 query = ( 

718 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

719 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

720 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

721 ) 

722 values = (ssObjectId, apdb_part, diaSourceId) 

723 else: 

724 query = ( 

725 f'UPDATE "{self._keyspace}"."{table_name}"' 

726 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

727 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

728 ) 

729 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

730 queries.add(self._preparer.prepare(query), values) 

731 

732 # Reassign in history tables, only if history is enabled 

733 if id2insert_id: 

734 # Filter out insert ids that have been deleted already. There is a 

735 # potential race with concurrent removal of insert IDs, but it 

736 # should be handled by WHERE in UPDATE. 

737 known_ids = set() 

738 if insert_ids := self.getInsertIds(): 

739 known_ids = set(insert_id.id for insert_id in insert_ids) 

740 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids} 

741 if id2insert_id: 

742 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId) 

743 for diaSourceId, ssObjectId in idMap.items(): 

744 if insert_id := id2insert_id.get(diaSourceId): 

745 query = ( 

746 f'UPDATE "{self._keyspace}"."{table_name}" ' 

747 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

748 'WHERE "insert_id" = ? AND "diaSourceId" = ?' 

749 ) 

750 values = (ssObjectId, insert_id, diaSourceId) 

751 queries.add(self._preparer.prepare(query), values) 

752 

753 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

754 with Timer(table_name + " update", self.config.timer): 

755 self._session.execute(queries, execution_profile="write") 

756 

757 def dailyJob(self) -> None: 

758 # docstring is inherited from a base class 

759 pass 

760 

761 def countUnassociatedObjects(self) -> int: 

762 # docstring is inherited from a base class 

763 

764 # It's too inefficient to implement it for Cassandra in current schema. 

765 raise NotImplementedError() 

766 

767 @property 

768 def metadata(self) -> ApdbMetadata: 

769 # docstring is inherited from a base class 

770 if self._metadata is None: 

771 raise RuntimeError("Database schema was not initialized.") 

772 return self._metadata 

773 

774 @classmethod 

775 def _makeProfiles(cls, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

776 """Make all execution profiles used in the code.""" 

777 if config.private_ips: 

778 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

779 else: 

780 loadBalancePolicy = RoundRobinPolicy() 

781 

782 read_tuples_profile = ExecutionProfile( 

783 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

784 request_timeout=config.read_timeout, 

785 row_factory=cassandra.query.tuple_factory, 

786 load_balancing_policy=loadBalancePolicy, 

787 ) 

788 read_pandas_profile = ExecutionProfile( 

789 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

790 request_timeout=config.read_timeout, 

791 row_factory=pandas_dataframe_factory, 

792 load_balancing_policy=loadBalancePolicy, 

793 ) 

794 read_raw_profile = ExecutionProfile( 

795 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

796 request_timeout=config.read_timeout, 

797 row_factory=raw_data_factory, 

798 load_balancing_policy=loadBalancePolicy, 

799 ) 

800 # Profile to use with select_concurrent to return pandas data frame 

801 read_pandas_multi_profile = ExecutionProfile( 

802 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

803 request_timeout=config.read_timeout, 

804 row_factory=pandas_dataframe_factory, 

805 load_balancing_policy=loadBalancePolicy, 

806 ) 

807 # Profile to use with select_concurrent to return raw data (columns and 

808 # rows) 

809 read_raw_multi_profile = ExecutionProfile( 

810 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

811 request_timeout=config.read_timeout, 

812 row_factory=raw_data_factory, 

813 load_balancing_policy=loadBalancePolicy, 

814 ) 

815 write_profile = ExecutionProfile( 

816 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

817 request_timeout=config.write_timeout, 

818 load_balancing_policy=loadBalancePolicy, 

819 ) 

820 # To replace default DCAwareRoundRobinPolicy 

821 default_profile = ExecutionProfile( 

822 load_balancing_policy=loadBalancePolicy, 

823 ) 

824 return { 

825 "read_tuples": read_tuples_profile, 

826 "read_pandas": read_pandas_profile, 

827 "read_raw": read_raw_profile, 

828 "read_pandas_multi": read_pandas_multi_profile, 

829 "read_raw_multi": read_raw_multi_profile, 

830 "write": write_profile, 

831 EXEC_PROFILE_DEFAULT: default_profile, 

832 } 

833 

834 def _getSources( 

835 self, 

836 region: sphgeom.Region, 

837 object_ids: Iterable[int] | None, 

838 mjd_start: float, 

839 mjd_end: float, 

840 table_name: ApdbTables, 

841 ) -> pandas.DataFrame: 

842 """Return catalog of DiaSource instances given set of DiaObject IDs. 

843 

844 Parameters 

845 ---------- 

846 region : `lsst.sphgeom.Region` 

847 Spherical region. 

848 object_ids : 

849 Collection of DiaObject IDs 

850 mjd_start : `float` 

851 Lower bound of time interval. 

852 mjd_end : `float` 

853 Upper bound of time interval. 

854 table_name : `ApdbTables` 

855 Name of the table. 

856 

857 Returns 

858 ------- 

859 catalog : `pandas.DataFrame`, or `None` 

860 Catalog containing DiaSource records. Empty catalog is returned if 

861 ``object_ids`` is empty. 

862 """ 

863 object_id_set: Set[int] = set() 

864 if object_ids is not None: 

865 object_id_set = set(object_ids) 

866 if len(object_id_set) == 0: 

867 return self._make_empty_catalog(table_name) 

868 

869 sp_where = self._spatial_where(region) 

870 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

871 

872 # We need to exclude extra partitioning columns from result. 

873 column_names = self._schema.apdbColumnNames(table_name) 

874 what = ",".join(quote_id(column) for column in column_names) 

875 

876 # Build all queries 

877 statements: list[tuple] = [] 

878 for table in tables: 

879 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

880 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

881 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

882 

883 with Timer(table_name.name + " select", self.config.timer): 

884 catalog = cast( 

885 pandas.DataFrame, 

886 select_concurrent( 

887 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

888 ), 

889 ) 

890 

891 # filter by given object IDs 

892 if len(object_id_set) > 0: 

893 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

894 

895 # precise filtering on midpointMjdTai 

896 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

897 

898 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

899 return catalog 

900 

901 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

902 """Return records from a particular table given set of insert IDs.""" 

903 if not self._schema.has_insert_id: 

904 raise ValueError("APDB is not configured for history retrieval") 

905 

906 insert_ids = [id.id for id in ids] 

907 params = ",".join("?" * len(insert_ids)) 

908 

909 table_name = self._schema.tableName(table) 

910 # I know that history table schema has only regular APDB columns plus 

911 # an insert_id column, and this is exactly what we need to return from 

912 # this method, so selecting a star is fine here. 

913 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

914 statement = self._preparer.prepare(query) 

915 

916 with Timer("DiaObject history", self.config.timer): 

917 result = self._session.execute(statement, insert_ids, execution_profile="read_raw") 

918 table_data = cast(ApdbCassandraTableData, result._current_rows) 

919 return table_data 

920 

921 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None: 

922 # Cassandra timestamp uses milliseconds since epoch 

923 timestamp = insert_id.insert_time.nsecs() // 1_000_000 

924 

925 # everything goes into a single partition 

926 partition = 0 

927 

928 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

929 query = ( 

930 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) ' 

931 "VALUES (?, ?, ?)" 

932 ) 

933 

934 self._session.execute( 

935 self._preparer.prepare(query), 

936 (partition, insert_id.id, timestamp), 

937 timeout=self.config.write_timeout, 

938 execution_profile="write", 

939 ) 

940 

941 def _storeDiaObjects( 

942 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

943 ) -> None: 

944 """Store catalog of DiaObjects from current visit. 

945 

946 Parameters 

947 ---------- 

948 objs : `pandas.DataFrame` 

949 Catalog with DiaObject records 

950 visit_time : `lsst.daf.base.DateTime` 

951 Time of the current visit. 

952 """ 

953 visit_time_dt = visit_time.toPython() 

954 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

955 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

956 

957 extra_columns["validityStart"] = visit_time_dt 

958 time_part: int | None = self._time_partition(visit_time) 

959 if not self.config.time_partition_tables: 

960 extra_columns["apdb_time_part"] = time_part 

961 time_part = None 

962 

963 # Only store DiaObects if not storing insert_ids or explicitly 

964 # configured to always store them 

965 if insert_id is None or not self.config.use_insert_id_skips_diaobjects: 

966 self._storeObjectsPandas( 

967 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

968 ) 

969 

970 if insert_id is not None: 

971 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt) 

972 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns) 

973 

974 def _storeDiaSources( 

975 self, 

976 table_name: ApdbTables, 

977 sources: pandas.DataFrame, 

978 visit_time: dafBase.DateTime, 

979 insert_id: ApdbInsertId | None, 

980 ) -> None: 

981 """Store catalog of DIASources or DIAForcedSources from current visit. 

982 

983 Parameters 

984 ---------- 

985 sources : `pandas.DataFrame` 

986 Catalog containing DiaSource records 

987 visit_time : `lsst.daf.base.DateTime` 

988 Time of the current visit. 

989 """ 

990 time_part: int | None = self._time_partition(visit_time) 

991 extra_columns: dict[str, Any] = {} 

992 if not self.config.time_partition_tables: 

993 extra_columns["apdb_time_part"] = time_part 

994 time_part = None 

995 

996 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

997 

998 if insert_id is not None: 

999 extra_columns = dict(insert_id=insert_id.id) 

1000 if table_name is ApdbTables.DiaSource: 

1001 extra_table = ExtraTables.DiaSourceInsertId 

1002 else: 

1003 extra_table = ExtraTables.DiaForcedSourceInsertId 

1004 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1005 

1006 def _storeDiaSourcesPartitions( 

1007 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

1008 ) -> None: 

1009 """Store mapping of diaSourceId to its partitioning values. 

1010 

1011 Parameters 

1012 ---------- 

1013 sources : `pandas.DataFrame` 

1014 Catalog containing DiaSource records 

1015 visit_time : `lsst.daf.base.DateTime` 

1016 Time of the current visit. 

1017 """ 

1018 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1019 extra_columns = { 

1020 "apdb_time_part": self._time_partition(visit_time), 

1021 "insert_id": insert_id.id if insert_id is not None else None, 

1022 } 

1023 

1024 self._storeObjectsPandas( 

1025 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1026 ) 

1027 

1028 def _storeObjectsPandas( 

1029 self, 

1030 records: pandas.DataFrame, 

1031 table_name: ApdbTables | ExtraTables, 

1032 extra_columns: Mapping | None = None, 

1033 time_part: int | None = None, 

1034 ) -> None: 

1035 """Store generic objects. 

1036 

1037 Takes Pandas catalog and stores a bunch of records in a table. 

1038 

1039 Parameters 

1040 ---------- 

1041 records : `pandas.DataFrame` 

1042 Catalog containing object records 

1043 table_name : `ApdbTables` 

1044 Name of the table as defined in APDB schema. 

1045 extra_columns : `dict`, optional 

1046 Mapping (column_name, column_value) which gives fixed values for 

1047 columns in each row, overrides values in ``records`` if matching 

1048 columns exist there. 

1049 time_part : `int`, optional 

1050 If not `None` then insert into a per-partition table. 

1051 

1052 Notes 

1053 ----- 

1054 If Pandas catalog contains additional columns not defined in table 

1055 schema they are ignored. Catalog does not have to contain all columns 

1056 defined in a table, but partition and clustering keys must be present 

1057 in a catalog or ``extra_columns``. 

1058 """ 

1059 # use extra columns if specified 

1060 if extra_columns is None: 

1061 extra_columns = {} 

1062 extra_fields = list(extra_columns.keys()) 

1063 

1064 # Fields that will come from dataframe. 

1065 df_fields = [column for column in records.columns if column not in extra_fields] 

1066 

1067 column_map = self._schema.getColumnMap(table_name) 

1068 # list of columns (as in felis schema) 

1069 fields = [column_map[field].name for field in df_fields if field in column_map] 

1070 fields += extra_fields 

1071 

1072 # check that all partitioning and clustering columns are defined 

1073 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

1074 table_name 

1075 ) 

1076 missing_columns = [column for column in required_columns if column not in fields] 

1077 if missing_columns: 

1078 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1079 

1080 qfields = [quote_id(field) for field in fields] 

1081 qfields_str = ",".join(qfields) 

1082 

1083 with Timer(table_name.name + " query build", self.config.timer): 

1084 table = self._schema.tableName(table_name) 

1085 if time_part is not None: 

1086 table = f"{table}_{time_part}" 

1087 

1088 holders = ",".join(["?"] * len(qfields)) 

1089 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

1090 statement = self._preparer.prepare(query) 

1091 queries = cassandra.query.BatchStatement() 

1092 for rec in records.itertuples(index=False): 

1093 values = [] 

1094 for field in df_fields: 

1095 if field not in column_map: 

1096 continue 

1097 value = getattr(rec, field) 

1098 if column_map[field].datatype is felis.types.Timestamp: 

1099 if isinstance(value, pandas.Timestamp): 

1100 value = literal(value.to_pydatetime()) 

1101 else: 

1102 # Assume it's seconds since epoch, Cassandra 

1103 # datetime is in milliseconds 

1104 value = int(value * 1000) 

1105 values.append(literal(value)) 

1106 for field in extra_fields: 

1107 value = extra_columns[field] 

1108 values.append(literal(value)) 

1109 queries.add(statement, values) 

1110 

1111 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

1112 with Timer(table_name.name + " insert", self.config.timer): 

1113 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

1114 

1115 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1116 """Calculate spatial partition for each record and add it to a 

1117 DataFrame. 

1118 

1119 Notes 

1120 ----- 

1121 This overrides any existing column in a DataFrame with the same name 

1122 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1123 returned. 

1124 """ 

1125 # calculate HTM index for every DiaObject 

1126 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1127 ra_col, dec_col = self.config.ra_dec_columns 

1128 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1129 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1130 idx = self._pixelization.pixel(uv3d) 

1131 apdb_part[i] = idx 

1132 df = df.copy() 

1133 df["apdb_part"] = apdb_part 

1134 return df 

1135 

1136 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1137 """Add apdb_part column to DiaSource catalog. 

1138 

1139 Notes 

1140 ----- 

1141 This method copies apdb_part value from a matching DiaObject record. 

1142 DiaObject catalog needs to have a apdb_part column filled by 

1143 ``_add_obj_part`` method and DiaSource records need to be 

1144 associated to DiaObjects via ``diaObjectId`` column. 

1145 

1146 This overrides any existing column in a DataFrame with the same name 

1147 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1148 returned. 

1149 """ 

1150 pixel_id_map: dict[int, int] = { 

1151 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1152 } 

1153 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1154 ra_col, dec_col = self.config.ra_dec_columns 

1155 for i, (diaObjId, ra, dec) in enumerate( 

1156 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

1157 ): 

1158 if diaObjId == 0: 

1159 # DiaSources associated with SolarSystemObjects do not have an 

1160 # associated DiaObject hence we skip them and set partition 

1161 # based on its own ra/dec 

1162 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1163 idx = self._pixelization.pixel(uv3d) 

1164 apdb_part[i] = idx 

1165 else: 

1166 apdb_part[i] = pixel_id_map[diaObjId] 

1167 sources = sources.copy() 

1168 sources["apdb_part"] = apdb_part 

1169 return sources 

1170 

1171 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1172 """Add apdb_part column to DiaForcedSource catalog. 

1173 

1174 Notes 

1175 ----- 

1176 This method copies apdb_part value from a matching DiaObject record. 

1177 DiaObject catalog needs to have a apdb_part column filled by 

1178 ``_add_obj_part`` method and DiaSource records need to be 

1179 associated to DiaObjects via ``diaObjectId`` column. 

1180 

1181 This overrides any existing column in a DataFrame with the same name 

1182 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

1183 returned. 

1184 """ 

1185 pixel_id_map: dict[int, int] = { 

1186 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

1187 } 

1188 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

1189 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1190 apdb_part[i] = pixel_id_map[diaObjId] 

1191 sources = sources.copy() 

1192 sources["apdb_part"] = apdb_part 

1193 return sources 

1194 

1195 @classmethod 

1196 def _time_partition_cls(cls, time: float | dafBase.DateTime, epoch_mjd: float, part_days: int) -> int: 

1197 """Calculate time partition number for a given time. 

1198 

1199 Parameters 

1200 ---------- 

1201 time : `float` or `lsst.daf.base.DateTime` 

1202 Time for which to calculate partition number. Can be float to mean 

1203 MJD or `lsst.daf.base.DateTime` 

1204 epoch_mjd : `float` 

1205 Epoch time for partition 0. 

1206 part_days : `int` 

1207 Number of days per partition. 

1208 

1209 Returns 

1210 ------- 

1211 partition : `int` 

1212 Partition number for a given time. 

1213 """ 

1214 if isinstance(time, dafBase.DateTime): 

1215 mjd = time.get(system=dafBase.DateTime.MJD) 

1216 else: 

1217 mjd = time 

1218 days_since_epoch = mjd - epoch_mjd 

1219 partition = int(days_since_epoch) // part_days 

1220 return partition 

1221 

1222 def _time_partition(self, time: float | dafBase.DateTime) -> int: 

1223 """Calculate time partition number for a given time. 

1224 

1225 Parameters 

1226 ---------- 

1227 time : `float` or `lsst.daf.base.DateTime` 

1228 Time for which to calculate partition number. Can be float to mean 

1229 MJD or `lsst.daf.base.DateTime` 

1230 

1231 Returns 

1232 ------- 

1233 partition : `int` 

1234 Partition number for a given time. 

1235 """ 

1236 if isinstance(time, dafBase.DateTime): 

1237 mjd = time.get(system=dafBase.DateTime.MJD) 

1238 else: 

1239 mjd = time 

1240 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

1241 partition = int(days_since_epoch) // self.config.time_partition_days 

1242 return partition 

1243 

1244 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1245 """Make an empty catalog for a table with a given name. 

1246 

1247 Parameters 

1248 ---------- 

1249 table_name : `ApdbTables` 

1250 Name of the table. 

1251 

1252 Returns 

1253 ------- 

1254 catalog : `pandas.DataFrame` 

1255 An empty catalog. 

1256 """ 

1257 table = self._schema.tableSchemas[table_name] 

1258 

1259 data = { 

1260 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1261 for columnDef in table.columns 

1262 } 

1263 return pandas.DataFrame(data) 

1264 

1265 def _combine_where( 

1266 self, 

1267 prefix: str, 

1268 where1: list[tuple[str, tuple]], 

1269 where2: list[tuple[str, tuple]], 

1270 suffix: str | None = None, 

1271 ) -> Iterator[tuple[cassandra.query.Statement, tuple]]: 

1272 """Make cartesian product of two parts of WHERE clause into a series 

1273 of statements to execute. 

1274 

1275 Parameters 

1276 ---------- 

1277 prefix : `str` 

1278 Initial statement prefix that comes before WHERE clause, e.g. 

1279 "SELECT * from Table" 

1280 """ 

1281 # If lists are empty use special sentinels. 

1282 if not where1: 

1283 where1 = [("", ())] 

1284 if not where2: 

1285 where2 = [("", ())] 

1286 

1287 for expr1, params1 in where1: 

1288 for expr2, params2 in where2: 

1289 full_query = prefix 

1290 wheres = [] 

1291 if expr1: 

1292 wheres.append(expr1) 

1293 if expr2: 

1294 wheres.append(expr2) 

1295 if wheres: 

1296 full_query += " WHERE " + " AND ".join(wheres) 

1297 if suffix: 

1298 full_query += " " + suffix 

1299 params = params1 + params2 

1300 if params: 

1301 statement = self._preparer.prepare(full_query) 

1302 else: 

1303 # If there are no params then it is likely that query 

1304 # has a bunch of literals rendered already, no point 

1305 # trying to prepare it. 

1306 statement = cassandra.query.SimpleStatement(full_query) 

1307 yield (statement, params) 

1308 

1309 def _spatial_where( 

1310 self, region: sphgeom.Region | None, use_ranges: bool = False 

1311 ) -> list[tuple[str, tuple]]: 

1312 """Generate expressions for spatial part of WHERE clause. 

1313 

1314 Parameters 

1315 ---------- 

1316 region : `sphgeom.Region` 

1317 Spatial region for query results. 

1318 use_ranges : `bool` 

1319 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1320 p2") instead of exact list of pixels. Should be set to True for 

1321 large regions covering very many pixels. 

1322 

1323 Returns 

1324 ------- 

1325 expressions : `list` [ `tuple` ] 

1326 Empty list is returned if ``region`` is `None`, otherwise a list 

1327 of one or more (expression, parameters) tuples 

1328 """ 

1329 if region is None: 

1330 return [] 

1331 if use_ranges: 

1332 pixel_ranges = self._pixelization.envelope(region) 

1333 expressions: list[tuple[str, tuple]] = [] 

1334 for lower, upper in pixel_ranges: 

1335 upper -= 1 

1336 if lower == upper: 

1337 expressions.append(('"apdb_part" = ?', (lower,))) 

1338 else: 

1339 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1340 return expressions 

1341 else: 

1342 pixels = self._pixelization.pixels(region) 

1343 if self.config.query_per_spatial_part: 

1344 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1345 else: 

1346 pixels_str = ",".join([str(pix) for pix in pixels]) 

1347 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1348 

1349 def _temporal_where( 

1350 self, 

1351 table: ApdbTables, 

1352 start_time: float | dafBase.DateTime, 

1353 end_time: float | dafBase.DateTime, 

1354 query_per_time_part: bool | None = None, 

1355 ) -> tuple[list[str], list[tuple[str, tuple]]]: 

1356 """Generate table names and expressions for temporal part of WHERE 

1357 clauses. 

1358 

1359 Parameters 

1360 ---------- 

1361 table : `ApdbTables` 

1362 Table to select from. 

1363 start_time : `dafBase.DateTime` or `float` 

1364 Starting Datetime of MJD value of the time range. 

1365 start_time : `dafBase.DateTime` or `float` 

1366 Starting Datetime of MJD value of the time range. 

1367 query_per_time_part : `bool`, optional 

1368 If None then use ``query_per_time_part`` from configuration. 

1369 

1370 Returns 

1371 ------- 

1372 tables : `list` [ `str` ] 

1373 List of the table names to query. 

1374 expressions : `list` [ `tuple` ] 

1375 A list of zero or more (expression, parameters) tuples. 

1376 """ 

1377 tables: list[str] 

1378 temporal_where: list[tuple[str, tuple]] = [] 

1379 table_name = self._schema.tableName(table) 

1380 time_part_start = self._time_partition(start_time) 

1381 time_part_end = self._time_partition(end_time) 

1382 time_parts = list(range(time_part_start, time_part_end + 1)) 

1383 if self.config.time_partition_tables: 

1384 tables = [f"{table_name}_{part}" for part in time_parts] 

1385 else: 

1386 tables = [table_name] 

1387 if query_per_time_part is None: 

1388 query_per_time_part = self.config.query_per_time_part 

1389 if query_per_time_part: 

1390 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1391 else: 

1392 time_part_list = ",".join([str(part) for part in time_parts]) 

1393 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1394 

1395 return tables, temporal_where