Coverage for python/lsst/dax/apdb/apdbSql.py: 15%

500 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 11:36 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from contextlib import closing, suppress 

32from typing import TYPE_CHECKING, Any, cast 

33 

34import lsst.daf.base as dafBase 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38from felis.simple import Table 

39from lsst.pex.config import ChoiceField, Field, ListField 

40from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

41from lsst.utils.iteration import chunk_iterable 

42from sqlalchemy import func, sql 

43from sqlalchemy.pool import NullPool 

44 

45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData 

46from .apdbConfigFreezer import ApdbConfigFreezer 

47from .apdbMetadataSql import ApdbMetadataSql 

48from .apdbSchema import ApdbTables 

49from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

50from .timer import Timer 

51from .versionTuple import IncompatibleVersionError, VersionTuple 

52 

53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 import sqlite3 

55 

56 from .apdbMetadata import ApdbMetadata 

57 

58_LOG = logging.getLogger(__name__) 

59 

60VERSION = VersionTuple(0, 1, 0) 

61"""Version for the code defined in this module. This needs to be updated 

62(following compatibility rules) when schema produced by this code changes. 

63""" 

64 

65 

66def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

67 """Change the type of uint64 columns to int64, and return copy of data 

68 frame. 

69 """ 

70 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

71 return df.astype({name: np.int64 for name in names}) 

72 

73 

74def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

75 """Calculate starting point for time-based source search. 

76 

77 Parameters 

78 ---------- 

79 visit_time : `lsst.daf.base.DateTime` 

80 Time of current visit. 

81 months : `int` 

82 Number of months in the sources history. 

83 

84 Returns 

85 ------- 

86 time : `float` 

87 A ``midpointMjdTai`` starting point, MJD time. 

88 """ 

89 # TODO: `system` must be consistent with the code in ap_association 

90 # (see DM-31996) 

91 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

92 

93 

94def _onSqlite3Connect( 

95 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

96) -> None: 

97 # Enable foreign keys 

98 with closing(dbapiConnection.cursor()) as cursor: 

99 cursor.execute("PRAGMA foreign_keys=ON;") 

100 

101 

102class ApdbSqlConfig(ApdbConfig): 

103 """APDB configuration class for SQL implementation (ApdbSql).""" 

104 

105 db_url = Field[str](doc="SQLAlchemy database connection URI") 

106 isolation_level = ChoiceField[str]( 

107 doc=( 

108 "Transaction isolation level, if unset then backend-default value " 

109 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

110 "Some backends may not support every allowed value." 

111 ), 

112 allowed={ 

113 "READ_COMMITTED": "Read committed", 

114 "READ_UNCOMMITTED": "Read uncommitted", 

115 "REPEATABLE_READ": "Repeatable read", 

116 "SERIALIZABLE": "Serializable", 

117 }, 

118 default=None, 

119 optional=True, 

120 ) 

121 connection_pool = Field[bool]( 

122 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

123 default=True, 

124 ) 

125 connection_timeout = Field[float]( 

126 doc=( 

127 "Maximum time to wait time for database lock to be released before exiting. " 

128 "Defaults to sqlalchemy defaults if not set." 

129 ), 

130 default=None, 

131 optional=True, 

132 ) 

133 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

134 dia_object_index = ChoiceField[str]( 

135 doc="Indexing mode for DiaObject table", 

136 allowed={ 

137 "baseline": "Index defined in baseline schema", 

138 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

139 "last_object_table": "Separate DiaObjectLast table", 

140 }, 

141 default="baseline", 

142 ) 

143 htm_level = Field[int](doc="HTM indexing level", default=20) 

144 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

145 htm_index_column = Field[str]( 

146 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

147 ) 

148 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

149 dia_object_columns = ListField[str]( 

150 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

151 ) 

152 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

153 namespace = Field[str]( 

154 doc=( 

155 "Namespace or schema name for all tables in APDB database. " 

156 "Presently only works for PostgreSQL backend. " 

157 "If schema with this name does not exist it will be created when " 

158 "APDB tables are created." 

159 ), 

160 default=None, 

161 optional=True, 

162 ) 

163 timer = Field[bool](doc="If True then print/log timing information", default=False) 

164 

165 def validate(self) -> None: 

166 super().validate() 

167 if len(self.ra_dec_columns) != 2: 

168 raise ValueError("ra_dec_columns must have exactly two column names") 

169 

170 

171class ApdbSqlTableData(ApdbTableData): 

172 """Implementation of ApdbTableData that wraps sqlalchemy Result.""" 

173 

174 def __init__(self, result: sqlalchemy.engine.Result): 

175 self._keys = list(result.keys()) 

176 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall())) 

177 

178 def column_names(self) -> list[str]: 

179 return self._keys 

180 

181 def rows(self) -> Iterable[tuple]: 

182 return self._rows 

183 

184 

185class ApdbSql(Apdb): 

186 """Implementation of APDB interface based on SQL database. 

187 

188 The implementation is configured via standard ``pex_config`` mechanism 

189 using `ApdbSqlConfig` configuration class. For an example of different 

190 configurations check ``config/`` folder. 

191 

192 Parameters 

193 ---------- 

194 config : `ApdbSqlConfig` 

195 Configuration object. 

196 """ 

197 

198 ConfigClass = ApdbSqlConfig 

199 

200 metadataSchemaVersionKey = "version:schema" 

201 """Name of the metadata key to store schema version number.""" 

202 

203 metadataCodeVersionKey = "version:ApdbSql" 

204 """Name of the metadata key to store code version number.""" 

205 

206 metadataConfigKey = "config:apdb-sql.json" 

207 """Name of the metadata key to store code version number.""" 

208 

209 _frozen_parameters = ( 

210 "use_insert_id", 

211 "dia_object_index", 

212 "htm_level", 

213 "htm_index_column", 

214 "ra_dec_columns", 

215 ) 

216 """Names of the config parameters to be frozen in metadata table.""" 

217 

218 def __init__(self, config: ApdbSqlConfig): 

219 self._engine = self._makeEngine(config) 

220 

221 sa_metadata = sqlalchemy.MetaData(schema=config.namespace) 

222 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix) 

223 meta_table: sqlalchemy.schema.Table | None = None 

224 with suppress(sqlalchemy.exc.NoSuchTableError): 

225 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine) 

226 

227 self._metadata = ApdbMetadataSql(self._engine, meta_table) 

228 

229 # Read frozen config from metadata. 

230 config_json = self._metadata.get(self.metadataConfigKey) 

231 if config_json is not None: 

232 # Update config from metadata. 

233 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters) 

234 self.config = freezer.update(config, config_json) 

235 else: 

236 self.config = config 

237 self.config.validate() 

238 

239 self._schema = ApdbSqlSchema( 

240 engine=self._engine, 

241 dia_object_index=self.config.dia_object_index, 

242 schema_file=self.config.schema_file, 

243 schema_name=self.config.schema_name, 

244 prefix=self.config.prefix, 

245 namespace=self.config.namespace, 

246 htm_index_column=self.config.htm_index_column, 

247 use_insert_id=self.config.use_insert_id, 

248 ) 

249 

250 if self._metadata.table_exists(): 

251 self._versionCheck(self._metadata) 

252 

253 self.pixelator = HtmPixelization(self.config.htm_level) 

254 self.use_insert_id = self._schema.has_insert_id 

255 

256 _LOG.debug("APDB Configuration:") 

257 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

258 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

259 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

260 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

261 _LOG.debug(" schema_file: %s", self.config.schema_file) 

262 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

263 _LOG.debug(" schema prefix: %s", self.config.prefix) 

264 

265 @classmethod 

266 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine: 

267 """Make SQLALchemy engine based on configured parameters. 

268 

269 Parameters 

270 ---------- 

271 config : `ApdbSqlConfig` 

272 Configuration object. 

273 """ 

274 # engine is reused between multiple processes, make sure that we don't 

275 # share connections by disabling pool (by using NullPool class) 

276 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo) 

277 conn_args: dict[str, Any] = dict() 

278 if not config.connection_pool: 

279 kw.update(poolclass=NullPool) 

280 if config.isolation_level is not None: 

281 kw.update(isolation_level=config.isolation_level) 

282 elif config.db_url.startswith("sqlite"): # type: ignore 

283 # Use READ_UNCOMMITTED as default value for sqlite. 

284 kw.update(isolation_level="READ_UNCOMMITTED") 

285 if config.connection_timeout is not None: 

286 if config.db_url.startswith("sqlite"): 

287 conn_args.update(timeout=config.connection_timeout) 

288 elif config.db_url.startswith(("postgresql", "mysql")): 

289 conn_args.update(connect_timeout=config.connection_timeout) 

290 kw.update(connect_args=conn_args) 

291 engine = sqlalchemy.create_engine(config.db_url, **kw) 

292 

293 if engine.dialect.name == "sqlite": 

294 # Need to enable foreign keys on every new connection. 

295 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) 

296 

297 return engine 

298 

299 def _versionCheck(self, metadata: ApdbMetadataSql) -> None: 

300 """Check schema version compatibility.""" 

301 

302 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

303 """Retrieve version number from given metadata key.""" 

304 if metadata.table_exists(): 

305 version_str = metadata.get(key) 

306 if version_str is None: 

307 # Should not happen with existing metadata table. 

308 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

309 return VersionTuple.fromString(version_str) 

310 return default 

311 

312 # For old databases where metadata table does not exist we assume that 

313 # version of both code and schema is 0.1.0. 

314 initial_version = VersionTuple(0, 1, 0) 

315 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

316 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

317 

318 # For now there is no way to make read-only APDB instances, assume that 

319 # any access can do updates. 

320 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

321 raise IncompatibleVersionError( 

322 f"Configured schema version {self._schema.schemaVersion()} " 

323 f"is not compatible with database version {db_schema_version}" 

324 ) 

325 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

326 raise IncompatibleVersionError( 

327 f"Current code version {self.apdbImplementationVersion()} " 

328 f"is not compatible with database version {db_code_version}" 

329 ) 

330 

331 @classmethod 

332 def apdbImplementationVersion(cls) -> VersionTuple: 

333 # Docstring inherited from base class. 

334 return VERSION 

335 

336 def apdbSchemaVersion(self) -> VersionTuple: 

337 # Docstring inherited from base class. 

338 return self._schema.schemaVersion() 

339 

340 def tableRowCount(self) -> dict[str, int]: 

341 """Return dictionary with the table names and row counts. 

342 

343 Used by ``ap_proto`` to keep track of the size of the database tables. 

344 Depending on database technology this could be expensive operation. 

345 

346 Returns 

347 ------- 

348 row_counts : `dict` 

349 Dict where key is a table name and value is a row count. 

350 """ 

351 res = {} 

352 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

353 if self.config.dia_object_index == "last_object_table": 

354 tables.append(ApdbTables.DiaObjectLast) 

355 with self._engine.begin() as conn: 

356 for table in tables: 

357 sa_table = self._schema.get_table(table) 

358 stmt = sql.select(func.count()).select_from(sa_table) 

359 count: int = conn.execute(stmt).scalar_one() 

360 res[table.name] = count 

361 

362 return res 

363 

364 def tableDef(self, table: ApdbTables) -> Table | None: 

365 # docstring is inherited from a base class 

366 return self._schema.tableSchemas.get(table) 

367 

368 @classmethod 

369 def makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None: 

370 # docstring is inherited from a base class 

371 

372 if not isinstance(config, ApdbSqlConfig): 

373 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

374 

375 engine = cls._makeEngine(config) 

376 

377 # Ask schema class to create all tables. 

378 schema = ApdbSqlSchema( 

379 engine=engine, 

380 dia_object_index=config.dia_object_index, 

381 schema_file=config.schema_file, 

382 schema_name=config.schema_name, 

383 prefix=config.prefix, 

384 namespace=config.namespace, 

385 htm_index_column=config.htm_index_column, 

386 use_insert_id=config.use_insert_id, 

387 ) 

388 schema.makeSchema(drop=drop) 

389 

390 # Need metadata table to store few items in it, if table exists. 

391 meta_table: sqlalchemy.schema.Table | None = None 

392 with suppress(ValueError): 

393 meta_table = schema.get_table(ApdbTables.metadata) 

394 

395 apdb_meta = ApdbMetadataSql(engine, meta_table) 

396 if apdb_meta.table_exists(): 

397 # Fill version numbers, overwrite if they are already there. 

398 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

399 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

400 

401 # Store frozen part of a configuration in metadata. 

402 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters) 

403 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

404 

405 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

406 # docstring is inherited from a base class 

407 

408 # decide what columns we need 

409 if self.config.dia_object_index == "last_object_table": 

410 table_enum = ApdbTables.DiaObjectLast 

411 else: 

412 table_enum = ApdbTables.DiaObject 

413 table = self._schema.get_table(table_enum) 

414 if not self.config.dia_object_columns: 

415 columns = self._schema.get_apdb_columns(table_enum) 

416 else: 

417 columns = [table.c[col] for col in self.config.dia_object_columns] 

418 query = sql.select(*columns) 

419 

420 # build selection 

421 query = query.where(self._filterRegion(table, region)) 

422 

423 # select latest version of objects 

424 if self.config.dia_object_index != "last_object_table": 

425 query = query.where(table.c.validityEnd == None) # noqa: E711 

426 

427 # _LOG.debug("query: %s", query) 

428 

429 # execute select 

430 with Timer("DiaObject select", self.config.timer): 

431 with self._engine.begin() as conn: 

432 objects = pandas.read_sql_query(query, conn) 

433 _LOG.debug("found %s DiaObjects", len(objects)) 

434 return objects 

435 

436 def getDiaSources( 

437 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

438 ) -> pandas.DataFrame | None: 

439 # docstring is inherited from a base class 

440 if self.config.read_sources_months == 0: 

441 _LOG.debug("Skip DiaSources fetching") 

442 return None 

443 

444 if object_ids is None: 

445 # region-based select 

446 return self._getDiaSourcesInRegion(region, visit_time) 

447 else: 

448 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

449 

450 def getDiaForcedSources( 

451 self, region: Region, object_ids: Iterable[int] | None, visit_time: dafBase.DateTime 

452 ) -> pandas.DataFrame | None: 

453 # docstring is inherited from a base class 

454 if self.config.read_forced_sources_months == 0: 

455 _LOG.debug("Skip DiaForceSources fetching") 

456 return None 

457 

458 if object_ids is None: 

459 # This implementation does not support region-based selection. 

460 raise NotImplementedError("Region-based selection is not supported") 

461 

462 # TODO: DateTime.MJD must be consistent with code in ap_association, 

463 # alternatively we can fill midpointMjdTai ourselves in store() 

464 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

465 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

466 

467 with Timer("DiaForcedSource select", self.config.timer): 

468 sources = self._getSourcesByIDs( 

469 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

470 ) 

471 

472 _LOG.debug("found %s DiaForcedSources", len(sources)) 

473 return sources 

474 

475 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

476 # docstring is inherited from a base class 

477 raise NotImplementedError() 

478 

479 def containsCcdVisit(self, ccdVisitId: int) -> bool: 

480 """Test whether data for a given visit-detector is present in the APDB. 

481 

482 This method is a placeholder until `Apdb.containsVisitDetector` can 

483 be implemented. 

484 

485 Parameters 

486 ---------- 

487 ccdVisitId : `int` 

488 The packed ID of the visit-detector to search for. 

489 

490 Returns 

491 ------- 

492 present : `bool` 

493 `True` if some DiaSource records exist for the specified 

494 observation, `False` otherwise. 

495 """ 

496 # TODO: remove this method in favor of containsVisitDetector on either 

497 # DM-41671 or a ticket that removes ccdVisitId from these tables 

498 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

499 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

500 # Query should load only one leaf page of the index 

501 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

502 # Backup query in case an image was processed but had no diaSources 

503 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

504 

505 with self._engine.begin() as conn: 

506 result = conn.execute(query1).scalar_one_or_none() 

507 if result is not None: 

508 return True 

509 else: 

510 result = conn.execute(query2).scalar_one_or_none() 

511 return result is not None 

512 

513 def getInsertIds(self) -> list[ApdbInsertId] | None: 

514 # docstring is inherited from a base class 

515 if not self._schema.has_insert_id: 

516 return None 

517 

518 table = self._schema.get_table(ExtraTables.DiaInsertId) 

519 assert table is not None, "has_insert_id=True means it must be defined" 

520 query = sql.select(table.columns["insert_id"], table.columns["insert_time"]).order_by( 

521 table.columns["insert_time"] 

522 ) 

523 with Timer("DiaObject insert id select", self.config.timer): 

524 with self._engine.connect() as conn: 

525 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

526 ids = [] 

527 for row in result: 

528 insert_time = dafBase.DateTime(int(row[1].timestamp() * 1e9)) 

529 ids.append(ApdbInsertId(id=row[0], insert_time=insert_time)) 

530 return ids 

531 

532 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None: 

533 # docstring is inherited from a base class 

534 if not self._schema.has_insert_id: 

535 raise ValueError("APDB is not configured for history storage") 

536 

537 table = self._schema.get_table(ExtraTables.DiaInsertId) 

538 

539 insert_ids = [id.id for id in ids] 

540 where_clause = table.columns["insert_id"].in_(insert_ids) 

541 stmt = table.delete().where(where_clause) 

542 with self._engine.begin() as conn: 

543 conn.execute(stmt) 

544 

545 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

546 # docstring is inherited from a base class 

547 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId) 

548 

549 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

550 # docstring is inherited from a base class 

551 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId) 

552 

553 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData: 

554 # docstring is inherited from a base class 

555 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId) 

556 

557 def _get_history( 

558 self, 

559 ids: Iterable[ApdbInsertId], 

560 table_enum: ApdbTables, 

561 history_table_enum: ExtraTables, 

562 ) -> ApdbTableData: 

563 """Return catalog of records for given insert identifiers, common 

564 implementation for all DIA tables. 

565 """ 

566 if not self._schema.has_insert_id: 

567 raise ValueError("APDB is not configured for history retrieval") 

568 

569 table = self._schema.get_table(table_enum) 

570 history_table = self._schema.get_table(history_table_enum) 

571 

572 join = table.join(history_table) 

573 insert_ids = [id.id for id in ids] 

574 history_id_column = history_table.columns["insert_id"] 

575 apdb_columns = self._schema.get_apdb_columns(table_enum) 

576 where_clause = history_id_column.in_(insert_ids) 

577 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause) 

578 

579 # execute select 

580 with Timer(f"{table.name} history select", self.config.timer): 

581 with self._engine.begin() as conn: 

582 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query) 

583 return ApdbSqlTableData(result) 

584 

585 def getSSObjects(self) -> pandas.DataFrame: 

586 # docstring is inherited from a base class 

587 

588 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

589 query = sql.select(*columns) 

590 

591 # execute select 

592 with Timer("DiaObject select", self.config.timer): 

593 with self._engine.begin() as conn: 

594 objects = pandas.read_sql_query(query, conn) 

595 _LOG.debug("found %s SSObjects", len(objects)) 

596 return objects 

597 

598 def store( 

599 self, 

600 visit_time: dafBase.DateTime, 

601 objects: pandas.DataFrame, 

602 sources: pandas.DataFrame | None = None, 

603 forced_sources: pandas.DataFrame | None = None, 

604 ) -> None: 

605 # docstring is inherited from a base class 

606 

607 # We want to run all inserts in one transaction. 

608 with self._engine.begin() as connection: 

609 insert_id: ApdbInsertId | None = None 

610 if self._schema.has_insert_id: 

611 insert_id = ApdbInsertId.new_insert_id(visit_time) 

612 self._storeInsertId(insert_id, visit_time, connection) 

613 

614 # fill pixelId column for DiaObjects 

615 objects = self._add_obj_htm_index(objects) 

616 self._storeDiaObjects(objects, visit_time, insert_id, connection) 

617 

618 if sources is not None: 

619 # copy pixelId column from DiaObjects to DiaSources 

620 sources = self._add_src_htm_index(sources, objects) 

621 self._storeDiaSources(sources, insert_id, connection) 

622 

623 if forced_sources is not None: 

624 self._storeDiaForcedSources(forced_sources, insert_id, connection) 

625 

626 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

627 # docstring is inherited from a base class 

628 

629 idColumn = "ssObjectId" 

630 table = self._schema.get_table(ApdbTables.SSObject) 

631 

632 # everything to be done in single transaction 

633 with self._engine.begin() as conn: 

634 # Find record IDs that already exist. Some types like np.int64 can 

635 # cause issues with sqlalchemy, convert them to int. 

636 ids = sorted(int(oid) for oid in objects[idColumn]) 

637 

638 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

639 result = conn.execute(query) 

640 knownIds = set(row.ssObjectId for row in result) 

641 

642 filter = objects[idColumn].isin(knownIds) 

643 toUpdate = cast(pandas.DataFrame, objects[filter]) 

644 toInsert = cast(pandas.DataFrame, objects[~filter]) 

645 

646 # insert new records 

647 if len(toInsert) > 0: 

648 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema) 

649 

650 # update existing records 

651 if len(toUpdate) > 0: 

652 whereKey = f"{idColumn}_param" 

653 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

654 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

655 values = toUpdate.to_dict("records") 

656 result = conn.execute(update, values) 

657 

658 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

659 # docstring is inherited from a base class 

660 

661 table = self._schema.get_table(ApdbTables.DiaSource) 

662 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

663 

664 with self._engine.begin() as conn: 

665 # Need to make sure that every ID exists in the database, but 

666 # executemany may not support rowcount, so iterate and check what 

667 # is missing. 

668 missing_ids: list[int] = [] 

669 for key, value in idMap.items(): 

670 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

671 result = conn.execute(query, params) 

672 if result.rowcount == 0: 

673 missing_ids.append(key) 

674 if missing_ids: 

675 missing = ",".join(str(item) for item in missing_ids) 

676 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

677 

678 def dailyJob(self) -> None: 

679 # docstring is inherited from a base class 

680 pass 

681 

682 def countUnassociatedObjects(self) -> int: 

683 # docstring is inherited from a base class 

684 

685 # Retrieve the DiaObject table. 

686 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

687 

688 # Construct the sql statement. 

689 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

690 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

691 

692 # Return the count. 

693 with self._engine.begin() as conn: 

694 count = conn.execute(stmt).scalar_one() 

695 

696 return count 

697 

698 @property 

699 def metadata(self) -> ApdbMetadata: 

700 # docstring is inherited from a base class 

701 if self._metadata is None: 

702 raise RuntimeError("Database schema was not initialized.") 

703 return self._metadata 

704 

705 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame: 

706 """Return catalog of DiaSource instances from given region. 

707 

708 Parameters 

709 ---------- 

710 region : `lsst.sphgeom.Region` 

711 Region to search for DIASources. 

712 visit_time : `lsst.daf.base.DateTime` 

713 Time of the current visit. 

714 

715 Returns 

716 ------- 

717 catalog : `pandas.DataFrame` 

718 Catalog containing DiaSource records. 

719 """ 

720 # TODO: DateTime.MJD must be consistent with code in ap_association, 

721 # alternatively we can fill midpointMjdTai ourselves in store() 

722 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

723 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

724 

725 table = self._schema.get_table(ApdbTables.DiaSource) 

726 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

727 query = sql.select(*columns) 

728 

729 # build selection 

730 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

731 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

732 query = query.where(where) 

733 

734 # execute select 

735 with Timer("DiaSource select", self.config.timer): 

736 with self._engine.begin() as conn: 

737 sources = pandas.read_sql_query(query, conn) 

738 _LOG.debug("found %s DiaSources", len(sources)) 

739 return sources 

740 

741 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: dafBase.DateTime) -> pandas.DataFrame: 

742 """Return catalog of DiaSource instances given set of DiaObject IDs. 

743 

744 Parameters 

745 ---------- 

746 object_ids : 

747 Collection of DiaObject IDs 

748 visit_time : `lsst.daf.base.DateTime` 

749 Time of the current visit. 

750 

751 Returns 

752 ------- 

753 catalog : `pandas.DataFrame` 

754 Catalog contaning DiaSource records. 

755 """ 

756 # TODO: DateTime.MJD must be consistent with code in ap_association, 

757 # alternatively we can fill midpointMjdTai ourselves in store() 

758 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

759 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

760 

761 with Timer("DiaSource select", self.config.timer): 

762 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

763 

764 _LOG.debug("found %s DiaSources", len(sources)) 

765 return sources 

766 

767 def _getSourcesByIDs( 

768 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

769 ) -> pandas.DataFrame: 

770 """Return catalog of DiaSource or DiaForcedSource instances given set 

771 of DiaObject IDs. 

772 

773 Parameters 

774 ---------- 

775 table : `sqlalchemy.schema.Table` 

776 Database table. 

777 object_ids : 

778 Collection of DiaObject IDs 

779 midpointMjdTai_start : `float` 

780 Earliest midpointMjdTai to retrieve. 

781 

782 Returns 

783 ------- 

784 catalog : `pandas.DataFrame` 

785 Catalog contaning DiaSource records. `None` is returned if 

786 ``read_sources_months`` configuration parameter is set to 0 or 

787 when ``object_ids`` is empty. 

788 """ 

789 table = self._schema.get_table(table_enum) 

790 columns = self._schema.get_apdb_columns(table_enum) 

791 

792 sources: pandas.DataFrame | None = None 

793 if len(object_ids) <= 0: 

794 _LOG.debug("ID list is empty, just fetch empty result") 

795 query = sql.select(*columns).where(sql.literal(False)) 

796 with self._engine.begin() as conn: 

797 sources = pandas.read_sql_query(query, conn) 

798 else: 

799 data_frames: list[pandas.DataFrame] = [] 

800 for ids in chunk_iterable(sorted(object_ids), 1000): 

801 query = sql.select(*columns) 

802 

803 # Some types like np.int64 can cause issues with 

804 # sqlalchemy, convert them to int. 

805 int_ids = [int(oid) for oid in ids] 

806 

807 # select by object id 

808 query = query.where( 

809 sql.expression.and_( 

810 table.columns["diaObjectId"].in_(int_ids), 

811 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

812 ) 

813 ) 

814 

815 # execute select 

816 with self._engine.begin() as conn: 

817 data_frames.append(pandas.read_sql_query(query, conn)) 

818 

819 if len(data_frames) == 1: 

820 sources = data_frames[0] 

821 else: 

822 sources = pandas.concat(data_frames) 

823 assert sources is not None, "Catalog cannot be None" 

824 return sources 

825 

826 def _storeInsertId( 

827 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection 

828 ) -> None: 

829 dt = visit_time.toPython() 

830 

831 table = self._schema.get_table(ExtraTables.DiaInsertId) 

832 

833 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt) 

834 connection.execute(stmt) 

835 

836 def _storeDiaObjects( 

837 self, 

838 objs: pandas.DataFrame, 

839 visit_time: dafBase.DateTime, 

840 insert_id: ApdbInsertId | None, 

841 connection: sqlalchemy.engine.Connection, 

842 ) -> None: 

843 """Store catalog of DiaObjects from current visit. 

844 

845 Parameters 

846 ---------- 

847 objs : `pandas.DataFrame` 

848 Catalog with DiaObject records. 

849 visit_time : `lsst.daf.base.DateTime` 

850 Time of the visit. 

851 insert_id : `ApdbInsertId` 

852 Insert identifier. 

853 """ 

854 if len(objs) == 0: 

855 _LOG.debug("No objects to write to database.") 

856 return 

857 

858 # Some types like np.int64 can cause issues with sqlalchemy, convert 

859 # them to int. 

860 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

861 _LOG.debug("first object ID: %d", ids[0]) 

862 

863 # TODO: Need to verify that we are using correct scale here for 

864 # DATETIME representation (see DM-31996). 

865 dt = visit_time.toPython() 

866 

867 # everything to be done in single transaction 

868 if self.config.dia_object_index == "last_object_table": 

869 # Insert and replace all records in LAST table. 

870 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

871 

872 # Drop the previous objects (pandas cannot upsert). 

873 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

874 

875 with Timer(table.name + " delete", self.config.timer): 

876 res = connection.execute(query) 

877 _LOG.debug("deleted %s objects", res.rowcount) 

878 

879 # DiaObjectLast is a subset of DiaObject, strip missing columns 

880 last_column_names = [column.name for column in table.columns] 

881 last_objs = objs[last_column_names] 

882 last_objs = _coerce_uint64(last_objs) 

883 

884 if "lastNonForcedSource" in last_objs.columns: 

885 # lastNonForcedSource is defined NOT NULL, fill it with visit 

886 # time just in case. 

887 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

888 else: 

889 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

890 last_objs.set_index(extra_column.index, inplace=True) 

891 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

892 

893 with Timer("DiaObjectLast insert", self.config.timer): 

894 last_objs.to_sql( 

895 table.name, 

896 connection, 

897 if_exists="append", 

898 index=False, 

899 schema=table.schema, 

900 ) 

901 else: 

902 # truncate existing validity intervals 

903 table = self._schema.get_table(ApdbTables.DiaObject) 

904 

905 update = ( 

906 table.update() 

907 .values(validityEnd=dt) 

908 .where( 

909 sql.expression.and_( 

910 table.columns["diaObjectId"].in_(ids), 

911 table.columns["validityEnd"].is_(None), 

912 ) 

913 ) 

914 ) 

915 

916 # _LOG.debug("query: %s", query) 

917 

918 with Timer(table.name + " truncate", self.config.timer): 

919 res = connection.execute(update) 

920 _LOG.debug("truncated %s intervals", res.rowcount) 

921 

922 objs = _coerce_uint64(objs) 

923 

924 # Fill additional columns 

925 extra_columns: list[pandas.Series] = [] 

926 if "validityStart" in objs.columns: 

927 objs["validityStart"] = dt 

928 else: 

929 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

930 if "validityEnd" in objs.columns: 

931 objs["validityEnd"] = None 

932 else: 

933 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

934 if "lastNonForcedSource" in objs.columns: 

935 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

936 # just in case. 

937 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

938 else: 

939 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

940 if extra_columns: 

941 objs.set_index(extra_columns[0].index, inplace=True) 

942 objs = pandas.concat([objs] + extra_columns, axis="columns") 

943 

944 # Insert history data 

945 table = self._schema.get_table(ApdbTables.DiaObject) 

946 history_data: list[dict] = [] 

947 history_stmt: Any = None 

948 if insert_id is not None: 

949 pk_names = [column.name for column in table.primary_key] 

950 history_data = objs[pk_names].to_dict("records") 

951 for row in history_data: 

952 row["insert_id"] = insert_id.id 

953 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId) 

954 history_stmt = history_table.insert() 

955 

956 # insert new versions 

957 with Timer("DiaObject insert", self.config.timer): 

958 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

959 if history_stmt is not None: 

960 connection.execute(history_stmt, history_data) 

961 

962 def _storeDiaSources( 

963 self, 

964 sources: pandas.DataFrame, 

965 insert_id: ApdbInsertId | None, 

966 connection: sqlalchemy.engine.Connection, 

967 ) -> None: 

968 """Store catalog of DiaSources from current visit. 

969 

970 Parameters 

971 ---------- 

972 sources : `pandas.DataFrame` 

973 Catalog containing DiaSource records 

974 """ 

975 table = self._schema.get_table(ApdbTables.DiaSource) 

976 

977 # Insert history data 

978 history: list[dict] = [] 

979 history_stmt: Any = None 

980 if insert_id is not None: 

981 pk_names = [column.name for column in table.primary_key] 

982 history = sources[pk_names].to_dict("records") 

983 for row in history: 

984 row["insert_id"] = insert_id.id 

985 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId) 

986 history_stmt = history_table.insert() 

987 

988 # everything to be done in single transaction 

989 with Timer("DiaSource insert", self.config.timer): 

990 sources = _coerce_uint64(sources) 

991 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

992 if history_stmt is not None: 

993 connection.execute(history_stmt, history) 

994 

995 def _storeDiaForcedSources( 

996 self, 

997 sources: pandas.DataFrame, 

998 insert_id: ApdbInsertId | None, 

999 connection: sqlalchemy.engine.Connection, 

1000 ) -> None: 

1001 """Store a set of DiaForcedSources from current visit. 

1002 

1003 Parameters 

1004 ---------- 

1005 sources : `pandas.DataFrame` 

1006 Catalog containing DiaForcedSource records 

1007 """ 

1008 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1009 

1010 # Insert history data 

1011 history: list[dict] = [] 

1012 history_stmt: Any = None 

1013 if insert_id is not None: 

1014 pk_names = [column.name for column in table.primary_key] 

1015 history = sources[pk_names].to_dict("records") 

1016 for row in history: 

1017 row["insert_id"] = insert_id.id 

1018 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId) 

1019 history_stmt = history_table.insert() 

1020 

1021 # everything to be done in single transaction 

1022 with Timer("DiaForcedSource insert", self.config.timer): 

1023 sources = _coerce_uint64(sources) 

1024 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1025 if history_stmt is not None: 

1026 connection.execute(history_stmt, history) 

1027 

1028 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1029 """Generate a set of HTM indices covering specified region. 

1030 

1031 Parameters 

1032 ---------- 

1033 region: `sphgeom.Region` 

1034 Region that needs to be indexed. 

1035 

1036 Returns 

1037 ------- 

1038 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1039 """ 

1040 _LOG.debug("region: %s", region) 

1041 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

1042 

1043 return indices.ranges() 

1044 

1045 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1046 """Make SQLAlchemy expression for selecting records in a region.""" 

1047 htm_index_column = table.columns[self.config.htm_index_column] 

1048 exprlist = [] 

1049 pixel_ranges = self._htm_indices(region) 

1050 for low, upper in pixel_ranges: 

1051 upper -= 1 

1052 if low == upper: 

1053 exprlist.append(htm_index_column == low) 

1054 else: 

1055 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1056 

1057 return sql.expression.or_(*exprlist) 

1058 

1059 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1060 """Calculate HTM index for each record and add it to a DataFrame. 

1061 

1062 Notes 

1063 ----- 

1064 This overrides any existing column in a DataFrame with the same name 

1065 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1066 returned. 

1067 """ 

1068 # calculate HTM index for every DiaObject 

1069 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1070 ra_col, dec_col = self.config.ra_dec_columns 

1071 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1072 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1073 idx = self.pixelator.index(uv3d) 

1074 htm_index[i] = idx 

1075 df = df.copy() 

1076 df[self.config.htm_index_column] = htm_index 

1077 return df 

1078 

1079 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1080 """Add pixelId column to DiaSource catalog. 

1081 

1082 Notes 

1083 ----- 

1084 This method copies pixelId value from a matching DiaObject record. 

1085 DiaObject catalog needs to have a pixelId column filled by 

1086 ``_add_obj_htm_index`` method and DiaSource records need to be 

1087 associated to DiaObjects via ``diaObjectId`` column. 

1088 

1089 This overrides any existing column in a DataFrame with the same name 

1090 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1091 returned. 

1092 """ 

1093 pixel_id_map: dict[int, int] = { 

1094 diaObjectId: pixelId 

1095 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

1096 } 

1097 # DiaSources associated with SolarSystemObjects do not have an 

1098 # associated DiaObject hence we skip them and set their htmIndex 

1099 # value to 0. 

1100 pixel_id_map[0] = 0 

1101 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1102 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1103 htm_index[i] = pixel_id_map[diaObjId] 

1104 sources = sources.copy() 

1105 sources[self.config.htm_index_column] = htm_index 

1106 return sources