Coverage for python/lsst/dax/apdb/apdbSql.py: 15%

425 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-28 09:32 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from typing import Any, Dict, List, Optional, Tuple, cast 

32 

33import lsst.daf.base as dafBase 

34import numpy as np 

35import pandas 

36import sqlalchemy 

37from felis.simple import Table 

38from lsst.pex.config import ChoiceField, Field, ListField 

39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

40from lsst.utils.iteration import chunk_iterable 

41from sqlalchemy import func, inspection, sql 

42from sqlalchemy.engine import Inspector 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

46from .apdbSchema import ApdbTables 

47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

48from .timer import Timer 

49 

50_LOG = logging.getLogger(__name__) 

51 

52 

53class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

54 """Terrible hack to workaround Pandas incomplete support for sqlalchemy 2. 

55 

56 We need to pass a Connection instance to pandas method, but in SA 2 the 

57 Connection class lost ``connect`` method which is used by Pandas. 

58 """ 

59 

60 def __init__(self, connection: sqlalchemy.engine.Connection): 

61 self._connection = connection 

62 

63 def connect(self) -> Any: 

64 return self 

65 

66 @property 

67 def execute(self) -> Callable: 

68 return self._connection.execute 

69 

70 @property 

71 def execution_options(self) -> Callable: 

72 return self._connection.execution_options 

73 

74 @property 

75 def connection(self) -> Any: 

76 return self._connection.connection 

77 

78 def __enter__(self) -> sqlalchemy.engine.Connection: 

79 return self._connection 

80 

81 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

82 # Do not close connection here 

83 pass 

84 

85 

86@inspection._inspects(_ConnectionHackSA2) 

87def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

88 return Inspector._construct(Inspector._init_connection, conn._connection) 

89 

90 

91def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

92 """Change type of the uint64 columns to int64, return copy of data frame.""" 

93 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

94 return df.astype({name: np.int64 for name in names}) 

95 

96 

97def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

98 """Calculate starting point for time-based source search. 

99 

100 Parameters 

101 ---------- 

102 visit_time : `lsst.daf.base.DateTime` 

103 Time of current visit. 

104 months : `int` 

105 Number of months in the sources history. 

106 

107 Returns 

108 ------- 

109 time : `float` 

110 A ``midPointTai`` starting point, MJD time. 

111 """ 

112 # TODO: `system` must be consistent with the code in ap_association 

113 # (see DM-31996) 

114 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

115 

116 

117class ApdbSqlConfig(ApdbConfig): 

118 """APDB configuration class for SQL implementation (ApdbSql).""" 

119 

120 db_url = Field[str](doc="SQLAlchemy database connection URI") 

121 isolation_level = ChoiceField[str]( 

122 doc=( 

123 "Transaction isolation level, if unset then backend-default value " 

124 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

125 "Some backends may not support every allowed value." 

126 ), 

127 allowed={ 

128 "READ_COMMITTED": "Read committed", 

129 "READ_UNCOMMITTED": "Read uncommitted", 

130 "REPEATABLE_READ": "Repeatable read", 

131 "SERIALIZABLE": "Serializable", 

132 }, 

133 default=None, 

134 optional=True, 

135 ) 

136 connection_pool = Field[bool]( 

137 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

138 default=True, 

139 ) 

140 connection_timeout = Field[float]( 

141 doc=( 

142 "Maximum time to wait time for database lock to be released before exiting. " 

143 "Defaults to sqlalchemy defaults if not set." 

144 ), 

145 default=None, 

146 optional=True, 

147 ) 

148 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

149 dia_object_index = ChoiceField[str]( 

150 doc="Indexing mode for DiaObject table", 

151 allowed={ 

152 "baseline": "Index defined in baseline schema", 

153 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

154 "last_object_table": "Separate DiaObjectLast table", 

155 }, 

156 default="baseline", 

157 ) 

158 htm_level = Field[int](doc="HTM indexing level", default=20) 

159 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

160 htm_index_column = Field[str]( 

161 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

162 ) 

163 ra_dec_columns = ListField[str](default=["ra", "decl"], doc="Names of ra/dec columns in DiaObject table") 

164 dia_object_columns = ListField[str]( 

165 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

166 ) 

167 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

168 namespace = Field[str]( 

169 doc=( 

170 "Namespace or schema name for all tables in APDB database. " 

171 "Presently only works for PostgreSQL backend. " 

172 "If schema with this name does not exist it will be created when " 

173 "APDB tables are created." 

174 ), 

175 default=None, 

176 optional=True, 

177 ) 

178 timer = Field[bool](doc="If True then print/log timing information", default=False) 

179 

180 def validate(self) -> None: 

181 super().validate() 

182 if len(self.ra_dec_columns) != 2: 

183 raise ValueError("ra_dec_columns must have exactly two column names") 

184 

185 

186class ApdbSqlTableData(ApdbTableData): 

187 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

188 

189 def __init__(self, result: sqlalchemy.engine.Result): 

190 self._keys = list(result.keys()) 

191 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

192 

193 def column_names(self) -> list[str]: 

194 return self._keys 

195 

196 def rows(self) -> Iterable[tuple]: 

197 return self._rows 

198 

199 

200class ApdbSql(Apdb): 

201 """Implementation of APDB interface based on SQL database. 

202 

203 The implementation is configured via standard ``pex_config`` mechanism 

204 using `ApdbSqlConfig` configuration class. For an example of different 

205 configurations check ``config/`` folder. 

206 

207 Parameters 

208 ---------- 

209 config : `ApdbSqlConfig` 

210 Configuration object. 

211 """ 

212 

213 ConfigClass = ApdbSqlConfig 

214 

215 def __init__(self, config: ApdbSqlConfig): 

216 config.validate() 

217 self.config = config 

218 

219 _LOG.debug("APDB Configuration:") 

220 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

221 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

222 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

223 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

224 _LOG.debug(" schema_file: %s", self.config.schema_file) 

225 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

226 _LOG.debug(" schema prefix: %s", self.config.prefix) 

227 

228 # engine is reused between multiple processes, make sure that we don't 

229 # share connections by disabling pool (by using NullPool class) 

230 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

231 conn_args: Dict[str, Any] = dict() 

232 if not self.config.connection_pool: 

233 kw.update(poolclass=NullPool) 

234 if self.config.isolation_level is not None: 

235 kw.update(isolation_level=self.config.isolation_level) 

236 elif self.config.db_url.startswith("sqlite"): # type: ignore 

237 # Use READ_UNCOMMITTED as default value for sqlite. 

238 kw.update(isolation_level="READ_UNCOMMITTED") 

239 if self.config.connection_timeout is not None: 

240 if self.config.db_url.startswith("sqlite"): 

241 conn_args.update(timeout=self.config.connection_timeout) 

242 elif self.config.db_url.startswith(("postgresql", "mysql")): 

243 conn_args.update(connect_timeout=self.config.connection_timeout) 

244 kw.update(connect_args=conn_args) 

245 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

246 

247 self._schema = ApdbSqlSchema( 

248 engine=self._engine, 

249 dia_object_index=self.config.dia_object_index, 

250 schema_file=self.config.schema_file, 

251 schema_name=self.config.schema_name, 

252 prefix=self.config.prefix, 

253 namespace=self.config.namespace, 

254 htm_index_column=self.config.htm_index_column, 

255 use_insert_id=config.use_insert_id, 

256 ) 

257 

258 self.pixelator = HtmPixelization(self.config.htm_level) 

259 self.use_insert_id = self._schema.has_insert_id 

260 

261 def tableRowCount(self) -> Dict[str, int]: 

262 """Returns dictionary with the table names and row counts. 

263 

264 Used by ``ap_proto`` to keep track of the size of the database tables. 

265 Depending on database technology this could be expensive operation. 

266 

267 Returns 

268 ------- 

269 row_counts : `dict` 

270 Dict where key is a table name and value is a row count. 

271 """ 

272 res = {} 

273 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

274 if self.config.dia_object_index == "last_object_table": 

275 tables.append(ApdbTables.DiaObjectLast) 

276 with self._engine.begin() as conn: 

277 for table in tables: 

278 sa_table = self._schema.get_table(table) 

279 stmt = sql.select(func.count()).select_from(sa_table) 

280 count: int = conn.execute(stmt).scalar_one() 

281 res[table.name] = count 

282 

283 return res 

284 

285 def tableDef(self, table: ApdbTables) -> Optional[Table]: 

286 # docstring is inherited from a base class 

287 return self._schema.tableSchemas.get(table) 

288 

289 def makeSchema(self, drop: bool = False) -> None: 

290 # docstring is inherited from a base class 

291 self._schema.makeSchema(drop=drop) 

292 

293 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

294 # docstring is inherited from a base class 

295 

296 # decide what columns we need 

297 if self.config.dia_object_index == "last_object_table": 

298 table_enum = ApdbTables.DiaObjectLast 

299 else: 

300 table_enum = ApdbTables.DiaObject 

301 table = self._schema.get_table(table_enum) 

302 if not self.config.dia_object_columns: 

303 columns = self._schema.get_apdb_columns(table_enum) 

304 else: 

305 columns = [table.c[col] for col in self.config.dia_object_columns] 

306 query = sql.select(*columns) 

307 

308 # build selection 

309 query = query.where(self._filterRegion(table, region)) 

310 

311 # select latest version of objects 

312 if self.config.dia_object_index != "last_object_table": 

313 query = query.where(table.c.validityEnd == None) # noqa: E711 

314 

315 # _LOG.debug("query: %s", query) 

316 

317 # execute select 

318 with Timer("DiaObject select", self.config.timer): 

319 with self._engine.begin() as conn: 

320 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

321 _LOG.debug("found %s DiaObjects", len(objects)) 

322 return objects 

323 

324 def getDiaSources( 

325 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

326 ) -> Optional[pandas.DataFrame]: 

327 # docstring is inherited from a base class 

328 if self.config.read_sources_months == 0: 

329 _LOG.debug("Skip DiaSources fetching") 

330 return None 

331 

332 if object_ids is None: 

333 # region-based select 

334 return self._getDiaSourcesInRegion(region, visit_time) 

335 else: 

336 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

337 

338 def getDiaForcedSources( 

339 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime 

340 ) -> Optional[pandas.DataFrame]: 

341 """Return catalog of DiaForcedSource instances from a given region. 

342 

343 Parameters 

344 ---------- 

345 region : `lsst.sphgeom.Region` 

346 Region to search for DIASources. 

347 object_ids : iterable [ `int` ], optional 

348 List of DiaObject IDs to further constrain the set of returned 

349 sources. If list is empty then empty catalog is returned with a 

350 correct schema. 

351 visit_time : `lsst.daf.base.DateTime` 

352 Time of the current visit. 

353 

354 Returns 

355 ------- 

356 catalog : `pandas.DataFrame`, or `None` 

357 Catalog containing DiaSource records. `None` is returned if 

358 ``read_sources_months`` configuration parameter is set to 0. 

359 

360 Raises 

361 ------ 

362 NotImplementedError 

363 Raised if ``object_ids`` is `None`. 

364 

365 Notes 

366 ----- 

367 Even though base class allows `None` to be passed for ``object_ids``, 

368 this class requires ``object_ids`` to be not-`None`. 

369 `NotImplementedError` is raised if `None` is passed. 

370 

371 This method returns DiaForcedSource catalog for a region with additional 

372 filtering based on DiaObject IDs. Only a subset of DiaSource history 

373 is returned limited by ``read_forced_sources_months`` config parameter, 

374 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

375 is always returned with a correct schema (columns/types). 

376 """ 

377 

378 if self.config.read_forced_sources_months == 0: 

379 _LOG.debug("Skip DiaForceSources fetching") 

380 return None 

381 

382 if object_ids is None: 

383 # This implementation does not support region-based selection. 

384 raise NotImplementedError("Region-based selection is not supported") 

385 

386 # TODO: DateTime.MJD must be consistent with code in ap_association, 

387 # alternatively we can fill midPointTai ourselves in store() 

388 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

389 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

390 

391 with Timer("DiaForcedSource select", self.config.timer): 

392 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), midPointTai_start) 

393 

394 _LOG.debug("found %s DiaForcedSources", len(sources)) 

395 return sources 

396 

397 def getInsertIds(self) -> list[ApdbInsertId] | None: 

398 # docstring is inherited from a base class 

399 if not self._schema.has_insert_id: 

400 return None 

401 

402 table = self._schema.get_table(ExtraTables.DiaInsertId) 

403 assert table is not None, "has_insert_id=True means it must be defined" 

404 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"]) 

405 with Timer("DiaObject insert id select", self.config.timer): 

406 with self._engine.connect() as conn: 

407 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

408 return [ApdbInsertId(row) for row in result.scalars()] 

409 

410 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

411 # docstring is inherited from a base class 

412 if not self._schema.has_insert_id: 

413 raise ValueError("APDB is not configured for history storage") 

414 

415 table = self._schema.get_table(ExtraTables.DiaInsertId) 

416 

417 insert_ids = [id.id for id in ids] 

418 where_clause = table.columns["insert_id"].in_(insert_ids) 

419 stmt = table.delete().where(where_clause) 

420 with self._engine.begin() as conn: 

421 conn.execute(stmt) 

422 

423 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

424 # docstring is inherited from a base class 

425 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

426 

427 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

428 # docstring is inherited from a base class 

429 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

430 

431 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

432 # docstring is inherited from a base class 

433 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

434 

435 def _get_history( 

436 self, 

437 ids: Iterable[ApdbInsertId], 

438 table_enum: ApdbTables, 

439 history_table_enum: ExtraTables, 

440 ) -> ApdbTableData: 

441 """Common implementation of the history methods.""" 

442 if not self._schema.has_insert_id: 

443 raise ValueError("APDB is not configured for history retrieval") 

444 

445 table = self._schema.get_table(table_enum) 

446 history_table = self._schema.get_table(history_table_enum) 

447 

448 join = table.join(history_table) 

449 insert_ids = [id.id for id in ids] 

450 history_id_column = history_table.columns["insert_id"] 

451 apdb_columns = self._schema.get_apdb_columns(table_enum) 

452 where_clause = history_id_column.in_(insert_ids) 

453 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

454 

455 # execute select 

456 with Timer(f"{table.name} history select", self.config.timer): 

457 with self._engine.begin() as conn: 

458 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

459 return ApdbSqlTableData(result) 

460 

461 def getSSObjects(self) -> pandas.DataFrame: 

462 # docstring is inherited from a base class 

463 

464 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

465 query = sql.select(*columns) 

466 

467 # execute select 

468 with Timer("DiaObject select", self.config.timer): 

469 with self._engine.begin() as conn: 

470 objects = pandas.read_sql_query(query, conn) 

471 _LOG.debug("found %s SSObjects", len(objects)) 

472 return objects 

473 

474 def store( 

475 self, 

476 visit_time: dafBase.DateTime, 

477 objects: pandas.DataFrame, 

478 sources: Optional[pandas.DataFrame] = None, 

479 forced_sources: Optional[pandas.DataFrame] = None, 

480 ) -> None: 

481 # docstring is inherited from a base class 

482 

483 # We want to run all inserts in one transaction. 

484 with self._engine.begin() as connection: 

485 insert_id: ApdbInsertId | None = None 

486 if self._schema.has_insert_id: 

487 insert_id = ApdbInsertId.new_insert_id() 

488 self._storeInsertId(insert_id, visit_time, connection) 

489 

490 # fill pixelId column for DiaObjects 

491 objects = self._add_obj_htm_index(objects) 

492 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

493 

494 if sources is not None: 

495 # copy pixelId column from DiaObjects to DiaSources 

496 sources = self._add_src_htm_index(sources, objects) 

497 self._storeDiaSources(sources, insert_id, connection) 

498 

499 if forced_sources is not None: 

500 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

501 

502 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

503 # docstring is inherited from a base class 

504 

505 idColumn = "ssObjectId" 

506 table = self._schema.get_table(ApdbTables.SSObject) 

507 

508 # everything to be done in single transaction 

509 with self._engine.begin() as conn: 

510 # Find record IDs that already exist. Some types like np.int64 can 

511 # cause issues with sqlalchemy, convert them to int. 

512 ids = sorted(int(oid) for oid in objects[idColumn]) 

513 

514 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

515 result = conn.execute(query) 

516 knownIds = set(row.ssObjectId for row in result) 

517 

518 filter = objects[idColumn].isin(knownIds) 

519 toUpdate = cast(pandas.DataFrame, objects[filter]) 

520 toInsert = cast(pandas.DataFrame, objects[~filter]) 

521 

522 # insert new records 

523 if len(toInsert) > 0: 

524 toInsert.to_sql( 

525 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

526 ) 

527 

528 # update existing records 

529 if len(toUpdate) > 0: 

530 whereKey = f"{idColumn}_param" 

531 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

532 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

533 values = toUpdate.to_dict("records") 

534 result = conn.execute(update, values) 

535 

536 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

537 # docstring is inherited from a base class 

538 

539 table = self._schema.get_table(ApdbTables.DiaSource) 

540 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

541 

542 with self._engine.begin() as conn: 

543 # Need to make sure that every ID exists in the database, but 

544 # executemany may not support rowcount, so iterate and check what is 

545 # missing. 

546 missing_ids: List[int] = [] 

547 for key, value in idMap.items(): 

548 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

549 result = conn.execute(query, params) 

550 if result.rowcount == 0: 

551 missing_ids.append(key) 

552 if missing_ids: 

553 missing = ",".join(str(item) for item in missing_ids) 

554 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

555 

556 def dailyJob(self) -> None: 

557 # docstring is inherited from a base class 

558 pass 

559 

560 def countUnassociatedObjects(self) -> int: 

561 # docstring is inherited from a base class 

562 

563 # Retrieve the DiaObject table. 

564 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

565 

566 # Construct the sql statement. 

567 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

568 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

569 

570 # Return the count. 

571 with self._engine.begin() as conn: 

572 count = conn.execute(stmt).scalar_one() 

573 

574 return count 

575 

576 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

577 """Returns catalog of DiaSource instances from given region. 

578 

579 Parameters 

580 ---------- 

581 region : `lsst.sphgeom.Region` 

582 Region to search for DIASources. 

583 visit_time : `lsst.daf.base.DateTime` 

584 Time of the current visit. 

585 

586 Returns 

587 ------- 

588 catalog : `pandas.DataFrame` 

589 Catalog containing DiaSource records. 

590 """ 

591 # TODO: DateTime.MJD must be consistent with code in ap_association, 

592 # alternatively we can fill midPointTai ourselves in store() 

593 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

594 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

595 

596 table = self._schema.get_table(ApdbTables.DiaSource) 

597 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

598 query = sql.select(*columns) 

599 

600 # build selection 

601 time_filter = table.columns["midPointTai"] > midPointTai_start 

602 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

603 query = query.where(where) 

604 

605 # execute select 

606 with Timer("DiaSource select", self.config.timer): 

607 with self._engine.begin() as conn: 

608 sources = pandas.read_sql_query(query, conn) 

609 _LOG.debug("found %s DiaSources", len(sources)) 

610 return sources 

611 

612 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

613 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

614 

615 Parameters 

616 ---------- 

617 object_ids : 

618 Collection of DiaObject IDs 

619 visit_time : `lsst.daf.base.DateTime` 

620 Time of the current visit. 

621 

622 Returns 

623 ------- 

624 catalog : `pandas.DataFrame` 

625 Catalog contaning DiaSource records. 

626 """ 

627 # TODO: DateTime.MJD must be consistent with code in ap_association, 

628 # alternatively we can fill midPointTai ourselves in store() 

629 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

630 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

631 

632 with Timer("DiaSource select", self.config.timer): 

633 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midPointTai_start) 

634 

635 _LOG.debug("found %s DiaSources", len(sources)) 

636 return sources 

637 

638 def _getSourcesByIDs( 

639 self, table_enum: ApdbTables, object_ids: List[int], midPointTai_start: float 

640 ) -> pandas.DataFrame: 

641 """Returns catalog of DiaSource or DiaForcedSource instances given set 

642 of DiaObject IDs. 

643 

644 Parameters 

645 ---------- 

646 table : `sqlalchemy.schema.Table` 

647 Database table. 

648 object_ids : 

649 Collection of DiaObject IDs 

650 midPointTai_start : `float` 

651 Earliest midPointTai to retrieve. 

652 

653 Returns 

654 ------- 

655 catalog : `pandas.DataFrame` 

656 Catalog contaning DiaSource records. `None` is returned if 

657 ``read_sources_months`` configuration parameter is set to 0 or 

658 when ``object_ids`` is empty. 

659 """ 

660 table = self._schema.get_table(table_enum) 

661 columns = self._schema.get_apdb_columns(table_enum) 

662 

663 sources: Optional[pandas.DataFrame] = None 

664 if len(object_ids) <= 0: 

665 _LOG.debug("ID list is empty, just fetch empty result") 

666 query = sql.select(*columns).where(sql.literal(False)) 

667 with self._engine.begin() as conn: 

668 sources = pandas.read_sql_query(query, conn) 

669 else: 

670 data_frames: list[pandas.DataFrame] = [] 

671 for ids in chunk_iterable(sorted(object_ids), 1000): 

672 query = sql.select(*columns) 

673 

674 # Some types like np.int64 can cause issues with 

675 # sqlalchemy, convert them to int. 

676 int_ids = [int(oid) for oid in ids] 

677 

678 # select by object id 

679 query = query.where( 

680 sql.expression.and_( 

681 table.columns["diaObjectId"].in_(int_ids), 

682 table.columns["midPointTai"] > midPointTai_start, 

683 ) 

684 ) 

685 

686 # execute select 

687 with self._engine.begin() as conn: 

688 data_frames.append(pandas.read_sql_query(query, conn)) 

689 

690 if len(data_frames) == 1: 

691 sources = data_frames[0] 

692 else: 

693 sources = pandas.concat(data_frames) 

694 assert sources is not None, "Catalog cannot be None" 

695 return sources 

696 

697 def _storeInsertId( 

698 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

699 ) -> None: 

700 dt = visit_time.toPython() 

701 

702 table = self._schema.get_table(ExtraTables.DiaInsertId) 

703 

704 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

705 connection.execute(stmt) 

706 

707 def _storeDiaObjects( 

708 self, 

709 objs: pandas.DataFrame, 

710 visit_time: dafBase.DateTime, 

711 insert_id: ApdbInsertId | None, 

712 connection: sqlalchemy.engine.Connection, 

713 ) -> None: 

714 """Store catalog of DiaObjects from current visit. 

715 

716 Parameters 

717 ---------- 

718 objs : `pandas.DataFrame` 

719 Catalog with DiaObject records. 

720 visit_time : `lsst.daf.base.DateTime` 

721 Time of the visit. 

722 insert_id : `ApdbInsertId` 

723 Insert identifier. 

724 """ 

725 

726 # Some types like np.int64 can cause issues with sqlalchemy, convert 

727 # them to int. 

728 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

729 _LOG.debug("first object ID: %d", ids[0]) 

730 

731 # TODO: Need to verify that we are using correct scale here for 

732 # DATETIME representation (see DM-31996). 

733 dt = visit_time.toPython() 

734 

735 # everything to be done in single transaction 

736 if self.config.dia_object_index == "last_object_table": 

737 # insert and replace all records in LAST table, mysql and postgres have 

738 # non-standard features 

739 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

740 

741 # Drop the previous objects (pandas cannot upsert). 

742 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

743 

744 with Timer(table.name + " delete", self.config.timer): 

745 res = connection.execute(query) 

746 _LOG.debug("deleted %s objects", res.rowcount) 

747 

748 # DiaObjectLast is a subset of DiaObject, strip missing columns 

749 last_column_names = [column.name for column in table.columns] 

750 last_objs = objs[last_column_names] 

751 last_objs = _coerce_uint64(last_objs) 

752 

753 if "lastNonForcedSource" in last_objs.columns: 

754 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

755 # just in case. 

756 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

757 else: 

758 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

759 last_objs.set_index(extra_column.index, inplace=True) 

760 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

761 

762 with Timer("DiaObjectLast insert", self.config.timer): 

763 last_objs.to_sql( 

764 table.name, 

765 _ConnectionHackSA2(connection), 

766 if_exists="append", 

767 index=False, 

768 schema=table.schema, 

769 ) 

770 else: 

771 # truncate existing validity intervals 

772 table = self._schema.get_table(ApdbTables.DiaObject) 

773 

774 update = ( 

775 table.update() 

776 .values(validityEnd=dt) 

777 .where( 

778 sql.expression.and_( 

779 table.columns["diaObjectId"].in_(ids), 

780 table.columns["validityEnd"].is_(None), 

781 ) 

782 ) 

783 ) 

784 

785 # _LOG.debug("query: %s", query) 

786 

787 with Timer(table.name + " truncate", self.config.timer): 

788 res = connection.execute(update) 

789 _LOG.debug("truncated %s intervals", res.rowcount) 

790 

791 objs = _coerce_uint64(objs) 

792 

793 # Fill additional columns 

794 extra_columns: List[pandas.Series] = [] 

795 if "validityStart" in objs.columns: 

796 objs["validityStart"] = dt 

797 else: 

798 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

799 if "validityEnd" in objs.columns: 

800 objs["validityEnd"] = None 

801 else: 

802 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

803 if "lastNonForcedSource" in objs.columns: 

804 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

805 # just in case. 

806 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

807 else: 

808 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

809 if extra_columns: 

810 objs.set_index(extra_columns[0].index, inplace=True) 

811 objs = pandas.concat([objs] + extra_columns, axis="columns") 

812 

813 # Insert history data 

814 table = self._schema.get_table(ApdbTables.DiaObject) 

815 history_data: list[dict] = [] 

816 history_stmt: Any = None 

817 if insert_id is not None: 

818 pk_names = [column.name for column in table.primary_key] 

819 history_data = objs[pk_names].to_dict("records") 

820 for row in history_data: 

821 row["insert_id"] = insert_id.id 

822 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

823 history_stmt = history_table.insert() 

824 

825 # insert new versions 

826 with Timer("DiaObject insert", self.config.timer): 

827 objs.to_sql( 

828 table.name, 

829 _ConnectionHackSA2(connection), 

830 if_exists="append", 

831 index=False, 

832 schema=table.schema, 

833 ) 

834 if history_stmt is not None: 

835 connection.execute(history_stmt, history_data) 

836 

837 def _storeDiaSources( 

838 self, 

839 sources: pandas.DataFrame, 

840 insert_id: ApdbInsertId | None, 

841 connection: sqlalchemy.engine.Connection, 

842 ) -> None: 

843 """Store catalog of DiaSources from current visit. 

844 

845 Parameters 

846 ---------- 

847 sources : `pandas.DataFrame` 

848 Catalog containing DiaSource records 

849 """ 

850 table = self._schema.get_table(ApdbTables.DiaSource) 

851 

852 # Insert history data 

853 history: list[dict] = [] 

854 history_stmt: Any = None 

855 if insert_id is not None: 

856 pk_names = [column.name for column in table.primary_key] 

857 history = sources[pk_names].to_dict("records") 

858 for row in history: 

859 row["insert_id"] = insert_id.id 

860 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

861 history_stmt = history_table.insert() 

862 

863 # everything to be done in single transaction 

864 with Timer("DiaSource insert", self.config.timer): 

865 sources = _coerce_uint64(sources) 

866 sources.to_sql( 

867 table.name, 

868 _ConnectionHackSA2(connection), 

869 if_exists="append", 

870 index=False, 

871 schema=table.schema, 

872 ) 

873 if history_stmt is not None: 

874 connection.execute(history_stmt, history) 

875 

876 def _storeDiaForcedSources( 

877 self, 

878 sources: pandas.DataFrame, 

879 insert_id: ApdbInsertId | None, 

880 connection: sqlalchemy.engine.Connection, 

881 ) -> None: 

882 """Store a set of DiaForcedSources from current visit. 

883 

884 Parameters 

885 ---------- 

886 sources : `pandas.DataFrame` 

887 Catalog containing DiaForcedSource records 

888 """ 

889 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

890 

891 # Insert history data 

892 history: list[dict] = [] 

893 history_stmt: Any = None 

894 if insert_id is not None: 

895 pk_names = [column.name for column in table.primary_key] 

896 history = sources[pk_names].to_dict("records") 

897 for row in history: 

898 row["insert_id"] = insert_id.id 

899 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

900 history_stmt = history_table.insert() 

901 

902 # everything to be done in single transaction 

903 with Timer("DiaForcedSource insert", self.config.timer): 

904 sources = _coerce_uint64(sources) 

905 sources.to_sql( 

906 table.name, 

907 _ConnectionHackSA2(connection), 

908 if_exists="append", 

909 index=False, 

910 schema=table.schema, 

911 ) 

912 if history_stmt is not None: 

913 connection.execute(history_stmt, history) 

914 

915 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

916 """Generate a set of HTM indices covering specified region. 

917 

918 Parameters 

919 ---------- 

920 region: `sphgeom.Region` 

921 Region that needs to be indexed. 

922 

923 Returns 

924 ------- 

925 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

926 """ 

927 _LOG.debug("region: %s", region) 

928 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

929 

930 return indices.ranges() 

931 

932 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

933 """Make SQLAlchemy expression for selecting records in a region.""" 

934 htm_index_column = table.columns[self.config.htm_index_column] 

935 exprlist = [] 

936 pixel_ranges = self._htm_indices(region) 

937 for low, upper in pixel_ranges: 

938 upper -= 1 

939 if low == upper: 

940 exprlist.append(htm_index_column == low) 

941 else: 

942 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

943 

944 return sql.expression.or_(*exprlist) 

945 

946 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

947 """Calculate HTM index for each record and add it to a DataFrame. 

948 

949 Notes 

950 ----- 

951 This overrides any existing column in a DataFrame with the same name 

952 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

953 returned. 

954 """ 

955 # calculate HTM index for every DiaObject 

956 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

957 ra_col, dec_col = self.config.ra_dec_columns 

958 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

959 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

960 idx = self.pixelator.index(uv3d) 

961 htm_index[i] = idx 

962 df = df.copy() 

963 df[self.config.htm_index_column] = htm_index 

964 return df 

965 

966 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

967 """Add pixelId column to DiaSource catalog. 

968 

969 Notes 

970 ----- 

971 This method copies pixelId value from a matching DiaObject record. 

972 DiaObject catalog needs to have a pixelId column filled by 

973 ``_add_obj_htm_index`` method and DiaSource records need to be 

974 associated to DiaObjects via ``diaObjectId`` column. 

975 

976 This overrides any existing column in a DataFrame with the same name 

977 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

978 returned. 

979 """ 

980 pixel_id_map: Dict[int, int] = { 

981 diaObjectId: pixelId 

982 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

983 } 

984 # DiaSources associated with SolarSystemObjects do not have an 

985 # associated DiaObject hence we skip them and set their htmIndex 

986 # value to 0. 

987 pixel_id_map[0] = 0 

988 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

989 for i, diaObjId in enumerate(sources["diaObjectId"]): 

990 htm_index[i] = pixel_id_map[diaObjId] 

991 sources = sources.copy() 

992 sources[self.config.htm_index_column] = htm_index 

993 return sources