Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

377 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-18 19:02 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29from contextlib import contextmanager 

30import logging 

31import numpy as np 

32import pandas 

33from typing import cast, Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple 

34 

35import lsst.daf.base as dafBase 

36from lsst.pex.config import Field, ChoiceField, ListField 

37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

38import sqlalchemy 

39from sqlalchemy import (func, sql) 

40from sqlalchemy.pool import NullPool 

41from .apdb import Apdb, ApdbConfig 

42from .apdbSchema import ApdbTables, TableDef 

43from .apdbSqlSchema import ApdbSqlSchema 

44from .timer import Timer 

45 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50def _split(seq: Iterable, nItems: int) -> Iterator[List]: 

51 """Split a sequence into smaller sequences""" 

52 seq = list(seq) 

53 while seq: 

54 yield seq[:nItems] 

55 del seq[:nItems] 

56 

57 

58def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

59 """Change type of the uint64 columns to int64, return copy of data frame. 

60 """ 

61 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

62 return df.astype({name: np.int64 for name in names}) 

63 

64 

65def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

66 """Calculate starting point for time-based source search. 

67 

68 Parameters 

69 ---------- 

70 visit_time : `lsst.daf.base.DateTime` 

71 Time of current visit. 

72 months : `int` 

73 Number of months in the sources history. 

74 

75 Returns 

76 ------- 

77 time : `float` 

78 A ``midPointTai`` starting point, MJD time. 

79 """ 

80 # TODO: `system` must be consistent with the code in ap_association 

81 # (see DM-31996) 

82 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

83 

84 

85@contextmanager 

86def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]: 

87 """Returns a connection, makes sure that ANSI mode is set for MySQL 

88 """ 

89 with engine.begin() as conn: 

90 if engine.name == 'mysql': 

91 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'")) 

92 yield conn 

93 return 

94 

95 

96class ApdbSqlConfig(ApdbConfig): 

97 """APDB configuration class for SQL implementation (ApdbSql). 

98 """ 

99 db_url = Field( 

100 dtype=str, 

101 doc="SQLAlchemy database connection URI" 

102 ) 

103 isolation_level = ChoiceField( 

104 dtype=str, 

105 doc="Transaction isolation level, if unset then backend-default value " 

106 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

107 "Some backends may not support every allowed value.", 

108 allowed={ 

109 "READ_COMMITTED": "Read committed", 

110 "READ_UNCOMMITTED": "Read uncommitted", 

111 "REPEATABLE_READ": "Repeatable read", 

112 "SERIALIZABLE": "Serializable" 

113 }, 

114 default=None, 

115 optional=True 

116 ) 

117 connection_pool = Field( 

118 dtype=bool, 

119 doc="If False then disable SQLAlchemy connection pool. " 

120 "Do not use connection pool when forking.", 

121 default=True 

122 ) 

123 connection_timeout = Field( 

124 dtype=float, 

125 doc="Maximum time to wait time for database lock to be released before " 

126 "exiting. Defaults to sqlachemy defaults if not set.", 

127 default=None, 

128 optional=True 

129 ) 

130 sql_echo = Field( 

131 dtype=bool, 

132 doc="If True then pass SQLAlchemy echo option.", 

133 default=False 

134 ) 

135 dia_object_index = ChoiceField( 

136 dtype=str, 

137 doc="Indexing mode for DiaObject table", 

138 allowed={ 

139 'baseline': "Index defined in baseline schema", 

140 'pix_id_iov': "(pixelId, objectId, iovStart) PK", 

141 'last_object_table': "Separate DiaObjectLast table" 

142 }, 

143 default='baseline' 

144 ) 

145 htm_level = Field( 

146 dtype=int, 

147 doc="HTM indexing level", 

148 default=20 

149 ) 

150 htm_max_ranges = Field( 

151 dtype=int, 

152 doc="Max number of ranges in HTM envelope", 

153 default=64 

154 ) 

155 htm_index_column = Field( 

156 dtype=str, 

157 default="pixelId", 

158 doc="Name of a HTM index column for DiaObject and DiaSource tables" 

159 ) 

160 ra_dec_columns = ListField( 

161 dtype=str, 

162 default=["ra", "decl"], 

163 doc="Names ra/dec columns in DiaObject table" 

164 ) 

165 dia_object_columns = ListField( 

166 dtype=str, 

167 doc="List of columns to read from DiaObject, by default read all columns", 

168 default=[] 

169 ) 

170 object_last_replace = Field( 

171 dtype=bool, 

172 doc="If True (default) then use \"upsert\" for DiaObjectsLast table", 

173 default=True 

174 ) 

175 prefix = Field( 

176 dtype=str, 

177 doc="Prefix to add to table names and index names", 

178 default="" 

179 ) 

180 explain = Field( 

181 dtype=bool, 

182 doc="If True then run EXPLAIN SQL command on each executed query", 

183 default=False 

184 ) 

185 timer = Field( 

186 dtype=bool, 

187 doc="If True then print/log timing information", 

188 default=False 

189 ) 

190 

191 def validate(self) -> None: 

192 super().validate() 

193 if len(self.ra_dec_columns) != 2: 

194 raise ValueError("ra_dec_columns must have exactly two column names") 

195 

196 

197class ApdbSql(Apdb): 

198 """Implementation of APDB interface based on SQL database. 

199 

200 The implementation is configured via standard ``pex_config`` mechanism 

201 using `ApdbSqlConfig` configuration class. For an example of different 

202 configurations check ``config/`` folder. 

203 

204 Parameters 

205 ---------- 

206 config : `ApdbSqlConfig` 

207 Configuration object. 

208 """ 

209 

210 ConfigClass = ApdbSqlConfig 

211 

212 def __init__(self, config: ApdbSqlConfig): 

213 

214 self.config = config 

215 

216 _LOG.debug("APDB Configuration:") 

217 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

218 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

219 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

220 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

221 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace) 

222 _LOG.debug(" schema_file: %s", self.config.schema_file) 

223 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

224 _LOG.debug(" schema prefix: %s", self.config.prefix) 

225 

226 # engine is reused between multiple processes, make sure that we don't 

227 # share connections by disabling pool (by using NullPool class) 

228 kw = dict(echo=self.config.sql_echo) 

229 conn_args: Dict[str, Any] = dict() 

230 if not self.config.connection_pool: 

231 kw.update(poolclass=NullPool) 

232 if self.config.isolation_level is not None: 

233 kw.update(isolation_level=self.config.isolation_level) 

234 elif self.config.db_url.startswith("sqlite"): 

235 # Use READ_UNCOMMITTED as default value for sqlite. 

236 kw.update(isolation_level="READ_UNCOMMITTED") 

237 if self.config.connection_timeout is not None: 

238 if self.config.db_url.startswith("sqlite"): 

239 conn_args.update(timeout=self.config.connection_timeout) 

240 elif self.config.db_url.startswith(("postgresql", "mysql")): 

241 conn_args.update(connect_timeout=self.config.connection_timeout) 

242 kw.update(connect_args=conn_args) 

243 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

244 

245 self._schema = ApdbSqlSchema(engine=self._engine, 

246 dia_object_index=self.config.dia_object_index, 

247 schema_file=self.config.schema_file, 

248 schema_name=self.config.schema_name, 

249 prefix=self.config.prefix, 

250 htm_index_column=self.config.htm_index_column) 

251 

252 self.pixelator = HtmPixelization(self.config.htm_level) 

253 

254 def tableRowCount(self) -> Dict[str, int]: 

255 """Returns dictionary with the table names and row counts. 

256 

257 Used by ``ap_proto`` to keep track of the size of the database tables. 

258 Depending on database technology this could be expensive operation. 

259 

260 Returns 

261 ------- 

262 row_counts : `dict` 

263 Dict where key is a table name and value is a row count. 

264 """ 

265 res = {} 

266 tables: List[sqlalchemy.schema.Table] = [ 

267 self._schema.objects, self._schema.sources, self._schema.forcedSources] 

268 if self.config.dia_object_index == 'last_object_table': 

269 tables.append(self._schema.objects_last) 

270 for table in tables: 

271 stmt = sql.select([func.count()]).select_from(table) 

272 count = self._engine.scalar(stmt) 

273 res[table.name] = count 

274 

275 return res 

276 

277 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

278 # docstring is inherited from a base class 

279 return self._schema.tableSchemas.get(table) 

280 

281 def makeSchema(self, drop: bool = False) -> None: 

282 # docstring is inherited from a base class 

283 self._schema.makeSchema(drop=drop) 

284 

285 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

286 # docstring is inherited from a base class 

287 

288 # decide what columns we need 

289 table: sqlalchemy.schema.Table 

290 if self.config.dia_object_index == 'last_object_table': 

291 table = self._schema.objects_last 

292 else: 

293 table = self._schema.objects 

294 if not self.config.dia_object_columns: 

295 query = table.select() 

296 else: 

297 columns = [table.c[col] for col in self.config.dia_object_columns] 

298 query = sql.select(columns) 

299 

300 # build selection 

301 query = query.where(self._filterRegion(table, region)) 

302 

303 # select latest version of objects 

304 if self.config.dia_object_index != 'last_object_table': 

305 query = query.where(table.c.validityEnd == None) # noqa: E711 

306 

307 _LOG.debug("query: %s", query) 

308 

309 if self.config.explain: 

310 # run the same query with explain 

311 self._explain(query, self._engine) 

312 

313 # execute select 

314 with Timer('DiaObject select', self.config.timer): 

315 with self._engine.begin() as conn: 

316 objects = pandas.read_sql_query(query, conn) 

317 _LOG.debug("found %s DiaObjects", len(objects)) 

318 return objects 

319 

320 def getDiaSources(self, region: Region, 

321 object_ids: Optional[Iterable[int]], 

322 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

323 # docstring is inherited from a base class 

324 if self.config.read_sources_months == 0: 

325 _LOG.debug("Skip DiaSources fetching") 

326 return None 

327 

328 if object_ids is None: 

329 # region-based select 

330 return self._getDiaSourcesInRegion(region, visit_time) 

331 else: 

332 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

333 

334 def getDiaForcedSources(self, region: Region, 

335 object_ids: Optional[Iterable[int]], 

336 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

337 """Return catalog of DiaForcedSource instances from a given region. 

338 

339 Parameters 

340 ---------- 

341 region : `lsst.sphgeom.Region` 

342 Region to search for DIASources. 

343 object_ids : iterable [ `int` ], optional 

344 List of DiaObject IDs to further constrain the set of returned 

345 sources. If list is empty then empty catalog is returned with a 

346 correct schema. 

347 visit_time : `lsst.daf.base.DateTime` 

348 Time of the current visit. 

349 

350 Returns 

351 ------- 

352 catalog : `pandas.DataFrame`, or `None` 

353 Catalog containing DiaSource records. `None` is returned if 

354 ``read_sources_months`` configuration parameter is set to 0. 

355 

356 Raises 

357 ------ 

358 NotImplementedError 

359 Raised if ``object_ids`` is `None`. 

360 

361 Notes 

362 ----- 

363 Even though base class allows `None` to be passed for ``object_ids``, 

364 this class requires ``object_ids`` to be not-`None`. 

365 `NotImplementedError` is raised if `None` is passed. 

366 

367 This method returns DiaForcedSource catalog for a region with additional 

368 filtering based on DiaObject IDs. Only a subset of DiaSource history 

369 is returned limited by ``read_forced_sources_months`` config parameter, 

370 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

371 is always returned with a correct schema (columns/types). 

372 """ 

373 

374 if self.config.read_forced_sources_months == 0: 

375 _LOG.debug("Skip DiaForceSources fetching") 

376 return None 

377 

378 if object_ids is None: 

379 # This implementation does not support region-based selection. 

380 raise NotImplementedError("Region-based selection is not supported") 

381 

382 # TODO: DateTime.MJD must be consistent with code in ap_association, 

383 # alternatively we can fill midPointTai ourselves in store() 

384 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

385 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

386 

387 table: sqlalchemy.schema.Table = self._schema.forcedSources 

388 with Timer('DiaForcedSource select', self.config.timer): 

389 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start) 

390 

391 _LOG.debug("found %s DiaForcedSources", len(sources)) 

392 return sources 

393 

394 def getDiaObjectsHistory(self, 

395 start_time: dafBase.DateTime, 

396 end_time: dafBase.DateTime, 

397 region: Optional[Region] = None) -> pandas.DataFrame: 

398 # docstring is inherited from a base class 

399 

400 table = self._schema.objects 

401 query = table.select() 

402 

403 # build selection 

404 time_filter = sql.expression.and_( 

405 table.columns["validityStart"] >= start_time.toPython(), 

406 table.columns["validityStart"] < end_time.toPython() 

407 ) 

408 

409 if region: 

410 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

411 query = query.where(where) 

412 else: 

413 query = query.where(time_filter) 

414 

415 # execute select 

416 with Timer('DiaObject history select', self.config.timer): 

417 with self._engine.begin() as conn: 

418 catalog = pandas.read_sql_query(query, conn) 

419 _LOG.debug("found %s DiaObjects history records", len(catalog)) 

420 return catalog 

421 

422 def getDiaSourcesHistory(self, 

423 start_time: dafBase.DateTime, 

424 end_time: dafBase.DateTime, 

425 region: Optional[Region] = None) -> pandas.DataFrame: 

426 # docstring is inherited from a base class 

427 

428 table = self._schema.sources 

429 query = table.select() 

430 

431 # build selection 

432 time_filter = sql.expression.and_( 

433 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

434 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

435 ) 

436 

437 if region: 

438 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

439 query = query.where(where) 

440 else: 

441 query = query.where(time_filter) 

442 

443 # execute select 

444 with Timer('DiaSource history select', self.config.timer): 

445 with self._engine.begin() as conn: 

446 catalog = pandas.read_sql_query(query, conn) 

447 _LOG.debug("found %s DiaSource history records", len(catalog)) 

448 return catalog 

449 

450 def getDiaForcedSourcesHistory(self, 

451 start_time: dafBase.DateTime, 

452 end_time: dafBase.DateTime, 

453 region: Optional[Region] = None) -> pandas.DataFrame: 

454 # docstring is inherited from a base class 

455 

456 table = self._schema.forcedSources 

457 query = table.select() 

458 

459 # build selection 

460 time_filter = sql.expression.and_( 

461 table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD), 

462 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

463 ) 

464 # Forced sources have no pixel index, so no region filtering 

465 query = query.where(time_filter) 

466 

467 # execute select 

468 with Timer('DiaForcedSource history select', self.config.timer): 

469 with self._engine.begin() as conn: 

470 catalog = pandas.read_sql_query(query, conn) 

471 _LOG.debug("found %s DiaForcedSource history records", len(catalog)) 

472 return catalog 

473 

474 def getSSObjects(self) -> pandas.DataFrame: 

475 # docstring is inherited from a base class 

476 

477 table = self._schema.ssObjects 

478 query = table.select() 

479 

480 if self.config.explain: 

481 # run the same query with explain 

482 self._explain(query, self._engine) 

483 

484 # execute select 

485 with Timer('DiaObject select', self.config.timer): 

486 with self._engine.begin() as conn: 

487 objects = pandas.read_sql_query(query, conn) 

488 _LOG.debug("found %s SSObjects", len(objects)) 

489 return objects 

490 

491 def store(self, 

492 visit_time: dafBase.DateTime, 

493 objects: pandas.DataFrame, 

494 sources: Optional[pandas.DataFrame] = None, 

495 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

496 # docstring is inherited from a base class 

497 

498 # fill pixelId column for DiaObjects 

499 objects = self._add_obj_htm_index(objects) 

500 self._storeDiaObjects(objects, visit_time) 

501 

502 if sources is not None: 

503 # copy pixelId column from DiaObjects to DiaSources 

504 sources = self._add_src_htm_index(sources, objects) 

505 self._storeDiaSources(sources) 

506 

507 if forced_sources is not None: 

508 self._storeDiaForcedSources(forced_sources) 

509 

510 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

511 # docstring is inherited from a base class 

512 

513 idColumn = "ssObjectId" 

514 table = self._schema.ssObjects 

515 

516 # everything to be done in single transaction 

517 with self._engine.begin() as conn: 

518 

519 # find record IDs that already exist 

520 ids = sorted(objects[idColumn]) 

521 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

522 result = conn.execute(query) 

523 knownIds = set(row[idColumn] for row in result) 

524 

525 filter = objects[idColumn].isin(knownIds) 

526 toUpdate = cast(pandas.DataFrame, objects[filter]) 

527 toInsert = cast(pandas.DataFrame, objects[~filter]) 

528 

529 # insert new records 

530 if len(toInsert) > 0: 

531 toInsert.to_sql(ApdbTables.SSObject.table_name(), conn, if_exists='append', index=False) 

532 

533 # update existing records 

534 if len(toUpdate) > 0: 

535 whereKey = f"{idColumn}_param" 

536 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

537 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

538 values = toUpdate.to_dict("records") 

539 result = conn.execute(query, values) 

540 

541 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

542 # docstring is inherited from a base class 

543 

544 table = self._schema.sources 

545 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

546 

547 with self._engine.begin() as conn: 

548 # Need to make sure that every ID exists in the database, but 

549 # executemany may not support rowcount, so iterate and check what is 

550 # missing. 

551 missing_ids: List[int] = [] 

552 for key, value in idMap.items(): 

553 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

554 result = conn.execute(query, params) 

555 if result.rowcount == 0: 

556 missing_ids.append(key) 

557 if missing_ids: 

558 missing = ",".join(str(item)for item in missing_ids) 

559 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

560 

561 def dailyJob(self) -> None: 

562 # docstring is inherited from a base class 

563 

564 if self._engine.name == 'postgresql': 

565 

566 # do VACUUM on all tables 

567 _LOG.info("Running VACUUM on all tables") 

568 connection = self._engine.raw_connection() 

569 ISOLATION_LEVEL_AUTOCOMMIT = 0 

570 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

571 cursor = connection.cursor() 

572 cursor.execute("VACUUM ANALYSE") 

573 

574 def countUnassociatedObjects(self) -> int: 

575 # docstring is inherited from a base class 

576 

577 # Retrieve the DiaObject table. 

578 table: sqlalchemy.schema.Table = self._schema.objects 

579 

580 # Construct the sql statement. 

581 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

582 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

583 

584 # Return the count. 

585 with self._engine.begin() as conn: 

586 count = conn.scalar(stmt) 

587 

588 return count 

589 

590 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime 

591 ) -> pandas.DataFrame: 

592 """Returns catalog of DiaSource instances from given region. 

593 

594 Parameters 

595 ---------- 

596 region : `lsst.sphgeom.Region` 

597 Region to search for DIASources. 

598 visit_time : `lsst.daf.base.DateTime` 

599 Time of the current visit. 

600 

601 Returns 

602 ------- 

603 catalog : `pandas.DataFrame` 

604 Catalog containing DiaSource records. 

605 """ 

606 # TODO: DateTime.MJD must be consistent with code in ap_association, 

607 # alternatively we can fill midPointTai ourselves in store() 

608 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

609 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

610 

611 table: sqlalchemy.schema.Table = self._schema.sources 

612 query = table.select() 

613 

614 # build selection 

615 time_filter = table.columns["midPointTai"] > midPointTai_start 

616 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

617 query = query.where(where) 

618 

619 # execute select 

620 with Timer('DiaSource select', self.config.timer): 

621 with _ansi_session(self._engine) as conn: 

622 sources = pandas.read_sql_query(query, conn) 

623 _LOG.debug("found %s DiaSources", len(sources)) 

624 return sources 

625 

626 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime 

627 ) -> pandas.DataFrame: 

628 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

629 

630 Parameters 

631 ---------- 

632 object_ids : 

633 Collection of DiaObject IDs 

634 visit_time : `lsst.daf.base.DateTime` 

635 Time of the current visit. 

636 

637 Returns 

638 ------- 

639 catalog : `pandas.DataFrame` 

640 Catalog contaning DiaSource records. 

641 """ 

642 # TODO: DateTime.MJD must be consistent with code in ap_association, 

643 # alternatively we can fill midPointTai ourselves in store() 

644 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

645 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

646 

647 table: sqlalchemy.schema.Table = self._schema.sources 

648 with Timer('DiaSource select', self.config.timer): 

649 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start) 

650 

651 _LOG.debug("found %s DiaSources", len(sources)) 

652 return sources 

653 

654 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table, 

655 object_ids: List[int], 

656 midPointTai_start: float 

657 ) -> pandas.DataFrame: 

658 """Returns catalog of DiaSource or DiaForcedSource instances given set 

659 of DiaObject IDs. 

660 

661 Parameters 

662 ---------- 

663 table : `sqlalchemy.schema.Table` 

664 Database table. 

665 object_ids : 

666 Collection of DiaObject IDs 

667 midPointTai_start : `float` 

668 Earliest midPointTai to retrieve. 

669 

670 Returns 

671 ------- 

672 catalog : `pandas.DataFrame` 

673 Catalog contaning DiaSource records. `None` is returned if 

674 ``read_sources_months`` configuration parameter is set to 0 or 

675 when ``object_ids`` is empty. 

676 """ 

677 sources: Optional[pandas.DataFrame] = None 

678 with _ansi_session(self._engine) as conn: 

679 if len(object_ids) <= 0: 

680 _LOG.debug("ID list is empty, just fetch empty result") 

681 query = table.select().where(False) 

682 sources = pandas.read_sql_query(query, conn) 

683 else: 

684 for ids in _split(sorted(object_ids), 1000): 

685 query = f'SELECT * FROM "{table.name}" WHERE ' 

686 

687 # select by object id 

688 ids_str = ",".join(str(id) for id in ids) 

689 query += f'"diaObjectId" IN ({ids_str})' 

690 query += f' AND "midPointTai" > {midPointTai_start}' 

691 

692 # execute select 

693 df = pandas.read_sql_query(sql.text(query), conn) 

694 if sources is None: 

695 sources = df 

696 else: 

697 sources = sources.append(df) 

698 assert sources is not None, "Catalog cannot be None" 

699 return sources 

700 

701 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

702 """Store catalog of DiaObjects from current visit. 

703 

704 Parameters 

705 ---------- 

706 objs : `pandas.DataFrame` 

707 Catalog with DiaObject records. 

708 visit_time : `lsst.daf.base.DateTime` 

709 Time of the visit. 

710 """ 

711 

712 ids = sorted(objs['diaObjectId']) 

713 _LOG.debug("first object ID: %d", ids[0]) 

714 

715 # NOTE: workaround for sqlite, need this here to avoid 

716 # "database is locked" error. 

717 table: sqlalchemy.schema.Table = self._schema.objects 

718 

719 # TODO: Need to verify that we are using correct scale here for 

720 # DATETIME representation (see DM-31996). 

721 dt = visit_time.toPython() 

722 

723 # everything to be done in single transaction 

724 with _ansi_session(self._engine) as conn: 

725 

726 ids_str = ",".join(str(id) for id in ids) 

727 

728 if self.config.dia_object_index == 'last_object_table': 

729 

730 # insert and replace all records in LAST table, mysql and postgres have 

731 # non-standard features 

732 table = self._schema.objects_last 

733 do_replace = self.config.object_last_replace 

734 # If the input data is of type Pandas, we drop the previous 

735 # objects regardless of the do_replace setting due to how 

736 # Pandas inserts objects. 

737 if not do_replace or isinstance(objs, pandas.DataFrame): 

738 query = 'DELETE FROM "' + table.name + '" ' 

739 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

740 

741 if self.config.explain: 

742 # run the same query with explain 

743 self._explain(query, conn) 

744 

745 with Timer(table.name + ' delete', self.config.timer): 

746 res = conn.execute(sql.text(query)) 

747 _LOG.debug("deleted %s objects", res.rowcount) 

748 

749 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt) 

750 with Timer("DiaObjectLast insert", self.config.timer): 

751 objs = _coerce_uint64(objs) 

752 for col, data in extra_columns.items(): 

753 objs[col] = data 

754 objs.to_sql("DiaObjectLast", conn, if_exists='append', 

755 index=False) 

756 else: 

757 

758 # truncate existing validity intervals 

759 table = self._schema.objects 

760 query = 'UPDATE "' + table.name + '" ' 

761 query += "SET \"validityEnd\" = '" + str(dt) + "' " 

762 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

763 query += 'AND "validityEnd" IS NULL' 

764 

765 # _LOG.debug("query: %s", query) 

766 

767 if self.config.explain: 

768 # run the same query with explain 

769 self._explain(query, conn) 

770 

771 with Timer(table.name + ' truncate', self.config.timer): 

772 res = conn.execute(sql.text(query)) 

773 _LOG.debug("truncated %s intervals", res.rowcount) 

774 

775 # insert new versions 

776 table = self._schema.objects 

777 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt, 

778 validityEnd=None) 

779 with Timer("DiaObject insert", self.config.timer): 

780 objs = _coerce_uint64(objs) 

781 for col, data in extra_columns.items(): 

782 objs[col] = data 

783 objs.to_sql("DiaObject", conn, if_exists='append', 

784 index=False) 

785 

786 def _storeDiaSources(self, sources: pandas.DataFrame) -> None: 

787 """Store catalog of DiaSources from current visit. 

788 

789 Parameters 

790 ---------- 

791 sources : `pandas.DataFrame` 

792 Catalog containing DiaSource records 

793 """ 

794 # everything to be done in single transaction 

795 with _ansi_session(self._engine) as conn: 

796 

797 with Timer("DiaSource insert", self.config.timer): 

798 sources = _coerce_uint64(sources) 

799 sources.to_sql("DiaSource", conn, if_exists='append', index=False) 

800 

801 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None: 

802 """Store a set of DiaForcedSources from current visit. 

803 

804 Parameters 

805 ---------- 

806 sources : `pandas.DataFrame` 

807 Catalog containing DiaForcedSource records 

808 """ 

809 

810 # everything to be done in single transaction 

811 with _ansi_session(self._engine) as conn: 

812 

813 with Timer("DiaForcedSource insert", self.config.timer): 

814 sources = _coerce_uint64(sources) 

815 sources.to_sql("DiaForcedSource", conn, if_exists='append', index=False) 

816 

817 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None: 

818 """Run the query with explain 

819 """ 

820 

821 _LOG.info("explain for query: %s...", query[:64]) 

822 

823 if conn.engine.name == 'mysql': 

824 query = "EXPLAIN EXTENDED " + query 

825 else: 

826 query = "EXPLAIN " + query 

827 

828 res = conn.execute(sql.text(query)) 

829 if res.returns_rows: 

830 _LOG.info("explain: %s", res.keys()) 

831 for row in res: 

832 _LOG.info("explain: %s", row) 

833 else: 

834 _LOG.info("EXPLAIN returned nothing") 

835 

836 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

837 """Generate a set of HTM indices covering specified region. 

838 

839 Parameters 

840 ---------- 

841 region: `sphgeom.Region` 

842 Region that needs to be indexed. 

843 

844 Returns 

845 ------- 

846 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

847 """ 

848 _LOG.debug('region: %s', region) 

849 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

850 

851 if _LOG.isEnabledFor(logging.DEBUG): 

852 for irange in indices.ranges(): 

853 _LOG.debug('range: %s %s', self.pixelator.toString(irange[0]), 

854 self.pixelator.toString(irange[1])) 

855 

856 return indices.ranges() 

857 

858 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement: 

859 """Make SQLAlchemy expression for selecting records in a region. 

860 """ 

861 htm_index_column = table.columns[self.config.htm_index_column] 

862 exprlist = [] 

863 pixel_ranges = self._htm_indices(region) 

864 for low, upper in pixel_ranges: 

865 upper -= 1 

866 if low == upper: 

867 exprlist.append(htm_index_column == low) 

868 else: 

869 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

870 

871 return sql.expression.or_(*exprlist) 

872 

873 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

874 """Calculate HTM index for each record and add it to a DataFrame. 

875 

876 Notes 

877 ----- 

878 This overrides any existing column in a DataFrame with the same name 

879 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

880 returned. 

881 """ 

882 # calculate HTM index for every DiaObject 

883 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

884 ra_col, dec_col = self.config.ra_dec_columns 

885 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

886 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

887 idx = self.pixelator.index(uv3d) 

888 htm_index[i] = idx 

889 df = df.copy() 

890 df[self.config.htm_index_column] = htm_index 

891 return df 

892 

893 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

894 """Add pixelId column to DiaSource catalog. 

895 

896 Notes 

897 ----- 

898 This method copies pixelId value from a matching DiaObject record. 

899 DiaObject catalog needs to have a pixelId column filled by 

900 ``_add_obj_htm_index`` method and DiaSource records need to be 

901 associated to DiaObjects via ``diaObjectId`` column. 

902 

903 This overrides any existing column in a DataFrame with the same name 

904 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

905 returned. 

906 """ 

907 pixel_id_map: Dict[int, int] = { 

908 diaObjectId: pixelId for diaObjectId, pixelId 

909 in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

910 } 

911 # DiaSources associated with SolarSystemObjects do not have an 

912 # associated DiaObject hence we skip them and set their htmIndex 

913 # value to 0. 

914 pixel_id_map[0] = 0 

915 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

916 for i, diaObjId in enumerate(sources["diaObjectId"]): 

917 htm_index[i] = pixel_id_map[diaObjId] 

918 sources = sources.copy() 

919 sources[self.config.htm_index_column] = htm_index 

920 return sources