Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

367 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-12 02:38 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29from contextlib import contextmanager 

30import logging 

31import numpy as np 

32import pandas 

33from typing import cast, Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple 

34 

35import lsst.daf.base as dafBase 

36from lsst.pex.config import Field, ChoiceField, ListField 

37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

38from lsst.utils.iteration import chunk_iterable 

39import sqlalchemy 

40from sqlalchemy import (func, sql) 

41from sqlalchemy.pool import NullPool 

42from .apdb import Apdb, ApdbConfig 

43from .apdbSchema import ApdbTables, TableDef 

44from .apdbSqlSchema import ApdbSqlSchema 

45from .timer import Timer 

46 

47 

48_LOG = logging.getLogger(__name__) 

49 

50 

51def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

52 """Change type of the uint64 columns to int64, return copy of data frame. 

53 """ 

54 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

55 return df.astype({name: np.int64 for name in names}) 

56 

57 

58def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

59 """Calculate starting point for time-based source search. 

60 

61 Parameters 

62 ---------- 

63 visit_time : `lsst.daf.base.DateTime` 

64 Time of current visit. 

65 months : `int` 

66 Number of months in the sources history. 

67 

68 Returns 

69 ------- 

70 time : `float` 

71 A ``midPointTai`` starting point, MJD time. 

72 """ 

73 # TODO: `system` must be consistent with the code in ap_association 

74 # (see DM-31996) 

75 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

76 

77 

78@contextmanager 

79def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]: 

80 """Returns a connection, makes sure that ANSI mode is set for MySQL 

81 """ 

82 with engine.begin() as conn: 

83 if engine.name == 'mysql': 

84 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'")) 

85 yield conn 

86 return 

87 

88 

89class ApdbSqlConfig(ApdbConfig): 

90 """APDB configuration class for SQL implementation (ApdbSql). 

91 """ 

92 db_url = Field( 

93 dtype=str, 

94 doc="SQLAlchemy database connection URI" 

95 ) 

96 isolation_level = ChoiceField( 

97 dtype=str, 

98 doc="Transaction isolation level, if unset then backend-default value " 

99 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

100 "Some backends may not support every allowed value.", 

101 allowed={ 

102 "READ_COMMITTED": "Read committed", 

103 "READ_UNCOMMITTED": "Read uncommitted", 

104 "REPEATABLE_READ": "Repeatable read", 

105 "SERIALIZABLE": "Serializable" 

106 }, 

107 default=None, 

108 optional=True 

109 ) 

110 connection_pool = Field( 

111 dtype=bool, 

112 doc="If False then disable SQLAlchemy connection pool. " 

113 "Do not use connection pool when forking.", 

114 default=True 

115 ) 

116 connection_timeout = Field( 

117 dtype=float, 

118 doc="Maximum time to wait time for database lock to be released before " 

119 "exiting. Defaults to sqlachemy defaults if not set.", 

120 default=None, 

121 optional=True 

122 ) 

123 sql_echo = Field( 

124 dtype=bool, 

125 doc="If True then pass SQLAlchemy echo option.", 

126 default=False 

127 ) 

128 dia_object_index = ChoiceField( 

129 dtype=str, 

130 doc="Indexing mode for DiaObject table", 

131 allowed={ 

132 'baseline': "Index defined in baseline schema", 

133 'pix_id_iov': "(pixelId, objectId, iovStart) PK", 

134 'last_object_table': "Separate DiaObjectLast table" 

135 }, 

136 default='baseline' 

137 ) 

138 htm_level = Field( 

139 dtype=int, 

140 doc="HTM indexing level", 

141 default=20 

142 ) 

143 htm_max_ranges = Field( 

144 dtype=int, 

145 doc="Max number of ranges in HTM envelope", 

146 default=64 

147 ) 

148 htm_index_column = Field( 

149 dtype=str, 

150 default="pixelId", 

151 doc="Name of a HTM index column for DiaObject and DiaSource tables" 

152 ) 

153 ra_dec_columns = ListField( 

154 dtype=str, 

155 default=["ra", "decl"], 

156 doc="Names ra/dec columns in DiaObject table" 

157 ) 

158 dia_object_columns = ListField( 

159 dtype=str, 

160 doc="List of columns to read from DiaObject, by default read all columns", 

161 default=[] 

162 ) 

163 object_last_replace = Field( 

164 dtype=bool, 

165 doc="If True (default) then use \"upsert\" for DiaObjectsLast table", 

166 default=True 

167 ) 

168 prefix = Field( 

169 dtype=str, 

170 doc="Prefix to add to table names and index names", 

171 default="" 

172 ) 

173 namespace = Field( 

174 dtype=str, 

175 doc=( 

176 "Namespace or schema name for all tables in APDB database. " 

177 "Presently only makes sense for PostgresQL backend. " 

178 "If schema with this name does not exist it will be created when " 

179 "APDB tables are created." 

180 ), 

181 default=None, 

182 optional=True 

183 ) 

184 explain = Field( 

185 dtype=bool, 

186 doc="If True then run EXPLAIN SQL command on each executed query", 

187 default=False 

188 ) 

189 timer = Field( 

190 dtype=bool, 

191 doc="If True then print/log timing information", 

192 default=False 

193 ) 

194 

195 def validate(self) -> None: 

196 super().validate() 

197 if len(self.ra_dec_columns) != 2: 

198 raise ValueError("ra_dec_columns must have exactly two column names") 

199 

200 

201class ApdbSql(Apdb): 

202 """Implementation of APDB interface based on SQL database. 

203 

204 The implementation is configured via standard ``pex_config`` mechanism 

205 using `ApdbSqlConfig` configuration class. For an example of different 

206 configurations check ``config/`` folder. 

207 

208 Parameters 

209 ---------- 

210 config : `ApdbSqlConfig` 

211 Configuration object. 

212 """ 

213 

214 ConfigClass = ApdbSqlConfig 

215 

216 def __init__(self, config: ApdbSqlConfig): 

217 

218 self.config = config 

219 

220 _LOG.debug("APDB Configuration:") 

221 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

222 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

223 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

224 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

225 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace) 

226 _LOG.debug(" schema_file: %s", self.config.schema_file) 

227 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

228 _LOG.debug(" schema prefix: %s", self.config.prefix) 

229 

230 # engine is reused between multiple processes, make sure that we don't 

231 # share connections by disabling pool (by using NullPool class) 

232 kw = dict(echo=self.config.sql_echo) 

233 conn_args: Dict[str, Any] = dict() 

234 if not self.config.connection_pool: 

235 kw.update(poolclass=NullPool) 

236 if self.config.isolation_level is not None: 

237 kw.update(isolation_level=self.config.isolation_level) 

238 elif self.config.db_url.startswith("sqlite"): 

239 # Use READ_UNCOMMITTED as default value for sqlite. 

240 kw.update(isolation_level="READ_UNCOMMITTED") 

241 if self.config.connection_timeout is not None: 

242 if self.config.db_url.startswith("sqlite"): 

243 conn_args.update(timeout=self.config.connection_timeout) 

244 elif self.config.db_url.startswith(("postgresql", "mysql")): 

245 conn_args.update(connect_timeout=self.config.connection_timeout) 

246 kw.update(connect_args=conn_args) 

247 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

248 

249 self._schema = ApdbSqlSchema(engine=self._engine, 

250 dia_object_index=self.config.dia_object_index, 

251 schema_file=self.config.schema_file, 

252 schema_name=self.config.schema_name, 

253 prefix=self.config.prefix, 

254 namespace=self.config.namespace, 

255 htm_index_column=self.config.htm_index_column) 

256 

257 self.pixelator = HtmPixelization(self.config.htm_level) 

258 

259 def tableRowCount(self) -> Dict[str, int]: 

260 """Returns dictionary with the table names and row counts. 

261 

262 Used by ``ap_proto`` to keep track of the size of the database tables. 

263 Depending on database technology this could be expensive operation. 

264 

265 Returns 

266 ------- 

267 row_counts : `dict` 

268 Dict where key is a table name and value is a row count. 

269 """ 

270 res = {} 

271 tables: List[sqlalchemy.schema.Table] = [ 

272 self._schema.objects, self._schema.sources, self._schema.forcedSources] 

273 if self.config.dia_object_index == 'last_object_table': 

274 tables.append(self._schema.objects_last) 

275 for table in tables: 

276 stmt = sql.select([func.count()]).select_from(table) 

277 count = self._engine.scalar(stmt) 

278 res[table.name] = count 

279 

280 return res 

281 

282 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

283 # docstring is inherited from a base class 

284 return self._schema.tableSchemas.get(table) 

285 

286 def makeSchema(self, drop: bool = False) -> None: 

287 # docstring is inherited from a base class 

288 self._schema.makeSchema(drop=drop) 

289 

290 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

291 # docstring is inherited from a base class 

292 

293 # decide what columns we need 

294 table: sqlalchemy.schema.Table 

295 if self.config.dia_object_index == 'last_object_table': 

296 table = self._schema.objects_last 

297 else: 

298 table = self._schema.objects 

299 if not self.config.dia_object_columns: 

300 query = table.select() 

301 else: 

302 columns = [table.c[col] for col in self.config.dia_object_columns] 

303 query = sql.select(columns) 

304 

305 # build selection 

306 query = query.where(self._filterRegion(table, region)) 

307 

308 # select latest version of objects 

309 if self.config.dia_object_index != 'last_object_table': 

310 query = query.where(table.c.validityEnd == None) # noqa: E711 

311 

312 _LOG.debug("query: %s", query) 

313 

314 if self.config.explain: 

315 # run the same query with explain 

316 self._explain(query, self._engine) 

317 

318 # execute select 

319 with Timer('DiaObject select', self.config.timer): 

320 with self._engine.begin() as conn: 

321 objects = pandas.read_sql_query(query, conn) 

322 _LOG.debug("found %s DiaObjects", len(objects)) 

323 return objects 

324 

325 def getDiaSources(self, region: Region, 

326 object_ids: Optional[Iterable[int]], 

327 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

328 # docstring is inherited from a base class 

329 if self.config.read_sources_months == 0: 

330 _LOG.debug("Skip DiaSources fetching") 

331 return None 

332 

333 if object_ids is None: 

334 # region-based select 

335 return self._getDiaSourcesInRegion(region, visit_time) 

336 else: 

337 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

338 

339 def getDiaForcedSources(self, region: Region, 

340 object_ids: Optional[Iterable[int]], 

341 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

342 """Return catalog of DiaForcedSource instances from a given region. 

343 

344 Parameters 

345 ---------- 

346 region : `lsst.sphgeom.Region` 

347 Region to search for DIASources. 

348 object_ids : iterable [ `int` ], optional 

349 List of DiaObject IDs to further constrain the set of returned 

350 sources. If list is empty then empty catalog is returned with a 

351 correct schema. 

352 visit_time : `lsst.daf.base.DateTime` 

353 Time of the current visit. 

354 

355 Returns 

356 ------- 

357 catalog : `pandas.DataFrame`, or `None` 

358 Catalog containing DiaSource records. `None` is returned if 

359 ``read_sources_months`` configuration parameter is set to 0. 

360 

361 Raises 

362 ------ 

363 NotImplementedError 

364 Raised if ``object_ids`` is `None`. 

365 

366 Notes 

367 ----- 

368 Even though base class allows `None` to be passed for ``object_ids``, 

369 this class requires ``object_ids`` to be not-`None`. 

370 `NotImplementedError` is raised if `None` is passed. 

371 

372 This method returns DiaForcedSource catalog for a region with additional 

373 filtering based on DiaObject IDs. Only a subset of DiaSource history 

374 is returned limited by ``read_forced_sources_months`` config parameter, 

375 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

376 is always returned with a correct schema (columns/types). 

377 """ 

378 

379 if self.config.read_forced_sources_months == 0: 

380 _LOG.debug("Skip DiaForceSources fetching") 

381 return None 

382 

383 if object_ids is None: 

384 # This implementation does not support region-based selection. 

385 raise NotImplementedError("Region-based selection is not supported") 

386 

387 # TODO: DateTime.MJD must be consistent with code in ap_association, 

388 # alternatively we can fill midPointTai ourselves in store() 

389 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

390 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

391 

392 table: sqlalchemy.schema.Table = self._schema.forcedSources 

393 with Timer('DiaForcedSource select', self.config.timer): 

394 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start) 

395 

396 _LOG.debug("found %s DiaForcedSources", len(sources)) 

397 return sources 

398 

399 def getDiaObjectsHistory(self, 

400 start_time: dafBase.DateTime, 

401 end_time: dafBase.DateTime, 

402 region: Optional[Region] = None) -> pandas.DataFrame: 

403 # docstring is inherited from a base class 

404 

405 table = self._schema.objects 

406 query = table.select() 

407 

408 # build selection 

409 time_filter = sql.expression.and_( 

410 table.columns["validityStart"] >= start_time.toPython(), 

411 table.columns["validityStart"] < end_time.toPython() 

412 ) 

413 

414 if region: 

415 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

416 query = query.where(where) 

417 else: 

418 query = query.where(time_filter) 

419 

420 # execute select 

421 with Timer('DiaObject history select', self.config.timer): 

422 with self._engine.begin() as conn: 

423 catalog = pandas.read_sql_query(query, conn) 

424 _LOG.debug("found %s DiaObjects history records", len(catalog)) 

425 return catalog 

426 

427 def getDiaSourcesHistory(self, 

428 start_time: dafBase.DateTime, 

429 end_time: dafBase.DateTime, 

430 region: Optional[Region] = None) -> pandas.DataFrame: 

431 # docstring is inherited from a base class 

432 

433 table = self._schema.sources 

434 query = table.select() 

435 

436 # build selection 

437 time_filter = sql.expression.and_( 

438 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

439 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

440 ) 

441 

442 if region: 

443 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

444 query = query.where(where) 

445 else: 

446 query = query.where(time_filter) 

447 

448 # execute select 

449 with Timer('DiaSource history select', self.config.timer): 

450 with self._engine.begin() as conn: 

451 catalog = pandas.read_sql_query(query, conn) 

452 _LOG.debug("found %s DiaSource history records", len(catalog)) 

453 return catalog 

454 

455 def getDiaForcedSourcesHistory(self, 

456 start_time: dafBase.DateTime, 

457 end_time: dafBase.DateTime, 

458 region: Optional[Region] = None) -> pandas.DataFrame: 

459 # docstring is inherited from a base class 

460 

461 table = self._schema.forcedSources 

462 query = table.select() 

463 

464 # build selection 

465 time_filter = sql.expression.and_( 

466 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

467 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

468 ) 

469 # Forced sources have no pixel index, so no region filtering 

470 query = query.where(time_filter) 

471 

472 # execute select 

473 with Timer('DiaForcedSource history select', self.config.timer): 

474 with self._engine.begin() as conn: 

475 catalog = pandas.read_sql_query(query, conn) 

476 _LOG.debug("found %s DiaForcedSource history records", len(catalog)) 

477 return catalog 

478 

479 def getSSObjects(self) -> pandas.DataFrame: 

480 # docstring is inherited from a base class 

481 

482 table = self._schema.ssObjects 

483 query = table.select() 

484 

485 if self.config.explain: 

486 # run the same query with explain 

487 self._explain(query, self._engine) 

488 

489 # execute select 

490 with Timer('DiaObject select', self.config.timer): 

491 with self._engine.begin() as conn: 

492 objects = pandas.read_sql_query(query, conn) 

493 _LOG.debug("found %s SSObjects", len(objects)) 

494 return objects 

495 

496 def store(self, 

497 visit_time: dafBase.DateTime, 

498 objects: pandas.DataFrame, 

499 sources: Optional[pandas.DataFrame] = None, 

500 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

501 # docstring is inherited from a base class 

502 

503 # fill pixelId column for DiaObjects 

504 objects = self._add_obj_htm_index(objects) 

505 self._storeDiaObjects(objects, visit_time) 

506 

507 if sources is not None: 

508 # copy pixelId column from DiaObjects to DiaSources 

509 sources = self._add_src_htm_index(sources, objects) 

510 self._storeDiaSources(sources) 

511 

512 if forced_sources is not None: 

513 self._storeDiaForcedSources(forced_sources) 

514 

515 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

516 # docstring is inherited from a base class 

517 

518 idColumn = "ssObjectId" 

519 table = self._schema.ssObjects 

520 

521 # everything to be done in single transaction 

522 with self._engine.begin() as conn: 

523 

524 # Find record IDs that already exist. Some types like np.int64 can 

525 # cause issues with sqlalchemy, convert them to int. 

526 ids = sorted(int(oid) for oid in objects[idColumn]) 

527 

528 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

529 result = conn.execute(query) 

530 knownIds = set(row[idColumn] for row in result) 

531 

532 filter = objects[idColumn].isin(knownIds) 

533 toUpdate = cast(pandas.DataFrame, objects[filter]) 

534 toInsert = cast(pandas.DataFrame, objects[~filter]) 

535 

536 # insert new records 

537 if len(toInsert) > 0: 

538 toInsert.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

539 

540 # update existing records 

541 if len(toUpdate) > 0: 

542 whereKey = f"{idColumn}_param" 

543 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

544 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

545 values = toUpdate.to_dict("records") 

546 result = conn.execute(query, values) 

547 

548 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

549 # docstring is inherited from a base class 

550 

551 table = self._schema.sources 

552 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

553 

554 with self._engine.begin() as conn: 

555 # Need to make sure that every ID exists in the database, but 

556 # executemany may not support rowcount, so iterate and check what is 

557 # missing. 

558 missing_ids: List[int] = [] 

559 for key, value in idMap.items(): 

560 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

561 result = conn.execute(query, params) 

562 if result.rowcount == 0: 

563 missing_ids.append(key) 

564 if missing_ids: 

565 missing = ",".join(str(item)for item in missing_ids) 

566 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

567 

568 def dailyJob(self) -> None: 

569 # docstring is inherited from a base class 

570 

571 if self._engine.name == 'postgresql': 

572 

573 # do VACUUM on all tables 

574 _LOG.info("Running VACUUM on all tables") 

575 connection = self._engine.raw_connection() 

576 ISOLATION_LEVEL_AUTOCOMMIT = 0 

577 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

578 cursor = connection.cursor() 

579 cursor.execute("VACUUM ANALYSE") 

580 

581 def countUnassociatedObjects(self) -> int: 

582 # docstring is inherited from a base class 

583 

584 # Retrieve the DiaObject table. 

585 table: sqlalchemy.schema.Table = self._schema.objects 

586 

587 # Construct the sql statement. 

588 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

589 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

590 

591 # Return the count. 

592 with self._engine.begin() as conn: 

593 count = conn.scalar(stmt) 

594 

595 return count 

596 

597 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime 

598 ) -> pandas.DataFrame: 

599 """Returns catalog of DiaSource instances from given region. 

600 

601 Parameters 

602 ---------- 

603 region : `lsst.sphgeom.Region` 

604 Region to search for DIASources. 

605 visit_time : `lsst.daf.base.DateTime` 

606 Time of the current visit. 

607 

608 Returns 

609 ------- 

610 catalog : `pandas.DataFrame` 

611 Catalog containing DiaSource records. 

612 """ 

613 # TODO: DateTime.MJD must be consistent with code in ap_association, 

614 # alternatively we can fill midPointTai ourselves in store() 

615 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

616 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

617 

618 table: sqlalchemy.schema.Table = self._schema.sources 

619 query = table.select() 

620 

621 # build selection 

622 time_filter = table.columns["midPointTai"] > midPointTai_start 

623 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

624 query = query.where(where) 

625 

626 # execute select 

627 with Timer('DiaSource select', self.config.timer): 

628 with _ansi_session(self._engine) as conn: 

629 sources = pandas.read_sql_query(query, conn) 

630 _LOG.debug("found %s DiaSources", len(sources)) 

631 return sources 

632 

633 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime 

634 ) -> pandas.DataFrame: 

635 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

636 

637 Parameters 

638 ---------- 

639 object_ids : 

640 Collection of DiaObject IDs 

641 visit_time : `lsst.daf.base.DateTime` 

642 Time of the current visit. 

643 

644 Returns 

645 ------- 

646 catalog : `pandas.DataFrame` 

647 Catalog contaning DiaSource records. 

648 """ 

649 # TODO: DateTime.MJD must be consistent with code in ap_association, 

650 # alternatively we can fill midPointTai ourselves in store() 

651 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

652 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

653 

654 table: sqlalchemy.schema.Table = self._schema.sources 

655 with Timer('DiaSource select', self.config.timer): 

656 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start) 

657 

658 _LOG.debug("found %s DiaSources", len(sources)) 

659 return sources 

660 

661 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table, 

662 object_ids: List[int], 

663 midPointTai_start: float 

664 ) -> pandas.DataFrame: 

665 """Returns catalog of DiaSource or DiaForcedSource instances given set 

666 of DiaObject IDs. 

667 

668 Parameters 

669 ---------- 

670 table : `sqlalchemy.schema.Table` 

671 Database table. 

672 object_ids : 

673 Collection of DiaObject IDs 

674 midPointTai_start : `float` 

675 Earliest midPointTai to retrieve. 

676 

677 Returns 

678 ------- 

679 catalog : `pandas.DataFrame` 

680 Catalog contaning DiaSource records. `None` is returned if 

681 ``read_sources_months`` configuration parameter is set to 0 or 

682 when ``object_ids`` is empty. 

683 """ 

684 sources: Optional[pandas.DataFrame] = None 

685 with _ansi_session(self._engine) as conn: 

686 if len(object_ids) <= 0: 

687 _LOG.debug("ID list is empty, just fetch empty result") 

688 query = table.select().where(False) 

689 sources = pandas.read_sql_query(query, conn) 

690 else: 

691 for ids in chunk_iterable(sorted(object_ids), 1000): 

692 query = table.select() 

693 

694 # Some types like np.int64 can cause issues with 

695 # sqlalchemy, convert them to int. 

696 int_ids = [int(oid) for oid in ids] 

697 

698 # select by object id 

699 query = query.where( 

700 sql.expression.and_( 

701 table.columns["diaObjectId"].in_(int_ids), 

702 table.columns["midPointTai"] > midPointTai_start, 

703 ) 

704 ) 

705 

706 # execute select 

707 df = pandas.read_sql_query(query, conn) 

708 if sources is None: 

709 sources = df 

710 else: 

711 sources = sources.append(df) 

712 assert sources is not None, "Catalog cannot be None" 

713 return sources 

714 

715 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

716 """Store catalog of DiaObjects from current visit. 

717 

718 Parameters 

719 ---------- 

720 objs : `pandas.DataFrame` 

721 Catalog with DiaObject records. 

722 visit_time : `lsst.daf.base.DateTime` 

723 Time of the visit. 

724 """ 

725 

726 # Some types like np.int64 can cause issues with sqlalchemy, convert 

727 # them to int. 

728 ids = sorted(int(oid) for oid in objs['diaObjectId']) 

729 _LOG.debug("first object ID: %d", ids[0]) 

730 

731 # NOTE: workaround for sqlite, need this here to avoid 

732 # "database is locked" error. 

733 table: sqlalchemy.schema.Table = self._schema.objects 

734 

735 # TODO: Need to verify that we are using correct scale here for 

736 # DATETIME representation (see DM-31996). 

737 dt = visit_time.toPython() 

738 

739 # everything to be done in single transaction 

740 with _ansi_session(self._engine) as conn: 

741 

742 if self.config.dia_object_index == 'last_object_table': 

743 

744 # insert and replace all records in LAST table, mysql and postgres have 

745 # non-standard features 

746 table = self._schema.objects_last 

747 do_replace = self.config.object_last_replace 

748 # If the input data is of type Pandas, we drop the previous 

749 # objects regardless of the do_replace setting due to how 

750 # Pandas inserts objects. 

751 if not do_replace or isinstance(objs, pandas.DataFrame): 

752 query = table.delete().where( 

753 table.columns["diaObjectId"].in_(ids) 

754 ) 

755 

756 if self.config.explain: 

757 # run the same query with explain 

758 self._explain(query, conn) 

759 

760 with Timer(table.name + ' delete', self.config.timer): 

761 res = conn.execute(query) 

762 _LOG.debug("deleted %s objects", res.rowcount) 

763 

764 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt) 

765 with Timer("DiaObjectLast insert", self.config.timer): 

766 objs = _coerce_uint64(objs) 

767 for col, data in extra_columns.items(): 

768 objs[col] = data 

769 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

770 else: 

771 

772 # truncate existing validity intervals 

773 table = self._schema.objects 

774 

775 query = table.update().values(validityEnd=dt).where( 

776 sql.expression.and_( 

777 table.columns["diaObjectId"].in_(ids), 

778 table.columns["validityEnd"].is_(None), 

779 ) 

780 ) 

781 

782 # _LOG.debug("query: %s", query) 

783 

784 if self.config.explain: 

785 # run the same query with explain 

786 self._explain(query, conn) 

787 

788 with Timer(table.name + ' truncate', self.config.timer): 

789 res = conn.execute(query) 

790 _LOG.debug("truncated %s intervals", res.rowcount) 

791 

792 # insert new versions 

793 table = self._schema.objects 

794 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt, 

795 validityEnd=None) 

796 with Timer("DiaObject insert", self.config.timer): 

797 objs = _coerce_uint64(objs) 

798 for col, data in extra_columns.items(): 

799 objs[col] = data 

800 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

801 

802 def _storeDiaSources(self, sources: pandas.DataFrame) -> None: 

803 """Store catalog of DiaSources from current visit. 

804 

805 Parameters 

806 ---------- 

807 sources : `pandas.DataFrame` 

808 Catalog containing DiaSource records 

809 """ 

810 # everything to be done in single transaction 

811 with _ansi_session(self._engine) as conn: 

812 

813 with Timer("DiaSource insert", self.config.timer): 

814 sources = _coerce_uint64(sources) 

815 table = self._schema.sources 

816 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

817 

818 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None: 

819 """Store a set of DiaForcedSources from current visit. 

820 

821 Parameters 

822 ---------- 

823 sources : `pandas.DataFrame` 

824 Catalog containing DiaForcedSource records 

825 """ 

826 

827 # everything to be done in single transaction 

828 with _ansi_session(self._engine) as conn: 

829 

830 with Timer("DiaForcedSource insert", self.config.timer): 

831 sources = _coerce_uint64(sources) 

832 table = self._schema.forcedSources 

833 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

834 

835 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None: 

836 """Run the query with explain 

837 """ 

838 

839 _LOG.info("explain for query: %s...", query[:64]) 

840 

841 if conn.engine.name == 'mysql': 

842 query = "EXPLAIN EXTENDED " + query 

843 else: 

844 query = "EXPLAIN " + query 

845 

846 res = conn.execute(sql.text(query)) 

847 if res.returns_rows: 

848 _LOG.info("explain: %s", res.keys()) 

849 for row in res: 

850 _LOG.info("explain: %s", row) 

851 else: 

852 _LOG.info("EXPLAIN returned nothing") 

853 

854 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

855 """Generate a set of HTM indices covering specified region. 

856 

857 Parameters 

858 ---------- 

859 region: `sphgeom.Region` 

860 Region that needs to be indexed. 

861 

862 Returns 

863 ------- 

864 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

865 """ 

866 _LOG.debug('region: %s', region) 

867 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

868 

869 return indices.ranges() 

870 

871 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement: 

872 """Make SQLAlchemy expression for selecting records in a region. 

873 """ 

874 htm_index_column = table.columns[self.config.htm_index_column] 

875 exprlist = [] 

876 pixel_ranges = self._htm_indices(region) 

877 for low, upper in pixel_ranges: 

878 upper -= 1 

879 if low == upper: 

880 exprlist.append(htm_index_column == low) 

881 else: 

882 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

883 

884 return sql.expression.or_(*exprlist) 

885 

886 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

887 """Calculate HTM index for each record and add it to a DataFrame. 

888 

889 Notes 

890 ----- 

891 This overrides any existing column in a DataFrame with the same name 

892 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

893 returned. 

894 """ 

895 # calculate HTM index for every DiaObject 

896 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

897 ra_col, dec_col = self.config.ra_dec_columns 

898 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

899 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

900 idx = self.pixelator.index(uv3d) 

901 htm_index[i] = idx 

902 df = df.copy() 

903 df[self.config.htm_index_column] = htm_index 

904 return df 

905 

906 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

907 """Add pixelId column to DiaSource catalog. 

908 

909 Notes 

910 ----- 

911 This method copies pixelId value from a matching DiaObject record. 

912 DiaObject catalog needs to have a pixelId column filled by 

913 ``_add_obj_htm_index`` method and DiaSource records need to be 

914 associated to DiaObjects via ``diaObjectId`` column. 

915 

916 This overrides any existing column in a DataFrame with the same name 

917 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

918 returned. 

919 """ 

920 pixel_id_map: Dict[int, int] = { 

921 diaObjectId: pixelId for diaObjectId, pixelId 

922 in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

923 } 

924 # DiaSources associated with SolarSystemObjects do not have an 

925 # associated DiaObject hence we skip them and set their htmIndex 

926 # value to 0. 

927 pixel_id_map[0] = 0 

928 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

929 for i, diaObjId in enumerate(sources["diaObjectId"]): 

930 htm_index[i] = pixel_id_map[diaObjId] 

931 sources = sources.copy() 

932 sources[self.config.htm_index_column] = htm_index 

933 return sources