Coverage for python/lsst/dax/apdb/apdbCassandra.py: 15%

475 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-26 10:23 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import logging 

27import uuid 

28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast 

29 

30import numpy as np 

31import pandas 

32 

33# If cassandra-driver is not there the module can still be imported 

34# but ApdbCassandra cannot be instantiated. 

35try: 

36 import cassandra 

37 import cassandra.query 

38 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

39 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile 

40 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

41 

42 CASSANDRA_IMPORTED = True 

43except ImportError: 

44 CASSANDRA_IMPORTED = False 

45 

46import felis.types 

47import lsst.daf.base as dafBase 

48from felis.simple import Table 

49from lsst import sphgeom 

50from lsst.pex.config import ChoiceField, Field, ListField 

51from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

52from lsst.utils.iteration import chunk_iterable 

53 

54from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

55from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

56from .apdbSchema import ApdbTables 

57from .cassandra_utils import ( 

58 ApdbCassandraTableData, 

59 literal, 

60 pandas_dataframe_factory, 

61 quote_id, 

62 raw_data_factory, 

63 select_concurrent, 

64) 

65from .pixelization import Pixelization 

66from .timer import Timer 

67 

68_LOG = logging.getLogger(__name__) 

69 

70# Copied from daf_butler. 

71DB_AUTH_ENVVAR = "LSST_DB_AUTH" 

72"""Default name of the environmental variable that will be used to locate DB 

73credentials configuration file. """ 

74 

75DB_AUTH_PATH = "~/.lsst/db-auth.yaml" 

76"""Default path at which it is expected that DB credentials are found.""" 

77 

78 

79class CassandraMissingError(Exception): 

80 def __init__(self) -> None: 

81 super().__init__("cassandra-driver module cannot be imported") 

82 

83 

84class ApdbCassandraConfig(ApdbConfig): 

85 """Configuration class for Cassandra-based APDB implementation.""" 

86 

87 contact_points = ListField[str]( 

88 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

89 ) 

90 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

91 port = Field[int](doc="Port number to connect to.", default=9042) 

92 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

93 username = Field[str]( 

94 doc=f"Cassandra user name, if empty then {DB_AUTH_PATH} has to provide it with password.", 

95 default="", 

96 ) 

97 read_consistency = Field[str]( 

98 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

99 ) 

100 write_consistency = Field[str]( 

101 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

102 ) 

103 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

104 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0) 

105 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

106 protocol_version = Field[int]( 

107 doc="Cassandra protocol version to use, default is V4", 

108 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

109 ) 

110 dia_object_columns = ListField[str]( 

111 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

112 ) 

113 prefix = Field[str](doc="Prefix to add to table names", default="") 

114 part_pixelization = ChoiceField[str]( 

115 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

116 doc="Pixelization used for partitioning index.", 

117 default="mq3c", 

118 ) 

119 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

120 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

121 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

122 timer = Field[bool](doc="If True then print/log timing information", default=False) 

123 time_partition_tables = Field[bool]( 

124 doc="Use per-partition tables for sources instead of partitioning by time", default=True 

125 ) 

126 time_partition_days = Field[int]( 

127 doc=( 

128 "Time partitioning granularity in days, this value must not be changed after database is " 

129 "initialized" 

130 ), 

131 default=30, 

132 ) 

133 time_partition_start = Field[str]( 

134 doc=( 

135 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

136 "This is used only when time_partition_tables is True." 

137 ), 

138 default="2018-12-01T00:00:00", 

139 ) 

140 time_partition_end = Field[str]( 

141 doc=( 

142 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

143 "This is used only when time_partition_tables is True." 

144 ), 

145 default="2030-01-01T00:00:00", 

146 ) 

147 query_per_time_part = Field[bool]( 

148 default=False, 

149 doc=( 

150 "If True then build separate query for each time partition, otherwise build one single query. " 

151 "This is only used when time_partition_tables is False in schema config." 

152 ), 

153 ) 

154 query_per_spatial_part = Field[bool]( 

155 default=False, 

156 doc="If True then build one query per spatial partition, otherwise build single query.", 

157 ) 

158 

159 

160if CASSANDRA_IMPORTED: 160 ↛ 175line 160 didn't jump to line 175, because the condition on line 160 was never false

161 

162 class _AddressTranslator(AddressTranslator): 

163 """Translate internal IP address to external. 

164 

165 Only used for docker-based setup, not viable long-term solution. 

166 """ 

167 

168 def __init__(self, public_ips: List[str], private_ips: List[str]): 

169 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

170 

171 def translate(self, private_ip: str) -> str: 

172 return self._map.get(private_ip, private_ip) 

173 

174 

175def _quote_column(name: str) -> str: 

176 """Quote column name""" 

177 if name.islower(): 

178 return name 

179 else: 

180 return f'"{name}"' 

181 

182 

183class ApdbCassandra(Apdb): 

184 """Implementation of APDB database on to of Apache Cassandra. 

185 

186 The implementation is configured via standard ``pex_config`` mechanism 

187 using `ApdbCassandraConfig` configuration class. For an example of 

188 different configurations check config/ folder. 

189 

190 Parameters 

191 ---------- 

192 config : `ApdbCassandraConfig` 

193 Configuration object. 

194 """ 

195 

196 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

197 """Start time for partition 0, this should never be changed.""" 

198 

199 def __init__(self, config: ApdbCassandraConfig): 

200 if not CASSANDRA_IMPORTED: 

201 raise CassandraMissingError() 

202 

203 config.validate() 

204 self.config = config 

205 

206 _LOG.debug("ApdbCassandra Configuration:") 

207 for key, value in self.config.items(): 

208 _LOG.debug(" %s: %s", key, value) 

209 

210 self._pixelization = Pixelization( 

211 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges 

212 ) 

213 

214 addressTranslator: Optional[AddressTranslator] = None 

215 if config.private_ips: 

216 addressTranslator = _AddressTranslator(list(config.contact_points), list(config.private_ips)) 

217 

218 self._keyspace = config.keyspace 

219 

220 self._cluster = Cluster( 

221 execution_profiles=self._makeProfiles(config), 

222 contact_points=self.config.contact_points, 

223 port=self.config.port, 

224 address_translator=addressTranslator, 

225 protocol_version=self.config.protocol_version, 

226 auth_provider=self._make_auth_provider(config), 

227 ) 

228 self._session = self._cluster.connect() 

229 # Disable result paging 

230 self._session.default_fetch_size = None 

231 

232 self._schema = ApdbCassandraSchema( 

233 session=self._session, 

234 keyspace=self._keyspace, 

235 schema_file=self.config.schema_file, 

236 schema_name=self.config.schema_name, 

237 prefix=self.config.prefix, 

238 time_partition_tables=self.config.time_partition_tables, 

239 use_insert_id=self.config.use_insert_id, 

240 ) 

241 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

242 

243 # Cache for prepared statements 

244 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {} 

245 

246 def __del__(self) -> None: 

247 if hasattr(self, "_cluster"): 

248 self._cluster.shutdown() 

249 

250 def _make_auth_provider(self, config: ApdbCassandraConfig) -> AuthProvider | None: 

251 """Make Cassandra authentication provider instance.""" 

252 try: 

253 dbauth = DbAuth(DB_AUTH_PATH, DB_AUTH_ENVVAR) 

254 except DbAuthNotFoundError: 

255 # Credentials file doesn't exist, use anonymous login. 

256 return None 

257 

258 empty_username = True 

259 # Try every contact point in turn. 

260 for hostname in config.contact_points: 

261 try: 

262 username, password = dbauth.getAuth( 

263 "cassandra", config.username, hostname, config.port, config.keyspace 

264 ) 

265 if not username: 

266 # Password without user name, try next hostname, but give 

267 # warning later if no better match is found. 

268 empty_username = True 

269 else: 

270 return PlainTextAuthProvider(username=username, password=password) 

271 except DbAuthNotFoundError: 

272 pass 

273 

274 if empty_username: 

275 _LOG.warning( 

276 f"Credentials file ({DB_AUTH_PATH} or ${DB_AUTH_ENVVAR}) provided password but not " 

277 f"user name, anonymous Cassandra logon will be attempted." 

278 ) 

279 

280 return None 

281 

282 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

283 # docstring is inherited from a base class 

284 return self._schema.tableSchemas.get(table) 

285 

286 def makeSchema(self, drop: bool = False) -> None: 

287 # docstring is inherited from a base class 

288 

289 if self.config.time_partition_tables: 

290 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

291 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

292 part_range = ( 

293 self._time_partition(time_partition_start), 

294 self._time_partition(time_partition_end) + 1, 

295 ) 

296 self._schema.makeSchema(drop=drop, part_range=part_range) 

297 else: 

298 self._schema.makeSchema(drop=drop) 

299 

300 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

301 # docstring is inherited from a base class 

302 

303 sp_where = self._spatial_where(region) 

304 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

305 

306 # We need to exclude extra partitioning columns from result. 

307 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

308 what = ",".join(_quote_column(column) for column in column_names) 

309 

310 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

311 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

312 statements: List[Tuple] = [] 

313 for where, params in sp_where: 

314 full_query = f"{query} WHERE {where}" 

315 if params: 

316 statement = self._prep_statement(full_query) 

317 else: 

318 # If there are no params then it is likely that query has a 

319 # bunch of literals rendered already, no point trying to 

320 # prepare it because it's not reusable. 

321 statement = cassandra.query.SimpleStatement(full_query) 

322 statements.append((statement, params)) 

323 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

324 

325 with Timer("DiaObject select", self.config.timer): 

326 objects = cast( 

327 pandas.DataFrame, 

328 select_concurrent( 

329 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

330 ), 

331 ) 

332 

333 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

334 return objects 

335 

336 def getDiaSources( 

337 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

338 ) -> Optional[pandas.DataFrame]: 

339 # docstring is inherited from a base class 

340 months = self.config.read_sources_months 

341 if months == 0: 

342 return None 

343 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

344 mjd_start = mjd_end - months * 30 

345 

346 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

347 

348 def getDiaForcedSources( 

349 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

350 ) -> Optional[pandas.DataFrame]: 

351 # docstring is inherited from a base class 

352 months = self.config.read_forced_sources_months 

353 if months == 0: 

354 return None 

355 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

356 mjd_start = mjd_end - months * 30 

357 

358 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

359 

360 def getInsertIds(self) -> list[ApdbInsertId] | None: 

361 # docstring is inherited from a base class 

362 if not self._schema.has_insert_id: 

363 return None 

364 

365 # everything goes into a single partition 

366 partition = 0 

367 

368 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

369 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?' 

370 

371 result = self._session.execute( 

372 self._prep_statement(query), 

373 (partition,), 

374 timeout=self.config.read_timeout, 

375 execution_profile="read_tuples", 

376 ) 

377 # order by insert_time 

378 rows = sorted(result) 

379 return [ApdbInsertId(row[1]) for row in rows] 

380 

381 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

382 # docstring is inherited from a base class 

383 if not self._schema.has_insert_id: 

384 raise ValueError("APDB is not configured for history storage") 

385 

386 insert_ids = [id.id for id in ids] 

387 params = ",".join("?" * len(insert_ids)) 

388 

389 # everything goes into a single partition 

390 partition = 0 

391 

392 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

393 query = ( 

394 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})' 

395 ) 

396 

397 self._session.execute( 

398 self._prep_statement(query), 

399 [partition] + insert_ids, 

400 timeout=self.config.write_timeout, 

401 ) 

402 

403 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

404 # docstring is inherited from a base class 

405 return self._get_history(ExtraTables.DiaObjectInsertId, ids) 

406 

407 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

408 # docstring is inherited from a base class 

409 return self._get_history(ExtraTables.DiaSourceInsertId, ids) 

410 

411 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

412 # docstring is inherited from a base class 

413 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids) 

414 

415 def getSSObjects(self) -> pandas.DataFrame: 

416 # docstring is inherited from a base class 

417 tableName = self._schema.tableName(ApdbTables.SSObject) 

418 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

419 

420 objects = None 

421 with Timer("SSObject select", self.config.timer): 

422 result = self._session.execute(query, execution_profile="read_pandas") 

423 objects = result._current_rows 

424 

425 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

426 return objects 

427 

428 def store( 

429 self, 

430 visit_time: dafBase.DateTime, 

431 objects: pandas.DataFrame, 

432 sources: Optional[pandas.DataFrame] = None, 

433 forced_sources: Optional[pandas.DataFrame] = None, 

434 ) -> None: 

435 # docstring is inherited from a base class 

436 

437 insert_id: ApdbInsertId | None = None 

438 if self._schema.has_insert_id: 

439 insert_id = ApdbInsertId.new_insert_id() 

440 self._storeInsertId(insert_id, visit_time) 

441 

442 # fill region partition column for DiaObjects 

443 objects = self._add_obj_part(objects) 

444 self._storeDiaObjects(objects, visit_time, insert_id) 

445 

446 if sources is not None: 

447 # copy apdb_part column from DiaObjects to DiaSources 

448 sources = self._add_src_part(sources, objects) 

449 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id) 

450 self._storeDiaSourcesPartitions(sources, visit_time, insert_id) 

451 

452 if forced_sources is not None: 

453 forced_sources = self._add_fsrc_part(forced_sources, objects) 

454 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id) 

455 

456 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

457 # docstring is inherited from a base class 

458 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

459 

460 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

461 # docstring is inherited from a base class 

462 

463 # To update a record we need to know its exact primary key (including 

464 # partition key) so we start by querying for diaSourceId to find the 

465 # primary keys. 

466 

467 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

468 # split it into 1k IDs per query 

469 selects: List[Tuple] = [] 

470 for ids in chunk_iterable(idMap.keys(), 1_000): 

471 ids_str = ",".join(str(item) for item in ids) 

472 selects.append( 

473 ( 

474 ( 

475 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" ' 

476 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

477 ), 

478 {}, 

479 ) 

480 ) 

481 

482 # No need for DataFrame here, read data as tuples. 

483 result = cast( 

484 List[Tuple[int, int, int, uuid.UUID | None]], 

485 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

486 ) 

487 

488 # Make mapping from source ID to its partition. 

489 id2partitions: Dict[int, Tuple[int, int]] = {} 

490 id2insert_id: Dict[int, ApdbInsertId] = {} 

491 for row in result: 

492 id2partitions[row[0]] = row[1:3] 

493 if row[3] is not None: 

494 id2insert_id[row[0]] = ApdbInsertId(row[3]) 

495 

496 # make sure we know partitions for each ID 

497 if set(id2partitions) != set(idMap): 

498 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

499 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

500 

501 # Reassign in standard tables 

502 queries = cassandra.query.BatchStatement() 

503 table_name = self._schema.tableName(ApdbTables.DiaSource) 

504 for diaSourceId, ssObjectId in idMap.items(): 

505 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

506 values: Tuple 

507 if self.config.time_partition_tables: 

508 query = ( 

509 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

510 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

511 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

512 ) 

513 values = (ssObjectId, apdb_part, diaSourceId) 

514 else: 

515 query = ( 

516 f'UPDATE "{self._keyspace}"."{table_name}"' 

517 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

518 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

519 ) 

520 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

521 queries.add(self._prep_statement(query), values) 

522 

523 # Reassign in history tables, only if history is enabled 

524 if id2insert_id: 

525 # Filter out insert ids that have been deleted already. There is a 

526 # potential race with concurrent removal of insert IDs, but it 

527 # should be handled by WHERE in UPDATE. 

528 known_ids = set() 

529 if insert_ids := self.getInsertIds(): 

530 known_ids = set(insert_ids) 

531 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids} 

532 if id2insert_id: 

533 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId) 

534 for diaSourceId, ssObjectId in idMap.items(): 

535 if insert_id := id2insert_id.get(diaSourceId): 

536 query = ( 

537 f'UPDATE "{self._keyspace}"."{table_name}" ' 

538 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

539 'WHERE "insert_id" = ? AND "diaSourceId" = ?' 

540 ) 

541 values = (ssObjectId, insert_id.id, diaSourceId) 

542 queries.add(self._prep_statement(query), values) 

543 

544 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

545 with Timer(table_name + " update", self.config.timer): 

546 self._session.execute(queries, execution_profile="write") 

547 

548 def dailyJob(self) -> None: 

549 # docstring is inherited from a base class 

550 pass 

551 

552 def countUnassociatedObjects(self) -> int: 

553 # docstring is inherited from a base class 

554 

555 # It's too inefficient to implement it for Cassandra in current schema. 

556 raise NotImplementedError() 

557 

558 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

559 """Make all execution profiles used in the code.""" 

560 if config.private_ips: 

561 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

562 else: 

563 loadBalancePolicy = RoundRobinPolicy() 

564 

565 read_tuples_profile = ExecutionProfile( 

566 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

567 request_timeout=config.read_timeout, 

568 row_factory=cassandra.query.tuple_factory, 

569 load_balancing_policy=loadBalancePolicy, 

570 ) 

571 read_pandas_profile = ExecutionProfile( 

572 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

573 request_timeout=config.read_timeout, 

574 row_factory=pandas_dataframe_factory, 

575 load_balancing_policy=loadBalancePolicy, 

576 ) 

577 read_raw_profile = ExecutionProfile( 

578 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

579 request_timeout=config.read_timeout, 

580 row_factory=raw_data_factory, 

581 load_balancing_policy=loadBalancePolicy, 

582 ) 

583 # Profile to use with select_concurrent to return pandas data frame 

584 read_pandas_multi_profile = ExecutionProfile( 

585 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

586 request_timeout=config.read_timeout, 

587 row_factory=pandas_dataframe_factory, 

588 load_balancing_policy=loadBalancePolicy, 

589 ) 

590 # Profile to use with select_concurrent to return raw data (columns and 

591 # rows) 

592 read_raw_multi_profile = ExecutionProfile( 

593 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

594 request_timeout=config.read_timeout, 

595 row_factory=raw_data_factory, 

596 load_balancing_policy=loadBalancePolicy, 

597 ) 

598 write_profile = ExecutionProfile( 

599 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

600 request_timeout=config.write_timeout, 

601 load_balancing_policy=loadBalancePolicy, 

602 ) 

603 # To replace default DCAwareRoundRobinPolicy 

604 default_profile = ExecutionProfile( 

605 load_balancing_policy=loadBalancePolicy, 

606 ) 

607 return { 

608 "read_tuples": read_tuples_profile, 

609 "read_pandas": read_pandas_profile, 

610 "read_raw": read_raw_profile, 

611 "read_pandas_multi": read_pandas_multi_profile, 

612 "read_raw_multi": read_raw_multi_profile, 

613 "write": write_profile, 

614 EXEC_PROFILE_DEFAULT: default_profile, 

615 } 

616 

617 def _getSources( 

618 self, 

619 region: sphgeom.Region, 

620 object_ids: Optional[Iterable[int]], 

621 mjd_start: float, 

622 mjd_end: float, 

623 table_name: ApdbTables, 

624 ) -> pandas.DataFrame: 

625 """Return catalog of DiaSource instances given set of DiaObject IDs. 

626 

627 Parameters 

628 ---------- 

629 region : `lsst.sphgeom.Region` 

630 Spherical region. 

631 object_ids : 

632 Collection of DiaObject IDs 

633 mjd_start : `float` 

634 Lower bound of time interval. 

635 mjd_end : `float` 

636 Upper bound of time interval. 

637 table_name : `ApdbTables` 

638 Name of the table. 

639 

640 Returns 

641 ------- 

642 catalog : `pandas.DataFrame`, or `None` 

643 Catalog containing DiaSource records. Empty catalog is returned if 

644 ``object_ids`` is empty. 

645 """ 

646 object_id_set: Set[int] = set() 

647 if object_ids is not None: 

648 object_id_set = set(object_ids) 

649 if len(object_id_set) == 0: 

650 return self._make_empty_catalog(table_name) 

651 

652 sp_where = self._spatial_where(region) 

653 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

654 

655 # We need to exclude extra partitioning columns from result. 

656 column_names = self._schema.apdbColumnNames(table_name) 

657 what = ",".join(_quote_column(column) for column in column_names) 

658 

659 # Build all queries 

660 statements: List[Tuple] = [] 

661 for table in tables: 

662 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

663 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

664 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

665 

666 with Timer(table_name.name + " select", self.config.timer): 

667 catalog = cast( 

668 pandas.DataFrame, 

669 select_concurrent( 

670 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

671 ), 

672 ) 

673 

674 # filter by given object IDs 

675 if len(object_id_set) > 0: 

676 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

677 

678 # precise filtering on midpointMjdTai 

679 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

680 

681 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

682 return catalog 

683 

684 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

685 """Return records from a particular table given set of insert IDs.""" 

686 if not self._schema.has_insert_id: 

687 raise ValueError("APDB is not configured for history retrieval") 

688 

689 insert_ids = [id.id for id in ids] 

690 params = ",".join("?" * len(insert_ids)) 

691 

692 table_name = self._schema.tableName(table) 

693 # I know that history table schema has only regular APDB columns plus 

694 # an insert_id column, and this is exactly what we need to return from 

695 # this method, so selecting a star is fine here. 

696 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

697 statement = self._prep_statement(query) 

698 

699 with Timer("DiaObject history", self.config.timer): 

700 result = self._session.execute(statement, insert_ids, execution_profile="read_raw") 

701 table_data = cast(ApdbCassandraTableData, result._current_rows) 

702 return table_data 

703 

704 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None: 

705 # Cassandra timestamp uses milliseconds since epoch 

706 timestamp = visit_time.nsecs() // 1_000_000 

707 

708 # everything goes into a single partition 

709 partition = 0 

710 

711 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

712 query = ( 

713 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) ' 

714 "VALUES (?, ?, ?)" 

715 ) 

716 

717 self._session.execute( 

718 self._prep_statement(query), 

719 (partition, insert_id.id, timestamp), 

720 timeout=self.config.write_timeout, 

721 execution_profile="write", 

722 ) 

723 

724 def _storeDiaObjects( 

725 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

726 ) -> None: 

727 """Store catalog of DiaObjects from current visit. 

728 

729 Parameters 

730 ---------- 

731 objs : `pandas.DataFrame` 

732 Catalog with DiaObject records 

733 visit_time : `lsst.daf.base.DateTime` 

734 Time of the current visit. 

735 """ 

736 visit_time_dt = visit_time.toPython() 

737 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

738 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

739 

740 extra_columns["validityStart"] = visit_time_dt 

741 time_part: Optional[int] = self._time_partition(visit_time) 

742 if not self.config.time_partition_tables: 

743 extra_columns["apdb_time_part"] = time_part 

744 time_part = None 

745 

746 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part) 

747 

748 if insert_id is not None: 

749 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt) 

750 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns) 

751 

752 def _storeDiaSources( 

753 self, 

754 table_name: ApdbTables, 

755 sources: pandas.DataFrame, 

756 visit_time: dafBase.DateTime, 

757 insert_id: ApdbInsertId | None, 

758 ) -> None: 

759 """Store catalog of DIASources or DIAForcedSources from current visit. 

760 

761 Parameters 

762 ---------- 

763 sources : `pandas.DataFrame` 

764 Catalog containing DiaSource records 

765 visit_time : `lsst.daf.base.DateTime` 

766 Time of the current visit. 

767 """ 

768 time_part: Optional[int] = self._time_partition(visit_time) 

769 extra_columns: dict[str, Any] = {} 

770 if not self.config.time_partition_tables: 

771 extra_columns["apdb_time_part"] = time_part 

772 time_part = None 

773 

774 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

775 

776 if insert_id is not None: 

777 extra_columns = dict(insert_id=insert_id.id) 

778 if table_name is ApdbTables.DiaSource: 

779 extra_table = ExtraTables.DiaSourceInsertId 

780 else: 

781 extra_table = ExtraTables.DiaForcedSourceInsertId 

782 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

783 

784 def _storeDiaSourcesPartitions( 

785 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

786 ) -> None: 

787 """Store mapping of diaSourceId to its partitioning values. 

788 

789 Parameters 

790 ---------- 

791 sources : `pandas.DataFrame` 

792 Catalog containing DiaSource records 

793 visit_time : `lsst.daf.base.DateTime` 

794 Time of the current visit. 

795 """ 

796 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

797 extra_columns = { 

798 "apdb_time_part": self._time_partition(visit_time), 

799 "insert_id": insert_id.id if insert_id is not None else None, 

800 } 

801 

802 self._storeObjectsPandas( 

803 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

804 ) 

805 

806 def _storeObjectsPandas( 

807 self, 

808 records: pandas.DataFrame, 

809 table_name: Union[ApdbTables, ExtraTables], 

810 extra_columns: Optional[Mapping] = None, 

811 time_part: Optional[int] = None, 

812 ) -> None: 

813 """Store generic objects. 

814 

815 Takes Pandas catalog and stores a bunch of records in a table. 

816 

817 Parameters 

818 ---------- 

819 records : `pandas.DataFrame` 

820 Catalog containing object records 

821 table_name : `ApdbTables` 

822 Name of the table as defined in APDB schema. 

823 extra_columns : `dict`, optional 

824 Mapping (column_name, column_value) which gives fixed values for 

825 columns in each row, overrides values in ``records`` if matching 

826 columns exist there. 

827 time_part : `int`, optional 

828 If not `None` then insert into a per-partition table. 

829 

830 Notes 

831 ----- 

832 If Pandas catalog contains additional columns not defined in table 

833 schema they are ignored. Catalog does not have to contain all columns 

834 defined in a table, but partition and clustering keys must be present 

835 in a catalog or ``extra_columns``. 

836 """ 

837 # use extra columns if specified 

838 if extra_columns is None: 

839 extra_columns = {} 

840 extra_fields = list(extra_columns.keys()) 

841 

842 # Fields that will come from dataframe. 

843 df_fields = [column for column in records.columns if column not in extra_fields] 

844 

845 column_map = self._schema.getColumnMap(table_name) 

846 # list of columns (as in felis schema) 

847 fields = [column_map[field].name for field in df_fields if field in column_map] 

848 fields += extra_fields 

849 

850 # check that all partitioning and clustering columns are defined 

851 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

852 table_name 

853 ) 

854 missing_columns = [column for column in required_columns if column not in fields] 

855 if missing_columns: 

856 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

857 

858 qfields = [quote_id(field) for field in fields] 

859 qfields_str = ",".join(qfields) 

860 

861 with Timer(table_name.name + " query build", self.config.timer): 

862 table = self._schema.tableName(table_name) 

863 if time_part is not None: 

864 table = f"{table}_{time_part}" 

865 

866 holders = ",".join(["?"] * len(qfields)) 

867 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

868 statement = self._prep_statement(query) 

869 queries = cassandra.query.BatchStatement() 

870 for rec in records.itertuples(index=False): 

871 values = [] 

872 for field in df_fields: 

873 if field not in column_map: 

874 continue 

875 value = getattr(rec, field) 

876 if column_map[field].datatype is felis.types.Timestamp: 

877 if isinstance(value, pandas.Timestamp): 

878 value = literal(value.to_pydatetime()) 

879 else: 

880 # Assume it's seconds since epoch, Cassandra 

881 # datetime is in milliseconds 

882 value = int(value * 1000) 

883 values.append(literal(value)) 

884 for field in extra_fields: 

885 value = extra_columns[field] 

886 values.append(literal(value)) 

887 queries.add(statement, values) 

888 

889 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

890 with Timer(table_name.name + " insert", self.config.timer): 

891 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

892 

893 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

894 """Calculate spatial partition for each record and add it to a 

895 DataFrame. 

896 

897 Notes 

898 ----- 

899 This overrides any existing column in a DataFrame with the same name 

900 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

901 returned. 

902 """ 

903 # calculate HTM index for every DiaObject 

904 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

905 ra_col, dec_col = self.config.ra_dec_columns 

906 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

907 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

908 idx = self._pixelization.pixel(uv3d) 

909 apdb_part[i] = idx 

910 df = df.copy() 

911 df["apdb_part"] = apdb_part 

912 return df 

913 

914 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

915 """Add apdb_part column to DiaSource catalog. 

916 

917 Notes 

918 ----- 

919 This method copies apdb_part value from a matching DiaObject record. 

920 DiaObject catalog needs to have a apdb_part column filled by 

921 ``_add_obj_part`` method and DiaSource records need to be 

922 associated to DiaObjects via ``diaObjectId`` column. 

923 

924 This overrides any existing column in a DataFrame with the same name 

925 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

926 returned. 

927 """ 

928 pixel_id_map: Dict[int, int] = { 

929 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

930 } 

931 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

932 ra_col, dec_col = self.config.ra_dec_columns 

933 for i, (diaObjId, ra, dec) in enumerate( 

934 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

935 ): 

936 if diaObjId == 0: 

937 # DiaSources associated with SolarSystemObjects do not have an 

938 # associated DiaObject hence we skip them and set partition 

939 # based on its own ra/dec 

940 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

941 idx = self._pixelization.pixel(uv3d) 

942 apdb_part[i] = idx 

943 else: 

944 apdb_part[i] = pixel_id_map[diaObjId] 

945 sources = sources.copy() 

946 sources["apdb_part"] = apdb_part 

947 return sources 

948 

949 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

950 """Add apdb_part column to DiaForcedSource catalog. 

951 

952 Notes 

953 ----- 

954 This method copies apdb_part value from a matching DiaObject record. 

955 DiaObject catalog needs to have a apdb_part column filled by 

956 ``_add_obj_part`` method and DiaSource records need to be 

957 associated to DiaObjects via ``diaObjectId`` column. 

958 

959 This overrides any existing column in a DataFrame with the same name 

960 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

961 returned. 

962 """ 

963 pixel_id_map: Dict[int, int] = { 

964 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

965 } 

966 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

967 for i, diaObjId in enumerate(sources["diaObjectId"]): 

968 apdb_part[i] = pixel_id_map[diaObjId] 

969 sources = sources.copy() 

970 sources["apdb_part"] = apdb_part 

971 return sources 

972 

973 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

974 """Calculate time partiton number for a given time. 

975 

976 Parameters 

977 ---------- 

978 time : `float` or `lsst.daf.base.DateTime` 

979 Time for which to calculate partition number. Can be float to mean 

980 MJD or `lsst.daf.base.DateTime` 

981 

982 Returns 

983 ------- 

984 partition : `int` 

985 Partition number for a given time. 

986 """ 

987 if isinstance(time, dafBase.DateTime): 

988 mjd = time.get(system=dafBase.DateTime.MJD) 

989 else: 

990 mjd = time 

991 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

992 partition = int(days_since_epoch) // self.config.time_partition_days 

993 return partition 

994 

995 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

996 """Make an empty catalog for a table with a given name. 

997 

998 Parameters 

999 ---------- 

1000 table_name : `ApdbTables` 

1001 Name of the table. 

1002 

1003 Returns 

1004 ------- 

1005 catalog : `pandas.DataFrame` 

1006 An empty catalog. 

1007 """ 

1008 table = self._schema.tableSchemas[table_name] 

1009 

1010 data = { 

1011 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

1012 for columnDef in table.columns 

1013 } 

1014 return pandas.DataFrame(data) 

1015 

1016 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement: 

1017 """Convert query string into prepared statement.""" 

1018 stmt = self._prepared_statements.get(query) 

1019 if stmt is None: 

1020 stmt = self._session.prepare(query) 

1021 self._prepared_statements[query] = stmt 

1022 return stmt 

1023 

1024 def _combine_where( 

1025 self, 

1026 prefix: str, 

1027 where1: List[Tuple[str, Tuple]], 

1028 where2: List[Tuple[str, Tuple]], 

1029 suffix: Optional[str] = None, 

1030 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]: 

1031 """Make cartesian product of two parts of WHERE clause into a series 

1032 of statements to execute. 

1033 

1034 Parameters 

1035 ---------- 

1036 prefix : `str` 

1037 Initial statement prefix that comes before WHERE clause, e.g. 

1038 "SELECT * from Table" 

1039 """ 

1040 # If lists are empty use special sentinels. 

1041 if not where1: 

1042 where1 = [("", ())] 

1043 if not where2: 

1044 where2 = [("", ())] 

1045 

1046 for expr1, params1 in where1: 

1047 for expr2, params2 in where2: 

1048 full_query = prefix 

1049 wheres = [] 

1050 if expr1: 

1051 wheres.append(expr1) 

1052 if expr2: 

1053 wheres.append(expr2) 

1054 if wheres: 

1055 full_query += " WHERE " + " AND ".join(wheres) 

1056 if suffix: 

1057 full_query += " " + suffix 

1058 params = params1 + params2 

1059 if params: 

1060 statement = self._prep_statement(full_query) 

1061 else: 

1062 # If there are no params then it is likely that query 

1063 # has a bunch of literals rendered already, no point 

1064 # trying to prepare it. 

1065 statement = cassandra.query.SimpleStatement(full_query) 

1066 yield (statement, params) 

1067 

1068 def _spatial_where( 

1069 self, region: Optional[sphgeom.Region], use_ranges: bool = False 

1070 ) -> List[Tuple[str, Tuple]]: 

1071 """Generate expressions for spatial part of WHERE clause. 

1072 

1073 Parameters 

1074 ---------- 

1075 region : `sphgeom.Region` 

1076 Spatial region for query results. 

1077 use_ranges : `bool` 

1078 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1079 p2") instead of exact list of pixels. Should be set to True for 

1080 large regions covering very many pixels. 

1081 

1082 Returns 

1083 ------- 

1084 expressions : `list` [ `tuple` ] 

1085 Empty list is returned if ``region`` is `None`, otherwise a list 

1086 of one or more (expression, parameters) tuples 

1087 """ 

1088 if region is None: 

1089 return [] 

1090 if use_ranges: 

1091 pixel_ranges = self._pixelization.envelope(region) 

1092 expressions: List[Tuple[str, Tuple]] = [] 

1093 for lower, upper in pixel_ranges: 

1094 upper -= 1 

1095 if lower == upper: 

1096 expressions.append(('"apdb_part" = ?', (lower,))) 

1097 else: 

1098 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1099 return expressions 

1100 else: 

1101 pixels = self._pixelization.pixels(region) 

1102 if self.config.query_per_spatial_part: 

1103 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1104 else: 

1105 pixels_str = ",".join([str(pix) for pix in pixels]) 

1106 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1107 

1108 def _temporal_where( 

1109 self, 

1110 table: ApdbTables, 

1111 start_time: Union[float, dafBase.DateTime], 

1112 end_time: Union[float, dafBase.DateTime], 

1113 query_per_time_part: Optional[bool] = None, 

1114 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]: 

1115 """Generate table names and expressions for temporal part of WHERE 

1116 clauses. 

1117 

1118 Parameters 

1119 ---------- 

1120 table : `ApdbTables` 

1121 Table to select from. 

1122 start_time : `dafBase.DateTime` or `float` 

1123 Starting Datetime of MJD value of the time range. 

1124 start_time : `dafBase.DateTime` or `float` 

1125 Starting Datetime of MJD value of the time range. 

1126 query_per_time_part : `bool`, optional 

1127 If None then use ``query_per_time_part`` from configuration. 

1128 

1129 Returns 

1130 ------- 

1131 tables : `list` [ `str` ] 

1132 List of the table names to query. 

1133 expressions : `list` [ `tuple` ] 

1134 A list of zero or more (expression, parameters) tuples. 

1135 """ 

1136 tables: List[str] 

1137 temporal_where: List[Tuple[str, Tuple]] = [] 

1138 table_name = self._schema.tableName(table) 

1139 time_part_start = self._time_partition(start_time) 

1140 time_part_end = self._time_partition(end_time) 

1141 time_parts = list(range(time_part_start, time_part_end + 1)) 

1142 if self.config.time_partition_tables: 

1143 tables = [f"{table_name}_{part}" for part in time_parts] 

1144 else: 

1145 tables = [table_name] 

1146 if query_per_time_part is None: 

1147 query_per_time_part = self.config.query_per_time_part 

1148 if query_per_time_part: 

1149 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1150 else: 

1151 time_part_list = ",".join([str(part) for part in time_parts]) 

1152 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1153 

1154 return tables, temporal_where