Coverage for python/lsst/dax/apdb/apdbSql.py: 14%

503 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-20 11:02 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Callable, Iterable, Mapping, MutableMapping 

31from contextlib import closing, suppress 

32from typing import TYPE_CHECKING, Any, cast 

33 

34import lsst.daf.base as dafBase 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38from felis.simple import Table 

39from lsst.pex.config import ChoiceField, Field, ListField 

40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

41from lsst.utils.iteration import chunk_iterable 

42from sqlalchemy import func, inspection, sql 

43from sqlalchemy.engine import Inspector 

44from sqlalchemy.pool import NullPool 

45 

46from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

47from .apdbMetadataSql import ApdbMetadataSql 

48from .apdbSchema import ApdbTables 

49from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

50from .timer import Timer 

51from .versionTuple import IncompatibleVersionError, VersionTuple 

52 

53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 import sqlite3 

55 

56 from .apdbMetadata import ApdbMetadata 

57 

58_LOG = logging.getLogger(__name__) 

59 

60VERSION = VersionTuple(0, 1, 0) 

61"""Version for the code defined in this module. This needs to be updated 

62(following compatibility rules) when schema produced by this code changes. 

63""" 

64 

65if pandas.__version__.partition(".")[0] == "1": 65 ↛ 67line 65 didn't jump to line 67, because the condition on line 65 was never true

66 

67 class _ConnectionHackSA2(sqlalchemy.engine.Connectable): 

68 """Terrible hack to workaround Pandas 1 incomplete support for 

69 sqlalchemy 2. 

70 

71 We need to pass a Connection instance to pandas method, but in SA 2 the 

72 Connection class lost ``connect`` method which is used by Pandas. 

73 """ 

74 

75 def __init__(self, connection: sqlalchemy.engine.Connection): 

76 self._connection = connection 

77 

78 def connect(self, **kwargs: Any) -> Any: 

79 return self 

80 

81 @property 

82 def execute(self) -> Callable: 

83 return self._connection.execute 

84 

85 @property 

86 def execution_options(self) -> Callable: 

87 return self._connection.execution_options 

88 

89 @property 

90 def connection(self) -> Any: 

91 return self._connection.connection 

92 

93 def __enter__(self) -> sqlalchemy.engine.Connection: 

94 return self._connection 

95 

96 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

97 # Do not close connection here 

98 pass 

99 

100 @inspection._inspects(_ConnectionHackSA2) 

101 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector: 

102 return Inspector._construct(Inspector._init_connection, conn._connection) 

103 

104else: 

105 # Pandas 2.0 supports SQLAlchemy 2 correctly. 

106 def _ConnectionHackSA2( # type: ignore[no-redef] 

107 conn: sqlalchemy.engine.Connectable, 

108 ) -> sqlalchemy.engine.Connectable: 

109 return conn 

110 

111 

112def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

113 """Change the type of uint64 columns to int64, and return copy of data 

114 frame. 

115 """ 

116 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

117 return df.astype({name: np.int64 for name in names}) 

118 

119 

120def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

121 """Calculate starting point for time-based source search. 

122 

123 Parameters 

124 ---------- 

125 visit_time : `lsst.daf.base.DateTime` 

126 Time of current visit. 

127 months : `int` 

128 Number of months in the sources history. 

129 

130 Returns 

131 ------- 

132 time : `float` 

133 A ``midpointMjdTai`` starting point, MJD time. 

134 """ 

135 # TODO: `system` must be consistent with the code in ap_association 

136 # (see DM-31996) 

137 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

138 

139 

140def _onSqlite3Connect( 

141 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

142) -> None: 

143 # Enable foreign keys 

144 with closing(dbapiConnection.cursor()) as cursor: 

145 cursor.execute("PRAGMA foreign_keys=ON;") 

146 

147 

148class ApdbSqlConfig(ApdbConfig): 

149 """APDB configuration class for SQL implementation (ApdbSql).""" 

150 

151 db_url = Field[str](doc="SQLAlchemy database connection URI") 

152 isolation_level = ChoiceField[str]( 

153 doc=( 

154 "Transaction isolation level, if unset then backend-default value " 

155 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

156 "Some backends may not support every allowed value." 

157 ), 

158 allowed={ 

159 "READ_COMMITTED": "Read committed", 

160 "READ_UNCOMMITTED": "Read uncommitted", 

161 "REPEATABLE_READ": "Repeatable read", 

162 "SERIALIZABLE": "Serializable", 

163 }, 

164 default=None, 

165 optional=True, 

166 ) 

167 connection_pool = Field[bool]( 

168 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

169 default=True, 

170 ) 

171 connection_timeout = Field[float]( 

172 doc=( 

173 "Maximum time to wait time for database lock to be released before exiting. " 

174 "Defaults to sqlalchemy defaults if not set." 

175 ), 

176 default=None, 

177 optional=True, 

178 ) 

179 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

180 dia_object_index = ChoiceField[str]( 

181 doc="Indexing mode for DiaObject table", 

182 allowed={ 

183 "baseline": "Index defined in baseline schema", 

184 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

185 "last_object_table": "Separate DiaObjectLast table", 

186 }, 

187 default="baseline", 

188 ) 

189 htm_level = Field[int](doc="HTM indexing level", default=20) 

190 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

191 htm_index_column = Field[str]( 

192 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

193 ) 

194 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

195 dia_object_columns = ListField[str]( 

196 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

197 ) 

198 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

199 namespace = Field[str]( 

200 doc=( 

201 "Namespace or schema name for all tables in APDB database. " 

202 "Presently only works for PostgreSQL backend. " 

203 "If schema with this name does not exist it will be created when " 

204 "APDB tables are created." 

205 ), 

206 default=None, 

207 optional=True, 

208 ) 

209 timer = Field[bool](doc="If True then print/log timing information", default=False) 

210 

211 def validate(self) -> None: 

212 super().validate() 

213 if len(self.ra_dec_columns) != 2: 

214 raise ValueError("ra_dec_columns must have exactly two column names") 

215 

216 

217class ApdbSqlTableData(ApdbTableData): 

218 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

219 

220 def __init__(self, result: sqlalchemy.engine.Result): 

221 self._keys = list(result.keys()) 

222 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

223 

224 def column_names(self) -> list[str]: 

225 return self._keys 

226 

227 def rows(self) -> Iterable[tuple]: 

228 return self._rows 

229 

230 

231class ApdbSql(Apdb): 

232 """Implementation of APDB interface based on SQL database. 

233 

234 The implementation is configured via standard ``pex_config`` mechanism 

235 using `ApdbSqlConfig` configuration class. For an example of different 

236 configurations check ``config/`` folder. 

237 

238 Parameters 

239 ---------- 

240 config : `ApdbSqlConfig` 

241 Configuration object. 

242 """ 

243 

244 ConfigClass = ApdbSqlConfig 

245 

246 metadataSchemaVersionKey = "version:schema" 

247 """Name of the metadata key to store schema version number.""" 

248 

249 metadataCodeVersionKey = "version:ApdbSql" 

250 """Name of the metadata key to store code version number.""" 

251 

252 def __init__(self, config: ApdbSqlConfig): 

253 config.validate() 

254 self.config = config 

255 

256 _LOG.debug("APDB Configuration:") 

257 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

258 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

259 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

260 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

261 _LOG.debug(" schema_file: %s", self.config.schema_file) 

262 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

263 _LOG.debug(" schema prefix: %s", self.config.prefix) 

264 

265 # engine is reused between multiple processes, make sure that we don't 

266 # share connections by disabling pool (by using NullPool class) 

267 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo) 

268 conn_args: dict[str, Any] = dict() 

269 if not self.config.connection_pool: 

270 kw.update(poolclass=NullPool) 

271 if self.config.isolation_level is not None: 

272 kw.update(isolation_level=self.config.isolation_level) 

273 elif self.config.db_url.startswith("sqlite"): # type: ignore 

274 # Use READ_UNCOMMITTED as default value for sqlite. 

275 kw.update(isolation_level="READ_UNCOMMITTED") 

276 if self.config.connection_timeout is not None: 

277 if self.config.db_url.startswith("sqlite"): 

278 conn_args.update(timeout=self.config.connection_timeout) 

279 elif self.config.db_url.startswith(("postgresql", "mysql")): 

280 conn_args.update(connect_timeout=self.config.connection_timeout) 

281 kw.update(connect_args=conn_args) 

282 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

283 

284 if self._engine.dialect.name == "sqlite": 

285 # Need to enable foreign keys on every new connection. 

286 sqlalchemy.event.listen(self._engine, "connect", _onSqlite3Connect) 

287 

288 self._schema = ApdbSqlSchema( 

289 engine=self._engine, 

290 dia_object_index=self.config.dia_object_index, 

291 schema_file=self.config.schema_file, 

292 schema_name=self.config.schema_name, 

293 prefix=self.config.prefix, 

294 namespace=self.config.namespace, 

295 htm_index_column=self.config.htm_index_column, 

296 use_insert_id=config.use_insert_id, 

297 ) 

298 

299 self._metadata: ApdbMetadataSql | None = None 

300 if not self._schema.empty(): 

301 table: sqlalchemy.schema.Table | None = None 

302 with suppress(ValueError): 

303 table = self._schema.get_table(ApdbTables.metadata) 

304 self._metadata = ApdbMetadataSql(self._engine, table) 

305 self._versionCheck(self._metadata) 

306 

307 self.pixelator = HtmPixelization(self.config.htm_level) 

308 self.use_insert_id = self._schema.has_insert_id 

309 

310 def _versionCheck(self, metadata: ApdbMetadataSql) -> None: 

311 """Check schema version compatibility.""" 

312 

313 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

314 """Retrieve version number from given metadata key.""" 

315 if metadata.table_exists(): 

316 version_str = metadata.get(key) 

317 if version_str is None: 

318 # Should not happen with existing metadata table. 

319 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

320 return VersionTuple.fromString(version_str) 

321 return default 

322 

323 # For old databases where metadata table does not exist we assume that 

324 # version of both code and schema is 0.1.0. 

325 initial_version = VersionTuple(0, 1, 0) 

326 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

327 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

328 

329 # For now there is no way to make read-only APDB instances, assume that 

330 # any access can do updates. 

331 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

332 raise IncompatibleVersionError( 

333 f"Configured schema version {self._schema.schemaVersion()} " 

334 f"is not compatible with database version {db_schema_version}" 

335 ) 

336 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

337 raise IncompatibleVersionError( 

338 f"Current code version {self.apdbImplementationVersion()} " 

339 f"is not compatible with database version {db_code_version}" 

340 ) 

341 

342 @classmethod 

343 def apdbImplementationVersion(cls) -> VersionTuple: 

344 # Docstring inherited from base class. 

345 return VERSION 

346 

347 def apdbSchemaVersion(self) -> VersionTuple: 

348 # Docstring inherited from base class. 

349 return self._schema.schemaVersion() 

350 

351 def tableRowCount(self) -> dict[str, int]: 

352 """Return dictionary with the table names and row counts. 

353 

354 Used by ``ap_proto`` to keep track of the size of the database tables. 

355 Depending on database technology this could be expensive operation. 

356 

357 Returns 

358 ------- 

359 row_counts : `dict` 

360 Dict where key is a table name and value is a row count. 

361 """ 

362 res = {} 

363 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

364 if self.config.dia_object_index == "last_object_table": 

365 tables.append(ApdbTables.DiaObjectLast) 

366 with self._engine.begin() as conn: 

367 for table in tables: 

368 sa_table = self._schema.get_table(table) 

369 stmt = sql.select(func.count()).select_from(sa_table) 

370 count: int = conn.execute(stmt).scalar_one() 

371 res[table.name] = count 

372 

373 return res 

374 

375 def tableDef(self, table: ApdbTables) -> Table | None: 

376 # docstring is inherited from a base class 

377 return self._schema.tableSchemas.get(table) 

378 

379 def makeSchema(self, drop: bool = False) -> None: 

380 # docstring is inherited from a base class 

381 self._schema.makeSchema(drop=drop) 

382 # Need to reset metadata after table was created. 

383 table: sqlalchemy.schema.Table | None = None 

384 with suppress(ValueError): 

385 table = self._schema.get_table(ApdbTables.metadata) 

386 self._metadata = ApdbMetadataSql(self._engine, table) 

387 

388 if self._metadata.table_exists(): 

389 # Fill version numbers, but only if they are not defined. 

390 if self._metadata.get(self.metadataSchemaVersionKey) is None: 

391 self._metadata.set(self.metadataSchemaVersionKey, str(self._schema.schemaVersion())) 

392 if self._metadata.get(self.metadataCodeVersionKey) is None: 

393 self._metadata.set(self.metadataCodeVersionKey, str(self.apdbImplementationVersion())) 

394 

395 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

396 # docstring is inherited from a base class 

397 

398 # decide what columns we need 

399 if self.config.dia_object_index == "last_object_table": 

400 table_enum = ApdbTables.DiaObjectLast 

401 else: 

402 table_enum = ApdbTables.DiaObject 

403 table = self._schema.get_table(table_enum) 

404 if not self.config.dia_object_columns: 

405 columns = self._schema.get_apdb_columns(table_enum) 

406 else: 

407 columns = [table.c[col] for col in self.config.dia_object_columns] 

408 query = sql.select(*columns) 

409 

410 # build selection 

411 query = query.where(self._filterRegion(table, region)) 

412 

413 # select latest version of objects 

414 if self.config.dia_object_index != "last_object_table": 

415 query = query.where(table.c.validityEnd == None) # noqa: E711 

416 

417 # _LOG.debug("query: %s", query) 

418 

419 # execute select 

420 with Timer("DiaObject select", self.config.timer): 

421 with self._engine.begin() as conn: 

422 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn)) 

423 _LOG.debug("found %s DiaObjects", len(objects)) 

424 return objects 

425 

426 def getDiaSources( 

427 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

428 ) -> pandas.DataFrame | None: 

429 # docstring is inherited from a base class 

430 if self.config.read_sources_months == 0: 

431 _LOG.debug("Skip DiaSources fetching") 

432 return None 

433 

434 if object_ids is None: 

435 # region-based select 

436 return self._getDiaSourcesInRegion(region, visit_time) 

437 else: 

438 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

439 

440 def getDiaForcedSources( 

441 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

442 ) -> pandas.DataFrame | None: 

443 # docstring is inherited from a base class 

444 if self.config.read_forced_sources_months == 0: 

445 _LOG.debug("Skip DiaForceSources fetching") 

446 return None 

447 

448 if object_ids is None: 

449 # This implementation does not support region-based selection. 

450 raise NotImplementedError("Region-based selection is not supported") 

451 

452 # TODO: DateTime.MJD must be consistent with code in ap_association, 

453 # alternatively we can fill midpointMjdTai ourselves in store() 

454 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

455 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

456 

457 with Timer("DiaForcedSource select", self.config.timer): 

458 sources = self._getSourcesByIDs( 

459 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

460 ) 

461 

462 _LOG.debug("found %s DiaForcedSources", len(sources)) 

463 return sources 

464 

465 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

466 # docstring is inherited from a base class 

467 raise NotImplementedError() 

468 

469 def containsCcdVisit(self, ccdVisitId: int) -> bool: 

470 """Test whether data for a given visit-detector is present in the APDB. 

471 

472 This method is a placeholder until `Apdb.containsVisitDetector` can 

473 be implemented. 

474 

475 Parameters 

476 ---------- 

477 ccdVisitId : `int` 

478 The packed ID of the visit-detector to search for. 

479 

480 Returns 

481 ------- 

482 present : `bool` 

483 `True` if some DiaSource records exist for the specified 

484 observation, `False` otherwise. 

485 """ 

486 # TODO: remove this method in favor of containsVisitDetector on either 

487 # DM-41671 or a ticket that removes ccdVisitId from these tables 

488 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

489 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

490 # Query should load only one leaf page of the index 

491 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

492 # Backup query in case an image was processed but had no diaSources 

493 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

494 

495 with self._engine.begin() as conn: 

496 result = conn.execute(query1).scalar_one_or_none() 

497 if result is not None: 

498 return True 

499 else: 

500 result = conn.execute(query2).scalar_one_or_none() 

501 return result is not None 

502 

503 def getInsertIds(self) -> list[ApdbInsertId] | None: 

504 # docstring is inherited from a base class 

505 if not self._schema.has_insert_id: 

506 return None 

507 

508 table = self._schema.get_table(ExtraTables.DiaInsertId) 

509 assert table is not None, "has_insert_id=True means it must be defined" 

510 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by( 

511 table.columns["insert_time"] 

512 ) 

513 with Timer("DiaObject insert id select", self.config.timer): 

514 with self._engine.connect() as conn: 

515 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

516 ids = [] 

517 for row in result: 

518 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9)) 

519 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time)) 

520 return ids 

521 

522 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

523 # docstring is inherited from a base class 

524 if not self._schema.has_insert_id: 

525 raise ValueError("APDB is not configured for history storage") 

526 

527 table = self._schema.get_table(ExtraTables.DiaInsertId) 

528 

529 insert_ids = [id.id for id in ids] 

530 where_clause = table.columns["insert_id"].in_(insert_ids) 

531 stmt = table.delete().where(where_clause) 

532 with self._engine.begin() as conn: 

533 conn.execute(stmt) 

534 

535 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

536 # docstring is inherited from a base class 

537 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

538 

539 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

540 # docstring is inherited from a base class 

541 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

542 

543 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

544 # docstring is inherited from a base class 

545 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

546 

547 def _get_history( 

548 self, 

549 ids: Iterable[ApdbInsertId], 

550 table_enum: ApdbTables, 

551 history_table_enum: ExtraTables, 

552 ) -> ApdbTableData: 

553 """Return catalog of records for given insert identifiers, common 

554 implementation for all DIA tables. 

555 """ 

556 if not self._schema.has_insert_id: 

557 raise ValueError("APDB is not configured for history retrieval") 

558 

559 table = self._schema.get_table(table_enum) 

560 history_table = self._schema.get_table(history_table_enum) 

561 

562 join = table.join(history_table) 

563 insert_ids = [id.id for id in ids] 

564 history_id_column = history_table.columns["insert_id"] 

565 apdb_columns = self._schema.get_apdb_columns(table_enum) 

566 where_clause = history_id_column.in_(insert_ids) 

567 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

568 

569 # execute select 

570 with Timer(f"{table.name} history select", self.config.timer): 

571 with self._engine.begin() as conn: 

572 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

573 return ApdbSqlTableData(result) 

574 

575 def getSSObjects(self) -> pandas.DataFrame: 

576 # docstring is inherited from a base class 

577 

578 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

579 query = sql.select(*columns) 

580 

581 # execute select 

582 with Timer("DiaObject select", self.config.timer): 

583 with self._engine.begin() as conn: 

584 objects = pandas.read_sql_query(query, conn) 

585 _LOG.debug("found %s SSObjects", len(objects)) 

586 return objects 

587 

588 def store( 

589 self, 

590 visit_time: dafBase.DateTime, 

591 objects: pandas.DataFrame, 

592 sources: pandas.DataFrame | None = None, 

593 forced_sources: pandas.DataFrame | None = None, 

594 ) -> None: 

595 # docstring is inherited from a base class 

596 

597 # We want to run all inserts in one transaction. 

598 with self._engine.begin() as connection: 

599 insert_id: ApdbInsertId | None = None 

600 if self._schema.has_insert_id: 

601 insert_id = ApdbInsertId.new_insert_id(visit_time) 

602 self._storeInsertId(insert_id, visit_time, connection) 

603 

604 # fill pixelId column for DiaObjects 

605 objects = self._add_obj_htm_index(objects) 

606 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

607 

608 if sources is not None: 

609 # copy pixelId column from DiaObjects to DiaSources 

610 sources = self._add_src_htm_index(sources, objects) 

611 self._storeDiaSources(sources, insert_id, connection) 

612 

613 if forced_sources is not None: 

614 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

615 

616 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

617 # docstring is inherited from a base class 

618 

619 idColumn = "ssObjectId" 

620 table = self._schema.get_table(ApdbTables.SSObject) 

621 

622 # everything to be done in single transaction 

623 with self._engine.begin() as conn: 

624 # Find record IDs that already exist. Some types like np.int64 can 

625 # cause issues with sqlalchemy, convert them to int. 

626 ids = sorted(int(oid) for oid in objects[idColumn]) 

627 

628 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

629 result = conn.execute(query) 

630 knownIds = set(row.ssObjectId for row in result) 

631 

632 filter = objects[idColumn].isin(knownIds) 

633 toUpdate = cast(pandas.DataFrame, objects[filter]) 

634 toInsert = cast(pandas.DataFrame, objects[~filter]) 

635 

636 # insert new records 

637 if len(toInsert) > 0: 

638 toInsert.to_sql( 

639 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema 

640 ) 

641 

642 # update existing records 

643 if len(toUpdate) > 0: 

644 whereKey = f"{idColumn}_param" 

645 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

646 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

647 values = toUpdate.to_dict("records") 

648 result = conn.execute(update, values) 

649 

650 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

651 # docstring is inherited from a base class 

652 

653 table = self._schema.get_table(ApdbTables.DiaSource) 

654 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

655 

656 with self._engine.begin() as conn: 

657 # Need to make sure that every ID exists in the database, but 

658 # executemany may not support rowcount, so iterate and check what 

659 # is missing. 

660 missing_ids: list[int] = [] 

661 for key, value in idMap.items(): 

662 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

663 result = conn.execute(query, params) 

664 if result.rowcount == 0: 

665 missing_ids.append(key) 

666 if missing_ids: 

667 missing = ",".join(str(item) for item in missing_ids) 

668 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

669 

670 def dailyJob(self) -> None: 

671 # docstring is inherited from a base class 

672 pass 

673 

674 def countUnassociatedObjects(self) -> int: 

675 # docstring is inherited from a base class 

676 

677 # Retrieve the DiaObject table. 

678 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

679 

680 # Construct the sql statement. 

681 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

682 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

683 

684 # Return the count. 

685 with self._engine.begin() as conn: 

686 count = conn.execute(stmt).scalar_one() 

687 

688 return count 

689 

690 @property 

691 def metadata(self) -> ApdbMetadata: 

692 # docstring is inherited from a base class 

693 if self._metadata is None: 

694 raise RuntimeError("Database schema was not initialized.") 

695 return self._metadata 

696 

697 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

698 """Return catalog of DiaSource instances from given region. 

699 

700 Parameters 

701 ---------- 

702 region : `lsst.sphgeom.Region` 

703 Region to search for DIASources. 

704 visit_time : `lsst.daf.base.DateTime` 

705 Time of the current visit. 

706 

707 Returns 

708 ------- 

709 catalog : `pandas.DataFrame` 

710 Catalog containing DiaSource records. 

711 """ 

712 # TODO: DateTime.MJD must be consistent with code in ap_association, 

713 # alternatively we can fill midpointMjdTai ourselves in store() 

714 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

715 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

716 

717 table = self._schema.get_table(ApdbTables.DiaSource) 

718 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

719 query = sql.select(*columns) 

720 

721 # build selection 

722 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

723 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

724 query = query.where(where) 

725 

726 # execute select 

727 with Timer("DiaSource select", self.config.timer): 

728 with self._engine.begin() as conn: 

729 sources = pandas.read_sql_query(query, conn) 

730 _LOG.debug("found %s DiaSources", len(sources)) 

731 return sources 

732 

733 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

734 """Return catalog of DiaSource instances given set of DiaObject IDs. 

735 

736 Parameters 

737 ---------- 

738 object_ids : 

739 Collection of DiaObject IDs 

740 visit_time : `lsst.daf.base.DateTime` 

741 Time of the current visit. 

742 

743 Returns 

744 ------- 

745 catalog : `pandas.DataFrame` 

746 Catalog contaning DiaSource records. 

747 """ 

748 # TODO: DateTime.MJD must be consistent with code in ap_association, 

749 # alternatively we can fill midpointMjdTai ourselves in store() 

750 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

751 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

752 

753 with Timer("DiaSource select", self.config.timer): 

754 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

755 

756 _LOG.debug("found %s DiaSources", len(sources)) 

757 return sources 

758 

759 def _getSourcesByIDs( 

760 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

761 ) -> pandas.DataFrame: 

762 """Return catalog of DiaSource or DiaForcedSource instances given set 

763 of DiaObject IDs. 

764 

765 Parameters 

766 ---------- 

767 table : `sqlalchemy.schema.Table` 

768 Database table. 

769 object_ids : 

770 Collection of DiaObject IDs 

771 midpointMjdTai_start : `float` 

772 Earliest midpointMjdTai to retrieve. 

773 

774 Returns 

775 ------- 

776 catalog : `pandas.DataFrame` 

777 Catalog contaning DiaSource records. `None` is returned if 

778 ``read_sources_months`` configuration parameter is set to 0 or 

779 when ``object_ids`` is empty. 

780 """ 

781 table = self._schema.get_table(table_enum) 

782 columns = self._schema.get_apdb_columns(table_enum) 

783 

784 sources: pandas.DataFrame | None = None 

785 if len(object_ids) <= 0: 

786 _LOG.debug("ID list is empty, just fetch empty result") 

787 query = sql.select(*columns).where(sql.literal(False)) 

788 with self._engine.begin() as conn: 

789 sources = pandas.read_sql_query(query, conn) 

790 else: 

791 data_frames: list[pandas.DataFrame] = [] 

792 for ids in chunk_iterable(sorted(object_ids), 1000): 

793 query = sql.select(*columns) 

794 

795 # Some types like np.int64 can cause issues with 

796 # sqlalchemy, convert them to int. 

797 int_ids = [int(oid) for oid in ids] 

798 

799 # select by object id 

800 query = query.where( 

801 sql.expression.and_( 

802 table.columns["diaObjectId"].in_(int_ids), 

803 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

804 ) 

805 ) 

806 

807 # execute select 

808 with self._engine.begin() as conn: 

809 data_frames.append(pandas.read_sql_query(query, conn)) 

810 

811 if len(data_frames) == 1: 

812 sources = data_frames[0] 

813 else: 

814 sources = pandas.concat(data_frames) 

815 assert sources is not None, "Catalog cannot be None" 

816 return sources 

817 

818 def _storeInsertId( 

819 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

820 ) -> None: 

821 dt = visit_time.toPython() 

822 

823 table = self._schema.get_table(ExtraTables.DiaInsertId) 

824 

825 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

826 connection.execute(stmt) 

827 

828 def _storeDiaObjects( 

829 self, 

830 objs: pandas.DataFrame, 

831 visit_time: dafBase.DateTime, 

832 insert_id: ApdbInsertId | None, 

833 connection: sqlalchemy.engine.Connection, 

834 ) -> None: 

835 """Store catalog of DiaObjects from current visit. 

836 

837 Parameters 

838 ---------- 

839 objs : `pandas.DataFrame` 

840 Catalog with DiaObject records. 

841 visit_time : `lsst.daf.base.DateTime` 

842 Time of the visit. 

843 insert_id : `ApdbInsertId` 

844 Insert identifier. 

845 """ 

846 # Some types like np.int64 can cause issues with sqlalchemy, convert 

847 # them to int. 

848 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

849 _LOG.debug("first object ID: %d", ids[0]) 

850 

851 # TODO: Need to verify that we are using correct scale here for 

852 # DATETIME representation (see DM-31996). 

853 dt = visit_time.toPython() 

854 

855 # everything to be done in single transaction 

856 if self.config.dia_object_index == "last_object_table": 

857 # Insert and replace all records in LAST table. 

858 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

859 

860 # Drop the previous objects (pandas cannot upsert). 

861 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

862 

863 with Timer(table.name + " delete", self.config.timer): 

864 res = connection.execute(query) 

865 _LOG.debug("deleted %s objects", res.rowcount) 

866 

867 # DiaObjectLast is a subset of DiaObject, strip missing columns 

868 last_column_names = [column.name for column in table.columns] 

869 last_objs = objs[last_column_names] 

870 last_objs = _coerce_uint64(last_objs) 

871 

872 if "lastNonForcedSource" in last_objs.columns: 

873 # lastNonForcedSource is defined NOT NULL, fill it with visit 

874 # time just in case. 

875 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

876 else: 

877 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

878 last_objs.set_index(extra_column.index, inplace=True) 

879 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

880 

881 with Timer("DiaObjectLast insert", self.config.timer): 

882 last_objs.to_sql( 

883 table.name, 

884 _ConnectionHackSA2(connection), 

885 if_exists="append", 

886 index=False, 

887 schema=table.schema, 

888 ) 

889 else: 

890 # truncate existing validity intervals 

891 table = self._schema.get_table(ApdbTables.DiaObject) 

892 

893 update = ( 

894 table.update() 

895 .values(validityEnd=dt) 

896 .where( 

897 sql.expression.and_( 

898 table.columns["diaObjectId"].in_(ids), 

899 table.columns["validityEnd"].is_(None), 

900 ) 

901 ) 

902 ) 

903 

904 # _LOG.debug("query: %s", query) 

905 

906 with Timer(table.name + " truncate", self.config.timer): 

907 res = connection.execute(update) 

908 _LOG.debug("truncated %s intervals", res.rowcount) 

909 

910 objs = _coerce_uint64(objs) 

911 

912 # Fill additional columns 

913 extra_columns: list[pandas.Series] = [] 

914 if "validityStart" in objs.columns: 

915 objs["validityStart"] = dt 

916 else: 

917 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

918 if "validityEnd" in objs.columns: 

919 objs["validityEnd"] = None 

920 else: 

921 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

922 if "lastNonForcedSource" in objs.columns: 

923 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

924 # just in case. 

925 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

926 else: 

927 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

928 if extra_columns: 

929 objs.set_index(extra_columns[0].index, inplace=True) 

930 objs = pandas.concat([objs] + extra_columns, axis="columns") 

931 

932 # Insert history data 

933 table = self._schema.get_table(ApdbTables.DiaObject) 

934 history_data: list[dict] = [] 

935 history_stmt: Any = None 

936 if insert_id is not None: 

937 pk_names = [column.name for column in table.primary_key] 

938 history_data = objs[pk_names].to_dict("records") 

939 for row in history_data: 

940 row["insert_id"] = insert_id.id 

941 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

942 history_stmt = history_table.insert() 

943 

944 # insert new versions 

945 with Timer("DiaObject insert", self.config.timer): 

946 objs.to_sql( 

947 table.name, 

948 _ConnectionHackSA2(connection), 

949 if_exists="append", 

950 index=False, 

951 schema=table.schema, 

952 ) 

953 if history_stmt is not None: 

954 connection.execute(history_stmt, history_data) 

955 

956 def _storeDiaSources( 

957 self, 

958 sources: pandas.DataFrame, 

959 insert_id: ApdbInsertId | None, 

960 connection: sqlalchemy.engine.Connection, 

961 ) -> None: 

962 """Store catalog of DiaSources from current visit. 

963 

964 Parameters 

965 ---------- 

966 sources : `pandas.DataFrame` 

967 Catalog containing DiaSource records 

968 """ 

969 table = self._schema.get_table(ApdbTables.DiaSource) 

970 

971 # Insert history data 

972 history: list[dict] = [] 

973 history_stmt: Any = None 

974 if insert_id is not None: 

975 pk_names = [column.name for column in table.primary_key] 

976 history = sources[pk_names].to_dict("records") 

977 for row in history: 

978 row["insert_id"] = insert_id.id 

979 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

980 history_stmt = history_table.insert() 

981 

982 # everything to be done in single transaction 

983 with Timer("DiaSource insert", self.config.timer): 

984 sources = _coerce_uint64(sources) 

985 sources.to_sql( 

986 table.name, 

987 _ConnectionHackSA2(connection), 

988 if_exists="append", 

989 index=False, 

990 schema=table.schema, 

991 ) 

992 if history_stmt is not None: 

993 connection.execute(history_stmt, history) 

994 

995 def _storeDiaForcedSources( 

996 self, 

997 sources: pandas.DataFrame, 

998 insert_id: ApdbInsertId | None, 

999 connection: sqlalchemy.engine.Connection, 

1000 ) -> None: 

1001 """Store a set of DiaForcedSources from current visit. 

1002 

1003 Parameters 

1004 ---------- 

1005 sources : `pandas.DataFrame` 

1006 Catalog containing DiaForcedSource records 

1007 """ 

1008 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1009 

1010 # Insert history data 

1011 history: list[dict] = [] 

1012 history_stmt: Any = None 

1013 if insert_id is not None: 

1014 pk_names = [column.name for column in table.primary_key] 

1015 history = sources[pk_names].to_dict("records") 

1016 for row in history: 

1017 row["insert_id"] = insert_id.id 

1018 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

1019 history_stmt = history_table.insert() 

1020 

1021 # everything to be done in single transaction 

1022 with Timer("DiaForcedSource insert", self.config.timer): 

1023 sources = _coerce_uint64(sources) 

1024 sources.to_sql( 

1025 table.name, 

1026 _ConnectionHackSA2(connection), 

1027 if_exists="append", 

1028 index=False, 

1029 schema=table.schema, 

1030 ) 

1031 if history_stmt is not None: 

1032 connection.execute(history_stmt, history) 

1033 

1034 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1035 """Generate a set of HTM indices covering specified region. 

1036 

1037 Parameters 

1038 ---------- 

1039 region: `sphgeom.Region` 

1040 Region that needs to be indexed. 

1041 

1042 Returns 

1043 ------- 

1044 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1045 """ 

1046 _LOG.debug("region: %s", region) 

1047 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

1048 

1049 return indices.ranges() 

1050 

1051 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1052 """Make SQLAlchemy expression for selecting records in a region.""" 

1053 htm_index_column = table.columns[self.config.htm_index_column] 

1054 exprlist = [] 

1055 pixel_ranges = self._htm_indices(region) 

1056 for low, upper in pixel_ranges: 

1057 upper -= 1 

1058 if low == upper: 

1059 exprlist.append(htm_index_column == low) 

1060 else: 

1061 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1062 

1063 return sql.expression.or_(*exprlist) 

1064 

1065 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1066 """Calculate HTM index for each record and add it to a DataFrame. 

1067 

1068 Notes 

1069 ----- 

1070 This overrides any existing column in a DataFrame with the same name 

1071 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1072 returned. 

1073 """ 

1074 # calculate HTM index for every DiaObject 

1075 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1076 ra_col, dec_col = self.config.ra_dec_columns 

1077 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1078 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1079 idx = self.pixelator.index(uv3d) 

1080 htm_index[i] = idx 

1081 df = df.copy() 

1082 df[self.config.htm_index_column] = htm_index 

1083 return df 

1084 

1085 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1086 """Add pixelId column to DiaSource catalog. 

1087 

1088 Notes 

1089 ----- 

1090 This method copies pixelId value from a matching DiaObject record. 

1091 DiaObject catalog needs to have a pixelId column filled by 

1092 ``_add_obj_htm_index`` method and DiaSource records need to be 

1093 associated to DiaObjects via ``diaObjectId`` column. 

1094 

1095 This overrides any existing column in a DataFrame with the same name 

1096 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1097 returned. 

1098 """ 

1099 pixel_id_map: dict[int, int] = { 

1100 diaObjectId: pixelId 

1101 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

1102 } 

1103 # DiaSources associated with SolarSystemObjects do not have an 

1104 # associated DiaObject hence we skip them and set their htmIndex 

1105 # value to 0. 

1106 pixel_id_map[0] = 0 

1107 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1108 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1109 htm_index[i] = pixel_id_map[diaObjId] 

1110 sources = sources.copy() 

1111 sources[self.config.htm_index_column] = htm_index 

1112 return sources