Coverage for python/lsst/dax/apdb/apdbSql.py: 13%

453 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-21 11:47 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from contextlib import closing 

32from typing import TYPE_CHECKING, Any, cast 

33 

34import lsst.daf.base as dafBase 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38from felis.simple import Table 

39from lsst.pex.config import ChoiceField, Field, ListField 

40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

41from lsst.utils.iteration import chunk_iterable 

42from sqlalchemy import func, inspection, sql 

43from sqlalchemy.engine import Inspector 

44from sqlalchemy.pool import NullPool 

45 

46from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

47from .apdbSchema import ApdbTables 

48from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

49from .timer import Timer 

50 

51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 import sqlite3 

53 

54_LOG = logging.getLogger(__name__) 

55 

56 

57if pandas.__version__.partition(".")[0] == "1": 57 ↛ 59line 57 didn't jump to line 59, because the condition on line 57 was never true

58 

59 class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

60 """Terrible hack to workaround Pandas 1 incomplete support for 

61 sqlalchemy 2. 

62 

63 We need to pass a Connection instance to pandas method, but in SA 2 the 

64 Connection class lost ``connect`` method which is used by Pandas. 

65 """ 

66 

67 def __init__(self, connection: sqlalchemy.engine.Connection): 

68 self._connection = connection 

69 

70 def connect(self, **kwargs: Any) -> Any: 

71 return self 

72 

73 @property 

74 def execute(self) -> Callable: 

75 return self._connection.execute 

76 

77 @property 

78 def execution_options(self) -> Callable: 

79 return self._connection.execution_options 

80 

81 @property 

82 def connection(self) -> Any: 

83 return self._connection.connection 

84 

85 def __enter__(self) -> sqlalchemy.engine.Connection: 

86 return self._connection 

87 

88 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

89 # Do not close connection here 

90 pass 

91 

92 @inspection._inspects(_ConnectionHackSA2) 

93 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

94 return Inspector._construct(Inspector._init_connection, conn._connection) 

95 

96else: 

97 # Pandas 2.0 supports SQLAlchemy 2 correctly. 

98 def _ConnectionHackSA2( # type: ignore[no-redef] 

99 conn: sqlalchemy.engine.Connectable, 

100 ) -> sqlalchemy.engine.Connectable: 

101 return conn 

102 

103 

104def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

105 """Change the type of uint64 columns to int64, and return copy of data 

106 frame. 

107 """ 

108 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

109 return df.astype({name: np.int64 for name in names}) 

110 

111 

112def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

113 """Calculate starting point for time-based source search. 

114 

115 Parameters 

116 ---------- 

117 visit_time : `lsst.daf.base.DateTime` 

118 Time of current visit. 

119 months : `int` 

120 Number of months in the sources history. 

121 

122 Returns 

123 ------- 

124 time : `float` 

125 A ``midpointMjdTai`` starting point, MJD time. 

126 """ 

127 # TODO: `system` must be consistent with the code in ap_association 

128 # (see DM-31996) 

129 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

130 

131 

132def _onSqlite3Connect( 

133 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

134) -> None: 

135 # Enable foreign keys 

136 with closing(dbapiConnection.cursor()) as cursor: 

137 cursor.execute("PRAGMA foreign_keys=ON;") 

138 

139 

140class ApdbSqlConfig(ApdbConfig): 

141 """APDB configuration class for SQL implementation (ApdbSql).""" 

142 

143 db_url = Field[str](doc="SQLAlchemy database connection URI") 

144 isolation_level = ChoiceField[str]( 

145 doc=( 

146 "Transaction isolation level, if unset then backend-default value " 

147 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

148 "Some backends may not support every allowed value." 

149 ), 

150 allowed={ 

151 "READ_COMMITTED": "Read committed", 

152 "READ_UNCOMMITTED": "Read uncommitted", 

153 "REPEATABLE_READ": "Repeatable read", 

154 "SERIALIZABLE": "Serializable", 

155 }, 

156 default=None, 

157 optional=True, 

158 ) 

159 connection_pool = Field[bool]( 

160 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

161 default=True, 

162 ) 

163 connection_timeout = Field[float]( 

164 doc=( 

165 "Maximum time to wait time for database lock to be released before exiting. " 

166 "Defaults to sqlalchemy defaults if not set." 

167 ), 

168 default=None, 

169 optional=True, 

170 ) 

171 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

172 dia_object_index = ChoiceField[str]( 

173 doc="Indexing mode for DiaObject table", 

174 allowed={ 

175 "baseline": "Index defined in baseline schema", 

176 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

177 "last_object_table": "Separate DiaObjectLast table", 

178 }, 

179 default="baseline", 

180 ) 

181 htm_level = Field[int](doc="HTM indexing level", default=20) 

182 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

183 htm_index_column = Field[str]( 

184 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

185 ) 

186 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

187 dia_object_columns = ListField[str]( 

188 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

189 ) 

190 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

191 namespace = Field[str]( 

192 doc=( 

193 "Namespace or schema name for all tables in APDB database. " 

194 "Presently only works for PostgreSQL backend. " 

195 "If schema with this name does not exist it will be created when " 

196 "APDB tables are created." 

197 ), 

198 default=None, 

199 optional=True, 

200 ) 

201 timer = Field[bool](doc="If True then print/log timing information", default=False) 

202 

203 def validate(self) -> None: 

204 super().validate() 

205 if len(self.ra_dec_columns) != 2: 

206 raise ValueError("ra_dec_columns must have exactly two column names") 

207 

208 

209class ApdbSqlTableData(ApdbTableData): 

210 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

211 

212 def __init__(self, result: sqlalchemy.engine.Result): 

213 self._keys = list(result.keys()) 

214 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

215 

216 def column_names(self) -> list[str]: 

217 return self._keys 

218 

219 def rows(self) -> Iterable[tuple]: 

220 return self._rows 

221 

222 

223class ApdbSql(Apdb): 

224 """Implementation of APDB interface based on SQL database. 

225 

226 The implementation is configured via standard ``pex_config`` mechanism 

227 using `ApdbSqlConfig` configuration class. For an example of different 

228 configurations check ``config/`` folder. 

229 

230 Parameters 

231 ---------- 

232 config : `ApdbSqlConfig` 

233 Configuration object. 

234 """ 

235 

236 ConfigClass = ApdbSqlConfig 

237 

238 def __init__(self, config: ApdbSqlConfig): 

239 config.validate() 

240 self.config = config 

241 

242 _LOG.debug("APDB Configuration:") 

243 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

244 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

245 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

246 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

247 _LOG.debug(" schema_file: %s", self.config.schema_file) 

248 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

249 _LOG.debug(" schema prefix: %s", self.config.prefix) 

250 

251 # engine is reused between multiple processes, make sure that we don't 

252 # share connections by disabling pool (by using NullPool class) 

253 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

254 conn_args: dict[str, Any] = dict() 

255 if not self.config.connection_pool: 

256 kw.update(poolclass=NullPool) 

257 if self.config.isolation_level is not None: 

258 kw.update(isolation_level=self.config.isolation_level) 

259 elif self.config.db_url.startswith("sqlite"): # type: ignore 

260 # Use READ_UNCOMMITTED as default value for sqlite. 

261 kw.update(isolation_level="READ_UNCOMMITTED") 

262 if self.config.connection_timeout is not None: 

263 if self.config.db_url.startswith("sqlite"): 

264 conn_args.update(timeout=self.config.connection_timeout) 

265 elif self.config.db_url.startswith(("postgresql", "mysql")): 

266 conn_args.update(connect_timeout=self.config.connection_timeout) 

267 kw.update(connect_args=conn_args) 

268 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

269 

270 if self._engine.dialect.name == "sqlite": 

271 # Need to enable foreign keys on every new connection. 

272 sqlalchemy.event.listen(self._engine, "connect", _onSqlite3Connect) 

273 

274 self._schema = ApdbSqlSchema( 

275 engine=self._engine, 

276 dia_object_index=self.config.dia_object_index, 

277 schema_file=self.config.schema_file, 

278 schema_name=self.config.schema_name, 

279 prefix=self.config.prefix, 

280 namespace=self.config.namespace, 

281 htm_index_column=self.config.htm_index_column, 

282 use_insert_id=config.use_insert_id, 

283 ) 

284 

285 self.pixelator = HtmPixelization(self.config.htm_level) 

286 self.use_insert_id = self._schema.has_insert_id 

287 

288 def tableRowCount(self) -> dict[str, int]: 

289 """Return dictionary with the table names and row counts. 

290 

291 Used by ``ap_proto`` to keep track of the size of the database tables. 

292 Depending on database technology this could be expensive operation. 

293 

294 Returns 

295 ------- 

296 row_counts : `dict` 

297 Dict where key is a table name and value is a row count. 

298 """ 

299 res = {} 

300 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

301 if self.config.dia_object_index == "last_object_table": 

302 tables.append(ApdbTables.DiaObjectLast) 

303 with self._engine.begin() as conn: 

304 for table in tables: 

305 sa_table = self._schema.get_table(table) 

306 stmt = sql.select(func.count()).select_from(sa_table) 

307 count: int = conn.execute(stmt).scalar_one() 

308 res[table.name] = count 

309 

310 return res 

311 

312 def tableDef(self, table: ApdbTables) -> Table | None: 

313 # docstring is inherited from a base class 

314 return self._schema.tableSchemas.get(table) 

315 

316 def makeSchema(self, drop: bool = False) -> None: 

317 # docstring is inherited from a base class 

318 self._schema.makeSchema(drop=drop) 

319 

320 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

321 # docstring is inherited from a base class 

322 

323 # decide what columns we need 

324 if self.config.dia_object_index == "last_object_table": 

325 table_enum = ApdbTables.DiaObjectLast 

326 else: 

327 table_enum = ApdbTables.DiaObject 

328 table = self._schema.get_table(table_enum) 

329 if not self.config.dia_object_columns: 

330 columns = self._schema.get_apdb_columns(table_enum) 

331 else: 

332 columns = [table.c[col] for col in self.config.dia_object_columns] 

333 query = sql.select(*columns) 

334 

335 # build selection 

336 query = query.where(self._filterRegion(table, region)) 

337 

338 # select latest version of objects 

339 if self.config.dia_object_index != "last_object_table": 

340 query = query.where(table.c.validityEnd == None) # noqa: E711 

341 

342 # _LOG.debug("query: %s", query) 

343 

344 # execute select 

345 with Timer("DiaObject select", self.config.timer): 

346 with self._engine.begin() as conn: 

347 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

348 _LOG.debug("found %s DiaObjects", len(objects)) 

349 return objects 

350 

351 def getDiaSources( 

352 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

353 ) -> pandas.DataFrame | None: 

354 # docstring is inherited from a base class 

355 if self.config.read_sources_months == 0: 

356 _LOG.debug("Skip DiaSources fetching") 

357 return None 

358 

359 if object_ids is None: 

360 # region-based select 

361 return self._getDiaSourcesInRegion(region, visit_time) 

362 else: 

363 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

364 

365 def getDiaForcedSources( 

366 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

367 ) -> pandas.DataFrame | None: 

368 # docstring is inherited from a base class 

369 if self.config.read_forced_sources_months == 0: 

370 _LOG.debug("Skip DiaForceSources fetching") 

371 return None 

372 

373 if object_ids is None: 

374 # This implementation does not support region-based selection. 

375 raise NotImplementedError("Region-based selection is not supported") 

376 

377 # TODO: DateTime.MJD must be consistent with code in ap_association, 

378 # alternatively we can fill midpointMjdTai ourselves in store() 

379 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

380 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

381 

382 with Timer("DiaForcedSource select", self.config.timer): 

383 sources = self._getSourcesByIDs( 

384 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

385 ) 

386 

387 _LOG.debug("found %s DiaForcedSources", len(sources)) 

388 return sources 

389 

390 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

391 # docstring is inherited from a base class 

392 raise NotImplementedError() 

393 

394 def containsCcdVisit(self, ccdVisitId: int) -> bool: 

395 """Test whether data for a given visit-detector is present in the APDB. 

396 

397 This method is a placeholder until `Apdb.containsVisitDetector` can 

398 be implemented. 

399 

400 Parameters 

401 ---------- 

402 ccdVisitId : `int` 

403 The packed ID of the visit-detector to search for. 

404 

405 Returns 

406 ------- 

407 present : `bool` 

408 `True` if some DiaSource records exist for the specified 

409 observation, `False` otherwise. 

410 """ 

411 # TODO: remove this method in favor of containsVisitDetector on either 

412 # DM-41671 or a ticket that removes ccdVisitId from these tables 

413 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

414 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

415 # Query should load only one leaf page of the index 

416 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

417 # Backup query in case an image was processed but had no diaSources 

418 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

419 

420 with self._engine.begin() as conn: 

421 result = conn.execute(query1).scalar_one_or_none() 

422 if result is not None: 

423 return True 

424 else: 

425 result = conn.execute(query2).scalar_one_or_none() 

426 return result is not None 

427 

428 def getInsertIds(self) -> list[ApdbInsertId] | None: 

429 # docstring is inherited from a base class 

430 if not self._schema.has_insert_id: 

431 return None 

432 

433 table = self._schema.get_table(ExtraTables.DiaInsertId) 

434 assert table is not None, "has_insert_id=True means it must be defined" 

435 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by( 

436 table.columns["insert_time"] 

437 ) 

438 with Timer("DiaObject insert id select", self.config.timer): 

439 with self._engine.connect() as conn: 

440 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

441 ids = [] 

442 for row in result: 

443 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9)) 

444 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time)) 

445 return ids 

446 

447 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

448 # docstring is inherited from a base class 

449 if not self._schema.has_insert_id: 

450 raise ValueError("APDB is not configured for history storage") 

451 

452 table = self._schema.get_table(ExtraTables.DiaInsertId) 

453 

454 insert_ids = [id.id for id in ids] 

455 where_clause = table.columns["insert_id"].in_(insert_ids) 

456 stmt = table.delete().where(where_clause) 

457 with self._engine.begin() as conn: 

458 conn.execute(stmt) 

459 

460 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

461 # docstring is inherited from a base class 

462 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

463 

464 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

465 # docstring is inherited from a base class 

466 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

467 

468 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

469 # docstring is inherited from a base class 

470 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

471 

472 def _get_history( 

473 self, 

474 ids: Iterable[ApdbInsertId], 

475 table_enum: ApdbTables, 

476 history_table_enum: ExtraTables, 

477 ) -> ApdbTableData: 

478 """Return catalog of records for given insert identifiers, common 

479 implementation for all DIA tables. 

480 """ 

481 if not self._schema.has_insert_id: 

482 raise ValueError("APDB is not configured for history retrieval") 

483 

484 table = self._schema.get_table(table_enum) 

485 history_table = self._schema.get_table(history_table_enum) 

486 

487 join = table.join(history_table) 

488 insert_ids = [id.id for id in ids] 

489 history_id_column = history_table.columns["insert_id"] 

490 apdb_columns = self._schema.get_apdb_columns(table_enum) 

491 where_clause = history_id_column.in_(insert_ids) 

492 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

493 

494 # execute select 

495 with Timer(f"{table.name} history select", self.config.timer): 

496 with self._engine.begin() as conn: 

497 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

498 return ApdbSqlTableData(result) 

499 

500 def getSSObjects(self) -> pandas.DataFrame: 

501 # docstring is inherited from a base class 

502 

503 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

504 query = sql.select(*columns) 

505 

506 # execute select 

507 with Timer("DiaObject select", self.config.timer): 

508 with self._engine.begin() as conn: 

509 objects = pandas.read_sql_query(query, conn) 

510 _LOG.debug("found %s SSObjects", len(objects)) 

511 return objects 

512 

513 def store( 

514 self, 

515 visit_time: dafBase.DateTime, 

516 objects: pandas.DataFrame, 

517 sources: pandas.DataFrame | None = None, 

518 forced_sources: pandas.DataFrame | None = None, 

519 ) -> None: 

520 # docstring is inherited from a base class 

521 

522 # We want to run all inserts in one transaction. 

523 with self._engine.begin() as connection: 

524 insert_id: ApdbInsertId | None = None 

525 if self._schema.has_insert_id: 

526 insert_id = ApdbInsertId.new_insert_id(visit_time) 

527 self._storeInsertId(insert_id, visit_time, connection) 

528 

529 # fill pixelId column for DiaObjects 

530 objects = self._add_obj_htm_index(objects) 

531 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

532 

533 if sources is not None: 

534 # copy pixelId column from DiaObjects to DiaSources 

535 sources = self._add_src_htm_index(sources, objects) 

536 self._storeDiaSources(sources, insert_id, connection) 

537 

538 if forced_sources is not None: 

539 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

540 

541 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

542 # docstring is inherited from a base class 

543 

544 idColumn = "ssObjectId" 

545 table = self._schema.get_table(ApdbTables.SSObject) 

546 

547 # everything to be done in single transaction 

548 with self._engine.begin() as conn: 

549 # Find record IDs that already exist. Some types like np.int64 can 

550 # cause issues with sqlalchemy, convert them to int. 

551 ids = sorted(int(oid) for oid in objects[idColumn]) 

552 

553 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

554 result = conn.execute(query) 

555 knownIds = set(row.ssObjectId for row in result) 

556 

557 filter = objects[idColumn].isin(knownIds) 

558 toUpdate = cast(pandas.DataFrame, objects[filter]) 

559 toInsert = cast(pandas.DataFrame, objects[~filter]) 

560 

561 # insert new records 

562 if len(toInsert) > 0: 

563 toInsert.to_sql( 

564 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

565 ) 

566 

567 # update existing records 

568 if len(toUpdate) > 0: 

569 whereKey = f"{idColumn}_param" 

570 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

571 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

572 values = toUpdate.to_dict("records") 

573 result = conn.execute(update, values) 

574 

575 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

576 # docstring is inherited from a base class 

577 

578 table = self._schema.get_table(ApdbTables.DiaSource) 

579 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

580 

581 with self._engine.begin() as conn: 

582 # Need to make sure that every ID exists in the database, but 

583 # executemany may not support rowcount, so iterate and check what 

584 # is missing. 

585 missing_ids: list[int] = [] 

586 for key, value in idMap.items(): 

587 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

588 result = conn.execute(query, params) 

589 if result.rowcount == 0: 

590 missing_ids.append(key) 

591 if missing_ids: 

592 missing = ",".join(str(item) for item in missing_ids) 

593 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

594 

595 def dailyJob(self) -> None: 

596 # docstring is inherited from a base class 

597 pass 

598 

599 def countUnassociatedObjects(self) -> int: 

600 # docstring is inherited from a base class 

601 

602 # Retrieve the DiaObject table. 

603 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

604 

605 # Construct the sql statement. 

606 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

607 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

608 

609 # Return the count. 

610 with self._engine.begin() as conn: 

611 count = conn.execute(stmt).scalar_one() 

612 

613 return count 

614 

615 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

616 """Return catalog of DiaSource instances from given region. 

617 

618 Parameters 

619 ---------- 

620 region : `lsst.sphgeom.Region` 

621 Region to search for DIASources. 

622 visit_time : `lsst.daf.base.DateTime` 

623 Time of the current visit. 

624 

625 Returns 

626 ------- 

627 catalog : `pandas.DataFrame` 

628 Catalog containing DiaSource records. 

629 """ 

630 # TODO: DateTime.MJD must be consistent with code in ap_association, 

631 # alternatively we can fill midpointMjdTai ourselves in store() 

632 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

633 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

634 

635 table = self._schema.get_table(ApdbTables.DiaSource) 

636 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

637 query = sql.select(*columns) 

638 

639 # build selection 

640 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

641 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

642 query = query.where(where) 

643 

644 # execute select 

645 with Timer("DiaSource select", self.config.timer): 

646 with self._engine.begin() as conn: 

647 sources = pandas.read_sql_query(query, conn) 

648 _LOG.debug("found %s DiaSources", len(sources)) 

649 return sources 

650 

651 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

652 """Return catalog of DiaSource instances given set of DiaObject IDs. 

653 

654 Parameters 

655 ---------- 

656 object_ids : 

657 Collection of DiaObject IDs 

658 visit_time : `lsst.daf.base.DateTime` 

659 Time of the current visit. 

660 

661 Returns 

662 ------- 

663 catalog : `pandas.DataFrame` 

664 Catalog contaning DiaSource records. 

665 """ 

666 # TODO: DateTime.MJD must be consistent with code in ap_association, 

667 # alternatively we can fill midpointMjdTai ourselves in store() 

668 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

669 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

670 

671 with Timer("DiaSource select", self.config.timer): 

672 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

673 

674 _LOG.debug("found %s DiaSources", len(sources)) 

675 return sources 

676 

677 def _getSourcesByIDs( 

678 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

679 ) -> pandas.DataFrame: 

680 """Return catalog of DiaSource or DiaForcedSource instances given set 

681 of DiaObject IDs. 

682 

683 Parameters 

684 ---------- 

685 table : `sqlalchemy.schema.Table` 

686 Database table. 

687 object_ids : 

688 Collection of DiaObject IDs 

689 midpointMjdTai_start : `float` 

690 Earliest midpointMjdTai to retrieve. 

691 

692 Returns 

693 ------- 

694 catalog : `pandas.DataFrame` 

695 Catalog contaning DiaSource records. `None` is returned if 

696 ``read_sources_months`` configuration parameter is set to 0 or 

697 when ``object_ids`` is empty. 

698 """ 

699 table = self._schema.get_table(table_enum) 

700 columns = self._schema.get_apdb_columns(table_enum) 

701 

702 sources: pandas.DataFrame | None = None 

703 if len(object_ids) <= 0: 

704 _LOG.debug("ID list is empty, just fetch empty result") 

705 query = sql.select(*columns).where(sql.literal(False)) 

706 with self._engine.begin() as conn: 

707 sources = pandas.read_sql_query(query, conn) 

708 else: 

709 data_frames: list[pandas.DataFrame] = [] 

710 for ids in chunk_iterable(sorted(object_ids), 1000): 

711 query = sql.select(*columns) 

712 

713 # Some types like np.int64 can cause issues with 

714 # sqlalchemy, convert them to int. 

715 int_ids = [int(oid) for oid in ids] 

716 

717 # select by object id 

718 query = query.where( 

719 sql.expression.and_( 

720 table.columns["diaObjectId"].in_(int_ids), 

721 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

722 ) 

723 ) 

724 

725 # execute select 

726 with self._engine.begin() as conn: 

727 data_frames.append(pandas.read_sql_query(query, conn)) 

728 

729 if len(data_frames) == 1: 

730 sources = data_frames[0] 

731 else: 

732 sources = pandas.concat(data_frames) 

733 assert sources is not None, "Catalog cannot be None" 

734 return sources 

735 

736 def _storeInsertId( 

737 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

738 ) -> None: 

739 dt = visit_time.toPython() 

740 

741 table = self._schema.get_table(ExtraTables.DiaInsertId) 

742 

743 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

744 connection.execute(stmt) 

745 

746 def _storeDiaObjects( 

747 self, 

748 objs: pandas.DataFrame, 

749 visit_time: dafBase.DateTime, 

750 insert_id: ApdbInsertId | None, 

751 connection: sqlalchemy.engine.Connection, 

752 ) -> None: 

753 """Store catalog of DiaObjects from current visit. 

754 

755 Parameters 

756 ---------- 

757 objs : `pandas.DataFrame` 

758 Catalog with DiaObject records. 

759 visit_time : `lsst.daf.base.DateTime` 

760 Time of the visit. 

761 insert_id : `ApdbInsertId` 

762 Insert identifier. 

763 """ 

764 # Some types like np.int64 can cause issues with sqlalchemy, convert 

765 # them to int. 

766 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

767 _LOG.debug("first object ID: %d", ids[0]) 

768 

769 # TODO: Need to verify that we are using correct scale here for 

770 # DATETIME representation (see DM-31996). 

771 dt = visit_time.toPython() 

772 

773 # everything to be done in single transaction 

774 if self.config.dia_object_index == "last_object_table": 

775 # Insert and replace all records in LAST table. 

776 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

777 

778 # Drop the previous objects (pandas cannot upsert). 

779 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

780 

781 with Timer(table.name + " delete", self.config.timer): 

782 res = connection.execute(query) 

783 _LOG.debug("deleted %s objects", res.rowcount) 

784 

785 # DiaObjectLast is a subset of DiaObject, strip missing columns 

786 last_column_names = [column.name for column in table.columns] 

787 last_objs = objs[last_column_names] 

788 last_objs = _coerce_uint64(last_objs) 

789 

790 if "lastNonForcedSource" in last_objs.columns: 

791 # lastNonForcedSource is defined NOT NULL, fill it with visit 

792 # time just in case. 

793 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

794 else: 

795 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

796 last_objs.set_index(extra_column.index, inplace=True) 

797 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

798 

799 with Timer("DiaObjectLast insert", self.config.timer): 

800 last_objs.to_sql( 

801 table.name, 

802 _ConnectionHackSA2(connection), 

803 if_exists="append", 

804 index=False, 

805 schema=table.schema, 

806 ) 

807 else: 

808 # truncate existing validity intervals 

809 table = self._schema.get_table(ApdbTables.DiaObject) 

810 

811 update = ( 

812 table.update() 

813 .values(validityEnd=dt) 

814 .where( 

815 sql.expression.and_( 

816 table.columns["diaObjectId"].in_(ids), 

817 table.columns["validityEnd"].is_(None), 

818 ) 

819 ) 

820 ) 

821 

822 # _LOG.debug("query: %s", query) 

823 

824 with Timer(table.name + " truncate", self.config.timer): 

825 res = connection.execute(update) 

826 _LOG.debug("truncated %s intervals", res.rowcount) 

827 

828 objs = _coerce_uint64(objs) 

829 

830 # Fill additional columns 

831 extra_columns: list[pandas.Series] = [] 

832 if "validityStart" in objs.columns: 

833 objs["validityStart"] = dt 

834 else: 

835 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

836 if "validityEnd" in objs.columns: 

837 objs["validityEnd"] = None 

838 else: 

839 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

840 if "lastNonForcedSource" in objs.columns: 

841 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

842 # just in case. 

843 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

844 else: 

845 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

846 if extra_columns: 

847 objs.set_index(extra_columns[0].index, inplace=True) 

848 objs = pandas.concat([objs] + extra_columns, axis="columns") 

849 

850 # Insert history data 

851 table = self._schema.get_table(ApdbTables.DiaObject) 

852 history_data: list[dict] = [] 

853 history_stmt: Any = None 

854 if insert_id is not None: 

855 pk_names = [column.name for column in table.primary_key] 

856 history_data = objs[pk_names].to_dict("records") 

857 for row in history_data: 

858 row["insert_id"] = insert_id.id 

859 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

860 history_stmt = history_table.insert() 

861 

862 # insert new versions 

863 with Timer("DiaObject insert", self.config.timer): 

864 objs.to_sql( 

865 table.name, 

866 _ConnectionHackSA2(connection), 

867 if_exists="append", 

868 index=False, 

869 schema=table.schema, 

870 ) 

871 if history_stmt is not None: 

872 connection.execute(history_stmt, history_data) 

873 

874 def _storeDiaSources( 

875 self, 

876 sources: pandas.DataFrame, 

877 insert_id: ApdbInsertId | None, 

878 connection: sqlalchemy.engine.Connection, 

879 ) -> None: 

880 """Store catalog of DiaSources from current visit. 

881 

882 Parameters 

883 ---------- 

884 sources : `pandas.DataFrame` 

885 Catalog containing DiaSource records 

886 """ 

887 table = self._schema.get_table(ApdbTables.DiaSource) 

888 

889 # Insert history data 

890 history: list[dict] = [] 

891 history_stmt: Any = None 

892 if insert_id is not None: 

893 pk_names = [column.name for column in table.primary_key] 

894 history = sources[pk_names].to_dict("records") 

895 for row in history: 

896 row["insert_id"] = insert_id.id 

897 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

898 history_stmt = history_table.insert() 

899 

900 # everything to be done in single transaction 

901 with Timer("DiaSource insert", self.config.timer): 

902 sources = _coerce_uint64(sources) 

903 sources.to_sql( 

904 table.name, 

905 _ConnectionHackSA2(connection), 

906 if_exists="append", 

907 index=False, 

908 schema=table.schema, 

909 ) 

910 if history_stmt is not None: 

911 connection.execute(history_stmt, history) 

912 

913 def _storeDiaForcedSources( 

914 self, 

915 sources: pandas.DataFrame, 

916 insert_id: ApdbInsertId | None, 

917 connection: sqlalchemy.engine.Connection, 

918 ) -> None: 

919 """Store a set of DiaForcedSources from current visit. 

920 

921 Parameters 

922 ---------- 

923 sources : `pandas.DataFrame` 

924 Catalog containing DiaForcedSource records 

925 """ 

926 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

927 

928 # Insert history data 

929 history: list[dict] = [] 

930 history_stmt: Any = None 

931 if insert_id is not None: 

932 pk_names = [column.name for column in table.primary_key] 

933 history = sources[pk_names].to_dict("records") 

934 for row in history: 

935 row["insert_id"] = insert_id.id 

936 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

937 history_stmt = history_table.insert() 

938 

939 # everything to be done in single transaction 

940 with Timer("DiaForcedSource insert", self.config.timer): 

941 sources = _coerce_uint64(sources) 

942 sources.to_sql( 

943 table.name, 

944 _ConnectionHackSA2(connection), 

945 if_exists="append", 

946 index=False, 

947 schema=table.schema, 

948 ) 

949 if history_stmt is not None: 

950 connection.execute(history_stmt, history) 

951 

952 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

953 """Generate a set of HTM indices covering specified region. 

954 

955 Parameters 

956 ---------- 

957 region: `sphgeom.Region` 

958 Region that needs to be indexed. 

959 

960 Returns 

961 ------- 

962 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

963 """ 

964 _LOG.debug("region: %s", region) 

965 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

966 

967 return indices.ranges() 

968 

969 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

970 """Make SQLAlchemy expression for selecting records in a region.""" 

971 htm_index_column = table.columns[self.config.htm_index_column] 

972 exprlist = [] 

973 pixel_ranges = self._htm_indices(region) 

974 for low, upper in pixel_ranges: 

975 upper -= 1 

976 if low == upper: 

977 exprlist.append(htm_index_column == low) 

978 else: 

979 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

980 

981 return sql.expression.or_(*exprlist) 

982 

983 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

984 """Calculate HTM index for each record and add it to a DataFrame. 

985 

986 Notes 

987 ----- 

988 This overrides any existing column in a DataFrame with the same name 

989 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

990 returned. 

991 """ 

992 # calculate HTM index for every DiaObject 

993 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

994 ra_col, dec_col = self.config.ra_dec_columns 

995 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

996 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

997 idx = self.pixelator.index(uv3d) 

998 htm_index[i] = idx 

999 df = df.copy() 

1000 df[self.config.htm_index_column] = htm_index 

1001 return df 

1002 

1003 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1004 """Add pixelId column to DiaSource catalog. 

1005 

1006 Notes 

1007 ----- 

1008 This method copies pixelId value from a matching DiaObject record. 

1009 DiaObject catalog needs to have a pixelId column filled by 

1010 ``_add_obj_htm_index`` method and DiaSource records need to be 

1011 associated to DiaObjects via ``diaObjectId`` column. 

1012 

1013 This overrides any existing column in a DataFrame with the same name 

1014 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1015 returned. 

1016 """ 

1017 pixel_id_map: dict[int, int] = { 

1018 diaObjectId: pixelId 

1019 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

1020 } 

1021 # DiaSources associated with SolarSystemObjects do not have an 

1022 # associated DiaObject hence we skip them and set their htmIndex 

1023 # value to 0. 

1024 pixel_id_map[0] = 0 

1025 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1026 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1027 htm_index[i] = pixel_id_map[diaObjId] 

1028 sources = sources.copy() 

1029 sources[self.config.htm_index_column] = htm_index 

1030 return sources