Coverage for python/lsst/dax/apdb/apdbCassandra.py: 14%

449 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-19 02:07 -0800

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26import logging 

27import uuid 

28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast 

29 

30import numpy as np 

31import pandas 

32 

33# If cassandra-driver is not there the module can still be imported 

34# but ApdbCassandra cannot be instantiated. 

35try: 

36 import cassandra 

37 import cassandra.query 

38 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile 

39 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

40 

41 CASSANDRA_IMPORTED = True 

42except ImportError: 

43 CASSANDRA_IMPORTED = False 

44 

45import felis.types 

46import lsst.daf.base as dafBase 

47from felis.simple import Table 

48from lsst import sphgeom 

49from lsst.pex.config import ChoiceField, Field, ListField 

50from lsst.utils.iteration import chunk_iterable 

51 

52from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

53from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables 

54from .apdbSchema import ApdbTables 

55from .cassandra_utils import ( 

56 ApdbCassandraTableData, 

57 literal, 

58 pandas_dataframe_factory, 

59 quote_id, 

60 raw_data_factory, 

61 select_concurrent, 

62) 

63from .pixelization import Pixelization 

64from .timer import Timer 

65 

66_LOG = logging.getLogger(__name__) 

67 

68 

69class CassandraMissingError(Exception): 

70 def __init__(self) -> None: 

71 super().__init__("cassandra-driver module cannot be imported") 

72 

73 

74class ApdbCassandraConfig(ApdbConfig): 

75 

76 contact_points = ListField[str]( 

77 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"] 

78 ) 

79 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[]) 

80 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb") 

81 read_consistency = Field[str]( 

82 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM" 

83 ) 

84 write_consistency = Field[str]( 

85 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM" 

86 ) 

87 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0) 

88 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0) 

89 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500) 

90 protocol_version = Field[int]( 

91 doc="Cassandra protocol version to use, default is V4", 

92 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0, 

93 ) 

94 dia_object_columns = ListField[str]( 

95 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[] 

96 ) 

97 prefix = Field[str](doc="Prefix to add to table names", default="") 

98 part_pixelization = ChoiceField[str]( 

99 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

100 doc="Pixelization used for partitioning index.", 

101 default="mq3c", 

102 ) 

103 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10) 

104 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64) 

105 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names ra/dec columns in DiaObject table") 

106 timer = Field[bool](doc="If True then print/log timing information", default=False) 

107 time_partition_tables = Field[bool]( 

108 doc="Use per-partition tables for sources instead of partitioning by time", default=True 

109 ) 

110 time_partition_days = Field[int]( 

111 doc=( 

112 "Time partitioning granularity in days, this value must not be changed after database is " 

113 "initialized" 

114 ), 

115 default=30, 

116 ) 

117 time_partition_start = Field[str]( 

118 doc=( 

119 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

120 "This is used only when time_partition_tables is True." 

121 ), 

122 default="2018-12-01T00:00:00", 

123 ) 

124 time_partition_end = Field[str]( 

125 doc=( 

126 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. " 

127 "This is used only when time_partition_tables is True." 

128 ), 

129 default="2030-01-01T00:00:00", 

130 ) 

131 query_per_time_part = Field[bool]( 

132 default=False, 

133 doc=( 

134 "If True then build separate query for each time partition, otherwise build one single query. " 

135 "This is only used when time_partition_tables is False in schema config." 

136 ), 

137 ) 

138 query_per_spatial_part = Field[bool]( 

139 default=False, 

140 doc="If True then build one query per spacial partition, otherwise build single query.", 

141 ) 

142 

143 

144if CASSANDRA_IMPORTED: 144 ↛ 146line 144 didn't jump to line 146, because the condition on line 144 was never true

145 

146 class _AddressTranslator(AddressTranslator): 

147 """Translate internal IP address to external. 

148 

149 Only used for docker-based setup, not viable long-term solution. 

150 """ 

151 

152 def __init__(self, public_ips: List[str], private_ips: List[str]): 

153 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

154 

155 def translate(self, private_ip: str) -> str: 

156 return self._map.get(private_ip, private_ip) 

157 

158 

159def _quote_column(name: str) -> str: 

160 """Quote column name""" 

161 if name.islower(): 

162 return name 

163 else: 

164 return f'"{name}"' 

165 

166 

167class ApdbCassandra(Apdb): 

168 """Implementation of APDB database on to of Apache Cassandra. 

169 

170 The implementation is configured via standard ``pex_config`` mechanism 

171 using `ApdbCassandraConfig` configuration class. For an example of 

172 different configurations check config/ folder. 

173 

174 Parameters 

175 ---------- 

176 config : `ApdbCassandraConfig` 

177 Configuration object. 

178 """ 

179 

180 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

181 """Start time for partition 0, this should never be changed.""" 

182 

183 def __init__(self, config: ApdbCassandraConfig): 

184 

185 if not CASSANDRA_IMPORTED: 

186 raise CassandraMissingError() 

187 

188 config.validate() 

189 self.config = config 

190 

191 _LOG.debug("ApdbCassandra Configuration:") 

192 for key, value in self.config.items(): 

193 _LOG.debug(" %s: %s", key, value) 

194 

195 self._pixelization = Pixelization( 

196 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges 

197 ) 

198 

199 addressTranslator: Optional[AddressTranslator] = None 

200 if config.private_ips: 

201 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

202 

203 self._keyspace = config.keyspace 

204 

205 self._cluster = Cluster( 

206 execution_profiles=self._makeProfiles(config), 

207 contact_points=self.config.contact_points, 

208 address_translator=addressTranslator, 

209 protocol_version=self.config.protocol_version, 

210 ) 

211 self._session = self._cluster.connect() 

212 # Disable result paging 

213 self._session.default_fetch_size = None 

214 

215 self._schema = ApdbCassandraSchema( 

216 session=self._session, 

217 keyspace=self._keyspace, 

218 schema_file=self.config.schema_file, 

219 schema_name=self.config.schema_name, 

220 prefix=self.config.prefix, 

221 time_partition_tables=self.config.time_partition_tables, 

222 use_insert_id=self.config.use_insert_id, 

223 ) 

224 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

225 

226 # Cache for prepared statements 

227 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {} 

228 

229 def __del__(self) -> None: 

230 self._cluster.shutdown() 

231 

232 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

233 # docstring is inherited from a base class 

234 return self._schema.tableSchemas.get(table) 

235 

236 def makeSchema(self, drop: bool = False) -> None: 

237 # docstring is inherited from a base class 

238 

239 if self.config.time_partition_tables: 

240 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

241 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

242 part_range = ( 

243 self._time_partition(time_partition_start), 

244 self._time_partition(time_partition_end) + 1, 

245 ) 

246 self._schema.makeSchema(drop=drop, part_range=part_range) 

247 else: 

248 self._schema.makeSchema(drop=drop) 

249 

250 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

251 # docstring is inherited from a base class 

252 

253 sp_where = self._spatial_where(region) 

254 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

255 

256 # We need to exclude extra partitioning columns from result. 

257 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

258 what = ",".join(_quote_column(column) for column in column_names) 

259 

260 table_name = self._schema.tableName(ApdbTables.DiaObjectLast) 

261 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"' 

262 statements: List[Tuple] = [] 

263 for where, params in sp_where: 

264 full_query = f"{query} WHERE {where}" 

265 if params: 

266 statement = self._prep_statement(full_query) 

267 else: 

268 # If there are no params then it is likely that query has a 

269 # bunch of literals rendered already, no point trying to 

270 # prepare it because it's not reusable. 

271 statement = cassandra.query.SimpleStatement(full_query) 

272 statements.append((statement, params)) 

273 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

274 

275 with Timer("DiaObject select", self.config.timer): 

276 objects = cast( 

277 pandas.DataFrame, 

278 select_concurrent( 

279 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

280 ), 

281 ) 

282 

283 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

284 return objects 

285 

286 def getDiaSources( 

287 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

288 ) -> Optional[pandas.DataFrame]: 

289 # docstring is inherited from a base class 

290 months = self.config.read_sources_months 

291 if months == 0: 

292 return None 

293 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

294 mjd_start = mjd_end - months * 30 

295 

296 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

297 

298 def getDiaForcedSources( 

299 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

300 ) -> Optional[pandas.DataFrame]: 

301 # docstring is inherited from a base class 

302 months = self.config.read_forced_sources_months 

303 if months == 0: 

304 return None 

305 mjd_end = visit_time.get(system=dafBase.DateTime.MJD) 

306 mjd_start = mjd_end - months * 30 

307 

308 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

309 

310 def getInsertIds(self) -> list[ApdbInsertId] | None: 

311 # docstring is inherited from a base class 

312 if not self._schema.has_insert_id: 

313 return None 

314 

315 # everything goes into a single partition 

316 partition = 0 

317 

318 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

319 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?' 

320 

321 result = self._session.execute( 

322 self._prep_statement(query), 

323 (partition,), 

324 timeout=self.config.read_timeout, 

325 execution_profile="read_tuples", 

326 ) 

327 # order by insert_time 

328 rows = sorted(result) 

329 return [ApdbInsertId(row[1]) for row in rows] 

330 

331 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

332 # docstring is inherited from a base class 

333 if not self._schema.has_insert_id: 

334 raise ValueError("APDB is not configured for history storage") 

335 

336 insert_ids = [id.id for id in ids] 

337 params = ",".join("?" * len(insert_ids)) 

338 

339 # everything goes into a single partition 

340 partition = 0 

341 

342 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

343 query = ( 

344 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})' 

345 ) 

346 

347 self._session.execute( 

348 self._prep_statement(query), 

349 [partition] + insert_ids, 

350 timeout=self.config.write_timeout, 

351 ) 

352 

353 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

354 # docstring is inherited from a base class 

355 return self._get_history(ExtraTables.DiaObjectInsertId, ids) 

356 

357 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

358 # docstring is inherited from a base class 

359 return self._get_history(ExtraTables.DiaSourceInsertId, ids) 

360 

361 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

362 # docstring is inherited from a base class 

363 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids) 

364 

365 def getSSObjects(self) -> pandas.DataFrame: 

366 # docstring is inherited from a base class 

367 tableName = self._schema.tableName(ApdbTables.SSObject) 

368 query = f'SELECT * from "{self._keyspace}"."{tableName}"' 

369 

370 objects = None 

371 with Timer("SSObject select", self.config.timer): 

372 result = self._session.execute(query, execution_profile="read_pandas") 

373 objects = result._current_rows 

374 

375 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

376 return objects 

377 

378 def store( 

379 self, 

380 visit_time: dafBase.DateTime, 

381 objects: pandas.DataFrame, 

382 sources: Optional[pandas.DataFrame] = None, 

383 forced_sources: Optional[pandas.DataFrame] = None, 

384 ) -> None: 

385 # docstring is inherited from a base class 

386 

387 insert_id: ApdbInsertId | None = None 

388 if self._schema.has_insert_id: 

389 insert_id = ApdbInsertId.new_insert_id() 

390 self._storeInsertId(insert_id, visit_time) 

391 

392 # fill region partition column for DiaObjects 

393 objects = self._add_obj_part(objects) 

394 self._storeDiaObjects(objects, visit_time, insert_id) 

395 

396 if sources is not None: 

397 # copy apdb_part column from DiaObjects to DiaSources 

398 sources = self._add_src_part(sources, objects) 

399 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id) 

400 self._storeDiaSourcesPartitions(sources, visit_time, insert_id) 

401 

402 if forced_sources is not None: 

403 forced_sources = self._add_fsrc_part(forced_sources, objects) 

404 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id) 

405 

406 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

407 # docstring is inherited from a base class 

408 self._storeObjectsPandas(objects, ApdbTables.SSObject) 

409 

410 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

411 # docstring is inherited from a base class 

412 

413 # To update a record we need to know its exact primary key (including 

414 # partition key) so we start by querying for diaSourceId to find the 

415 # primary keys. 

416 

417 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition) 

418 # split it into 1k IDs per query 

419 selects: List[Tuple] = [] 

420 for ids in chunk_iterable(idMap.keys(), 1_000): 

421 ids_str = ",".join(str(item) for item in ids) 

422 selects.append( 

423 ( 

424 ( 

425 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" ' 

426 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})' 

427 ), 

428 {}, 

429 ) 

430 ) 

431 

432 # No need for DataFrame here, read data as tuples. 

433 result = cast( 

434 List[Tuple[int, int, int, uuid.UUID | None]], 

435 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency), 

436 ) 

437 

438 # Make mapping from source ID to its partition. 

439 id2partitions: Dict[int, Tuple[int, int]] = {} 

440 id2insert_id: Dict[int, ApdbInsertId] = {} 

441 for row in result: 

442 id2partitions[row[0]] = row[1:3] 

443 if row[3] is not None: 

444 id2insert_id[row[0]] = ApdbInsertId(row[3]) 

445 

446 # make sure we know partitions for each ID 

447 if set(id2partitions) != set(idMap): 

448 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

449 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

450 

451 # Reassign in standard tables 

452 queries = cassandra.query.BatchStatement() 

453 table_name = self._schema.tableName(ApdbTables.DiaSource) 

454 for diaSourceId, ssObjectId in idMap.items(): 

455 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

456 values: Tuple 

457 if self.config.time_partition_tables: 

458 query = ( 

459 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"' 

460 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

461 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?' 

462 ) 

463 values = (ssObjectId, apdb_part, diaSourceId) 

464 else: 

465 query = ( 

466 f'UPDATE "{self._keyspace}"."{table_name}"' 

467 ' SET "ssObjectId" = ?, "diaObjectId" = NULL' 

468 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?' 

469 ) 

470 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId) 

471 queries.add(self._prep_statement(query), values) 

472 

473 # Reassign in history tables, only if history is enabled 

474 if id2insert_id: 

475 # Filter out insert ids that have been deleted already. There is a 

476 # potential race with concurrent removal of insert IDs, but it 

477 # should be handled by WHERE in UPDATE. 

478 known_ids = set() 

479 if insert_ids := self.getInsertIds(): 

480 known_ids = set(insert_ids) 

481 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids} 

482 if id2insert_id: 

483 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId) 

484 for diaSourceId, ssObjectId in idMap.items(): 

485 if insert_id := id2insert_id.get(diaSourceId): 

486 query = ( 

487 f'UPDATE "{self._keyspace}"."{table_name}" ' 

488 ' SET "ssObjectId" = ?, "diaObjectId" = NULL ' 

489 'WHERE "insert_id" = ? AND "diaSourceId" = ?' 

490 ) 

491 values = (ssObjectId, insert_id.id, diaSourceId) 

492 queries.add(self._prep_statement(query), values) 

493 

494 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

495 with Timer(table_name + " update", self.config.timer): 

496 self._session.execute(queries, execution_profile="write") 

497 

498 def dailyJob(self) -> None: 

499 # docstring is inherited from a base class 

500 pass 

501 

502 def countUnassociatedObjects(self) -> int: 

503 # docstring is inherited from a base class 

504 

505 # It's too inefficient to implement it for Cassandra in current schema. 

506 raise NotImplementedError() 

507 

508 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]: 

509 """Make all execution profiles used in the code.""" 

510 

511 if config.private_ips: 

512 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

513 else: 

514 loadBalancePolicy = RoundRobinPolicy() 

515 

516 read_tuples_profile = ExecutionProfile( 

517 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

518 request_timeout=config.read_timeout, 

519 row_factory=cassandra.query.tuple_factory, 

520 load_balancing_policy=loadBalancePolicy, 

521 ) 

522 read_pandas_profile = ExecutionProfile( 

523 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

524 request_timeout=config.read_timeout, 

525 row_factory=pandas_dataframe_factory, 

526 load_balancing_policy=loadBalancePolicy, 

527 ) 

528 read_raw_profile = ExecutionProfile( 

529 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

530 request_timeout=config.read_timeout, 

531 row_factory=raw_data_factory, 

532 load_balancing_policy=loadBalancePolicy, 

533 ) 

534 # Profile to use with select_concurrent to return pandas data frame 

535 read_pandas_multi_profile = ExecutionProfile( 

536 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

537 request_timeout=config.read_timeout, 

538 row_factory=pandas_dataframe_factory, 

539 load_balancing_policy=loadBalancePolicy, 

540 ) 

541 # Profile to use with select_concurrent to return raw data (columns and 

542 # rows) 

543 read_raw_multi_profile = ExecutionProfile( 

544 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency), 

545 request_timeout=config.read_timeout, 

546 row_factory=raw_data_factory, 

547 load_balancing_policy=loadBalancePolicy, 

548 ) 

549 write_profile = ExecutionProfile( 

550 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency), 

551 request_timeout=config.write_timeout, 

552 load_balancing_policy=loadBalancePolicy, 

553 ) 

554 # To replace default DCAwareRoundRobinPolicy 

555 default_profile = ExecutionProfile( 

556 load_balancing_policy=loadBalancePolicy, 

557 ) 

558 return { 

559 "read_tuples": read_tuples_profile, 

560 "read_pandas": read_pandas_profile, 

561 "read_raw": read_raw_profile, 

562 "read_pandas_multi": read_pandas_multi_profile, 

563 "read_raw_multi": read_raw_multi_profile, 

564 "write": write_profile, 

565 EXEC_PROFILE_DEFAULT: default_profile, 

566 } 

567 

568 def _getSources( 

569 self, 

570 region: sphgeom.Region, 

571 object_ids: Optional[Iterable[int]], 

572 mjd_start: float, 

573 mjd_end: float, 

574 table_name: ApdbTables, 

575 ) -> pandas.DataFrame: 

576 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

577 

578 Parameters 

579 ---------- 

580 region : `lsst.sphgeom.Region` 

581 Spherical region. 

582 object_ids : 

583 Collection of DiaObject IDs 

584 mjd_start : `float` 

585 Lower bound of time interval. 

586 mjd_end : `float` 

587 Upper bound of time interval. 

588 table_name : `ApdbTables` 

589 Name of the table. 

590 

591 Returns 

592 ------- 

593 catalog : `pandas.DataFrame`, or `None` 

594 Catalog contaning DiaSource records. Empty catalog is returned if 

595 ``object_ids`` is empty. 

596 """ 

597 object_id_set: Set[int] = set() 

598 if object_ids is not None: 

599 object_id_set = set(object_ids) 

600 if len(object_id_set) == 0: 

601 return self._make_empty_catalog(table_name) 

602 

603 sp_where = self._spatial_where(region) 

604 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end) 

605 

606 # We need to exclude extra partitioning columns from result. 

607 column_names = self._schema.apdbColumnNames(table_name) 

608 what = ",".join(_quote_column(column) for column in column_names) 

609 

610 # Build all queries 

611 statements: List[Tuple] = [] 

612 for table in tables: 

613 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"' 

614 statements += list(self._combine_where(prefix, sp_where, temporal_where)) 

615 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

616 

617 with Timer(table_name.name + " select", self.config.timer): 

618 catalog = cast( 

619 pandas.DataFrame, 

620 select_concurrent( 

621 self._session, statements, "read_pandas_multi", self.config.read_concurrency 

622 ), 

623 ) 

624 

625 # filter by given object IDs 

626 if len(object_id_set) > 0: 

627 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

628 

629 # precise filtering on midPointTai 

630 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start]) 

631 

632 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

633 return catalog 

634 

635 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

636 """Return records from a particular table given set of insert IDs.""" 

637 if not self._schema.has_insert_id: 

638 raise ValueError("APDB is not configured for history retrieval") 

639 

640 insert_ids = [id.id for id in ids] 

641 params = ",".join("?" * len(insert_ids)) 

642 

643 table_name = self._schema.tableName(table) 

644 # I know that history table schema has only regular APDB columns plus 

645 # an insert_id column, and this is exactly what we need to return from 

646 # this method, so selecting a star is fine here. 

647 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})' 

648 statement = self._prep_statement(query) 

649 

650 with Timer("DiaObject history", self.config.timer): 

651 result = self._session.execute(statement, insert_ids, execution_profile="read_raw") 

652 table_data = cast(ApdbCassandraTableData, result._current_rows) 

653 return table_data 

654 

655 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None: 

656 

657 # Cassandra timestamp uses milliseconds since epoch 

658 timestamp = visit_time.nsecs() // 1_000_000 

659 

660 # everything goes into a single partition 

661 partition = 0 

662 

663 table_name = self._schema.tableName(ExtraTables.DiaInsertId) 

664 query = ( 

665 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) ' 

666 "VALUES (?, ?, ?)" 

667 ) 

668 

669 self._session.execute( 

670 self._prep_statement(query), 

671 (partition, insert_id.id, timestamp), 

672 timeout=self.config.write_timeout, 

673 execution_profile="write", 

674 ) 

675 

676 def _storeDiaObjects( 

677 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

678 ) -> None: 

679 """Store catalog of DiaObjects from current visit. 

680 

681 Parameters 

682 ---------- 

683 objs : `pandas.DataFrame` 

684 Catalog with DiaObject records 

685 visit_time : `lsst.daf.base.DateTime` 

686 Time of the current visit. 

687 """ 

688 visit_time_dt = visit_time.toPython() 

689 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

690 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

691 

692 extra_columns["validityStart"] = visit_time_dt 

693 time_part: Optional[int] = self._time_partition(visit_time) 

694 if not self.config.time_partition_tables: 

695 extra_columns["apdb_time_part"] = time_part 

696 time_part = None 

697 

698 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part) 

699 

700 if insert_id is not None: 

701 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt) 

702 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns) 

703 

704 def _storeDiaSources( 

705 self, 

706 table_name: ApdbTables, 

707 sources: pandas.DataFrame, 

708 visit_time: dafBase.DateTime, 

709 insert_id: ApdbInsertId | None, 

710 ) -> None: 

711 """Store catalog of DIASources or DIAForcedSources from current visit. 

712 

713 Parameters 

714 ---------- 

715 sources : `pandas.DataFrame` 

716 Catalog containing DiaSource records 

717 visit_time : `lsst.daf.base.DateTime` 

718 Time of the current visit. 

719 """ 

720 time_part: Optional[int] = self._time_partition(visit_time) 

721 extra_columns: dict[str, Any] = {} 

722 if not self.config.time_partition_tables: 

723 extra_columns["apdb_time_part"] = time_part 

724 time_part = None 

725 

726 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part) 

727 

728 if insert_id is not None: 

729 extra_columns = dict(insert_id=insert_id.id) 

730 if table_name is ApdbTables.DiaSource: 

731 extra_table = ExtraTables.DiaSourceInsertId 

732 else: 

733 extra_table = ExtraTables.DiaForcedSourceInsertId 

734 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

735 

736 def _storeDiaSourcesPartitions( 

737 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None 

738 ) -> None: 

739 """Store mapping of diaSourceId to its partitioning values. 

740 

741 Parameters 

742 ---------- 

743 sources : `pandas.DataFrame` 

744 Catalog containing DiaSource records 

745 visit_time : `lsst.daf.base.DateTime` 

746 Time of the current visit. 

747 """ 

748 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

749 extra_columns = { 

750 "apdb_time_part": self._time_partition(visit_time), 

751 "insert_id": insert_id.id if insert_id is not None else None, 

752 } 

753 

754 self._storeObjectsPandas( 

755 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

756 ) 

757 

758 def _storeObjectsPandas( 

759 self, 

760 records: pandas.DataFrame, 

761 table_name: Union[ApdbTables, ExtraTables], 

762 extra_columns: Optional[Mapping] = None, 

763 time_part: Optional[int] = None, 

764 ) -> None: 

765 """Generic store method. 

766 

767 Takes Pandas catalog and stores a bunch of records in a table. 

768 

769 Parameters 

770 ---------- 

771 records : `pandas.DataFrame` 

772 Catalog containing object records 

773 table_name : `ApdbTables` 

774 Name of the table as defined in APDB schema. 

775 extra_columns : `dict`, optional 

776 Mapping (column_name, column_value) which gives fixed values for 

777 columns in each row, overrides values in ``records`` if matching 

778 columns exist there. 

779 time_part : `int`, optional 

780 If not `None` then insert into a per-partition table. 

781 

782 Notes 

783 ----- 

784 If Pandas catalog contains additional columns not defined in table 

785 schema they are ignored. Catalog does not have to contain all columns 

786 defined in a table, but partition and clustering keys must be present 

787 in a catalog or ``extra_columns``. 

788 """ 

789 # use extra columns if specified 

790 if extra_columns is None: 

791 extra_columns = {} 

792 extra_fields = list(extra_columns.keys()) 

793 

794 # Fields that will come from dataframe. 

795 df_fields = [column for column in records.columns if column not in extra_fields] 

796 

797 column_map = self._schema.getColumnMap(table_name) 

798 # list of columns (as in felis schema) 

799 fields = [column_map[field].name for field in df_fields if field in column_map] 

800 fields += extra_fields 

801 

802 # check that all partitioning and clustering columns are defined 

803 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns( 

804 table_name 

805 ) 

806 missing_columns = [column for column in required_columns if column not in fields] 

807 if missing_columns: 

808 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

809 

810 qfields = [quote_id(field) for field in fields] 

811 qfields_str = ",".join(qfields) 

812 

813 with Timer(table_name.name + " query build", self.config.timer): 

814 

815 table = self._schema.tableName(table_name) 

816 if time_part is not None: 

817 table = f"{table}_{time_part}" 

818 

819 holders = ",".join(["?"] * len(qfields)) 

820 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})' 

821 statement = self._prep_statement(query) 

822 queries = cassandra.query.BatchStatement() 

823 for rec in records.itertuples(index=False): 

824 values = [] 

825 for field in df_fields: 

826 if field not in column_map: 

827 continue 

828 value = getattr(rec, field) 

829 if column_map[field].datatype is felis.types.Timestamp: 

830 if isinstance(value, pandas.Timestamp): 

831 value = literal(value.to_pydatetime()) 

832 else: 

833 # Assume it's seconds since epoch, Cassandra 

834 # datetime is in milliseconds 

835 value = int(value * 1000) 

836 values.append(literal(value)) 

837 for field in extra_fields: 

838 value = extra_columns[field] 

839 values.append(literal(value)) 

840 queries.add(statement, values) 

841 

842 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0]) 

843 with Timer(table_name.name + " insert", self.config.timer): 

844 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write") 

845 

846 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

847 """Calculate spacial partition for each record and add it to a 

848 DataFrame. 

849 

850 Notes 

851 ----- 

852 This overrides any existing column in a DataFrame with the same name 

853 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

854 returned. 

855 """ 

856 # calculate HTM index for every DiaObject 

857 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

858 ra_col, dec_col = self.config.ra_dec_columns 

859 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

860 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

861 idx = self._pixelization.pixel(uv3d) 

862 apdb_part[i] = idx 

863 df = df.copy() 

864 df["apdb_part"] = apdb_part 

865 return df 

866 

867 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

868 """Add apdb_part column to DiaSource catalog. 

869 

870 Notes 

871 ----- 

872 This method copies apdb_part value from a matching DiaObject record. 

873 DiaObject catalog needs to have a apdb_part column filled by 

874 ``_add_obj_part`` method and DiaSource records need to be 

875 associated to DiaObjects via ``diaObjectId`` column. 

876 

877 This overrides any existing column in a DataFrame with the same name 

878 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

879 returned. 

880 """ 

881 pixel_id_map: Dict[int, int] = { 

882 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

883 } 

884 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

885 ra_col, dec_col = self.config.ra_dec_columns 

886 for i, (diaObjId, ra, dec) in enumerate( 

887 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col]) 

888 ): 

889 if diaObjId == 0: 

890 # DiaSources associated with SolarSystemObjects do not have an 

891 # associated DiaObject hence we skip them and set partition 

892 # based on its own ra/dec 

893 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

894 idx = self._pixelization.pixel(uv3d) 

895 apdb_part[i] = idx 

896 else: 

897 apdb_part[i] = pixel_id_map[diaObjId] 

898 sources = sources.copy() 

899 sources["apdb_part"] = apdb_part 

900 return sources 

901 

902 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

903 """Add apdb_part column to DiaForcedSource catalog. 

904 

905 Notes 

906 ----- 

907 This method copies apdb_part value from a matching DiaObject record. 

908 DiaObject catalog needs to have a apdb_part column filled by 

909 ``_add_obj_part`` method and DiaSource records need to be 

910 associated to DiaObjects via ``diaObjectId`` column. 

911 

912 This overrides any existing column in a DataFrame with the same name 

913 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

914 returned. 

915 """ 

916 pixel_id_map: Dict[int, int] = { 

917 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"]) 

918 } 

919 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

920 for i, diaObjId in enumerate(sources["diaObjectId"]): 

921 apdb_part[i] = pixel_id_map[diaObjId] 

922 sources = sources.copy() 

923 sources["apdb_part"] = apdb_part 

924 return sources 

925 

926 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

927 """Calculate time partiton number for a given time. 

928 

929 Parameters 

930 ---------- 

931 time : `float` or `lsst.daf.base.DateTime` 

932 Time for which to calculate partition number. Can be float to mean 

933 MJD or `lsst.daf.base.DateTime` 

934 

935 Returns 

936 ------- 

937 partition : `int` 

938 Partition number for a given time. 

939 """ 

940 if isinstance(time, dafBase.DateTime): 

941 mjd = time.get(system=dafBase.DateTime.MJD) 

942 else: 

943 mjd = time 

944 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

945 partition = int(days_since_epoch) // self.config.time_partition_days 

946 return partition 

947 

948 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

949 """Make an empty catalog for a table with a given name. 

950 

951 Parameters 

952 ---------- 

953 table_name : `ApdbTables` 

954 Name of the table. 

955 

956 Returns 

957 ------- 

958 catalog : `pandas.DataFrame` 

959 An empty catalog. 

960 """ 

961 table = self._schema.tableSchemas[table_name] 

962 

963 data = { 

964 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype)) 

965 for columnDef in table.columns 

966 } 

967 return pandas.DataFrame(data) 

968 

969 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement: 

970 """Convert query string into prepared statement.""" 

971 stmt = self._prepared_statements.get(query) 

972 if stmt is None: 

973 stmt = self._session.prepare(query) 

974 self._prepared_statements[query] = stmt 

975 return stmt 

976 

977 def _combine_where( 

978 self, 

979 prefix: str, 

980 where1: List[Tuple[str, Tuple]], 

981 where2: List[Tuple[str, Tuple]], 

982 suffix: Optional[str] = None, 

983 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]: 

984 """Make cartesian product of two parts of WHERE clause into a series 

985 of statements to execute. 

986 

987 Parameters 

988 ---------- 

989 prefix : `str` 

990 Initial statement prefix that comes before WHERE clause, e.g. 

991 "SELECT * from Table" 

992 """ 

993 # If lists are empty use special sentinels. 

994 if not where1: 

995 where1 = [("", ())] 

996 if not where2: 

997 where2 = [("", ())] 

998 

999 for expr1, params1 in where1: 

1000 for expr2, params2 in where2: 

1001 full_query = prefix 

1002 wheres = [] 

1003 if expr1: 

1004 wheres.append(expr1) 

1005 if expr2: 

1006 wheres.append(expr2) 

1007 if wheres: 

1008 full_query += " WHERE " + " AND ".join(wheres) 

1009 if suffix: 

1010 full_query += " " + suffix 

1011 params = params1 + params2 

1012 if params: 

1013 statement = self._prep_statement(full_query) 

1014 else: 

1015 # If there are no params then it is likely that query 

1016 # has a bunch of literals rendered already, no point 

1017 # trying to prepare it. 

1018 statement = cassandra.query.SimpleStatement(full_query) 

1019 yield (statement, params) 

1020 

1021 def _spatial_where( 

1022 self, region: Optional[sphgeom.Region], use_ranges: bool = False 

1023 ) -> List[Tuple[str, Tuple]]: 

1024 """Generate expressions for spatial part of WHERE clause. 

1025 

1026 Parameters 

1027 ---------- 

1028 region : `sphgeom.Region` 

1029 Spatial region for query results. 

1030 use_ranges : `bool` 

1031 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <= 

1032 p2") instead of exact list of pixels. Should be set to True for 

1033 large regions covering very many pixels. 

1034 

1035 Returns 

1036 ------- 

1037 expressions : `list` [ `tuple` ] 

1038 Empty list is returned if ``region`` is `None`, otherwise a list 

1039 of one or more (expression, parameters) tuples 

1040 """ 

1041 if region is None: 

1042 return [] 

1043 if use_ranges: 

1044 pixel_ranges = self._pixelization.envelope(region) 

1045 expressions: List[Tuple[str, Tuple]] = [] 

1046 for lower, upper in pixel_ranges: 

1047 upper -= 1 

1048 if lower == upper: 

1049 expressions.append(('"apdb_part" = ?', (lower,))) 

1050 else: 

1051 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper))) 

1052 return expressions 

1053 else: 

1054 pixels = self._pixelization.pixels(region) 

1055 if self.config.query_per_spatial_part: 

1056 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels] 

1057 else: 

1058 pixels_str = ",".join([str(pix) for pix in pixels]) 

1059 return [(f'"apdb_part" IN ({pixels_str})', ())] 

1060 

1061 def _temporal_where( 

1062 self, 

1063 table: ApdbTables, 

1064 start_time: Union[float, dafBase.DateTime], 

1065 end_time: Union[float, dafBase.DateTime], 

1066 query_per_time_part: Optional[bool] = None, 

1067 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]: 

1068 """Generate table names and expressions for temporal part of WHERE 

1069 clauses. 

1070 

1071 Parameters 

1072 ---------- 

1073 table : `ApdbTables` 

1074 Table to select from. 

1075 start_time : `dafBase.DateTime` or `float` 

1076 Starting Datetime of MJD value of the time range. 

1077 start_time : `dafBase.DateTime` or `float` 

1078 Starting Datetime of MJD value of the time range. 

1079 query_per_time_part : `bool`, optional 

1080 If None then use ``query_per_time_part`` from configuration. 

1081 

1082 Returns 

1083 ------- 

1084 tables : `list` [ `str` ] 

1085 List of the table names to query. 

1086 expressions : `list` [ `tuple` ] 

1087 A list of zero or more (expression, parameters) tuples. 

1088 """ 

1089 tables: List[str] 

1090 temporal_where: List[Tuple[str, Tuple]] = [] 

1091 table_name = self._schema.tableName(table) 

1092 time_part_start = self._time_partition(start_time) 

1093 time_part_end = self._time_partition(end_time) 

1094 time_parts = list(range(time_part_start, time_part_end + 1)) 

1095 if self.config.time_partition_tables: 

1096 tables = [f"{table_name}_{part}" for part in time_parts] 

1097 else: 

1098 tables = [table_name] 

1099 if query_per_time_part is None: 

1100 query_per_time_part = self.config.query_per_time_part 

1101 if query_per_time_part: 

1102 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts] 

1103 else: 

1104 time_part_list = ",".join([str(part) for part in time_parts]) 

1105 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())] 

1106 

1107 return tables, temporal_where