Coverage for python/lsst/dax/apdb/apdbCassandra.py: 14%

449 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-05 01:34 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import logging 

27import uuid 

28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast 

29 

30import numpy as np 

31import pandas 

32 

33# If cassandra-driver is not there the module can still be imported 

34# but ApdbCassandra cannot be instantiated. 

35try: 

36 import cassandra 

37 import cassandra.query 

38 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile 

39 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

40 

41 CASSANDRA_IMPORTED = True 

42except ImportError: 

43 CASSANDRA_IMPORTED = False 

44 

45import felis.types 

46import lsst.daf.base as dafBase 

47from felis.simple import Table 

48from lsst import sphgeom 

49from lsst.pex.config import ChoiceField, Field, ListField 

50from lsst.utils.iteration import chunk_iterable 

51 

52from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

53from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

54from .apdbSchema import ApdbTables 

55from .cassandra_utils import ( 

56 ApdbCassandraTableData, 

57 literal, 

58 pandas_dataframe_factory, 

59 quote_id, 

60 raw_data_factory, 

61 select_concurrent, 

62) 

63from .pixelization import Pixelization 

64from .timer import Timer 

65 

66_LOG = logging.getLogger(__name__) 

67 

68 

69class CassandraMissingError(Exception): 

70 def __init__(self) -> None: 

71 super().__init__("cassandra-driver module cannot be imported") 

72 

73 

74class ApdbCassandraConfig(ApdbConfig): 

75 contact_points = ListField[str]( 

76 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

77 ) 

78 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

79 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

80 read_consistency = Field[str]( 

81 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

82 ) 

83 write_consistency = Field[str]( 

84 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

85 ) 

86 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

87 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0) 

88 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

89 protocol_version = Field[int]( 

90 doc="Cassandra protocol version to use, default is V4", 

91 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

92 ) 

93 dia_object_columns = ListField[str]( 

94 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

95 ) 

96 prefix = Field[str](doc="Prefix to add to table names", default="") 

97 part_pixelization = ChoiceField[str]( 

98 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

99 doc="Pixelization used for partitioning index.", 

100 default="mq3c", 

101 ) 

102 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

103 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

104 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names of ra/dec columns in DiaObject table") 

105 timer = Field[bool](doc="If True then print/log timing information", default=False) 

106 time_partition_tables = Field[bool]( 

107 doc="Use per-partition tables for sources instead of partitioning by time", default=True 

108 ) 

109 time_partition_days = Field[int]( 

110 doc=( 

111 "Time partitioning granularity in days, this value must not be changed after database is " 

112 "initialized" 

113 ), 

114 default=30, 

115 ) 

116 time_partition_start = Field[str]( 

117 doc=( 

118 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

119 "This is used only when time_partition_tables is True." 

120 ), 

121 default="2018-12-01T00:00:00", 

122 ) 

123 time_partition_end = Field[str]( 

124 doc=( 

125 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

126 "This is used only when time_partition_tables is True." 

127 ), 

128 default="2030-01-01T00:00:00", 

129 ) 

130 query_per_time_part = Field[bool]( 

131 default=False, 

132 doc=( 

133 "If True then build separate query for each time partition, otherwise build one single query. " 

134 "This is only used when time_partition_tables is False in schema config." 

135 ), 

136 ) 

137 query_per_spatial_part = Field[bool]( 

138 default=False, 

139 doc="If True then build one query per spatial partition, otherwise build single query.", 

140 ) 

141 

142 

143if CASSANDRA_IMPORTED: 143 ↛ 145line 143 didn't jump to line 145, because the condition on line 143 was never true

144 

145 class _AddressTranslator(AddressTranslator): 

146 """Translate internal IP address to external. 

147 

148 Only used for docker-based setup, not viable long-term solution. 

149 """ 

150 

151 def __init__(self, public_ips: List[str], private_ips: List[str]): 

152 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

153 

154 def translate(self, private_ip: str) -> str: 

155 return self._map.get(private_ip, private_ip) 

156 

157 

158def _quote_column(name: str) -> str: 

159 """Quote column name""" 

160 if name.islower(): 

161 return name 

162 else: 

163 return f'"{name}"' 

164 

165 

166class ApdbCassandra(Apdb): 

167 """Implementation of APDB database on to of Apache Cassandra. 

168 

169 The implementation is configured via standard ``pex_config`` mechanism 

170 using `ApdbCassandraConfig` configuration class. For an example of 

171 different configurations check config/ folder. 

172 

173 Parameters 

174 ---------- 

175 config : `ApdbCassandraConfig` 

176 Configuration object. 

177 """ 

178 

179 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

180 """Start time for partition 0, this should never be changed.""" 

181 

182 def __init__(self, config: ApdbCassandraConfig): 

183 if not CASSANDRA_IMPORTED: 

184 raise CassandraMissingError() 

185 

186 config.validate() 

187 self.config = config 

188 

189 _LOG.debug("ApdbCassandra Configuration:") 

190 for key, value in self.config.items(): 

191 _LOG.debug(" %s: %s", key, value) 

192 

193 self._pixelization = Pixelization( 

194 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges 

195 ) 

196 

197 addressTranslator: Optional[AddressTranslator] = None 

198 if config.private_ips: 

199 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

200 

201 self._keyspace = config.keyspace 

202 

203 self._cluster = Cluster( 

204 execution_profiles=self._makeProfiles(config), 

205 contact_points=self.config.contact_points, 

206 address_translator=addressTranslator, 

207 protocol_version=self.config.protocol_version, 

208 ) 

209 self._session = self._cluster.connect() 

210 # Disable result paging 

211 self._session.default_fetch_size = None 

212 

213 self._schema = ApdbCassandraSchema( 

214 session=self._session, 

215 keyspace=self._keyspace, 

216 schema_file=self.config.schema_file, 

217 schema_name=self.config.schema_name, 

218 prefix=self.config.prefix, 

219 time_partition_tables=self.config.time_partition_tables, 

220 use_insert_id=self.config.use_insert_id, 

221 ) 

222 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

223 

224 # Cache for prepared statements 

225 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {} 

226 

227 def __del__(self) -> None: 

228 self._cluster.shutdown() 

229 

230 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

231 # docstring is inherited from a base class 

232 return self._schema.tableSchemas.get(table) 

233 

234 def makeSchema(self, drop: bool = False) -> None: 

235 # docstring is inherited from a base class 

236 

237 if self.config.time_partition_tables: 

238 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

239 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

240 part_range = ( 

241 self._time_partition(time_partition_start), 

242 self._time_partition(time_partition_end) + 1, 

243 ) 

244 self._schema.makeSchema(drop=drop, part_range=part_range) 

245 else: 

246 self._schema.makeSchema(drop=drop) 

247 

248 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

249 # docstring is inherited from a base class 

250 

251 sp_where = self._spatial_where(region) 

252 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

253 

254 # We need to exclude extra partitioning columns from result. 

255 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

256 what = ",".join(_quote_column(column) for column in column_names) 

257 

258 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

259 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

260 statements: List[Tuple] = [] 

261 for where, params in sp_where: 

262 full_query = f"{query} WHERE {where}" 

263 if params: 

264 statement = self._prep_statement(full_query) 

265 else: 

266 # If there are no params then it is likely that query has a 

267 # bunch of literals rendered already, no point trying to 

268 # prepare it because it's not reusable. 

269 statement = cassandra.query.SimpleStatement(full_query) 

270 statements.append((statement, params)) 

271 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

272 

273 with Timer("DiaObject select", self.config.timer): 

274 objects = cast( 

275 pandas.DataFrame, 

276 select_concurrent( 

277 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

278 ), 

279 ) 

280 

281 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

282 return objects 

283 

284 def getDiaSources( 

285 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

286 ) -> Optional[pandas.DataFrame]: 

287 # docstring is inherited from a base class 

288 months = self.config.read_sources_months 

289 if months == 0: 

290 return None 

291 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

292 mjd_start = mjd_end - months * 30 

293 

294 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

295 

296 def getDiaForcedSources( 

297 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

298 ) -> Optional[pandas.DataFrame]: 

299 # docstring is inherited from a base class 

300 months = self.config.read_forced_sources_months 

301 if months == 0: 

302 return None 

303 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

304 mjd_start = mjd_end - months * 30 

305 

306 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

307 

308 def getInsertIds(self) -> list[ApdbInsertId] | None: 

309 # docstring is inherited from a base class 

310 if not self._schema.has_insert_id: 

311 return None 

312 

313 # everything goes into a single partition 

314 partition = 0 

315 

316 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

317 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?' 

318 

319 result = self._session.execute( 

320 self._prep_statement(query), 

321 (partition,), 

322 timeout=self.config.read_timeout, 

323 execution_profile="read_tuples", 

324 ) 

325 # order by insert_time 

326 rows = sorted(result) 

327 return [ApdbInsertId(row[1]) for row in rows] 

328 

329 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

330 # docstring is inherited from a base class 

331 if not self._schema.has_insert_id: 

332 raise ValueError("APDB is not configured for history storage") 

333 

334 insert_ids = [id.id for id in ids] 

335 params = ",".join("?" * len(insert_ids)) 

336 

337 # everything goes into a single partition 

338 partition = 0 

339 

340 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

341 query = ( 

342 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})' 

343 ) 

344 

345 self._session.execute( 

346 self._prep_statement(query), 

347 [partition] + insert_ids, 

348 timeout=self.config.write_timeout, 

349 ) 

350 

351 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

352 # docstring is inherited from a base class 

353 return self._get_history(ExtraTables.DiaObjectInsertId, ids) 

354 

355 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

356 # docstring is inherited from a base class 

357 return self._get_history(ExtraTables.DiaSourceInsertId, ids) 

358 

359 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

360 # docstring is inherited from a base class 

361 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids) 

362 

363 def getSSObjects(self) -> pandas.DataFrame: 

364 # docstring is inherited from a base class 

365 tableName = self._schema.tableName(ApdbTables.SSObject) 

366 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

367 

368 objects = None 

369 with Timer("SSObject select", self.config.timer): 

370 result = self._session.execute(query, execution_profile="read_pandas") 

371 objects = result._current_rows 

372 

373 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

374 return objects 

375 

376 def store( 

377 self, 

378 visit_time: dafBase.DateTime, 

379 objects: pandas.DataFrame, 

380 sources: Optional[pandas.DataFrame] = None, 

381 forced_sources: Optional[pandas.DataFrame] = None, 

382 ) -> None: 

383 # docstring is inherited from a base class 

384 

385 insert_id: ApdbInsertId | None = None 

386 if self._schema.has_insert_id: 

387 insert_id = ApdbInsertId.new_insert_id() 

388 self._storeInsertId(insert_id, visit_time) 

389 

390 # fill region partition column for DiaObjects 

391 objects = self._add_obj_part(objects) 

392 self._storeDiaObjects(objects, visit_time, insert_id) 

393 

394 if sources is not None: 

395 # copy apdb_part column from DiaObjects to DiaSources 

396 sources = self._add_src_part(sources, objects) 

397 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id) 

398 self._storeDiaSourcesPartitions(sources, visit_time, insert_id) 

399 

400 if forced_sources is not None: 

401 forced_sources = self._add_fsrc_part(forced_sources, objects) 

402 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id) 

403 

404 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

405 # docstring is inherited from a base class 

406 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

407 

408 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

409 # docstring is inherited from a base class 

410 

411 # To update a record we need to know its exact primary key (including 

412 # partition key) so we start by querying for diaSourceId to find the 

413 # primary keys. 

414 

415 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

416 # split it into 1k IDs per query 

417 selects: List[Tuple] = [] 

418 for ids in chunk_iterable(idMap.keys(), 1_000): 

419 ids_str = ",".join(str(item) for item in ids) 

420 selects.append( 

421 ( 

422 ( 

423 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" ' 

424 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

425 ), 

426 {}, 

427 ) 

428 ) 

429 

430 # No need for DataFrame here, read data as tuples. 

431 result = cast( 

432 List[Tuple[int, int, int, uuid.UUID | None]], 

433 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

434 ) 

435 

436 # Make mapping from source ID to its partition. 

437 id2partitions: Dict[int, Tuple[int, int]] = {} 

438 id2insert_id: Dict[int, ApdbInsertId] = {} 

439 for row in result: 

440 id2partitions[row[0]] = row[1:3] 

441 if row[3] is not None: 

442 id2insert_id[row[0]] = ApdbInsertId(row[3]) 

443 

444 # make sure we know partitions for each ID 

445 if set(id2partitions) != set(idMap): 

446 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

447 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

448 

449 # Reassign in standard tables 

450 queries = cassandra.query.BatchStatement() 

451 table_name = self._schema.tableName(ApdbTables.DiaSource) 

452 for diaSourceId, ssObjectId in idMap.items(): 

453 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

454 values: Tuple 

455 if self.config.time_partition_tables: 

456 query = ( 

457 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

458 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

459 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

460 ) 

461 values = (ssObjectId, apdb_part, diaSourceId) 

462 else: 

463 query = ( 

464 f'UPDATE "{self._keyspace}"."{table_name}"' 

465 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

466 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

467 ) 

468 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

469 queries.add(self._prep_statement(query), values) 

470 

471 # Reassign in history tables, only if history is enabled 

472 if id2insert_id: 

473 # Filter out insert ids that have been deleted already. There is a 

474 # potential race with concurrent removal of insert IDs, but it 

475 # should be handled by WHERE in UPDATE. 

476 known_ids = set() 

477 if insert_ids := self.getInsertIds(): 

478 known_ids = set(insert_ids) 

479 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids} 

480 if id2insert_id: 

481 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId) 

482 for diaSourceId, ssObjectId in idMap.items(): 

483 if insert_id := id2insert_id.get(diaSourceId): 

484 query = ( 

485 f'UPDATE "{self._keyspace}"."{table_name}" ' 

486 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

487 'WHERE "insert_id" = ? AND "diaSourceId" = ?' 

488 ) 

489 values = (ssObjectId, insert_id.id, diaSourceId) 

490 queries.add(self._prep_statement(query), values) 

491 

492 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

493 with Timer(table_name + " update", self.config.timer): 

494 self._session.execute(queries, execution_profile="write") 

495 

496 def dailyJob(self) -> None: 

497 # docstring is inherited from a base class 

498 pass 

499 

500 def countUnassociatedObjects(self) -> int: 

501 # docstring is inherited from a base class 

502 

503 # It's too inefficient to implement it for Cassandra in current schema. 

504 raise NotImplementedError() 

505 

506 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

507 """Make all execution profiles used in the code.""" 

508 

509 if config.private_ips: 

510 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

511 else: 

512 loadBalancePolicy = RoundRobinPolicy() 

513 

514 read_tuples_profile = ExecutionProfile( 

515 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

516 request_timeout=config.read_timeout, 

517 row_factory=cassandra.query.tuple_factory, 

518 load_balancing_policy=loadBalancePolicy, 

519 ) 

520 read_pandas_profile = ExecutionProfile( 

521 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

522 request_timeout=config.read_timeout, 

523 row_factory=pandas_dataframe_factory, 

524 load_balancing_policy=loadBalancePolicy, 

525 ) 

526 read_raw_profile = ExecutionProfile( 

527 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

528 request_timeout=config.read_timeout, 

529 row_factory=raw_data_factory, 

530 load_balancing_policy=loadBalancePolicy, 

531 ) 

532 # Profile to use with select_concurrent to return pandas data frame 

533 read_pandas_multi_profile = ExecutionProfile( 

534 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

535 request_timeout=config.read_timeout, 

536 row_factory=pandas_dataframe_factory, 

537 load_balancing_policy=loadBalancePolicy, 

538 ) 

539 # Profile to use with select_concurrent to return raw data (columns and 

540 # rows) 

541 read_raw_multi_profile = ExecutionProfile( 

542 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

543 request_timeout=config.read_timeout, 

544 row_factory=raw_data_factory, 

545 load_balancing_policy=loadBalancePolicy, 

546 ) 

547 write_profile = ExecutionProfile( 

548 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

549 request_timeout=config.write_timeout, 

550 load_balancing_policy=loadBalancePolicy, 

551 ) 

552 # To replace default DCAwareRoundRobinPolicy 

553 default_profile = ExecutionProfile( 

554 load_balancing_policy=loadBalancePolicy, 

555 ) 

556 return { 

557 "read_tuples": read_tuples_profile, 

558 "read_pandas": read_pandas_profile, 

559 "read_raw": read_raw_profile, 

560 "read_pandas_multi": read_pandas_multi_profile, 

561 "read_raw_multi": read_raw_multi_profile, 

562 "write": write_profile, 

563 EXEC_PROFILE_DEFAULT: default_profile, 

564 } 

565 

566 def _getSources( 

567 self, 

568 region: sphgeom.Region, 

569 object_ids: Optional[Iterable[int]], 

570 mjd_start: float, 

571 mjd_end: float, 

572 table_name: ApdbTables, 

573 ) -> pandas.DataFrame: 

574 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

575 

576 Parameters 

577 ---------- 

578 region : `lsst.sphgeom.Region` 

579 Spherical region. 

580 object_ids : 

581 Collection of DiaObject IDs 

582 mjd_start : `float` 

583 Lower bound of time interval. 

584 mjd_end : `float` 

585 Upper bound of time interval. 

586 table_name : `ApdbTables` 

587 Name of the table. 

588 

589 Returns 

590 ------- 

591 catalog : `pandas.DataFrame`, or `None` 

592 Catalog containing DiaSource records. Empty catalog is returned if 

593 ``object_ids`` is empty. 

594 """ 

595 object_id_set: Set[int] = set() 

596 if object_ids is not None: 

597 object_id_set = set(object_ids) 

598 if len(object_id_set) == 0: 

599 return self._make_empty_catalog(table_name) 

600 

601 sp_where = self._spatial_where(region) 

602 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

603 

604 # We need to exclude extra partitioning columns from result. 

605 column_names = self._schema.apdbColumnNames(table_name) 

606 what = ",".join(_quote_column(column) for column in column_names) 

607 

608 # Build all queries 

609 statements: List[Tuple] = [] 

610 for table in tables: 

611 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

612 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

613 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

614 

615 with Timer(table_name.name + " select", self.config.timer): 

616 catalog = cast( 

617 pandas.DataFrame, 

618 select_concurrent( 

619 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

620 ), 

621 ) 

622 

623 # filter by given object IDs 

624 if len(object_id_set) > 0: 

625 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

626 

627 # precise filtering on midPointTai 

628 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start]) 

629 

630 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

631 return catalog 

632 

633 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

634 """Return records from a particular table given set of insert IDs.""" 

635 if not self._schema.has_insert_id: 

636 raise ValueError("APDB is not configured for history retrieval") 

637 

638 insert_ids = [id.id for id in ids] 

639 params = ",".join("?" * len(insert_ids)) 

640 

641 table_name = self._schema.tableName(table) 

642 # I know that history table schema has only regular APDB columns plus 

643 # an insert_id column, and this is exactly what we need to return from 

644 # this method, so selecting a star is fine here. 

645 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

646 statement = self._prep_statement(query) 

647 

648 with Timer("DiaObject history", self.config.timer): 

649 result = self._session.execute(statement, insert_ids, execution_profile="read_raw") 

650 table_data = cast(ApdbCassandraTableData, result._current_rows) 

651 return table_data 

652 

653 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None: 

654 # Cassandra timestamp uses milliseconds since epoch 

655 timestamp = visit_time.nsecs() // 1_000_000 

656 

657 # everything goes into a single partition 

658 partition = 0 

659 

660 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

661 query = ( 

662 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) ' 

663 "VALUES (?, ?, ?)" 

664 ) 

665 

666 self._session.execute( 

667 self._prep_statement(query), 

668 (partition, insert_id.id, timestamp), 

669 timeout=self.config.write_timeout, 

670 execution_profile="write", 

671 ) 

672 

673 def _storeDiaObjects( 

674 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

675 ) -> None: 

676 """Store catalog of DiaObjects from current visit. 

677 

678 Parameters 

679 ---------- 

680 objs : `pandas.DataFrame` 

681 Catalog with DiaObject records 

682 visit_time : `lsst.daf.base.DateTime` 

683 Time of the current visit. 

684 """ 

685 visit_time_dt = visit_time.toPython() 

686 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

687 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

688 

689 extra_columns["validityStart"] = visit_time_dt 

690 time_part: Optional[int] = self._time_partition(visit_time) 

691 if not self.config.time_partition_tables: 

692 extra_columns["apdb_time_part"] = time_part 

693 time_part = None 

694 

695 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part) 

696 

697 if insert_id is not None: 

698 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt) 

699 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns) 

700 

701 def _storeDiaSources( 

702 self, 

703 table_name: ApdbTables, 

704 sources: pandas.DataFrame, 

705 visit_time: dafBase.DateTime, 

706 insert_id: ApdbInsertId | None, 

707 ) -> None: 

708 """Store catalog of DIASources or DIAForcedSources from current visit. 

709 

710 Parameters 

711 ---------- 

712 sources : `pandas.DataFrame` 

713 Catalog containing DiaSource records 

714 visit_time : `lsst.daf.base.DateTime` 

715 Time of the current visit. 

716 """ 

717 time_part: Optional[int] = self._time_partition(visit_time) 

718 extra_columns: dict[str, Any] = {} 

719 if not self.config.time_partition_tables: 

720 extra_columns["apdb_time_part"] = time_part 

721 time_part = None 

722 

723 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

724 

725 if insert_id is not None: 

726 extra_columns = dict(insert_id=insert_id.id) 

727 if table_name is ApdbTables.DiaSource: 

728 extra_table = ExtraTables.DiaSourceInsertId 

729 else: 

730 extra_table = ExtraTables.DiaForcedSourceInsertId 

731 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

732 

733 def _storeDiaSourcesPartitions( 

734 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

735 ) -> None: 

736 """Store mapping of diaSourceId to its partitioning values. 

737 

738 Parameters 

739 ---------- 

740 sources : `pandas.DataFrame` 

741 Catalog containing DiaSource records 

742 visit_time : `lsst.daf.base.DateTime` 

743 Time of the current visit. 

744 """ 

745 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

746 extra_columns = { 

747 "apdb_time_part": self._time_partition(visit_time), 

748 "insert_id": insert_id.id if insert_id is not None else None, 

749 } 

750 

751 self._storeObjectsPandas( 

752 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

753 ) 

754 

755 def _storeObjectsPandas( 

756 self, 

757 records: pandas.DataFrame, 

758 table_name: Union[ApdbTables, ExtraTables], 

759 extra_columns: Optional[Mapping] = None, 

760 time_part: Optional[int] = None, 

761 ) -> None: 

762 """Generic store method. 

763 

764 Takes Pandas catalog and stores a bunch of records in a table. 

765 

766 Parameters 

767 ---------- 

768 records : `pandas.DataFrame` 

769 Catalog containing object records 

770 table_name : `ApdbTables` 

771 Name of the table as defined in APDB schema. 

772 extra_columns : `dict`, optional 

773 Mapping (column_name, column_value) which gives fixed values for 

774 columns in each row, overrides values in ``records`` if matching 

775 columns exist there. 

776 time_part : `int`, optional 

777 If not `None` then insert into a per-partition table. 

778 

779 Notes 

780 ----- 

781 If Pandas catalog contains additional columns not defined in table 

782 schema they are ignored. Catalog does not have to contain all columns 

783 defined in a table, but partition and clustering keys must be present 

784 in a catalog or ``extra_columns``. 

785 """ 

786 # use extra columns if specified 

787 if extra_columns is None: 

788 extra_columns = {} 

789 extra_fields = list(extra_columns.keys()) 

790 

791 # Fields that will come from dataframe. 

792 df_fields = [column for column in records.columns if column not in extra_fields] 

793 

794 column_map = self._schema.getColumnMap(table_name) 

795 # list of columns (as in felis schema) 

796 fields = [column_map[field].name for field in df_fields if field in column_map] 

797 fields += extra_fields 

798 

799 # check that all partitioning and clustering columns are defined 

800 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

801 table_name 

802 ) 

803 missing_columns = [column for column in required_columns if column not in fields] 

804 if missing_columns: 

805 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

806 

807 qfields = [quote_id(field) for field in fields] 

808 qfields_str = ",".join(qfields) 

809 

810 with Timer(table_name.name + " query build", self.config.timer): 

811 table = self._schema.tableName(table_name) 

812 if time_part is not None: 

813 table = f"{table}_{time_part}" 

814 

815 holders = ",".join(["?"] * len(qfields)) 

816 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

817 statement = self._prep_statement(query) 

818 queries = cassandra.query.BatchStatement() 

819 for rec in records.itertuples(index=False): 

820 values = [] 

821 for field in df_fields: 

822 if field not in column_map: 

823 continue 

824 value = getattr(rec, field) 

825 if column_map[field].datatype is felis.types.Timestamp: 

826 if isinstance(value, pandas.Timestamp): 

827 value = literal(value.to_pydatetime()) 

828 else: 

829 # Assume it's seconds since epoch, Cassandra 

830 # datetime is in milliseconds 

831 value = int(value * 1000) 

832 values.append(literal(value)) 

833 for field in extra_fields: 

834 value = extra_columns[field] 

835 values.append(literal(value)) 

836 queries.add(statement, values) 

837 

838 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

839 with Timer(table_name.name + " insert", self.config.timer): 

840 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

841 

842 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

843 """Calculate spatial partition for each record and add it to a 

844 DataFrame. 

845 

846 Notes 

847 ----- 

848 This overrides any existing column in a DataFrame with the same name 

849 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

850 returned. 

851 """ 

852 # calculate HTM index for every DiaObject 

853 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

854 ra_col, dec_col = self.config.ra_dec_columns 

855 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

856 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

857 idx = self._pixelization.pixel(uv3d) 

858 apdb_part[i] = idx 

859 df = df.copy() 

860 df["apdb_part"] = apdb_part 

861 return df 

862 

863 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

864 """Add apdb_part column to DiaSource catalog. 

865 

866 Notes 

867 ----- 

868 This method copies apdb_part value from a matching DiaObject record. 

869 DiaObject catalog needs to have a apdb_part column filled by 

870 ``_add_obj_part`` method and DiaSource records need to be 

871 associated to DiaObjects via ``diaObjectId`` column. 

872 

873 This overrides any existing column in a DataFrame with the same name 

874 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

875 returned. 

876 """ 

877 pixel_id_map: Dict[int, int] = { 

878 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

879 } 

880 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

881 ra_col, dec_col = self.config.ra_dec_columns 

882 for i, (diaObjId, ra, dec) in enumerate( 

883 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

884 ): 

885 if diaObjId == 0: 

886 # DiaSources associated with SolarSystemObjects do not have an 

887 # associated DiaObject hence we skip them and set partition 

888 # based on its own ra/dec 

889 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

890 idx = self._pixelization.pixel(uv3d) 

891 apdb_part[i] = idx 

892 else: 

893 apdb_part[i] = pixel_id_map[diaObjId] 

894 sources = sources.copy() 

895 sources["apdb_part"] = apdb_part 

896 return sources 

897 

898 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

899 """Add apdb_part column to DiaForcedSource catalog. 

900 

901 Notes 

902 ----- 

903 This method copies apdb_part value from a matching DiaObject record. 

904 DiaObject catalog needs to have a apdb_part column filled by 

905 ``_add_obj_part`` method and DiaSource records need to be 

906 associated to DiaObjects via ``diaObjectId`` column. 

907 

908 This overrides any existing column in a DataFrame with the same name 

909 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

910 returned. 

911 """ 

912 pixel_id_map: Dict[int, int] = { 

913 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

914 } 

915 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

916 for i, diaObjId in enumerate(sources["diaObjectId"]): 

917 apdb_part[i] = pixel_id_map[diaObjId] 

918 sources = sources.copy() 

919 sources["apdb_part"] = apdb_part 

920 return sources 

921 

922 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

923 """Calculate time partiton number for a given time. 

924 

925 Parameters 

926 ---------- 

927 time : `float` or `lsst.daf.base.DateTime` 

928 Time for which to calculate partition number. Can be float to mean 

929 MJD or `lsst.daf.base.DateTime` 

930 

931 Returns 

932 ------- 

933 partition : `int` 

934 Partition number for a given time. 

935 """ 

936 if isinstance(time, dafBase.DateTime): 

937 mjd = time.get(system=dafBase.DateTime.MJD) 

938 else: 

939 mjd = time 

940 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

941 partition = int(days_since_epoch) // self.config.time_partition_days 

942 return partition 

943 

944 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

945 """Make an empty catalog for a table with a given name. 

946 

947 Parameters 

948 ---------- 

949 table_name : `ApdbTables` 

950 Name of the table. 

951 

952 Returns 

953 ------- 

954 catalog : `pandas.DataFrame` 

955 An empty catalog. 

956 """ 

957 table = self._schema.tableSchemas[table_name] 

958 

959 data = { 

960 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

961 for columnDef in table.columns 

962 } 

963 return pandas.DataFrame(data) 

964 

965 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement: 

966 """Convert query string into prepared statement.""" 

967 stmt = self._prepared_statements.get(query) 

968 if stmt is None: 

969 stmt = self._session.prepare(query) 

970 self._prepared_statements[query] = stmt 

971 return stmt 

972 

973 def _combine_where( 

974 self, 

975 prefix: str, 

976 where1: List[Tuple[str, Tuple]], 

977 where2: List[Tuple[str, Tuple]], 

978 suffix: Optional[str] = None, 

979 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]: 

980 """Make cartesian product of two parts of WHERE clause into a series 

981 of statements to execute. 

982 

983 Parameters 

984 ---------- 

985 prefix : `str` 

986 Initial statement prefix that comes before WHERE clause, e.g. 

987 "SELECT * from Table" 

988 """ 

989 # If lists are empty use special sentinels. 

990 if not where1: 

991 where1 = [("", ())] 

992 if not where2: 

993 where2 = [("", ())] 

994 

995 for expr1, params1 in where1: 

996 for expr2, params2 in where2: 

997 full_query = prefix 

998 wheres = [] 

999 if expr1: 

1000 wheres.append(expr1) 

1001 if expr2: 

1002 wheres.append(expr2) 

1003 if wheres: 

1004 full_query += " WHERE " + " AND ".join(wheres) 

1005 if suffix: 

1006 full_query += " " + suffix 

1007 params = params1 + params2 

1008 if params: 

1009 statement = self._prep_statement(full_query) 

1010 else: 

1011 # If there are no params then it is likely that query 

1012 # has a bunch of literals rendered already, no point 

1013 # trying to prepare it. 

1014 statement = cassandra.query.SimpleStatement(full_query) 

1015 yield (statement, params) 

1016 

1017 def _spatial_where( 

1018 self, region: Optional[sphgeom.Region], use_ranges: bool = False 

1019 ) -> List[Tuple[str, Tuple]]: 

1020 """Generate expressions for spatial part of WHERE clause. 

1021 

1022 Parameters 

1023 ---------- 

1024 region : `sphgeom.Region` 

1025 Spatial region for query results. 

1026 use_ranges : `bool` 

1027 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1028 p2") instead of exact list of pixels. Should be set to True for 

1029 large regions covering very many pixels. 

1030 

1031 Returns 

1032 ------- 

1033 expressions : `list` [ `tuple` ] 

1034 Empty list is returned if ``region`` is `None`, otherwise a list 

1035 of one or more (expression, parameters) tuples 

1036 """ 

1037 if region is None: 

1038 return [] 

1039 if use_ranges: 

1040 pixel_ranges = self._pixelization.envelope(region) 

1041 expressions: List[Tuple[str, Tuple]] = [] 

1042 for lower, upper in pixel_ranges: 

1043 upper -= 1 

1044 if lower == upper: 

1045 expressions.append(('"apdb_part" = ?', (lower,))) 

1046 else: 

1047 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1048 return expressions 

1049 else: 

1050 pixels = self._pixelization.pixels(region) 

1051 if self.config.query_per_spatial_part: 

1052 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1053 else: 

1054 pixels_str = ",".join([str(pix) for pix in pixels]) 

1055 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1056 

1057 def _temporal_where( 

1058 self, 

1059 table: ApdbTables, 

1060 start_time: Union[float, dafBase.DateTime], 

1061 end_time: Union[float, dafBase.DateTime], 

1062 query_per_time_part: Optional[bool] = None, 

1063 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]: 

1064 """Generate table names and expressions for temporal part of WHERE 

1065 clauses. 

1066 

1067 Parameters 

1068 ---------- 

1069 table : `ApdbTables` 

1070 Table to select from. 

1071 start_time : `dafBase.DateTime` or `float` 

1072 Starting Datetime of MJD value of the time range. 

1073 start_time : `dafBase.DateTime` or `float` 

1074 Starting Datetime of MJD value of the time range. 

1075 query_per_time_part : `bool`, optional 

1076 If None then use ``query_per_time_part`` from configuration. 

1077 

1078 Returns 

1079 ------- 

1080 tables : `list` [ `str` ] 

1081 List of the table names to query. 

1082 expressions : `list` [ `tuple` ] 

1083 A list of zero or more (expression, parameters) tuples. 

1084 """ 

1085 tables: List[str] 

1086 temporal_where: List[Tuple[str, Tuple]] = [] 

1087 table_name = self._schema.tableName(table) 

1088 time_part_start = self._time_partition(start_time) 

1089 time_part_end = self._time_partition(end_time) 

1090 time_parts = list(range(time_part_start, time_part_end + 1)) 

1091 if self.config.time_partition_tables: 

1092 tables = [f"{table_name}_{part}" for part in time_parts] 

1093 else: 

1094 tables = [table_name] 

1095 if query_per_time_part is None: 

1096 query_per_time_part = self.config.query_per_time_part 

1097 if query_per_time_part: 

1098 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1099 else: 

1100 time_part_list = ",".join([str(part) for part in time_parts]) 

1101 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1102 

1103 return tables, temporal_where