Coverage for python/lsst/dax/apdb/apdbSql.py: 15%

428 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-07 01:16 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from typing import Any, Dict, List, Optional, Tuple, cast 

32 

33import lsst.daf.base as dafBase 

34import numpy as np 

35import pandas 

36import sqlalchemy 

37from felis.simple import Table 

38from lsst.pex.config import ChoiceField, Field, ListField 

39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

40from lsst.utils.iteration import chunk_iterable 

41from sqlalchemy import func, inspection, sql 

42from sqlalchemy.engine import Inspector 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

46from .apdbSchema import ApdbTables 

47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

48from .timer import Timer 

49 

50_LOG = logging.getLogger(__name__) 

51 

52 

53if pandas.__version__.partition(".")[0] == "1": 53 ↛ 94line 53 didn't jump to line 94, because the condition on line 53 was never false

54 

55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

56 """Terrible hack to workaround Pandas 1 incomplete support for 

57 sqlalchemy 2. 

58 

59 We need to pass a Connection instance to pandas method, but in SA 2 the 

60 Connection class lost ``connect`` method which is used by Pandas. 

61 """ 

62 

63 def __init__(self, connection: sqlalchemy.engine.Connection): 

64 self._connection = connection 

65 

66 def connect(self, **kwargs: Any) -> Any: 

67 return self 

68 

69 @property 

70 def execute(self) -> Callable: 

71 return self._connection.execute 

72 

73 @property 

74 def execution_options(self) -> Callable: 

75 return self._connection.execution_options 

76 

77 @property 

78 def connection(self) -> Any: 

79 return self._connection.connection 

80 

81 def __enter__(self) -> sqlalchemy.engine.Connection: 

82 return self._connection 

83 

84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

85 # Do not close connection here 

86 pass 

87 

88 @inspection._inspects(_ConnectionHackSA2) 

89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

90 return Inspector._construct(Inspector._init_connection, conn._connection) 

91 

92else: 

93 # Pandas 2.0 supports SQLAlchemy 2 correctly. 

94 def _ConnectionHackSA2( # type: ignore[no-redef] 

95 conn: sqlalchemy.engine.Connectable, 

96 ) -> sqlalchemy.engine.Connectable: 

97 return conn 

98 

99 

100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

101 """Change type of the uint64 columns to int64, return copy of data frame.""" 

102 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

103 return df.astype({name: np.int64 for name in names}) 

104 

105 

106def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

107 """Calculate starting point for time-based source search. 

108 

109 Parameters 

110 ---------- 

111 visit_time : `lsst.daf.base.DateTime` 

112 Time of current visit. 

113 months : `int` 

114 Number of months in the sources history. 

115 

116 Returns 

117 ------- 

118 time : `float` 

119 A ``midPointTai`` starting point, MJD time. 

120 """ 

121 # TODO: `system` must be consistent with the code in ap_association 

122 # (see DM-31996) 

123 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

124 

125 

126class ApdbSqlConfig(ApdbConfig): 

127 """APDB configuration class for SQL implementation (ApdbSql).""" 

128 

129 db_url = Field[str](doc="SQLAlchemy database connection URI") 

130 isolation_level = ChoiceField[str]( 

131 doc=( 

132 "Transaction isolation level, if unset then backend-default value " 

133 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

134 "Some backends may not support every allowed value." 

135 ), 

136 allowed={ 

137 "READ_COMMITTED": "Read committed", 

138 "READ_UNCOMMITTED": "Read uncommitted", 

139 "REPEATABLE_READ": "Repeatable read", 

140 "SERIALIZABLE": "Serializable", 

141 }, 

142 default=None, 

143 optional=True, 

144 ) 

145 connection_pool = Field[bool]( 

146 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

147 default=True, 

148 ) 

149 connection_timeout = Field[float]( 

150 doc=( 

151 "Maximum time to wait time for database lock to be released before exiting. " 

152 "Defaults to sqlalchemy defaults if not set." 

153 ), 

154 default=None, 

155 optional=True, 

156 ) 

157 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

158 dia_object_index = ChoiceField[str]( 

159 doc="Indexing mode for DiaObject table", 

160 allowed={ 

161 "baseline": "Index defined in baseline schema", 

162 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

163 "last_object_table": "Separate DiaObjectLast table", 

164 }, 

165 default="baseline", 

166 ) 

167 htm_level = Field[int](doc="HTM indexing level", default=20) 

168 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

169 htm_index_column = Field[str]( 

170 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

171 ) 

172 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names of ra/dec columns in DiaObject table") 

173 dia_object_columns = ListField[str]( 

174 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

175 ) 

176 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

177 namespace = Field[str]( 

178 doc=( 

179 "Namespace or schema name for all tables in APDB database. " 

180 "Presently only works for PostgreSQL backend. " 

181 "If schema with this name does not exist it will be created when " 

182 "APDB tables are created." 

183 ), 

184 default=None, 

185 optional=True, 

186 ) 

187 timer = Field[bool](doc="If True then print/log timing information", default=False) 

188 

189 def validate(self) -> None: 

190 super().validate() 

191 if len(self.ra_dec_columns) != 2: 

192 raise ValueError("ra_dec_columns must have exactly two column names") 

193 

194 

195class ApdbSqlTableData(ApdbTableData): 

196 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

197 

198 def __init__(self, result: sqlalchemy.engine.Result): 

199 self._keys = list(result.keys()) 

200 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

201 

202 def column_names(self) -> list[str]: 

203 return self._keys 

204 

205 def rows(self) -> Iterable[tuple]: 

206 return self._rows 

207 

208 

209class ApdbSql(Apdb): 

210 """Implementation of APDB interface based on SQL database. 

211 

212 The implementation is configured via standard ``pex_config`` mechanism 

213 using `ApdbSqlConfig` configuration class. For an example of different 

214 configurations check ``config/`` folder. 

215 

216 Parameters 

217 ---------- 

218 config : `ApdbSqlConfig` 

219 Configuration object. 

220 """ 

221 

222 ConfigClass = ApdbSqlConfig 

223 

224 def __init__(self, config: ApdbSqlConfig): 

225 config.validate() 

226 self.config = config 

227 

228 _LOG.debug("APDB Configuration:") 

229 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

230 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

231 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

232 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

233 _LOG.debug(" schema_file: %s", self.config.schema_file) 

234 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

235 _LOG.debug(" schema prefix: %s", self.config.prefix) 

236 

237 # engine is reused between multiple processes, make sure that we don't 

238 # share connections by disabling pool (by using NullPool class) 

239 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

240 conn_args: Dict[str, Any] = dict() 

241 if not self.config.connection_pool: 

242 kw.update(poolclass=NullPool) 

243 if self.config.isolation_level is not None: 

244 kw.update(isolation_level=self.config.isolation_level) 

245 elif self.config.db_url.startswith("sqlite"): # type: ignore 

246 # Use READ_UNCOMMITTED as default value for sqlite. 

247 kw.update(isolation_level="READ_UNCOMMITTED") 

248 if self.config.connection_timeout is not None: 

249 if self.config.db_url.startswith("sqlite"): 

250 conn_args.update(timeout=self.config.connection_timeout) 

251 elif self.config.db_url.startswith(("postgresql", "mysql")): 

252 conn_args.update(connect_timeout=self.config.connection_timeout) 

253 kw.update(connect_args=conn_args) 

254 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

255 

256 self._schema = ApdbSqlSchema( 

257 engine=self._engine, 

258 dia_object_index=self.config.dia_object_index, 

259 schema_file=self.config.schema_file, 

260 schema_name=self.config.schema_name, 

261 prefix=self.config.prefix, 

262 namespace=self.config.namespace, 

263 htm_index_column=self.config.htm_index_column, 

264 use_insert_id=config.use_insert_id, 

265 ) 

266 

267 self.pixelator = HtmPixelization(self.config.htm_level) 

268 self.use_insert_id = self._schema.has_insert_id 

269 

270 def tableRowCount(self) -> Dict[str, int]: 

271 """Returns dictionary with the table names and row counts. 

272 

273 Used by ``ap_proto`` to keep track of the size of the database tables. 

274 Depending on database technology this could be expensive operation. 

275 

276 Returns 

277 ------- 

278 row_counts : `dict` 

279 Dict where key is a table name and value is a row count. 

280 """ 

281 res = {} 

282 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

283 if self.config.dia_object_index == "last_object_table": 

284 tables.append(ApdbTables.DiaObjectLast) 

285 with self._engine.begin() as conn: 

286 for table in tables: 

287 sa_table = self._schema.get_table(table) 

288 stmt = sql.select(func.count()).select_from(sa_table) 

289 count: int = conn.execute(stmt).scalar_one() 

290 res[table.name] = count 

291 

292 return res 

293 

294 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

295 # docstring is inherited from a base class 

296 return self._schema.tableSchemas.get(table) 

297 

298 def makeSchema(self, drop: bool = False) -> None: 

299 # docstring is inherited from a base class 

300 self._schema.makeSchema(drop=drop) 

301 

302 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

303 # docstring is inherited from a base class 

304 

305 # decide what columns we need 

306 if self.config.dia_object_index == "last_object_table": 

307 table_enum = ApdbTables.DiaObjectLast 

308 else: 

309 table_enum = ApdbTables.DiaObject 

310 table = self._schema.get_table(table_enum) 

311 if not self.config.dia_object_columns: 

312 columns = self._schema.get_apdb_columns(table_enum) 

313 else: 

314 columns = [table.c[col] for col in self.config.dia_object_columns] 

315 query = sql.select(*columns) 

316 

317 # build selection 

318 query = query.where(self._filterRegion(table, region)) 

319 

320 # select latest version of objects 

321 if self.config.dia_object_index != "last_object_table": 

322 query = query.where(table.c.validityEnd == None) # noqa: E711 

323 

324 # _LOG.debug("query: %s", query) 

325 

326 # execute select 

327 with Timer("DiaObject select", self.config.timer): 

328 with self._engine.begin() as conn: 

329 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

330 _LOG.debug("found %s DiaObjects", len(objects)) 

331 return objects 

332 

333 def getDiaSources( 

334 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

335 ) -> Optional[pandas.DataFrame]: 

336 # docstring is inherited from a base class 

337 if self.config.read_sources_months == 0: 

338 _LOG.debug("Skip DiaSources fetching") 

339 return None 

340 

341 if object_ids is None: 

342 # region-based select 

343 return self._getDiaSourcesInRegion(region, visit_time) 

344 else: 

345 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

346 

347 def getDiaForcedSources( 

348 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

349 ) -> Optional[pandas.DataFrame]: 

350 """Return catalog of DiaForcedSource instances from a given region. 

351 

352 Parameters 

353 ---------- 

354 region : `lsst.sphgeom.Region` 

355 Region to search for DIASources. 

356 object_ids : iterable [ `int` ], optional 

357 List of DiaObject IDs to further constrain the set of returned 

358 sources. If list is empty then empty catalog is returned with a 

359 correct schema. 

360 visit_time : `lsst.daf.base.DateTime` 

361 Time of the current visit. 

362 

363 Returns 

364 ------- 

365 catalog : `pandas.DataFrame`, or `None` 

366 Catalog containing DiaSource records. `None` is returned if 

367 ``read_sources_months`` configuration parameter is set to 0. 

368 

369 Raises 

370 ------ 

371 NotImplementedError 

372 Raised if ``object_ids`` is `None`. 

373 

374 Notes 

375 ----- 

376 Even though base class allows `None` to be passed for ``object_ids``, 

377 this class requires ``object_ids`` to be not-`None`. 

378 `NotImplementedError` is raised if `None` is passed. 

379 

380 This method returns DiaForcedSource catalog for a region with additional 

381 filtering based on DiaObject IDs. Only a subset of DiaSource history 

382 is returned limited by ``read_forced_sources_months`` config parameter, 

383 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

384 is always returned with a correct schema (columns/types). 

385 """ 

386 

387 if self.config.read_forced_sources_months == 0: 

388 _LOG.debug("Skip DiaForceSources fetching") 

389 return None 

390 

391 if object_ids is None: 

392 # This implementation does not support region-based selection. 

393 raise NotImplementedError("Region-based selection is not supported") 

394 

395 # TODO: DateTime.MJD must be consistent with code in ap_association, 

396 # alternatively we can fill midPointTai ourselves in store() 

397 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

398 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

399 

400 with Timer("DiaForcedSource select", self.config.timer): 

401 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start) 

402 

403 _LOG.debug("found %s DiaForcedSources", len(sources)) 

404 return sources 

405 

406 def getInsertIds(self) -> list[ApdbInsertId] | None: 

407 # docstring is inherited from a base class 

408 if not self._schema.has_insert_id: 

409 return None 

410 

411 table = self._schema.get_table(ExtraTables.DiaInsertId) 

412 assert table is not None, "has_insert_id=True means it must be defined" 

413 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"]) 

414 with Timer("DiaObject insert id select", self.config.timer): 

415 with self._engine.connect() as conn: 

416 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

417 return [ApdbInsertId(row) for row in result.scalars()] 

418 

419 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

420 # docstring is inherited from a base class 

421 if not self._schema.has_insert_id: 

422 raise ValueError("APDB is not configured for history storage") 

423 

424 table = self._schema.get_table(ExtraTables.DiaInsertId) 

425 

426 insert_ids = [id.id for id in ids] 

427 where_clause = table.columns["insert_id"].in_(insert_ids) 

428 stmt = table.delete().where(where_clause) 

429 with self._engine.begin() as conn: 

430 conn.execute(stmt) 

431 

432 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

433 # docstring is inherited from a base class 

434 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

435 

436 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

437 # docstring is inherited from a base class 

438 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

439 

440 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

441 # docstring is inherited from a base class 

442 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

443 

444 def _get_history( 

445 self, 

446 ids: Iterable[ApdbInsertId], 

447 table_enum: ApdbTables, 

448 history_table_enum: ExtraTables, 

449 ) -> ApdbTableData: 

450 """Common implementation of the history methods.""" 

451 if not self._schema.has_insert_id: 

452 raise ValueError("APDB is not configured for history retrieval") 

453 

454 table = self._schema.get_table(table_enum) 

455 history_table = self._schema.get_table(history_table_enum) 

456 

457 join = table.join(history_table) 

458 insert_ids = [id.id for id in ids] 

459 history_id_column = history_table.columns["insert_id"] 

460 apdb_columns = self._schema.get_apdb_columns(table_enum) 

461 where_clause = history_id_column.in_(insert_ids) 

462 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

463 

464 # execute select 

465 with Timer(f"{table.name} history select", self.config.timer): 

466 with self._engine.begin() as conn: 

467 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

468 return ApdbSqlTableData(result) 

469 

470 def getSSObjects(self) -> pandas.DataFrame: 

471 # docstring is inherited from a base class 

472 

473 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

474 query = sql.select(*columns) 

475 

476 # execute select 

477 with Timer("DiaObject select", self.config.timer): 

478 with self._engine.begin() as conn: 

479 objects = pandas.read_sql_query(query, conn) 

480 _LOG.debug("found %s SSObjects", len(objects)) 

481 return objects 

482 

483 def store( 

484 self, 

485 visit_time: dafBase.DateTime, 

486 objects: pandas.DataFrame, 

487 sources: Optional[pandas.DataFrame] = None, 

488 forced_sources: Optional[pandas.DataFrame] = None, 

489 ) -> None: 

490 # docstring is inherited from a base class 

491 

492 # We want to run all inserts in one transaction. 

493 with self._engine.begin() as connection: 

494 insert_id: ApdbInsertId | None = None 

495 if self._schema.has_insert_id: 

496 insert_id = ApdbInsertId.new_insert_id() 

497 self._storeInsertId(insert_id, visit_time, connection) 

498 

499 # fill pixelId column for DiaObjects 

500 objects = self._add_obj_htm_index(objects) 

501 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

502 

503 if sources is not None: 

504 # copy pixelId column from DiaObjects to DiaSources 

505 sources = self._add_src_htm_index(sources, objects) 

506 self._storeDiaSources(sources, insert_id, connection) 

507 

508 if forced_sources is not None: 

509 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

510 

511 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

512 # docstring is inherited from a base class 

513 

514 idColumn = "ssObjectId" 

515 table = self._schema.get_table(ApdbTables.SSObject) 

516 

517 # everything to be done in single transaction 

518 with self._engine.begin() as conn: 

519 # Find record IDs that already exist. Some types like np.int64 can 

520 # cause issues with sqlalchemy, convert them to int. 

521 ids = sorted(int(oid) for oid in objects[idColumn]) 

522 

523 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

524 result = conn.execute(query) 

525 knownIds = set(row.ssObjectId for row in result) 

526 

527 filter = objects[idColumn].isin(knownIds) 

528 toUpdate = cast(pandas.DataFrame, objects[filter]) 

529 toInsert = cast(pandas.DataFrame, objects[~filter]) 

530 

531 # insert new records 

532 if len(toInsert) > 0: 

533 toInsert.to_sql( 

534 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

535 ) 

536 

537 # update existing records 

538 if len(toUpdate) > 0: 

539 whereKey = f"{idColumn}_param" 

540 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

541 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

542 values = toUpdate.to_dict("records") 

543 result = conn.execute(update, values) 

544 

545 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

546 # docstring is inherited from a base class 

547 

548 table = self._schema.get_table(ApdbTables.DiaSource) 

549 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

550 

551 with self._engine.begin() as conn: 

552 # Need to make sure that every ID exists in the database, but 

553 # executemany may not support rowcount, so iterate and check what is 

554 # missing. 

555 missing_ids: List[int] = [] 

556 for key, value in idMap.items(): 

557 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

558 result = conn.execute(query, params) 

559 if result.rowcount == 0: 

560 missing_ids.append(key) 

561 if missing_ids: 

562 missing = ",".join(str(item) for item in missing_ids) 

563 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

564 

565 def dailyJob(self) -> None: 

566 # docstring is inherited from a base class 

567 pass 

568 

569 def countUnassociatedObjects(self) -> int: 

570 # docstring is inherited from a base class 

571 

572 # Retrieve the DiaObject table. 

573 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

574 

575 # Construct the sql statement. 

576 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

577 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

578 

579 # Return the count. 

580 with self._engine.begin() as conn: 

581 count = conn.execute(stmt).scalar_one() 

582 

583 return count 

584 

585 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

586 """Returns catalog of DiaSource instances from given region. 

587 

588 Parameters 

589 ---------- 

590 region : `lsst.sphgeom.Region` 

591 Region to search for DIASources. 

592 visit_time : `lsst.daf.base.DateTime` 

593 Time of the current visit. 

594 

595 Returns 

596 ------- 

597 catalog : `pandas.DataFrame` 

598 Catalog containing DiaSource records. 

599 """ 

600 # TODO: DateTime.MJD must be consistent with code in ap_association, 

601 # alternatively we can fill midPointTai ourselves in store() 

602 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

603 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

604 

605 table = self._schema.get_table(ApdbTables.DiaSource) 

606 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

607 query = sql.select(*columns) 

608 

609 # build selection 

610 time_filter = table.columns["midPointTai"] > midPointTai_start 

611 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

612 query = query.where(where) 

613 

614 # execute select 

615 with Timer("DiaSource select", self.config.timer): 

616 with self._engine.begin() as conn: 

617 sources = pandas.read_sql_query(query, conn) 

618 _LOG.debug("found %s DiaSources", len(sources)) 

619 return sources 

620 

621 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

622 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

623 

624 Parameters 

625 ---------- 

626 object_ids : 

627 Collection of DiaObject IDs 

628 visit_time : `lsst.daf.base.DateTime` 

629 Time of the current visit. 

630 

631 Returns 

632 ------- 

633 catalog : `pandas.DataFrame` 

634 Catalog contaning DiaSource records. 

635 """ 

636 # TODO: DateTime.MJD must be consistent with code in ap_association, 

637 # alternatively we can fill midPointTai ourselves in store() 

638 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

639 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

640 

641 with Timer("DiaSource select", self.config.timer): 

642 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start) 

643 

644 _LOG.debug("found %s DiaSources", len(sources)) 

645 return sources 

646 

647 def _getSourcesByIDs( 

648 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float 

649 ) -> pandas.DataFrame: 

650 """Returns catalog of DiaSource or DiaForcedSource instances given set 

651 of DiaObject IDs. 

652 

653 Parameters 

654 ---------- 

655 table : `sqlalchemy.schema.Table` 

656 Database table. 

657 object_ids : 

658 Collection of DiaObject IDs 

659 midPointTai_start : `float` 

660 Earliest midPointTai to retrieve. 

661 

662 Returns 

663 ------- 

664 catalog : `pandas.DataFrame` 

665 Catalog contaning DiaSource records. `None` is returned if 

666 ``read_sources_months`` configuration parameter is set to 0 or 

667 when ``object_ids`` is empty. 

668 """ 

669 table = self._schema.get_table(table_enum) 

670 columns = self._schema.get_apdb_columns(table_enum) 

671 

672 sources: Optional[pandas.DataFrame] = None 

673 if len(object_ids) <= 0: 

674 _LOG.debug("ID list is empty, just fetch empty result") 

675 query = sql.select(*columns).where(sql.literal(False)) 

676 with self._engine.begin() as conn: 

677 sources = pandas.read_sql_query(query, conn) 

678 else: 

679 data_frames: list[pandas.DataFrame] = [] 

680 for ids in chunk_iterable(sorted(object_ids), 1000): 

681 query = sql.select(*columns) 

682 

683 # Some types like np.int64 can cause issues with 

684 # sqlalchemy, convert them to int. 

685 int_ids = [int(oid) for oid in ids] 

686 

687 # select by object id 

688 query = query.where( 

689 sql.expression.and_( 

690 table.columns["diaObjectId"].in_(int_ids), 

691 table.columns["midPointTai"] > midPointTai_start, 

692 ) 

693 ) 

694 

695 # execute select 

696 with self._engine.begin() as conn: 

697 data_frames.append(pandas.read_sql_query(query, conn)) 

698 

699 if len(data_frames) == 1: 

700 sources = data_frames[0] 

701 else: 

702 sources = pandas.concat(data_frames) 

703 assert sources is not None, "Catalog cannot be None" 

704 return sources 

705 

706 def _storeInsertId( 

707 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

708 ) -> None: 

709 dt = visit_time.toPython() 

710 

711 table = self._schema.get_table(ExtraTables.DiaInsertId) 

712 

713 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

714 connection.execute(stmt) 

715 

716 def _storeDiaObjects( 

717 self, 

718 objs: pandas.DataFrame, 

719 visit_time: dafBase.DateTime, 

720 insert_id: ApdbInsertId | None, 

721 connection: sqlalchemy.engine.Connection, 

722 ) -> None: 

723 """Store catalog of DiaObjects from current visit. 

724 

725 Parameters 

726 ---------- 

727 objs : `pandas.DataFrame` 

728 Catalog with DiaObject records. 

729 visit_time : `lsst.daf.base.DateTime` 

730 Time of the visit. 

731 insert_id : `ApdbInsertId` 

732 Insert identifier. 

733 """ 

734 

735 # Some types like np.int64 can cause issues with sqlalchemy, convert 

736 # them to int. 

737 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

738 _LOG.debug("first object ID: %d", ids[0]) 

739 

740 # TODO: Need to verify that we are using correct scale here for 

741 # DATETIME representation (see DM-31996). 

742 dt = visit_time.toPython() 

743 

744 # everything to be done in single transaction 

745 if self.config.dia_object_index == "last_object_table": 

746 # insert and replace all records in LAST table, mysql and postgres have 

747 # non-standard features 

748 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

749 

750 # Drop the previous objects (pandas cannot upsert). 

751 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

752 

753 with Timer(table.name + " delete", self.config.timer): 

754 res = connection.execute(query) 

755 _LOG.debug("deleted %s objects", res.rowcount) 

756 

757 # DiaObjectLast is a subset of DiaObject, strip missing columns 

758 last_column_names = [column.name for column in table.columns] 

759 last_objs = objs[last_column_names] 

760 last_objs = _coerce_uint64(last_objs) 

761 

762 if "lastNonForcedSource" in last_objs.columns: 

763 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

764 # just in case. 

765 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

766 else: 

767 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

768 last_objs.set_index(extra_column.index, inplace=True) 

769 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

770 

771 with Timer("DiaObjectLast insert", self.config.timer): 

772 last_objs.to_sql( 

773 table.name, 

774 _ConnectionHackSA2(connection), 

775 if_exists="append", 

776 index=False, 

777 schema=table.schema, 

778 ) 

779 else: 

780 # truncate existing validity intervals 

781 table = self._schema.get_table(ApdbTables.DiaObject) 

782 

783 update = ( 

784 table.update() 

785 .values(validityEnd=dt) 

786 .where( 

787 sql.expression.and_( 

788 table.columns["diaObjectId"].in_(ids), 

789 table.columns["validityEnd"].is_(None), 

790 ) 

791 ) 

792 ) 

793 

794 # _LOG.debug("query: %s", query) 

795 

796 with Timer(table.name + " truncate", self.config.timer): 

797 res = connection.execute(update) 

798 _LOG.debug("truncated %s intervals", res.rowcount) 

799 

800 objs = _coerce_uint64(objs) 

801 

802 # Fill additional columns 

803 extra_columns: List[pandas.Series] = [] 

804 if "validityStart" in objs.columns: 

805 objs["validityStart"] = dt 

806 else: 

807 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

808 if "validityEnd" in objs.columns: 

809 objs["validityEnd"] = None 

810 else: 

811 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

812 if "lastNonForcedSource" in objs.columns: 

813 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

814 # just in case. 

815 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

816 else: 

817 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

818 if extra_columns: 

819 objs.set_index(extra_columns[0].index, inplace=True) 

820 objs = pandas.concat([objs] + extra_columns, axis="columns") 

821 

822 # Insert history data 

823 table = self._schema.get_table(ApdbTables.DiaObject) 

824 history_data: list[dict] = [] 

825 history_stmt: Any = None 

826 if insert_id is not None: 

827 pk_names = [column.name for column in table.primary_key] 

828 history_data = objs[pk_names].to_dict("records") 

829 for row in history_data: 

830 row["insert_id"] = insert_id.id 

831 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

832 history_stmt = history_table.insert() 

833 

834 # insert new versions 

835 with Timer("DiaObject insert", self.config.timer): 

836 objs.to_sql( 

837 table.name, 

838 _ConnectionHackSA2(connection), 

839 if_exists="append", 

840 index=False, 

841 schema=table.schema, 

842 ) 

843 if history_stmt is not None: 

844 connection.execute(history_stmt, history_data) 

845 

846 def _storeDiaSources( 

847 self, 

848 sources: pandas.DataFrame, 

849 insert_id: ApdbInsertId | None, 

850 connection: sqlalchemy.engine.Connection, 

851 ) -> None: 

852 """Store catalog of DiaSources from current visit. 

853 

854 Parameters 

855 ---------- 

856 sources : `pandas.DataFrame` 

857 Catalog containing DiaSource records 

858 """ 

859 table = self._schema.get_table(ApdbTables.DiaSource) 

860 

861 # Insert history data 

862 history: list[dict] = [] 

863 history_stmt: Any = None 

864 if insert_id is not None: 

865 pk_names = [column.name for column in table.primary_key] 

866 history = sources[pk_names].to_dict("records") 

867 for row in history: 

868 row["insert_id"] = insert_id.id 

869 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

870 history_stmt = history_table.insert() 

871 

872 # everything to be done in single transaction 

873 with Timer("DiaSource insert", self.config.timer): 

874 sources = _coerce_uint64(sources) 

875 sources.to_sql( 

876 table.name, 

877 _ConnectionHackSA2(connection), 

878 if_exists="append", 

879 index=False, 

880 schema=table.schema, 

881 ) 

882 if history_stmt is not None: 

883 connection.execute(history_stmt, history) 

884 

885 def _storeDiaForcedSources( 

886 self, 

887 sources: pandas.DataFrame, 

888 insert_id: ApdbInsertId | None, 

889 connection: sqlalchemy.engine.Connection, 

890 ) -> None: 

891 """Store a set of DiaForcedSources from current visit. 

892 

893 Parameters 

894 ---------- 

895 sources : `pandas.DataFrame` 

896 Catalog containing DiaForcedSource records 

897 """ 

898 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

899 

900 # Insert history data 

901 history: list[dict] = [] 

902 history_stmt: Any = None 

903 if insert_id is not None: 

904 pk_names = [column.name for column in table.primary_key] 

905 history = sources[pk_names].to_dict("records") 

906 for row in history: 

907 row["insert_id"] = insert_id.id 

908 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

909 history_stmt = history_table.insert() 

910 

911 # everything to be done in single transaction 

912 with Timer("DiaForcedSource insert", self.config.timer): 

913 sources = _coerce_uint64(sources) 

914 sources.to_sql( 

915 table.name, 

916 _ConnectionHackSA2(connection), 

917 if_exists="append", 

918 index=False, 

919 schema=table.schema, 

920 ) 

921 if history_stmt is not None: 

922 connection.execute(history_stmt, history) 

923 

924 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

925 """Generate a set of HTM indices covering specified region. 

926 

927 Parameters 

928 ---------- 

929 region: `sphgeom.Region` 

930 Region that needs to be indexed. 

931 

932 Returns 

933 ------- 

934 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

935 """ 

936 _LOG.debug("region: %s", region) 

937 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

938 

939 return indices.ranges() 

940 

941 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

942 """Make SQLAlchemy expression for selecting records in a region.""" 

943 htm_index_column = table.columns[self.config.htm_index_column] 

944 exprlist = [] 

945 pixel_ranges = self._htm_indices(region) 

946 for low, upper in pixel_ranges: 

947 upper -= 1 

948 if low == upper: 

949 exprlist.append(htm_index_column == low) 

950 else: 

951 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

952 

953 return sql.expression.or_(*exprlist) 

954 

955 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

956 """Calculate HTM index for each record and add it to a DataFrame. 

957 

958 Notes 

959 ----- 

960 This overrides any existing column in a DataFrame with the same name 

961 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

962 returned. 

963 """ 

964 # calculate HTM index for every DiaObject 

965 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

966 ra_col, dec_col = self.config.ra_dec_columns 

967 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

968 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

969 idx = self.pixelator.index(uv3d) 

970 htm_index[i] = idx 

971 df = df.copy() 

972 df[self.config.htm_index_column] = htm_index 

973 return df 

974 

975 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

976 """Add pixelId column to DiaSource catalog. 

977 

978 Notes 

979 ----- 

980 This method copies pixelId value from a matching DiaObject record. 

981 DiaObject catalog needs to have a pixelId column filled by 

982 ``_add_obj_htm_index`` method and DiaSource records need to be 

983 associated to DiaObjects via ``diaObjectId`` column. 

984 

985 This overrides any existing column in a DataFrame with the same name 

986 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

987 returned. 

988 """ 

989 pixel_id_map: Dict[int, int] = { 

990 diaObjectId: pixelId 

991 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

992 } 

993 # DiaSources associated with SolarSystemObjects do not have an 

994 # associated DiaObject hence we skip them and set their htmIndex 

995 # value to 0. 

996 pixel_id_map[0] = 0 

997 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

998 for i, diaObjId in enumerate(sources["diaObjectId"]): 

999 htm_index[i] = pixel_id_map[diaObjId] 

1000 sources = sources.copy() 

1001 sources[self.config.htm_index_column] = htm_index 

1002 return sources