Coverage for python/lsst/dax/apdb/apdbCassandra.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

403 statements  

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"] 

25 

26from datetime import datetime, timedelta 

27import logging 

28import numpy as np 

29import pandas 

30from typing import cast, Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union 

31 

32try: 

33 import cbor 

34except ImportError: 

35 cbor = None 

36 

37# If cassandra-driver is not there the module can still be imported 

38# but ApdbCassandra cannot be instantiated. 

39try: 

40 import cassandra 

41 from cassandra.cluster import Cluster 

42 from cassandra.concurrent import execute_concurrent 

43 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator 

44 import cassandra.query 

45 CASSANDRA_IMPORTED = True 

46except ImportError: 

47 CASSANDRA_IMPORTED = False 

48 

49import lsst.daf.base as dafBase 

50from lsst.pex.config import ChoiceField, Field, ListField 

51from lsst import sphgeom 

52from .timer import Timer 

53from .apdb import Apdb, ApdbConfig 

54from .apdbSchema import ApdbTables, ColumnDef, TableDef 

55from .apdbCassandraSchema import ApdbCassandraSchema 

56 

57 

58_LOG = logging.getLogger(__name__) 

59 

60 

61class CassandraMissingError(Exception): 

62 def __init__(self) -> None: 

63 super().__init__("cassandra-driver module cannot be imported") 

64 

65 

66class ApdbCassandraConfig(ApdbConfig): 

67 

68 contact_points = ListField( 

69 dtype=str, 

70 doc="The list of contact points to try connecting for cluster discovery.", 

71 default=["127.0.0.1"] 

72 ) 

73 private_ips = ListField( 

74 dtype=str, 

75 doc="List of internal IP addresses for contact_points.", 

76 default=[] 

77 ) 

78 keyspace = Field( 

79 dtype=str, 

80 doc="Default keyspace for operations.", 

81 default="apdb" 

82 ) 

83 read_consistency = Field( 

84 dtype=str, 

85 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", 

86 default="QUORUM" 

87 ) 

88 write_consistency = Field( 

89 dtype=str, 

90 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", 

91 default="QUORUM" 

92 ) 

93 read_timeout = Field( 

94 dtype=float, 

95 doc="Timeout in seconds for read operations.", 

96 default=120. 

97 ) 

98 write_timeout = Field( 

99 dtype=float, 

100 doc="Timeout in seconds for write operations.", 

101 default=10. 

102 ) 

103 read_concurrency = Field( 

104 dtype=int, 

105 doc="Concurrency level for read operations.", 

106 default=500 

107 ) 

108 protocol_version = Field( 

109 dtype=int, 

110 doc="Cassandra protocol version to use, default is V4", 

111 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0 

112 ) 

113 dia_object_columns = ListField( 

114 dtype=str, 

115 doc="List of columns to read from DiaObject, by default read all columns", 

116 default=[] 

117 ) 

118 prefix = Field( 

119 dtype=str, 

120 doc="Prefix to add to table names", 

121 default="" 

122 ) 

123 part_pixelization = ChoiceField( 

124 dtype=str, 

125 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"), 

126 doc="Pixelization used for partitioning index.", 

127 default="mq3c" 

128 ) 

129 part_pix_level = Field( 

130 dtype=int, 

131 doc="Pixelization level used for partitioning index.", 

132 default=10 

133 ) 

134 ra_dec_columns = ListField( 

135 dtype=str, 

136 default=["ra", "decl"], 

137 doc="Names ra/dec columns in DiaObject table" 

138 ) 

139 timer = Field( 

140 dtype=bool, 

141 doc="If True then print/log timing information", 

142 default=False 

143 ) 

144 time_partition_tables = Field( 

145 dtype=bool, 

146 doc="Use per-partition tables for sources instead of partitioning by time", 

147 default=True 

148 ) 

149 time_partition_days = Field( 

150 dtype=int, 

151 doc="Time partitoning granularity in days, this value must not be changed" 

152 " after database is initialized", 

153 default=30 

154 ) 

155 time_partition_start = Field( 

156 dtype=str, 

157 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI." 

158 " This is used only when time_partition_tables is True.", 

159 default="2018-12-01T00:00:00" 

160 ) 

161 time_partition_end = Field( 

162 dtype=str, 

163 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI" 

164 " This is used only when time_partition_tables is True.", 

165 default="2030-01-01T00:00:00" 

166 ) 

167 query_per_time_part = Field( 

168 dtype=bool, 

169 default=False, 

170 doc="If True then build separate query for each time partition, otherwise build one single query. " 

171 "This is only used when time_partition_tables is False in schema config." 

172 ) 

173 query_per_spatial_part = Field( 

174 dtype=bool, 

175 default=False, 

176 doc="If True then build one query per spacial partition, otherwise build single query. " 

177 ) 

178 pandas_delay_conv = Field( 

179 dtype=bool, 

180 default=True, 

181 doc="If True then combine result rows before converting to pandas. " 

182 ) 

183 packing = ChoiceField( 

184 dtype=str, 

185 allowed=dict(none="No field packing", cbor="Pack using CBOR"), 

186 doc="Packing method for table records.", 

187 default="none" 

188 ) 

189 prepared_statements = Field( 

190 dtype=bool, 

191 default=True, 

192 doc="If True use Cassandra prepared statements." 

193 ) 

194 

195 

196class Partitioner: 

197 """Class that calculates indices of the objects for partitioning. 

198 

199 Used internally by `ApdbCassandra` 

200 

201 Parameters 

202 ---------- 

203 config : `ApdbCassandraConfig` 

204 """ 

205 def __init__(self, config: ApdbCassandraConfig): 

206 pix = config.part_pixelization 

207 if pix == "htm": 

208 self.pixelator = sphgeom.HtmPixelization(config.part_pix_level) 

209 elif pix == "q3c": 

210 self.pixelator = sphgeom.Q3cPixelization(config.part_pix_level) 

211 elif pix == "mq3c": 

212 self.pixelator = sphgeom.Mq3cPixelization(config.part_pix_level) 

213 else: 

214 raise ValueError(f"unknown pixelization: {pix}") 

215 

216 def pixels(self, region: sphgeom.Region) -> List[int]: 

217 """Compute set of the pixel indices for given region. 

218 

219 Parameters 

220 ---------- 

221 region : `lsst.sphgeom.Region` 

222 """ 

223 # we want finest set of pixels, so ask as many pixel as possible 

224 ranges = self.pixelator.envelope(region, 1_000_000) 

225 indices = [] 

226 for lower, upper in ranges: 

227 indices += list(range(lower, upper)) 

228 return indices 

229 

230 def pixel(self, direction: sphgeom.UnitVector3d) -> int: 

231 """Compute the index of the pixel for given direction. 

232 

233 Parameters 

234 ---------- 

235 direction : `lsst.sphgeom.UnitVector3d` 

236 """ 

237 index = self.pixelator.index(direction) 

238 return index 

239 

240 

241if CASSANDRA_IMPORTED: 241 ↛ 243line 241 didn't jump to line 243, because the condition on line 241 was never true

242 

243 class _AddressTranslator(AddressTranslator): 

244 """Translate internal IP address to external. 

245 

246 Only used for docker-based setup, not viable long-term solution. 

247 """ 

248 def __init__(self, public_ips: List[str], private_ips: List[str]): 

249 self._map = dict((k, v) for k, v in zip(private_ips, public_ips)) 

250 

251 def translate(self, private_ip: str) -> str: 

252 return self._map.get(private_ip, private_ip) 

253 

254 

255def _rows_to_pandas(colnames: List[str], rows: List[Tuple], 

256 packedColumns: List[ColumnDef]) -> pandas.DataFrame: 

257 """Convert result rows to pandas. 

258 

259 Unpacks BLOBs that were packed on insert. 

260 

261 Parameters 

262 ---------- 

263 colname : `list` [ `str` ] 

264 Names of the columns. 

265 rows : `list` of `tuple` 

266 Result rows. 

267 packedColumns : `list` [ `ColumnDef` ] 

268 Column definitions for packed columns. 

269 

270 Returns 

271 ------- 

272 catalog : `pandas.DataFrame` 

273 DataFrame with the result set. 

274 """ 

275 try: 

276 idx = colnames.index("apdb_packed") 

277 except ValueError: 

278 # no packed columns 

279 return pandas.DataFrame.from_records(rows, columns=colnames) 

280 

281 # make data frame for non-packed columns 

282 df = pandas.DataFrame.from_records(rows, columns=colnames, exclude=["apdb_packed"]) 

283 

284 # make records with packed data only as dicts 

285 packed_rows = [] 

286 for row in rows: 

287 blob = row[idx] 

288 if blob[:5] == b"cbor:": 

289 blob = cbor.loads(blob[5:]) 

290 else: 

291 raise ValueError("Unexpected BLOB format: %r", blob) 

292 packed_rows.append(blob) 

293 

294 # make data frome from packed data 

295 packed = pandas.DataFrame.from_records(packed_rows, columns=[col.name for col in packedColumns]) 

296 

297 # convert timestamps which are integer milliseconds into datetime 

298 for col in packedColumns: 

299 if col.type == "DATETIME": 

300 packed[col.name] = pandas.to_datetime(packed[col.name], unit="ms", origin="unix") 

301 

302 return pandas.concat([df, packed], axis=1) 

303 

304 

305class _PandasRowFactory: 

306 """Create pandas DataFrame from Cassandra result set. 

307 

308 Parameters 

309 ---------- 

310 packedColumns : `list` [ `ColumnDef` ] 

311 Column definitions for packed columns. 

312 """ 

313 def __init__(self, packedColumns: Iterable[ColumnDef]): 

314 self.packedColumns = list(packedColumns) 

315 

316 def __call__(self, colnames: List[str], rows: List[Tuple]) -> pandas.DataFrame: 

317 """Convert result set into output catalog. 

318 

319 Parameters 

320 ---------- 

321 colname : `list` [ `str` ] 

322 Names of the columns. 

323 rows : `list` of `tuple` 

324 Result rows 

325 

326 Returns 

327 ------- 

328 catalog : `pandas.DataFrame` 

329 DataFrame with the result set. 

330 """ 

331 return _rows_to_pandas(colnames, rows, self.packedColumns) 

332 

333 

334class _RawRowFactory: 

335 """Row factory that makes no conversions. 

336 

337 Parameters 

338 ---------- 

339 packedColumns : `list` [ `ColumnDef` ] 

340 Column definitions for packed columns. 

341 """ 

342 def __init__(self, packedColumns: Iterable[ColumnDef]): 

343 self.packedColumns = list(packedColumns) 

344 

345 def __call__(self, colnames: List[str], rows: List[Tuple]) -> Tuple[List[str], List[Tuple]]: 

346 """Return parameters without change. 

347 

348 Parameters 

349 ---------- 

350 colname : `list` of `str` 

351 Names of the columns. 

352 rows : `list` of `tuple` 

353 Result rows 

354 

355 Returns 

356 ------- 

357 colname : `list` of `str` 

358 Names of the columns. 

359 rows : `list` of `tuple` 

360 Result rows 

361 """ 

362 return (colnames, rows) 

363 

364 

365class ApdbCassandra(Apdb): 

366 """Implementation of APDB database on to of Apache Cassandra. 

367 

368 The implementation is configured via standard ``pex_config`` mechanism 

369 using `ApdbCassandraConfig` configuration class. For an example of 

370 different configurations check config/ folder. 

371 

372 Parameters 

373 ---------- 

374 config : `ApdbCassandraConfig` 

375 Configuration object. 

376 """ 

377 

378 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI) 

379 """Start time for partition 0, this should never be changed.""" 

380 

381 def __init__(self, config: ApdbCassandraConfig): 

382 

383 if not CASSANDRA_IMPORTED: 

384 raise CassandraMissingError() 

385 

386 self.config = config 

387 

388 _LOG.debug("ApdbCassandra Configuration:") 

389 _LOG.debug(" read_consistency: %s", self.config.read_consistency) 

390 _LOG.debug(" write_consistency: %s", self.config.write_consistency) 

391 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

392 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

393 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

394 _LOG.debug(" schema_file: %s", self.config.schema_file) 

395 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

396 _LOG.debug(" schema prefix: %s", self.config.prefix) 

397 _LOG.debug(" part_pixelization: %s", self.config.part_pixelization) 

398 _LOG.debug(" part_pix_level: %s", self.config.part_pix_level) 

399 _LOG.debug(" query_per_time_part: %s", self.config.query_per_time_part) 

400 _LOG.debug(" query_per_spatial_part: %s", self.config.query_per_spatial_part) 

401 

402 self._partitioner = Partitioner(config) 

403 

404 addressTranslator: Optional[AddressTranslator] = None 

405 if config.private_ips: 

406 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

407 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips) 

408 else: 

409 loadBalancePolicy = RoundRobinPolicy() 

410 

411 self._read_consistency = getattr(cassandra.ConsistencyLevel, config.read_consistency) 

412 self._write_consistency = getattr(cassandra.ConsistencyLevel, config.write_consistency) 

413 

414 self._cluster = Cluster(contact_points=self.config.contact_points, 

415 load_balancing_policy=loadBalancePolicy, 

416 address_translator=addressTranslator, 

417 protocol_version=self.config.protocol_version) 

418 self._session = self._cluster.connect(keyspace=config.keyspace) 

419 self._session.row_factory = cassandra.query.named_tuple_factory 

420 

421 self._schema = ApdbCassandraSchema(session=self._session, 

422 schema_file=self.config.schema_file, 

423 extra_schema_file=self.config.extra_schema_file, 

424 prefix=self.config.prefix, 

425 packing=self.config.packing, 

426 time_partition_tables=self.config.time_partition_tables) 

427 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD) 

428 

429 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

430 # docstring is inherited from a base class 

431 return self._schema.tableSchemas.get(table) 

432 

433 def makeSchema(self, drop: bool = False) -> None: 

434 # docstring is inherited from a base class 

435 

436 if self.config.time_partition_tables: 

437 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI) 

438 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI) 

439 part_range = ( 

440 self._time_partition(time_partition_start), 

441 self._time_partition(time_partition_end) + 1 

442 ) 

443 self._schema.makeSchema(drop=drop, part_range=part_range) 

444 else: 

445 self._schema.makeSchema(drop=drop) 

446 

447 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

448 # docstring is inherited from a base class 

449 packedColumns = self._schema.packedColumns(ApdbTables.DiaObjectLast) 

450 self._session.row_factory = _PandasRowFactory(packedColumns) 

451 self._session.default_fetch_size = None 

452 

453 pixels = self._partitioner.pixels(region) 

454 _LOG.debug("getDiaObjects: #partitions: %s", len(pixels)) 

455 pixels_str = ",".join([str(pix) for pix in pixels]) 

456 

457 queries: List[Tuple] = [] 

458 query = f'SELECT * from "DiaObjectLast" WHERE "apdb_part" IN ({pixels_str})' 

459 queries += [(cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {})] 

460 _LOG.debug("getDiaObjects: #queries: %s", len(queries)) 

461 # _LOG.debug("getDiaObjects: queries: %s", queries) 

462 

463 objects = None 

464 with Timer('DiaObject select', self.config.timer): 

465 # submit all queries 

466 futures = [self._session.execute_async(query, values, timeout=self.config.read_timeout) 

467 for query, values in queries] 

468 # TODO: This orders result processing which is not very efficient 

469 dataframes = [future.result()._current_rows for future in futures] 

470 # concatenate all frames 

471 if len(dataframes) == 1: 

472 objects = dataframes[0] 

473 else: 

474 objects = pandas.concat(dataframes) 

475 

476 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

477 return objects 

478 

479 def getDiaSources(self, region: sphgeom.Region, 

480 object_ids: Optional[Iterable[int]], 

481 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

482 # docstring is inherited from a base class 

483 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaSource, 

484 self.config.read_sources_months) 

485 

486 def getDiaForcedSources(self, region: sphgeom.Region, 

487 object_ids: Optional[Iterable[int]], 

488 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

489 return self._getSources(region, object_ids, visit_time, ApdbTables.DiaForcedSource, 

490 self.config.read_forced_sources_months) 

491 

492 def _getSources(self, region: sphgeom.Region, 

493 object_ids: Optional[Iterable[int]], 

494 visit_time: dafBase.DateTime, 

495 table_name: ApdbTables, 

496 months: int) -> Optional[pandas.DataFrame]: 

497 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

498 

499 Parameters 

500 ---------- 

501 region : `lsst.sphgeom.Region` 

502 Spherical region. 

503 object_ids : 

504 Collection of DiaObject IDs 

505 visit_time : `lsst.daf.base.DateTime` 

506 Time of the current visit 

507 table_name : `ApdbTables` 

508 Name of the table, either "DiaSource" or "DiaForcedSource" 

509 months : `int` 

510 Number of months of history to return, if negative returns whole 

511 history (Note: negative does not work with table-per-partition 

512 case) 

513 

514 Returns 

515 ------- 

516 catalog : `pandas.DataFrame`, or `None` 

517 Catalog contaning DiaSource records. `None` is returned if 

518 ``months`` is 0 or when ``object_ids`` is empty. 

519 """ 

520 if months == 0: 

521 return None 

522 object_id_set: Set[int] = set() 

523 if object_ids is not None: 

524 object_id_set = set(object_ids) 

525 if len(object_id_set) == 0: 

526 return self._make_empty_catalog(table_name) 

527 

528 packedColumns = self._schema.packedColumns(table_name) 

529 if self.config.pandas_delay_conv: 

530 self._session.row_factory = _RawRowFactory(packedColumns) 

531 else: 

532 self._session.row_factory = _PandasRowFactory(packedColumns) 

533 self._session.default_fetch_size = None 

534 

535 # spatial pixels included into query 

536 pixels = self._partitioner.pixels(region) 

537 _LOG.debug("_getSources: %s #partitions: %s", table_name.name, len(pixels)) 

538 

539 # spatial part of WHERE 

540 spatial_where = [] 

541 if self.config.query_per_spatial_part: 

542 spatial_where = [f'"apdb_part" = {pixel}' for pixel in pixels] 

543 else: 

544 pixels_str = ",".join([str(pix) for pix in pixels]) 

545 spatial_where = [f'"apdb_part" IN ({pixels_str})'] 

546 

547 # temporal part of WHERE, can be empty 

548 temporal_where = [] 

549 # time partitions and table names to query, there may be multiple 

550 # tables depending on configuration 

551 full_name = self._schema.tableName(table_name) 

552 tables = [full_name] 

553 mjd_now = visit_time.get(system=dafBase.DateTime.MJD) 

554 mjd_begin = mjd_now - months*30 

555 time_part_now = self._time_partition(mjd_now) 

556 time_part_begin = self._time_partition(mjd_begin) 

557 time_parts = list(range(time_part_begin, time_part_now + 1)) 

558 if self.config.time_partition_tables: 

559 tables = [f"{full_name}_{part}" for part in time_parts] 

560 else: 

561 if self.config.query_per_time_part: 

562 temporal_where = [f'"apdb_time_part" = {time_part}' for time_part in time_parts] 

563 else: 

564 time_part_list = ",".join([str(part) for part in time_parts]) 

565 temporal_where = [f'"apdb_time_part" IN ({time_part_list})'] 

566 

567 # Build all queries 

568 queries: List[str] = [] 

569 for table in tables: 

570 query = f'SELECT * from "{table}" WHERE ' 

571 for spacial in spatial_where: 

572 if temporal_where: 

573 for temporal in temporal_where: 

574 queries.append(query + spacial + " AND " + temporal) 

575 else: 

576 queries.append(query + spacial) 

577 # _LOG.debug("_getSources: queries: %s", queries) 

578 

579 statements: List[Tuple] = [ 

580 (cassandra.query.SimpleStatement(query, consistency_level=self._read_consistency), {}) 

581 for query in queries 

582 ] 

583 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

584 

585 with Timer(table_name.name + ' select', self.config.timer): 

586 # submit all queries 

587 results = execute_concurrent(self._session, statements, results_generator=True, 

588 concurrency=self.config.read_concurrency) 

589 if self.config.pandas_delay_conv: 

590 _LOG.debug("making pandas data frame out of rows/columns") 

591 columns: Any = None 

592 rows = [] 

593 for success, result in results: 

594 result = result._current_rows 

595 if success: 

596 if columns is None: 

597 columns = result[0] 

598 elif columns != result[0]: 

599 _LOG.error("different columns returned by queries: %s and %s", 

600 columns, result[0]) 

601 raise ValueError( 

602 f"diferent columns returned by queries: {columns} and {result[0]}" 

603 ) 

604 rows += result[1] 

605 else: 

606 _LOG.error("error returned by query: %s", result) 

607 raise result 

608 catalog = _rows_to_pandas(columns, rows, self._schema.packedColumns(table_name)) 

609 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

610 # filter by given object IDs 

611 if len(object_id_set) > 0: 

612 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

613 else: 

614 _LOG.debug("making pandas data frame out of set of data frames") 

615 dataframes = [] 

616 for success, result in results: 

617 if success: 

618 dataframes.append(result._current_rows) 

619 else: 

620 _LOG.error("error returned by query: %s", result) 

621 raise result 

622 # concatenate all frames 

623 if len(dataframes) == 1: 

624 catalog = dataframes[0] 

625 else: 

626 catalog = pandas.concat(dataframes) 

627 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

628 # filter by given object IDs 

629 if len(object_id_set) > 0: 

630 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

631 

632 # precise filtering on midPointTai 

633 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_begin]) 

634 

635 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

636 return catalog 

637 

638 def store(self, 

639 visit_time: dafBase.DateTime, 

640 objects: pandas.DataFrame, 

641 sources: Optional[pandas.DataFrame] = None, 

642 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

643 # docstring is inherited from a base class 

644 

645 # fill region partition column for DiaObjects 

646 objects = self._add_obj_part(objects) 

647 self._storeDiaObjects(objects, visit_time) 

648 

649 if sources is not None: 

650 # copy apdb_part column from DiaObjects to DiaSources 

651 sources = self._add_src_part(sources, objects) 

652 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time) 

653 

654 if forced_sources is not None: 

655 forced_sources = self._add_fsrc_part(forced_sources, objects) 

656 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time) 

657 

658 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

659 """Store catalog of DiaObjects from current visit. 

660 

661 Parameters 

662 ---------- 

663 objs : `pandas.DataFrame` 

664 Catalog with DiaObject records 

665 visit_time : `lsst.daf.base.DateTime` 

666 Time of the current visit. 

667 """ 

668 visit_time_dt = visit_time.toPython() 

669 extra_columns = dict(lastNonForcedSource=visit_time_dt) 

670 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, visit_time, extra_columns=extra_columns) 

671 

672 extra_columns["validityStart"] = visit_time_dt 

673 time_part: Optional[int] = self._time_partition(visit_time) 

674 if not self.config.time_partition_tables: 

675 extra_columns["apdb_time_part"] = time_part 

676 time_part = None 

677 

678 self._storeObjectsPandas(objs, ApdbTables.DiaObject, visit_time, 

679 extra_columns=extra_columns, time_part=time_part) 

680 

681 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame, 

682 visit_time: dafBase.DateTime) -> None: 

683 """Store catalog of DIASources or DIAForcedSources from current visit. 

684 

685 Parameters 

686 ---------- 

687 sources : `pandas.DataFrame` 

688 Catalog containing DiaSource records 

689 visit_time : `lsst.daf.base.DateTime` 

690 Time of the current visit. 

691 """ 

692 time_part: Optional[int] = self._time_partition(visit_time) 

693 extra_columns = {} 

694 if not self.config.time_partition_tables: 

695 extra_columns["apdb_time_part"] = time_part 

696 time_part = None 

697 

698 self._storeObjectsPandas(sources, table_name, visit_time, 

699 extra_columns=extra_columns, time_part=time_part) 

700 

701 def dailyJob(self) -> None: 

702 # docstring is inherited from a base class 

703 pass 

704 

705 def countUnassociatedObjects(self) -> int: 

706 # docstring is inherited from a base class 

707 raise NotImplementedError() 

708 

709 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: ApdbTables, 

710 visit_time: dafBase.DateTime, extra_columns: Optional[Mapping] = None, 

711 time_part: Optional[int] = None) -> None: 

712 """Generic store method. 

713 

714 Takes catalog of records and stores a bunch of objects in a table. 

715 

716 Parameters 

717 ---------- 

718 objects : `pandas.DataFrame` 

719 Catalog containing object records 

720 table_name : `ApdbTables` 

721 Name of the table as defined in APDB schema. 

722 visit_time : `lsst.daf.base.DateTime` 

723 Time of the current visit. 

724 extra_columns : `dict`, optional 

725 Mapping (column_name, column_value) which gives column values to add 

726 to every row, only if column is missing in catalog records. 

727 time_part : `int`, optional 

728 If not `None` then insert into a per-partition table. 

729 """ 

730 

731 def qValue(v: Any) -> Any: 

732 """Transform object into a value for query""" 

733 if v is None: 

734 pass 

735 elif isinstance(v, datetime): 

736 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1))*1000 

737 elif isinstance(v, (bytes, str)): 

738 pass 

739 else: 

740 try: 

741 if not np.isfinite(v): 

742 v = None 

743 except TypeError: 

744 pass 

745 return v 

746 

747 def quoteId(columnName: str) -> str: 

748 """Smart quoting for column names. 

749 Lower-case names are not quoted. 

750 """ 

751 if not columnName.islower(): 

752 columnName = '"' + columnName + '"' 

753 return columnName 

754 

755 # use extra columns if specified 

756 if extra_columns is None: 

757 extra_columns = {} 

758 extra_fields = list(extra_columns.keys()) 

759 

760 df_fields = [column for column in objects.columns 

761 if column not in extra_fields] 

762 

763 column_map = self._schema.getColumnMap(table_name) 

764 # list of columns (as in cat schema) 

765 fields = [column_map[field].name for field in df_fields if field in column_map] 

766 fields += extra_fields 

767 

768 # check that all partitioning and clustering columns are defined 

769 required_columns = self._schema.partitionColumns(table_name) \ 

770 + self._schema.clusteringColumns(table_name) 

771 missing_columns = [column for column in required_columns if column not in fields] 

772 if missing_columns: 

773 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

774 

775 blob_columns = set(col.name for col in self._schema.packedColumns(table_name)) 

776 # _LOG.debug("blob_columns: %s", blob_columns) 

777 

778 qfields = [quoteId(field) for field in fields if field not in blob_columns] 

779 if blob_columns: 

780 qfields += [quoteId("apdb_packed")] 

781 qfields_str = ','.join(qfields) 

782 

783 with Timer(table_name.name + ' query build', self.config.timer): 

784 

785 table = self._schema.tableName(table_name) 

786 if time_part is not None: 

787 table = f"{table}_{time_part}" 

788 

789 prepared: Optional[cassandra.query.PreparedStatement] = None 

790 if self.config.prepared_statements: 

791 holders = ','.join(['?']*len(qfields)) 

792 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})' 

793 prepared = self._session.prepare(query) 

794 queries = cassandra.query.BatchStatement(consistency_level=self._write_consistency) 

795 for rec in objects.itertuples(index=False): 

796 values = [] 

797 blob = {} 

798 for field in df_fields: 

799 if field not in column_map: 

800 continue 

801 value = getattr(rec, field) 

802 if column_map[field].type == "DATETIME": 

803 if isinstance(value, pandas.Timestamp): 

804 value = qValue(value.to_pydatetime()) 

805 else: 

806 # Assume it's seconds since epoch, Cassandra 

807 # datetime is in milliseconds 

808 value = int(value*1000) 

809 if field in blob_columns: 

810 blob[field] = qValue(value) 

811 else: 

812 values.append(qValue(value)) 

813 for field in extra_fields: 

814 value = extra_columns[field] 

815 if field in blob_columns: 

816 blob[field] = qValue(value) 

817 else: 

818 values.append(qValue(value)) 

819 if blob_columns: 

820 if self.config.packing == "cbor": 

821 blob = b"cbor:" + cbor.dumps(blob) 

822 values.append(blob) 

823 holders = ','.join(['%s']*len(values)) 

824 if prepared is not None: 

825 stmt = prepared 

826 else: 

827 query = f'INSERT INTO "{table}" ({qfields_str}) VALUES ({holders})' 

828 # _LOG.debug("query: %r", query) 

829 # _LOG.debug("values: %s", values) 

830 stmt = cassandra.query.SimpleStatement(query, consistency_level=self._write_consistency) 

831 queries.add(stmt, values) 

832 

833 # _LOG.debug("query: %s", query) 

834 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0]) 

835 with Timer(table_name.name + ' insert', self.config.timer): 

836 self._session.execute(queries, timeout=self.config.write_timeout) 

837 

838 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

839 """Calculate spacial partition for each record and add it to a 

840 DataFrame. 

841 

842 Notes 

843 ----- 

844 This overrides any existing column in a DataFrame with the same name 

845 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

846 returned. 

847 """ 

848 # calculate HTM index for every DiaObject 

849 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

850 ra_col, dec_col = self.config.ra_dec_columns 

851 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

852 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

853 idx = self._partitioner.pixel(uv3d) 

854 apdb_part[i] = idx 

855 df = df.copy() 

856 df["apdb_part"] = apdb_part 

857 return df 

858 

859 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

860 """Add apdb_part column to DiaSource catalog. 

861 

862 Notes 

863 ----- 

864 This method copies apdb_part value from a matching DiaObject record. 

865 DiaObject catalog needs to have a apdb_part column filled by 

866 ``_add_obj_part`` method and DiaSource records need to be 

867 associated to DiaObjects via ``diaObjectId`` column. 

868 

869 This overrides any existing column in a DataFrame with the same name 

870 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

871 returned. 

872 """ 

873 pixel_id_map: Dict[int, int] = { 

874 diaObjectId: apdb_part for diaObjectId, apdb_part 

875 in zip(objs["diaObjectId"], objs["apdb_part"]) 

876 } 

877 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

878 ra_col, dec_col = self.config.ra_dec_columns 

879 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"], 

880 sources[ra_col], sources[dec_col])): 

881 if diaObjId == 0: 

882 # DiaSources associated with SolarSystemObjects do not have an 

883 # associated DiaObject hence we skip them and set partition 

884 # based on its own ra/dec 

885 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

886 idx = self._partitioner.pixel(uv3d) 

887 apdb_part[i] = idx 

888 else: 

889 apdb_part[i] = pixel_id_map[diaObjId] 

890 sources = sources.copy() 

891 sources["apdb_part"] = apdb_part 

892 return sources 

893 

894 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

895 """Add apdb_part column to DiaForcedSource catalog. 

896 

897 Notes 

898 ----- 

899 This method copies apdb_part value from a matching DiaObject record. 

900 DiaObject catalog needs to have a apdb_part column filled by 

901 ``_add_obj_part`` method and DiaSource records need to be 

902 associated to DiaObjects via ``diaObjectId`` column. 

903 

904 This overrides any existing column in a DataFrame with the same name 

905 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is 

906 returned. 

907 """ 

908 pixel_id_map: Dict[int, int] = { 

909 diaObjectId: apdb_part for diaObjectId, apdb_part 

910 in zip(objs["diaObjectId"], objs["apdb_part"]) 

911 } 

912 apdb_part = np.zeros(sources.shape[0], dtype=np.int64) 

913 for i, diaObjId in enumerate(sources["diaObjectId"]): 

914 apdb_part[i] = pixel_id_map[diaObjId] 

915 sources = sources.copy() 

916 sources["apdb_part"] = apdb_part 

917 return sources 

918 

919 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int: 

920 """Calculate time partiton number for a given time. 

921 

922 Parameters 

923 ---------- 

924 time : `float` or `lsst.daf.base.DateTime` 

925 Time for which to calculate partition number. Can be float to mean 

926 MJD or `lsst.daf.base.DateTime` 

927 

928 Returns 

929 ------- 

930 partition : `int` 

931 Partition number for a given time. 

932 """ 

933 if isinstance(time, dafBase.DateTime): 

934 mjd = time.get(system=dafBase.DateTime.MJD) 

935 else: 

936 mjd = time 

937 days_since_epoch = mjd - self._partition_zero_epoch_mjd 

938 partition = int(days_since_epoch) // self.config.time_partition_days 

939 return partition 

940 

941 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

942 """Make an empty catalog for a table with a given name. 

943 

944 Parameters 

945 ---------- 

946 table_name : `ApdbTables` 

947 Name of the table. 

948 

949 Returns 

950 ------- 

951 catalog : `pandas.DataFrame` 

952 An empty catalog. 

953 """ 

954 table = self._schema.tableSchemas[table_name] 

955 

956 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns} 

957 return pandas.DataFrame(data)