Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

408 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-19 02:07 -0800

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from typing import Any, Dict, List, Optional, Tuple, cast 

32 

33import lsst.daf.base as dafBase 

34import numpy as np 

35import pandas 

36import sqlalchemy 

37from felis.simple import Table 

38from lsst.pex.config import ChoiceField, Field, ListField 

39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

40from lsst.utils.iteration import chunk_iterable 

41from sqlalchemy import func, sql 

42from sqlalchemy.pool import NullPool 

43 

44from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

45from .apdbSchema import ApdbTables 

46from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

47from .timer import Timer 

48 

49_LOG = logging.getLogger(__name__) 

50 

51 

52def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

53 """Change type of the uint64 columns to int64, return copy of data frame.""" 

54 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

55 return df.astype({name: np.int64 for name in names}) 

56 

57 

58def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

59 """Calculate starting point for time-based source search. 

60 

61 Parameters 

62 ---------- 

63 visit_time : `lsst.daf.base.DateTime` 

64 Time of current visit. 

65 months : `int` 

66 Number of months in the sources history. 

67 

68 Returns 

69 ------- 

70 time : `float` 

71 A ``midPointTai`` starting point, MJD time. 

72 """ 

73 # TODO: `system` must be consistent with the code in ap_association 

74 # (see DM-31996) 

75 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

76 

77 

78class ApdbSqlConfig(ApdbConfig): 

79 """APDB configuration class for SQL implementation (ApdbSql).""" 

80 

81 db_url = Field[str](doc="SQLAlchemy database connection URI") 

82 isolation_level = ChoiceField[str]( 

83 doc=( 

84 "Transaction isolation level, if unset then backend-default value " 

85 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

86 "Some backends may not support every allowed value." 

87 ), 

88 allowed={ 

89 "READ_COMMITTED": "Read committed", 

90 "READ_UNCOMMITTED": "Read uncommitted", 

91 "REPEATABLE_READ": "Repeatable read", 

92 "SERIALIZABLE": "Serializable", 

93 }, 

94 default=None, 

95 optional=True, 

96 ) 

97 connection_pool = Field[bool]( 

98 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

99 default=True, 

100 ) 

101 connection_timeout = Field[float]( 

102 doc=( 

103 "Maximum time to wait time for database lock to be released before exiting. " 

104 "Defaults to sqlalchemy defaults if not set." 

105 ), 

106 default=None, 

107 optional=True, 

108 ) 

109 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

110 dia_object_index = ChoiceField[str]( 

111 doc="Indexing mode for DiaObject table", 

112 allowed={ 

113 "baseline": "Index defined in baseline schema", 

114 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

115 "last_object_table": "Separate DiaObjectLast table", 

116 }, 

117 default="baseline", 

118 ) 

119 htm_level = Field[int](doc="HTM indexing level", default=20) 

120 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

121 htm_index_column = Field[str]( 

122 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

123 ) 

124 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names ra/dec columns in DiaObject table") 

125 dia_object_columns = ListField[str]( 

126 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

127 ) 

128 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

129 namespace = Field[str]( 

130 doc=( 

131 "Namespace or schema name for all tables in APDB database. " 

132 "Presently only makes sense for PostgresQL backend. " 

133 "If schema with this name does not exist it will be created when " 

134 "APDB tables are created." 

135 ), 

136 default=None, 

137 optional=True, 

138 ) 

139 timer = Field[bool](doc="If True then print/log timing information", default=False) 

140 

141 def validate(self) -> None: 

142 super().validate() 

143 if len(self.ra_dec_columns) != 2: 

144 raise ValueError("ra_dec_columns must have exactly two column names") 

145 

146 

147class ApdbSqlTableData(ApdbTableData): 

148 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

149 

150 def __init__(self, result: sqlalchemy.engine.Result): 

151 self.result = result 

152 

153 def column_names(self) -> list[str]: 

154 return self.result.keys() 

155 

156 def rows(self) -> Iterable[tuple]: 

157 for row in self.result: 

158 yield tuple(row) 

159 

160 

161class ApdbSql(Apdb): 

162 """Implementation of APDB interface based on SQL database. 

163 

164 The implementation is configured via standard ``pex_config`` mechanism 

165 using `ApdbSqlConfig` configuration class. For an example of different 

166 configurations check ``config/`` folder. 

167 

168 Parameters 

169 ---------- 

170 config : `ApdbSqlConfig` 

171 Configuration object. 

172 """ 

173 

174 ConfigClass = ApdbSqlConfig 

175 

176 def __init__(self, config: ApdbSqlConfig): 

177 

178 config.validate() 

179 self.config = config 

180 

181 _LOG.debug("APDB Configuration:") 

182 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

183 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

184 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

185 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

186 _LOG.debug(" schema_file: %s", self.config.schema_file) 

187 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

188 _LOG.debug(" schema prefix: %s", self.config.prefix) 

189 

190 # engine is reused between multiple processes, make sure that we don't 

191 # share connections by disabling pool (by using NullPool class) 

192 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

193 conn_args: Dict[str, Any] = dict() 

194 if not self.config.connection_pool: 

195 kw.update(poolclass=NullPool) 

196 if self.config.isolation_level is not None: 

197 kw.update(isolation_level=self.config.isolation_level) 

198 elif self.config.db_url.startswith("sqlite"): # type: ignore 

199 # Use READ_UNCOMMITTED as default value for sqlite. 

200 kw.update(isolation_level="READ_UNCOMMITTED") 

201 if self.config.connection_timeout is not None: 

202 if self.config.db_url.startswith("sqlite"): 

203 conn_args.update(timeout=self.config.connection_timeout) 

204 elif self.config.db_url.startswith(("postgresql", "mysql")): 

205 conn_args.update(connect_timeout=self.config.connection_timeout) 

206 kw.update(connect_args=conn_args) 

207 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

208 

209 self._schema = ApdbSqlSchema( 

210 engine=self._engine, 

211 dia_object_index=self.config.dia_object_index, 

212 schema_file=self.config.schema_file, 

213 schema_name=self.config.schema_name, 

214 prefix=self.config.prefix, 

215 namespace=self.config.namespace, 

216 htm_index_column=self.config.htm_index_column, 

217 use_insert_id=config.use_insert_id, 

218 ) 

219 

220 self.pixelator = HtmPixelization(self.config.htm_level) 

221 self.use_insert_id = self._schema.has_insert_id 

222 

223 def tableRowCount(self) -> Dict[str, int]: 

224 """Returns dictionary with the table names and row counts. 

225 

226 Used by ``ap_proto`` to keep track of the size of the database tables. 

227 Depending on database technology this could be expensive operation. 

228 

229 Returns 

230 ------- 

231 row_counts : `dict` 

232 Dict where key is a table name and value is a row count. 

233 """ 

234 res = {} 

235 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

236 if self.config.dia_object_index == "last_object_table": 

237 tables.append(ApdbTables.DiaObjectLast) 

238 for table in tables: 

239 sa_table = self._schema.get_table(table) 

240 stmt = sql.select([func.count()]).select_from(sa_table) 

241 count = self._engine.scalar(stmt) 

242 res[table.name] = count 

243 

244 return res 

245 

246 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

247 # docstring is inherited from a base class 

248 return self._schema.tableSchemas.get(table) 

249 

250 def makeSchema(self, drop: bool = False) -> None: 

251 # docstring is inherited from a base class 

252 self._schema.makeSchema(drop=drop) 

253 

254 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

255 # docstring is inherited from a base class 

256 

257 # decide what columns we need 

258 if self.config.dia_object_index == "last_object_table": 

259 table_enum = ApdbTables.DiaObjectLast 

260 else: 

261 table_enum = ApdbTables.DiaObject 

262 table = self._schema.get_table(table_enum) 

263 if not self.config.dia_object_columns: 

264 columns = self._schema.get_apdb_columns(table_enum) 

265 else: 

266 columns = [table.c[col] for col in self.config.dia_object_columns] 

267 query = sql.select(*columns) 

268 

269 # build selection 

270 query = query.where(self._filterRegion(table, region)) 

271 

272 # select latest version of objects 

273 if self.config.dia_object_index != "last_object_table": 

274 query = query.where(table.c.validityEnd == None) # noqa: E711 

275 

276 # _LOG.debug("query: %s", query) 

277 

278 # execute select 

279 with Timer("DiaObject select", self.config.timer): 

280 with self._engine.begin() as conn: 

281 objects = pandas.read_sql_query(query, conn) 

282 _LOG.debug("found %s DiaObjects", len(objects)) 

283 return objects 

284 

285 def getDiaSources( 

286 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

287 ) -> Optional[pandas.DataFrame]: 

288 # docstring is inherited from a base class 

289 if self.config.read_sources_months == 0: 

290 _LOG.debug("Skip DiaSources fetching") 

291 return None 

292 

293 if object_ids is None: 

294 # region-based select 

295 return self._getDiaSourcesInRegion(region, visit_time) 

296 else: 

297 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

298 

299 def getDiaForcedSources( 

300 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

301 ) -> Optional[pandas.DataFrame]: 

302 """Return catalog of DiaForcedSource instances from a given region. 

303 

304 Parameters 

305 ---------- 

306 region : `lsst.sphgeom.Region` 

307 Region to search for DIASources. 

308 object_ids : iterable [ `int` ], optional 

309 List of DiaObject IDs to further constrain the set of returned 

310 sources. If list is empty then empty catalog is returned with a 

311 correct schema. 

312 visit_time : `lsst.daf.base.DateTime` 

313 Time of the current visit. 

314 

315 Returns 

316 ------- 

317 catalog : `pandas.DataFrame`, or `None` 

318 Catalog containing DiaSource records. `None` is returned if 

319 ``read_sources_months`` configuration parameter is set to 0. 

320 

321 Raises 

322 ------ 

323 NotImplementedError 

324 Raised if ``object_ids`` is `None`. 

325 

326 Notes 

327 ----- 

328 Even though base class allows `None` to be passed for ``object_ids``, 

329 this class requires ``object_ids`` to be not-`None`. 

330 `NotImplementedError` is raised if `None` is passed. 

331 

332 This method returns DiaForcedSource catalog for a region with additional 

333 filtering based on DiaObject IDs. Only a subset of DiaSource history 

334 is returned limited by ``read_forced_sources_months`` config parameter, 

335 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

336 is always returned with a correct schema (columns/types). 

337 """ 

338 

339 if self.config.read_forced_sources_months == 0: 

340 _LOG.debug("Skip DiaForceSources fetching") 

341 return None 

342 

343 if object_ids is None: 

344 # This implementation does not support region-based selection. 

345 raise NotImplementedError("Region-based selection is not supported") 

346 

347 # TODO: DateTime.MJD must be consistent with code in ap_association, 

348 # alternatively we can fill midPointTai ourselves in store() 

349 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

350 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

351 

352 with Timer("DiaForcedSource select", self.config.timer): 

353 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start) 

354 

355 _LOG.debug("found %s DiaForcedSources", len(sources)) 

356 return sources 

357 

358 def getInsertIds(self) -> list[ApdbInsertId] | None: 

359 # docstring is inherited from a base class 

360 if not self._schema.has_insert_id: 

361 return None 

362 

363 table = self._schema.get_table(ExtraTables.DiaInsertId) 

364 assert table is not None, "has_insert_id=True means it must be defined" 

365 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"]) 

366 with Timer("DiaObject insert id select", self.config.timer): 

367 with self._engine.connect() as conn: 

368 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

369 return [ApdbInsertId(row) for row in result.scalars()] 

370 

371 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

372 # docstring is inherited from a base class 

373 if not self._schema.has_insert_id: 

374 raise ValueError("APDB is not configured for history storage") 

375 

376 table = self._schema.get_table(ExtraTables.DiaInsertId) 

377 

378 insert_ids = [id.id for id in ids] 

379 where_clause = table.columns["insert_id"].in_(insert_ids) 

380 stmt = table.delete().where(where_clause) 

381 with self._engine.begin() as conn: 

382 conn.execute(stmt) 

383 

384 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

385 # docstring is inherited from a base class 

386 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

387 

388 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

389 # docstring is inherited from a base class 

390 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

391 

392 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

393 # docstring is inherited from a base class 

394 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

395 

396 def _get_history( 

397 self, 

398 ids: Iterable[ApdbInsertId], 

399 table_enum: ApdbTables, 

400 history_table_enum: ExtraTables, 

401 ) -> ApdbTableData: 

402 """Common implementation of the history methods.""" 

403 if not self._schema.has_insert_id: 

404 raise ValueError("APDB is not configured for history retrieval") 

405 

406 table = self._schema.get_table(table_enum) 

407 history_table = self._schema.get_table(history_table_enum) 

408 

409 join = table.join(history_table) 

410 insert_ids = [id.id for id in ids] 

411 history_id_column = history_table.columns["insert_id"] 

412 apdb_columns = self._schema.get_apdb_columns(table_enum) 

413 where_clause = history_id_column.in_(insert_ids) 

414 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

415 

416 # execute select 

417 with Timer(f"{table.name} history select", self.config.timer): 

418 connection = self._engine.connect(close_with_result=True) 

419 result = connection.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

420 return ApdbSqlTableData(result) 

421 

422 def getSSObjects(self) -> pandas.DataFrame: 

423 # docstring is inherited from a base class 

424 

425 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

426 query = sql.select(*columns) 

427 

428 # execute select 

429 with Timer("DiaObject select", self.config.timer): 

430 with self._engine.begin() as conn: 

431 objects = pandas.read_sql_query(query, conn) 

432 _LOG.debug("found %s SSObjects", len(objects)) 

433 return objects 

434 

435 def store( 

436 self, 

437 visit_time: dafBase.DateTime, 

438 objects: pandas.DataFrame, 

439 sources: Optional[pandas.DataFrame] = None, 

440 forced_sources: Optional[pandas.DataFrame] = None, 

441 ) -> None: 

442 # docstring is inherited from a base class 

443 

444 # We want to run all inserts in one transaction. 

445 with self._engine.begin() as connection: 

446 

447 insert_id: ApdbInsertId | None = None 

448 if self._schema.has_insert_id: 

449 insert_id = ApdbInsertId.new_insert_id() 

450 self._storeInsertId(insert_id, visit_time, connection) 

451 

452 # fill pixelId column for DiaObjects 

453 objects = self._add_obj_htm_index(objects) 

454 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

455 

456 if sources is not None: 

457 # copy pixelId column from DiaObjects to DiaSources 

458 sources = self._add_src_htm_index(sources, objects) 

459 self._storeDiaSources(sources, insert_id, connection) 

460 

461 if forced_sources is not None: 

462 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

463 

464 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

465 # docstring is inherited from a base class 

466 

467 idColumn = "ssObjectId" 

468 table = self._schema.get_table(ApdbTables.SSObject) 

469 

470 # everything to be done in single transaction 

471 with self._engine.begin() as conn: 

472 

473 # Find record IDs that already exist. Some types like np.int64 can 

474 # cause issues with sqlalchemy, convert them to int. 

475 ids = sorted(int(oid) for oid in objects[idColumn]) 

476 

477 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

478 result = conn.execute(query) 

479 knownIds = set(row[idColumn] for row in result) 

480 

481 filter = objects[idColumn].isin(knownIds) 

482 toUpdate = cast(pandas.DataFrame, objects[filter]) 

483 toInsert = cast(pandas.DataFrame, objects[~filter]) 

484 

485 # insert new records 

486 if len(toInsert) > 0: 

487 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema) 

488 

489 # update existing records 

490 if len(toUpdate) > 0: 

491 whereKey = f"{idColumn}_param" 

492 query = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

493 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

494 values = toUpdate.to_dict("records") 

495 result = conn.execute(query, values) 

496 

497 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

498 # docstring is inherited from a base class 

499 

500 table = self._schema.get_table(ApdbTables.DiaSource) 

501 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

502 

503 with self._engine.begin() as conn: 

504 # Need to make sure that every ID exists in the database, but 

505 # executemany may not support rowcount, so iterate and check what is 

506 # missing. 

507 missing_ids: List[int] = [] 

508 for key, value in idMap.items(): 

509 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

510 result = conn.execute(query, params) 

511 if result.rowcount == 0: 

512 missing_ids.append(key) 

513 if missing_ids: 

514 missing = ",".join(str(item) for item in missing_ids) 

515 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

516 

517 def dailyJob(self) -> None: 

518 # docstring is inherited from a base class 

519 

520 if self._engine.name == "postgresql": 

521 

522 # do VACUUM on all tables 

523 _LOG.info("Running VACUUM on all tables") 

524 connection = self._engine.raw_connection() 

525 ISOLATION_LEVEL_AUTOCOMMIT = 0 

526 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

527 cursor = connection.cursor() 

528 cursor.execute("VACUUM ANALYSE") 

529 

530 def countUnassociatedObjects(self) -> int: 

531 # docstring is inherited from a base class 

532 

533 # Retrieve the DiaObject table. 

534 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

535 

536 # Construct the sql statement. 

537 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

538 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

539 

540 # Return the count. 

541 with self._engine.begin() as conn: 

542 count = conn.scalar(stmt) 

543 

544 return count 

545 

546 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

547 """Returns catalog of DiaSource instances from given region. 

548 

549 Parameters 

550 ---------- 

551 region : `lsst.sphgeom.Region` 

552 Region to search for DIASources. 

553 visit_time : `lsst.daf.base.DateTime` 

554 Time of the current visit. 

555 

556 Returns 

557 ------- 

558 catalog : `pandas.DataFrame` 

559 Catalog containing DiaSource records. 

560 """ 

561 # TODO: DateTime.MJD must be consistent with code in ap_association, 

562 # alternatively we can fill midPointTai ourselves in store() 

563 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

564 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

565 

566 table = self._schema.get_table(ApdbTables.DiaSource) 

567 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

568 query = sql.select(*columns) 

569 

570 # build selection 

571 time_filter = table.columns["midPointTai"] > midPointTai_start 

572 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

573 query = query.where(where) 

574 

575 # execute select 

576 with Timer("DiaSource select", self.config.timer): 

577 with self._engine.begin() as conn: 

578 sources = pandas.read_sql_query(query, conn) 

579 _LOG.debug("found %s DiaSources", len(sources)) 

580 return sources 

581 

582 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

583 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

584 

585 Parameters 

586 ---------- 

587 object_ids : 

588 Collection of DiaObject IDs 

589 visit_time : `lsst.daf.base.DateTime` 

590 Time of the current visit. 

591 

592 Returns 

593 ------- 

594 catalog : `pandas.DataFrame` 

595 Catalog contaning DiaSource records. 

596 """ 

597 # TODO: DateTime.MJD must be consistent with code in ap_association, 

598 # alternatively we can fill midPointTai ourselves in store() 

599 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

600 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

601 

602 with Timer("DiaSource select", self.config.timer): 

603 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start) 

604 

605 _LOG.debug("found %s DiaSources", len(sources)) 

606 return sources 

607 

608 def _getSourcesByIDs( 

609 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float 

610 ) -> pandas.DataFrame: 

611 """Returns catalog of DiaSource or DiaForcedSource instances given set 

612 of DiaObject IDs. 

613 

614 Parameters 

615 ---------- 

616 table : `sqlalchemy.schema.Table` 

617 Database table. 

618 object_ids : 

619 Collection of DiaObject IDs 

620 midPointTai_start : `float` 

621 Earliest midPointTai to retrieve. 

622 

623 Returns 

624 ------- 

625 catalog : `pandas.DataFrame` 

626 Catalog contaning DiaSource records. `None` is returned if 

627 ``read_sources_months`` configuration parameter is set to 0 or 

628 when ``object_ids`` is empty. 

629 """ 

630 table = self._schema.get_table(table_enum) 

631 columns = self._schema.get_apdb_columns(table_enum) 

632 

633 sources: Optional[pandas.DataFrame] = None 

634 if len(object_ids) <= 0: 

635 _LOG.debug("ID list is empty, just fetch empty result") 

636 query = sql.select(*columns).where(False) 

637 with self._engine.begin() as conn: 

638 sources = pandas.read_sql_query(query, conn) 

639 else: 

640 data_frames: list[pandas.DataFrame] = [] 

641 for ids in chunk_iterable(sorted(object_ids), 1000): 

642 query = sql.select(*columns) 

643 

644 # Some types like np.int64 can cause issues with 

645 # sqlalchemy, convert them to int. 

646 int_ids = [int(oid) for oid in ids] 

647 

648 # select by object id 

649 query = query.where( 

650 sql.expression.and_( 

651 table.columns["diaObjectId"].in_(int_ids), 

652 table.columns["midPointTai"] > midPointTai_start, 

653 ) 

654 ) 

655 

656 # execute select 

657 with self._engine.begin() as conn: 

658 data_frames.append(pandas.read_sql_query(query, conn)) 

659 

660 if len(data_frames) == 1: 

661 sources = data_frames[0] 

662 else: 

663 sources = pandas.concat(data_frames) 

664 assert sources is not None, "Catalog cannot be None" 

665 return sources 

666 

667 def _storeInsertId( 

668 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

669 ) -> None: 

670 

671 dt = visit_time.toPython() 

672 

673 table = self._schema.get_table(ExtraTables.DiaInsertId) 

674 

675 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

676 connection.execute(stmt) 

677 

678 def _storeDiaObjects( 

679 self, 

680 objs: pandas.DataFrame, 

681 visit_time: dafBase.DateTime, 

682 insert_id: ApdbInsertId | None, 

683 connection: sqlalchemy.engine.Connection, 

684 ) -> None: 

685 """Store catalog of DiaObjects from current visit. 

686 

687 Parameters 

688 ---------- 

689 objs : `pandas.DataFrame` 

690 Catalog with DiaObject records. 

691 visit_time : `lsst.daf.base.DateTime` 

692 Time of the visit. 

693 insert_id : `ApdbInsertId` 

694 Insert identifier. 

695 """ 

696 

697 # Some types like np.int64 can cause issues with sqlalchemy, convert 

698 # them to int. 

699 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

700 _LOG.debug("first object ID: %d", ids[0]) 

701 

702 # TODO: Need to verify that we are using correct scale here for 

703 # DATETIME representation (see DM-31996). 

704 dt = visit_time.toPython() 

705 

706 # everything to be done in single transaction 

707 if self.config.dia_object_index == "last_object_table": 

708 

709 # insert and replace all records in LAST table, mysql and postgres have 

710 # non-standard features 

711 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

712 

713 # Drop the previous objects (pandas cannot upsert). 

714 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

715 

716 with Timer(table.name + " delete", self.config.timer): 

717 res = connection.execute(query) 

718 _LOG.debug("deleted %s objects", res.rowcount) 

719 

720 # DiaObjectLast is a subset of DiaObject, strip missing columns 

721 last_column_names = [column.name for column in table.columns] 

722 last_objs = objs[last_column_names] 

723 last_objs = _coerce_uint64(last_objs) 

724 

725 if "lastNonForcedSource" in last_objs.columns: 

726 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

727 # just in case. 

728 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

729 else: 

730 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

731 last_objs.set_index(extra_column.index, inplace=True) 

732 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

733 

734 with Timer("DiaObjectLast insert", self.config.timer): 

735 last_objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

736 else: 

737 

738 # truncate existing validity intervals 

739 table = self._schema.get_table(ApdbTables.DiaObject) 

740 

741 query = ( 

742 table.update() 

743 .values(validityEnd=dt) 

744 .where( 

745 sql.expression.and_( 

746 table.columns["diaObjectId"].in_(ids), 

747 table.columns["validityEnd"].is_(None), 

748 ) 

749 ) 

750 ) 

751 

752 # _LOG.debug("query: %s", query) 

753 

754 with Timer(table.name + " truncate", self.config.timer): 

755 res = connection.execute(query) 

756 _LOG.debug("truncated %s intervals", res.rowcount) 

757 

758 objs = _coerce_uint64(objs) 

759 

760 # Fill additional columns 

761 extra_columns: List[pandas.Series] = [] 

762 if "validityStart" in objs.columns: 

763 objs["validityStart"] = dt 

764 else: 

765 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

766 if "validityEnd" in objs.columns: 

767 objs["validityEnd"] = None 

768 else: 

769 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

770 if "lastNonForcedSource" in objs.columns: 

771 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

772 # just in case. 

773 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

774 else: 

775 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

776 if extra_columns: 

777 objs.set_index(extra_columns[0].index, inplace=True) 

778 objs = pandas.concat([objs] + extra_columns, axis="columns") 

779 

780 # Insert history data 

781 table = self._schema.get_table(ApdbTables.DiaObject) 

782 history_data: list[dict] = [] 

783 history_stmt: Any = None 

784 if insert_id is not None: 

785 pk_names = [column.name for column in table.primary_key] 

786 history_data = objs[pk_names].to_dict("records") 

787 for row in history_data: 

788 row["insert_id"] = insert_id.id 

789 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

790 history_stmt = history_table.insert() 

791 

792 # insert new versions 

793 with Timer("DiaObject insert", self.config.timer): 

794 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

795 if history_stmt is not None: 

796 connection.execute(history_stmt, *history_data) 

797 

798 def _storeDiaSources( 

799 self, 

800 sources: pandas.DataFrame, 

801 insert_id: ApdbInsertId | None, 

802 connection: sqlalchemy.engine.Connection, 

803 ) -> None: 

804 """Store catalog of DiaSources from current visit. 

805 

806 Parameters 

807 ---------- 

808 sources : `pandas.DataFrame` 

809 Catalog containing DiaSource records 

810 """ 

811 table = self._schema.get_table(ApdbTables.DiaSource) 

812 

813 # Insert history data 

814 history: list[dict] = [] 

815 history_stmt: Any = None 

816 if insert_id is not None: 

817 pk_names = [column.name for column in table.primary_key] 

818 history = sources[pk_names].to_dict("records") 

819 for row in history: 

820 row["insert_id"] = insert_id.id 

821 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

822 history_stmt = history_table.insert() 

823 

824 # everything to be done in single transaction 

825 with Timer("DiaSource insert", self.config.timer): 

826 sources = _coerce_uint64(sources) 

827 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

828 if history_stmt is not None: 

829 connection.execute(history_stmt, *history) 

830 

831 def _storeDiaForcedSources( 

832 self, 

833 sources: pandas.DataFrame, 

834 insert_id: ApdbInsertId | None, 

835 connection: sqlalchemy.engine.Connection, 

836 ) -> None: 

837 """Store a set of DiaForcedSources from current visit. 

838 

839 Parameters 

840 ---------- 

841 sources : `pandas.DataFrame` 

842 Catalog containing DiaForcedSource records 

843 """ 

844 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

845 

846 # Insert history data 

847 history: list[dict] = [] 

848 history_stmt: Any = None 

849 if insert_id is not None: 

850 pk_names = [column.name for column in table.primary_key] 

851 history = sources[pk_names].to_dict("records") 

852 for row in history: 

853 row["insert_id"] = insert_id.id 

854 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

855 history_stmt = history_table.insert() 

856 

857 # everything to be done in single transaction 

858 with Timer("DiaForcedSource insert", self.config.timer): 

859 sources = _coerce_uint64(sources) 

860 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

861 if history_stmt is not None: 

862 connection.execute(history_stmt, *history) 

863 

864 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

865 """Generate a set of HTM indices covering specified region. 

866 

867 Parameters 

868 ---------- 

869 region: `sphgeom.Region` 

870 Region that needs to be indexed. 

871 

872 Returns 

873 ------- 

874 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

875 """ 

876 _LOG.debug("region: %s", region) 

877 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

878 

879 return indices.ranges() 

880 

881 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ClauseElement: 

882 """Make SQLAlchemy expression for selecting records in a region.""" 

883 htm_index_column = table.columns[self.config.htm_index_column] 

884 exprlist = [] 

885 pixel_ranges = self._htm_indices(region) 

886 for low, upper in pixel_ranges: 

887 upper -= 1 

888 if low == upper: 

889 exprlist.append(htm_index_column == low) 

890 else: 

891 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

892 

893 return sql.expression.or_(*exprlist) 

894 

895 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

896 """Calculate HTM index for each record and add it to a DataFrame. 

897 

898 Notes 

899 ----- 

900 This overrides any existing column in a DataFrame with the same name 

901 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

902 returned. 

903 """ 

904 # calculate HTM index for every DiaObject 

905 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

906 ra_col, dec_col = self.config.ra_dec_columns 

907 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

908 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

909 idx = self.pixelator.index(uv3d) 

910 htm_index[i] = idx 

911 df = df.copy() 

912 df[self.config.htm_index_column] = htm_index 

913 return df 

914 

915 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

916 """Add pixelId column to DiaSource catalog. 

917 

918 Notes 

919 ----- 

920 This method copies pixelId value from a matching DiaObject record. 

921 DiaObject catalog needs to have a pixelId column filled by 

922 ``_add_obj_htm_index`` method and DiaSource records need to be 

923 associated to DiaObjects via ``diaObjectId`` column. 

924 

925 This overrides any existing column in a DataFrame with the same name 

926 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

927 returned. 

928 """ 

929 pixel_id_map: Dict[int, int] = { 

930 diaObjectId: pixelId 

931 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

932 } 

933 # DiaSources associated with SolarSystemObjects do not have an 

934 # associated DiaObject hence we skip them and set their htmIndex 

935 # value to 0. 

936 pixel_id_map[0] = 0 

937 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

938 for i, diaObjId in enumerate(sources["diaObjectId"]): 

939 htm_index[i] = pixel_id_map[diaObjId] 

940 sources = sources.copy() 

941 sources[self.config.htm_index_column] = htm_index 

942 return sources