Coverage for python/lsst/dax/apdb/apdbCassandra.py: 16%

401 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-17 01:57 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import logging 

27import numpy as np 

28import pandas 

29from typing import Any, cast, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union 

30 

31# If cassandra-driver is not there the module can still be imported 

32# but ApdbCassandra cannot be instantiated. 

33try: 

34 import cassandra 

35 from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT 

36 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator 

37 import cassandra.query 

38 CASSANDRA_IMPORTED = True 

39except ImportError: 

40 CASSANDRA_IMPORTED = False 

41 

42import lsst.daf.base as dafBase 

43from lsst import sphgeom 

44from lsst.pex.config import ChoiceField, Field, ListField 

45from lsst.utils.iteration import chunk_iterable 

46from .timer import Timer 

47from .apdb import Apdb, ApdbConfig 

48from .apdbSchema import ApdbTables, TableDef 

49from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

50from .cassandra_utils import ( 

51 literal, 

52 pandas_dataframe_factory, 

53 quote_id, 

54 raw_data_factory, 

55 select_concurrent, 

56) 

57from .pixelization import Pixelization 

58 

59_LOG = logging.getLogger(__name__) 

60 

61 

62class CassandraMissingError(Exception): 

63 def __init__(self) -> None: 

64 super().__init__("cassandra-driver module cannot be imported") 

65 

66 

67class ApdbCassandraConfig(ApdbConfig): 

68 

69 contact_points = ListField( 

70 dtype=str, 

71 doc="The list of contact points to try connecting for cluster discovery.", 

72 default=["127.0.0.1"] 

73 ) 

74 private_ips = ListField( 

75 dtype=str, 

76 doc="List of internal IP addresses for contact_points.", 

77 default=[] 

78 ) 

79 keyspace = Field( 

80 dtype=str, 

81 doc="Default keyspace for operations.", 

82 default="apdb" 

83 ) 

84 read_consistency = Field( 

85 dtype=str, 

86 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", 

87 default="QUORUM" 

88 ) 

89 write_consistency = Field( 

90 dtype=str, 

91 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", 

92 default="QUORUM" 

93 ) 

94 read_timeout = Field( 

95 dtype=float, 

96 doc="Timeout in seconds for read operations.", 

97 default=120. 

98 ) 

99 write_timeout = Field( 

100 dtype=float, 

101 doc="Timeout in seconds for write operations.", 

102 default=10. 

103 ) 

104 read_concurrency = Field( 

105 dtype=int, 

106 doc="Concurrency level for read operations.", 

107 default=500 

108 ) 

109 protocol_version = Field( 

110 dtype=int, 

111 doc="Cassandra protocol version to use, default is V4", 

112 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0 

113 ) 

114 dia_object_columns = ListField( 

115 dtype=str, 

116 doc="List of columns to read from DiaObject, by default read all columns", 

117 default=[] 

118 ) 

119 prefix = Field( 

120 dtype=str, 

121 doc="Prefix to add to table names", 

122 default="" 

123 ) 

124 part_pixelization = ChoiceField( 

125 dtype=str, 

126 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

127 doc="Pixelization used for partitioning index.", 

128 default="mq3c" 

129 ) 

130 part_pix_level = Field( 

131 dtype=int, 

132 doc="Pixelization level used for partitioning index.", 

133 default=10 

134 ) 

135 part_pix_max_ranges = Field( 

136 dtype=int, 

137 doc="Max number of ranges in pixelization envelope", 

138 default=64 

139 ) 

140 ra_dec_columns = ListField( 

141 dtype=str, 

142 default=["ra", "decl"], 

143 doc="Names ra/dec columns in DiaObject table" 

144 ) 

145 timer = Field( 

146 dtype=bool, 

147 doc="If True then print/log timing information", 

148 default=False 

149 ) 

150 time_partition_tables = Field( 

151 dtype=bool, 

152 doc="Use per-partition tables for sources instead of partitioning by time", 

153 default=True 

154 ) 

155 time_partition_days = Field( 

156 dtype=int, 

157 doc="Time partitoning granularity in days, this value must not be changed" 

158 " after database is initialized", 

159 default=30 

160 ) 

161 time_partition_start = Field( 

162 dtype=str, 

163 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI." 

164 " This is used only when time_partition_tables is True.", 

165 default="2018-12-01T00:00:00" 

166 ) 

167 time_partition_end = Field( 

168 dtype=str, 

169 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI" 

170 " This is used only when time_partition_tables is True.", 

171 default="2030-01-01T00:00:00" 

172 ) 

173 query_per_time_part = Field( 

174 dtype=bool, 

175 default=False, 

176 doc="If True then build separate query for each time partition, otherwise build one single query. " 

177 "This is only used when time_partition_tables is False in schema config." 

178 ) 

179 query_per_spatial_part = Field( 

180 dtype=bool, 

181 default=False, 

182 doc="If True then build one query per spacial partition, otherwise build single query. " 

183 ) 

184 pandas_delay_conv = Field( 

185 dtype=bool, 

186 default=True, 

187 doc="If True then combine result rows before converting to pandas. " 

188 ) 

189 

190 

191if CASSANDRA_IMPORTED: 191 ↛ 193line 191 didn't jump to line 193, because the condition on line 191 was never true

192 

193 class _AddressTranslator(AddressTranslator): 

194 """Translate internal IP address to external. 

195 

196 Only used for docker-based setup, not viable long-term solution. 

197 """ 

198 def __init__(self, public_ips: List[str], private_ips: List[str]): 

199 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

200 

201 def translate(self, private_ip: str) -> str: 

202 return self._map.get(private_ip, private_ip) 

203 

204 

205class ApdbCassandra(Apdb): 

206 """Implementation of APDB database on to of Apache Cassandra. 

207 

208 The implementation is configured via standard ``pex_config`` mechanism 

209 using `ApdbCassandraConfig` configuration class. For an example of 

210 different configurations check config/ folder. 

211 

212 Parameters 

213 ---------- 

214 config : `ApdbCassandraConfig` 

215 Configuration object. 

216 """ 

217 

218 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

219 """Start time for partition 0, this should never be changed.""" 

220 

221 def __init__(self, config: ApdbCassandraConfig): 

222 

223 if not CASSANDRA_IMPORTED: 

224 raise CassandraMissingError() 

225 

226 self.config = config 

227 

228 _LOG.debug("ApdbCassandra Configuration:") 

229 for key, value in self.config.items(): 

230 _LOG.debug(" %s: %s", key, value) 

231 

232 self._pixelization = Pixelization( 

233 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges 

234 ) 

235 

236 addressTranslator: Optional[AddressTranslator] = None 

237 if config.private_ips: 

238 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

239 

240 self._keyspace = config.keyspace 

241 

242 self._cluster = Cluster(execution_profiles=self._makeProfiles(config), 

243 contact_points=self.config.contact_points, 

244 address_translator=addressTranslator, 

245 protocol_version=self.config.protocol_version) 

246 self._session = self._cluster.connect() 

247 # Disable result paging 

248 self._session.default_fetch_size = None 

249 

250 self._schema = ApdbCassandraSchema(session=self._session, 

251 keyspace=self._keyspace, 

252 schema_file=self.config.schema_file, 

253 schema_name=self.config.schema_name, 

254 prefix=self.config.prefix, 

255 time_partition_tables=self.config.time_partition_tables) 

256 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

257 

258 # Cache for prepared statements 

259 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {} 

260 

261 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

262 # docstring is inherited from a base class 

263 return self._schema.tableSchemas.get(table) 

264 

265 def makeSchema(self, drop: bool = False) -> None: 

266 # docstring is inherited from a base class 

267 

268 if self.config.time_partition_tables: 

269 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

270 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

271 part_range = ( 

272 self._time_partition(time_partition_start), 

273 self._time_partition(time_partition_end) + 1 

274 ) 

275 self._schema.makeSchema(drop=drop, part_range=part_range) 

276 else: 

277 self._schema.makeSchema(drop=drop) 

278 

279 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

280 # docstring is inherited from a base class 

281 

282 sp_where = self._spatial_where(region) 

283 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

284 

285 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

286 query = f'SELECT * from "{self._keyspace}"."{table_name}"' 

287 statements: List[Tuple] = [] 

288 for where, params in sp_where: 

289 full_query = f"{query} WHERE {where}" 

290 if params: 

291 statement = self._prep_statement(full_query) 

292 else: 

293 # If there are no params then it is likely that query has a 

294 # bunch of literals rendered already, no point trying to 

295 # prepare it because it's not reusable. 

296 statement = cassandra.query.SimpleStatement(full_query) 

297 statements.append((statement, params)) 

298 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

299 

300 with Timer('DiaObject select', self.config.timer): 

301 objects = cast( 

302 pandas.DataFrame, 

303 select_concurrent( 

304 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

305 ) 

306 ) 

307 

308 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

309 return objects 

310 

311 def getDiaSources(self, region: sphgeom.Region, 

312 object_ids: Optional[Iterable[int]], 

313 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

314 # docstring is inherited from a base class 

315 months = self.config.read_sources_months 

316 if months == 0: 

317 return None 

318 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

319 mjd_start = mjd_end - months*30 

320 

321 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

322 

323 def getDiaForcedSources(self, region: sphgeom.Region, 

324 object_ids: Optional[Iterable[int]], 

325 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

326 # docstring is inherited from a base class 

327 months = self.config.read_forced_sources_months 

328 if months == 0: 

329 return None 

330 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

331 mjd_start = mjd_end - months*30 

332 

333 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

334 

335 def getDiaObjectsHistory(self, 

336 start_time: dafBase.DateTime, 

337 end_time: dafBase.DateTime, 

338 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

339 # docstring is inherited from a base class 

340 

341 sp_where = self._spatial_where(region, use_ranges=True) 

342 tables, temporal_where = self._temporal_where(ApdbTables.DiaObject, start_time, end_time, True) 

343 

344 # Build all queries 

345 statements: List[Tuple] = [] 

346 for table in tables: 

347 prefix = f'SELECT * from "{self._keyspace}"."{table}"' 

348 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING")) 

349 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements)) 

350 

351 # Run all selects in parallel 

352 with Timer("DiaObject history", self.config.timer): 

353 catalog = cast( 

354 pandas.DataFrame, 

355 select_concurrent( 

356 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

357 ) 

358 ) 

359 

360 # precise filtering on validityStart 

361 validity_start = start_time.toPython() 

362 validity_end = end_time.toPython() 

363 catalog = cast( 

364 pandas.DataFrame, 

365 catalog[(catalog["validityStart"] >= validity_start) & (catalog["validityStart"] < validity_end)] 

366 ) 

367 

368 _LOG.debug("found %d DiaObjects", catalog.shape[0]) 

369 return catalog 

370 

371 def getDiaSourcesHistory(self, 

372 start_time: dafBase.DateTime, 

373 end_time: dafBase.DateTime, 

374 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

375 # docstring is inherited from a base class 

376 return self._getSourcesHistory(ApdbTables.DiaSource, start_time, end_time, region) 

377 

378 def getDiaForcedSourcesHistory(self, 

379 start_time: dafBase.DateTime, 

380 end_time: dafBase.DateTime, 

381 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

382 # docstring is inherited from a base class 

383 return self._getSourcesHistory(ApdbTables.DiaForcedSource, start_time, end_time, region) 

384 

385 def getSSObjects(self) -> pandas.DataFrame: 

386 # docstring is inherited from a base class 

387 tableName = self._schema.tableName(ApdbTables.SSObject) 

388 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

389 

390 objects = None 

391 with Timer('SSObject select', self.config.timer): 

392 result = self._session.execute(query, execution_profile="read_pandas") 

393 objects = result._current_rows 

394 

395 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

396 return objects 

397 

398 def store(self, 

399 visit_time: dafBase.DateTime, 

400 objects: pandas.DataFrame, 

401 sources: Optional[pandas.DataFrame] = None, 

402 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

403 # docstring is inherited from a base class 

404 

405 # fill region partition column for DiaObjects 

406 objects = self._add_obj_part(objects) 

407 self._storeDiaObjects(objects, visit_time) 

408 

409 if sources is not None: 

410 # copy apdb_part column from DiaObjects to DiaSources 

411 sources = self._add_src_part(sources, objects) 

412 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time) 

413 self._storeDiaSourcesPartitions(sources, visit_time) 

414 

415 if forced_sources is not None: 

416 forced_sources = self._add_fsrc_part(forced_sources, objects) 

417 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time) 

418 

419 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

420 # docstring is inherited from a base class 

421 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

422 

423 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

424 # docstring is inherited from a base class 

425 

426 # To update a record we need to know its exact primary key (including 

427 # partition key) so we start by querying for diaSourceId to find the 

428 # primary keys. 

429 

430 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

431 # split it into 1k IDs per query 

432 selects: List[Tuple] = [] 

433 for ids in chunk_iterable(idMap.keys(), 1_000): 

434 ids_str = ",".join(str(item) for item in ids) 

435 selects.append(( 

436 (f'SELECT "diaSourceId", "apdb_part", "apdb_time_part" FROM "{self._keyspace}"."{table_name}"' 

437 f' WHERE "diaSourceId" IN ({ids_str})'), 

438 {} 

439 )) 

440 

441 # No need for DataFrame here, read data as tuples. 

442 result = cast( 

443 List[Tuple[int, int, int]], 

444 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency) 

445 ) 

446 

447 # Make mapping from source ID to its partition. 

448 id2partitions: Dict[int, Tuple[int, int]] = {} 

449 for row in result: 

450 id2partitions[row[0]] = row[1:] 

451 

452 # make sure we know partitions for each ID 

453 if set(id2partitions) != set(idMap): 

454 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

455 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

456 

457 queries = cassandra.query.BatchStatement() 

458 table_name = self._schema.tableName(ApdbTables.DiaSource) 

459 for diaSourceId, ssObjectId in idMap.items(): 

460 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

461 values: Tuple 

462 if self.config.time_partition_tables: 

463 query = ( 

464 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

465 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

466 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

467 ) 

468 values = (ssObjectId, apdb_part, diaSourceId) 

469 else: 

470 query = ( 

471 f'UPDATE "{self._keyspace}"."{table_name}"' 

472 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

473 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

474 ) 

475 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

476 queries.add(self._prep_statement(query), values) 

477 

478 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

479 with Timer(table_name + ' update', self.config.timer): 

480 self._session.execute(queries, execution_profile="write") 

481 

482 def dailyJob(self) -> None: 

483 # docstring is inherited from a base class 

484 pass 

485 

486 def countUnassociatedObjects(self) -> int: 

487 # docstring is inherited from a base class 

488 

489 # It's too inefficient to implement it for Cassandra in current schema. 

490 raise NotImplementedError() 

491 

492 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

493 """Make all execution profiles used in the code.""" 

494 

495 if config.private_ips: 

496 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

497 else: 

498 loadBalancePolicy = RoundRobinPolicy() 

499 

500 pandas_row_factory: Callable 

501 if not config.pandas_delay_conv: 

502 pandas_row_factory = pandas_dataframe_factory 

503 else: 

504 pandas_row_factory = raw_data_factory 

505 

506 read_tuples_profile = ExecutionProfile( 

507 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

508 request_timeout=config.read_timeout, 

509 row_factory=cassandra.query.tuple_factory, 

510 load_balancing_policy=loadBalancePolicy, 

511 ) 

512 read_pandas_profile = ExecutionProfile( 

513 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

514 request_timeout=config.read_timeout, 

515 row_factory=pandas_dataframe_factory, 

516 load_balancing_policy=loadBalancePolicy, 

517 ) 

518 read_pandas_multi_profile = ExecutionProfile( 

519 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

520 request_timeout=config.read_timeout, 

521 row_factory=pandas_row_factory, 

522 load_balancing_policy=loadBalancePolicy, 

523 ) 

524 write_profile = ExecutionProfile( 

525 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

526 request_timeout=config.write_timeout, 

527 load_balancing_policy=loadBalancePolicy, 

528 ) 

529 # To replace default DCAwareRoundRobinPolicy 

530 default_profile = ExecutionProfile( 

531 load_balancing_policy=loadBalancePolicy, 

532 ) 

533 return { 

534 "read_tuples": read_tuples_profile, 

535 "read_pandas": read_pandas_profile, 

536 "read_pandas_multi": read_pandas_multi_profile, 

537 "write": write_profile, 

538 EXEC_PROFILE_DEFAULT: default_profile, 

539 } 

540 

541 def _getSources(self, region: sphgeom.Region, 

542 object_ids: Optional[Iterable[int]], 

543 mjd_start: float, 

544 mjd_end: float, 

545 table_name: ApdbTables) -> pandas.DataFrame: 

546 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

547 

548 Parameters 

549 ---------- 

550 region : `lsst.sphgeom.Region` 

551 Spherical region. 

552 object_ids : 

553 Collection of DiaObject IDs 

554 mjd_start : `float` 

555 Lower bound of time interval. 

556 mjd_end : `float` 

557 Upper bound of time interval. 

558 table_name : `ApdbTables` 

559 Name of the table. 

560 

561 Returns 

562 ------- 

563 catalog : `pandas.DataFrame`, or `None` 

564 Catalog contaning DiaSource records. Empty catalog is returned if 

565 ``object_ids`` is empty. 

566 """ 

567 object_id_set: Set[int] = set() 

568 if object_ids is not None: 

569 object_id_set = set(object_ids) 

570 if len(object_id_set) == 0: 

571 return self._make_empty_catalog(table_name) 

572 

573 sp_where = self._spatial_where(region) 

574 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

575 

576 # Build all queries 

577 statements: List[Tuple] = [] 

578 for table in tables: 

579 prefix = f'SELECT * from "{self._keyspace}"."{table}"' 

580 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

581 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

582 

583 with Timer(table_name.name + ' select', self.config.timer): 

584 catalog = cast( 

585 pandas.DataFrame, 

586 select_concurrent( 

587 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

588 ) 

589 ) 

590 

591 # filter by given object IDs 

592 if len(object_id_set) > 0: 

593 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

594 

595 # precise filtering on midPointTai 

596 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start]) 

597 

598 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

599 return catalog 

600 

601 def _getSourcesHistory( 

602 self, 

603 table: ApdbTables, 

604 start_time: dafBase.DateTime, 

605 end_time: dafBase.DateTime, 

606 region: Optional[sphgeom.Region] = None, 

607 ) -> pandas.DataFrame: 

608 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

609 

610 Parameters 

611 ---------- 

612 table : `ApdbTables` 

613 Name of the table. 

614 start_time : `dafBase.DateTime` 

615 Starting time for DiaSource history search. DiaSource record is 

616 selected when its ``midPointTai`` falls into an interval between 

617 ``start_time`` (inclusive) and ``end_time`` (exclusive). 

618 end_time : `dafBase.DateTime` 

619 Upper limit on time for DiaSource history search. 

620 region : `lsst.sphgeom.Region` 

621 Spherical region. 

622 

623 Returns 

624 ------- 

625 catalog : `pandas.DataFrame` 

626 Catalog contaning DiaSource records. 

627 """ 

628 sp_where = self._spatial_where(region, use_ranges=False) 

629 tables, temporal_where = self._temporal_where(table, start_time, end_time, True) 

630 

631 # Build all queries 

632 statements: List[Tuple] = [] 

633 for table_name in tables: 

634 prefix = f'SELECT * from "{self._keyspace}"."{table_name}"' 

635 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING")) 

636 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements)) 

637 

638 # Run all selects in parallel 

639 with Timer(f"{table.name} history", self.config.timer): 

640 catalog = cast( 

641 pandas.DataFrame, 

642 select_concurrent( 

643 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

644 ) 

645 ) 

646 

647 # precise filtering on validityStart 

648 period_start = start_time.get(system=dafBase.DateTime.MJD) 

649 period_end = end_time.get(system=dafBase.DateTime.MJD) 

650 catalog = cast( 

651 pandas.DataFrame, 

652 catalog[(catalog["midPointTai"] >= period_start) & (catalog["midPointTai"] < period_end)] 

653 ) 

654 

655 _LOG.debug("found %d %ss", catalog.shape[0], table.name) 

656 return catalog 

657 

658 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

659 """Store catalog of DiaObjects from current visit. 

660 

661 Parameters 

662 ---------- 

663 objs : `pandas.DataFrame` 

664 Catalog with DiaObject records 

665 visit_time : `lsst.daf.base.DateTime` 

666 Time of the current visit. 

667 """ 

668 visit_time_dt = visit_time.toPython() 

669 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

670 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

671 

672 extra_columns["validityStart"] = visit_time_dt 

673 time_part: Optional[int] = self._time_partition(visit_time) 

674 if not self.config.time_partition_tables: 

675 extra_columns["apdb_time_part"] = time_part 

676 time_part = None 

677 

678 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part) 

679 

680 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame, 

681 visit_time: dafBase.DateTime) -> None: 

682 """Store catalog of DIASources or DIAForcedSources from current visit. 

683 

684 Parameters 

685 ---------- 

686 sources : `pandas.DataFrame` 

687 Catalog containing DiaSource records 

688 visit_time : `lsst.daf.base.DateTime` 

689 Time of the current visit. 

690 """ 

691 time_part: Optional[int] = self._time_partition(visit_time) 

692 extra_columns = {} 

693 if not self.config.time_partition_tables: 

694 extra_columns["apdb_time_part"] = time_part 

695 time_part = None 

696 

697 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

698 

699 def _storeDiaSourcesPartitions(self, sources: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

700 """Store mapping of diaSourceId to its partitioning values. 

701 

702 Parameters 

703 ---------- 

704 sources : `pandas.DataFrame` 

705 Catalog containing DiaSource records 

706 visit_time : `lsst.daf.base.DateTime` 

707 Time of the current visit. 

708 """ 

709 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

710 extra_columns = { 

711 "apdb_time_part": self._time_partition(visit_time), 

712 } 

713 

714 self._storeObjectsPandas( 

715 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

716 ) 

717 

718 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: Union[ApdbTables, ExtraTables], 

719 extra_columns: Optional[Mapping] = None, 

720 time_part: Optional[int] = None) -> None: 

721 """Generic store method. 

722 

723 Takes catalog of records and stores a bunch of objects in a table. 

724 

725 Parameters 

726 ---------- 

727 objects : `pandas.DataFrame` 

728 Catalog containing object records 

729 table_name : `ApdbTables` 

730 Name of the table as defined in APDB schema. 

731 extra_columns : `dict`, optional 

732 Mapping (column_name, column_value) which gives column values to add 

733 to every row, only if column is missing in catalog records. 

734 time_part : `int`, optional 

735 If not `None` then insert into a per-partition table. 

736 """ 

737 # use extra columns if specified 

738 if extra_columns is None: 

739 extra_columns = {} 

740 extra_fields = list(extra_columns.keys()) 

741 

742 df_fields = [ 

743 column for column in objects.columns if column not in extra_fields 

744 ] 

745 

746 column_map = self._schema.getColumnMap(table_name) 

747 # list of columns (as in cat schema) 

748 fields = [column_map[field].name for field in df_fields if field in column_map] 

749 fields += extra_fields 

750 

751 # check that all partitioning and clustering columns are defined 

752 required_columns = self._schema.partitionColumns(table_name) \ 

753 + self._schema.clusteringColumns(table_name) 

754 missing_columns = [column for column in required_columns if column not in fields] 

755 if missing_columns: 

756 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

757 

758 qfields = [quote_id(field) for field in fields] 

759 qfields_str = ','.join(qfields) 

760 

761 with Timer(table_name.name + ' query build', self.config.timer): 

762 

763 table = self._schema.tableName(table_name) 

764 if time_part is not None: 

765 table = f"{table}_{time_part}" 

766 

767 holders = ','.join(['?']*len(qfields)) 

768 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

769 statement = self._prep_statement(query) 

770 queries = cassandra.query.BatchStatement() 

771 for rec in objects.itertuples(index=False): 

772 values = [] 

773 for field in df_fields: 

774 if field not in column_map: 

775 continue 

776 value = getattr(rec, field) 

777 if column_map[field].type == "DATETIME": 

778 if isinstance(value, pandas.Timestamp): 

779 value = literal(value.to_pydatetime()) 

780 else: 

781 # Assume it's seconds since epoch, Cassandra 

782 # datetime is in milliseconds 

783 value = int(value*1000) 

784 values.append(literal(value)) 

785 for field in extra_fields: 

786 value = extra_columns[field] 

787 values.append(literal(value)) 

788 queries.add(statement, values) 

789 

790 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0]) 

791 with Timer(table_name.name + ' insert', self.config.timer): 

792 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

793 

794 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

795 """Calculate spacial partition for each record and add it to a 

796 DataFrame. 

797 

798 Notes 

799 ----- 

800 This overrides any existing column in a DataFrame with the same name 

801 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

802 returned. 

803 """ 

804 # calculate HTM index for every DiaObject 

805 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

806 ra_col, dec_col = self.config.ra_dec_columns 

807 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

808 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

809 idx = self._pixelization.pixel(uv3d) 

810 apdb_part[i] = idx 

811 df = df.copy() 

812 df["apdb_part"] = apdb_part 

813 return df 

814 

815 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

816 """Add apdb_part column to DiaSource catalog. 

817 

818 Notes 

819 ----- 

820 This method copies apdb_part value from a matching DiaObject record. 

821 DiaObject catalog needs to have a apdb_part column filled by 

822 ``_add_obj_part`` method and DiaSource records need to be 

823 associated to DiaObjects via ``diaObjectId`` column. 

824 

825 This overrides any existing column in a DataFrame with the same name 

826 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

827 returned. 

828 """ 

829 pixel_id_map: Dict[int, int] = { 

830 diaObjectId: apdb_part for diaObjectId, apdb_part 

831 in zip(objs["diaObjectId"], objs["apdb_part"]) 

832 } 

833 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

834 ra_col, dec_col = self.config.ra_dec_columns 

835 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"], 

836 sources[ra_col], sources[dec_col])): 

837 if diaObjId == 0: 

838 # DiaSources associated with SolarSystemObjects do not have an 

839 # associated DiaObject hence we skip them and set partition 

840 # based on its own ra/dec 

841 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

842 idx = self._pixelization.pixel(uv3d) 

843 apdb_part[i] = idx 

844 else: 

845 apdb_part[i] = pixel_id_map[diaObjId] 

846 sources = sources.copy() 

847 sources["apdb_part"] = apdb_part 

848 return sources 

849 

850 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

851 """Add apdb_part column to DiaForcedSource catalog. 

852 

853 Notes 

854 ----- 

855 This method copies apdb_part value from a matching DiaObject record. 

856 DiaObject catalog needs to have a apdb_part column filled by 

857 ``_add_obj_part`` method and DiaSource records need to be 

858 associated to DiaObjects via ``diaObjectId`` column. 

859 

860 This overrides any existing column in a DataFrame with the same name 

861 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

862 returned. 

863 """ 

864 pixel_id_map: Dict[int, int] = { 

865 diaObjectId: apdb_part for diaObjectId, apdb_part 

866 in zip(objs["diaObjectId"], objs["apdb_part"]) 

867 } 

868 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

869 for i, diaObjId in enumerate(sources["diaObjectId"]): 

870 apdb_part[i] = pixel_id_map[diaObjId] 

871 sources = sources.copy() 

872 sources["apdb_part"] = apdb_part 

873 return sources 

874 

875 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

876 """Calculate time partiton number for a given time. 

877 

878 Parameters 

879 ---------- 

880 time : `float` or `lsst.daf.base.DateTime` 

881 Time for which to calculate partition number. Can be float to mean 

882 MJD or `lsst.daf.base.DateTime` 

883 

884 Returns 

885 ------- 

886 partition : `int` 

887 Partition number for a given time. 

888 """ 

889 if isinstance(time, dafBase.DateTime): 

890 mjd = time.get(system=dafBase.DateTime.MJD) 

891 else: 

892 mjd = time 

893 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

894 partition = int(days_since_epoch) // self.config.time_partition_days 

895 return partition 

896 

897 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

898 """Make an empty catalog for a table with a given name. 

899 

900 Parameters 

901 ---------- 

902 table_name : `ApdbTables` 

903 Name of the table. 

904 

905 Returns 

906 ------- 

907 catalog : `pandas.DataFrame` 

908 An empty catalog. 

909 """ 

910 table = self._schema.tableSchemas[table_name] 

911 

912 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns} 

913 return pandas.DataFrame(data) 

914 

915 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement: 

916 """Convert query string into prepared statement.""" 

917 stmt = self._prepared_statements.get(query) 

918 if stmt is None: 

919 stmt = self._session.prepare(query) 

920 self._prepared_statements[query] = stmt 

921 return stmt 

922 

923 def _combine_where( 

924 self, 

925 prefix: str, 

926 where1: List[Tuple[str, Tuple]], 

927 where2: List[Tuple[str, Tuple]], 

928 suffix: Optional[str] = None, 

929 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]: 

930 """Make cartesian product of two parts of WHERE clause into a series 

931 of statements to execute. 

932 

933 Parameters 

934 ---------- 

935 prefix : `str` 

936 Initial statement prefix that comes before WHERE clause, e.g. 

937 "SELECT * from Table" 

938 """ 

939 # If lists are empty use special sentinels. 

940 if not where1: 

941 where1 = [("", ())] 

942 if not where2: 

943 where2 = [("", ())] 

944 

945 for expr1, params1 in where1: 

946 for expr2, params2 in where2: 

947 full_query = prefix 

948 wheres = [] 

949 if expr1: 

950 wheres.append(expr1) 

951 if expr2: 

952 wheres.append(expr2) 

953 if wheres: 

954 full_query += " WHERE " + " AND ".join(wheres) 

955 if suffix: 

956 full_query += " " + suffix 

957 params = params1 + params2 

958 if params: 

959 statement = self._prep_statement(full_query) 

960 else: 

961 # If there are no params then it is likely that query 

962 # has a bunch of literals rendered already, no point 

963 # trying to prepare it. 

964 statement = cassandra.query.SimpleStatement(full_query) 

965 yield (statement, params) 

966 

967 def _spatial_where( 

968 self, region: Optional[sphgeom.Region], use_ranges: bool = False 

969 ) -> List[Tuple[str, Tuple]]: 

970 """Generate expressions for spatial part of WHERE clause. 

971 

972 Parameters 

973 ---------- 

974 region : `sphgeom.Region` 

975 Spatial region for query results. 

976 use_ranges : `bool` 

977 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

978 p2") instead of exact list of pixels. Should be set to True for 

979 large regions covering very many pixels. 

980 

981 Returns 

982 ------- 

983 expressions : `list` [ `tuple` ] 

984 Empty list is returned if ``region`` is `None`, otherwise a list 

985 of one or more (expression, parameters) tuples 

986 """ 

987 if region is None: 

988 return [] 

989 if use_ranges: 

990 pixel_ranges = self._pixelization.envelope(region) 

991 expressions: List[Tuple[str, Tuple]] = [] 

992 for lower, upper in pixel_ranges: 

993 upper -= 1 

994 if lower == upper: 

995 expressions.append(('"apdb_part" = ?', (lower, ))) 

996 else: 

997 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

998 return expressions 

999 else: 

1000 pixels = self._pixelization.pixels(region) 

1001 if self.config.query_per_spatial_part: 

1002 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1003 else: 

1004 pixels_str = ",".join([str(pix) for pix in pixels]) 

1005 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1006 

1007 def _temporal_where( 

1008 self, 

1009 table: ApdbTables, 

1010 start_time: Union[float, dafBase.DateTime], 

1011 end_time: Union[float, dafBase.DateTime], 

1012 query_per_time_part: Optional[bool] = None, 

1013 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]: 

1014 """Generate table names and expressions for temporal part of WHERE 

1015 clauses. 

1016 

1017 Parameters 

1018 ---------- 

1019 table : `ApdbTables` 

1020 Table to select from. 

1021 start_time : `dafBase.DateTime` or `float` 

1022 Starting Datetime of MJD value of the time range. 

1023 start_time : `dafBase.DateTime` or `float` 

1024 Starting Datetime of MJD value of the time range. 

1025 query_per_time_part : `bool`, optional 

1026 If None then use ``query_per_time_part`` from configuration. 

1027 

1028 Returns 

1029 ------- 

1030 tables : `list` [ `str` ] 

1031 List of the table names to query. 

1032 expressions : `list` [ `tuple` ] 

1033 A list of zero or more (expression, parameters) tuples. 

1034 """ 

1035 tables: List[str] 

1036 temporal_where: List[Tuple[str, Tuple]] = [] 

1037 table_name = self._schema.tableName(table) 

1038 time_part_start = self._time_partition(start_time) 

1039 time_part_end = self._time_partition(end_time) 

1040 time_parts = list(range(time_part_start, time_part_end + 1)) 

1041 if self.config.time_partition_tables: 

1042 tables = [f"{table_name}_{part}" for part in time_parts] 

1043 else: 

1044 tables = [table_name] 

1045 if query_per_time_part is None: 

1046 query_per_time_part = self.config.query_per_time_part 

1047 if query_per_time_part: 

1048 temporal_where = [ 

1049 ('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts 

1050 ] 

1051 else: 

1052 time_part_list = ",".join([str(part) for part in time_parts]) 

1053 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1054 

1055 return tables, temporal_where