Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

428 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-26 10:23 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from typing import Any, Dict, List, Optional, Tuple, cast 

32 

33import lsst.daf.base as dafBase 

34import numpy as np 

35import pandas 

36import sqlalchemy 

37from felis.simple import Table 

38from lsst.pex.config import ChoiceField, Field, ListField 

39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

40from lsst.utils.iteration import chunk_iterable 

41from sqlalchemy import func, inspection, sql 

42from sqlalchemy.engine import Inspector 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

46from .apdbSchema import ApdbTables 

47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

48from .timer import Timer 

49 

50_LOG = logging.getLogger(__name__) 

51 

52 

53if pandas.__version__.partition(".")[0] == "1": 53 ↛ 55line 53 didn't jump to line 55, because the condition on line 53 was never true

54 

55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

56 """Terrible hack to workaround Pandas 1 incomplete support for 

57 sqlalchemy 2. 

58 

59 We need to pass a Connection instance to pandas method, but in SA 2 the 

60 Connection class lost ``connect`` method which is used by Pandas. 

61 """ 

62 

63 def __init__(self, connection: sqlalchemy.engine.Connection): 

64 self._connection = connection 

65 

66 def connect(self, **kwargs: Any) -> Any: 

67 return self 

68 

69 @property 

70 def execute(self) -> Callable: 

71 return self._connection.execute 

72 

73 @property 

74 def execution_options(self) -> Callable: 

75 return self._connection.execution_options 

76 

77 @property 

78 def connection(self) -> Any: 

79 return self._connection.connection 

80 

81 def __enter__(self) -> sqlalchemy.engine.Connection: 

82 return self._connection 

83 

84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

85 # Do not close connection here 

86 pass 

87 

88 @inspection._inspects(_ConnectionHackSA2) 

89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

90 return Inspector._construct(Inspector._init_connection, conn._connection) 

91 

92else: 

93 # Pandas 2.0 supports SQLAlchemy 2 correctly. 

94 def _ConnectionHackSA2( # type: ignore[no-redef] 

95 conn: sqlalchemy.engine.Connectable, 

96 ) -> sqlalchemy.engine.Connectable: 

97 return conn 

98 

99 

100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

101 """Change the type of uint64 columns to int64, and return copy of data 

102 frame. 

103 """ 

104 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

105 return df.astype({name: np.int64 for name in names}) 

106 

107 

108def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

109 """Calculate starting point for time-based source search. 

110 

111 Parameters 

112 ---------- 

113 visit_time : `lsst.daf.base.DateTime` 

114 Time of current visit. 

115 months : `int` 

116 Number of months in the sources history. 

117 

118 Returns 

119 ------- 

120 time : `float` 

121 A ``midpointMjdTai`` starting point, MJD time. 

122 """ 

123 # TODO: `system` must be consistent with the code in ap_association 

124 # (see DM-31996) 

125 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

126 

127 

128class ApdbSqlConfig(ApdbConfig): 

129 """APDB configuration class for SQL implementation (ApdbSql).""" 

130 

131 db_url = Field[str](doc="SQLAlchemy database connection URI") 

132 isolation_level = ChoiceField[str]( 

133 doc=( 

134 "Transaction isolation level, if unset then backend-default value " 

135 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

136 "Some backends may not support every allowed value." 

137 ), 

138 allowed={ 

139 "READ_COMMITTED": "Read committed", 

140 "READ_UNCOMMITTED": "Read uncommitted", 

141 "REPEATABLE_READ": "Repeatable read", 

142 "SERIALIZABLE": "Serializable", 

143 }, 

144 default=None, 

145 optional=True, 

146 ) 

147 connection_pool = Field[bool]( 

148 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

149 default=True, 

150 ) 

151 connection_timeout = Field[float]( 

152 doc=( 

153 "Maximum time to wait time for database lock to be released before exiting. " 

154 "Defaults to sqlalchemy defaults if not set." 

155 ), 

156 default=None, 

157 optional=True, 

158 ) 

159 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

160 dia_object_index = ChoiceField[str]( 

161 doc="Indexing mode for DiaObject table", 

162 allowed={ 

163 "baseline": "Index defined in baseline schema", 

164 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

165 "last_object_table": "Separate DiaObjectLast table", 

166 }, 

167 default="baseline", 

168 ) 

169 htm_level = Field[int](doc="HTM indexing level", default=20) 

170 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

171 htm_index_column = Field[str]( 

172 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

173 ) 

174 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

175 dia_object_columns = ListField[str]( 

176 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

177 ) 

178 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

179 namespace = Field[str]( 

180 doc=( 

181 "Namespace or schema name for all tables in APDB database. " 

182 "Presently only works for PostgreSQL backend. " 

183 "If schema with this name does not exist it will be created when " 

184 "APDB tables are created." 

185 ), 

186 default=None, 

187 optional=True, 

188 ) 

189 timer = Field[bool](doc="If True then print/log timing information", default=False) 

190 

191 def validate(self) -> None: 

192 super().validate() 

193 if len(self.ra_dec_columns) != 2: 

194 raise ValueError("ra_dec_columns must have exactly two column names") 

195 

196 

197class ApdbSqlTableData(ApdbTableData): 

198 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

199 

200 def __init__(self, result: sqlalchemy.engine.Result): 

201 self._keys = list(result.keys()) 

202 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

203 

204 def column_names(self) -> list[str]: 

205 return self._keys 

206 

207 def rows(self) -> Iterable[tuple]: 

208 return self._rows 

209 

210 

211class ApdbSql(Apdb): 

212 """Implementation of APDB interface based on SQL database. 

213 

214 The implementation is configured via standard ``pex_config`` mechanism 

215 using `ApdbSqlConfig` configuration class. For an example of different 

216 configurations check ``config/`` folder. 

217 

218 Parameters 

219 ---------- 

220 config : `ApdbSqlConfig` 

221 Configuration object. 

222 """ 

223 

224 ConfigClass = ApdbSqlConfig 

225 

226 def __init__(self, config: ApdbSqlConfig): 

227 config.validate() 

228 self.config = config 

229 

230 _LOG.debug("APDB Configuration:") 

231 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

232 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

233 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

234 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

235 _LOG.debug(" schema_file: %s", self.config.schema_file) 

236 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

237 _LOG.debug(" schema prefix: %s", self.config.prefix) 

238 

239 # engine is reused between multiple processes, make sure that we don't 

240 # share connections by disabling pool (by using NullPool class) 

241 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

242 conn_args: Dict[str, Any] = dict() 

243 if not self.config.connection_pool: 

244 kw.update(poolclass=NullPool) 

245 if self.config.isolation_level is not None: 

246 kw.update(isolation_level=self.config.isolation_level) 

247 elif self.config.db_url.startswith("sqlite"): # type: ignore 

248 # Use READ_UNCOMMITTED as default value for sqlite. 

249 kw.update(isolation_level="READ_UNCOMMITTED") 

250 if self.config.connection_timeout is not None: 

251 if self.config.db_url.startswith("sqlite"): 

252 conn_args.update(timeout=self.config.connection_timeout) 

253 elif self.config.db_url.startswith(("postgresql", "mysql")): 

254 conn_args.update(connect_timeout=self.config.connection_timeout) 

255 kw.update(connect_args=conn_args) 

256 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

257 

258 self._schema = ApdbSqlSchema( 

259 engine=self._engine, 

260 dia_object_index=self.config.dia_object_index, 

261 schema_file=self.config.schema_file, 

262 schema_name=self.config.schema_name, 

263 prefix=self.config.prefix, 

264 namespace=self.config.namespace, 

265 htm_index_column=self.config.htm_index_column, 

266 use_insert_id=config.use_insert_id, 

267 ) 

268 

269 self.pixelator = HtmPixelization(self.config.htm_level) 

270 self.use_insert_id = self._schema.has_insert_id 

271 

272 def tableRowCount(self) -> Dict[str, int]: 

273 """Return dictionary with the table names and row counts. 

274 

275 Used by ``ap_proto`` to keep track of the size of the database tables. 

276 Depending on database technology this could be expensive operation. 

277 

278 Returns 

279 ------- 

280 row_counts : `dict` 

281 Dict where key is a table name and value is a row count. 

282 """ 

283 res = {} 

284 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

285 if self.config.dia_object_index == "last_object_table": 

286 tables.append(ApdbTables.DiaObjectLast) 

287 with self._engine.begin() as conn: 

288 for table in tables: 

289 sa_table = self._schema.get_table(table) 

290 stmt = sql.select(func.count()).select_from(sa_table) 

291 count: int = conn.execute(stmt).scalar_one() 

292 res[table.name] = count 

293 

294 return res 

295 

296 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

297 # docstring is inherited from a base class 

298 return self._schema.tableSchemas.get(table) 

299 

300 def makeSchema(self, drop: bool = False) -> None: 

301 # docstring is inherited from a base class 

302 self._schema.makeSchema(drop=drop) 

303 

304 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

305 # docstring is inherited from a base class 

306 

307 # decide what columns we need 

308 if self.config.dia_object_index == "last_object_table": 

309 table_enum = ApdbTables.DiaObjectLast 

310 else: 

311 table_enum = ApdbTables.DiaObject 

312 table = self._schema.get_table(table_enum) 

313 if not self.config.dia_object_columns: 

314 columns = self._schema.get_apdb_columns(table_enum) 

315 else: 

316 columns = [table.c[col] for col in self.config.dia_object_columns] 

317 query = sql.select(*columns) 

318 

319 # build selection 

320 query = query.where(self._filterRegion(table, region)) 

321 

322 # select latest version of objects 

323 if self.config.dia_object_index != "last_object_table": 

324 query = query.where(table.c.validityEnd == None) # noqa: E711 

325 

326 # _LOG.debug("query: %s", query) 

327 

328 # execute select 

329 with Timer("DiaObject select", self.config.timer): 

330 with self._engine.begin() as conn: 

331 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

332 _LOG.debug("found %s DiaObjects", len(objects)) 

333 return objects 

334 

335 def getDiaSources( 

336 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

337 ) -> Optional[pandas.DataFrame]: 

338 # docstring is inherited from a base class 

339 if self.config.read_sources_months == 0: 

340 _LOG.debug("Skip DiaSources fetching") 

341 return None 

342 

343 if object_ids is None: 

344 # region-based select 

345 return self._getDiaSourcesInRegion(region, visit_time) 

346 else: 

347 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

348 

349 def getDiaForcedSources( 

350 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

351 ) -> Optional[pandas.DataFrame]: 

352 # docstring is inherited from a base class 

353 if self.config.read_forced_sources_months == 0: 

354 _LOG.debug("Skip DiaForceSources fetching") 

355 return None 

356 

357 if object_ids is None: 

358 # This implementation does not support region-based selection. 

359 raise NotImplementedError("Region-based selection is not supported") 

360 

361 # TODO: DateTime.MJD must be consistent with code in ap_association, 

362 # alternatively we can fill midpointMjdTai ourselves in store() 

363 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

364 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

365 

366 with Timer("DiaForcedSource select", self.config.timer): 

367 sources = self._getSourcesByIDs( 

368 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

369 ) 

370 

371 _LOG.debug("found %s DiaForcedSources", len(sources)) 

372 return sources 

373 

374 def getInsertIds(self) -> list[ApdbInsertId] | None: 

375 # docstring is inherited from a base class 

376 if not self._schema.has_insert_id: 

377 return None 

378 

379 table = self._schema.get_table(ExtraTables.DiaInsertId) 

380 assert table is not None, "has_insert_id=True means it must be defined" 

381 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"]) 

382 with Timer("DiaObject insert id select", self.config.timer): 

383 with self._engine.connect() as conn: 

384 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

385 return [ApdbInsertId(row) for row in result.scalars()] 

386 

387 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

388 # docstring is inherited from a base class 

389 if not self._schema.has_insert_id: 

390 raise ValueError("APDB is not configured for history storage") 

391 

392 table = self._schema.get_table(ExtraTables.DiaInsertId) 

393 

394 insert_ids = [id.id for id in ids] 

395 where_clause = table.columns["insert_id"].in_(insert_ids) 

396 stmt = table.delete().where(where_clause) 

397 with self._engine.begin() as conn: 

398 conn.execute(stmt) 

399 

400 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

401 # docstring is inherited from a base class 

402 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

403 

404 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

405 # docstring is inherited from a base class 

406 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

407 

408 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

409 # docstring is inherited from a base class 

410 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

411 

412 def _get_history( 

413 self, 

414 ids: Iterable[ApdbInsertId], 

415 table_enum: ApdbTables, 

416 history_table_enum: ExtraTables, 

417 ) -> ApdbTableData: 

418 """Return catalog of records for given insert identifiers, common 

419 implementation for all DIA tables. 

420 """ 

421 if not self._schema.has_insert_id: 

422 raise ValueError("APDB is not configured for history retrieval") 

423 

424 table = self._schema.get_table(table_enum) 

425 history_table = self._schema.get_table(history_table_enum) 

426 

427 join = table.join(history_table) 

428 insert_ids = [id.id for id in ids] 

429 history_id_column = history_table.columns["insert_id"] 

430 apdb_columns = self._schema.get_apdb_columns(table_enum) 

431 where_clause = history_id_column.in_(insert_ids) 

432 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

433 

434 # execute select 

435 with Timer(f"{table.name} history select", self.config.timer): 

436 with self._engine.begin() as conn: 

437 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

438 return ApdbSqlTableData(result) 

439 

440 def getSSObjects(self) -> pandas.DataFrame: 

441 # docstring is inherited from a base class 

442 

443 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

444 query = sql.select(*columns) 

445 

446 # execute select 

447 with Timer("DiaObject select", self.config.timer): 

448 with self._engine.begin() as conn: 

449 objects = pandas.read_sql_query(query, conn) 

450 _LOG.debug("found %s SSObjects", len(objects)) 

451 return objects 

452 

453 def store( 

454 self, 

455 visit_time: dafBase.DateTime, 

456 objects: pandas.DataFrame, 

457 sources: Optional[pandas.DataFrame] = None, 

458 forced_sources: Optional[pandas.DataFrame] = None, 

459 ) -> None: 

460 # docstring is inherited from a base class 

461 

462 # We want to run all inserts in one transaction. 

463 with self._engine.begin() as connection: 

464 insert_id: ApdbInsertId | None = None 

465 if self._schema.has_insert_id: 

466 insert_id = ApdbInsertId.new_insert_id() 

467 self._storeInsertId(insert_id, visit_time, connection) 

468 

469 # fill pixelId column for DiaObjects 

470 objects = self._add_obj_htm_index(objects) 

471 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

472 

473 if sources is not None: 

474 # copy pixelId column from DiaObjects to DiaSources 

475 sources = self._add_src_htm_index(sources, objects) 

476 self._storeDiaSources(sources, insert_id, connection) 

477 

478 if forced_sources is not None: 

479 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

480 

481 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

482 # docstring is inherited from a base class 

483 

484 idColumn = "ssObjectId" 

485 table = self._schema.get_table(ApdbTables.SSObject) 

486 

487 # everything to be done in single transaction 

488 with self._engine.begin() as conn: 

489 # Find record IDs that already exist. Some types like np.int64 can 

490 # cause issues with sqlalchemy, convert them to int. 

491 ids = sorted(int(oid) for oid in objects[idColumn]) 

492 

493 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

494 result = conn.execute(query) 

495 knownIds = set(row.ssObjectId for row in result) 

496 

497 filter = objects[idColumn].isin(knownIds) 

498 toUpdate = cast(pandas.DataFrame, objects[filter]) 

499 toInsert = cast(pandas.DataFrame, objects[~filter]) 

500 

501 # insert new records 

502 if len(toInsert) > 0: 

503 toInsert.to_sql( 

504 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

505 ) 

506 

507 # update existing records 

508 if len(toUpdate) > 0: 

509 whereKey = f"{idColumn}_param" 

510 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

511 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

512 values = toUpdate.to_dict("records") 

513 result = conn.execute(update, values) 

514 

515 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

516 # docstring is inherited from a base class 

517 

518 table = self._schema.get_table(ApdbTables.DiaSource) 

519 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

520 

521 with self._engine.begin() as conn: 

522 # Need to make sure that every ID exists in the database, but 

523 # executemany may not support rowcount, so iterate and check what 

524 # is missing. 

525 missing_ids: List[int] = [] 

526 for key, value in idMap.items(): 

527 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

528 result = conn.execute(query, params) 

529 if result.rowcount == 0: 

530 missing_ids.append(key) 

531 if missing_ids: 

532 missing = ",".join(str(item) for item in missing_ids) 

533 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

534 

535 def dailyJob(self) -> None: 

536 # docstring is inherited from a base class 

537 pass 

538 

539 def countUnassociatedObjects(self) -> int: 

540 # docstring is inherited from a base class 

541 

542 # Retrieve the DiaObject table. 

543 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

544 

545 # Construct the sql statement. 

546 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

547 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

548 

549 # Return the count. 

550 with self._engine.begin() as conn: 

551 count = conn.execute(stmt).scalar_one() 

552 

553 return count 

554 

555 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

556 """Return catalog of DiaSource instances from given region. 

557 

558 Parameters 

559 ---------- 

560 region : `lsst.sphgeom.Region` 

561 Region to search for DIASources. 

562 visit_time : `lsst.daf.base.DateTime` 

563 Time of the current visit. 

564 

565 Returns 

566 ------- 

567 catalog : `pandas.DataFrame` 

568 Catalog containing DiaSource records. 

569 """ 

570 # TODO: DateTime.MJD must be consistent with code in ap_association, 

571 # alternatively we can fill midpointMjdTai ourselves in store() 

572 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

573 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

574 

575 table = self._schema.get_table(ApdbTables.DiaSource) 

576 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

577 query = sql.select(*columns) 

578 

579 # build selection 

580 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

581 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

582 query = query.where(where) 

583 

584 # execute select 

585 with Timer("DiaSource select", self.config.timer): 

586 with self._engine.begin() as conn: 

587 sources = pandas.read_sql_query(query, conn) 

588 _LOG.debug("found %s DiaSources", len(sources)) 

589 return sources 

590 

591 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

592 """Return catalog of DiaSource instances given set of DiaObject IDs. 

593 

594 Parameters 

595 ---------- 

596 object_ids : 

597 Collection of DiaObject IDs 

598 visit_time : `lsst.daf.base.DateTime` 

599 Time of the current visit. 

600 

601 Returns 

602 ------- 

603 catalog : `pandas.DataFrame` 

604 Catalog contaning DiaSource records. 

605 """ 

606 # TODO: DateTime.MJD must be consistent with code in ap_association, 

607 # alternatively we can fill midpointMjdTai ourselves in store() 

608 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

609 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

610 

611 with Timer("DiaSource select", self.config.timer): 

612 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

613 

614 _LOG.debug("found %s DiaSources", len(sources)) 

615 return sources 

616 

617 def _getSourcesByIDs( 

618 self, table_enum: ApdbTables, object_ids: List[int], midpointMjdTai_start: float 

619 ) -> pandas.DataFrame: 

620 """Return catalog of DiaSource or DiaForcedSource instances given set 

621 of DiaObject IDs. 

622 

623 Parameters 

624 ---------- 

625 table : `sqlalchemy.schema.Table` 

626 Database table. 

627 object_ids : 

628 Collection of DiaObject IDs 

629 midpointMjdTai_start : `float` 

630 Earliest midpointMjdTai to retrieve. 

631 

632 Returns 

633 ------- 

634 catalog : `pandas.DataFrame` 

635 Catalog contaning DiaSource records. `None` is returned if 

636 ``read_sources_months`` configuration parameter is set to 0 or 

637 when ``object_ids`` is empty. 

638 """ 

639 table = self._schema.get_table(table_enum) 

640 columns = self._schema.get_apdb_columns(table_enum) 

641 

642 sources: Optional[pandas.DataFrame] = None 

643 if len(object_ids) <= 0: 

644 _LOG.debug("ID list is empty, just fetch empty result") 

645 query = sql.select(*columns).where(sql.literal(False)) 

646 with self._engine.begin() as conn: 

647 sources = pandas.read_sql_query(query, conn) 

648 else: 

649 data_frames: list[pandas.DataFrame] = [] 

650 for ids in chunk_iterable(sorted(object_ids), 1000): 

651 query = sql.select(*columns) 

652 

653 # Some types like np.int64 can cause issues with 

654 # sqlalchemy, convert them to int. 

655 int_ids = [int(oid) for oid in ids] 

656 

657 # select by object id 

658 query = query.where( 

659 sql.expression.and_( 

660 table.columns["diaObjectId"].in_(int_ids), 

661 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

662 ) 

663 ) 

664 

665 # execute select 

666 with self._engine.begin() as conn: 

667 data_frames.append(pandas.read_sql_query(query, conn)) 

668 

669 if len(data_frames) == 1: 

670 sources = data_frames[0] 

671 else: 

672 sources = pandas.concat(data_frames) 

673 assert sources is not None, "Catalog cannot be None" 

674 return sources 

675 

676 def _storeInsertId( 

677 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

678 ) -> None: 

679 dt = visit_time.toPython() 

680 

681 table = self._schema.get_table(ExtraTables.DiaInsertId) 

682 

683 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

684 connection.execute(stmt) 

685 

686 def _storeDiaObjects( 

687 self, 

688 objs: pandas.DataFrame, 

689 visit_time: dafBase.DateTime, 

690 insert_id: ApdbInsertId | None, 

691 connection: sqlalchemy.engine.Connection, 

692 ) -> None: 

693 """Store catalog of DiaObjects from current visit. 

694 

695 Parameters 

696 ---------- 

697 objs : `pandas.DataFrame` 

698 Catalog with DiaObject records. 

699 visit_time : `lsst.daf.base.DateTime` 

700 Time of the visit. 

701 insert_id : `ApdbInsertId` 

702 Insert identifier. 

703 """ 

704 # Some types like np.int64 can cause issues with sqlalchemy, convert 

705 # them to int. 

706 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

707 _LOG.debug("first object ID: %d", ids[0]) 

708 

709 # TODO: Need to verify that we are using correct scale here for 

710 # DATETIME representation (see DM-31996). 

711 dt = visit_time.toPython() 

712 

713 # everything to be done in single transaction 

714 if self.config.dia_object_index == "last_object_table": 

715 # Insert and replace all records in LAST table. 

716 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

717 

718 # Drop the previous objects (pandas cannot upsert). 

719 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

720 

721 with Timer(table.name + " delete", self.config.timer): 

722 res = connection.execute(query) 

723 _LOG.debug("deleted %s objects", res.rowcount) 

724 

725 # DiaObjectLast is a subset of DiaObject, strip missing columns 

726 last_column_names = [column.name for column in table.columns] 

727 last_objs = objs[last_column_names] 

728 last_objs = _coerce_uint64(last_objs) 

729 

730 if "lastNonForcedSource" in last_objs.columns: 

731 # lastNonForcedSource is defined NOT NULL, fill it with visit 

732 # time just in case. 

733 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

734 else: 

735 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

736 last_objs.set_index(extra_column.index, inplace=True) 

737 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

738 

739 with Timer("DiaObjectLast insert", self.config.timer): 

740 last_objs.to_sql( 

741 table.name, 

742 _ConnectionHackSA2(connection), 

743 if_exists="append", 

744 index=False, 

745 schema=table.schema, 

746 ) 

747 else: 

748 # truncate existing validity intervals 

749 table = self._schema.get_table(ApdbTables.DiaObject) 

750 

751 update = ( 

752 table.update() 

753 .values(validityEnd=dt) 

754 .where( 

755 sql.expression.and_( 

756 table.columns["diaObjectId"].in_(ids), 

757 table.columns["validityEnd"].is_(None), 

758 ) 

759 ) 

760 ) 

761 

762 # _LOG.debug("query: %s", query) 

763 

764 with Timer(table.name + " truncate", self.config.timer): 

765 res = connection.execute(update) 

766 _LOG.debug("truncated %s intervals", res.rowcount) 

767 

768 objs = _coerce_uint64(objs) 

769 

770 # Fill additional columns 

771 extra_columns: List[pandas.Series] = [] 

772 if "validityStart" in objs.columns: 

773 objs["validityStart"] = dt 

774 else: 

775 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

776 if "validityEnd" in objs.columns: 

777 objs["validityEnd"] = None 

778 else: 

779 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

780 if "lastNonForcedSource" in objs.columns: 

781 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

782 # just in case. 

783 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

784 else: 

785 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

786 if extra_columns: 

787 objs.set_index(extra_columns[0].index, inplace=True) 

788 objs = pandas.concat([objs] + extra_columns, axis="columns") 

789 

790 # Insert history data 

791 table = self._schema.get_table(ApdbTables.DiaObject) 

792 history_data: list[dict] = [] 

793 history_stmt: Any = None 

794 if insert_id is not None: 

795 pk_names = [column.name for column in table.primary_key] 

796 history_data = objs[pk_names].to_dict("records") 

797 for row in history_data: 

798 row["insert_id"] = insert_id.id 

799 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

800 history_stmt = history_table.insert() 

801 

802 # insert new versions 

803 with Timer("DiaObject insert", self.config.timer): 

804 objs.to_sql( 

805 table.name, 

806 _ConnectionHackSA2(connection), 

807 if_exists="append", 

808 index=False, 

809 schema=table.schema, 

810 ) 

811 if history_stmt is not None: 

812 connection.execute(history_stmt, history_data) 

813 

814 def _storeDiaSources( 

815 self, 

816 sources: pandas.DataFrame, 

817 insert_id: ApdbInsertId | None, 

818 connection: sqlalchemy.engine.Connection, 

819 ) -> None: 

820 """Store catalog of DiaSources from current visit. 

821 

822 Parameters 

823 ---------- 

824 sources : `pandas.DataFrame` 

825 Catalog containing DiaSource records 

826 """ 

827 table = self._schema.get_table(ApdbTables.DiaSource) 

828 

829 # Insert history data 

830 history: list[dict] = [] 

831 history_stmt: Any = None 

832 if insert_id is not None: 

833 pk_names = [column.name for column in table.primary_key] 

834 history = sources[pk_names].to_dict("records") 

835 for row in history: 

836 row["insert_id"] = insert_id.id 

837 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

838 history_stmt = history_table.insert() 

839 

840 # everything to be done in single transaction 

841 with Timer("DiaSource insert", self.config.timer): 

842 sources = _coerce_uint64(sources) 

843 sources.to_sql( 

844 table.name, 

845 _ConnectionHackSA2(connection), 

846 if_exists="append", 

847 index=False, 

848 schema=table.schema, 

849 ) 

850 if history_stmt is not None: 

851 connection.execute(history_stmt, history) 

852 

853 def _storeDiaForcedSources( 

854 self, 

855 sources: pandas.DataFrame, 

856 insert_id: ApdbInsertId | None, 

857 connection: sqlalchemy.engine.Connection, 

858 ) -> None: 

859 """Store a set of DiaForcedSources from current visit. 

860 

861 Parameters 

862 ---------- 

863 sources : `pandas.DataFrame` 

864 Catalog containing DiaForcedSource records 

865 """ 

866 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

867 

868 # Insert history data 

869 history: list[dict] = [] 

870 history_stmt: Any = None 

871 if insert_id is not None: 

872 pk_names = [column.name for column in table.primary_key] 

873 history = sources[pk_names].to_dict("records") 

874 for row in history: 

875 row["insert_id"] = insert_id.id 

876 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

877 history_stmt = history_table.insert() 

878 

879 # everything to be done in single transaction 

880 with Timer("DiaForcedSource insert", self.config.timer): 

881 sources = _coerce_uint64(sources) 

882 sources.to_sql( 

883 table.name, 

884 _ConnectionHackSA2(connection), 

885 if_exists="append", 

886 index=False, 

887 schema=table.schema, 

888 ) 

889 if history_stmt is not None: 

890 connection.execute(history_stmt, history) 

891 

892 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

893 """Generate a set of HTM indices covering specified region. 

894 

895 Parameters 

896 ---------- 

897 region: `sphgeom.Region` 

898 Region that needs to be indexed. 

899 

900 Returns 

901 ------- 

902 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

903 """ 

904 _LOG.debug("region: %s", region) 

905 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

906 

907 return indices.ranges() 

908 

909 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

910 """Make SQLAlchemy expression for selecting records in a region.""" 

911 htm_index_column = table.columns[self.config.htm_index_column] 

912 exprlist = [] 

913 pixel_ranges = self._htm_indices(region) 

914 for low, upper in pixel_ranges: 

915 upper -= 1 

916 if low == upper: 

917 exprlist.append(htm_index_column == low) 

918 else: 

919 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

920 

921 return sql.expression.or_(*exprlist) 

922 

923 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

924 """Calculate HTM index for each record and add it to a DataFrame. 

925 

926 Notes 

927 ----- 

928 This overrides any existing column in a DataFrame with the same name 

929 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

930 returned. 

931 """ 

932 # calculate HTM index for every DiaObject 

933 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

934 ra_col, dec_col = self.config.ra_dec_columns 

935 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

936 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

937 idx = self.pixelator.index(uv3d) 

938 htm_index[i] = idx 

939 df = df.copy() 

940 df[self.config.htm_index_column] = htm_index 

941 return df 

942 

943 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

944 """Add pixelId column to DiaSource catalog. 

945 

946 Notes 

947 ----- 

948 This method copies pixelId value from a matching DiaObject record. 

949 DiaObject catalog needs to have a pixelId column filled by 

950 ``_add_obj_htm_index`` method and DiaSource records need to be 

951 associated to DiaObjects via ``diaObjectId`` column. 

952 

953 This overrides any existing column in a DataFrame with the same name 

954 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

955 returned. 

956 """ 

957 pixel_id_map: Dict[int, int] = { 

958 diaObjectId: pixelId 

959 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

960 } 

961 # DiaSources associated with SolarSystemObjects do not have an 

962 # associated DiaObject hence we skip them and set their htmIndex 

963 # value to 0. 

964 pixel_id_map[0] = 0 

965 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

966 for i, diaObjId in enumerate(sources["diaObjectId"]): 

967 htm_index[i] = pixel_id_map[diaObjId] 

968 sources = sources.copy() 

969 sources[self.config.htm_index_column] = htm_index 

970 return sources