Coverage for python/lsst/dax/apdb/apdbCassandra.py: 15%

406 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-30 01:30 -0800

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import logging 

27from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast 

28 

29import numpy as np 

30import pandas 

31 

32# If cassandra-driver is not there the module can still be imported 

33# but ApdbCassandra cannot be instantiated. 

34try: 

35 import cassandra 

36 import cassandra.query 

37 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile 

38 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

39 CASSANDRA_IMPORTED = True 

40except ImportError: 

41 CASSANDRA_IMPORTED = False 

42 

43import felis.types 

44import lsst.daf.base as dafBase 

45from felis.simple import Table 

46from lsst import sphgeom 

47from lsst.pex.config import ChoiceField, Field, ListField 

48from lsst.utils.iteration import chunk_iterable 

49 

50from .apdb import Apdb, ApdbConfig 

51from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

52from .apdbSchema import ApdbTables 

53from .cassandra_utils import literal, pandas_dataframe_factory, quote_id, raw_data_factory, select_concurrent 

54from .pixelization import Pixelization 

55from .timer import Timer 

56 

57_LOG = logging.getLogger(__name__) 

58 

59 

60class CassandraMissingError(Exception): 

61 def __init__(self) -> None: 

62 super().__init__("cassandra-driver module cannot be imported") 

63 

64 

65class ApdbCassandraConfig(ApdbConfig): 

66 

67 contact_points = ListField[str]( 

68 doc="The list of contact points to try connecting for cluster discovery.", 

69 default=["127.0.0.1"] 

70 ) 

71 private_ips = ListField[str]( 

72 doc="List of internal IP addresses for contact_points.", 

73 default=[] 

74 ) 

75 keyspace = Field[str]( 

76 doc="Default keyspace for operations.", 

77 default="apdb" 

78 ) 

79 read_consistency = Field[str]( 

80 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", 

81 default="QUORUM" 

82 ) 

83 write_consistency = Field[str]( 

84 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", 

85 default="QUORUM" 

86 ) 

87 read_timeout = Field[float]( 

88 doc="Timeout in seconds for read operations.", 

89 default=120. 

90 ) 

91 write_timeout = Field[float]( 

92 doc="Timeout in seconds for write operations.", 

93 default=10. 

94 ) 

95 read_concurrency = Field[int]( 

96 doc="Concurrency level for read operations.", 

97 default=500 

98 ) 

99 protocol_version = Field[int]( 

100 doc="Cassandra protocol version to use, default is V4", 

101 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0 

102 ) 

103 dia_object_columns = ListField[str]( 

104 doc="List of columns to read from DiaObject, by default read all columns", 

105 default=[] 

106 ) 

107 prefix = Field[str]( 

108 doc="Prefix to add to table names", 

109 default="" 

110 ) 

111 part_pixelization = ChoiceField[str]( 

112 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

113 doc="Pixelization used for partitioning index.", 

114 default="mq3c" 

115 ) 

116 part_pix_level = Field[int]( 

117 doc="Pixelization level used for partitioning index.", 

118 default=10 

119 ) 

120 part_pix_max_ranges = Field[int]( 

121 doc="Max number of ranges in pixelization envelope", 

122 default=64 

123 ) 

124 ra_dec_columns = ListField[str]( 

125 default=["ra", "decl"], 

126 doc="Names ra/dec columns in DiaObject table" 

127 ) 

128 timer = Field[bool]( 

129 doc="If True then print/log timing information", 

130 default=False 

131 ) 

132 time_partition_tables = Field[bool]( 

133 doc="Use per-partition tables for sources instead of partitioning by time", 

134 default=True 

135 ) 

136 time_partition_days = Field[int]( 

137 doc="Time partitioning granularity in days, this value must not be changed" 

138 " after database is initialized", 

139 default=30 

140 ) 

141 time_partition_start = Field[str]( 

142 doc="Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI." 

143 " This is used only when time_partition_tables is True.", 

144 default="2018-12-01T00:00:00" 

145 ) 

146 time_partition_end = Field[str]( 

147 doc="Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI" 

148 " This is used only when time_partition_tables is True.", 

149 default="2030-01-01T00:00:00" 

150 ) 

151 query_per_time_part = Field[bool]( 

152 default=False, 

153 doc="If True then build separate query for each time partition, otherwise build one single query. " 

154 "This is only used when time_partition_tables is False in schema config." 

155 ) 

156 query_per_spatial_part = Field[bool]( 

157 default=False, 

158 doc="If True then build one query per spacial partition, otherwise build single query. " 

159 ) 

160 pandas_delay_conv = Field[bool]( 

161 default=True, 

162 doc="If True then combine result rows before converting to pandas. " 

163 ) 

164 

165 

166if CASSANDRA_IMPORTED: 166 ↛ 168line 166 didn't jump to line 168, because the condition on line 166 was never true

167 

168 class _AddressTranslator(AddressTranslator): 

169 """Translate internal IP address to external. 

170 

171 Only used for docker-based setup, not viable long-term solution. 

172 """ 

173 def __init__(self, public_ips: List[str], private_ips: List[str]): 

174 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

175 

176 def translate(self, private_ip: str) -> str: 

177 return self._map.get(private_ip, private_ip) 

178 

179 

180class ApdbCassandra(Apdb): 

181 """Implementation of APDB database on to of Apache Cassandra. 

182 

183 The implementation is configured via standard ``pex_config`` mechanism 

184 using `ApdbCassandraConfig` configuration class. For an example of 

185 different configurations check config/ folder. 

186 

187 Parameters 

188 ---------- 

189 config : `ApdbCassandraConfig` 

190 Configuration object. 

191 """ 

192 

193 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

194 """Start time for partition 0, this should never be changed.""" 

195 

196 def __init__(self, config: ApdbCassandraConfig): 

197 

198 if not CASSANDRA_IMPORTED: 

199 raise CassandraMissingError() 

200 

201 config.validate() 

202 self.config = config 

203 

204 _LOG.debug("ApdbCassandra Configuration:") 

205 for key, value in self.config.items(): 

206 _LOG.debug(" %s: %s", key, value) 

207 

208 self._pixelization = Pixelization( 

209 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges 

210 ) 

211 

212 addressTranslator: Optional[AddressTranslator] = None 

213 if config.private_ips: 

214 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

215 

216 self._keyspace = config.keyspace 

217 

218 self._cluster = Cluster(execution_profiles=self._makeProfiles(config), 

219 contact_points=self.config.contact_points, 

220 address_translator=addressTranslator, 

221 protocol_version=self.config.protocol_version) 

222 self._session = self._cluster.connect() 

223 # Disable result paging 

224 self._session.default_fetch_size = None 

225 

226 self._schema = ApdbCassandraSchema(session=self._session, 

227 keyspace=self._keyspace, 

228 schema_file=self.config.schema_file, 

229 schema_name=self.config.schema_name, 

230 prefix=self.config.prefix, 

231 time_partition_tables=self.config.time_partition_tables) 

232 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

233 

234 # Cache for prepared statements 

235 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {} 

236 

237 def __del__(self) -> None: 

238 self._cluster.shutdown() 

239 

240 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

241 # docstring is inherited from a base class 

242 return self._schema.tableSchemas.get(table) 

243 

244 def makeSchema(self, drop: bool = False) -> None: 

245 # docstring is inherited from a base class 

246 

247 if self.config.time_partition_tables: 

248 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

249 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

250 part_range = ( 

251 self._time_partition(time_partition_start), 

252 self._time_partition(time_partition_end) + 1 

253 ) 

254 self._schema.makeSchema(drop=drop, part_range=part_range) 

255 else: 

256 self._schema.makeSchema(drop=drop) 

257 

258 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

259 # docstring is inherited from a base class 

260 

261 sp_where = self._spatial_where(region) 

262 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

263 

264 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

265 query = f'SELECT * from "{self._keyspace}"."{table_name}"' 

266 statements: List[Tuple] = [] 

267 for where, params in sp_where: 

268 full_query = f"{query} WHERE {where}" 

269 if params: 

270 statement = self._prep_statement(full_query) 

271 else: 

272 # If there are no params then it is likely that query has a 

273 # bunch of literals rendered already, no point trying to 

274 # prepare it because it's not reusable. 

275 statement = cassandra.query.SimpleStatement(full_query) 

276 statements.append((statement, params)) 

277 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

278 

279 with Timer('DiaObject select', self.config.timer): 

280 objects = cast( 

281 pandas.DataFrame, 

282 select_concurrent( 

283 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

284 ) 

285 ) 

286 

287 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

288 return objects 

289 

290 def getDiaSources(self, region: sphgeom.Region, 

291 object_ids: Optional[Iterable[int]], 

292 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

293 # docstring is inherited from a base class 

294 months = self.config.read_sources_months 

295 if months == 0: 

296 return None 

297 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

298 mjd_start = mjd_end - months*30 

299 

300 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

301 

302 def getDiaForcedSources(self, region: sphgeom.Region, 

303 object_ids: Optional[Iterable[int]], 

304 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

305 # docstring is inherited from a base class 

306 months = self.config.read_forced_sources_months 

307 if months == 0: 

308 return None 

309 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

310 mjd_start = mjd_end - months*30 

311 

312 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

313 

314 def getDiaObjectsHistory(self, 

315 start_time: dafBase.DateTime, 

316 end_time: dafBase.DateTime, 

317 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

318 # docstring is inherited from a base class 

319 

320 sp_where = self._spatial_where(region, use_ranges=True) 

321 tables, temporal_where = self._temporal_where(ApdbTables.DiaObject, start_time, end_time, True) 

322 

323 # Build all queries 

324 statements: List[Tuple] = [] 

325 for table in tables: 

326 prefix = f'SELECT * from "{self._keyspace}"."{table}"' 

327 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING")) 

328 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements)) 

329 

330 # Run all selects in parallel 

331 with Timer("DiaObject history", self.config.timer): 

332 catalog = cast( 

333 pandas.DataFrame, 

334 select_concurrent( 

335 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

336 ) 

337 ) 

338 

339 # precise filtering on validityStart 

340 validity_start = start_time.toPython() 

341 validity_end = end_time.toPython() 

342 catalog = cast( 

343 pandas.DataFrame, 

344 catalog[(catalog["validityStart"] >= validity_start) & (catalog["validityStart"] < validity_end)] 

345 ) 

346 

347 _LOG.debug("found %d DiaObjects", catalog.shape[0]) 

348 return catalog 

349 

350 def getDiaSourcesHistory(self, 

351 start_time: dafBase.DateTime, 

352 end_time: dafBase.DateTime, 

353 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

354 # docstring is inherited from a base class 

355 return self._getSourcesHistory(ApdbTables.DiaSource, start_time, end_time, region) 

356 

357 def getDiaForcedSourcesHistory(self, 

358 start_time: dafBase.DateTime, 

359 end_time: dafBase.DateTime, 

360 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

361 # docstring is inherited from a base class 

362 return self._getSourcesHistory(ApdbTables.DiaForcedSource, start_time, end_time, region) 

363 

364 def getSSObjects(self) -> pandas.DataFrame: 

365 # docstring is inherited from a base class 

366 tableName = self._schema.tableName(ApdbTables.SSObject) 

367 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

368 

369 objects = None 

370 with Timer('SSObject select', self.config.timer): 

371 result = self._session.execute(query, execution_profile="read_pandas") 

372 objects = result._current_rows 

373 

374 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

375 return objects 

376 

377 def store(self, 

378 visit_time: dafBase.DateTime, 

379 objects: pandas.DataFrame, 

380 sources: Optional[pandas.DataFrame] = None, 

381 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

382 # docstring is inherited from a base class 

383 

384 # fill region partition column for DiaObjects 

385 objects = self._add_obj_part(objects) 

386 self._storeDiaObjects(objects, visit_time) 

387 

388 if sources is not None: 

389 # copy apdb_part column from DiaObjects to DiaSources 

390 sources = self._add_src_part(sources, objects) 

391 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time) 

392 self._storeDiaSourcesPartitions(sources, visit_time) 

393 

394 if forced_sources is not None: 

395 forced_sources = self._add_fsrc_part(forced_sources, objects) 

396 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time) 

397 

398 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

399 # docstring is inherited from a base class 

400 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

401 

402 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

403 # docstring is inherited from a base class 

404 

405 # To update a record we need to know its exact primary key (including 

406 # partition key) so we start by querying for diaSourceId to find the 

407 # primary keys. 

408 

409 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

410 # split it into 1k IDs per query 

411 selects: List[Tuple] = [] 

412 for ids in chunk_iterable(idMap.keys(), 1_000): 

413 ids_str = ",".join(str(item) for item in ids) 

414 selects.append(( 

415 (f'SELECT "diaSourceId", "apdb_part", "apdb_time_part" FROM "{self._keyspace}"."{table_name}"' 

416 f' WHERE "diaSourceId" IN ({ids_str})'), 

417 {} 

418 )) 

419 

420 # No need for DataFrame here, read data as tuples. 

421 result = cast( 

422 List[Tuple[int, int, int]], 

423 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency) 

424 ) 

425 

426 # Make mapping from source ID to its partition. 

427 id2partitions: Dict[int, Tuple[int, int]] = {} 

428 for row in result: 

429 id2partitions[row[0]] = row[1:] 

430 

431 # make sure we know partitions for each ID 

432 if set(id2partitions) != set(idMap): 

433 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

434 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

435 

436 queries = cassandra.query.BatchStatement() 

437 table_name = self._schema.tableName(ApdbTables.DiaSource) 

438 for diaSourceId, ssObjectId in idMap.items(): 

439 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

440 values: Tuple 

441 if self.config.time_partition_tables: 

442 query = ( 

443 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

444 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

445 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

446 ) 

447 values = (ssObjectId, apdb_part, diaSourceId) 

448 else: 

449 query = ( 

450 f'UPDATE "{self._keyspace}"."{table_name}"' 

451 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

452 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

453 ) 

454 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

455 queries.add(self._prep_statement(query), values) 

456 

457 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

458 with Timer(table_name + ' update', self.config.timer): 

459 self._session.execute(queries, execution_profile="write") 

460 

461 def dailyJob(self) -> None: 

462 # docstring is inherited from a base class 

463 pass 

464 

465 def countUnassociatedObjects(self) -> int: 

466 # docstring is inherited from a base class 

467 

468 # It's too inefficient to implement it for Cassandra in current schema. 

469 raise NotImplementedError() 

470 

471 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

472 """Make all execution profiles used in the code.""" 

473 

474 if config.private_ips: 

475 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

476 else: 

477 loadBalancePolicy = RoundRobinPolicy() 

478 

479 pandas_row_factory: Callable 

480 if not config.pandas_delay_conv: 

481 pandas_row_factory = pandas_dataframe_factory 

482 else: 

483 pandas_row_factory = raw_data_factory 

484 

485 read_tuples_profile = ExecutionProfile( 

486 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

487 request_timeout=config.read_timeout, 

488 row_factory=cassandra.query.tuple_factory, 

489 load_balancing_policy=loadBalancePolicy, 

490 ) 

491 read_pandas_profile = ExecutionProfile( 

492 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

493 request_timeout=config.read_timeout, 

494 row_factory=pandas_dataframe_factory, 

495 load_balancing_policy=loadBalancePolicy, 

496 ) 

497 read_pandas_multi_profile = ExecutionProfile( 

498 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

499 request_timeout=config.read_timeout, 

500 row_factory=pandas_row_factory, 

501 load_balancing_policy=loadBalancePolicy, 

502 ) 

503 write_profile = ExecutionProfile( 

504 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

505 request_timeout=config.write_timeout, 

506 load_balancing_policy=loadBalancePolicy, 

507 ) 

508 # To replace default DCAwareRoundRobinPolicy 

509 default_profile = ExecutionProfile( 

510 load_balancing_policy=loadBalancePolicy, 

511 ) 

512 return { 

513 "read_tuples": read_tuples_profile, 

514 "read_pandas": read_pandas_profile, 

515 "read_pandas_multi": read_pandas_multi_profile, 

516 "write": write_profile, 

517 EXEC_PROFILE_DEFAULT: default_profile, 

518 } 

519 

520 def _getSources(self, region: sphgeom.Region, 

521 object_ids: Optional[Iterable[int]], 

522 mjd_start: float, 

523 mjd_end: float, 

524 table_name: ApdbTables) -> pandas.DataFrame: 

525 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

526 

527 Parameters 

528 ---------- 

529 region : `lsst.sphgeom.Region` 

530 Spherical region. 

531 object_ids : 

532 Collection of DiaObject IDs 

533 mjd_start : `float` 

534 Lower bound of time interval. 

535 mjd_end : `float` 

536 Upper bound of time interval. 

537 table_name : `ApdbTables` 

538 Name of the table. 

539 

540 Returns 

541 ------- 

542 catalog : `pandas.DataFrame`, or `None` 

543 Catalog contaning DiaSource records. Empty catalog is returned if 

544 ``object_ids`` is empty. 

545 """ 

546 object_id_set: Set[int] = set() 

547 if object_ids is not None: 

548 object_id_set = set(object_ids) 

549 if len(object_id_set) == 0: 

550 return self._make_empty_catalog(table_name) 

551 

552 sp_where = self._spatial_where(region) 

553 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

554 

555 # Build all queries 

556 statements: List[Tuple] = [] 

557 for table in tables: 

558 prefix = f'SELECT * from "{self._keyspace}"."{table}"' 

559 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

560 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

561 

562 with Timer(table_name.name + ' select', self.config.timer): 

563 catalog = cast( 

564 pandas.DataFrame, 

565 select_concurrent( 

566 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

567 ) 

568 ) 

569 

570 # filter by given object IDs 

571 if len(object_id_set) > 0: 

572 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

573 

574 # precise filtering on midPointTai 

575 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start]) 

576 

577 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

578 return catalog 

579 

580 def _getSourcesHistory( 

581 self, 

582 table: ApdbTables, 

583 start_time: dafBase.DateTime, 

584 end_time: dafBase.DateTime, 

585 region: Optional[sphgeom.Region] = None, 

586 ) -> pandas.DataFrame: 

587 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

588 

589 Parameters 

590 ---------- 

591 table : `ApdbTables` 

592 Name of the table. 

593 start_time : `dafBase.DateTime` 

594 Starting time for DiaSource history search. DiaSource record is 

595 selected when its ``midPointTai`` falls into an interval between 

596 ``start_time`` (inclusive) and ``end_time`` (exclusive). 

597 end_time : `dafBase.DateTime` 

598 Upper limit on time for DiaSource history search. 

599 region : `lsst.sphgeom.Region` 

600 Spherical region. 

601 

602 Returns 

603 ------- 

604 catalog : `pandas.DataFrame` 

605 Catalog contaning DiaSource records. 

606 """ 

607 sp_where = self._spatial_where(region, use_ranges=False) 

608 tables, temporal_where = self._temporal_where(table, start_time, end_time, True) 

609 

610 # Build all queries 

611 statements: List[Tuple] = [] 

612 for table_name in tables: 

613 prefix = f'SELECT * from "{self._keyspace}"."{table_name}"' 

614 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING")) 

615 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements)) 

616 

617 # Run all selects in parallel 

618 with Timer(f"{table.name} history", self.config.timer): 

619 catalog = cast( 

620 pandas.DataFrame, 

621 select_concurrent( 

622 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

623 ) 

624 ) 

625 

626 # precise filtering on validityStart 

627 period_start = start_time.get(system=dafBase.DateTime.MJD) 

628 period_end = end_time.get(system=dafBase.DateTime.MJD) 

629 catalog = cast( 

630 pandas.DataFrame, 

631 catalog[(catalog["midPointTai"] >= period_start) & (catalog["midPointTai"] < period_end)] 

632 ) 

633 

634 _LOG.debug("found %d %ss", catalog.shape[0], table.name) 

635 return catalog 

636 

637 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

638 """Store catalog of DiaObjects from current visit. 

639 

640 Parameters 

641 ---------- 

642 objs : `pandas.DataFrame` 

643 Catalog with DiaObject records 

644 visit_time : `lsst.daf.base.DateTime` 

645 Time of the current visit. 

646 """ 

647 visit_time_dt = visit_time.toPython() 

648 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

649 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

650 

651 extra_columns["validityStart"] = visit_time_dt 

652 time_part: Optional[int] = self._time_partition(visit_time) 

653 if not self.config.time_partition_tables: 

654 extra_columns["apdb_time_part"] = time_part 

655 time_part = None 

656 

657 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part) 

658 

659 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame, 

660 visit_time: dafBase.DateTime) -> None: 

661 """Store catalog of DIASources or DIAForcedSources from current visit. 

662 

663 Parameters 

664 ---------- 

665 sources : `pandas.DataFrame` 

666 Catalog containing DiaSource records 

667 visit_time : `lsst.daf.base.DateTime` 

668 Time of the current visit. 

669 """ 

670 time_part: Optional[int] = self._time_partition(visit_time) 

671 extra_columns = {} 

672 if not self.config.time_partition_tables: 

673 extra_columns["apdb_time_part"] = time_part 

674 time_part = None 

675 

676 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

677 

678 def _storeDiaSourcesPartitions(self, sources: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

679 """Store mapping of diaSourceId to its partitioning values. 

680 

681 Parameters 

682 ---------- 

683 sources : `pandas.DataFrame` 

684 Catalog containing DiaSource records 

685 visit_time : `lsst.daf.base.DateTime` 

686 Time of the current visit. 

687 """ 

688 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

689 extra_columns = { 

690 "apdb_time_part": self._time_partition(visit_time), 

691 } 

692 

693 self._storeObjectsPandas( 

694 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

695 ) 

696 

697 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: Union[ApdbTables, ExtraTables], 

698 extra_columns: Optional[Mapping] = None, 

699 time_part: Optional[int] = None) -> None: 

700 """Generic store method. 

701 

702 Takes catalog of records and stores a bunch of objects in a table. 

703 

704 Parameters 

705 ---------- 

706 objects : `pandas.DataFrame` 

707 Catalog containing object records 

708 table_name : `ApdbTables` 

709 Name of the table as defined in APDB schema. 

710 extra_columns : `dict`, optional 

711 Mapping (column_name, column_value) which gives column values to add 

712 to every row, only if column is missing in catalog records. 

713 time_part : `int`, optional 

714 If not `None` then insert into a per-partition table. 

715 """ 

716 # use extra columns if specified 

717 if extra_columns is None: 

718 extra_columns = {} 

719 extra_fields = list(extra_columns.keys()) 

720 

721 df_fields = [ 

722 column for column in objects.columns if column not in extra_fields 

723 ] 

724 

725 column_map = self._schema.getColumnMap(table_name) 

726 # list of columns (as in cat schema) 

727 fields = [column_map[field].name for field in df_fields if field in column_map] 

728 fields += extra_fields 

729 

730 # check that all partitioning and clustering columns are defined 

731 required_columns = self._schema.partitionColumns(table_name) \ 

732 + self._schema.clusteringColumns(table_name) 

733 missing_columns = [column for column in required_columns if column not in fields] 

734 if missing_columns: 

735 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

736 

737 qfields = [quote_id(field) for field in fields] 

738 qfields_str = ','.join(qfields) 

739 

740 with Timer(table_name.name + ' query build', self.config.timer): 

741 

742 table = self._schema.tableName(table_name) 

743 if time_part is not None: 

744 table = f"{table}_{time_part}" 

745 

746 holders = ','.join(['?']*len(qfields)) 

747 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

748 statement = self._prep_statement(query) 

749 queries = cassandra.query.BatchStatement() 

750 for rec in objects.itertuples(index=False): 

751 values = [] 

752 for field in df_fields: 

753 if field not in column_map: 

754 continue 

755 value = getattr(rec, field) 

756 if column_map[field].datatype is felis.types.Timestamp: 

757 if isinstance(value, pandas.Timestamp): 

758 value = literal(value.to_pydatetime()) 

759 else: 

760 # Assume it's seconds since epoch, Cassandra 

761 # datetime is in milliseconds 

762 value = int(value*1000) 

763 values.append(literal(value)) 

764 for field in extra_fields: 

765 value = extra_columns[field] 

766 values.append(literal(value)) 

767 queries.add(statement, values) 

768 

769 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0]) 

770 with Timer(table_name.name + ' insert', self.config.timer): 

771 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

772 

773 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

774 """Calculate spacial partition for each record and add it to a 

775 DataFrame. 

776 

777 Notes 

778 ----- 

779 This overrides any existing column in a DataFrame with the same name 

780 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

781 returned. 

782 """ 

783 # calculate HTM index for every DiaObject 

784 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

785 ra_col, dec_col = self.config.ra_dec_columns 

786 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

787 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

788 idx = self._pixelization.pixel(uv3d) 

789 apdb_part[i] = idx 

790 df = df.copy() 

791 df["apdb_part"] = apdb_part 

792 return df 

793 

794 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

795 """Add apdb_part column to DiaSource catalog. 

796 

797 Notes 

798 ----- 

799 This method copies apdb_part value from a matching DiaObject record. 

800 DiaObject catalog needs to have a apdb_part column filled by 

801 ``_add_obj_part`` method and DiaSource records need to be 

802 associated to DiaObjects via ``diaObjectId`` column. 

803 

804 This overrides any existing column in a DataFrame with the same name 

805 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

806 returned. 

807 """ 

808 pixel_id_map: Dict[int, int] = { 

809 diaObjectId: apdb_part for diaObjectId, apdb_part 

810 in zip(objs["diaObjectId"], objs["apdb_part"]) 

811 } 

812 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

813 ra_col, dec_col = self.config.ra_dec_columns 

814 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"], 

815 sources[ra_col], sources[dec_col])): 

816 if diaObjId == 0: 

817 # DiaSources associated with SolarSystemObjects do not have an 

818 # associated DiaObject hence we skip them and set partition 

819 # based on its own ra/dec 

820 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

821 idx = self._pixelization.pixel(uv3d) 

822 apdb_part[i] = idx 

823 else: 

824 apdb_part[i] = pixel_id_map[diaObjId] 

825 sources = sources.copy() 

826 sources["apdb_part"] = apdb_part 

827 return sources 

828 

829 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

830 """Add apdb_part column to DiaForcedSource catalog. 

831 

832 Notes 

833 ----- 

834 This method copies apdb_part value from a matching DiaObject record. 

835 DiaObject catalog needs to have a apdb_part column filled by 

836 ``_add_obj_part`` method and DiaSource records need to be 

837 associated to DiaObjects via ``diaObjectId`` column. 

838 

839 This overrides any existing column in a DataFrame with the same name 

840 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

841 returned. 

842 """ 

843 pixel_id_map: Dict[int, int] = { 

844 diaObjectId: apdb_part for diaObjectId, apdb_part 

845 in zip(objs["diaObjectId"], objs["apdb_part"]) 

846 } 

847 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

848 for i, diaObjId in enumerate(sources["diaObjectId"]): 

849 apdb_part[i] = pixel_id_map[diaObjId] 

850 sources = sources.copy() 

851 sources["apdb_part"] = apdb_part 

852 return sources 

853 

854 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

855 """Calculate time partiton number for a given time. 

856 

857 Parameters 

858 ---------- 

859 time : `float` or `lsst.daf.base.DateTime` 

860 Time for which to calculate partition number. Can be float to mean 

861 MJD or `lsst.daf.base.DateTime` 

862 

863 Returns 

864 ------- 

865 partition : `int` 

866 Partition number for a given time. 

867 """ 

868 if isinstance(time, dafBase.DateTime): 

869 mjd = time.get(system=dafBase.DateTime.MJD) 

870 else: 

871 mjd = time 

872 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

873 partition = int(days_since_epoch) // self.config.time_partition_days 

874 return partition 

875 

876 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

877 """Make an empty catalog for a table with a given name. 

878 

879 Parameters 

880 ---------- 

881 table_name : `ApdbTables` 

882 Name of the table. 

883 

884 Returns 

885 ------- 

886 catalog : `pandas.DataFrame` 

887 An empty catalog. 

888 """ 

889 table = self._schema.tableSchemas[table_name] 

890 

891 data = { 

892 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

893 for columnDef in table.columns 

894 } 

895 return pandas.DataFrame(data) 

896 

897 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement: 

898 """Convert query string into prepared statement.""" 

899 stmt = self._prepared_statements.get(query) 

900 if stmt is None: 

901 stmt = self._session.prepare(query) 

902 self._prepared_statements[query] = stmt 

903 return stmt 

904 

905 def _combine_where( 

906 self, 

907 prefix: str, 

908 where1: List[Tuple[str, Tuple]], 

909 where2: List[Tuple[str, Tuple]], 

910 suffix: Optional[str] = None, 

911 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]: 

912 """Make cartesian product of two parts of WHERE clause into a series 

913 of statements to execute. 

914 

915 Parameters 

916 ---------- 

917 prefix : `str` 

918 Initial statement prefix that comes before WHERE clause, e.g. 

919 "SELECT * from Table" 

920 """ 

921 # If lists are empty use special sentinels. 

922 if not where1: 

923 where1 = [("", ())] 

924 if not where2: 

925 where2 = [("", ())] 

926 

927 for expr1, params1 in where1: 

928 for expr2, params2 in where2: 

929 full_query = prefix 

930 wheres = [] 

931 if expr1: 

932 wheres.append(expr1) 

933 if expr2: 

934 wheres.append(expr2) 

935 if wheres: 

936 full_query += " WHERE " + " AND ".join(wheres) 

937 if suffix: 

938 full_query += " " + suffix 

939 params = params1 + params2 

940 if params: 

941 statement = self._prep_statement(full_query) 

942 else: 

943 # If there are no params then it is likely that query 

944 # has a bunch of literals rendered already, no point 

945 # trying to prepare it. 

946 statement = cassandra.query.SimpleStatement(full_query) 

947 yield (statement, params) 

948 

949 def _spatial_where( 

950 self, region: Optional[sphgeom.Region], use_ranges: bool = False 

951 ) -> List[Tuple[str, Tuple]]: 

952 """Generate expressions for spatial part of WHERE clause. 

953 

954 Parameters 

955 ---------- 

956 region : `sphgeom.Region` 

957 Spatial region for query results. 

958 use_ranges : `bool` 

959 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

960 p2") instead of exact list of pixels. Should be set to True for 

961 large regions covering very many pixels. 

962 

963 Returns 

964 ------- 

965 expressions : `list` [ `tuple` ] 

966 Empty list is returned if ``region`` is `None`, otherwise a list 

967 of one or more (expression, parameters) tuples 

968 """ 

969 if region is None: 

970 return [] 

971 if use_ranges: 

972 pixel_ranges = self._pixelization.envelope(region) 

973 expressions: List[Tuple[str, Tuple]] = [] 

974 for lower, upper in pixel_ranges: 

975 upper -= 1 

976 if lower == upper: 

977 expressions.append(('"apdb_part" = ?', (lower, ))) 

978 else: 

979 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

980 return expressions 

981 else: 

982 pixels = self._pixelization.pixels(region) 

983 if self.config.query_per_spatial_part: 

984 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

985 else: 

986 pixels_str = ",".join([str(pix) for pix in pixels]) 

987 return [(f'"apdb_part" IN ({pixels_str})', ())] 

988 

989 def _temporal_where( 

990 self, 

991 table: ApdbTables, 

992 start_time: Union[float, dafBase.DateTime], 

993 end_time: Union[float, dafBase.DateTime], 

994 query_per_time_part: Optional[bool] = None, 

995 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]: 

996 """Generate table names and expressions for temporal part of WHERE 

997 clauses. 

998 

999 Parameters 

1000 ---------- 

1001 table : `ApdbTables` 

1002 Table to select from. 

1003 start_time : `dafBase.DateTime` or `float` 

1004 Starting Datetime of MJD value of the time range. 

1005 start_time : `dafBase.DateTime` or `float` 

1006 Starting Datetime of MJD value of the time range. 

1007 query_per_time_part : `bool`, optional 

1008 If None then use ``query_per_time_part`` from configuration. 

1009 

1010 Returns 

1011 ------- 

1012 tables : `list` [ `str` ] 

1013 List of the table names to query. 

1014 expressions : `list` [ `tuple` ] 

1015 A list of zero or more (expression, parameters) tuples. 

1016 """ 

1017 tables: List[str] 

1018 temporal_where: List[Tuple[str, Tuple]] = [] 

1019 table_name = self._schema.tableName(table) 

1020 time_part_start = self._time_partition(start_time) 

1021 time_part_end = self._time_partition(end_time) 

1022 time_parts = list(range(time_part_start, time_part_end + 1)) 

1023 if self.config.time_partition_tables: 

1024 tables = [f"{table_name}_{part}" for part in time_parts] 

1025 else: 

1026 tables = [table_name] 

1027 if query_per_time_part is None: 

1028 query_per_time_part = self.config.query_per_time_part 

1029 if query_per_time_part: 

1030 temporal_where = [ 

1031 ('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts 

1032 ] 

1033 else: 

1034 time_part_list = ",".join([str(part) for part in time_parts]) 

1035 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1036 

1037 return tables, temporal_where