Coverage for python/lsst/dax/apdb/sql/apdbSql.py: 15%

493 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-26 09:55 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from contextlib import closing, suppress 

32from typing import TYPE_CHECKING, Any, cast 

33 

34import astropy.time 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38import sqlalchemy.dialects.postgresql 

39import sqlalchemy.dialects.sqlite 

40from lsst.pex.config import ChoiceField, Field, ListField 

41from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

42from lsst.utils.iteration import chunk_iterable 

43from sqlalchemy import func, sql 

44from sqlalchemy.pool import NullPool 

45 

46from ..apdb import Apdb, ApdbConfig 

47from ..apdbConfigFreezer import ApdbConfigFreezer 

48from ..apdbReplica import ReplicaChunk 

49from ..apdbSchema import ApdbTables 

50from ..schema_model import Table 

51from ..timer import Timer 

52from ..versionTuple import IncompatibleVersionError, VersionTuple 

53from .apdbMetadataSql import ApdbMetadataSql 

54from .apdbSqlReplica import ApdbSqlReplica 

55from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

56 

57if TYPE_CHECKING: 

58 import sqlite3 

59 

60 from ..apdbMetadata import ApdbMetadata 

61 

62_LOG = logging.getLogger(__name__) 

63 

64VERSION = VersionTuple(0, 1, 0) 

65"""Version for the code controlling non-replication tables. This needs to be 

66updated following compatibility rules when schema produced by this code 

67changes. 

68""" 

69 

70 

71def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

72 """Change the type of uint64 columns to int64, and return copy of data 

73 frame. 

74 """ 

75 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

76 return df.astype({name: np.int64 for name in names}) 

77 

78 

79def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float: 

80 """Calculate starting point for time-based source search. 

81 

82 Parameters 

83 ---------- 

84 visit_time : `astropy.time.Time` 

85 Time of current visit. 

86 months : `int` 

87 Number of months in the sources history. 

88 

89 Returns 

90 ------- 

91 time : `float` 

92 A ``midpointMjdTai`` starting point, MJD time. 

93 """ 

94 # TODO: Use of MJD must be consistent with the code in ap_association 

95 # (see DM-31996) 

96 return visit_time.mjd - months * 30 

97 

98 

99def _onSqlite3Connect( 

100 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

101) -> None: 

102 # Enable foreign keys 

103 with closing(dbapiConnection.cursor()) as cursor: 

104 cursor.execute("PRAGMA foreign_keys=ON;") 

105 

106 

107class ApdbSqlConfig(ApdbConfig): 

108 """APDB configuration class for SQL implementation (ApdbSql).""" 

109 

110 db_url = Field[str](doc="SQLAlchemy database connection URI") 

111 isolation_level = ChoiceField[str]( 

112 doc=( 

113 "Transaction isolation level, if unset then backend-default value " 

114 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

115 "Some backends may not support every allowed value." 

116 ), 

117 allowed={ 

118 "READ_COMMITTED": "Read committed", 

119 "READ_UNCOMMITTED": "Read uncommitted", 

120 "REPEATABLE_READ": "Repeatable read", 

121 "SERIALIZABLE": "Serializable", 

122 }, 

123 default=None, 

124 optional=True, 

125 ) 

126 connection_pool = Field[bool]( 

127 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

128 default=True, 

129 ) 

130 connection_timeout = Field[float]( 

131 doc=( 

132 "Maximum time to wait time for database lock to be released before exiting. " 

133 "Defaults to sqlalchemy defaults if not set." 

134 ), 

135 default=None, 

136 optional=True, 

137 ) 

138 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

139 dia_object_index = ChoiceField[str]( 

140 doc="Indexing mode for DiaObject table", 

141 allowed={ 

142 "baseline": "Index defined in baseline schema", 

143 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

144 "last_object_table": "Separate DiaObjectLast table", 

145 }, 

146 default="baseline", 

147 ) 

148 htm_level = Field[int](doc="HTM indexing level", default=20) 

149 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

150 htm_index_column = Field[str]( 

151 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

152 ) 

153 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

154 dia_object_columns = ListField[str]( 

155 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

156 ) 

157 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

158 namespace = Field[str]( 

159 doc=( 

160 "Namespace or schema name for all tables in APDB database. " 

161 "Presently only works for PostgreSQL backend. " 

162 "If schema with this name does not exist it will be created when " 

163 "APDB tables are created." 

164 ), 

165 default=None, 

166 optional=True, 

167 ) 

168 timer = Field[bool](doc="If True then print/log timing information", default=False) 

169 

170 def validate(self) -> None: 

171 super().validate() 

172 if len(self.ra_dec_columns) != 2: 

173 raise ValueError("ra_dec_columns must have exactly two column names") 

174 

175 

176class ApdbSql(Apdb): 

177 """Implementation of APDB interface based on SQL database. 

178 

179 The implementation is configured via standard ``pex_config`` mechanism 

180 using `ApdbSqlConfig` configuration class. For an example of different 

181 configurations check ``config/`` folder. 

182 

183 Parameters 

184 ---------- 

185 config : `ApdbSqlConfig` 

186 Configuration object. 

187 """ 

188 

189 ConfigClass = ApdbSqlConfig 

190 

191 metadataSchemaVersionKey = "version:schema" 

192 """Name of the metadata key to store schema version number.""" 

193 

194 metadataCodeVersionKey = "version:ApdbSql" 

195 """Name of the metadata key to store code version number.""" 

196 

197 metadataReplicaVersionKey = "version:ApdbSqlReplica" 

198 """Name of the metadata key to store replica code version number.""" 

199 

200 metadataConfigKey = "config:apdb-sql.json" 

201 """Name of the metadata key to store code version number.""" 

202 

203 _frozen_parameters = ( 

204 "use_insert_id", 

205 "dia_object_index", 

206 "htm_level", 

207 "htm_index_column", 

208 "ra_dec_columns", 

209 ) 

210 """Names of the config parameters to be frozen in metadata table.""" 

211 

212 def __init__(self, config: ApdbSqlConfig): 

213 self._engine = self._makeEngine(config) 

214 

215 sa_metadata = sqlalchemy.MetaData(schema=config.namespace) 

216 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix) 

217 meta_table: sqlalchemy.schema.Table | None = None 

218 with suppress(sqlalchemy.exc.NoSuchTableError): 

219 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine) 

220 

221 self._metadata = ApdbMetadataSql(self._engine, meta_table) 

222 

223 # Read frozen config from metadata. 

224 config_json = self._metadata.get(self.metadataConfigKey) 

225 if config_json is not None: 

226 # Update config from metadata. 

227 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters) 

228 self.config = freezer.update(config, config_json) 

229 else: 

230 self.config = config 

231 self.config.validate() 

232 

233 self._schema = ApdbSqlSchema( 

234 engine=self._engine, 

235 dia_object_index=self.config.dia_object_index, 

236 schema_file=self.config.schema_file, 

237 schema_name=self.config.schema_name, 

238 prefix=self.config.prefix, 

239 namespace=self.config.namespace, 

240 htm_index_column=self.config.htm_index_column, 

241 enable_replica=self.config.use_insert_id, 

242 ) 

243 

244 if self._metadata.table_exists(): 

245 self._versionCheck(self._metadata) 

246 

247 self.pixelator = HtmPixelization(self.config.htm_level) 

248 

249 _LOG.debug("APDB Configuration:") 

250 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

251 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

252 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

253 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

254 _LOG.debug(" schema_file: %s", self.config.schema_file) 

255 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

256 _LOG.debug(" schema prefix: %s", self.config.prefix) 

257 

258 @classmethod 

259 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine: 

260 """Make SQLALchemy engine based on configured parameters. 

261 

262 Parameters 

263 ---------- 

264 config : `ApdbSqlConfig` 

265 Configuration object. 

266 """ 

267 # engine is reused between multiple processes, make sure that we don't 

268 # share connections by disabling pool (by using NullPool class) 

269 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo) 

270 conn_args: dict[str, Any] = dict() 

271 if not config.connection_pool: 

272 kw.update(poolclass=NullPool) 

273 if config.isolation_level is not None: 

274 kw.update(isolation_level=config.isolation_level) 

275 elif config.db_url.startswith("sqlite"): # type: ignore 

276 # Use READ_UNCOMMITTED as default value for sqlite. 

277 kw.update(isolation_level="READ_UNCOMMITTED") 

278 if config.connection_timeout is not None: 

279 if config.db_url.startswith("sqlite"): 

280 conn_args.update(timeout=config.connection_timeout) 

281 elif config.db_url.startswith(("postgresql", "mysql")): 

282 conn_args.update(connect_timeout=config.connection_timeout) 

283 kw.update(connect_args=conn_args) 

284 engine = sqlalchemy.create_engine(config.db_url, **kw) 

285 

286 if engine.dialect.name == "sqlite": 

287 # Need to enable foreign keys on every new connection. 

288 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) 

289 

290 return engine 

291 

292 def _versionCheck(self, metadata: ApdbMetadataSql) -> None: 

293 """Check schema version compatibility.""" 

294 

295 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

296 """Retrieve version number from given metadata key.""" 

297 if metadata.table_exists(): 

298 version_str = metadata.get(key) 

299 if version_str is None: 

300 # Should not happen with existing metadata table. 

301 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

302 return VersionTuple.fromString(version_str) 

303 return default 

304 

305 # For old databases where metadata table does not exist we assume that 

306 # version of both code and schema is 0.1.0. 

307 initial_version = VersionTuple(0, 1, 0) 

308 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

309 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

310 

311 # For now there is no way to make read-only APDB instances, assume that 

312 # any access can do updates. 

313 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

314 raise IncompatibleVersionError( 

315 f"Configured schema version {self._schema.schemaVersion()} " 

316 f"is not compatible with database version {db_schema_version}" 

317 ) 

318 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

319 raise IncompatibleVersionError( 

320 f"Current code version {self.apdbImplementationVersion()} " 

321 f"is not compatible with database version {db_code_version}" 

322 ) 

323 

324 # Check replica code version only if replica is enabled. 

325 if self._schema.has_replica_chunks: 

326 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version) 

327 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion() 

328 if not code_replica_version.checkCompatibility(db_replica_version, True): 

329 raise IncompatibleVersionError( 

330 f"Current replication code version {code_replica_version} " 

331 f"is not compatible with database version {db_replica_version}" 

332 ) 

333 

334 @classmethod 

335 def apdbImplementationVersion(cls) -> VersionTuple: 

336 # Docstring inherited from base class. 

337 return VERSION 

338 

339 @classmethod 

340 def init_database( 

341 cls, 

342 db_url: str, 

343 *, 

344 schema_file: str | None = None, 

345 schema_name: str | None = None, 

346 read_sources_months: int | None = None, 

347 read_forced_sources_months: int | None = None, 

348 use_insert_id: bool = False, 

349 connection_timeout: int | None = None, 

350 dia_object_index: str | None = None, 

351 htm_level: int | None = None, 

352 htm_index_column: str | None = None, 

353 ra_dec_columns: list[str] | None = None, 

354 prefix: str | None = None, 

355 namespace: str | None = None, 

356 drop: bool = False, 

357 ) -> ApdbSqlConfig: 

358 """Initialize new APDB instance and make configuration object for it. 

359 

360 Parameters 

361 ---------- 

362 db_url : `str` 

363 SQLAlchemy database URL. 

364 schema_file : `str`, optional 

365 Location of (YAML) configuration file with APDB schema. If not 

366 specified then default location will be used. 

367 schema_name : str | None 

368 Name of the schema in YAML configuration file. If not specified 

369 then default name will be used. 

370 read_sources_months : `int`, optional 

371 Number of months of history to read from DiaSource. 

372 read_forced_sources_months : `int`, optional 

373 Number of months of history to read from DiaForcedSource. 

374 use_insert_id : `bool` 

375 If True, make additional tables used for replication to PPDB. 

376 connection_timeout : `int`, optional 

377 Database connection timeout in seconds. 

378 dia_object_index : `str`, optional 

379 Indexing mode for DiaObject table. 

380 htm_level : `int`, optional 

381 HTM indexing level. 

382 htm_index_column : `str`, optional 

383 Name of a HTM index column for DiaObject and DiaSource tables. 

384 ra_dec_columns : `list` [`str`], optional 

385 Names of ra/dec columns in DiaObject table. 

386 prefix : `str`, optional 

387 Optional prefix for all table names. 

388 namespace : `str`, optional 

389 Name of the database schema for all APDB tables. If not specified 

390 then default schema is used. 

391 drop : `bool`, optional 

392 If `True` then drop existing tables before re-creating the schema. 

393 

394 Returns 

395 ------- 

396 config : `ApdbSqlConfig` 

397 Resulting configuration object for a created APDB instance. 

398 """ 

399 config = ApdbSqlConfig(db_url=db_url, use_insert_id=use_insert_id) 

400 if schema_file is not None: 

401 config.schema_file = schema_file 

402 if schema_name is not None: 

403 config.schema_name = schema_name 

404 if read_sources_months is not None: 

405 config.read_sources_months = read_sources_months 

406 if read_forced_sources_months is not None: 

407 config.read_forced_sources_months = read_forced_sources_months 

408 if connection_timeout is not None: 

409 config.connection_timeout = connection_timeout 

410 if dia_object_index is not None: 

411 config.dia_object_index = dia_object_index 

412 if htm_level is not None: 

413 config.htm_level = htm_level 

414 if htm_index_column is not None: 

415 config.htm_index_column = htm_index_column 

416 if ra_dec_columns is not None: 

417 config.ra_dec_columns = ra_dec_columns 

418 if prefix is not None: 

419 config.prefix = prefix 

420 if namespace is not None: 

421 config.namespace = namespace 

422 

423 cls._makeSchema(config, drop=drop) 

424 

425 return config 

426 

427 def apdbSchemaVersion(self) -> VersionTuple: 

428 # Docstring inherited from base class. 

429 return self._schema.schemaVersion() 

430 

431 def get_replica(self) -> ApdbSqlReplica: 

432 """Return `ApdbReplica` instance for this database.""" 

433 return ApdbSqlReplica(self._schema, self._engine) 

434 

435 def tableRowCount(self) -> dict[str, int]: 

436 """Return dictionary with the table names and row counts. 

437 

438 Used by ``ap_proto`` to keep track of the size of the database tables. 

439 Depending on database technology this could be expensive operation. 

440 

441 Returns 

442 ------- 

443 row_counts : `dict` 

444 Dict where key is a table name and value is a row count. 

445 """ 

446 res = {} 

447 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

448 if self.config.dia_object_index == "last_object_table": 

449 tables.append(ApdbTables.DiaObjectLast) 

450 with self._engine.begin() as conn: 

451 for table in tables: 

452 sa_table = self._schema.get_table(table) 

453 stmt = sql.select(func.count()).select_from(sa_table) 

454 count: int = conn.execute(stmt).scalar_one() 

455 res[table.name] = count 

456 

457 return res 

458 

459 def tableDef(self, table: ApdbTables) -> Table | None: 

460 # docstring is inherited from a base class 

461 return self._schema.tableSchemas.get(table) 

462 

463 @classmethod 

464 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None: 

465 # docstring is inherited from a base class 

466 

467 if not isinstance(config, ApdbSqlConfig): 

468 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

469 

470 engine = cls._makeEngine(config) 

471 

472 # Ask schema class to create all tables. 

473 schema = ApdbSqlSchema( 

474 engine=engine, 

475 dia_object_index=config.dia_object_index, 

476 schema_file=config.schema_file, 

477 schema_name=config.schema_name, 

478 prefix=config.prefix, 

479 namespace=config.namespace, 

480 htm_index_column=config.htm_index_column, 

481 enable_replica=config.use_insert_id, 

482 ) 

483 schema.makeSchema(drop=drop) 

484 

485 # Need metadata table to store few items in it, if table exists. 

486 meta_table: sqlalchemy.schema.Table | None = None 

487 with suppress(ValueError): 

488 meta_table = schema.get_table(ApdbTables.metadata) 

489 

490 apdb_meta = ApdbMetadataSql(engine, meta_table) 

491 if apdb_meta.table_exists(): 

492 # Fill version numbers, overwrite if they are already there. 

493 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

494 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

495 if config.use_insert_id: 

496 # Only store replica code version if replcia is enabled. 

497 apdb_meta.set( 

498 cls.metadataReplicaVersionKey, 

499 str(ApdbSqlReplica.apdbReplicaImplementationVersion()), 

500 force=True, 

501 ) 

502 

503 # Store frozen part of a configuration in metadata. 

504 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters) 

505 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

506 

507 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

508 # docstring is inherited from a base class 

509 

510 # decide what columns we need 

511 if self.config.dia_object_index == "last_object_table": 

512 table_enum = ApdbTables.DiaObjectLast 

513 else: 

514 table_enum = ApdbTables.DiaObject 

515 table = self._schema.get_table(table_enum) 

516 if not self.config.dia_object_columns: 

517 columns = self._schema.get_apdb_columns(table_enum) 

518 else: 

519 columns = [table.c[col] for col in self.config.dia_object_columns] 

520 query = sql.select(*columns) 

521 

522 # build selection 

523 query = query.where(self._filterRegion(table, region)) 

524 

525 # select latest version of objects 

526 if self.config.dia_object_index != "last_object_table": 

527 query = query.where(table.c.validityEnd == None) # noqa: E711 

528 

529 # _LOG.debug("query: %s", query) 

530 

531 # execute select 

532 with Timer("DiaObject select", self.config.timer): 

533 with self._engine.begin() as conn: 

534 objects = pandas.read_sql_query(query, conn) 

535 _LOG.debug("found %s DiaObjects", len(objects)) 

536 return objects 

537 

538 def getDiaSources( 

539 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

540 ) -> pandas.DataFrame | None: 

541 # docstring is inherited from a base class 

542 if self.config.read_sources_months == 0: 

543 _LOG.debug("Skip DiaSources fetching") 

544 return None 

545 

546 if object_ids is None: 

547 # region-based select 

548 return self._getDiaSourcesInRegion(region, visit_time) 

549 else: 

550 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

551 

552 def getDiaForcedSources( 

553 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

554 ) -> pandas.DataFrame | None: 

555 # docstring is inherited from a base class 

556 if self.config.read_forced_sources_months == 0: 

557 _LOG.debug("Skip DiaForceSources fetching") 

558 return None 

559 

560 if object_ids is None: 

561 # This implementation does not support region-based selection. 

562 raise NotImplementedError("Region-based selection is not supported") 

563 

564 # TODO: DateTime.MJD must be consistent with code in ap_association, 

565 # alternatively we can fill midpointMjdTai ourselves in store() 

566 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

567 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

568 

569 with Timer("DiaForcedSource select", self.config.timer): 

570 sources = self._getSourcesByIDs( 

571 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

572 ) 

573 

574 _LOG.debug("found %s DiaForcedSources", len(sources)) 

575 return sources 

576 

577 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

578 # docstring is inherited from a base class 

579 raise NotImplementedError() 

580 

581 def containsCcdVisit(self, ccdVisitId: int) -> bool: 

582 """Test whether data for a given visit-detector is present in the APDB. 

583 

584 This method is a placeholder until `Apdb.containsVisitDetector` can 

585 be implemented. 

586 

587 Parameters 

588 ---------- 

589 ccdVisitId : `int` 

590 The packed ID of the visit-detector to search for. 

591 

592 Returns 

593 ------- 

594 present : `bool` 

595 `True` if some DiaSource records exist for the specified 

596 observation, `False` otherwise. 

597 """ 

598 # TODO: remove this method in favor of containsVisitDetector on either 

599 # DM-41671 or a ticket that removes ccdVisitId from these tables 

600 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

601 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

602 # Query should load only one leaf page of the index 

603 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

604 # Backup query in case an image was processed but had no diaSources 

605 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

606 

607 with self._engine.begin() as conn: 

608 result = conn.execute(query1).scalar_one_or_none() 

609 if result is not None: 

610 return True 

611 else: 

612 result = conn.execute(query2).scalar_one_or_none() 

613 return result is not None 

614 

615 def getSSObjects(self) -> pandas.DataFrame: 

616 # docstring is inherited from a base class 

617 

618 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

619 query = sql.select(*columns) 

620 

621 # execute select 

622 with Timer("DiaObject select", self.config.timer): 

623 with self._engine.begin() as conn: 

624 objects = pandas.read_sql_query(query, conn) 

625 _LOG.debug("found %s SSObjects", len(objects)) 

626 return objects 

627 

628 def store( 

629 self, 

630 visit_time: astropy.time.Time, 

631 objects: pandas.DataFrame, 

632 sources: pandas.DataFrame | None = None, 

633 forced_sources: pandas.DataFrame | None = None, 

634 ) -> None: 

635 # docstring is inherited from a base class 

636 

637 # We want to run all inserts in one transaction. 

638 with self._engine.begin() as connection: 

639 replica_chunk: ReplicaChunk | None = None 

640 if self._schema.has_replica_chunks: 

641 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

642 self._storeReplicaChunk(replica_chunk, visit_time, connection) 

643 

644 # fill pixelId column for DiaObjects 

645 objects = self._add_obj_htm_index(objects) 

646 self._storeDiaObjects(objects, visit_time, replica_chunk, connection) 

647 

648 if sources is not None: 

649 # copy pixelId column from DiaObjects to DiaSources 

650 sources = self._add_src_htm_index(sources, objects) 

651 self._storeDiaSources(sources, replica_chunk, connection) 

652 

653 if forced_sources is not None: 

654 self._storeDiaForcedSources(forced_sources, replica_chunk, connection) 

655 

656 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

657 # docstring is inherited from a base class 

658 

659 idColumn = "ssObjectId" 

660 table = self._schema.get_table(ApdbTables.SSObject) 

661 

662 # everything to be done in single transaction 

663 with self._engine.begin() as conn: 

664 # Find record IDs that already exist. Some types like np.int64 can 

665 # cause issues with sqlalchemy, convert them to int. 

666 ids = sorted(int(oid) for oid in objects[idColumn]) 

667 

668 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

669 result = conn.execute(query) 

670 knownIds = set(row.ssObjectId for row in result) 

671 

672 filter = objects[idColumn].isin(knownIds) 

673 toUpdate = cast(pandas.DataFrame, objects[filter]) 

674 toInsert = cast(pandas.DataFrame, objects[~filter]) 

675 

676 # insert new records 

677 if len(toInsert) > 0: 

678 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema) 

679 

680 # update existing records 

681 if len(toUpdate) > 0: 

682 whereKey = f"{idColumn}_param" 

683 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

684 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

685 values = toUpdate.to_dict("records") 

686 result = conn.execute(update, values) 

687 

688 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

689 # docstring is inherited from a base class 

690 

691 table = self._schema.get_table(ApdbTables.DiaSource) 

692 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

693 

694 with self._engine.begin() as conn: 

695 # Need to make sure that every ID exists in the database, but 

696 # executemany may not support rowcount, so iterate and check what 

697 # is missing. 

698 missing_ids: list[int] = [] 

699 for key, value in idMap.items(): 

700 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

701 result = conn.execute(query, params) 

702 if result.rowcount == 0: 

703 missing_ids.append(key) 

704 if missing_ids: 

705 missing = ",".join(str(item) for item in missing_ids) 

706 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

707 

708 def dailyJob(self) -> None: 

709 # docstring is inherited from a base class 

710 pass 

711 

712 def countUnassociatedObjects(self) -> int: 

713 # docstring is inherited from a base class 

714 

715 # Retrieve the DiaObject table. 

716 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

717 

718 # Construct the sql statement. 

719 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

720 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

721 

722 # Return the count. 

723 with self._engine.begin() as conn: 

724 count = conn.execute(stmt).scalar_one() 

725 

726 return count 

727 

728 @property 

729 def metadata(self) -> ApdbMetadata: 

730 # docstring is inherited from a base class 

731 if self._metadata is None: 

732 raise RuntimeError("Database schema was not initialized.") 

733 return self._metadata 

734 

735 def _getDiaSourcesInRegion(self, region: Region, visit_time: astropy.time.Time) -> pandas.DataFrame: 

736 """Return catalog of DiaSource instances from given region. 

737 

738 Parameters 

739 ---------- 

740 region : `lsst.sphgeom.Region` 

741 Region to search for DIASources. 

742 visit_time : `astropy.time.Time` 

743 Time of the current visit. 

744 

745 Returns 

746 ------- 

747 catalog : `pandas.DataFrame` 

748 Catalog containing DiaSource records. 

749 """ 

750 # TODO: DateTime.MJD must be consistent with code in ap_association, 

751 # alternatively we can fill midpointMjdTai ourselves in store() 

752 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

753 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

754 

755 table = self._schema.get_table(ApdbTables.DiaSource) 

756 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

757 query = sql.select(*columns) 

758 

759 # build selection 

760 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

761 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

762 query = query.where(where) 

763 

764 # execute select 

765 with Timer("DiaSource select", self.config.timer): 

766 with self._engine.begin() as conn: 

767 sources = pandas.read_sql_query(query, conn) 

768 _LOG.debug("found %s DiaSources", len(sources)) 

769 return sources 

770 

771 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: astropy.time.Time) -> pandas.DataFrame: 

772 """Return catalog of DiaSource instances given set of DiaObject IDs. 

773 

774 Parameters 

775 ---------- 

776 object_ids : 

777 Collection of DiaObject IDs 

778 visit_time : `astropy.time.Time` 

779 Time of the current visit. 

780 

781 Returns 

782 ------- 

783 catalog : `pandas.DataFrame` 

784 Catalog contaning DiaSource records. 

785 """ 

786 # TODO: DateTime.MJD must be consistent with code in ap_association, 

787 # alternatively we can fill midpointMjdTai ourselves in store() 

788 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

789 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

790 

791 with Timer("DiaSource select", self.config.timer): 

792 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

793 

794 _LOG.debug("found %s DiaSources", len(sources)) 

795 return sources 

796 

797 def _getSourcesByIDs( 

798 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

799 ) -> pandas.DataFrame: 

800 """Return catalog of DiaSource or DiaForcedSource instances given set 

801 of DiaObject IDs. 

802 

803 Parameters 

804 ---------- 

805 table : `sqlalchemy.schema.Table` 

806 Database table. 

807 object_ids : 

808 Collection of DiaObject IDs 

809 midpointMjdTai_start : `float` 

810 Earliest midpointMjdTai to retrieve. 

811 

812 Returns 

813 ------- 

814 catalog : `pandas.DataFrame` 

815 Catalog contaning DiaSource records. `None` is returned if 

816 ``read_sources_months`` configuration parameter is set to 0 or 

817 when ``object_ids`` is empty. 

818 """ 

819 table = self._schema.get_table(table_enum) 

820 columns = self._schema.get_apdb_columns(table_enum) 

821 

822 sources: pandas.DataFrame | None = None 

823 if len(object_ids) <= 0: 

824 _LOG.debug("ID list is empty, just fetch empty result") 

825 query = sql.select(*columns).where(sql.literal(False)) 

826 with self._engine.begin() as conn: 

827 sources = pandas.read_sql_query(query, conn) 

828 else: 

829 data_frames: list[pandas.DataFrame] = [] 

830 for ids in chunk_iterable(sorted(object_ids), 1000): 

831 query = sql.select(*columns) 

832 

833 # Some types like np.int64 can cause issues with 

834 # sqlalchemy, convert them to int. 

835 int_ids = [int(oid) for oid in ids] 

836 

837 # select by object id 

838 query = query.where( 

839 sql.expression.and_( 

840 table.columns["diaObjectId"].in_(int_ids), 

841 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

842 ) 

843 ) 

844 

845 # execute select 

846 with self._engine.begin() as conn: 

847 data_frames.append(pandas.read_sql_query(query, conn)) 

848 

849 if len(data_frames) == 1: 

850 sources = data_frames[0] 

851 else: 

852 sources = pandas.concat(data_frames) 

853 assert sources is not None, "Catalog cannot be None" 

854 return sources 

855 

856 def _storeReplicaChunk( 

857 self, 

858 replica_chunk: ReplicaChunk, 

859 visit_time: astropy.time.Time, 

860 connection: sqlalchemy.engine.Connection, 

861 ) -> None: 

862 dt = visit_time.datetime 

863 

864 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks) 

865 

866 # We need UPSERT which is dialect-specific construct 

867 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id} 

868 row = {"apdb_replica_chunk": replica_chunk.id} | values 

869 if connection.dialect.name == "sqlite": 

870 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table) 

871 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values) 

872 connection.execute(insert_sqlite, row) 

873 elif connection.dialect.name == "postgresql": 

874 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table) 

875 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values) 

876 connection.execute(insert_pg, row) 

877 else: 

878 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.") 

879 

880 def _storeDiaObjects( 

881 self, 

882 objs: pandas.DataFrame, 

883 visit_time: astropy.time.Time, 

884 replica_chunk: ReplicaChunk | None, 

885 connection: sqlalchemy.engine.Connection, 

886 ) -> None: 

887 """Store catalog of DiaObjects from current visit. 

888 

889 Parameters 

890 ---------- 

891 objs : `pandas.DataFrame` 

892 Catalog with DiaObject records. 

893 visit_time : `astropy.time.Time` 

894 Time of the visit. 

895 replica_chunk : `ReplicaChunk` 

896 Insert identifier. 

897 """ 

898 if len(objs) == 0: 

899 _LOG.debug("No objects to write to database.") 

900 return 

901 

902 # Some types like np.int64 can cause issues with sqlalchemy, convert 

903 # them to int. 

904 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

905 _LOG.debug("first object ID: %d", ids[0]) 

906 

907 # TODO: Need to verify that we are using correct scale here for 

908 # DATETIME representation (see DM-31996). 

909 dt = visit_time.datetime 

910 

911 # everything to be done in single transaction 

912 if self.config.dia_object_index == "last_object_table": 

913 # Insert and replace all records in LAST table. 

914 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

915 

916 # Drop the previous objects (pandas cannot upsert). 

917 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

918 

919 with Timer(table.name + " delete", self.config.timer): 

920 res = connection.execute(query) 

921 _LOG.debug("deleted %s objects", res.rowcount) 

922 

923 # DiaObjectLast is a subset of DiaObject, strip missing columns 

924 last_column_names = [column.name for column in table.columns] 

925 last_objs = objs[last_column_names] 

926 last_objs = _coerce_uint64(last_objs) 

927 

928 if "lastNonForcedSource" in last_objs.columns: 

929 # lastNonForcedSource is defined NOT NULL, fill it with visit 

930 # time just in case. 

931 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

932 else: 

933 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

934 last_objs.set_index(extra_column.index, inplace=True) 

935 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

936 

937 with Timer("DiaObjectLast insert", self.config.timer): 

938 last_objs.to_sql( 

939 table.name, 

940 connection, 

941 if_exists="append", 

942 index=False, 

943 schema=table.schema, 

944 ) 

945 else: 

946 # truncate existing validity intervals 

947 table = self._schema.get_table(ApdbTables.DiaObject) 

948 

949 update = ( 

950 table.update() 

951 .values(validityEnd=dt) 

952 .where( 

953 sql.expression.and_( 

954 table.columns["diaObjectId"].in_(ids), 

955 table.columns["validityEnd"].is_(None), 

956 ) 

957 ) 

958 ) 

959 

960 # _LOG.debug("query: %s", query) 

961 

962 with Timer(table.name + " truncate", self.config.timer): 

963 res = connection.execute(update) 

964 _LOG.debug("truncated %s intervals", res.rowcount) 

965 

966 objs = _coerce_uint64(objs) 

967 

968 # Fill additional columns 

969 extra_columns: list[pandas.Series] = [] 

970 if "validityStart" in objs.columns: 

971 objs["validityStart"] = dt 

972 else: 

973 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

974 if "validityEnd" in objs.columns: 

975 objs["validityEnd"] = None 

976 else: 

977 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

978 if "lastNonForcedSource" in objs.columns: 

979 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

980 # just in case. 

981 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

982 else: 

983 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

984 if extra_columns: 

985 objs.set_index(extra_columns[0].index, inplace=True) 

986 objs = pandas.concat([objs] + extra_columns, axis="columns") 

987 

988 # Insert replica data 

989 table = self._schema.get_table(ApdbTables.DiaObject) 

990 replica_data: list[dict] = [] 

991 replica_stmt: Any = None 

992 if replica_chunk is not None: 

993 pk_names = [column.name for column in table.primary_key] 

994 replica_data = objs[pk_names].to_dict("records") 

995 for row in replica_data: 

996 row["apdb_replica_chunk"] = replica_chunk.id 

997 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks) 

998 replica_stmt = replica_table.insert() 

999 

1000 # insert new versions 

1001 with Timer("DiaObject insert", self.config.timer): 

1002 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1003 if replica_stmt is not None: 

1004 connection.execute(replica_stmt, replica_data) 

1005 

1006 def _storeDiaSources( 

1007 self, 

1008 sources: pandas.DataFrame, 

1009 replica_chunk: ReplicaChunk | None, 

1010 connection: sqlalchemy.engine.Connection, 

1011 ) -> None: 

1012 """Store catalog of DiaSources from current visit. 

1013 

1014 Parameters 

1015 ---------- 

1016 sources : `pandas.DataFrame` 

1017 Catalog containing DiaSource records 

1018 """ 

1019 table = self._schema.get_table(ApdbTables.DiaSource) 

1020 

1021 # Insert replica data 

1022 replica_data: list[dict] = [] 

1023 replica_stmt: Any = None 

1024 if replica_chunk is not None: 

1025 pk_names = [column.name for column in table.primary_key] 

1026 replica_data = sources[pk_names].to_dict("records") 

1027 for row in replica_data: 

1028 row["apdb_replica_chunk"] = replica_chunk.id 

1029 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks) 

1030 replica_stmt = replica_table.insert() 

1031 

1032 # everything to be done in single transaction 

1033 with Timer("DiaSource insert", self.config.timer): 

1034 sources = _coerce_uint64(sources) 

1035 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1036 if replica_stmt is not None: 

1037 connection.execute(replica_stmt, replica_data) 

1038 

1039 def _storeDiaForcedSources( 

1040 self, 

1041 sources: pandas.DataFrame, 

1042 replica_chunk: ReplicaChunk | None, 

1043 connection: sqlalchemy.engine.Connection, 

1044 ) -> None: 

1045 """Store a set of DiaForcedSources from current visit. 

1046 

1047 Parameters 

1048 ---------- 

1049 sources : `pandas.DataFrame` 

1050 Catalog containing DiaForcedSource records 

1051 """ 

1052 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1053 

1054 # Insert replica data 

1055 replica_data: list[dict] = [] 

1056 replica_stmt: Any = None 

1057 if replica_chunk is not None: 

1058 pk_names = [column.name for column in table.primary_key] 

1059 replica_data = sources[pk_names].to_dict("records") 

1060 for row in replica_data: 

1061 row["apdb_replica_chunk"] = replica_chunk.id 

1062 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks) 

1063 replica_stmt = replica_table.insert() 

1064 

1065 # everything to be done in single transaction 

1066 with Timer("DiaForcedSource insert", self.config.timer): 

1067 sources = _coerce_uint64(sources) 

1068 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1069 if replica_stmt is not None: 

1070 connection.execute(replica_stmt, replica_data) 

1071 

1072 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1073 """Generate a set of HTM indices covering specified region. 

1074 

1075 Parameters 

1076 ---------- 

1077 region: `sphgeom.Region` 

1078 Region that needs to be indexed. 

1079 

1080 Returns 

1081 ------- 

1082 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1083 """ 

1084 _LOG.debug("region: %s", region) 

1085 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

1086 

1087 return indices.ranges() 

1088 

1089 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1090 """Make SQLAlchemy expression for selecting records in a region.""" 

1091 htm_index_column = table.columns[self.config.htm_index_column] 

1092 exprlist = [] 

1093 pixel_ranges = self._htm_indices(region) 

1094 for low, upper in pixel_ranges: 

1095 upper -= 1 

1096 if low == upper: 

1097 exprlist.append(htm_index_column == low) 

1098 else: 

1099 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1100 

1101 return sql.expression.or_(*exprlist) 

1102 

1103 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1104 """Calculate HTM index for each record and add it to a DataFrame. 

1105 

1106 Notes 

1107 ----- 

1108 This overrides any existing column in a DataFrame with the same name 

1109 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1110 returned. 

1111 """ 

1112 # calculate HTM index for every DiaObject 

1113 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1114 ra_col, dec_col = self.config.ra_dec_columns 

1115 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1116 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1117 idx = self.pixelator.index(uv3d) 

1118 htm_index[i] = idx 

1119 df = df.copy() 

1120 df[self.config.htm_index_column] = htm_index 

1121 return df 

1122 

1123 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1124 """Add pixelId column to DiaSource catalog. 

1125 

1126 Notes 

1127 ----- 

1128 This method copies pixelId value from a matching DiaObject record. 

1129 DiaObject catalog needs to have a pixelId column filled by 

1130 ``_add_obj_htm_index`` method and DiaSource records need to be 

1131 associated to DiaObjects via ``diaObjectId`` column. 

1132 

1133 This overrides any existing column in a DataFrame with the same name 

1134 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1135 returned. 

1136 """ 

1137 pixel_id_map: dict[int, int] = { 

1138 diaObjectId: pixelId 

1139 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

1140 } 

1141 # DiaSources associated with SolarSystemObjects do not have an 

1142 # associated DiaObject hence we skip them and set their htmIndex 

1143 # value to 0. 

1144 pixel_id_map[0] = 0 

1145 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1146 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1147 htm_index[i] = pixel_id_map[diaObjId] 

1148 sources = sources.copy() 

1149 sources[self.config.htm_index_column] = htm_index 

1150 return sources