Coverage for python/lsst/dax/apdb/apdbSql.py: 14%

372 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-13 03:06 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Iterable, Iterator, Mapping, MutableMapping 

31from contextlib import contextmanager 

32from typing import Any, Dict, List, Optional, Tuple, cast 

33 

34import lsst.daf.base as dafBase 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38from felis.simple import Table 

39from lsst.pex.config import ChoiceField, Field, ListField 

40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

41from lsst.utils.iteration import chunk_iterable 

42from sqlalchemy import func, sql 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig 

46from .apdbSchema import ApdbTables 

47from .apdbSqlSchema import ApdbSqlSchema 

48from .timer import Timer 

49 

50_LOG = logging.getLogger(__name__) 

51 

52 

53def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

54 """Change type of the uint64 columns to int64, return copy of data frame. 

55 """ 

56 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

57 return df.astype({name: np.int64 for name in names}) 

58 

59 

60def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

61 """Calculate starting point for time-based source search. 

62 

63 Parameters 

64 ---------- 

65 visit_time : `lsst.daf.base.DateTime` 

66 Time of current visit. 

67 months : `int` 

68 Number of months in the sources history. 

69 

70 Returns 

71 ------- 

72 time : `float` 

73 A ``midPointTai`` starting point, MJD time. 

74 """ 

75 # TODO: `system` must be consistent with the code in ap_association 

76 # (see DM-31996) 

77 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

78 

79 

80@contextmanager 

81def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]: 

82 """Returns a connection, makes sure that ANSI mode is set for MySQL 

83 """ 

84 with engine.begin() as conn: 

85 if engine.name == 'mysql': 

86 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'")) 

87 yield conn 

88 return 

89 

90 

91class ApdbSqlConfig(ApdbConfig): 

92 """APDB configuration class for SQL implementation (ApdbSql). 

93 """ 

94 db_url = Field[str]( 

95 doc="SQLAlchemy database connection URI" 

96 ) 

97 isolation_level = ChoiceField[str]( 

98 doc="Transaction isolation level, if unset then backend-default value " 

99 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

100 "Some backends may not support every allowed value.", 

101 allowed={ 

102 "READ_COMMITTED": "Read committed", 

103 "READ_UNCOMMITTED": "Read uncommitted", 

104 "REPEATABLE_READ": "Repeatable read", 

105 "SERIALIZABLE": "Serializable" 

106 }, 

107 default=None, 

108 optional=True 

109 ) 

110 connection_pool = Field[bool]( 

111 doc="If False then disable SQLAlchemy connection pool. " 

112 "Do not use connection pool when forking.", 

113 default=True 

114 ) 

115 connection_timeout = Field[float]( 

116 doc="Maximum time to wait time for database lock to be released before " 

117 "exiting. Defaults to sqlalchemy defaults if not set.", 

118 default=None, 

119 optional=True 

120 ) 

121 sql_echo = Field[bool]( 

122 doc="If True then pass SQLAlchemy echo option.", 

123 default=False 

124 ) 

125 dia_object_index = ChoiceField[str]( 

126 doc="Indexing mode for DiaObject table", 

127 allowed={ 

128 'baseline': "Index defined in baseline schema", 

129 'pix_id_iov': "(pixelId, objectId, iovStart) PK", 

130 'last_object_table': "Separate DiaObjectLast table" 

131 }, 

132 default='baseline' 

133 ) 

134 htm_level = Field[int]( 

135 doc="HTM indexing level", 

136 default=20 

137 ) 

138 htm_max_ranges = Field[int]( 

139 doc="Max number of ranges in HTM envelope", 

140 default=64 

141 ) 

142 htm_index_column = Field[str]( 

143 default="pixelId", 

144 doc="Name of a HTM index column for DiaObject and DiaSource tables" 

145 ) 

146 ra_dec_columns = ListField[str]( 

147 default=["ra", "decl"], 

148 doc="Names ra/dec columns in DiaObject table" 

149 ) 

150 dia_object_columns = ListField[str]( 

151 doc="List of columns to read from DiaObject, by default read all columns", 

152 default=[] 

153 ) 

154 object_last_replace = Field[bool]( 

155 doc="If True (default) then use \"upsert\" for DiaObjectsLast table", 

156 default=True, 

157 deprecated="This field is not used and will be removed on 2022-21-31." 

158 ) 

159 prefix = Field[str]( 

160 doc="Prefix to add to table names and index names", 

161 default="" 

162 ) 

163 namespace = Field[str]( 

164 doc=( 

165 "Namespace or schema name for all tables in APDB database. " 

166 "Presently only makes sense for PostgresQL backend. " 

167 "If schema with this name does not exist it will be created when " 

168 "APDB tables are created." 

169 ), 

170 default=None, 

171 optional=True 

172 ) 

173 explain = Field[bool]( 

174 doc="If True then run EXPLAIN SQL command on each executed query", 

175 default=False 

176 ) 

177 timer = Field[bool]( 

178 doc="If True then print/log timing information", 

179 default=False 

180 ) 

181 

182 def validate(self) -> None: 

183 super().validate() 

184 if len(self.ra_dec_columns) != 2: 

185 raise ValueError("ra_dec_columns must have exactly two column names") 

186 

187 

188class ApdbSql(Apdb): 

189 """Implementation of APDB interface based on SQL database. 

190 

191 The implementation is configured via standard ``pex_config`` mechanism 

192 using `ApdbSqlConfig` configuration class. For an example of different 

193 configurations check ``config/`` folder. 

194 

195 Parameters 

196 ---------- 

197 config : `ApdbSqlConfig` 

198 Configuration object. 

199 """ 

200 

201 ConfigClass = ApdbSqlConfig 

202 

203 def __init__(self, config: ApdbSqlConfig): 

204 

205 config.validate() 

206 self.config = config 

207 

208 _LOG.debug("APDB Configuration:") 

209 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

210 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

211 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

212 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

213 _LOG.debug(" schema_file: %s", self.config.schema_file) 

214 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

215 _LOG.debug(" schema prefix: %s", self.config.prefix) 

216 

217 # engine is reused between multiple processes, make sure that we don't 

218 # share connections by disabling pool (by using NullPool class) 

219 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

220 conn_args: Dict[str, Any] = dict() 

221 if not self.config.connection_pool: 

222 kw.update(poolclass=NullPool) 

223 if self.config.isolation_level is not None: 

224 kw.update(isolation_level=self.config.isolation_level) 

225 elif self.config.db_url.startswith("sqlite"): # type: ignore 

226 # Use READ_UNCOMMITTED as default value for sqlite. 

227 kw.update(isolation_level="READ_UNCOMMITTED") 

228 if self.config.connection_timeout is not None: 

229 if self.config.db_url.startswith("sqlite"): 

230 conn_args.update(timeout=self.config.connection_timeout) 

231 elif self.config.db_url.startswith(("postgresql", "mysql")): 

232 conn_args.update(connect_timeout=self.config.connection_timeout) 

233 kw.update(connect_args=conn_args) 

234 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

235 

236 self._schema = ApdbSqlSchema(engine=self._engine, 

237 dia_object_index=self.config.dia_object_index, 

238 schema_file=self.config.schema_file, 

239 schema_name=self.config.schema_name, 

240 prefix=self.config.prefix, 

241 namespace=self.config.namespace, 

242 htm_index_column=self.config.htm_index_column) 

243 

244 self.pixelator = HtmPixelization(self.config.htm_level) 

245 

246 def tableRowCount(self) -> Dict[str, int]: 

247 """Returns dictionary with the table names and row counts. 

248 

249 Used by ``ap_proto`` to keep track of the size of the database tables. 

250 Depending on database technology this could be expensive operation. 

251 

252 Returns 

253 ------- 

254 row_counts : `dict` 

255 Dict where key is a table name and value is a row count. 

256 """ 

257 res = {} 

258 tables: List[sqlalchemy.schema.Table] = [ 

259 self._schema.objects, self._schema.sources, self._schema.forcedSources] 

260 if self.config.dia_object_index == 'last_object_table': 

261 tables.append(self._schema.objects_last) 

262 for table in tables: 

263 stmt = sql.select([func.count()]).select_from(table) 

264 count = self._engine.scalar(stmt) 

265 res[table.name] = count 

266 

267 return res 

268 

269 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

270 # docstring is inherited from a base class 

271 return self._schema.tableSchemas.get(table) 

272 

273 def makeSchema(self, drop: bool = False) -> None: 

274 # docstring is inherited from a base class 

275 self._schema.makeSchema(drop=drop) 

276 

277 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

278 # docstring is inherited from a base class 

279 

280 # decide what columns we need 

281 table: sqlalchemy.schema.Table 

282 if self.config.dia_object_index == 'last_object_table': 

283 table = self._schema.objects_last 

284 else: 

285 table = self._schema.objects 

286 if not self.config.dia_object_columns: 

287 query = table.select() 

288 else: 

289 columns = [table.c[col] for col in self.config.dia_object_columns] 

290 query = sql.select(columns) 

291 

292 # build selection 

293 query = query.where(self._filterRegion(table, region)) 

294 

295 # select latest version of objects 

296 if self.config.dia_object_index != 'last_object_table': 

297 query = query.where(table.c.validityEnd == None) # noqa: E711 

298 

299 # _LOG.debug("query: %s", query) 

300 

301 if self.config.explain: 

302 # run the same query with explain 

303 self._explain(query, self._engine) 

304 

305 # execute select 

306 with Timer('DiaObject select', self.config.timer): 

307 with self._engine.begin() as conn: 

308 objects = pandas.read_sql_query(query, conn) 

309 _LOG.debug("found %s DiaObjects", len(objects)) 

310 return objects 

311 

312 def getDiaSources(self, region: Region, 

313 object_ids: Optional[Iterable[int]], 

314 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

315 # docstring is inherited from a base class 

316 if self.config.read_sources_months == 0: 

317 _LOG.debug("Skip DiaSources fetching") 

318 return None 

319 

320 if object_ids is None: 

321 # region-based select 

322 return self._getDiaSourcesInRegion(region, visit_time) 

323 else: 

324 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

325 

326 def getDiaForcedSources(self, region: Region, 

327 object_ids: Optional[Iterable[int]], 

328 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

329 """Return catalog of DiaForcedSource instances from a given region. 

330 

331 Parameters 

332 ---------- 

333 region : `lsst.sphgeom.Region` 

334 Region to search for DIASources. 

335 object_ids : iterable [ `int` ], optional 

336 List of DiaObject IDs to further constrain the set of returned 

337 sources. If list is empty then empty catalog is returned with a 

338 correct schema. 

339 visit_time : `lsst.daf.base.DateTime` 

340 Time of the current visit. 

341 

342 Returns 

343 ------- 

344 catalog : `pandas.DataFrame`, or `None` 

345 Catalog containing DiaSource records. `None` is returned if 

346 ``read_sources_months`` configuration parameter is set to 0. 

347 

348 Raises 

349 ------ 

350 NotImplementedError 

351 Raised if ``object_ids`` is `None`. 

352 

353 Notes 

354 ----- 

355 Even though base class allows `None` to be passed for ``object_ids``, 

356 this class requires ``object_ids`` to be not-`None`. 

357 `NotImplementedError` is raised if `None` is passed. 

358 

359 This method returns DiaForcedSource catalog for a region with additional 

360 filtering based on DiaObject IDs. Only a subset of DiaSource history 

361 is returned limited by ``read_forced_sources_months`` config parameter, 

362 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

363 is always returned with a correct schema (columns/types). 

364 """ 

365 

366 if self.config.read_forced_sources_months == 0: 

367 _LOG.debug("Skip DiaForceSources fetching") 

368 return None 

369 

370 if object_ids is None: 

371 # This implementation does not support region-based selection. 

372 raise NotImplementedError("Region-based selection is not supported") 

373 

374 # TODO: DateTime.MJD must be consistent with code in ap_association, 

375 # alternatively we can fill midPointTai ourselves in store() 

376 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

377 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

378 

379 table: sqlalchemy.schema.Table = self._schema.forcedSources 

380 with Timer('DiaForcedSource select', self.config.timer): 

381 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start) 

382 

383 _LOG.debug("found %s DiaForcedSources", len(sources)) 

384 return sources 

385 

386 def getDiaObjectsHistory(self, 

387 start_time: dafBase.DateTime, 

388 end_time: dafBase.DateTime, 

389 region: Optional[Region] = None) -> pandas.DataFrame: 

390 # docstring is inherited from a base class 

391 

392 table = self._schema.objects 

393 query = table.select() 

394 

395 # build selection 

396 time_filter = sql.expression.and_( 

397 table.columns["validityStart"] >= start_time.toPython(), 

398 table.columns["validityStart"] < end_time.toPython() 

399 ) 

400 

401 if region: 

402 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

403 query = query.where(where) 

404 else: 

405 query = query.where(time_filter) 

406 

407 # execute select 

408 with Timer('DiaObject history select', self.config.timer): 

409 with self._engine.begin() as conn: 

410 catalog = pandas.read_sql_query(query, conn) 

411 _LOG.debug("found %s DiaObjects history records", len(catalog)) 

412 return catalog 

413 

414 def getDiaSourcesHistory(self, 

415 start_time: dafBase.DateTime, 

416 end_time: dafBase.DateTime, 

417 region: Optional[Region] = None) -> pandas.DataFrame: 

418 # docstring is inherited from a base class 

419 

420 table = self._schema.sources 

421 query = table.select() 

422 

423 # build selection 

424 time_filter = sql.expression.and_( 

425 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

426 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

427 ) 

428 

429 if region: 

430 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

431 query = query.where(where) 

432 else: 

433 query = query.where(time_filter) 

434 

435 # execute select 

436 with Timer('DiaSource history select', self.config.timer): 

437 with self._engine.begin() as conn: 

438 catalog = pandas.read_sql_query(query, conn) 

439 _LOG.debug("found %s DiaSource history records", len(catalog)) 

440 return catalog 

441 

442 def getDiaForcedSourcesHistory(self, 

443 start_time: dafBase.DateTime, 

444 end_time: dafBase.DateTime, 

445 region: Optional[Region] = None) -> pandas.DataFrame: 

446 # docstring is inherited from a base class 

447 

448 table = self._schema.forcedSources 

449 query = table.select() 

450 

451 # build selection 

452 time_filter = sql.expression.and_( 

453 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

454 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

455 ) 

456 # Forced sources have no pixel index, so no region filtering 

457 query = query.where(time_filter) 

458 

459 # execute select 

460 with Timer('DiaForcedSource history select', self.config.timer): 

461 with self._engine.begin() as conn: 

462 catalog = pandas.read_sql_query(query, conn) 

463 _LOG.debug("found %s DiaForcedSource history records", len(catalog)) 

464 return catalog 

465 

466 def getSSObjects(self) -> pandas.DataFrame: 

467 # docstring is inherited from a base class 

468 

469 table = self._schema.ssObjects 

470 query = table.select() 

471 

472 if self.config.explain: 

473 # run the same query with explain 

474 self._explain(query, self._engine) 

475 

476 # execute select 

477 with Timer('DiaObject select', self.config.timer): 

478 with self._engine.begin() as conn: 

479 objects = pandas.read_sql_query(query, conn) 

480 _LOG.debug("found %s SSObjects", len(objects)) 

481 return objects 

482 

483 def store(self, 

484 visit_time: dafBase.DateTime, 

485 objects: pandas.DataFrame, 

486 sources: Optional[pandas.DataFrame] = None, 

487 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

488 # docstring is inherited from a base class 

489 

490 # fill pixelId column for DiaObjects 

491 objects = self._add_obj_htm_index(objects) 

492 self._storeDiaObjects(objects, visit_time) 

493 

494 if sources is not None: 

495 # copy pixelId column from DiaObjects to DiaSources 

496 sources = self._add_src_htm_index(sources, objects) 

497 self._storeDiaSources(sources) 

498 

499 if forced_sources is not None: 

500 self._storeDiaForcedSources(forced_sources) 

501 

502 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

503 # docstring is inherited from a base class 

504 

505 idColumn = "ssObjectId" 

506 table = self._schema.ssObjects 

507 

508 # everything to be done in single transaction 

509 with self._engine.begin() as conn: 

510 

511 # Find record IDs that already exist. Some types like np.int64 can 

512 # cause issues with sqlalchemy, convert them to int. 

513 ids = sorted(int(oid) for oid in objects[idColumn]) 

514 

515 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

516 result = conn.execute(query) 

517 knownIds = set(row[idColumn] for row in result) 

518 

519 filter = objects[idColumn].isin(knownIds) 

520 toUpdate = cast(pandas.DataFrame, objects[filter]) 

521 toInsert = cast(pandas.DataFrame, objects[~filter]) 

522 

523 # insert new records 

524 if len(toInsert) > 0: 

525 toInsert.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

526 

527 # update existing records 

528 if len(toUpdate) > 0: 

529 whereKey = f"{idColumn}_param" 

530 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

531 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

532 values = toUpdate.to_dict("records") 

533 result = conn.execute(query, values) 

534 

535 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

536 # docstring is inherited from a base class 

537 

538 table = self._schema.sources 

539 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

540 

541 with self._engine.begin() as conn: 

542 # Need to make sure that every ID exists in the database, but 

543 # executemany may not support rowcount, so iterate and check what is 

544 # missing. 

545 missing_ids: List[int] = [] 

546 for key, value in idMap.items(): 

547 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

548 result = conn.execute(query, params) 

549 if result.rowcount == 0: 

550 missing_ids.append(key) 

551 if missing_ids: 

552 missing = ",".join(str(item)for item in missing_ids) 

553 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

554 

555 def dailyJob(self) -> None: 

556 # docstring is inherited from a base class 

557 

558 if self._engine.name == 'postgresql': 

559 

560 # do VACUUM on all tables 

561 _LOG.info("Running VACUUM on all tables") 

562 connection = self._engine.raw_connection() 

563 ISOLATION_LEVEL_AUTOCOMMIT = 0 

564 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

565 cursor = connection.cursor() 

566 cursor.execute("VACUUM ANALYSE") 

567 

568 def countUnassociatedObjects(self) -> int: 

569 # docstring is inherited from a base class 

570 

571 # Retrieve the DiaObject table. 

572 table: sqlalchemy.schema.Table = self._schema.objects 

573 

574 # Construct the sql statement. 

575 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

576 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

577 

578 # Return the count. 

579 with self._engine.begin() as conn: 

580 count = conn.scalar(stmt) 

581 

582 return count 

583 

584 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime 

585 ) -> pandas.DataFrame: 

586 """Returns catalog of DiaSource instances from given region. 

587 

588 Parameters 

589 ---------- 

590 region : `lsst.sphgeom.Region` 

591 Region to search for DIASources. 

592 visit_time : `lsst.daf.base.DateTime` 

593 Time of the current visit. 

594 

595 Returns 

596 ------- 

597 catalog : `pandas.DataFrame` 

598 Catalog containing DiaSource records. 

599 """ 

600 # TODO: DateTime.MJD must be consistent with code in ap_association, 

601 # alternatively we can fill midPointTai ourselves in store() 

602 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

603 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

604 

605 table: sqlalchemy.schema.Table = self._schema.sources 

606 query = table.select() 

607 

608 # build selection 

609 time_filter = table.columns["midPointTai"] > midPointTai_start 

610 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

611 query = query.where(where) 

612 

613 # execute select 

614 with Timer('DiaSource select', self.config.timer): 

615 with _ansi_session(self._engine) as conn: 

616 sources = pandas.read_sql_query(query, conn) 

617 _LOG.debug("found %s DiaSources", len(sources)) 

618 return sources 

619 

620 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime 

621 ) -> pandas.DataFrame: 

622 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

623 

624 Parameters 

625 ---------- 

626 object_ids : 

627 Collection of DiaObject IDs 

628 visit_time : `lsst.daf.base.DateTime` 

629 Time of the current visit. 

630 

631 Returns 

632 ------- 

633 catalog : `pandas.DataFrame` 

634 Catalog contaning DiaSource records. 

635 """ 

636 # TODO: DateTime.MJD must be consistent with code in ap_association, 

637 # alternatively we can fill midPointTai ourselves in store() 

638 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

639 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

640 

641 table: sqlalchemy.schema.Table = self._schema.sources 

642 with Timer('DiaSource select', self.config.timer): 

643 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start) 

644 

645 _LOG.debug("found %s DiaSources", len(sources)) 

646 return sources 

647 

648 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table, 

649 object_ids: List[int], 

650 midPointTai_start: float 

651 ) -> pandas.DataFrame: 

652 """Returns catalog of DiaSource or DiaForcedSource instances given set 

653 of DiaObject IDs. 

654 

655 Parameters 

656 ---------- 

657 table : `sqlalchemy.schema.Table` 

658 Database table. 

659 object_ids : 

660 Collection of DiaObject IDs 

661 midPointTai_start : `float` 

662 Earliest midPointTai to retrieve. 

663 

664 Returns 

665 ------- 

666 catalog : `pandas.DataFrame` 

667 Catalog contaning DiaSource records. `None` is returned if 

668 ``read_sources_months`` configuration parameter is set to 0 or 

669 when ``object_ids`` is empty. 

670 """ 

671 sources: Optional[pandas.DataFrame] = None 

672 with _ansi_session(self._engine) as conn: 

673 if len(object_ids) <= 0: 

674 _LOG.debug("ID list is empty, just fetch empty result") 

675 query = table.select().where(False) 

676 sources = pandas.read_sql_query(query, conn) 

677 else: 

678 for ids in chunk_iterable(sorted(object_ids), 1000): 

679 query = table.select() 

680 

681 # Some types like np.int64 can cause issues with 

682 # sqlalchemy, convert them to int. 

683 int_ids = [int(oid) for oid in ids] 

684 

685 # select by object id 

686 query = query.where( 

687 sql.expression.and_( 

688 table.columns["diaObjectId"].in_(int_ids), 

689 table.columns["midPointTai"] > midPointTai_start, 

690 ) 

691 ) 

692 

693 # execute select 

694 df = pandas.read_sql_query(query, conn) 

695 if sources is None: 

696 sources = df 

697 else: 

698 sources = sources.append(df) 

699 assert sources is not None, "Catalog cannot be None" 

700 return sources 

701 

702 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

703 """Store catalog of DiaObjects from current visit. 

704 

705 Parameters 

706 ---------- 

707 objs : `pandas.DataFrame` 

708 Catalog with DiaObject records. 

709 visit_time : `lsst.daf.base.DateTime` 

710 Time of the visit. 

711 """ 

712 

713 # Some types like np.int64 can cause issues with sqlalchemy, convert 

714 # them to int. 

715 ids = sorted(int(oid) for oid in objs['diaObjectId']) 

716 _LOG.debug("first object ID: %d", ids[0]) 

717 

718 # NOTE: workaround for sqlite, need this here to avoid 

719 # "database is locked" error. 

720 table: sqlalchemy.schema.Table = self._schema.objects 

721 

722 # TODO: Need to verify that we are using correct scale here for 

723 # DATETIME representation (see DM-31996). 

724 dt = visit_time.toPython() 

725 

726 # everything to be done in single transaction 

727 with _ansi_session(self._engine) as conn: 

728 

729 if self.config.dia_object_index == 'last_object_table': 

730 

731 # insert and replace all records in LAST table, mysql and postgres have 

732 # non-standard features 

733 table = self._schema.objects_last 

734 

735 # Drop the previous objects (pandas cannot upsert). 

736 query = table.delete().where( 

737 table.columns["diaObjectId"].in_(ids) 

738 ) 

739 

740 if self.config.explain: 

741 # run the same query with explain 

742 self._explain(query, conn) 

743 

744 with Timer(table.name + ' delete', self.config.timer): 

745 res = conn.execute(query) 

746 _LOG.debug("deleted %s objects", res.rowcount) 

747 

748 # DiaObjectLast is a subset of DiaObject, strip missing columns 

749 last_column_names = [column.name for column in table.columns] 

750 last_objs = objs[last_column_names] 

751 

752 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt) 

753 with Timer("DiaObjectLast insert", self.config.timer): 

754 last_objs = _coerce_uint64(last_objs) 

755 for col, data in extra_columns.items(): 

756 last_objs[col] = data 

757 last_objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

758 else: 

759 

760 # truncate existing validity intervals 

761 table = self._schema.objects 

762 

763 query = table.update().values(validityEnd=dt).where( 

764 sql.expression.and_( 

765 table.columns["diaObjectId"].in_(ids), 

766 table.columns["validityEnd"].is_(None), 

767 ) 

768 ) 

769 

770 # _LOG.debug("query: %s", query) 

771 

772 if self.config.explain: 

773 # run the same query with explain 

774 self._explain(query, conn) 

775 

776 with Timer(table.name + ' truncate', self.config.timer): 

777 res = conn.execute(query) 

778 _LOG.debug("truncated %s intervals", res.rowcount) 

779 

780 # insert new versions 

781 table = self._schema.objects 

782 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt, 

783 validityEnd=None) 

784 with Timer("DiaObject insert", self.config.timer): 

785 objs = _coerce_uint64(objs) 

786 if extra_columns: 

787 columns: List[pandas.Series] = [] 

788 for col, data in extra_columns.items(): 

789 columns.append(pandas.Series([data]*len(objs), name=col)) 

790 objs.set_index(columns[0].index, inplace=True) 

791 objs = pandas.concat([objs] + columns, axis="columns") 

792 objs.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

793 

794 def _storeDiaSources(self, sources: pandas.DataFrame) -> None: 

795 """Store catalog of DiaSources from current visit. 

796 

797 Parameters 

798 ---------- 

799 sources : `pandas.DataFrame` 

800 Catalog containing DiaSource records 

801 """ 

802 # everything to be done in single transaction 

803 with _ansi_session(self._engine) as conn: 

804 

805 with Timer("DiaSource insert", self.config.timer): 

806 sources = _coerce_uint64(sources) 

807 table = self._schema.sources 

808 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

809 

810 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None: 

811 """Store a set of DiaForcedSources from current visit. 

812 

813 Parameters 

814 ---------- 

815 sources : `pandas.DataFrame` 

816 Catalog containing DiaForcedSource records 

817 """ 

818 

819 # everything to be done in single transaction 

820 with _ansi_session(self._engine) as conn: 

821 

822 with Timer("DiaForcedSource insert", self.config.timer): 

823 sources = _coerce_uint64(sources) 

824 table = self._schema.forcedSources 

825 sources.to_sql(table.name, conn, if_exists='append', index=False, schema=table.schema) 

826 

827 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None: 

828 """Run the query with explain 

829 """ 

830 

831 _LOG.info("explain for query: %s...", query[:64]) 

832 

833 if conn.engine.name == 'mysql': 

834 query = "EXPLAIN EXTENDED " + query 

835 else: 

836 query = "EXPLAIN " + query 

837 

838 res = conn.execute(sql.text(query)) 

839 if res.returns_rows: 

840 _LOG.info("explain: %s", res.keys()) 

841 for row in res: 

842 _LOG.info("explain: %s", row) 

843 else: 

844 _LOG.info("EXPLAIN returned nothing") 

845 

846 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

847 """Generate a set of HTM indices covering specified region. 

848 

849 Parameters 

850 ---------- 

851 region: `sphgeom.Region` 

852 Region that needs to be indexed. 

853 

854 Returns 

855 ------- 

856 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

857 """ 

858 _LOG.debug('region: %s', region) 

859 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

860 

861 return indices.ranges() 

862 

863 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement: 

864 """Make SQLAlchemy expression for selecting records in a region. 

865 """ 

866 htm_index_column = table.columns[self.config.htm_index_column] 

867 exprlist = [] 

868 pixel_ranges = self._htm_indices(region) 

869 for low, upper in pixel_ranges: 

870 upper -= 1 

871 if low == upper: 

872 exprlist.append(htm_index_column == low) 

873 else: 

874 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

875 

876 return sql.expression.or_(*exprlist) 

877 

878 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

879 """Calculate HTM index for each record and add it to a DataFrame. 

880 

881 Notes 

882 ----- 

883 This overrides any existing column in a DataFrame with the same name 

884 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

885 returned. 

886 """ 

887 # calculate HTM index for every DiaObject 

888 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

889 ra_col, dec_col = self.config.ra_dec_columns 

890 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

891 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

892 idx = self.pixelator.index(uv3d) 

893 htm_index[i] = idx 

894 df = df.copy() 

895 df[self.config.htm_index_column] = htm_index 

896 return df 

897 

898 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

899 """Add pixelId column to DiaSource catalog. 

900 

901 Notes 

902 ----- 

903 This method copies pixelId value from a matching DiaObject record. 

904 DiaObject catalog needs to have a pixelId column filled by 

905 ``_add_obj_htm_index`` method and DiaSource records need to be 

906 associated to DiaObjects via ``diaObjectId`` column. 

907 

908 This overrides any existing column in a DataFrame with the same name 

909 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

910 returned. 

911 """ 

912 pixel_id_map: Dict[int, int] = { 

913 diaObjectId: pixelId for diaObjectId, pixelId 

914 in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

915 } 

916 # DiaSources associated with SolarSystemObjects do not have an 

917 # associated DiaObject hence we skip them and set their htmIndex 

918 # value to 0. 

919 pixel_id_map[0] = 0 

920 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

921 for i, diaObjId in enumerate(sources["diaObjectId"]): 

922 htm_index[i] = pixel_id_map[diaObjId] 

923 sources = sources.copy() 

924 sources[self.config.htm_index_column] = htm_index 

925 return sources