Coverage for python/lsst/dax/apdb/sql/apdbSql.py: 15%

509 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-01 10:45 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from contextlib import closing, suppress 

32from typing import TYPE_CHECKING, Any, cast 

33 

34import astropy.time 

35import numpy as np 

36import pandas 

37import sqlalchemy 

38import sqlalchemy.dialects.postgresql 

39import sqlalchemy.dialects.sqlite 

40from lsst.pex.config import ChoiceField, Field, ListField 

41from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

42from lsst.utils.iteration import chunk_iterable 

43from sqlalchemy import func, sql 

44from sqlalchemy.pool import NullPool 

45 

46from ..apdb import Apdb, ApdbConfig 

47from ..apdbConfigFreezer import ApdbConfigFreezer 

48from ..apdbReplica import ReplicaChunk 

49from ..apdbSchema import ApdbTables 

50from ..monitor import MonAgent 

51from ..schema_model import Table 

52from ..timer import Timer 

53from ..versionTuple import IncompatibleVersionError, VersionTuple 

54from .apdbMetadataSql import ApdbMetadataSql 

55from .apdbSqlReplica import ApdbSqlReplica 

56from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

57 

58if TYPE_CHECKING: 

59 import sqlite3 

60 

61 from ..apdbMetadata import ApdbMetadata 

62 

63_LOG = logging.getLogger(__name__) 

64 

65_MON = MonAgent(__name__) 

66 

67VERSION = VersionTuple(0, 1, 0) 

68"""Version for the code controlling non-replication tables. This needs to be 

69updated following compatibility rules when schema produced by this code 

70changes. 

71""" 

72 

73 

74def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

75 """Change the type of uint64 columns to int64, and return copy of data 

76 frame. 

77 """ 

78 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

79 return df.astype({name: np.int64 for name in names}) 

80 

81 

82def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float: 

83 """Calculate starting point for time-based source search. 

84 

85 Parameters 

86 ---------- 

87 visit_time : `astropy.time.Time` 

88 Time of current visit. 

89 months : `int` 

90 Number of months in the sources history. 

91 

92 Returns 

93 ------- 

94 time : `float` 

95 A ``midpointMjdTai`` starting point, MJD time. 

96 """ 

97 # TODO: Use of MJD must be consistent with the code in ap_association 

98 # (see DM-31996) 

99 return visit_time.mjd - months * 30 

100 

101 

102def _onSqlite3Connect( 

103 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

104) -> None: 

105 # Enable foreign keys 

106 with closing(dbapiConnection.cursor()) as cursor: 

107 cursor.execute("PRAGMA foreign_keys=ON;") 

108 

109 

110class ApdbSqlConfig(ApdbConfig): 

111 """APDB configuration class for SQL implementation (ApdbSql).""" 

112 

113 db_url = Field[str](doc="SQLAlchemy database connection URI") 

114 isolation_level = ChoiceField[str]( 

115 doc=( 

116 "Transaction isolation level, if unset then backend-default value " 

117 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

118 "Some backends may not support every allowed value." 

119 ), 

120 allowed={ 

121 "READ_COMMITTED": "Read committed", 

122 "READ_UNCOMMITTED": "Read uncommitted", 

123 "REPEATABLE_READ": "Repeatable read", 

124 "SERIALIZABLE": "Serializable", 

125 }, 

126 default=None, 

127 optional=True, 

128 ) 

129 connection_pool = Field[bool]( 

130 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.", 

131 default=True, 

132 ) 

133 connection_timeout = Field[float]( 

134 doc=( 

135 "Maximum time to wait time for database lock to be released before exiting. " 

136 "Defaults to sqlalchemy defaults if not set." 

137 ), 

138 default=None, 

139 optional=True, 

140 ) 

141 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False) 

142 dia_object_index = ChoiceField[str]( 

143 doc="Indexing mode for DiaObject table", 

144 allowed={ 

145 "baseline": "Index defined in baseline schema", 

146 "pix_id_iov": "(pixelId, objectId, iovStart) PK", 

147 "last_object_table": "Separate DiaObjectLast table", 

148 }, 

149 default="baseline", 

150 ) 

151 htm_level = Field[int](doc="HTM indexing level", default=20) 

152 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64) 

153 htm_index_column = Field[str]( 

154 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables" 

155 ) 

156 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table") 

157 dia_object_columns = ListField[str]( 

158 doc="List of columns to read from DiaObject, by default read all columns", default=[] 

159 ) 

160 prefix = Field[str](doc="Prefix to add to table names and index names", default="") 

161 namespace = Field[str]( 

162 doc=( 

163 "Namespace or schema name for all tables in APDB database. " 

164 "Presently only works for PostgreSQL backend. " 

165 "If schema with this name does not exist it will be created when " 

166 "APDB tables are created." 

167 ), 

168 default=None, 

169 optional=True, 

170 ) 

171 timer = Field[bool](doc="If True then print/log timing information", default=False) 

172 

173 def validate(self) -> None: 

174 super().validate() 

175 if len(self.ra_dec_columns) != 2: 

176 raise ValueError("ra_dec_columns must have exactly two column names") 

177 

178 

179class ApdbSql(Apdb): 

180 """Implementation of APDB interface based on SQL database. 

181 

182 The implementation is configured via standard ``pex_config`` mechanism 

183 using `ApdbSqlConfig` configuration class. For an example of different 

184 configurations check ``config/`` folder. 

185 

186 Parameters 

187 ---------- 

188 config : `ApdbSqlConfig` 

189 Configuration object. 

190 """ 

191 

192 ConfigClass = ApdbSqlConfig 

193 

194 metadataSchemaVersionKey = "version:schema" 

195 """Name of the metadata key to store schema version number.""" 

196 

197 metadataCodeVersionKey = "version:ApdbSql" 

198 """Name of the metadata key to store code version number.""" 

199 

200 metadataReplicaVersionKey = "version:ApdbSqlReplica" 

201 """Name of the metadata key to store replica code version number.""" 

202 

203 metadataConfigKey = "config:apdb-sql.json" 

204 """Name of the metadata key to store code version number.""" 

205 

206 _frozen_parameters = ( 

207 "use_insert_id", 

208 "dia_object_index", 

209 "htm_level", 

210 "htm_index_column", 

211 "ra_dec_columns", 

212 ) 

213 """Names of the config parameters to be frozen in metadata table.""" 

214 

215 def __init__(self, config: ApdbSqlConfig): 

216 self._engine = self._makeEngine(config) 

217 

218 sa_metadata = sqlalchemy.MetaData(schema=config.namespace) 

219 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix) 

220 meta_table: sqlalchemy.schema.Table | None = None 

221 with suppress(sqlalchemy.exc.NoSuchTableError): 

222 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine) 

223 

224 self._metadata = ApdbMetadataSql(self._engine, meta_table) 

225 

226 # Read frozen config from metadata. 

227 config_json = self._metadata.get(self.metadataConfigKey) 

228 if config_json is not None: 

229 # Update config from metadata. 

230 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters) 

231 self.config = freezer.update(config, config_json) 

232 else: 

233 self.config = config 

234 self.config.validate() 

235 

236 self._schema = ApdbSqlSchema( 

237 engine=self._engine, 

238 dia_object_index=self.config.dia_object_index, 

239 schema_file=self.config.schema_file, 

240 schema_name=self.config.schema_name, 

241 prefix=self.config.prefix, 

242 namespace=self.config.namespace, 

243 htm_index_column=self.config.htm_index_column, 

244 enable_replica=self.config.use_insert_id, 

245 ) 

246 

247 if self._metadata.table_exists(): 

248 self._versionCheck(self._metadata) 

249 

250 self.pixelator = HtmPixelization(self.config.htm_level) 

251 

252 _LOG.debug("APDB Configuration:") 

253 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

254 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

255 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

256 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

257 _LOG.debug(" schema_file: %s", self.config.schema_file) 

258 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

259 _LOG.debug(" schema prefix: %s", self.config.prefix) 

260 

261 self._timer_args: list[MonAgent | logging.Logger] = [_MON] 

262 if self.config.timer: 

263 self._timer_args.append(_LOG) 

264 

265 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

266 """Create `Timer` instance given its name.""" 

267 return Timer(name, *self._timer_args, tags=tags) 

268 

269 @classmethod 

270 def _makeEngine(cls, config: ApdbSqlConfig) -> sqlalchemy.engine.Engine: 

271 """Make SQLALchemy engine based on configured parameters. 

272 

273 Parameters 

274 ---------- 

275 config : `ApdbSqlConfig` 

276 Configuration object. 

277 """ 

278 # engine is reused between multiple processes, make sure that we don't 

279 # share connections by disabling pool (by using NullPool class) 

280 kw: MutableMapping[str, Any] = dict(echo=config.sql_echo) 

281 conn_args: dict[str, Any] = dict() 

282 if not config.connection_pool: 

283 kw.update(poolclass=NullPool) 

284 if config.isolation_level is not None: 

285 kw.update(isolation_level=config.isolation_level) 

286 elif config.db_url.startswith("sqlite"): # type: ignore 

287 # Use READ_UNCOMMITTED as default value for sqlite. 

288 kw.update(isolation_level="READ_UNCOMMITTED") 

289 if config.connection_timeout is not None: 

290 if config.db_url.startswith("sqlite"): 

291 conn_args.update(timeout=config.connection_timeout) 

292 elif config.db_url.startswith(("postgresql", "mysql")): 

293 conn_args.update(connect_timeout=config.connection_timeout) 

294 kw.update(connect_args=conn_args) 

295 engine = sqlalchemy.create_engine(config.db_url, **kw) 

296 

297 if engine.dialect.name == "sqlite": 

298 # Need to enable foreign keys on every new connection. 

299 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) 

300 

301 return engine 

302 

303 def _versionCheck(self, metadata: ApdbMetadataSql) -> None: 

304 """Check schema version compatibility.""" 

305 

306 def _get_version(key: str, default: VersionTuple) -> VersionTuple: 

307 """Retrieve version number from given metadata key.""" 

308 if metadata.table_exists(): 

309 version_str = metadata.get(key) 

310 if version_str is None: 

311 # Should not happen with existing metadata table. 

312 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

313 return VersionTuple.fromString(version_str) 

314 return default 

315 

316 # For old databases where metadata table does not exist we assume that 

317 # version of both code and schema is 0.1.0. 

318 initial_version = VersionTuple(0, 1, 0) 

319 db_schema_version = _get_version(self.metadataSchemaVersionKey, initial_version) 

320 db_code_version = _get_version(self.metadataCodeVersionKey, initial_version) 

321 

322 # For now there is no way to make read-only APDB instances, assume that 

323 # any access can do updates. 

324 if not self._schema.schemaVersion().checkCompatibility(db_schema_version, True): 

325 raise IncompatibleVersionError( 

326 f"Configured schema version {self._schema.schemaVersion()} " 

327 f"is not compatible with database version {db_schema_version}" 

328 ) 

329 if not self.apdbImplementationVersion().checkCompatibility(db_code_version, True): 

330 raise IncompatibleVersionError( 

331 f"Current code version {self.apdbImplementationVersion()} " 

332 f"is not compatible with database version {db_code_version}" 

333 ) 

334 

335 # Check replica code version only if replica is enabled. 

336 if self._schema.has_replica_chunks: 

337 db_replica_version = _get_version(self.metadataReplicaVersionKey, initial_version) 

338 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion() 

339 if not code_replica_version.checkCompatibility(db_replica_version, True): 

340 raise IncompatibleVersionError( 

341 f"Current replication code version {code_replica_version} " 

342 f"is not compatible with database version {db_replica_version}" 

343 ) 

344 

345 @classmethod 

346 def apdbImplementationVersion(cls) -> VersionTuple: 

347 # Docstring inherited from base class. 

348 return VERSION 

349 

350 @classmethod 

351 def init_database( 

352 cls, 

353 db_url: str, 

354 *, 

355 schema_file: str | None = None, 

356 schema_name: str | None = None, 

357 read_sources_months: int | None = None, 

358 read_forced_sources_months: int | None = None, 

359 use_insert_id: bool = False, 

360 connection_timeout: int | None = None, 

361 dia_object_index: str | None = None, 

362 htm_level: int | None = None, 

363 htm_index_column: str | None = None, 

364 ra_dec_columns: list[str] | None = None, 

365 prefix: str | None = None, 

366 namespace: str | None = None, 

367 drop: bool = False, 

368 ) -> ApdbSqlConfig: 

369 """Initialize new APDB instance and make configuration object for it. 

370 

371 Parameters 

372 ---------- 

373 db_url : `str` 

374 SQLAlchemy database URL. 

375 schema_file : `str`, optional 

376 Location of (YAML) configuration file with APDB schema. If not 

377 specified then default location will be used. 

378 schema_name : str | None 

379 Name of the schema in YAML configuration file. If not specified 

380 then default name will be used. 

381 read_sources_months : `int`, optional 

382 Number of months of history to read from DiaSource. 

383 read_forced_sources_months : `int`, optional 

384 Number of months of history to read from DiaForcedSource. 

385 use_insert_id : `bool` 

386 If True, make additional tables used for replication to PPDB. 

387 connection_timeout : `int`, optional 

388 Database connection timeout in seconds. 

389 dia_object_index : `str`, optional 

390 Indexing mode for DiaObject table. 

391 htm_level : `int`, optional 

392 HTM indexing level. 

393 htm_index_column : `str`, optional 

394 Name of a HTM index column for DiaObject and DiaSource tables. 

395 ra_dec_columns : `list` [`str`], optional 

396 Names of ra/dec columns in DiaObject table. 

397 prefix : `str`, optional 

398 Optional prefix for all table names. 

399 namespace : `str`, optional 

400 Name of the database schema for all APDB tables. If not specified 

401 then default schema is used. 

402 drop : `bool`, optional 

403 If `True` then drop existing tables before re-creating the schema. 

404 

405 Returns 

406 ------- 

407 config : `ApdbSqlConfig` 

408 Resulting configuration object for a created APDB instance. 

409 """ 

410 config = ApdbSqlConfig(db_url=db_url, use_insert_id=use_insert_id) 

411 if schema_file is not None: 

412 config.schema_file = schema_file 

413 if schema_name is not None: 

414 config.schema_name = schema_name 

415 if read_sources_months is not None: 

416 config.read_sources_months = read_sources_months 

417 if read_forced_sources_months is not None: 

418 config.read_forced_sources_months = read_forced_sources_months 

419 if connection_timeout is not None: 

420 config.connection_timeout = connection_timeout 

421 if dia_object_index is not None: 

422 config.dia_object_index = dia_object_index 

423 if htm_level is not None: 

424 config.htm_level = htm_level 

425 if htm_index_column is not None: 

426 config.htm_index_column = htm_index_column 

427 if ra_dec_columns is not None: 

428 config.ra_dec_columns = ra_dec_columns 

429 if prefix is not None: 

430 config.prefix = prefix 

431 if namespace is not None: 

432 config.namespace = namespace 

433 

434 cls._makeSchema(config, drop=drop) 

435 

436 return config 

437 

438 def apdbSchemaVersion(self) -> VersionTuple: 

439 # Docstring inherited from base class. 

440 return self._schema.schemaVersion() 

441 

442 def get_replica(self) -> ApdbSqlReplica: 

443 """Return `ApdbReplica` instance for this database.""" 

444 return ApdbSqlReplica(self._schema, self._engine) 

445 

446 def tableRowCount(self) -> dict[str, int]: 

447 """Return dictionary with the table names and row counts. 

448 

449 Used by ``ap_proto`` to keep track of the size of the database tables. 

450 Depending on database technology this could be expensive operation. 

451 

452 Returns 

453 ------- 

454 row_counts : `dict` 

455 Dict where key is a table name and value is a row count. 

456 """ 

457 res = {} 

458 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

459 if self.config.dia_object_index == "last_object_table": 

460 tables.append(ApdbTables.DiaObjectLast) 

461 with self._engine.begin() as conn: 

462 for table in tables: 

463 sa_table = self._schema.get_table(table) 

464 stmt = sql.select(func.count()).select_from(sa_table) 

465 count: int = conn.execute(stmt).scalar_one() 

466 res[table.name] = count 

467 

468 return res 

469 

470 def tableDef(self, table: ApdbTables) -> Table | None: 

471 # docstring is inherited from a base class 

472 return self._schema.tableSchemas.get(table) 

473 

474 @classmethod 

475 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None: 

476 # docstring is inherited from a base class 

477 

478 if not isinstance(config, ApdbSqlConfig): 

479 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

480 

481 engine = cls._makeEngine(config) 

482 

483 # Ask schema class to create all tables. 

484 schema = ApdbSqlSchema( 

485 engine=engine, 

486 dia_object_index=config.dia_object_index, 

487 schema_file=config.schema_file, 

488 schema_name=config.schema_name, 

489 prefix=config.prefix, 

490 namespace=config.namespace, 

491 htm_index_column=config.htm_index_column, 

492 enable_replica=config.use_insert_id, 

493 ) 

494 schema.makeSchema(drop=drop) 

495 

496 # Need metadata table to store few items in it, if table exists. 

497 meta_table: sqlalchemy.schema.Table | None = None 

498 with suppress(ValueError): 

499 meta_table = schema.get_table(ApdbTables.metadata) 

500 

501 apdb_meta = ApdbMetadataSql(engine, meta_table) 

502 if apdb_meta.table_exists(): 

503 # Fill version numbers, overwrite if they are already there. 

504 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

505 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

506 if config.use_insert_id: 

507 # Only store replica code version if replcia is enabled. 

508 apdb_meta.set( 

509 cls.metadataReplicaVersionKey, 

510 str(ApdbSqlReplica.apdbReplicaImplementationVersion()), 

511 force=True, 

512 ) 

513 

514 # Store frozen part of a configuration in metadata. 

515 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters) 

516 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

517 

518 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

519 # docstring is inherited from a base class 

520 

521 # decide what columns we need 

522 if self.config.dia_object_index == "last_object_table": 

523 table_enum = ApdbTables.DiaObjectLast 

524 else: 

525 table_enum = ApdbTables.DiaObject 

526 table = self._schema.get_table(table_enum) 

527 if not self.config.dia_object_columns: 

528 columns = self._schema.get_apdb_columns(table_enum) 

529 else: 

530 columns = [table.c[col] for col in self.config.dia_object_columns] 

531 query = sql.select(*columns) 

532 

533 # build selection 

534 query = query.where(self._filterRegion(table, region)) 

535 

536 # select latest version of objects 

537 if self.config.dia_object_index != "last_object_table": 

538 query = query.where(table.c.validityEnd == None) # noqa: E711 

539 

540 # _LOG.debug("query: %s", query) 

541 

542 # execute select 

543 with self._timer("select_time", tags={"table": "DiaObject"}): 

544 with self._engine.begin() as conn: 

545 objects = pandas.read_sql_query(query, conn) 

546 _LOG.debug("found %s DiaObjects", len(objects)) 

547 return objects 

548 

549 def getDiaSources( 

550 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

551 ) -> pandas.DataFrame | None: 

552 # docstring is inherited from a base class 

553 if self.config.read_sources_months == 0: 

554 _LOG.debug("Skip DiaSources fetching") 

555 return None 

556 

557 if object_ids is None: 

558 # region-based select 

559 return self._getDiaSourcesInRegion(region, visit_time) 

560 else: 

561 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

562 

563 def getDiaForcedSources( 

564 self, region: Region, object_ids: Iterable[int] | None, visit_time: astropy.time.Time 

565 ) -> pandas.DataFrame | None: 

566 # docstring is inherited from a base class 

567 if self.config.read_forced_sources_months == 0: 

568 _LOG.debug("Skip DiaForceSources fetching") 

569 return None 

570 

571 if object_ids is None: 

572 # This implementation does not support region-based selection. 

573 raise NotImplementedError("Region-based selection is not supported") 

574 

575 # TODO: DateTime.MJD must be consistent with code in ap_association, 

576 # alternatively we can fill midpointMjdTai ourselves in store() 

577 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

578 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

579 

580 with self._timer("select_time", tags={"table": "DiaForcedSource"}): 

581 sources = self._getSourcesByIDs( 

582 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start 

583 ) 

584 

585 _LOG.debug("found %s DiaForcedSources", len(sources)) 

586 return sources 

587 

588 def containsVisitDetector(self, visit: int, detector: int) -> bool: 

589 # docstring is inherited from a base class 

590 raise NotImplementedError() 

591 

592 def containsCcdVisit(self, ccdVisitId: int) -> bool: 

593 """Test whether data for a given visit-detector is present in the APDB. 

594 

595 This method is a placeholder until `Apdb.containsVisitDetector` can 

596 be implemented. 

597 

598 Parameters 

599 ---------- 

600 ccdVisitId : `int` 

601 The packed ID of the visit-detector to search for. 

602 

603 Returns 

604 ------- 

605 present : `bool` 

606 `True` if some DiaSource records exist for the specified 

607 observation, `False` otherwise. 

608 """ 

609 # TODO: remove this method in favor of containsVisitDetector on either 

610 # DM-41671 or a ticket that removes ccdVisitId from these tables 

611 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

612 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

613 # Query should load only one leaf page of the index 

614 query1 = sql.select(src_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

615 # Backup query in case an image was processed but had no diaSources 

616 query2 = sql.select(frcsrc_table.c.ccdVisitId).filter_by(ccdVisitId=ccdVisitId).limit(1) 

617 

618 with self._engine.begin() as conn: 

619 result = conn.execute(query1).scalar_one_or_none() 

620 if result is not None: 

621 return True 

622 else: 

623 result = conn.execute(query2).scalar_one_or_none() 

624 return result is not None 

625 

626 def getSSObjects(self) -> pandas.DataFrame: 

627 # docstring is inherited from a base class 

628 

629 columns = self._schema.get_apdb_columns(ApdbTables.SSObject) 

630 query = sql.select(*columns) 

631 

632 # execute select 

633 with self._timer("SSObject_select_time", tags={"table": "SSObject"}): 

634 with self._engine.begin() as conn: 

635 objects = pandas.read_sql_query(query, conn) 

636 _LOG.debug("found %s SSObjects", len(objects)) 

637 return objects 

638 

639 def store( 

640 self, 

641 visit_time: astropy.time.Time, 

642 objects: pandas.DataFrame, 

643 sources: pandas.DataFrame | None = None, 

644 forced_sources: pandas.DataFrame | None = None, 

645 ) -> None: 

646 # docstring is inherited from a base class 

647 

648 # We want to run all inserts in one transaction. 

649 with self._engine.begin() as connection: 

650 replica_chunk: ReplicaChunk | None = None 

651 if self._schema.has_replica_chunks: 

652 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

653 self._storeReplicaChunk(replica_chunk, visit_time, connection) 

654 

655 # fill pixelId column for DiaObjects 

656 objects = self._add_obj_htm_index(objects) 

657 self._storeDiaObjects(objects, visit_time, replica_chunk, connection) 

658 

659 if sources is not None: 

660 # copy pixelId column from DiaObjects to DiaSources 

661 sources = self._add_src_htm_index(sources, objects) 

662 self._storeDiaSources(sources, replica_chunk, connection) 

663 

664 if forced_sources is not None: 

665 self._storeDiaForcedSources(forced_sources, replica_chunk, connection) 

666 

667 def storeSSObjects(self, objects: pandas.DataFrame) -> None: 

668 # docstring is inherited from a base class 

669 

670 idColumn = "ssObjectId" 

671 table = self._schema.get_table(ApdbTables.SSObject) 

672 

673 # everything to be done in single transaction 

674 with self._engine.begin() as conn: 

675 # Find record IDs that already exist. Some types like np.int64 can 

676 # cause issues with sqlalchemy, convert them to int. 

677 ids = sorted(int(oid) for oid in objects[idColumn]) 

678 

679 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids)) 

680 result = conn.execute(query) 

681 knownIds = set(row.ssObjectId for row in result) 

682 

683 filter = objects[idColumn].isin(knownIds) 

684 toUpdate = cast(pandas.DataFrame, objects[filter]) 

685 toInsert = cast(pandas.DataFrame, objects[~filter]) 

686 

687 # insert new records 

688 if len(toInsert) > 0: 

689 toInsert.to_sql(table.name, conn, if_exists="append", index=False, schema=table.schema) 

690 

691 # update existing records 

692 if len(toUpdate) > 0: 

693 whereKey = f"{idColumn}_param" 

694 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey)) 

695 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns") 

696 values = toUpdate.to_dict("records") 

697 result = conn.execute(update, values) 

698 

699 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

700 # docstring is inherited from a base class 

701 

702 table = self._schema.get_table(ApdbTables.DiaSource) 

703 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

704 

705 with self._engine.begin() as conn: 

706 # Need to make sure that every ID exists in the database, but 

707 # executemany may not support rowcount, so iterate and check what 

708 # is missing. 

709 missing_ids: list[int] = [] 

710 for key, value in idMap.items(): 

711 params = dict(srcId=key, diaObjectId=0, ssObjectId=value) 

712 result = conn.execute(query, params) 

713 if result.rowcount == 0: 

714 missing_ids.append(key) 

715 if missing_ids: 

716 missing = ",".join(str(item) for item in missing_ids) 

717 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

718 

719 def dailyJob(self) -> None: 

720 # docstring is inherited from a base class 

721 pass 

722 

723 def countUnassociatedObjects(self) -> int: 

724 # docstring is inherited from a base class 

725 

726 # Retrieve the DiaObject table. 

727 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

728 

729 # Construct the sql statement. 

730 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

731 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

732 

733 # Return the count. 

734 with self._engine.begin() as conn: 

735 count = conn.execute(stmt).scalar_one() 

736 

737 return count 

738 

739 @property 

740 def metadata(self) -> ApdbMetadata: 

741 # docstring is inherited from a base class 

742 if self._metadata is None: 

743 raise RuntimeError("Database schema was not initialized.") 

744 return self._metadata 

745 

746 def _getDiaSourcesInRegion(self, region: Region, visit_time: astropy.time.Time) -> pandas.DataFrame: 

747 """Return catalog of DiaSource instances from given region. 

748 

749 Parameters 

750 ---------- 

751 region : `lsst.sphgeom.Region` 

752 Region to search for DIASources. 

753 visit_time : `astropy.time.Time` 

754 Time of the current visit. 

755 

756 Returns 

757 ------- 

758 catalog : `pandas.DataFrame` 

759 Catalog containing DiaSource records. 

760 """ 

761 # TODO: DateTime.MJD must be consistent with code in ap_association, 

762 # alternatively we can fill midpointMjdTai ourselves in store() 

763 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

764 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

765 

766 table = self._schema.get_table(ApdbTables.DiaSource) 

767 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

768 query = sql.select(*columns) 

769 

770 # build selection 

771 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start 

772 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

773 query = query.where(where) 

774 

775 # execute select 

776 with self._timer("DiaSource_select_time", tags={"table": "DiaSource"}): 

777 with self._engine.begin() as conn: 

778 sources = pandas.read_sql_query(query, conn) 

779 _LOG.debug("found %s DiaSources", len(sources)) 

780 return sources 

781 

782 def _getDiaSourcesByIDs(self, object_ids: list[int], visit_time: astropy.time.Time) -> pandas.DataFrame: 

783 """Return catalog of DiaSource instances given set of DiaObject IDs. 

784 

785 Parameters 

786 ---------- 

787 object_ids : 

788 Collection of DiaObject IDs 

789 visit_time : `astropy.time.Time` 

790 Time of the current visit. 

791 

792 Returns 

793 ------- 

794 catalog : `pandas.DataFrame` 

795 Catalog contaning DiaSource records. 

796 """ 

797 # TODO: DateTime.MJD must be consistent with code in ap_association, 

798 # alternatively we can fill midpointMjdTai ourselves in store() 

799 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

800 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start) 

801 

802 with self._timer("select_time", tags={"table": "DiaSource"}): 

803 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start) 

804 

805 _LOG.debug("found %s DiaSources", len(sources)) 

806 return sources 

807 

808 def _getSourcesByIDs( 

809 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

810 ) -> pandas.DataFrame: 

811 """Return catalog of DiaSource or DiaForcedSource instances given set 

812 of DiaObject IDs. 

813 

814 Parameters 

815 ---------- 

816 table : `sqlalchemy.schema.Table` 

817 Database table. 

818 object_ids : 

819 Collection of DiaObject IDs 

820 midpointMjdTai_start : `float` 

821 Earliest midpointMjdTai to retrieve. 

822 

823 Returns 

824 ------- 

825 catalog : `pandas.DataFrame` 

826 Catalog contaning DiaSource records. `None` is returned if 

827 ``read_sources_months`` configuration parameter is set to 0 or 

828 when ``object_ids`` is empty. 

829 """ 

830 table = self._schema.get_table(table_enum) 

831 columns = self._schema.get_apdb_columns(table_enum) 

832 

833 sources: pandas.DataFrame | None = None 

834 if len(object_ids) <= 0: 

835 _LOG.debug("ID list is empty, just fetch empty result") 

836 query = sql.select(*columns).where(sql.literal(False)) 

837 with self._engine.begin() as conn: 

838 sources = pandas.read_sql_query(query, conn) 

839 else: 

840 data_frames: list[pandas.DataFrame] = [] 

841 for ids in chunk_iterable(sorted(object_ids), 1000): 

842 query = sql.select(*columns) 

843 

844 # Some types like np.int64 can cause issues with 

845 # sqlalchemy, convert them to int. 

846 int_ids = [int(oid) for oid in ids] 

847 

848 # select by object id 

849 query = query.where( 

850 sql.expression.and_( 

851 table.columns["diaObjectId"].in_(int_ids), 

852 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

853 ) 

854 ) 

855 

856 # execute select 

857 with self._engine.begin() as conn: 

858 data_frames.append(pandas.read_sql_query(query, conn)) 

859 

860 if len(data_frames) == 1: 

861 sources = data_frames[0] 

862 else: 

863 sources = pandas.concat(data_frames) 

864 assert sources is not None, "Catalog cannot be None" 

865 return sources 

866 

867 def _storeReplicaChunk( 

868 self, 

869 replica_chunk: ReplicaChunk, 

870 visit_time: astropy.time.Time, 

871 connection: sqlalchemy.engine.Connection, 

872 ) -> None: 

873 dt = visit_time.datetime 

874 

875 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks) 

876 

877 # We need UPSERT which is dialect-specific construct 

878 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id} 

879 row = {"apdb_replica_chunk": replica_chunk.id} | values 

880 if connection.dialect.name == "sqlite": 

881 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table) 

882 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values) 

883 connection.execute(insert_sqlite, row) 

884 elif connection.dialect.name == "postgresql": 

885 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table) 

886 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values) 

887 connection.execute(insert_pg, row) 

888 else: 

889 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.") 

890 

891 def _storeDiaObjects( 

892 self, 

893 objs: pandas.DataFrame, 

894 visit_time: astropy.time.Time, 

895 replica_chunk: ReplicaChunk | None, 

896 connection: sqlalchemy.engine.Connection, 

897 ) -> None: 

898 """Store catalog of DiaObjects from current visit. 

899 

900 Parameters 

901 ---------- 

902 objs : `pandas.DataFrame` 

903 Catalog with DiaObject records. 

904 visit_time : `astropy.time.Time` 

905 Time of the visit. 

906 replica_chunk : `ReplicaChunk` 

907 Insert identifier. 

908 """ 

909 if len(objs) == 0: 

910 _LOG.debug("No objects to write to database.") 

911 return 

912 

913 # Some types like np.int64 can cause issues with sqlalchemy, convert 

914 # them to int. 

915 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

916 _LOG.debug("first object ID: %d", ids[0]) 

917 

918 # TODO: Need to verify that we are using correct scale here for 

919 # DATETIME representation (see DM-31996). 

920 dt = visit_time.datetime 

921 

922 # everything to be done in single transaction 

923 if self.config.dia_object_index == "last_object_table": 

924 # Insert and replace all records in LAST table. 

925 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

926 

927 # Drop the previous objects (pandas cannot upsert). 

928 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

929 

930 with self._timer("delete_time", tags={"table": table.name}): 

931 res = connection.execute(query) 

932 _LOG.debug("deleted %s objects", res.rowcount) 

933 

934 # DiaObjectLast is a subset of DiaObject, strip missing columns 

935 last_column_names = [column.name for column in table.columns] 

936 last_objs = objs[last_column_names] 

937 last_objs = _coerce_uint64(last_objs) 

938 

939 if "lastNonForcedSource" in last_objs.columns: 

940 # lastNonForcedSource is defined NOT NULL, fill it with visit 

941 # time just in case. 

942 last_objs["lastNonForcedSource"].fillna(dt, inplace=True) 

943 else: 

944 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource") 

945 last_objs.set_index(extra_column.index, inplace=True) 

946 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

947 

948 with self._timer("insert_time", tags={"table": "DiaObjectLast"}): 

949 last_objs.to_sql( 

950 table.name, 

951 connection, 

952 if_exists="append", 

953 index=False, 

954 schema=table.schema, 

955 ) 

956 else: 

957 # truncate existing validity intervals 

958 table = self._schema.get_table(ApdbTables.DiaObject) 

959 

960 update = ( 

961 table.update() 

962 .values(validityEnd=dt) 

963 .where( 

964 sql.expression.and_( 

965 table.columns["diaObjectId"].in_(ids), 

966 table.columns["validityEnd"].is_(None), 

967 ) 

968 ) 

969 ) 

970 

971 with self._timer("truncate_time", tags={"table": table.name}): 

972 res = connection.execute(update) 

973 _LOG.debug("truncated %s intervals", res.rowcount) 

974 

975 objs = _coerce_uint64(objs) 

976 

977 # Fill additional columns 

978 extra_columns: list[pandas.Series] = [] 

979 if "validityStart" in objs.columns: 

980 objs["validityStart"] = dt 

981 else: 

982 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart")) 

983 if "validityEnd" in objs.columns: 

984 objs["validityEnd"] = None 

985 else: 

986 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd")) 

987 if "lastNonForcedSource" in objs.columns: 

988 # lastNonForcedSource is defined NOT NULL, fill it with visit time 

989 # just in case. 

990 objs["lastNonForcedSource"].fillna(dt, inplace=True) 

991 else: 

992 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource")) 

993 if extra_columns: 

994 objs.set_index(extra_columns[0].index, inplace=True) 

995 objs = pandas.concat([objs] + extra_columns, axis="columns") 

996 

997 # Insert replica data 

998 table = self._schema.get_table(ApdbTables.DiaObject) 

999 replica_data: list[dict] = [] 

1000 replica_stmt: Any = None 

1001 replica_table_name = "" 

1002 if replica_chunk is not None: 

1003 pk_names = [column.name for column in table.primary_key] 

1004 replica_data = objs[pk_names].to_dict("records") 

1005 for row in replica_data: 

1006 row["apdb_replica_chunk"] = replica_chunk.id 

1007 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks) 

1008 replica_table_name = replica_table.name 

1009 replica_stmt = replica_table.insert() 

1010 

1011 # insert new versions 

1012 with self._timer("insert_time", tags={"table": table.name}): 

1013 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1014 if replica_stmt is not None: 

1015 with self._timer("insert_time", tags={"table": replica_table_name}): 

1016 connection.execute(replica_stmt, replica_data) 

1017 

1018 def _storeDiaSources( 

1019 self, 

1020 sources: pandas.DataFrame, 

1021 replica_chunk: ReplicaChunk | None, 

1022 connection: sqlalchemy.engine.Connection, 

1023 ) -> None: 

1024 """Store catalog of DiaSources from current visit. 

1025 

1026 Parameters 

1027 ---------- 

1028 sources : `pandas.DataFrame` 

1029 Catalog containing DiaSource records 

1030 """ 

1031 table = self._schema.get_table(ApdbTables.DiaSource) 

1032 

1033 # Insert replica data 

1034 replica_data: list[dict] = [] 

1035 replica_stmt: Any = None 

1036 replica_table_name = "" 

1037 if replica_chunk is not None: 

1038 pk_names = [column.name for column in table.primary_key] 

1039 replica_data = sources[pk_names].to_dict("records") 

1040 for row in replica_data: 

1041 row["apdb_replica_chunk"] = replica_chunk.id 

1042 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks) 

1043 replica_table_name = replica_table.name 

1044 replica_stmt = replica_table.insert() 

1045 

1046 # everything to be done in single transaction 

1047 with self._timer("insert_time", tags={"table": table.name}): 

1048 sources = _coerce_uint64(sources) 

1049 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1050 if replica_stmt is not None: 

1051 with self._timer("replica_insert_time", tags={"table": replica_table_name}): 

1052 connection.execute(replica_stmt, replica_data) 

1053 

1054 def _storeDiaForcedSources( 

1055 self, 

1056 sources: pandas.DataFrame, 

1057 replica_chunk: ReplicaChunk | None, 

1058 connection: sqlalchemy.engine.Connection, 

1059 ) -> None: 

1060 """Store a set of DiaForcedSources from current visit. 

1061 

1062 Parameters 

1063 ---------- 

1064 sources : `pandas.DataFrame` 

1065 Catalog containing DiaForcedSource records 

1066 """ 

1067 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1068 

1069 # Insert replica data 

1070 replica_data: list[dict] = [] 

1071 replica_stmt: Any = None 

1072 replica_table_name = "" 

1073 if replica_chunk is not None: 

1074 pk_names = [column.name for column in table.primary_key] 

1075 replica_data = sources[pk_names].to_dict("records") 

1076 for row in replica_data: 

1077 row["apdb_replica_chunk"] = replica_chunk.id 

1078 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks) 

1079 replica_table_name = replica_table.name 

1080 replica_stmt = replica_table.insert() 

1081 

1082 # everything to be done in single transaction 

1083 with self._timer("insert_time", tags={"table": table.name}): 

1084 sources = _coerce_uint64(sources) 

1085 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1086 if replica_stmt is not None: 

1087 with self._timer("insert_time", tags={"table": replica_table_name}): 

1088 connection.execute(replica_stmt, replica_data) 

1089 

1090 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1091 """Generate a set of HTM indices covering specified region. 

1092 

1093 Parameters 

1094 ---------- 

1095 region: `sphgeom.Region` 

1096 Region that needs to be indexed. 

1097 

1098 Returns 

1099 ------- 

1100 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1101 """ 

1102 _LOG.debug("region: %s", region) 

1103 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

1104 

1105 return indices.ranges() 

1106 

1107 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1108 """Make SQLAlchemy expression for selecting records in a region.""" 

1109 htm_index_column = table.columns[self.config.htm_index_column] 

1110 exprlist = [] 

1111 pixel_ranges = self._htm_indices(region) 

1112 for low, upper in pixel_ranges: 

1113 upper -= 1 

1114 if low == upper: 

1115 exprlist.append(htm_index_column == low) 

1116 else: 

1117 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1118 

1119 return sql.expression.or_(*exprlist) 

1120 

1121 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1122 """Calculate HTM index for each record and add it to a DataFrame. 

1123 

1124 Notes 

1125 ----- 

1126 This overrides any existing column in a DataFrame with the same name 

1127 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1128 returned. 

1129 """ 

1130 # calculate HTM index for every DiaObject 

1131 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1132 ra_col, dec_col = self.config.ra_dec_columns 

1133 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1134 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1135 idx = self.pixelator.index(uv3d) 

1136 htm_index[i] = idx 

1137 df = df.copy() 

1138 df[self.config.htm_index_column] = htm_index 

1139 return df 

1140 

1141 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

1142 """Add pixelId column to DiaSource catalog. 

1143 

1144 Notes 

1145 ----- 

1146 This method copies pixelId value from a matching DiaObject record. 

1147 DiaObject catalog needs to have a pixelId column filled by 

1148 ``_add_obj_htm_index`` method and DiaSource records need to be 

1149 associated to DiaObjects via ``diaObjectId`` column. 

1150 

1151 This overrides any existing column in a DataFrame with the same name 

1152 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1153 returned. 

1154 """ 

1155 pixel_id_map: dict[int, int] = { 

1156 diaObjectId: pixelId 

1157 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

1158 } 

1159 # DiaSources associated with SolarSystemObjects do not have an 

1160 # associated DiaObject hence we skip them and set their htmIndex 

1161 # value to 0. 

1162 pixel_id_map[0] = 0 

1163 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

1164 for i, diaObjId in enumerate(sources["diaObjectId"]): 

1165 htm_index[i] = pixel_id_map[diaObjId] 

1166 sources = sources.copy() 

1167 sources[self.config.htm_index_column] = htm_index 

1168 return sources