Coverage for python/lsst/dax/apdb/apdbCassandra.py: 17%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

415 statements  

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26from datetime import datetime, timedelta 

27import logging 

28import numpy as np 

29import pandas 

30from typing import cast, Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union 

31 

32try: 

33 import cbor 

34except ImportError: 

35 cbor = None 

36 

37# If cassandra-driver is not there the module can still be imported 

38# but ApdbCassandra cannot be instantiated. 

39try: 

40 import cassandra 

41 from cassandra.cluster import Cluster 

42 from cassandra.concurrent import execute_concurrent 

43 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator 

44 import cassandra.query 

45 CASSANDRA_IMPORTED = True 

46except ImportError: 

47 CASSANDRA_IMPORTED = False 

48 

49import lsst.daf.base as dafBase 

50from lsst.pex.config import ChoiceField, Field, ListField 

51from lsst import sphgeom 

52from .timer import Timer 

53from .apdb import Apdb, ApdbConfig 

54from .apdbSchema import ApdbTables, ColumnDef, TableDef 

55from .apdbCassandraSchema import ApdbCassandraSchema 

56 

57 

58_LOG = logging.getLogger(__name__) 

59 

60 

61class CassandraMissingError(Exception): 

62 def __init__(self) -> None: 

63 super().__init__("cassandra-driver module cannot be imported") 

64 

65 

66class ApdbCassandraConfig(ApdbConfig): 

67 

68 contact_points = ListField( 

69 dtype=str, 

70 doc="The list of contact points to try connecting for cluster discovery.", 

71 default=["127.0.0.1"] 

72 ) 

73 private_ips = ListField( 

74 dtype=str, 

75 doc="List of internal IP addresses for contact_points.", 

76 default=[] 

77 ) 

78 keyspace = Field( 

79 dtype=str, 

80 doc="Default keyspace for operations.", 

81 default="apdb" 

82 ) 

83 read_consistency = Field( 

84 dtype=str, 

85 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", 

86 default="QUORUM" 

87 ) 

88 write_consistency = Field( 

89 dtype=str, 

90 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", 

91 default="QUORUM" 

92 ) 

93 read_timeout = Field( 

94 dtype=float, 

95 doc="Timeout in seconds for read operations.", 

96 default=120. 

97 ) 

98 write_timeout = Field( 

99 dtype=float, 

100 doc="Timeout in seconds for write operations.", 

101 default=10. 

102 ) 

103 read_concurrency = Field( 

104 dtype=int, 

105 doc="Concurrency level for read operations.", 

106 default=500 

107 ) 

108 protocol_version = Field( 

109 dtype=int, 

110 doc="Cassandra protocol version to use, default is V4", 

111 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0 

112 ) 

113 dia_object_columns = ListField( 

114 dtype=str, 

115 doc="List of columns to read from DiaObject, by default read all columns", 

116 default=[] 

117 ) 

118 prefix = Field( 

119 dtype=str, 

120 doc="Prefix to add to table names", 

121 default="" 

122 ) 

123 part_pixelization = ChoiceField( 

124 dtype=str, 

125 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

126 doc="Pixelization used for partitioning index.", 

127 default="mq3c" 

128 ) 

129 part_pix_level = Field( 

130 dtype=int, 

131 doc="Pixelization level used for partitioning index.", 

132 default=10 

133 ) 

134 ra_dec_columns = ListField( 

135 dtype=str, 

136 default=["ra", "decl"], 

137 doc="Names ra/dec columns in DiaObject table" 

138 ) 

139 timer = Field( 

140 dtype=bool, 

141 doc="If True then print/log timing information", 

142 default=False 

143 ) 

144 time_partition_tables = Field( 

145 dtype=bool, 

146 doc="Use per-partition tables for sources instead of partitioning by time", 

147 default=True 

148 ) 

149 time_partition_days = Field( 

150 dtype=int, 

151 doc="Time partitoning granularity in days, this value must not be changed" 

152 " after database is initialized", 

153 default=30 

154 ) 

155 time_partition_start = Field( 

156 dtype=str, 

157 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI." 

158 " This is used only when time_partition_tables is True.", 

159 default="2018-12-01T00:00:00" 

160 ) 

161 time_partition_end = Field( 

162 dtype=str, 

163 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI" 

164 " This is used only when time_partition_tables is True.", 

165 default="2030-01-01T00:00:00" 

166 ) 

167 query_per_time_part = Field( 

168 dtype=bool, 

169 default=False, 

170 doc="If True then build separate query for each time partition, otherwise build one single query. " 

171 "This is only used when time_partition_tables is False in schema config." 

172 ) 

173 query_per_spatial_part = Field( 

174 dtype=bool, 

175 default=False, 

176 doc="If True then build one query per spacial partition, otherwise build single query. " 

177 ) 

178 pandas_delay_conv = Field( 

179 dtype=bool, 

180 default=True, 

181 doc="If True then combine result rows before converting to pandas. " 

182 ) 

183 packing = ChoiceField( 

184 dtype=str, 

185 allowed=dict(none="No field packing", cbor="Pack using CBOR"), 

186 doc="Packing method for table records.", 

187 default="none" 

188 ) 

189 prepared_statements = Field( 

190 dtype=bool, 

191 default=True, 

192 doc="If True use Cassandra prepared statements." 

193 ) 

194 

195 

196class Partitioner: 

197 """Class that calculates indices of the objects for partitioning. 

198 

199 Used internally by `ApdbCassandra` 

200 

201 Parameters 

202 ---------- 

203 config : `ApdbCassandraConfig` 

204 """ 

205 def __init__(self, config: ApdbCassandraConfig): 

206 pix = config.part_pixelization 

207 if pix == "htm": 

208 self.pixelator = sphgeom.HtmPixelization(config.part_pix_level) 

209 elif pix == "q3c": 

210 self.pixelator = sphgeom.Q3cPixelization(config.part_pix_level) 

211 elif pix == "mq3c": 

212 self.pixelator = sphgeom.Mq3cPixelization(config.part_pix_level) 

213 else: 

214 raise ValueError(f"unknown pixelization: {pix}") 

215 

216 def pixels(self, region: sphgeom.Region) -> List[int]: 

217 """Compute set of the pixel indices for given region. 

218 

219 Parameters 

220 ---------- 

221 region : `lsst.sphgeom.Region` 

222 """ 

223 # we want finest set of pixels, so ask as many pixel as possible 

224 ranges = self.pixelator.envelope(region, 1_000_000) 

225 indices = [] 

226 for lower, upper in ranges: 

227 indices += list(range(lower, upper)) 

228 return indices 

229 

230 def pixel(self, direction: sphgeom.UnitVector3d) -> int: 

231 """Compute the index of the pixel for given direction. 

232 

233 Parameters 

234 ---------- 

235 direction : `lsst.sphgeom.UnitVector3d` 

236 """ 

237 index = self.pixelator.index(direction) 

238 return index 

239 

240 

241if CASSANDRA_IMPORTED: 241 ↛ 243line 241 didn't jump to line 243, because the condition on line 241 was never true

242 

243 class _AddressTranslator(AddressTranslator): 

244 """Translate internal IP address to external. 

245 

246 Only used for docker-based setup, not viable long-term solution. 

247 """ 

248 def __init__(self, public_ips: List[str], private_ips: List[str]): 

249 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

250 

251 def translate(self, private_ip: str) -> str: 

252 return self._map.get(private_ip, private_ip) 

253 

254 

255def _rows_to_pandas(colnames: List[str], rows: List[Tuple], 

256 packedColumns: List[ColumnDef]) -> pandas.DataFrame: 

257 """Convert result rows to pandas. 

258 

259 Unpacks BLOBs that were packed on insert. 

260 

261 Parameters 

262 ---------- 

263 colname : `list` [ `str` ] 

264 Names of the columns. 

265 rows : `list` of `tuple` 

266 Result rows. 

267 packedColumns : `list` [ `ColumnDef` ] 

268 Column definitions for packed columns. 

269 

270 Returns 

271 ------- 

272 catalog : `pandas.DataFrame` 

273 DataFrame with the result set. 

274 """ 

275 try: 

276 idx = colnames.index("apdb_packed") 

277 except ValueError: 

278 # no packed columns 

279 return pandas.DataFrame.from_records(rows, columns=colnames) 

280 

281 # make data frame for non-packed columns 

282 df = pandas.DataFrame.from_records(rows, columns=colnames, exclude=["apdb_packed"]) 

283 

284 # make records with packed data only as dicts 

285 packed_rows = [] 

286 for row in rows: 

287 blob = row[idx] 

288 if blob[:5] == b"cbor:": 

289 blob = cbor.loads(blob[5:]) 

290 else: 

291 raise ValueError("Unexpected BLOB format: %r", blob) 

292 packed_rows.append(blob) 

293 

294 # make data frome from packed data 

295 packed = pandas.DataFrame.from_records(packed_rows, columns=[col.name for col in packedColumns]) 

296 

297 # convert timestamps which are integer milliseconds into datetime 

298 for col in packedColumns: 

299 if col.type == "DATETIME": 

300 packed[col.name] = pandas.to_datetime(packed[col.name], unit="ms", origin="unix") 

301 

302 return pandas.concat([df, packed], axis=1) 

303 

304 

305class _PandasRowFactory: 

306 """Create pandas DataFrame from Cassandra result set. 

307 

308 Parameters 

309 ---------- 

310 packedColumns : `list` [ `ColumnDef` ] 

311 Column definitions for packed columns. 

312 """ 

313 def __init__(self, packedColumns: Iterable[ColumnDef]): 

314 self.packedColumns = list(packedColumns) 

315 

316 def __call__(self, colnames: List[str], rows: List[Tuple]) -> pandas.DataFrame: 

317 """Convert result set into output catalog. 

318 

319 Parameters 

320 ---------- 

321 colname : `list` [ `str` ] 

322 Names of the columns. 

323 rows : `list` of `tuple` 

324 Result rows 

325 

326 Returns 

327 ------- 

328 catalog : `pandas.DataFrame` 

329 DataFrame with the result set. 

330 """ 

331 return _rows_to_pandas(colnames, rows, self.packedColumns) 

332 

333 

334class _RawRowFactory: 

335 """Row factory that makes no conversions. 

336 

337 Parameters 

338 ---------- 

339 packedColumns : `list` [ `ColumnDef` ] 

340 Column definitions for packed columns. 

341 """ 

342 def __init__(self, packedColumns: Iterable[ColumnDef]): 

343 self.packedColumns = list(packedColumns) 

344 

345 def __call__(self, colnames: List[str], rows: List[Tuple]) -> Tuple[List[str], List[Tuple]]: 

346 """Return parameters without change. 

347 

348 Parameters 

349 ---------- 

350 colname : `list` of `str` 

351 Names of the columns. 

352 rows : `list` of `tuple` 

353 Result rows 

354 

355 Returns 

356 ------- 

357 colname : `list` of `str` 

358 Names of the columns. 

359 rows : `list` of `tuple` 

360 Result rows 

361 """ 

362 return (colnames, rows) 

363 

364 

365class ApdbCassandra(Apdb): 

366 """Implementation of APDB database on to of Apache Cassandra. 

367 

368 The implementation is configured via standard ``pex_config`` mechanism 

369 using `ApdbCassandraConfig` configuration class. For an example of 

370 different configurations check config/ folder. 

371 

372 Parameters 

373 ---------- 

374 config : `ApdbCassandraConfig` 

375 Configuration object. 

376 """ 

377 

378 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

379 """Start time for partition 0, this should never be changed.""" 

380 

381 def __init__(self, config: ApdbCassandraConfig): 

382 

383 if not CASSANDRA_IMPORTED: 

384 raise CassandraMissingError() 

385 

386 self.config = config 

387 

388 _LOG.debug("ApdbCassandra Configuration:") 

389 _LOG.debug(" read_consistency: %s", self.config.read_consistency) 

390 _LOG.debug(" write_consistency: %s", self.config.write_consistency) 

391 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

392 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

393 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

394 _LOG.debug(" schema_file: %s", self.config.schema_file) 

395 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

396 _LOG.debug(" schema prefix: %s", self.config.prefix) 

397 _LOG.debug(" part_pixelization: %s", self.config.part_pixelization) 

398 _LOG.debug(" part_pix_level: %s", self.config.part_pix_level) 

399 _LOG.debug(" query_per_time_part: %s", self.config.query_per_time_part) 

400 _LOG.debug(" query_per_spatial_part: %s", self.config.query_per_spatial_part) 

401 

402 self._partitioner = Partitioner(config) 

403 

404 addressTranslator: Optional[AddressTranslator] = None 

405 if config.private_ips: 

406 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

407 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

408 else: 

409 loadBalancePolicy = RoundRobinPolicy() 

410 

411 self._read_consistency = getattr(cassandra.ConsistencyLevel, config.read_consistency) 

412 self._write_consistency = getattr(cassandra.ConsistencyLevel, config.write_consistency) 

413 

414 self._cluster = Cluster(contact_points=self.config.contact_points, 

415 load_balancing_policy=loadBalancePolicy, 

416 address_translator=addressTranslator, 

417 protocol_version=self.config.protocol_version) 

418 self._session = self._cluster.connect(keyspace=config.keyspace) 

419 self._session.row_factory = cassandra.query.named_tuple_factory 

420 

421 self._schema = ApdbCassandraSchema(session=self._session, 

422 schema_file=self.config.schema_file, 

423 extra_schema_file=self.config.extra_schema_file, 

424 prefix=self.config.prefix, 

425 packing=self.config.packing, 

426 time_partition_tables=self.config.time_partition_tables) 

427 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

428 

429 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

430 # docstring is inherited from a base class 

431 return self._schema.tableSchemas.get(table) 

432 

433 def makeSchema(self, drop: bool = False) -> None: 

434 # docstring is inherited from a base class 

435 

436 if self.config.time_partition_tables: 

437 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

438 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

439 part_range = ( 

440 self._time_partition(time_partition_start), 

441 self._time_partition(time_partition_end) + 1 

442 ) 

443 self._schema.makeSchema(drop=drop, part_range=part_range) 

444 else: 

445 self._schema.makeSchema(drop=drop) 

446 

447 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

448 # docstring is inherited from a base class 

449 packedColumns = self._schema.packedColumns(ApdbTables.DiaObjectLast) 

450 self._session.row_factory = _PandasRowFactory(packedColumns) 

451 self._session.default_fetch_size = None 

452 

453 pixels = self._partitioner.pixels(region) 

454 _LOG.debug("getDiaObjects: #partitions: %s", len(pixels)) 

455 pixels_str = ",".join([str(pix) for pix in pixels]) 

456 

457 queries: List[Tuple] = [] 

458 query = f'SELECT * from "DiaObjectLast" WHERE "apdb_part" IN ({pixels_str})' 

459 queries += [(cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})] 

460 _LOG.debug("getDiaObjects: #queries: %s", len(queries)) 

461 # _LOG.debug("getDiaObjects: queries: %s", queries) 

462 

463 objects = None 

464 with Timer('DiaObject select', self.config.timer): 

465 # submit all queries 

466 futures = [self._session.execute_async(query, values, timeout=self.config.read_timeout) 

467 for query, values in queries] 

468 # TODO: This orders result processing which is not very efficient 

469 dataframes = [future.result()._current_rows for future in futures] 

470 # concatenate all frames 

471 if len(dataframes) == 1: 

472 objects = dataframes[0] 

473 else: 

474 objects = pandas.concat(dataframes) 

475 

476 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

477 return objects 

478 

479 def getDiaSources(self, region: sphgeom.Region, 

480 object_ids: Optional[Iterable[int]], 

481 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

482 # docstring is inherited from a base class 

483 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaSource, 

484 self.config.read_sources_months) 

485 

486 def getDiaForcedSources(self, region: sphgeom.Region, 

487 object_ids: Optional[Iterable[int]], 

488 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

489 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaForcedSource, 

490 self.config.read_forced_sources_months) 

491 

492 def _getSources(self, region: sphgeom.Region, 

493 object_ids: Optional[Iterable[int]], 

494 visit_time: dafBase.DateTime, 

495 table_name: ApdbTables, 

496 months: int) -> Optional[pandas.DataFrame]: 

497 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

498 

499 Parameters 

500 ---------- 

501 region : `lsst.sphgeom.Region` 

502 Spherical region. 

503 object_ids : 

504 Collection of DiaObject IDs 

505 visit_time : `lsst.daf.base.DateTime` 

506 Time of the current visit 

507 table_name : `ApdbTables` 

508 Name of the table, either "DiaSource" or "DiaForcedSource" 

509 months : `int` 

510 Number of months of history to return, if negative returns whole 

511 history (Note: negative does not work with table-per-partition 

512 case) 

513 

514 Returns 

515 ------- 

516 catalog : `pandas.DataFrame`, or `None` 

517 Catalog contaning DiaSource records. `None` is returned if 

518 ``months`` is 0 or when ``object_ids`` is empty. 

519 """ 

520 if months == 0: 

521 return None 

522 object_id_set: Set[int] = set() 

523 if object_ids is not None: 

524 object_id_set = set(object_ids) 

525 if len(object_id_set) == 0: 

526 return self._make_empty_catalog(table_name) 

527 

528 packedColumns = self._schema.packedColumns(table_name) 

529 if self.config.pandas_delay_conv: 

530 self._session.row_factory = _RawRowFactory(packedColumns) 

531 else: 

532 self._session.row_factory = _PandasRowFactory(packedColumns) 

533 self._session.default_fetch_size = None 

534 

535 # spatial pixels included into query 

536 pixels = self._partitioner.pixels(region) 

537 _LOG.debug("_getSources: %s #partitions: %s", table_name.name, len(pixels)) 

538 

539 # spatial part of WHERE 

540 spatial_where = [] 

541 if self.config.query_per_spatial_part: 

542 spatial_where = [f'"apdb_part" = {pixel}' for pixel in pixels] 

543 else: 

544 pixels_str = ",".join([str(pix) for pix in pixels]) 

545 spatial_where = [f'"apdb_part" IN ({pixels_str})'] 

546 

547 # temporal part of WHERE, can be empty 

548 temporal_where = [] 

549 # time partitions and table names to query, there may be multiple 

550 # tables depending on configuration 

551 full_name = self._schema.tableName(table_name) 

552 tables = [full_name] 

553 mjd_now = visit_time.get(system=dafBase.DateTime.MJD) 

554 mjd_begin = mjd_now - months*30 

555 time_part_now = self._time_partition(mjd_now) 

556 time_part_begin = self._time_partition(mjd_begin) 

557 time_parts = list(range(time_part_begin, time_part_now + 1)) 

558 if self.config.time_partition_tables: 

559 tables = [f"{full_name}_{part}" for part in time_parts] 

560 else: 

561 if self.config.query_per_time_part: 

562 temporal_where = [f'"apdb_time_part" = {time_part}' for time_part in time_parts] 

563 else: 

564 time_part_list = ",".join([str(part) for part in time_parts]) 

565 temporal_where = [f'"apdb_time_part" IN ({time_part_list})'] 

566 

567 # Build all queries 

568 queries: List[str] = [] 

569 for table in tables: 

570 query = f'SELECT * from "{table}" WHERE ' 

571 for spacial in spatial_where: 

572 if temporal_where: 

573 for temporal in temporal_where: 

574 queries.append(query + spacial + " AND " + temporal) 

575 else: 

576 queries.append(query + spacial) 

577 # _LOG.debug("_getSources: queries: %s", queries) 

578 

579 statements: List[Tuple] = [ 

580 (cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {}) 

581 for query in queries 

582 ] 

583 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

584 

585 with Timer(table_name.name + ' select', self.config.timer): 

586 # submit all queries 

587 results = execute_concurrent(self._session, statements, results_generator=True, 

588 concurrency=self.config.read_concurrency) 

589 if self.config.pandas_delay_conv: 

590 _LOG.debug("making pandas data frame out of rows/columns") 

591 columns: Any = None 

592 rows = [] 

593 for success, result in results: 

594 result = result._current_rows 

595 if success: 

596 if columns is None: 

597 columns = result[0] 

598 elif columns != result[0]: 

599 _LOG.error("different columns returned by queries: %s and %s", 

600 columns, result[0]) 

601 raise ValueError( 

602 f"diferent columns returned by queries: {columns} and {result[0]}" 

603 ) 

604 rows += result[1] 

605 else: 

606 _LOG.error("error returned by query: %s", result) 

607 raise result 

608 catalog = _rows_to_pandas(columns, rows, self._schema.packedColumns(table_name)) 

609 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

610 # filter by given object IDs 

611 if len(object_id_set) > 0: 

612 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

613 else: 

614 _LOG.debug("making pandas data frame out of set of data frames") 

615 dataframes = [] 

616 for success, result in results: 

617 if success: 

618 dataframes.append(result._current_rows) 

619 else: 

620 _LOG.error("error returned by query: %s", result) 

621 raise result 

622 # concatenate all frames 

623 if len(dataframes) == 1: 

624 catalog = dataframes[0] 

625 else: 

626 catalog = pandas.concat(dataframes) 

627 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

628 # filter by given object IDs 

629 if len(object_id_set) > 0: 

630 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

631 

632 # precise filtering on midPointTai 

633 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_begin]) 

634 

635 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

636 return catalog 

637 

638 def getDiaObjectsHistory(self, 

639 start_time: dafBase.DateTime, 

640 end_time: Optional[dafBase.DateTime] = None, 

641 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

642 # docstring is inherited from a base class 

643 raise NotImplementedError() 

644 

645 def getDiaSourcesHistory(self, 

646 start_time: dafBase.DateTime, 

647 end_time: Optional[dafBase.DateTime] = None, 

648 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

649 # docstring is inherited from a base class 

650 raise NotImplementedError() 

651 

652 def getDiaForcedSourcesHistory(self, 

653 start_time: dafBase.DateTime, 

654 end_time: Optional[dafBase.DateTime] = None, 

655 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame: 

656 # docstring is inherited from a base class 

657 raise NotImplementedError() 

658 

659 def getSSObjects(self) -> pandas.DataFrame: 

660 # docstring is inherited from a base class 

661 raise NotImplementedError() 

662 

663 def store(self, 

664 visit_time: dafBase.DateTime, 

665 objects: pandas.DataFrame, 

666 sources: Optional[pandas.DataFrame] = None, 

667 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

668 # docstring is inherited from a base class 

669 

670 # fill region partition column for DiaObjects 

671 objects = self._add_obj_part(objects) 

672 self._storeDiaObjects(objects, visit_time) 

673 

674 if sources is not None: 

675 # copy apdb_part column from DiaObjects to DiaSources 

676 sources = self._add_src_part(sources, objects) 

677 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time) 

678 

679 if forced_sources is not None: 

680 forced_sources = self._add_fsrc_part(forced_sources, objects) 

681 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time) 

682 

683 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

684 """Store catalog of DiaObjects from current visit. 

685 

686 Parameters 

687 ---------- 

688 objs : `pandas.DataFrame` 

689 Catalog with DiaObject records 

690 visit_time : `lsst.daf.base.DateTime` 

691 Time of the current visit. 

692 """ 

693 visit_time_dt = visit_time.toPython() 

694 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

695 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, visit_time, extra_columns=extra_columns) 

696 

697 extra_columns["validityStart"] = visit_time_dt 

698 time_part: Optional[int] = self._time_partition(visit_time) 

699 if not self.config.time_partition_tables: 

700 extra_columns["apdb_time_part"] = time_part 

701 time_part = None 

702 

703 self._storeObjectsPandas(objs, ApdbTables.DiaObject, visit_time, 

704 extra_columns=extra_columns, time_part=time_part) 

705 

706 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame, 

707 visit_time: dafBase.DateTime) -> None: 

708 """Store catalog of DIASources or DIAForcedSources from current visit. 

709 

710 Parameters 

711 ---------- 

712 sources : `pandas.DataFrame` 

713 Catalog containing DiaSource records 

714 visit_time : `lsst.daf.base.DateTime` 

715 Time of the current visit. 

716 """ 

717 time_part: Optional[int] = self._time_partition(visit_time) 

718 extra_columns = {} 

719 if not self.config.time_partition_tables: 

720 extra_columns["apdb_time_part"] = time_part 

721 time_part = None 

722 

723 self._storeObjectsPandas(sources, table_name, visit_time, 

724 extra_columns=extra_columns, time_part=time_part) 

725 

726 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

727 # docstring is inherited from a base class 

728 raise NotImplementedError() 

729 

730 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

731 # docstring is inherited from a base class 

732 raise NotImplementedError() 

733 

734 def dailyJob(self) -> None: 

735 # docstring is inherited from a base class 

736 pass 

737 

738 def countUnassociatedObjects(self) -> int: 

739 # docstring is inherited from a base class 

740 raise NotImplementedError() 

741 

742 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: ApdbTables, 

743 visit_time: dafBase.DateTime, extra_columns: Optional[Mapping] = None, 

744 time_part: Optional[int] = None) -> None: 

745 """Generic store method. 

746 

747 Takes catalog of records and stores a bunch of objects in a table. 

748 

749 Parameters 

750 ---------- 

751 objects : `pandas.DataFrame` 

752 Catalog containing object records 

753 table_name : `ApdbTables` 

754 Name of the table as defined in APDB schema. 

755 visit_time : `lsst.daf.base.DateTime` 

756 Time of the current visit. 

757 extra_columns : `dict`, optional 

758 Mapping (column_name, column_value) which gives column values to add 

759 to every row, only if column is missing in catalog records. 

760 time_part : `int`, optional 

761 If not `None` then insert into a per-partition table. 

762 """ 

763 

764 def qValue(v: Any) -> Any: 

765 """Transform object into a value for query""" 

766 if v is None: 

767 pass 

768 elif isinstance(v, datetime): 

769 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1))*1000 

770 elif isinstance(v, (bytes, str)): 

771 pass 

772 else: 

773 try: 

774 if not np.isfinite(v): 

775 v = None 

776 except TypeError: 

777 pass 

778 return v 

779 

780 def quoteId(columnName: str) -> str: 

781 """Smart quoting for column names. 

782 Lower-case names are not quoted. 

783 """ 

784 if not columnName.islower(): 

785 columnName = '"' + columnName + '"' 

786 return columnName 

787 

788 # use extra columns if specified 

789 if extra_columns is None: 

790 extra_columns = {} 

791 extra_fields = list(extra_columns.keys()) 

792 

793 df_fields = [column for column in objects.columns 

794 if column not in extra_fields] 

795 

796 column_map = self._schema.getColumnMap(table_name) 

797 # list of columns (as in cat schema) 

798 fields = [column_map[field].name for field in df_fields if field in column_map] 

799 fields += extra_fields 

800 

801 # check that all partitioning and clustering columns are defined 

802 required_columns = self._schema.partitionColumns(table_name) \ 

803 + self._schema.clusteringColumns(table_name) 

804 missing_columns = [column for column in required_columns if column not in fields] 

805 if missing_columns: 

806 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

807 

808 blob_columns = set(col.name for col in self._schema.packedColumns(table_name)) 

809 # _LOG.debug("blob_columns: %s", blob_columns) 

810 

811 qfields = [quoteId(field) for field in fields if field not in blob_columns] 

812 if blob_columns: 

813 qfields += [quoteId("apdb_packed")] 

814 qfields_str = ','.join(qfields) 

815 

816 with Timer(table_name.name + ' query build', self.config.timer): 

817 

818 table = self._schema.tableName(table_name) 

819 if time_part is not None: 

820 table = f"{table}_{time_part}" 

821 

822 prepared: Optional[cassandra.query.PreparedStatement] = None 

823 if self.config.prepared_statements: 

824 holders = ','.join(['?']*len(qfields)) 

825 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})' 

826 prepared = self._session.prepare(query) 

827 queries = cassandra.query.BatchStatement(consistency_level=self._write_consistency) 

828 for rec in objects.itertuples(index=False): 

829 values = [] 

830 blob = {} 

831 for field in df_fields: 

832 if field not in column_map: 

833 continue 

834 value = getattr(rec, field) 

835 if column_map[field].type == "DATETIME": 

836 if isinstance(value, pandas.Timestamp): 

837 value = qValue(value.to_pydatetime()) 

838 else: 

839 # Assume it's seconds since epoch, Cassandra 

840 # datetime is in milliseconds 

841 value = int(value*1000) 

842 if field in blob_columns: 

843 blob[field] = qValue(value) 

844 else: 

845 values.append(qValue(value)) 

846 for field in extra_fields: 

847 value = extra_columns[field] 

848 if field in blob_columns: 

849 blob[field] = qValue(value) 

850 else: 

851 values.append(qValue(value)) 

852 if blob_columns: 

853 if self.config.packing == "cbor": 

854 blob = b"cbor:" + cbor.dumps(blob) 

855 values.append(blob) 

856 holders = ','.join(['%s']*len(values)) 

857 if prepared is not None: 

858 stmt = prepared 

859 else: 

860 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})' 

861 # _LOG.debug("query: %r", query) 

862 # _LOG.debug("values: %s", values) 

863 stmt = cassandra.query.SimpleStatement(query, consistency_level=self._write_consistency) 

864 queries.add(stmt, values) 

865 

866 # _LOG.debug("query: %s", query) 

867 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0]) 

868 with Timer(table_name.name + ' insert', self.config.timer): 

869 self._session.execute(queries, timeout=self.config.write_timeout) 

870 

871 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

872 """Calculate spacial partition for each record and add it to a 

873 DataFrame. 

874 

875 Notes 

876 ----- 

877 This overrides any existing column in a DataFrame with the same name 

878 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

879 returned. 

880 """ 

881 # calculate HTM index for every DiaObject 

882 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

883 ra_col, dec_col = self.config.ra_dec_columns 

884 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

885 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

886 idx = self._partitioner.pixel(uv3d) 

887 apdb_part[i] = idx 

888 df = df.copy() 

889 df["apdb_part"] = apdb_part 

890 return df 

891 

892 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

893 """Add apdb_part column to DiaSource catalog. 

894 

895 Notes 

896 ----- 

897 This method copies apdb_part value from a matching DiaObject record. 

898 DiaObject catalog needs to have a apdb_part column filled by 

899 ``_add_obj_part`` method and DiaSource records need to be 

900 associated to DiaObjects via ``diaObjectId`` column. 

901 

902 This overrides any existing column in a DataFrame with the same name 

903 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

904 returned. 

905 """ 

906 pixel_id_map: Dict[int, int] = { 

907 diaObjectId: apdb_part for diaObjectId, apdb_part 

908 in zip(objs["diaObjectId"], objs["apdb_part"]) 

909 } 

910 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

911 ra_col, dec_col = self.config.ra_dec_columns 

912 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"], 

913 sources[ra_col], sources[dec_col])): 

914 if diaObjId == 0: 

915 # DiaSources associated with SolarSystemObjects do not have an 

916 # associated DiaObject hence we skip them and set partition 

917 # based on its own ra/dec 

918 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

919 idx = self._partitioner.pixel(uv3d) 

920 apdb_part[i] = idx 

921 else: 

922 apdb_part[i] = pixel_id_map[diaObjId] 

923 sources = sources.copy() 

924 sources["apdb_part"] = apdb_part 

925 return sources 

926 

927 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

928 """Add apdb_part column to DiaForcedSource catalog. 

929 

930 Notes 

931 ----- 

932 This method copies apdb_part value from a matching DiaObject record. 

933 DiaObject catalog needs to have a apdb_part column filled by 

934 ``_add_obj_part`` method and DiaSource records need to be 

935 associated to DiaObjects via ``diaObjectId`` column. 

936 

937 This overrides any existing column in a DataFrame with the same name 

938 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

939 returned. 

940 """ 

941 pixel_id_map: Dict[int, int] = { 

942 diaObjectId: apdb_part for diaObjectId, apdb_part 

943 in zip(objs["diaObjectId"], objs["apdb_part"]) 

944 } 

945 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

946 for i, diaObjId in enumerate(sources["diaObjectId"]): 

947 apdb_part[i] = pixel_id_map[diaObjId] 

948 sources = sources.copy() 

949 sources["apdb_part"] = apdb_part 

950 return sources 

951 

952 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

953 """Calculate time partiton number for a given time. 

954 

955 Parameters 

956 ---------- 

957 time : `float` or `lsst.daf.base.DateTime` 

958 Time for which to calculate partition number. Can be float to mean 

959 MJD or `lsst.daf.base.DateTime` 

960 

961 Returns 

962 ------- 

963 partition : `int` 

964 Partition number for a given time. 

965 """ 

966 if isinstance(time, dafBase.DateTime): 

967 mjd = time.get(system=dafBase.DateTime.MJD) 

968 else: 

969 mjd = time 

970 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

971 partition = int(days_since_epoch) // self.config.time_partition_days 

972 return partition 

973 

974 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

975 """Make an empty catalog for a table with a given name. 

976 

977 Parameters 

978 ---------- 

979 table_name : `ApdbTables` 

980 Name of the table. 

981 

982 Returns 

983 ------- 

984 catalog : `pandas.DataFrame` 

985 An empty catalog. 

986 """ 

987 table = self._schema.tableSchemas[table_name] 

988 

989 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns} 

990 return pandas.DataFrame(data)