Coverage for python/lsst/dax/apdb/apdbSql.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

376 statements  

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29from contextlib import contextmanager 

30import logging 

31import numpy as np 

32import pandas 

33from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple 

34 

35import lsst.daf.base as dafBase 

36from lsst.pex.config import Field, ChoiceField, ListField 

37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

38import sqlalchemy 

39from sqlalchemy import (func, sql) 

40from sqlalchemy.pool import NullPool 

41from .apdb import Apdb, ApdbConfig 

42from .apdbSchema import ApdbTables, TableDef 

43from .apdbSqlSchema import ApdbSqlSchema 

44from .timer import Timer 

45 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50def _split(seq: Iterable, nItems: int) -> Iterator[List]: 

51 """Split a sequence into smaller sequences""" 

52 seq = list(seq) 

53 while seq: 

54 yield seq[:nItems] 

55 del seq[:nItems] 

56 

57 

58def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

59 """Change type of the uint64 columns to int64, return copy of data frame. 

60 """ 

61 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

62 return df.astype({name: np.int64 for name in names}) 

63 

64 

65def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

66 """Calculate starting point for time-based source search. 

67 

68 Parameters 

69 ---------- 

70 visit_time : `lsst.daf.base.DateTime` 

71 Time of current visit. 

72 months : `int` 

73 Number of months in the sources history. 

74 

75 Returns 

76 ------- 

77 time : `float` 

78 A ``midPointTai`` starting point, MJD time. 

79 """ 

80 # TODO: `system` must be consistent with the code in ap_association 

81 # (see DM-31996) 

82 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

83 

84 

85@contextmanager 

86def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]: 

87 """Returns a connection, makes sure that ANSI mode is set for MySQL 

88 """ 

89 with engine.begin() as conn: 

90 if engine.name == 'mysql': 

91 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'")) 

92 yield conn 

93 return 

94 

95 

96class ApdbSqlConfig(ApdbConfig): 

97 """APDB configuration class for SQL implementation (ApdbSql). 

98 """ 

99 db_url = Field( 

100 dtype=str, 

101 doc="SQLAlchemy database connection URI" 

102 ) 

103 isolation_level = ChoiceField( 

104 dtype=str, 

105 doc="Transaction isolation level, if unset then backend-default value " 

106 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

107 "Some backends may not support every allowed value.", 

108 allowed={ 

109 "READ_COMMITTED": "Read committed", 

110 "READ_UNCOMMITTED": "Read uncommitted", 

111 "REPEATABLE_READ": "Repeatable read", 

112 "SERIALIZABLE": "Serializable" 

113 }, 

114 default=None, 

115 optional=True 

116 ) 

117 connection_pool = Field( 

118 dtype=bool, 

119 doc="If False then disable SQLAlchemy connection pool. " 

120 "Do not use connection pool when forking.", 

121 default=True 

122 ) 

123 connection_timeout = Field( 

124 dtype=float, 

125 doc="Maximum time to wait time for database lock to be released before " 

126 "exiting. Defaults to sqlachemy defaults if not set.", 

127 default=None, 

128 optional=True 

129 ) 

130 sql_echo = Field( 

131 dtype=bool, 

132 doc="If True then pass SQLAlchemy echo option.", 

133 default=False 

134 ) 

135 dia_object_index = ChoiceField( 

136 dtype=str, 

137 doc="Indexing mode for DiaObject table", 

138 allowed={ 

139 'baseline': "Index defined in baseline schema", 

140 'pix_id_iov': "(pixelId, objectId, iovStart) PK", 

141 'last_object_table': "Separate DiaObjectLast table" 

142 }, 

143 default='baseline' 

144 ) 

145 htm_level = Field( 

146 dtype=int, 

147 doc="HTM indexing level", 

148 default=20 

149 ) 

150 htm_max_ranges = Field( 

151 dtype=int, 

152 doc="Max number of ranges in HTM envelope", 

153 default=64 

154 ) 

155 htm_index_column = Field( 

156 dtype=str, 

157 default="pixelId", 

158 doc="Name of a HTM index column for DiaObject and DiaSource tables" 

159 ) 

160 ra_dec_columns = ListField( 

161 dtype=str, 

162 default=["ra", "decl"], 

163 doc="Names ra/dec columns in DiaObject table" 

164 ) 

165 dia_object_columns = ListField( 

166 dtype=str, 

167 doc="List of columns to read from DiaObject, by default read all columns", 

168 default=[] 

169 ) 

170 object_last_replace = Field( 

171 dtype=bool, 

172 doc="If True (default) then use \"upsert\" for DiaObjectsLast table", 

173 default=True 

174 ) 

175 prefix = Field( 

176 dtype=str, 

177 doc="Prefix to add to table names and index names", 

178 default="" 

179 ) 

180 explain = Field( 

181 dtype=bool, 

182 doc="If True then run EXPLAIN SQL command on each executed query", 

183 default=False 

184 ) 

185 timer = Field( 

186 dtype=bool, 

187 doc="If True then print/log timing information", 

188 default=False 

189 ) 

190 

191 def validate(self) -> None: 

192 super().validate() 

193 if len(self.ra_dec_columns) != 2: 

194 raise ValueError("ra_dec_columns must have exactly two column names") 

195 

196 

197class ApdbSql(Apdb): 

198 """Implementation of APDB interface based on SQL database. 

199 

200 The implementation is configured via standard ``pex_config`` mechanism 

201 using `ApdbSqlConfig` configuration class. For an example of different 

202 configurations check ``config/`` folder. 

203 

204 Parameters 

205 ---------- 

206 config : `ApdbSqlConfig` 

207 Configuration object. 

208 """ 

209 

210 ConfigClass = ApdbSqlConfig 

211 

212 def __init__(self, config: ApdbSqlConfig): 

213 

214 self.config = config 

215 

216 _LOG.debug("APDB Configuration:") 

217 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

218 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

219 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

220 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

221 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace) 

222 _LOG.debug(" schema_file: %s", self.config.schema_file) 

223 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

224 _LOG.debug(" schema prefix: %s", self.config.prefix) 

225 

226 # engine is reused between multiple processes, make sure that we don't 

227 # share connections by disabling pool (by using NullPool class) 

228 kw = dict(echo=self.config.sql_echo) 

229 conn_args: Dict[str, Any] = dict() 

230 if not self.config.connection_pool: 

231 kw.update(poolclass=NullPool) 

232 if self.config.isolation_level is not None: 

233 kw.update(isolation_level=self.config.isolation_level) 

234 elif self.config.db_url.startswith("sqlite"): 

235 # Use READ_UNCOMMITTED as default value for sqlite. 

236 kw.update(isolation_level="READ_UNCOMMITTED") 

237 if self.config.connection_timeout is not None: 

238 if self.config.db_url.startswith("sqlite"): 

239 conn_args.update(timeout=self.config.connection_timeout) 

240 elif self.config.db_url.startswith(("postgresql", "mysql")): 

241 conn_args.update(connect_timeout=self.config.connection_timeout) 

242 kw.update(connect_args=conn_args) 

243 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

244 

245 self._schema = ApdbSqlSchema(engine=self._engine, 

246 dia_object_index=self.config.dia_object_index, 

247 schema_file=self.config.schema_file, 

248 extra_schema_file=self.config.extra_schema_file, 

249 prefix=self.config.prefix, 

250 htm_index_column=self.config.htm_index_column) 

251 

252 self.pixelator = HtmPixelization(self.config.htm_level) 

253 

254 def tableRowCount(self) -> Dict[str, int]: 

255 """Returns dictionary with the table names and row counts. 

256 

257 Used by ``ap_proto`` to keep track of the size of the database tables. 

258 Depending on database technology this could be expensive operation. 

259 

260 Returns 

261 ------- 

262 row_counts : `dict` 

263 Dict where key is a table name and value is a row count. 

264 """ 

265 res = {} 

266 tables: List[sqlalchemy.schema.Table] = [ 

267 self._schema.objects, self._schema.sources, self._schema.forcedSources] 

268 if self.config.dia_object_index == 'last_object_table': 

269 tables.append(self._schema.objects_last) 

270 for table in tables: 

271 stmt = sql.select([func.count()]).select_from(table) 

272 count = self._engine.scalar(stmt) 

273 res[table.name] = count 

274 

275 return res 

276 

277 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

278 # docstring is inherited from a base class 

279 return self._schema.tableSchemas.get(table) 

280 

281 def makeSchema(self, drop: bool = False) -> None: 

282 # docstring is inherited from a base class 

283 self._schema.makeSchema(drop=drop) 

284 

285 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

286 # docstring is inherited from a base class 

287 

288 # decide what columns we need 

289 table: sqlalchemy.schema.Table 

290 if self.config.dia_object_index == 'last_object_table': 

291 table = self._schema.objects_last 

292 else: 

293 table = self._schema.objects 

294 if not self.config.dia_object_columns: 

295 query = table.select() 

296 else: 

297 columns = [table.c[col] for col in self.config.dia_object_columns] 

298 query = sql.select(columns) 

299 

300 # build selection 

301 query = query.where(self._filterRegion(table, region)) 

302 

303 # select latest version of objects 

304 if self.config.dia_object_index != 'last_object_table': 

305 query = query.where(table.c.validityEnd == None) # noqa: E711 

306 

307 _LOG.debug("query: %s", query) 

308 

309 if self.config.explain: 

310 # run the same query with explain 

311 self._explain(query, self._engine) 

312 

313 # execute select 

314 with Timer('DiaObject select', self.config.timer): 

315 with self._engine.begin() as conn: 

316 objects = pandas.read_sql_query(query, conn) 

317 _LOG.debug("found %s DiaObjects", len(objects)) 

318 return objects 

319 

320 def getDiaSources(self, region: Region, 

321 object_ids: Optional[Iterable[int]], 

322 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

323 # docstring is inherited from a base class 

324 if self.config.read_sources_months == 0: 

325 _LOG.debug("Skip DiaSources fetching") 

326 return None 

327 

328 if object_ids is None: 

329 # region-based select 

330 return self._getDiaSourcesInRegion(region, visit_time) 

331 else: 

332 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

333 

334 def getDiaForcedSources(self, region: Region, 

335 object_ids: Optional[Iterable[int]], 

336 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

337 """Return catalog of DiaForcedSource instances from a given region. 

338 

339 Parameters 

340 ---------- 

341 region : `lsst.sphgeom.Region` 

342 Region to search for DIASources. 

343 object_ids : iterable [ `int` ], optional 

344 List of DiaObject IDs to further constrain the set of returned 

345 sources. If list is empty then empty catalog is returned with a 

346 correct schema. 

347 visit_time : `lsst.daf.base.DateTime` 

348 Time of the current visit. 

349 

350 Returns 

351 ------- 

352 catalog : `pandas.DataFrame`, or `None` 

353 Catalog containing DiaSource records. `None` is returned if 

354 ``read_sources_months`` configuration parameter is set to 0. 

355 

356 Raises 

357 ------ 

358 NotImplementedError 

359 Raised if ``object_ids`` is `None`. 

360 

361 Notes 

362 ----- 

363 Even though base class allows `None` to be passed for ``object_ids``, 

364 this class requires ``object_ids`` to be not-`None`. 

365 `NotImplementedError` is raised if `None` is passed. 

366 

367 This method returns DiaForcedSource catalog for a region with additional 

368 filtering based on DiaObject IDs. Only a subset of DiaSource history 

369 is returned limited by ``read_forced_sources_months`` config parameter, 

370 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

371 is always returned with a correct schema (columns/types). 

372 """ 

373 

374 if self.config.read_forced_sources_months == 0: 

375 _LOG.debug("Skip DiaForceSources fetching") 

376 return None 

377 

378 if object_ids is None: 

379 # This implementation does not support region-based selection. 

380 raise NotImplementedError("Region-based selection is not supported") 

381 

382 # TODO: DateTime.MJD must be consistent with code in ap_association, 

383 # alternatively we can fill midPointTai ourselves in store() 

384 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

385 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

386 

387 table: sqlalchemy.schema.Table = self._schema.forcedSources 

388 with Timer('DiaForcedSource select', self.config.timer): 

389 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start) 

390 

391 _LOG.debug("found %s DiaForcedSources", len(sources)) 

392 return sources 

393 

394 def getDiaObjectsHistory(self, 

395 start_time: dafBase.DateTime, 

396 end_time: Optional[dafBase.DateTime] = None, 

397 region: Optional[Region] = None) -> pandas.DataFrame: 

398 # docstring is inherited from a base class 

399 

400 table = self._schema.objects 

401 query = table.select() 

402 

403 # build selection 

404 time_filter = table.columns["validityStart"] >= start_time.toPython() 

405 if end_time: 

406 time_filter = sql.expression.and_( 

407 time_filter, 

408 table.columns["validityStart"] < end_time.toPython() 

409 ) 

410 

411 if region: 

412 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

413 query = query.where(where) 

414 else: 

415 query = query.where(time_filter) 

416 

417 # execute select 

418 with Timer('DiaObject history select', self.config.timer): 

419 with self._engine.begin() as conn: 

420 catalog = pandas.read_sql_query(query, conn) 

421 _LOG.debug("found %s DiaObjects history records", len(catalog)) 

422 return catalog 

423 

424 def getDiaSourcesHistory(self, 

425 start_time: dafBase.DateTime, 

426 end_time: Optional[dafBase.DateTime] = None, 

427 region: Optional[Region] = None) -> pandas.DataFrame: 

428 # docstring is inherited from a base class 

429 

430 table = self._schema.sources 

431 query = table.select() 

432 

433 # build selection 

434 time_filter = table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD) 

435 if end_time: 

436 time_filter = sql.expression.and_( 

437 time_filter, 

438 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

439 ) 

440 

441 if region: 

442 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

443 query = query.where(where) 

444 else: 

445 query = query.where(time_filter) 

446 

447 # execute select 

448 with Timer('DiaSource history select', self.config.timer): 

449 with self._engine.begin() as conn: 

450 catalog = pandas.read_sql_query(query, conn) 

451 _LOG.debug("found %s DiaSource history records", len(catalog)) 

452 return catalog 

453 

454 def getDiaForcedSourcesHistory(self, 

455 start_time: dafBase.DateTime, 

456 end_time: Optional[dafBase.DateTime] = None, 

457 region: Optional[Region] = None) -> pandas.DataFrame: 

458 # docstring is inherited from a base class 

459 

460 table = self._schema.forcedSources 

461 query = table.select() 

462 

463 # build selection 

464 time_filter = table.columns["midPointTai"] >= start_time.get(system=dafBase.DateTime.MJD) 

465 if end_time: 

466 time_filter = sql.expression.and_( 

467 time_filter, 

468 table.columns["midPointTai"] < end_time.get(system=dafBase.DateTime.MJD) 

469 ) 

470 # Forced sources have no pixel index, so no region filtering 

471 query = query.where(time_filter) 

472 

473 # execute select 

474 with Timer('DiaForcedSource history select', self.config.timer): 

475 with self._engine.begin() as conn: 

476 catalog = pandas.read_sql_query(query, conn) 

477 _LOG.debug("found %s DiaForcedSource history records", len(catalog)) 

478 return catalog 

479 

480 def getSSObjects(self) -> pandas.DataFrame: 

481 # docstring is inherited from a base class 

482 

483 table = self._schema.ssObjects 

484 query = table.select() 

485 

486 if self.config.explain: 

487 # run the same query with explain 

488 self._explain(query, self._engine) 

489 

490 # execute select 

491 with Timer('DiaObject select', self.config.timer): 

492 with self._engine.begin() as conn: 

493 objects = pandas.read_sql_query(query, conn) 

494 _LOG.debug("found %s SSObjects", len(objects)) 

495 return objects 

496 

497 def store(self, 

498 visit_time: dafBase.DateTime, 

499 objects: pandas.DataFrame, 

500 sources: Optional[pandas.DataFrame] = None, 

501 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

502 # docstring is inherited from a base class 

503 

504 # fill pixelId column for DiaObjects 

505 objects = self._add_obj_htm_index(objects) 

506 self._storeDiaObjects(objects, visit_time) 

507 

508 if sources is not None: 

509 # copy pixelId column from DiaObjects to DiaSources 

510 sources = self._add_src_htm_index(sources, objects) 

511 self._storeDiaSources(sources) 

512 

513 if forced_sources is not None: 

514 self._storeDiaForcedSources(forced_sources) 

515 

516 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

517 # docstring is inherited from a base class 

518 

519 idColumn = "ssObjectId" 

520 table = self._schema.ssObjects 

521 

522 # everything to be done in single transaction 

523 with self._engine.begin() as conn: 

524 

525 # find record IDs that already exist 

526 ids = sorted(objects[idColumn]) 

527 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

528 result = conn.execute(query) 

529 knownIds = set(row[idColumn] for row in result) 

530 

531 filter = objects[idColumn].isin(knownIds) 

532 toUpdate = objects[filter] 

533 toInsert = objects[~filter] 

534 

535 # insert new records 

536 if len(toInsert) > 0: 

537 toInsert.to_sql(ApdbTables.SSObject.table_name(), conn, if_exists='append', index=False) 

538 

539 # update existing records 

540 if len(toUpdate) > 0: 

541 whereKey = f"{idColumn}_param" 

542 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

543 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

544 values = toUpdate.to_dict("records") 

545 result = conn.execute(query, values) 

546 

547 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

548 # docstring is inherited from a base class 

549 

550 table = self._schema.sources 

551 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

552 

553 with self._engine.begin() as conn: 

554 # TODO: diaObjectId should probably be None but in our current 

555 # schema it is defined NOT NULL, may need to update for the future 

556 # schema. 

557 params = [dict(srcId=key, diaObjectId=0, ssObjectId=value) for key, value in idMap.items()] 

558 conn.execute(query, params) 

559 

560 def dailyJob(self) -> None: 

561 # docstring is inherited from a base class 

562 

563 if self._engine.name == 'postgresql': 

564 

565 # do VACUUM on all tables 

566 _LOG.info("Running VACUUM on all tables") 

567 connection = self._engine.raw_connection() 

568 ISOLATION_LEVEL_AUTOCOMMIT = 0 

569 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

570 cursor = connection.cursor() 

571 cursor.execute("VACUUM ANALYSE") 

572 

573 def countUnassociatedObjects(self) -> int: 

574 # docstring is inherited from a base class 

575 

576 # Retrieve the DiaObject table. 

577 table: sqlalchemy.schema.Table = self._schema.objects 

578 

579 # Construct the sql statement. 

580 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

581 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

582 

583 # Return the count. 

584 with self._engine.begin() as conn: 

585 count = conn.scalar(stmt) 

586 

587 return count 

588 

589 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime 

590 ) -> pandas.DataFrame: 

591 """Returns catalog of DiaSource instances from given region. 

592 

593 Parameters 

594 ---------- 

595 region : `lsst.sphgeom.Region` 

596 Region to search for DIASources. 

597 visit_time : `lsst.daf.base.DateTime` 

598 Time of the current visit. 

599 

600 Returns 

601 ------- 

602 catalog : `pandas.DataFrame` 

603 Catalog containing DiaSource records. 

604 """ 

605 # TODO: DateTime.MJD must be consistent with code in ap_association, 

606 # alternatively we can fill midPointTai ourselves in store() 

607 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

608 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

609 

610 table: sqlalchemy.schema.Table = self._schema.sources 

611 query = table.select() 

612 

613 # build selection 

614 time_filter = table.columns["midPointTai"] > midPointTai_start 

615 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

616 query = query.where(where) 

617 

618 # execute select 

619 with Timer('DiaSource select', self.config.timer): 

620 with _ansi_session(self._engine) as conn: 

621 sources = pandas.read_sql_query(query, conn) 

622 _LOG.debug("found %s DiaSources", len(sources)) 

623 return sources 

624 

625 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime 

626 ) -> pandas.DataFrame: 

627 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

628 

629 Parameters 

630 ---------- 

631 object_ids : 

632 Collection of DiaObject IDs 

633 visit_time : `lsst.daf.base.DateTime` 

634 Time of the current visit. 

635 

636 Returns 

637 ------- 

638 catalog : `pandas.DataFrame` 

639 Catalog contaning DiaSource records. 

640 """ 

641 # TODO: DateTime.MJD must be consistent with code in ap_association, 

642 # alternatively we can fill midPointTai ourselves in store() 

643 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

644 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

645 

646 table: sqlalchemy.schema.Table = self._schema.sources 

647 with Timer('DiaSource select', self.config.timer): 

648 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start) 

649 

650 _LOG.debug("found %s DiaSources", len(sources)) 

651 return sources 

652 

653 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table, 

654 object_ids: List[int], 

655 midPointTai_start: float 

656 ) -> pandas.DataFrame: 

657 """Returns catalog of DiaSource or DiaForcedSource instances given set 

658 of DiaObject IDs. 

659 

660 Parameters 

661 ---------- 

662 table : `sqlalchemy.schema.Table` 

663 Database table. 

664 object_ids : 

665 Collection of DiaObject IDs 

666 midPointTai_start : `float` 

667 Earliest midPointTai to retrieve. 

668 

669 Returns 

670 ------- 

671 catalog : `pandas.DataFrame` 

672 Catalog contaning DiaSource records. `None` is returned if 

673 ``read_sources_months`` configuration parameter is set to 0 or 

674 when ``object_ids`` is empty. 

675 """ 

676 sources: Optional[pandas.DataFrame] = None 

677 with _ansi_session(self._engine) as conn: 

678 if len(object_ids) <= 0: 

679 _LOG.debug("ID list is empty, just fetch empty result") 

680 query = table.select().where(False) 

681 sources = pandas.read_sql_query(query, conn) 

682 else: 

683 for ids in _split(sorted(object_ids), 1000): 

684 query = f'SELECT * FROM "{table.name}" WHERE ' 

685 

686 # select by object id 

687 ids_str = ",".join(str(id) for id in ids) 

688 query += f'"diaObjectId" IN ({ids_str})' 

689 query += f' AND "midPointTai" > {midPointTai_start}' 

690 

691 # execute select 

692 df = pandas.read_sql_query(sql.text(query), conn) 

693 if sources is None: 

694 sources = df 

695 else: 

696 sources = sources.append(df) 

697 assert sources is not None, "Catalog cannot be None" 

698 return sources 

699 

700 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

701 """Store catalog of DiaObjects from current visit. 

702 

703 Parameters 

704 ---------- 

705 objs : `pandas.DataFrame` 

706 Catalog with DiaObject records. 

707 visit_time : `lsst.daf.base.DateTime` 

708 Time of the visit. 

709 """ 

710 

711 ids = sorted(objs['diaObjectId']) 

712 _LOG.debug("first object ID: %d", ids[0]) 

713 

714 # NOTE: workaround for sqlite, need this here to avoid 

715 # "database is locked" error. 

716 table: sqlalchemy.schema.Table = self._schema.objects 

717 

718 # TODO: Need to verify that we are using correct scale here for 

719 # DATETIME representation (see DM-31996). 

720 dt = visit_time.toPython() 

721 

722 # everything to be done in single transaction 

723 with _ansi_session(self._engine) as conn: 

724 

725 ids_str = ",".join(str(id) for id in ids) 

726 

727 if self.config.dia_object_index == 'last_object_table': 

728 

729 # insert and replace all records in LAST table, mysql and postgres have 

730 # non-standard features 

731 table = self._schema.objects_last 

732 do_replace = self.config.object_last_replace 

733 # If the input data is of type Pandas, we drop the previous 

734 # objects regardless of the do_replace setting due to how 

735 # Pandas inserts objects. 

736 if not do_replace or isinstance(objs, pandas.DataFrame): 

737 query = 'DELETE FROM "' + table.name + '" ' 

738 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

739 

740 if self.config.explain: 

741 # run the same query with explain 

742 self._explain(query, conn) 

743 

744 with Timer(table.name + ' delete', self.config.timer): 

745 res = conn.execute(sql.text(query)) 

746 _LOG.debug("deleted %s objects", res.rowcount) 

747 

748 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt) 

749 with Timer("DiaObjectLast insert", self.config.timer): 

750 objs = _coerce_uint64(objs) 

751 for col, data in extra_columns.items(): 

752 objs[col] = data 

753 objs.to_sql("DiaObjectLast", conn, if_exists='append', 

754 index=False) 

755 else: 

756 

757 # truncate existing validity intervals 

758 table = self._schema.objects 

759 query = 'UPDATE "' + table.name + '" ' 

760 query += "SET \"validityEnd\" = '" + str(dt) + "' " 

761 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

762 query += 'AND "validityEnd" IS NULL' 

763 

764 # _LOG.debug("query: %s", query) 

765 

766 if self.config.explain: 

767 # run the same query with explain 

768 self._explain(query, conn) 

769 

770 with Timer(table.name + ' truncate', self.config.timer): 

771 res = conn.execute(sql.text(query)) 

772 _LOG.debug("truncated %s intervals", res.rowcount) 

773 

774 # insert new versions 

775 table = self._schema.objects 

776 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt, 

777 validityEnd=None) 

778 with Timer("DiaObject insert", self.config.timer): 

779 objs = _coerce_uint64(objs) 

780 for col, data in extra_columns.items(): 

781 objs[col] = data 

782 objs.to_sql("DiaObject", conn, if_exists='append', 

783 index=False) 

784 

785 def _storeDiaSources(self, sources: pandas.DataFrame) -> None: 

786 """Store catalog of DiaSources from current visit. 

787 

788 Parameters 

789 ---------- 

790 sources : `pandas.DataFrame` 

791 Catalog containing DiaSource records 

792 """ 

793 # everything to be done in single transaction 

794 with _ansi_session(self._engine) as conn: 

795 

796 with Timer("DiaSource insert", self.config.timer): 

797 sources = _coerce_uint64(sources) 

798 sources.to_sql("DiaSource", conn, if_exists='append', index=False) 

799 

800 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None: 

801 """Store a set of DiaForcedSources from current visit. 

802 

803 Parameters 

804 ---------- 

805 sources : `pandas.DataFrame` 

806 Catalog containing DiaForcedSource records 

807 """ 

808 

809 # everything to be done in single transaction 

810 with _ansi_session(self._engine) as conn: 

811 

812 with Timer("DiaForcedSource insert", self.config.timer): 

813 sources = _coerce_uint64(sources) 

814 sources.to_sql("DiaForcedSource", conn, if_exists='append', index=False) 

815 

816 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None: 

817 """Run the query with explain 

818 """ 

819 

820 _LOG.info("explain for query: %s...", query[:64]) 

821 

822 if conn.engine.name == 'mysql': 

823 query = "EXPLAIN EXTENDED " + query 

824 else: 

825 query = "EXPLAIN " + query 

826 

827 res = conn.execute(sql.text(query)) 

828 if res.returns_rows: 

829 _LOG.info("explain: %s", res.keys()) 

830 for row in res: 

831 _LOG.info("explain: %s", row) 

832 else: 

833 _LOG.info("EXPLAIN returned nothing") 

834 

835 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

836 """Generate a set of HTM indices covering specified region. 

837 

838 Parameters 

839 ---------- 

840 region: `sphgeom.Region` 

841 Region that needs to be indexed. 

842 

843 Returns 

844 ------- 

845 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

846 """ 

847 _LOG.debug('region: %s', region) 

848 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

849 

850 if _LOG.isEnabledFor(logging.DEBUG): 

851 for irange in indices.ranges(): 

852 _LOG.debug('range: %s %s', self.pixelator.toString(irange[0]), 

853 self.pixelator.toString(irange[1])) 

854 

855 return indices.ranges() 

856 

857 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement: 

858 """Make SQLAlchemy expression for selecting records in a region. 

859 """ 

860 htm_index_column = table.columns[self.config.htm_index_column] 

861 exprlist = [] 

862 pixel_ranges = self._htm_indices(region) 

863 for low, upper in pixel_ranges: 

864 upper -= 1 

865 if low == upper: 

866 exprlist.append(htm_index_column == low) 

867 else: 

868 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

869 

870 return sql.expression.or_(*exprlist) 

871 

872 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

873 """Calculate HTM index for each record and add it to a DataFrame. 

874 

875 Notes 

876 ----- 

877 This overrides any existing column in a DataFrame with the same name 

878 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

879 returned. 

880 """ 

881 # calculate HTM index for every DiaObject 

882 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

883 ra_col, dec_col = self.config.ra_dec_columns 

884 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

885 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

886 idx = self.pixelator.index(uv3d) 

887 htm_index[i] = idx 

888 df = df.copy() 

889 df[self.config.htm_index_column] = htm_index 

890 return df 

891 

892 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

893 """Add pixelId column to DiaSource catalog. 

894 

895 Notes 

896 ----- 

897 This method copies pixelId value from a matching DiaObject record. 

898 DiaObject catalog needs to have a pixelId column filled by 

899 ``_add_obj_htm_index`` method and DiaSource records need to be 

900 associated to DiaObjects via ``diaObjectId`` column. 

901 

902 This overrides any existing column in a DataFrame with the same name 

903 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

904 returned. 

905 """ 

906 pixel_id_map: Dict[int, int] = { 

907 diaObjectId: pixelId for diaObjectId, pixelId 

908 in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

909 } 

910 # DiaSources associated with SolarSystemObjects do not have an 

911 # associated DiaObject hence we skip them and set their htmIndex 

912 # value to 0. 

913 pixel_id_map[0] = 0 

914 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

915 for i, diaObjId in enumerate(sources["diaObjectId"]): 

916 htm_index[i] = pixel_id_map[diaObjId] 

917 sources = sources.copy() 

918 sources[self.config.htm_index_column] = htm_index 

919 return sources