Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

428 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-14 19:51 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from typing import Any, Dict, List, Optional, Tuple, cast 

32 

33import lsst.daf.base as dafBase 

34import numpy as np 

35import pandas 

36import sqlalchemy 

37from felis.simple import Table 

38from lsst.pex.config import ChoiceField, Field, ListField 

39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

40from lsst.utils.iteration import chunk_iterable 

41from sqlalchemy import func, inspection, sql 

42from sqlalchemy.engine import Inspector 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

46from .apdbSchema import ApdbTables 

47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

48from .timer import Timer 

49 

50_LOG = logging.getLogger(__name__) 

51 

52 

53if pandas.__version__.partition(".")[0] == "1": 53 ↛ 55line 53 didn't jump to line 55, because the condition on line 53 was never true

54 

55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

56 """Terrible hack to workaround Pandas 1 incomplete support for 

57 sqlalchemy 2. 

58 

59 We need to pass a Connection instance to pandas method, but in SA 2 the 

60 Connection class lost ``connect`` method which is used by Pandas. 

61 """ 

62 

63 def __init__(self, connection: sqlalchemy.engine.Connection): 

64 self._connection = connection 

65 

66 def connect(self, **kwargs: Any) -> Any: 

67 return self 

68 

69 @property 

70 def execute(self) -> Callable: 

71 return self._connection.execute 

72 

73 @property 

74 def execution_options(self) -> Callable: 

75 return self._connection.execution_options 

76 

77 @property 

78 def connection(self) -> Any: 

79 return self._connection.connection 

80 

81 def __enter__(self) -> sqlalchemy.engine.Connection: 

82 return self._connection 

83 

84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

85 # Do not close connection here 

86 pass 

87 

88 @inspection._inspects(_ConnectionHackSA2) 

89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

90 return Inspector._construct(Inspector._init_connection, conn._connection) 

91 

92else: 

93 # Pandas 2.0 supports SQLAlchemy 2 correctly. 

94 def _ConnectionHackSA2( # type: ignore[no-redef] 

95 conn: sqlalchemy.engine.Connectable, 

96 ) -> sqlalchemy.engine.Connectable: 

97 return conn 

98 

99 

100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

101 """Change type of the uint64 columns to int64, return copy of data frame.""" 

102 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

103 return df.astype({name: np.int64 for name in names}) 

104 

105 

106def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

107 """Calculate starting point for time-based source search. 

108 

109 Parameters 

110 ---------- 

111 visit_time : `lsst.daf.base.DateTime` 

112 Time of current visit. 

113 months : `int` 

114 Number of months in the sources history. 

115 

116 Returns 

117 ------- 

118 time : `float` 

119 A ``midpointMjdTai`` starting point, MJD time. 

120 """ 

121 # TODO: `system` must be consistent with the code in ap_association 

122 # (see DM-31996) 

123 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

124 

125 

126class ApdbSqlConfig(ApdbConfig): 

127 """APDB configuration class for SQL implementation (ApdbSql).""" 

128 

129 db_url = Field[str](doc="SQLAlchemy database connection URI") 

130 isolation_level = ChoiceField[str]( 

131 doc=( 

132 "Transaction isolation level, if unset then backend-default value " 

133 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

134 "Some backends may not support every allowed value." 

135 ), 

136 allowed={ 

137 "READ_COMMITTED": "Read committed", 

138 "READ_UNCOMMITTED": "Read uncommitted", 

139 "REPEATABLE_READ": "Repeatable read", 

140 "SERIALIZABLE": "Serializable", 

141 }, 

142 default=None, 

143 optional=True, 

144 ) 

145 connection_pool = Field[bool]( 

146 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

147 default=True, 

148 ) 

149 connection_timeout = Field[float]( 

150 doc=( 

151 "Maximum time to wait time for database lock to be released before exiting. " 

152 "Defaults to sqlalchemy defaults if not set." 

153 ), 

154 default=None, 

155 optional=True, 

156 ) 

157 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

158 dia_object_index = ChoiceField[str]( 

159 doc="Indexing mode for DiaObject table", 

160 allowed={ 

161 "baseline": "Index defined in baseline schema", 

162 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

163 "last_object_table": "Separate DiaObjectLast table", 

164 }, 

165 default="baseline", 

166 ) 

167 htm_level = Field[int](doc="HTM indexing level", default=20) 

168 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

169 htm_index_column = Field[str]( 

170 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

171 ) 

172 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

173 dia_object_columns = ListField[str]( 

174 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

175 ) 

176 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

177 namespace = Field[str]( 

178 doc=( 

179 "Namespace or schema name for all tables in APDB database. " 

180 "Presently only works for PostgreSQL backend. " 

181 "If schema with this name does not exist it will be created when " 

182 "APDB tables are created." 

183 ), 

184 default=None, 

185 optional=True, 

186 ) 

187 timer = Field[bool](doc="If True then print/log timing information", default=False) 

188 

189 def validate(self) -> None: 

190 super().validate() 

191 if len(self.ra_dec_columns) != 2: 

192 raise ValueError("ra_dec_columns must have exactly two column names") 

193 

194 

195class ApdbSqlTableData(ApdbTableData): 

196 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

197 

198 def __init__(self, result: sqlalchemy.engine.Result): 

199 self._keys = list(result.keys()) 

200 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

201 

202 def column_names(self) -> list[str]: 

203 return self._keys 

204 

205 def rows(self) -> Iterable[tuple]: 

206 return self._rows 

207 

208 

209class ApdbSql(Apdb): 

210 """Implementation of APDB interface based on SQL database. 

211 

212 The implementation is configured via standard ``pex_config`` mechanism 

213 using `ApdbSqlConfig` configuration class. For an example of different 

214 configurations check ``config/`` folder. 

215 

216 Parameters 

217 ---------- 

218 config : `ApdbSqlConfig` 

219 Configuration object. 

220 """ 

221 

222 ConfigClass = ApdbSqlConfig 

223 

224 def __init__(self, config: ApdbSqlConfig): 

225 config.validate() 

226 self.config = config 

227 

228 _LOG.debug("APDB Configuration:") 

229 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

230 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

231 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

232 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

233 _LOG.debug(" schema_file: %s", self.config.schema_file) 

234 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

235 _LOG.debug(" schema prefix: %s", self.config.prefix) 

236 

237 # engine is reused between multiple processes, make sure that we don't 

238 # share connections by disabling pool (by using NullPool class) 

239 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

240 conn_args: Dict[str, Any] = dict() 

241 if not self.config.connection_pool: 

242 kw.update(poolclass=NullPool) 

243 if self.config.isolation_level is not None: 

244 kw.update(isolation_level=self.config.isolation_level) 

245 elif self.config.db_url.startswith("sqlite"): # type: ignore 

246 # Use READ_UNCOMMITTED as default value for sqlite. 

247 kw.update(isolation_level="READ_UNCOMMITTED") 

248 if self.config.connection_timeout is not None: 

249 if self.config.db_url.startswith("sqlite"): 

250 conn_args.update(timeout=self.config.connection_timeout) 

251 elif self.config.db_url.startswith(("postgresql", "mysql")): 

252 conn_args.update(connect_timeout=self.config.connection_timeout) 

253 kw.update(connect_args=conn_args) 

254 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

255 

256 self._schema = ApdbSqlSchema( 

257 engine=self._engine, 

258 dia_object_index=self.config.dia_object_index, 

259 schema_file=self.config.schema_file, 

260 schema_name=self.config.schema_name, 

261 prefix=self.config.prefix, 

262 namespace=self.config.namespace, 

263 htm_index_column=self.config.htm_index_column, 

264 use_insert_id=config.use_insert_id, 

265 ) 

266 

267 self.pixelator = HtmPixelization(self.config.htm_level) 

268 self.use_insert_id = self._schema.has_insert_id 

269 

270 def tableRowCount(self) -> Dict[str, int]: 

271 """Returns dictionary with the table names and row counts. 

272 

273 Used by ``ap_proto`` to keep track of the size of the database tables. 

274 Depending on database technology this could be expensive operation. 

275 

276 Returns 

277 ------- 

278 row_counts : `dict` 

279 Dict where key is a table name and value is a row count. 

280 """ 

281 res = {} 

282 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

283 if self.config.dia_object_index == "last_object_table": 

284 tables.append(ApdbTables.DiaObjectLast) 

285 with self._engine.begin() as conn: 

286 for table in tables: 

287 sa_table = self._schema.get_table(table) 

288 stmt = sql.select(func.count()).select_from(sa_table) 

289 count: int = conn.execute(stmt).scalar_one() 

290 res[table.name] = count 

291 

292 return res 

293 

294 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

295 # docstring is inherited from a base class 

296 return self._schema.tableSchemas.get(table) 

297 

298 def makeSchema(self, drop: bool = False) -> None: 

299 # docstring is inherited from a base class 

300 self._schema.makeSchema(drop=drop) 

301 

302 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

303 # docstring is inherited from a base class 

304 

305 # decide what columns we need 

306 if self.config.dia_object_index == "last_object_table": 

307 table_enum = ApdbTables.DiaObjectLast 

308 else: 

309 table_enum = ApdbTables.DiaObject 

310 table = self._schema.get_table(table_enum) 

311 if not self.config.dia_object_columns: 

312 columns = self._schema.get_apdb_columns(table_enum) 

313 else: 

314 columns = [table.c[col] for col in self.config.dia_object_columns] 

315 query = sql.select(*columns) 

316 

317 # build selection 

318 query = query.where(self._filterRegion(table, region)) 

319 

320 # select latest version of objects 

321 if self.config.dia_object_index != "last_object_table": 

322 query = query.where(table.c.validityEnd == None) # noqa: E711 

323 

324 # _LOG.debug("query: %s", query) 

325 

326 # execute select 

327 with Timer("DiaObject select", self.config.timer): 

328 with self._engine.begin() as conn: 

329 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

330 _LOG.debug("found %s DiaObjects", len(objects)) 

331 return objects 

332 

333 def getDiaSources( 

334 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

335 ) -> Optional[pandas.DataFrame]: 

336 # docstring is inherited from a base class 

337 if self.config.read_sources_months == 0: 

338 _LOG.debug("Skip DiaSources fetching") 

339 return None 

340 

341 if object_ids is None: 

342 # region-based select 

343 return self._getDiaSourcesInRegion(region, visit_time) 

344 else: 

345 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

346 

347 def getDiaForcedSources( 

348 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

349 ) -> Optional[pandas.DataFrame]: 

350 """Return catalog of DiaForcedSource instances from a given region. 

351 

352 Parameters 

353 ---------- 

354 region : `lsst.sphgeom.Region` 

355 Region to search for DIASources. 

356 object_ids : iterable [ `int` ], optional 

357 List of DiaObject IDs to further constrain the set of returned 

358 sources. If list is empty then empty catalog is returned with a 

359 correct schema. 

360 visit_time : `lsst.daf.base.DateTime` 

361 Time of the current visit. 

362 

363 Returns 

364 ------- 

365 catalog : `pandas.DataFrame`, or `None` 

366 Catalog containing DiaSource records. `None` is returned if 

367 ``read_sources_months`` configuration parameter is set to 0. 

368 

369 Raises 

370 ------ 

371 NotImplementedError 

372 Raised if ``object_ids`` is `None`. 

373 

374 Notes 

375 ----- 

376 Even though base class allows `None` to be passed for ``object_ids``, 

377 this class requires ``object_ids`` to be not-`None`. 

378 `NotImplementedError` is raised if `None` is passed. 

379 

380 This method returns DiaForcedSource catalog for a region with additional 

381 filtering based on DiaObject IDs. Only a subset of DiaSource history 

382 is returned limited by ``read_forced_sources_months`` config parameter, 

383 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

384 is always returned with a correct schema (columns/types). 

385 """ 

386 

387 if self.config.read_forced_sources_months == 0: 

388 _LOG.debug("Skip DiaForceSources fetching") 

389 return None 

390 

391 if object_ids is None: 

392 # This implementation does not support region-based selection. 

393 raise NotImplementedError("Region-based selection is not supported") 

394 

395 # TODO: DateTime.MJD must be consistent with code in ap_association, 

396 # alternatively we can fill midpointMjdTai ourselves in store() 

397 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

398 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

399 

400 with Timer("DiaForcedSource select", self.config.timer): 

401 sources = self._getSourcesByIDs( 

402 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

403 ) 

404 

405 _LOG.debug("found %s DiaForcedSources", len(sources)) 

406 return sources 

407 

408 def getInsertIds(self) -> list[ApdbInsertId] | None: 

409 # docstring is inherited from a base class 

410 if not self._schema.has_insert_id: 

411 return None 

412 

413 table = self._schema.get_table(ExtraTables.DiaInsertId) 

414 assert table is not None, "has_insert_id=True means it must be defined" 

415 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"]) 

416 with Timer("DiaObject insert id select", self.config.timer): 

417 with self._engine.connect() as conn: 

418 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

419 return [ApdbInsertId(row) for row in result.scalars()] 

420 

421 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

422 # docstring is inherited from a base class 

423 if not self._schema.has_insert_id: 

424 raise ValueError("APDB is not configured for history storage") 

425 

426 table = self._schema.get_table(ExtraTables.DiaInsertId) 

427 

428 insert_ids = [id.id for id in ids] 

429 where_clause = table.columns["insert_id"].in_(insert_ids) 

430 stmt = table.delete().where(where_clause) 

431 with self._engine.begin() as conn: 

432 conn.execute(stmt) 

433 

434 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

435 # docstring is inherited from a base class 

436 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

437 

438 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

439 # docstring is inherited from a base class 

440 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

441 

442 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

443 # docstring is inherited from a base class 

444 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

445 

446 def _get_history( 

447 self, 

448 ids: Iterable[ApdbInsertId], 

449 table_enum: ApdbTables, 

450 history_table_enum: ExtraTables, 

451 ) -> ApdbTableData: 

452 """Common implementation of the history methods.""" 

453 if not self._schema.has_insert_id: 

454 raise ValueError("APDB is not configured for history retrieval") 

455 

456 table = self._schema.get_table(table_enum) 

457 history_table = self._schema.get_table(history_table_enum) 

458 

459 join = table.join(history_table) 

460 insert_ids = [id.id for id in ids] 

461 history_id_column = history_table.columns["insert_id"] 

462 apdb_columns = self._schema.get_apdb_columns(table_enum) 

463 where_clause = history_id_column.in_(insert_ids) 

464 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

465 

466 # execute select 

467 with Timer(f"{table.name} history select", self.config.timer): 

468 with self._engine.begin() as conn: 

469 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

470 return ApdbSqlTableData(result) 

471 

472 def getSSObjects(self) -> pandas.DataFrame: 

473 # docstring is inherited from a base class 

474 

475 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

476 query = sql.select(*columns) 

477 

478 # execute select 

479 with Timer("DiaObject select", self.config.timer): 

480 with self._engine.begin() as conn: 

481 objects = pandas.read_sql_query(query, conn) 

482 _LOG.debug("found %s SSObjects", len(objects)) 

483 return objects 

484 

485 def store( 

486 self, 

487 visit_time: dafBase.DateTime, 

488 objects: pandas.DataFrame, 

489 sources: Optional[pandas.DataFrame] = None, 

490 forced_sources: Optional[pandas.DataFrame] = None, 

491 ) -> None: 

492 # docstring is inherited from a base class 

493 

494 # We want to run all inserts in one transaction. 

495 with self._engine.begin() as connection: 

496 insert_id: ApdbInsertId | None = None 

497 if self._schema.has_insert_id: 

498 insert_id = ApdbInsertId.new_insert_id() 

499 self._storeInsertId(insert_id, visit_time, connection) 

500 

501 # fill pixelId column for DiaObjects 

502 objects = self._add_obj_htm_index(objects) 

503 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

504 

505 if sources is not None: 

506 # copy pixelId column from DiaObjects to DiaSources 

507 sources = self._add_src_htm_index(sources, objects) 

508 self._storeDiaSources(sources, insert_id, connection) 

509 

510 if forced_sources is not None: 

511 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

512 

513 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

514 # docstring is inherited from a base class 

515 

516 idColumn = "ssObjectId" 

517 table = self._schema.get_table(ApdbTables.SSObject) 

518 

519 # everything to be done in single transaction 

520 with self._engine.begin() as conn: 

521 # Find record IDs that already exist. Some types like np.int64 can 

522 # cause issues with sqlalchemy, convert them to int. 

523 ids = sorted(int(oid) for oid in objects[idColumn]) 

524 

525 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

526 result = conn.execute(query) 

527 knownIds = set(row.ssObjectId for row in result) 

528 

529 filter = objects[idColumn].isin(knownIds) 

530 toUpdate = cast(pandas.DataFrame, objects[filter]) 

531 toInsert = cast(pandas.DataFrame, objects[~filter]) 

532 

533 # insert new records 

534 if len(toInsert) > 0: 

535 toInsert.to_sql( 

536 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

537 ) 

538 

539 # update existing records 

540 if len(toUpdate) > 0: 

541 whereKey = f"{idColumn}_param" 

542 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

543 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

544 values = toUpdate.to_dict("records") 

545 result = conn.execute(update, values) 

546 

547 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

548 # docstring is inherited from a base class 

549 

550 table = self._schema.get_table(ApdbTables.DiaSource) 

551 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

552 

553 with self._engine.begin() as conn: 

554 # Need to make sure that every ID exists in the database, but 

555 # executemany may not support rowcount, so iterate and check what is 

556 # missing. 

557 missing_ids: List[int] = [] 

558 for key, value in idMap.items(): 

559 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

560 result = conn.execute(query, params) 

561 if result.rowcount == 0: 

562 missing_ids.append(key) 

563 if missing_ids: 

564 missing = ",".join(str(item) for item in missing_ids) 

565 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

566 

567 def dailyJob(self) -> None: 

568 # docstring is inherited from a base class 

569 pass 

570 

571 def countUnassociatedObjects(self) -> int: 

572 # docstring is inherited from a base class 

573 

574 # Retrieve the DiaObject table. 

575 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

576 

577 # Construct the sql statement. 

578 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

579 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

580 

581 # Return the count. 

582 with self._engine.begin() as conn: 

583 count = conn.execute(stmt).scalar_one() 

584 

585 return count 

586 

587 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

588 """Returns catalog of DiaSource instances from given region. 

589 

590 Parameters 

591 ---------- 

592 region : `lsst.sphgeom.Region` 

593 Region to search for DIASources. 

594 visit_time : `lsst.daf.base.DateTime` 

595 Time of the current visit. 

596 

597 Returns 

598 ------- 

599 catalog : `pandas.DataFrame` 

600 Catalog containing DiaSource records. 

601 """ 

602 # TODO: DateTime.MJD must be consistent with code in ap_association, 

603 # alternatively we can fill midpointMjdTai ourselves in store() 

604 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

605 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

606 

607 table = self._schema.get_table(ApdbTables.DiaSource) 

608 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

609 query = sql.select(*columns) 

610 

611 # build selection 

612 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

613 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

614 query = query.where(where) 

615 

616 # execute select 

617 with Timer("DiaSource select", self.config.timer): 

618 with self._engine.begin() as conn: 

619 sources = pandas.read_sql_query(query, conn) 

620 _LOG.debug("found %s DiaSources", len(sources)) 

621 return sources 

622 

623 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

624 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

625 

626 Parameters 

627 ---------- 

628 object_ids : 

629 Collection of DiaObject IDs 

630 visit_time : `lsst.daf.base.DateTime` 

631 Time of the current visit. 

632 

633 Returns 

634 ------- 

635 catalog : `pandas.DataFrame` 

636 Catalog contaning DiaSource records. 

637 """ 

638 # TODO: DateTime.MJD must be consistent with code in ap_association, 

639 # alternatively we can fill midpointMjdTai ourselves in store() 

640 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

641 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

642 

643 with Timer("DiaSource select", self.config.timer): 

644 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

645 

646 _LOG.debug("found %s DiaSources", len(sources)) 

647 return sources 

648 

649 def _getSourcesByIDs( 

650 self, table_enum: ApdbTables, object_ids: List[int], midpointMjdTai_start: float 

651 ) -> pandas.DataFrame: 

652 """Returns catalog of DiaSource or DiaForcedSource instances given set 

653 of DiaObject IDs. 

654 

655 Parameters 

656 ---------- 

657 table : `sqlalchemy.schema.Table` 

658 Database table. 

659 object_ids : 

660 Collection of DiaObject IDs 

661 midpointMjdTai_start : `float` 

662 Earliest midpointMjdTai to retrieve. 

663 

664 Returns 

665 ------- 

666 catalog : `pandas.DataFrame` 

667 Catalog contaning DiaSource records. `None` is returned if 

668 ``read_sources_months`` configuration parameter is set to 0 or 

669 when ``object_ids`` is empty. 

670 """ 

671 table = self._schema.get_table(table_enum) 

672 columns = self._schema.get_apdb_columns(table_enum) 

673 

674 sources: Optional[pandas.DataFrame] = None 

675 if len(object_ids) <= 0: 

676 _LOG.debug("ID list is empty, just fetch empty result") 

677 query = sql.select(*columns).where(sql.literal(False)) 

678 with self._engine.begin() as conn: 

679 sources = pandas.read_sql_query(query, conn) 

680 else: 

681 data_frames: list[pandas.DataFrame] = [] 

682 for ids in chunk_iterable(sorted(object_ids), 1000): 

683 query = sql.select(*columns) 

684 

685 # Some types like np.int64 can cause issues with 

686 # sqlalchemy, convert them to int. 

687 int_ids = [int(oid) for oid in ids] 

688 

689 # select by object id 

690 query = query.where( 

691 sql.expression.and_( 

692 table.columns["diaObjectId"].in_(int_ids), 

693 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

694 ) 

695 ) 

696 

697 # execute select 

698 with self._engine.begin() as conn: 

699 data_frames.append(pandas.read_sql_query(query, conn)) 

700 

701 if len(data_frames) == 1: 

702 sources = data_frames[0] 

703 else: 

704 sources = pandas.concat(data_frames) 

705 assert sources is not None, "Catalog cannot be None" 

706 return sources 

707 

708 def _storeInsertId( 

709 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

710 ) -> None: 

711 dt = visit_time.toPython() 

712 

713 table = self._schema.get_table(ExtraTables.DiaInsertId) 

714 

715 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

716 connection.execute(stmt) 

717 

718 def _storeDiaObjects( 

719 self, 

720 objs: pandas.DataFrame, 

721 visit_time: dafBase.DateTime, 

722 insert_id: ApdbInsertId | None, 

723 connection: sqlalchemy.engine.Connection, 

724 ) -> None: 

725 """Store catalog of DiaObjects from current visit. 

726 

727 Parameters 

728 ---------- 

729 objs : `pandas.DataFrame` 

730 Catalog with DiaObject records. 

731 visit_time : `lsst.daf.base.DateTime` 

732 Time of the visit. 

733 insert_id : `ApdbInsertId` 

734 Insert identifier. 

735 """ 

736 

737 # Some types like np.int64 can cause issues with sqlalchemy, convert 

738 # them to int. 

739 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

740 _LOG.debug("first object ID: %d", ids[0]) 

741 

742 # TODO: Need to verify that we are using correct scale here for 

743 # DATETIME representation (see DM-31996). 

744 dt = visit_time.toPython() 

745 

746 # everything to be done in single transaction 

747 if self.config.dia_object_index == "last_object_table": 

748 # insert and replace all records in LAST table, mysql and postgres have 

749 # non-standard features 

750 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

751 

752 # Drop the previous objects (pandas cannot upsert). 

753 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

754 

755 with Timer(table.name + " delete", self.config.timer): 

756 res = connection.execute(query) 

757 _LOG.debug("deleted %s objects", res.rowcount) 

758 

759 # DiaObjectLast is a subset of DiaObject, strip missing columns 

760 last_column_names = [column.name for column in table.columns] 

761 last_objs = objs[last_column_names] 

762 last_objs = _coerce_uint64(last_objs) 

763 

764 if "lastNonForcedSource" in last_objs.columns: 

765 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

766 # just in case. 

767 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

768 else: 

769 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

770 last_objs.set_index(extra_column.index, inplace=True) 

771 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

772 

773 with Timer("DiaObjectLast insert", self.config.timer): 

774 last_objs.to_sql( 

775 table.name, 

776 _ConnectionHackSA2(connection), 

777 if_exists="append", 

778 index=False, 

779 schema=table.schema, 

780 ) 

781 else: 

782 # truncate existing validity intervals 

783 table = self._schema.get_table(ApdbTables.DiaObject) 

784 

785 update = ( 

786 table.update() 

787 .values(validityEnd=dt) 

788 .where( 

789 sql.expression.and_( 

790 table.columns["diaObjectId"].in_(ids), 

791 table.columns["validityEnd"].is_(None), 

792 ) 

793 ) 

794 ) 

795 

796 # _LOG.debug("query: %s", query) 

797 

798 with Timer(table.name + " truncate", self.config.timer): 

799 res = connection.execute(update) 

800 _LOG.debug("truncated %s intervals", res.rowcount) 

801 

802 objs = _coerce_uint64(objs) 

803 

804 # Fill additional columns 

805 extra_columns: List[pandas.Series] = [] 

806 if "validityStart" in objs.columns: 

807 objs["validityStart"] = dt 

808 else: 

809 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

810 if "validityEnd" in objs.columns: 

811 objs["validityEnd"] = None 

812 else: 

813 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

814 if "lastNonForcedSource" in objs.columns: 

815 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

816 # just in case. 

817 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

818 else: 

819 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

820 if extra_columns: 

821 objs.set_index(extra_columns[0].index, inplace=True) 

822 objs = pandas.concat([objs] + extra_columns, axis="columns") 

823 

824 # Insert history data 

825 table = self._schema.get_table(ApdbTables.DiaObject) 

826 history_data: list[dict] = [] 

827 history_stmt: Any = None 

828 if insert_id is not None: 

829 pk_names = [column.name for column in table.primary_key] 

830 history_data = objs[pk_names].to_dict("records") 

831 for row in history_data: 

832 row["insert_id"] = insert_id.id 

833 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

834 history_stmt = history_table.insert() 

835 

836 # insert new versions 

837 with Timer("DiaObject insert", self.config.timer): 

838 objs.to_sql( 

839 table.name, 

840 _ConnectionHackSA2(connection), 

841 if_exists="append", 

842 index=False, 

843 schema=table.schema, 

844 ) 

845 if history_stmt is not None: 

846 connection.execute(history_stmt, history_data) 

847 

848 def _storeDiaSources( 

849 self, 

850 sources: pandas.DataFrame, 

851 insert_id: ApdbInsertId | None, 

852 connection: sqlalchemy.engine.Connection, 

853 ) -> None: 

854 """Store catalog of DiaSources from current visit. 

855 

856 Parameters 

857 ---------- 

858 sources : `pandas.DataFrame` 

859 Catalog containing DiaSource records 

860 """ 

861 table = self._schema.get_table(ApdbTables.DiaSource) 

862 

863 # Insert history data 

864 history: list[dict] = [] 

865 history_stmt: Any = None 

866 if insert_id is not None: 

867 pk_names = [column.name for column in table.primary_key] 

868 history = sources[pk_names].to_dict("records") 

869 for row in history: 

870 row["insert_id"] = insert_id.id 

871 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

872 history_stmt = history_table.insert() 

873 

874 # everything to be done in single transaction 

875 with Timer("DiaSource insert", self.config.timer): 

876 sources = _coerce_uint64(sources) 

877 sources.to_sql( 

878 table.name, 

879 _ConnectionHackSA2(connection), 

880 if_exists="append", 

881 index=False, 

882 schema=table.schema, 

883 ) 

884 if history_stmt is not None: 

885 connection.execute(history_stmt, history) 

886 

887 def _storeDiaForcedSources( 

888 self, 

889 sources: pandas.DataFrame, 

890 insert_id: ApdbInsertId | None, 

891 connection: sqlalchemy.engine.Connection, 

892 ) -> None: 

893 """Store a set of DiaForcedSources from current visit. 

894 

895 Parameters 

896 ---------- 

897 sources : `pandas.DataFrame` 

898 Catalog containing DiaForcedSource records 

899 """ 

900 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

901 

902 # Insert history data 

903 history: list[dict] = [] 

904 history_stmt: Any = None 

905 if insert_id is not None: 

906 pk_names = [column.name for column in table.primary_key] 

907 history = sources[pk_names].to_dict("records") 

908 for row in history: 

909 row["insert_id"] = insert_id.id 

910 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

911 history_stmt = history_table.insert() 

912 

913 # everything to be done in single transaction 

914 with Timer("DiaForcedSource insert", self.config.timer): 

915 sources = _coerce_uint64(sources) 

916 sources.to_sql( 

917 table.name, 

918 _ConnectionHackSA2(connection), 

919 if_exists="append", 

920 index=False, 

921 schema=table.schema, 

922 ) 

923 if history_stmt is not None: 

924 connection.execute(history_stmt, history) 

925 

926 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

927 """Generate a set of HTM indices covering specified region. 

928 

929 Parameters 

930 ---------- 

931 region: `sphgeom.Region` 

932 Region that needs to be indexed. 

933 

934 Returns 

935 ------- 

936 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

937 """ 

938 _LOG.debug("region: %s", region) 

939 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

940 

941 return indices.ranges() 

942 

943 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

944 """Make SQLAlchemy expression for selecting records in a region.""" 

945 htm_index_column = table.columns[self.config.htm_index_column] 

946 exprlist = [] 

947 pixel_ranges = self._htm_indices(region) 

948 for low, upper in pixel_ranges: 

949 upper -= 1 

950 if low == upper: 

951 exprlist.append(htm_index_column == low) 

952 else: 

953 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

954 

955 return sql.expression.or_(*exprlist) 

956 

957 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

958 """Calculate HTM index for each record and add it to a DataFrame. 

959 

960 Notes 

961 ----- 

962 This overrides any existing column in a DataFrame with the same name 

963 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

964 returned. 

965 """ 

966 # calculate HTM index for every DiaObject 

967 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

968 ra_col, dec_col = self.config.ra_dec_columns 

969 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

970 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

971 idx = self.pixelator.index(uv3d) 

972 htm_index[i] = idx 

973 df = df.copy() 

974 df[self.config.htm_index_column] = htm_index 

975 return df 

976 

977 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

978 """Add pixelId column to DiaSource catalog. 

979 

980 Notes 

981 ----- 

982 This method copies pixelId value from a matching DiaObject record. 

983 DiaObject catalog needs to have a pixelId column filled by 

984 ``_add_obj_htm_index`` method and DiaSource records need to be 

985 associated to DiaObjects via ``diaObjectId`` column. 

986 

987 This overrides any existing column in a DataFrame with the same name 

988 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

989 returned. 

990 """ 

991 pixel_id_map: Dict[int, int] = { 

992 diaObjectId: pixelId 

993 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

994 } 

995 # DiaSources associated with SolarSystemObjects do not have an 

996 # associated DiaObject hence we skip them and set their htmIndex 

997 # value to 0. 

998 pixel_id_map[0] = 0 

999 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1000 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1001 htm_index[i] = pixel_id_map[diaObjId] 

1002 sources = sources.copy() 

1003 sources[self.config.htm_index_column] = htm_index 

1004 return sources