Coverage for python / lsst / dax / apdb / sql / apdbSql.py: 12%

702 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:43 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods.""" 

23 

24from __future__ import annotations 

25 

26__all__ = ["ApdbSql"] 

27 

28import datetime 

29import json 

30import logging 

31import urllib.parse 

32import uuid 

33import warnings 

34from collections import Counter 

35from collections.abc import Iterable, Mapping, MutableMapping 

36from contextlib import closing 

37from typing import TYPE_CHECKING, Any 

38 

39import astropy.time 

40import numpy as np 

41import pandas 

42import sqlalchemy 

43import sqlalchemy.dialects.postgresql 

44import sqlalchemy.dialects.sqlite 

45from sqlalchemy import func, sql 

46from sqlalchemy.pool import NullPool 

47 

48from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

49from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

50from lsst.utils.iteration import chunk_iterable 

51 

52from ..apdb import Apdb 

53from ..apdbConfigFreezer import ApdbConfigFreezer 

54from ..apdbReplica import ReplicaChunk 

55from ..apdbSchema import ApdbSchema, ApdbTables 

56from ..apdbUpdateRecord import ( 

57 ApdbCloseDiaObjectValidityRecord, 

58 ApdbReassignDiaSourceToDiaObjectRecord, 

59 ApdbUpdateNDiaSourcesRecord, 

60 ApdbUpdateRecord, 

61) 

62from ..config import ApdbConfig 

63from ..monitor import MonAgent 

64from ..recordIds import DiaObjectId, DiaSourceId 

65from ..schema_model import Table 

66from ..timer import Timer 

67from ..versionTuple import IncompatibleVersionError, VersionTuple 

68from .apdbMetadataSql import ApdbMetadataSql 

69from .apdbSqlAdmin import ApdbSqlAdmin 

70from .apdbSqlReplica import ApdbSqlReplica, ApdbSqlTableData 

71from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

72from .config import ApdbSqlConfig 

73 

74if TYPE_CHECKING: 

75 import sqlite3 

76 

77 from ..apdbMetadata import ApdbMetadata 

78 from ..apdbUpdateRecord import ApdbUpdateRecord 

79 

80_LOG = logging.getLogger(__name__) 

81 

82_MON = MonAgent(__name__) 

83 

84VERSION = VersionTuple(1, 2, 1) 

85"""Version for the code controlling non-replication tables. This needs to be 

86updated following compatibility rules when schema produced by this code 

87changes. 

88""" 

89 

90 

91def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

92 """Change the type of uint64 columns to int64, and return copy of data 

93 frame. 

94 """ 

95 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

96 return df.astype(dict.fromkeys(names, np.int64)) 

97 

98 

99def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float: 

100 """Calculate starting point for time-based source search. 

101 

102 Parameters 

103 ---------- 

104 visit_time : `astropy.time.Time` 

105 Time of current visit. 

106 months : `int` 

107 Number of months in the sources history. 

108 

109 Returns 

110 ------- 

111 time : `float` 

112 A ``midpointMjdTai`` starting point, MJD time. 

113 """ 

114 # TODO: Use of MJD must be consistent with the code in ap_association 

115 # (see DM-31996) 

116 return float(visit_time.tai.mjd) - months * 30 

117 

118 

119def _onSqlite3Connect( 

120 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

121) -> None: 

122 # Enable foreign keys 

123 with closing(dbapiConnection.cursor()) as cursor: 

124 cursor.execute("PRAGMA foreign_keys=ON;") 

125 

126 

127class ApdbSql(Apdb): 

128 """Implementation of APDB interface based on SQL database. 

129 

130 The implementation is configured via standard ``pex_config`` mechanism 

131 using `ApdbSqlConfig` configuration class. For an example of different 

132 configurations check ``config/`` folder. 

133 

134 Parameters 

135 ---------- 

136 config : `ApdbSqlConfig` 

137 Configuration object. 

138 """ 

139 

140 metadataSchemaVersionKey = "version:schema" 

141 """Name of the metadata key to store schema version number.""" 

142 

143 metadataCodeVersionKey = "version:ApdbSql" 

144 """Name of the metadata key to store code version number.""" 

145 

146 metadataReplicaVersionKey = "version:ApdbSqlReplica" 

147 """Name of the metadata key to store replica code version number.""" 

148 

149 metadataConfigKey = "config:apdb-sql.json" 

150 """Name of the metadata key to store code version number.""" 

151 

152 metadataDedupKey = "status:deduplication.json" 

153 """Name of the metadata key to store code version number.""" 

154 

155 _frozen_parameters = ( 

156 "enable_replica", 

157 "dia_object_index", 

158 "pixelization.htm_level", 

159 "pixelization.htm_index_column", 

160 "ra_dec_columns", 

161 ) 

162 """Names of the config parameters to be frozen in metadata table.""" 

163 

164 def __init__(self, config: ApdbSqlConfig): 

165 self._engine = self._makeEngine(config, create=False) 

166 

167 sa_metadata = sqlalchemy.MetaData(schema=config.namespace) 

168 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix) 

169 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine) 

170 self._metadata = ApdbMetadataSql(self._engine, meta_table) 

171 

172 # Get tables schemas. 

173 self._table_schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

174 

175 # Check that versions are compatible, must be the first thing before 

176 # reading frozen config. 

177 self._db_schema_version = self._versionCheck(self._metadata, self._table_schema.schemaVersion()) 

178 

179 # Read frozen config from metadata. 

180 config_json = self._metadata.get(self.metadataConfigKey) 

181 if config_json is not None: 

182 # Update config from metadata. 

183 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters) 

184 self.config = freezer.update(config, config_json) 

185 else: 

186 self.config = config 

187 

188 self._schema = ApdbSqlSchema( 

189 table_schema=self._table_schema, 

190 engine=self._engine, 

191 dia_object_index=self.config.dia_object_index, 

192 prefix=self.config.prefix, 

193 namespace=self.config.namespace, 

194 htm_index_column=self.config.pixelization.htm_index_column, 

195 enable_replica=self.config.enable_replica, 

196 ) 

197 

198 self.pixelator = HtmPixelization(self.config.pixelization.htm_level) 

199 

200 if _LOG.isEnabledFor(logging.DEBUG): 

201 _LOG.debug("ApdbSql Configuration: %s", self.config.model_dump()) 

202 

203 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

204 """Create `Timer` instance given its name.""" 

205 return Timer(name, _MON, tags=tags) 

206 

207 @classmethod 

208 def _makeEngine(cls, config: ApdbSqlConfig, *, create: bool) -> sqlalchemy.engine.Engine: 

209 """Make SQLALchemy engine based on configured parameters. 

210 

211 Parameters 

212 ---------- 

213 config : `ApdbSqlConfig` 

214 Configuration object. 

215 create : `bool` 

216 Whether to try to create new database file, only relevant for 

217 SQLite backend which always creates new files by default. 

218 """ 

219 # engine is reused between multiple processes, make sure that we don't 

220 # share connections by disabling pool (by using NullPool class) 

221 kw: MutableMapping[str, Any] = dict(config.connection_config.extra_parameters) 

222 conn_args: dict[str, Any] = {} 

223 if not config.connection_config.connection_pool: 

224 kw.update(poolclass=NullPool) 

225 if config.connection_config.isolation_level is not None: 

226 kw.update(isolation_level=config.connection_config.isolation_level) 

227 elif config.db_url.startswith("sqlite"): 

228 # Use READ_UNCOMMITTED as default value for sqlite. 

229 kw.update(isolation_level="READ_UNCOMMITTED") 

230 if config.connection_config.connection_timeout is not None: 

231 if config.db_url.startswith("sqlite"): 

232 conn_args.update(timeout=config.connection_config.connection_timeout) 

233 elif config.db_url.startswith(("postgresql", "mysql")): 

234 conn_args.update(connect_timeout=int(config.connection_config.connection_timeout)) 

235 kw.update(connect_args=conn_args) 

236 engine = sqlalchemy.create_engine(cls._connection_url(config.db_url, create=create), **kw) 

237 

238 if engine.dialect.name == "sqlite": 

239 # Need to enable foreign keys on every new connection. 

240 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) 

241 

242 return engine 

243 

244 @classmethod 

245 def _connection_url(cls, config_url: str, *, create: bool) -> sqlalchemy.engine.URL | str: 

246 """Generate a complete URL for database with proper credentials. 

247 

248 Parameters 

249 ---------- 

250 config_url : `str` 

251 Database URL as specified in configuration. 

252 create : `bool` 

253 Whether to try to create new database file, only relevant for 

254 SQLite backend which always creates new files by default. 

255 

256 Returns 

257 ------- 

258 connection_url : `sqlalchemy.engine.URL` or `str` 

259 Connection URL including credentials. 

260 """ 

261 # Allow 3rd party authentication mechanisms by assuming connection 

262 # string is correct when we can not recognize (dialect, host, database) 

263 # matching keys. 

264 components = urllib.parse.urlparse(config_url) 

265 if all((components.scheme is not None, components.hostname is not None, components.path is not None)): 

266 try: 

267 db_auth = DbAuth() 

268 config_url = db_auth.getUrl(config_url) 

269 except DbAuthNotFoundError: 

270 # Credentials file doesn't exist or no matching credentials, 

271 # use default auth. 

272 pass 

273 

274 # SQLite has a nasty habit creating empty databases when they do not 

275 # exist, tell it not to do that unless we do need to create it. 

276 if not create: 

277 config_url = cls._update_sqlite_url(config_url) 

278 

279 return config_url 

280 

281 @classmethod 

282 def _update_sqlite_url(cls, url_string: str) -> str: 

283 """If URL refers to sqlite dialect, update it so that the backend does 

284 not try to create database file if it does not exist already. 

285 

286 Parameters 

287 ---------- 

288 url_string : `str` 

289 Connection string. 

290 

291 Returns 

292 ------- 

293 url_string : `str` 

294 Possibly updated connection string. 

295 """ 

296 try: 

297 url = sqlalchemy.make_url(url_string) 

298 except sqlalchemy.exc.SQLAlchemyError: 

299 # If parsing fails it means some special format, likely not 

300 # sqlite so we just return it unchanged. 

301 return url_string 

302 

303 if url.get_backend_name() == "sqlite": 

304 # Massage url so that database name starts with "file:" and 

305 # option string has "mode=rw&uri=true". Database name 

306 # should look like a path (:memory: is not supported by 

307 # Apdb, but someone could still try to use it). 

308 database = url.database 

309 if database and not database.startswith((":", "file:")): 

310 query = dict(url.query, mode="rw", uri="true") 

311 # If ``database`` is an absolute path then original URL should 

312 # include four slashes after "sqlite:". Humans are bad at 

313 # counting things beyond four and sometimes an extra slash gets 

314 # added unintentionally, which causes sqlite to treat initial 

315 # element as "authority" and to complain. Strip extra slashes 

316 # at the start of the path to avoid that (DM-46077). 

317 if database.startswith("//"): 

318 warnings.warn( 

319 f"Database URL contains extra leading slashes which will be removed: {url}", 

320 stacklevel=3, 

321 ) 

322 database = "/" + database.lstrip("/") 

323 url = url.set(database=f"file:{database}", query=query) 

324 url_string = url.render_as_string() 

325 

326 return url_string 

327 

328 @classmethod 

329 def _versionCheck(cls, metadata: ApdbMetadataSql, schema_version: VersionTuple) -> VersionTuple: 

330 """Check schema version compatibility and return the database schema 

331 version. 

332 """ 

333 

334 def _get_version(key: str) -> VersionTuple: 

335 """Retrieve version number from given metadata key.""" 

336 version_str = metadata.get(key) 

337 if version_str is None: 

338 # Should not happen with existing metadata table. 

339 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

340 return VersionTuple.fromString(version_str) 

341 

342 db_schema_version = _get_version(cls.metadataSchemaVersionKey) 

343 db_code_version = _get_version(cls.metadataCodeVersionKey) 

344 

345 # For now there is no way to make read-only APDB instances, assume that 

346 # any access can do updates. 

347 if not schema_version.checkCompatibility(db_schema_version): 

348 raise IncompatibleVersionError( 

349 f"Configured schema version {schema_version} " 

350 f"is not compatible with database version {db_schema_version}" 

351 ) 

352 if not cls.apdbImplementationVersion().checkCompatibility(db_code_version): 

353 raise IncompatibleVersionError( 

354 f"Current code version {cls.apdbImplementationVersion()} " 

355 f"is not compatible with database version {db_code_version}" 

356 ) 

357 

358 # Check replica code version only if replica is enabled. Sort of 

359 # chicken and egg problem - `enable_replica` is a part of frozen 

360 # configuration, but we cannot read frozen configuration until we 

361 # validate versions. Assume that if the replica version is present 

362 # then replication is enabled. 

363 if metadata.get(cls.metadataReplicaVersionKey) is not None: 

364 db_replica_version = _get_version(cls.metadataReplicaVersionKey) 

365 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion() 

366 if not code_replica_version.checkCompatibility(db_replica_version): 

367 raise IncompatibleVersionError( 

368 f"Current replication code version {code_replica_version} " 

369 f"is not compatible with database version {db_replica_version}" 

370 ) 

371 

372 return db_schema_version 

373 

374 @classmethod 

375 def apdbImplementationVersion(cls) -> VersionTuple: 

376 """Return version number for current APDB implementation. 

377 

378 Returns 

379 ------- 

380 version : `VersionTuple` 

381 Version of the code defined in implementation class. 

382 """ 

383 return VERSION 

384 

385 @classmethod 

386 def init_database( 

387 cls, 

388 db_url: str, 

389 *, 

390 schema_file: str | None = None, 

391 ss_schema_file: str | None = None, 

392 read_sources_months: int | None = None, 

393 read_forced_sources_months: int | None = None, 

394 enable_replica: bool = False, 

395 connection_timeout: int | None = None, 

396 dia_object_index: str | None = None, 

397 htm_level: int | None = None, 

398 htm_index_column: str | None = None, 

399 ra_dec_columns: tuple[str, str] | None = None, 

400 prefix: str | None = None, 

401 namespace: str | None = None, 

402 drop: bool = False, 

403 ) -> ApdbSqlConfig: 

404 """Initialize new APDB instance and make configuration object for it. 

405 

406 Parameters 

407 ---------- 

408 db_url : `str` 

409 SQLAlchemy database URL. 

410 schema_file : `str`, optional 

411 Location of (YAML) configuration file with APDB schema. If not 

412 specified then default location will be used. 

413 ss_schema_file : `str`, optional 

414 Location of (YAML) configuration file with SSO schema. If not 

415 specified then default location will be used. 

416 read_sources_months : `int`, optional 

417 Number of months of history to read from DiaSource. 

418 read_forced_sources_months : `int`, optional 

419 Number of months of history to read from DiaForcedSource. 

420 enable_replica : `bool`, optional 

421 If True, make additional tables used for replication to PPDB. 

422 connection_timeout : `int`, optional 

423 Database connection timeout in seconds. 

424 dia_object_index : `str`, optional 

425 Indexing mode for DiaObject table. 

426 htm_level : `int`, optional 

427 HTM indexing level. 

428 htm_index_column : `str`, optional 

429 Name of a HTM index column for DiaObject and DiaSource tables. 

430 ra_dec_columns : `tuple` [`str`, `str`], optional 

431 Names of ra/dec columns in DiaObject table. 

432 prefix : `str`, optional 

433 Optional prefix for all table names. 

434 namespace : `str`, optional 

435 Name of the database schema for all APDB tables. If not specified 

436 then default schema is used. 

437 drop : `bool`, optional 

438 If `True` then drop existing tables before re-creating the schema. 

439 

440 Returns 

441 ------- 

442 config : `ApdbSqlConfig` 

443 Resulting configuration object for a created APDB instance. 

444 """ 

445 config = ApdbSqlConfig(db_url=db_url, enable_replica=enable_replica) 

446 if schema_file is not None: 

447 config.schema_file = schema_file 

448 if ss_schema_file is not None: 

449 config.ss_schema_file = ss_schema_file 

450 if read_sources_months is not None: 

451 config.read_sources_months = read_sources_months 

452 if read_forced_sources_months is not None: 

453 config.read_forced_sources_months = read_forced_sources_months 

454 if connection_timeout is not None: 

455 config.connection_config.connection_timeout = connection_timeout 

456 if dia_object_index is not None: 

457 config.dia_object_index = dia_object_index 

458 if htm_level is not None: 

459 config.pixelization.htm_level = htm_level 

460 if htm_index_column is not None: 

461 config.pixelization.htm_index_column = htm_index_column 

462 if ra_dec_columns is not None: 

463 config.ra_dec_columns = ra_dec_columns 

464 if prefix is not None: 

465 config.prefix = prefix 

466 if namespace is not None: 

467 config.namespace = namespace 

468 

469 cls._makeSchema(config, drop=drop) 

470 

471 # SQLite has a nasty habit of creating empty database by default, 

472 # update URL in config file to disable that behavior. 

473 config.db_url = cls._update_sqlite_url(config.db_url) 

474 

475 return config 

476 

477 def get_replica(self) -> ApdbSqlReplica: 

478 """Return `ApdbReplica` instance for this database.""" 

479 return ApdbSqlReplica(self._schema, self._engine, self._db_schema_version) 

480 

481 def tableRowCount(self) -> dict[str, int]: 

482 """Return dictionary with the table names and row counts. 

483 

484 Used by ``ap_proto`` to keep track of the size of the database tables. 

485 Depending on database technology this could be expensive operation. 

486 

487 Returns 

488 ------- 

489 row_counts : `dict` 

490 Dict where key is a table name and value is a row count. 

491 """ 

492 res = {} 

493 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

494 if self.config.dia_object_index == "last_object_table": 

495 tables.append(ApdbTables.DiaObjectLast) 

496 with self._engine.begin() as conn: 

497 for table in tables: 

498 sa_table = self._schema.get_table(table) 

499 stmt = sql.select(func.count()).select_from(sa_table) 

500 count: int = conn.execute(stmt).scalar_one() 

501 res[table.name] = count 

502 

503 return res 

504 

505 def getConfig(self) -> ApdbSqlConfig: 

506 # docstring is inherited from a base class 

507 return self.config 

508 

509 def tableDef(self, table: ApdbTables) -> Table | None: 

510 # docstring is inherited from a base class 

511 return self._schema.tableSchemas.get(table) 

512 

513 @classmethod 

514 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None: 

515 # docstring is inherited from a base class 

516 

517 if not isinstance(config, ApdbSqlConfig): 

518 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

519 

520 engine = cls._makeEngine(config, create=True) 

521 

522 table_schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

523 

524 # Ask schema class to create all tables. 

525 schema = ApdbSqlSchema( 

526 table_schema=table_schema, 

527 engine=engine, 

528 dia_object_index=config.dia_object_index, 

529 prefix=config.prefix, 

530 namespace=config.namespace, 

531 htm_index_column=config.pixelization.htm_index_column, 

532 enable_replica=config.enable_replica, 

533 ) 

534 schema.makeSchema(drop=drop) 

535 

536 # Need metadata table to store few items in it. 

537 meta_table = schema.get_table(ApdbTables.metadata) 

538 apdb_meta = ApdbMetadataSql(engine, meta_table) 

539 

540 # Fill version numbers, overwrite if they are already there. 

541 apdb_meta.set(cls.metadataSchemaVersionKey, str(table_schema.schemaVersion()), force=True) 

542 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

543 if config.enable_replica: 

544 # Only store replica code version if replica is enabled. 

545 apdb_meta.set( 

546 cls.metadataReplicaVersionKey, 

547 str(ApdbSqlReplica.apdbReplicaImplementationVersion()), 

548 force=True, 

549 ) 

550 

551 # Store frozen part of a configuration in metadata. 

552 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters) 

553 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

554 

555 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

556 # docstring is inherited from a base class 

557 

558 # decide what columns we need 

559 if self.config.dia_object_index == "last_object_table": 

560 table_enum = ApdbTables.DiaObjectLast 

561 else: 

562 table_enum = ApdbTables.DiaObject 

563 table = self._schema.get_table(table_enum) 

564 if not self.config.dia_object_columns: 

565 columns = self._schema.get_apdb_columns(table_enum) 

566 else: 

567 columns = [table.c[col] for col in self.config.dia_object_columns] 

568 query = sql.select(*columns) 

569 

570 # build selection 

571 query = query.where(self._filterRegion(table, region)) 

572 

573 validity_end_column = self._timestamp_column_name("validityEnd") 

574 

575 # select latest version of objects 

576 if self.config.dia_object_index != "last_object_table": 

577 query = query.where(table.columns[validity_end_column] == None) # noqa: E711 

578 

579 # _LOG.debug("query: %s", query) 

580 

581 # execute select 

582 with self._timer("select_time", tags={"table": "DiaObject"}) as timer: 

583 with self._engine.begin() as conn: 

584 result = conn.execute(query) 

585 column_defs = self._table_schema.tableSchemas[table_enum].columns 

586 table_data = ApdbSqlTableData(result, column_defs) 

587 objects = table_data.to_pandas() 

588 timer.add_values(row_count=len(objects)) 

589 _LOG.debug("found %s DiaObjects", len(objects)) 

590 return self._fix_result_timestamps(objects) 

591 

592 def getDiaSources( 

593 self, 

594 region: Region, 

595 object_ids: Iterable[int] | None, 

596 visit_time: astropy.time.Time, 

597 start_time: astropy.time.Time | None = None, 

598 ) -> pandas.DataFrame | None: 

599 # docstring is inherited from a base class 

600 if start_time is None and self.config.read_sources_months == 0: 

601 _LOG.debug("Skip DiaSources fetching") 

602 return None 

603 

604 if start_time is None: 

605 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

606 else: 

607 start_time_mjdTai = float(start_time.tai.mjd) 

608 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai) 

609 

610 if object_ids is None: 

611 # region-based select 

612 return self._getDiaSourcesInRegion(region, start_time_mjdTai) 

613 else: 

614 return self._getDiaSourcesByIDs(list(object_ids), start_time_mjdTai) 

615 

616 def getDiaForcedSources( 

617 self, 

618 region: Region, 

619 object_ids: Iterable[int] | None, 

620 visit_time: astropy.time.Time, 

621 start_time: astropy.time.Time | None = None, 

622 ) -> pandas.DataFrame | None: 

623 # docstring is inherited from a base class 

624 if start_time is None and self.config.read_forced_sources_months == 0: 

625 _LOG.debug("Skip DiaForceSources fetching") 

626 return None 

627 

628 if object_ids is None: 

629 # This implementation does not support region-based selection. In 

630 # the past DiaForcedSource schema did not have ra/dec columns (it 

631 # had x/y columns). ra/dec were added at some point, so we could 

632 # add pixelOd column to this table if/when needed. 

633 raise NotImplementedError("Region-based selection is not supported") 

634 

635 # TODO: DateTime.MJD must be consistent with code in ap_association, 

636 # alternatively we can fill midpointMjdTai ourselves in store() 

637 if start_time is None: 

638 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

639 else: 

640 start_time_mjdTai = float(start_time.tai.mjd) 

641 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai) 

642 

643 with self._timer("select_time", tags={"table": "DiaForcedSource"}) as timer: 

644 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), start_time_mjdTai) 

645 timer.add_values(row_count=len(sources)) 

646 

647 _LOG.debug("found %s DiaForcedSources", len(sources)) 

648 return sources 

649 

650 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame: 

651 # docstring is inherited from a base class 

652 

653 if since is None: 

654 # Read last deduplication time from metadata. 

655 dedup_str = self._metadata.get(self.metadataDedupKey) 

656 if dedup_str is not None: 

657 dedup_state = json.loads(dedup_str) 

658 dedup_time_str = dedup_state["dedup_time_iso_tai"] 

659 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai") 

660 

661 validity_start_column = self._timestamp_column_name("validityStart") 

662 

663 # decide what columns we need 

664 if self.config.dia_object_index == "last_object_table": 

665 table_enum = ApdbTables.DiaObjectLast 

666 else: 

667 table_enum = ApdbTables.DiaObject 

668 table = self._schema.get_table(table_enum) 

669 

670 if not self.config.dia_object_columns_for_dedup: 

671 columns = self._schema.get_apdb_columns(table_enum) 

672 else: 

673 column_names = list(self.config.dia_object_columns_for_dedup) 

674 if validity_start_column not in column_names: 

675 column_names.insert(0, validity_start_column) 

676 if "diaObjectId" not in column_names: 

677 column_names.insert(0, "diaObjectId") 

678 columns = [table.columns[col] for col in column_names] 

679 

680 query = sql.select(*columns) 

681 

682 # build selection 

683 if since is not None: 

684 timestamp = self._timestamp_column_value(since) 

685 query = query.where(table.columns[validity_start_column] >= timestamp) 

686 

687 # execute select 

688 with self._timer("select_time", tags={"table": "DiaObject"}) as timer: 

689 with self._engine.begin() as conn: 

690 result = conn.execute(query) 

691 column_defs = self._table_schema.tableSchemas[table_enum].columns 

692 table_data = ApdbSqlTableData(result, column_defs) 

693 objects = table_data.to_pandas() 

694 timer.add_values(row_count=len(objects)) 

695 _LOG.debug("found %s DiaObjects", len(objects)) 

696 return self._fix_result_timestamps(objects) 

697 

698 def getDiaSourcesForDiaObjects( 

699 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0 

700 ) -> pandas.DataFrame: 

701 # docstring is inherited from a base class 

702 object_ids = {object_id.diaObjectId for object_id in objects} 

703 return self._getDiaSourcesByIDs(list(object_ids), float(start_time.tai.mjd)) 

704 

705 def containsVisitDetector( 

706 self, 

707 visit: int, 

708 detector: int, 

709 region: Region, 

710 visit_time: astropy.time.Time, 

711 ) -> bool: 

712 # docstring is inherited from a base class 

713 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

714 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

715 # Query should load only one leaf page of the index 

716 query1 = sql.select(src_table.c.visit).filter_by(visit=visit, detector=detector).limit(1) 

717 

718 with self._engine.begin() as conn: 

719 result = conn.execute(query1).scalar_one_or_none() 

720 if result is not None: 

721 return True 

722 else: 

723 # Backup query if an image was processed but had no diaSources 

724 query2 = sql.select(frcsrc_table.c.visit).filter_by(visit=visit, detector=detector).limit(1) 

725 result = conn.execute(query2).scalar_one_or_none() 

726 return result is not None 

727 

728 def store( 

729 self, 

730 visit_time: astropy.time.Time, 

731 objects: pandas.DataFrame, 

732 sources: pandas.DataFrame | None = None, 

733 forced_sources: pandas.DataFrame | None = None, 

734 ) -> None: 

735 # docstring is inherited from a base class 

736 objects = self._fix_input_timestamps(objects) 

737 if sources is not None: 

738 sources = self._fix_input_timestamps(sources) 

739 if forced_sources is not None: 

740 forced_sources = self._fix_input_timestamps(forced_sources) 

741 

742 # We want to run all inserts in one transaction. 

743 with self._engine.begin() as connection: 

744 replica_chunk: ReplicaChunk | None = None 

745 if self._schema.replication_enabled: 

746 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

747 self._storeReplicaChunk(replica_chunk, connection) 

748 

749 # fill pixelId column for DiaObjects 

750 objects = self._add_spatial_index(objects) 

751 self._storeDiaObjects(objects, visit_time, replica_chunk, connection) 

752 

753 if sources is not None: 

754 # fill pixelId column for DiaSources 

755 sources = self._add_spatial_index(sources) 

756 self._storeDiaSources(sources, replica_chunk, connection) 

757 

758 if forced_sources is not None: 

759 self._storeDiaForcedSources(forced_sources, replica_chunk, connection) 

760 

761 def reassignDiaSourcesToDiaObjects( 

762 self, 

763 idMap: Mapping[DiaSourceId, int], 

764 *, 

765 increment_nDiaSources: bool = True, 

766 decrement_nDiaSources: bool = True, 

767 ) -> None: 

768 # docstring is inherited from a base class 

769 

770 new_object_ids = set(idMap.values()) 

771 source_ids = {source.diaSourceId for source in idMap} 

772 

773 current_time = self._current_time() 

774 current_time_ns = int(current_time.unix_tai * 1e9) 

775 

776 with self._engine.begin() as conn: 

777 # Make sure that all DiaSources exist. 

778 found_sources = self._get_diasource_data(conn, source_ids, "diaObjectId") 

779 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}): 

780 raise LookupError(f"Some source IDs are missing from DiaSource table: {missing_ids}") 

781 original_object_ids = {row.diaSourceId: row.diaObjectId for row in found_sources} 

782 

783 # Make sure that all DiaObjects exist, we also want to know 

784 # nDiaSources count for current and new records because we want to 

785 # send updated values to replica. 

786 all_object_ids = new_object_ids | set(original_object_ids.values()) 

787 found_objects = self._get_diaobject_data(conn, all_object_ids, "ra", "dec", "nDiaSources") 

788 if missing_ids := (new_object_ids - {row.diaObjectId for row in found_objects}): 

789 raise LookupError(f"Some object IDs are missing from DiaObject table: {missing_ids}") 

790 

791 found_objects_by_id = {row.diaObjectId: row for row in found_objects} 

792 

793 update_records: list[ApdbUpdateRecord] = [] 

794 update_order = 0 

795 

796 # Update DiaSources. 

797 source_table = self._schema.get_table(ApdbTables.DiaSource) 

798 for source, diaObjectId in idMap.items(): 

799 update = ( 

800 source_table.update() 

801 .where(source_table.columns["diaSourceId"] == source.diaSourceId) 

802 .values(diaObjectId=diaObjectId) 

803 ) 

804 conn.execute(update) 

805 

806 if self._schema.replication_enabled: 

807 update_records.append( 

808 ApdbReassignDiaSourceToDiaObjectRecord( 

809 diaSourceId=source.diaSourceId, 

810 ra=source.ra, 

811 dec=source.dec, 

812 midpointMjdTai=source.midpointMjdTai, 

813 diaObjectId=diaObjectId, 

814 update_time_ns=current_time_ns, 

815 update_order=update_order, 

816 ) 

817 ) 

818 update_order += 1 

819 

820 # DiaObject tables to update. 

821 object_tables = [self._schema.get_table(ApdbTables.DiaObject)] 

822 if self.config.dia_object_index == "last_object_table": 

823 object_tables.append(self._schema.get_table(ApdbTables.DiaObjectLast)) 

824 

825 # Things to increment/decrement. 

826 increments: Counter = Counter() 

827 if increment_nDiaSources: 

828 increments.update(idMap.values()) 

829 if decrement_nDiaSources: 

830 increments.subtract(original_object_ids[source_id.diaSourceId] for source_id in idMap) 

831 

832 if increments: 

833 for table in object_tables: 

834 for diaObjectId, increment in increments.items(): 

835 update = ( 

836 table.update() 

837 .where(table.columns["diaObjectId"] == diaObjectId) 

838 .values(nDiaSources=table.columns["nDiaSources"] + increment) 

839 ) 

840 conn.execute(update) 

841 

842 # Also send updated values to replica. 

843 if self._schema.replication_enabled: 

844 for diaObjectId, increment in increments.items(): 

845 dia_object = found_objects_by_id[diaObjectId] 

846 update_records.append( 

847 ApdbUpdateNDiaSourcesRecord( 

848 diaObjectId=diaObjectId, 

849 ra=dia_object.ra, 

850 dec=dia_object.dec, 

851 nDiaSources=dia_object.nDiaSources + increment, 

852 update_time_ns=current_time_ns, 

853 update_order=update_order, 

854 ) 

855 ) 

856 update_order += 1 

857 

858 if update_records: 

859 replica_chunk = ReplicaChunk.make_replica_chunk( 

860 current_time, self.config.replica_chunk_seconds 

861 ) 

862 self._storeUpdateRecords(update_records, replica_chunk, connection=conn, store_chunk=True) 

863 

864 def setValidityEnd( 

865 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False 

866 ) -> int: 

867 # docstring is inherited from a base class 

868 if not objects: 

869 return 0 

870 

871 requested_ids = {obj.diaObjectId for obj in objects} 

872 

873 validity_end_column = self._timestamp_column_name("validityEnd") 

874 validityEnd_value = self._timestamp_column_value(validityEnd) 

875 

876 # Find all matching DiaObjects with validityEnd = NULL. 

877 table = self._schema.get_table(ApdbTables.DiaObject) 

878 query = sql.select(table.columns["diaObjectId"]).where( 

879 sqlalchemy.and_( 

880 table.columns["diaObjectId"].in_(sorted(requested_ids)), 

881 table.columns[validity_end_column].is_(None), 

882 ) 

883 ) 

884 

885 with self._engine.begin() as conn: 

886 result = conn.execute(query) 

887 found_ids = set(result.scalars()) 

888 

889 # Check that we found all that is requested. 

890 if raise_on_missing_id: 

891 if missing_ids := (requested_ids - found_ids): 

892 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}") 

893 

894 # Filter existing records. 

895 if len(objects) != len(found_ids): 

896 objects = [obj for obj in objects if obj.diaObjectId in found_ids] 

897 

898 if not objects: 

899 return 0 

900 

901 values = {validity_end_column: validityEnd_value} 

902 update = ( 

903 table.update() 

904 .where( 

905 sqlalchemy.and_( 

906 table.columns["diaObjectId"].in_(sorted(found_ids)), 

907 table.columns[validity_end_column].is_(None), 

908 ) 

909 ) 

910 .values(**values) 

911 ) 

912 result = conn.execute(update) 

913 if result.rowcount != len(found_ids): 

914 raise RuntimeError( 

915 f"Unexpected mismatch in the number of records updated. Object IDs = {found_ids}" 

916 ) 

917 

918 # Also drop them from DiaObjectLast. 

919 if self.config.dia_object_index == "last_object_table": 

920 last_table = self._schema.get_table(ApdbTables.DiaObjectLast) 

921 delete = last_table.delete().where(last_table.columns["diaObjectId"].in_(sorted(found_ids))) 

922 result = conn.execute(delete) 

923 if result.rowcount != len(found_ids): 

924 raise RuntimeError( 

925 f"Unexpected mismatch in the number of records deleted. Object IDs = {found_ids}" 

926 ) 

927 

928 # If replication is enabled then send all updates. 

929 if self._schema.replication_enabled: 

930 current_time = self._current_time() 

931 current_time_ns = int(current_time.unix_tai * 1e9) 

932 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, self.config.replica_chunk_seconds) 

933 

934 update_records = [ 

935 ApdbCloseDiaObjectValidityRecord( 

936 diaObjectId=obj.diaObjectId, 

937 ra=obj.ra, 

938 dec=obj.dec, 

939 update_time_ns=current_time_ns, 

940 update_order=index, 

941 validityEndMjdTai=float(validityEnd.tai.mjd), 

942 nDiaSources=None, 

943 ) 

944 for index, obj in enumerate(objects) 

945 ] 

946 

947 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) 

948 

949 return len(objects) 

950 

951 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None: 

952 # docstring is inherited from a base class 

953 

954 # SQL backend does not have separate dedup tables, nothing to delete, 

955 # only save last dedup time in metadata. 

956 if dedup_time is None: 

957 dedup_time = self._current_time() 

958 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")} 

959 data_json = json.dumps(data) 

960 self._metadata.set(self.metadataDedupKey, data_json, force=True) 

961 

962 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

963 # docstring is inherited from a base class 

964 

965 timestamp: float | datetime.datetime 

966 now = self._current_time() 

967 timestamp_column = self._timestamp_column_name("ssObjectReassocTime") 

968 timestamp = self._timestamp_column_value(now) 

969 

970 table = self._schema.get_table(ApdbTables.DiaSource) 

971 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

972 

973 with self._engine.begin() as conn: 

974 # Need to make sure that every ID exists in the database, but 

975 # executemany may not support rowcount, so iterate and check what 

976 # is missing. 

977 missing_ids: list[int] = [] 

978 for key, value in idMap.items(): 

979 params = { 

980 "srcId": key, 

981 "diaObjectId": 0, 

982 "ssObjectId": value, 

983 timestamp_column: timestamp, 

984 } 

985 result = conn.execute(query, params) 

986 if result.rowcount == 0: 

987 missing_ids.append(key) 

988 if missing_ids: 

989 missing = ",".join(str(item) for item in missing_ids) 

990 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

991 

992 def countUnassociatedObjects(self) -> int: 

993 # docstring is inherited from a base class 

994 

995 # Retrieve the DiaObject table. 

996 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

997 

998 # Construct the sql statement. 

999 validity_end_column = self._timestamp_column_name("validityEnd") 

1000 stmt = ( 

1001 sql.select(func.count()) 

1002 .select_from(table) 

1003 .where( 

1004 sqlalchemy.and_( 

1005 table.columns["nDiaSources"] == 1, 

1006 table.columns[validity_end_column].is_(None), 

1007 ) 

1008 ) 

1009 ) 

1010 

1011 # Return the count. 

1012 with self._engine.begin() as conn: 

1013 count = conn.execute(stmt).scalar_one() 

1014 

1015 return count 

1016 

1017 @property 

1018 def schema(self) -> ApdbSchema: 

1019 # docstring is inherited from a base class 

1020 return self._table_schema 

1021 

1022 @property 

1023 def metadata(self) -> ApdbMetadata: 

1024 # docstring is inherited from a base class 

1025 return self._metadata 

1026 

1027 @property 

1028 def admin(self) -> ApdbSqlAdmin: 

1029 # docstring is inherited from a base class 

1030 return ApdbSqlAdmin(self.pixelator) 

1031 

1032 def _getDiaSourcesInRegion(self, region: Region, start_time_mjdTai: float) -> pandas.DataFrame: 

1033 """Return catalog of DiaSource instances from given region. 

1034 

1035 Parameters 

1036 ---------- 

1037 region : `lsst.sphgeom.Region` 

1038 Region to search for DIASources. 

1039 start_time_mjdTai : `float` 

1040 Lower bound of time window for the query. 

1041 

1042 Returns 

1043 ------- 

1044 catalog : `pandas.DataFrame` 

1045 Catalog containing DiaSource records. 

1046 """ 

1047 table = self._schema.get_table(ApdbTables.DiaSource) 

1048 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

1049 query = sql.select(*columns) 

1050 

1051 # build selection 

1052 time_filter = table.columns["midpointMjdTai"] > start_time_mjdTai 

1053 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

1054 query = query.where(where) 

1055 

1056 # execute select 

1057 with self._timer("DiaSource_select_time", tags={"table": "DiaSource"}) as timer: 

1058 with self._engine.begin() as conn: 

1059 result = conn.execute(query) 

1060 column_defs = self._table_schema.tableSchemas[ApdbTables.DiaSource].columns 

1061 table_data = ApdbSqlTableData(result, column_defs) 

1062 sources = table_data.to_pandas() 

1063 timer.add_values(row_counts=len(sources)) 

1064 _LOG.debug("found %s DiaSources", len(sources)) 

1065 return self._fix_result_timestamps(sources) 

1066 

1067 def _getDiaSourcesByIDs(self, object_ids: list[int], start_time_mjdTai: float) -> pandas.DataFrame: 

1068 """Return catalog of DiaSource instances given set of DiaObject IDs. 

1069 

1070 Parameters 

1071 ---------- 

1072 object_ids : 

1073 Collection of DiaObject IDs 

1074 start_time_mjdTai : `float` 

1075 Lower bound of time window for the query. 

1076 

1077 Returns 

1078 ------- 

1079 catalog : `pandas.DataFrame` 

1080 Catalog containing DiaSource records. 

1081 """ 

1082 with self._timer("select_time", tags={"table": "DiaSource"}) as timer: 

1083 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, start_time_mjdTai) 

1084 timer.add_values(row_count=len(sources)) 

1085 

1086 _LOG.debug("found %s DiaSources", len(sources)) 

1087 return sources 

1088 

1089 def _getSourcesByIDs( 

1090 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

1091 ) -> pandas.DataFrame: 

1092 """Return catalog of DiaSource or DiaForcedSource instances given set 

1093 of DiaObject IDs. 

1094 

1095 Parameters 

1096 ---------- 

1097 table : `sqlalchemy.schema.Table` 

1098 Database table. 

1099 object_ids : 

1100 Collection of DiaObject IDs 

1101 midpointMjdTai_start : `float` 

1102 Earliest midpointMjdTai to retrieve. 

1103 

1104 Returns 

1105 ------- 

1106 catalog : `pandas.DataFrame` 

1107 Catalog contaning DiaSource records. `None` is returned if 

1108 ``read_sources_months`` configuration parameter is set to 0 or 

1109 when ``object_ids`` is empty. 

1110 """ 

1111 table = self._schema.get_table(table_enum) 

1112 columns = self._schema.get_apdb_columns(table_enum) 

1113 column_defs = self._table_schema.tableSchemas[table_enum].columns 

1114 

1115 sources: pandas.DataFrame | None = None 

1116 if len(object_ids) <= 0: 

1117 _LOG.debug("ID list is empty, just fetch empty result") 

1118 query = sql.select(*columns).where(sql.literal(False)) 

1119 with self._engine.begin() as conn: 

1120 result = conn.execute(query) 

1121 table_data = ApdbSqlTableData(result, column_defs) 

1122 sources = table_data.to_pandas() 

1123 else: 

1124 data_frames: list[pandas.DataFrame] = [] 

1125 for ids in chunk_iterable(sorted(object_ids), 1000): 

1126 query = sql.select(*columns) 

1127 

1128 # Some types like np.int64 can cause issues with 

1129 # sqlalchemy, convert them to int. 

1130 int_ids = [int(oid) for oid in ids] 

1131 

1132 # select by object id 

1133 query = query.where( 

1134 sql.expression.and_( 

1135 table.columns["diaObjectId"].in_(int_ids), 

1136 table.columns["midpointMjdTai"] >= midpointMjdTai_start, 

1137 ) 

1138 ) 

1139 

1140 # execute select 

1141 with self._engine.begin() as conn: 

1142 result = conn.execute(query) 

1143 table_data = ApdbSqlTableData(result, column_defs) 

1144 data_frames.append(table_data.to_pandas()) 

1145 

1146 if len(data_frames) == 1: 

1147 sources = data_frames[0] 

1148 else: 

1149 sources = pandas.concat(data_frames) 

1150 assert sources is not None, "Catalog cannot be None" 

1151 return self._fix_result_timestamps(sources) 

1152 

1153 def _storeReplicaChunk( 

1154 self, 

1155 replica_chunk: ReplicaChunk, 

1156 connection: sqlalchemy.engine.Connection, 

1157 ) -> None: 

1158 # `visit_time.datetime` returns naive datetime, even though all astropy 

1159 # times are in UTC. Add UTC timezone to timestamp so that database 

1160 # can store a correct value. 

1161 dt = datetime.datetime.fromtimestamp(replica_chunk.last_update_time.unix_tai, tz=datetime.UTC) 

1162 

1163 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks) 

1164 

1165 # We need UPSERT which is dialect-specific construct 

1166 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id} 

1167 row = {"apdb_replica_chunk": replica_chunk.id} | values 

1168 if connection.dialect.name == "sqlite": 

1169 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table) 

1170 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values) 

1171 connection.execute(insert_sqlite, row) 

1172 elif connection.dialect.name == "postgresql": 

1173 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table) 

1174 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values) 

1175 connection.execute(insert_pg, row) 

1176 else: 

1177 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.") 

1178 

1179 def _storeDiaObjects( 

1180 self, 

1181 objs: pandas.DataFrame, 

1182 visit_time: astropy.time.Time, 

1183 replica_chunk: ReplicaChunk | None, 

1184 connection: sqlalchemy.engine.Connection, 

1185 ) -> None: 

1186 """Store catalog of DiaObjects from current visit. 

1187 

1188 Parameters 

1189 ---------- 

1190 objs : `pandas.DataFrame` 

1191 Catalog with DiaObject records. 

1192 visit_time : `astropy.time.Time` 

1193 Time of the visit. 

1194 replica_chunk : `ReplicaChunk` 

1195 Insert identifier. 

1196 """ 

1197 if len(objs) == 0: 

1198 _LOG.debug("No objects to write to database.") 

1199 return 

1200 

1201 # Some types like np.int64 can cause issues with sqlalchemy, convert 

1202 # them to int. 

1203 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

1204 _LOG.debug("first object ID: %d", ids[0]) 

1205 

1206 validity_start_column = self._timestamp_column_name("validityStart") 

1207 validity_end_column = self._timestamp_column_name("validityEnd") 

1208 timestamp = self._timestamp_column_value(visit_time) 

1209 

1210 # everything to be done in single transaction 

1211 if self.config.dia_object_index == "last_object_table": 

1212 # Insert and replace all records in LAST table. 

1213 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

1214 

1215 # DiaObjectLast did not have this column in the past. 

1216 use_validity_start = self._schema.check_column(ApdbTables.DiaObjectLast, validity_start_column) 

1217 

1218 # Drop the previous objects (pandas cannot upsert). 

1219 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

1220 

1221 with self._timer("delete_time", tags={"table": table.name}) as timer: 

1222 res = connection.execute(query) 

1223 timer.add_values(row_count=res.rowcount) 

1224 _LOG.debug("deleted %s objects", res.rowcount) 

1225 

1226 # DiaObjectLast is a subset of DiaObject, strip missing columns 

1227 last_column_names = [column.name for column in table.columns] 

1228 if validity_start_column in last_column_names and validity_start_column not in objs.columns: 

1229 last_column_names.remove(validity_start_column) 

1230 last_objs = objs[last_column_names] 

1231 last_objs = _coerce_uint64(last_objs) 

1232 

1233 # Fill validityStart, only when it is in the schema. 

1234 if use_validity_start: 

1235 if validity_start_column in last_objs: 

1236 last_objs[validity_start_column] = timestamp 

1237 else: 

1238 extra_column = pandas.Series([timestamp] * len(last_objs), name=validity_start_column) 

1239 last_objs.set_index(extra_column.index, inplace=True) 

1240 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

1241 

1242 with self._timer("insert_time", tags={"table": "DiaObjectLast"}) as timer: 

1243 last_objs.to_sql( 

1244 table.name, 

1245 connection, 

1246 if_exists="append", 

1247 index=False, 

1248 schema=table.schema, 

1249 ) 

1250 timer.add_values(row_count=len(last_objs)) 

1251 else: 

1252 # truncate existing validity intervals 

1253 table = self._schema.get_table(ApdbTables.DiaObject) 

1254 

1255 update = ( 

1256 table.update() 

1257 .values(**{validity_end_column: timestamp}) 

1258 .where( 

1259 sql.expression.and_( 

1260 table.columns["diaObjectId"].in_(ids), 

1261 table.columns[validity_end_column].is_(None), 

1262 ) 

1263 ) 

1264 ) 

1265 

1266 with self._timer("truncate_time", tags={"table": table.name}) as timer: 

1267 res = connection.execute(update) 

1268 timer.add_values(row_count=res.rowcount) 

1269 _LOG.debug("truncated %s intervals", res.rowcount) 

1270 

1271 objs = _coerce_uint64(objs) 

1272 

1273 # Fill additional columns 

1274 extra_columns: list[pandas.Series] = [] 

1275 if validity_start_column in objs.columns: 

1276 objs[validity_start_column] = timestamp 

1277 else: 

1278 extra_columns.append(pandas.Series([timestamp] * len(objs), name=validity_start_column)) 

1279 if validity_end_column in objs.columns: 

1280 objs[validity_end_column] = None 

1281 else: 

1282 extra_columns.append(pandas.Series([None] * len(objs), name=validity_end_column)) 

1283 if extra_columns: 

1284 objs.set_index(extra_columns[0].index, inplace=True) 

1285 objs = pandas.concat([objs] + extra_columns, axis="columns") 

1286 

1287 # Insert replica data 

1288 table = self._schema.get_table(ApdbTables.DiaObject) 

1289 replica_data: list[dict] = [] 

1290 replica_stmt: Any = None 

1291 replica_table_name = "" 

1292 if replica_chunk is not None: 

1293 pk_names = [column.name for column in table.primary_key] 

1294 replica_data = objs[pk_names].to_dict("records") 

1295 if replica_data: 

1296 for row in replica_data: 

1297 row["apdb_replica_chunk"] = replica_chunk.id 

1298 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks) 

1299 replica_table_name = replica_table.name 

1300 replica_stmt = replica_table.insert() 

1301 

1302 # insert new versions 

1303 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1304 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1305 timer.add_values(row_count=len(objs)) 

1306 if replica_stmt is not None: 

1307 with self._timer("insert_time", tags={"table": replica_table_name}) as timer: 

1308 connection.execute(replica_stmt, replica_data) 

1309 timer.add_values(row_count=len(replica_data)) 

1310 

1311 def _storeDiaSources( 

1312 self, 

1313 sources: pandas.DataFrame, 

1314 replica_chunk: ReplicaChunk | None, 

1315 connection: sqlalchemy.engine.Connection, 

1316 ) -> None: 

1317 """Store catalog of DiaSources from current visit. 

1318 

1319 Parameters 

1320 ---------- 

1321 sources : `pandas.DataFrame` 

1322 Catalog containing DiaSource records 

1323 """ 

1324 table = self._schema.get_table(ApdbTables.DiaSource) 

1325 

1326 # Insert replica data 

1327 replica_data: list[dict] = [] 

1328 replica_stmt: Any = None 

1329 replica_table_name = "" 

1330 if replica_chunk is not None: 

1331 pk_names = [column.name for column in table.primary_key] 

1332 replica_data = sources[pk_names].to_dict("records") 

1333 if replica_data: 

1334 for row in replica_data: 

1335 row["apdb_replica_chunk"] = replica_chunk.id 

1336 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks) 

1337 replica_table_name = replica_table.name 

1338 replica_stmt = replica_table.insert() 

1339 

1340 # everything to be done in single transaction 

1341 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1342 sources = _coerce_uint64(sources) 

1343 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1344 timer.add_values(row_count=len(sources)) 

1345 if replica_stmt is not None: 

1346 with self._timer("replica_insert_time", tags={"table": replica_table_name}) as timer: 

1347 connection.execute(replica_stmt, replica_data) 

1348 timer.add_values(row_count=len(replica_data)) 

1349 

1350 def _storeDiaForcedSources( 

1351 self, 

1352 sources: pandas.DataFrame, 

1353 replica_chunk: ReplicaChunk | None, 

1354 connection: sqlalchemy.engine.Connection, 

1355 ) -> None: 

1356 """Store a set of DiaForcedSources from current visit. 

1357 

1358 Parameters 

1359 ---------- 

1360 sources : `pandas.DataFrame` 

1361 Catalog containing DiaForcedSource records 

1362 """ 

1363 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1364 

1365 # Insert replica data 

1366 replica_data: list[dict] = [] 

1367 replica_stmt: Any = None 

1368 replica_table_name = "" 

1369 if replica_chunk is not None: 

1370 pk_names = [column.name for column in table.primary_key] 

1371 replica_data = sources[pk_names].to_dict("records") 

1372 if replica_data: 

1373 for row in replica_data: 

1374 row["apdb_replica_chunk"] = replica_chunk.id 

1375 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks) 

1376 replica_table_name = replica_table.name 

1377 replica_stmt = replica_table.insert() 

1378 

1379 # everything to be done in single transaction 

1380 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1381 sources = _coerce_uint64(sources) 

1382 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1383 timer.add_values(row_count=len(sources)) 

1384 if replica_stmt is not None: 

1385 with self._timer("insert_time", tags={"table": replica_table_name}) as timer: 

1386 connection.execute(replica_stmt, replica_data) 

1387 timer.add_values(row_count=len(replica_data)) 

1388 

1389 def _storeUpdateRecords( 

1390 self, 

1391 records: Iterable[ApdbUpdateRecord], 

1392 chunk: ReplicaChunk, 

1393 *, 

1394 store_chunk: bool = False, 

1395 connection: sqlalchemy.engine.Connection | None = None, 

1396 ) -> None: 

1397 """Store ApdbUpdateRecords in the replica table for those records. 

1398 

1399 Parameters 

1400 ---------- 

1401 records : `list` [`ApdbUpdateRecord`] 

1402 Records to store. 

1403 chunk : `ReplicaChunk` 

1404 Replica chunk for these records. 

1405 store_chunk : `bool` 

1406 If True then also store replica chunk. 

1407 connection : `sqlalchemy.engine.Connection` 

1408 SQLALchemy connection to use, if `None` the new connection will be 

1409 made. `None` is useful for tests only, regular use will call this 

1410 method in the same transaction that saves other types of records. 

1411 

1412 Raises 

1413 ------ 

1414 TypeError 

1415 Raised if replication is not enabled for this instance. 

1416 """ 

1417 if not self._schema.replication_enabled: 

1418 raise TypeError("Replication is not enabled for this APDB instance.") 

1419 

1420 apdb_replica_chunk = chunk.id 

1421 # Do not use unique_if from ReplicaChunk as it could be reused in 

1422 # multiple calls to this method. 

1423 update_unique_id = uuid.uuid4() 

1424 

1425 record_dicts = [] 

1426 for record in records: 

1427 record_dicts.append( 

1428 { 

1429 "apdb_replica_chunk": apdb_replica_chunk, 

1430 "update_time_ns": record.update_time_ns, 

1431 "update_order": record.update_order, 

1432 "update_unique_id": update_unique_id, 

1433 "update_payload": record.to_json(), 

1434 } 

1435 ) 

1436 

1437 if not record_dicts: 

1438 return 

1439 

1440 # TODO: Need to check that table exists. 

1441 table = self._schema.get_table(ExtraTables.ApdbUpdateRecordChunks) 

1442 

1443 def _do_store(connection: sqlalchemy.engine.Connection) -> None: 

1444 if store_chunk: 

1445 self._storeReplicaChunk(chunk, connection) 

1446 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1447 connection.execute(table.insert(), record_dicts) 

1448 timer.add_values(row_count=len(record_dicts)) 

1449 

1450 if connection is None: 

1451 with self._engine.begin() as connection: 

1452 _do_store(connection) 

1453 else: 

1454 _do_store(connection) 

1455 

1456 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1457 """Generate a set of HTM indices covering specified region. 

1458 

1459 Parameters 

1460 ---------- 

1461 region: `sphgeom.Region` 

1462 Region that needs to be indexed. 

1463 

1464 Returns 

1465 ------- 

1466 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1467 """ 

1468 _LOG.debug("region: %s", region) 

1469 indices = self.pixelator.envelope(region, self.config.pixelization.htm_max_ranges) 

1470 

1471 return indices.ranges() 

1472 

1473 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1474 """Make SQLAlchemy expression for selecting records in a region.""" 

1475 htm_index_column = table.columns[self.config.pixelization.htm_index_column] 

1476 exprlist = [] 

1477 pixel_ranges = self._htm_indices(region) 

1478 for low, upper in pixel_ranges: 

1479 upper -= 1 

1480 if low == upper: 

1481 exprlist.append(htm_index_column == low) 

1482 else: 

1483 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1484 

1485 return sql.expression.or_(*exprlist) 

1486 

1487 def _add_spatial_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1488 """Calculate spatial index for each record and add it to a DataFrame. 

1489 

1490 Parameters 

1491 ---------- 

1492 df : `pandas.DataFrame` 

1493 DataFrame which has to contain ra/dec columns, names of these 

1494 columns are defined by configuration ``ra_dec_columns`` field. 

1495 

1496 Returns 

1497 ------- 

1498 df : `pandas.DataFrame` 

1499 DataFrame with ``pixelId`` column which contains pixel index 

1500 for ra/dec coordinates. 

1501 

1502 Notes 

1503 ----- 

1504 This overrides any existing column in a DataFrame with the same name 

1505 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1506 returned. 

1507 """ 

1508 # calculate HTM index for every DiaObject 

1509 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1510 ra_col, dec_col = self.config.ra_dec_columns 

1511 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1512 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1513 idx = self.pixelator.index(uv3d) 

1514 htm_index[i] = idx 

1515 df = df.copy() 

1516 df[self.config.pixelization.htm_index_column] = htm_index 

1517 return df 

1518 

1519 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1520 """Update timestamp columns in input DataFrame to be aware datetime 

1521 type in in UTC. 

1522 

1523 AP pipeline generates naive datetime instances, we want them to be 

1524 aware before they go to database. All naive timestamps are assumed to 

1525 be in UTC timezone (they should be TAI). 

1526 """ 

1527 # Find all columns with aware non-UTC timestamps and convert to UTC. 

1528 columns = [ 

1529 column 

1530 for column, dtype in df.dtypes.items() 

1531 if isinstance(dtype, pandas.DatetimeTZDtype) and dtype.tz is not datetime.UTC 

1532 ] 

1533 for column in columns: 

1534 df[column] = df[column].dt.tz_convert(datetime.UTC) 

1535 # Find all columns with naive timestamps and add UTC timezone. 

1536 columns = [ 

1537 column for column, dtype in df.dtypes.items() if pandas.api.types.is_datetime64_dtype(dtype) 

1538 ] 

1539 for column in columns: 

1540 df[column] = df[column].dt.tz_localize(datetime.UTC) 

1541 return df 

1542 

1543 def _fix_result_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1544 """Update timestamp columns to be naive datetime type in returned 

1545 DataFrame. 

1546 

1547 AP pipeline code expects DataFrames to contain naive datetime columns, 

1548 while Postgres queries return timezone-aware type. This method converts 

1549 those columns to naive datetime in UTC timezone. 

1550 """ 

1551 # Find all columns with aware timestamps. 

1552 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)] 

1553 for column in columns: 

1554 # tz_convert(None) will convert to UTC and drop timezone. 

1555 df[column] = df[column].dt.tz_convert(None) 

1556 return df 

1557 

1558 def _timestamp_column_name(self, column: str) -> str: 

1559 """Return column name before/after schema migration to MJD TAI.""" 

1560 return self.schema.timestamp_column_name(column) 

1561 

1562 def _timestamp_column_value(self, time: astropy.time.Time) -> float | datetime.datetime: 

1563 """Return column value before/after schema migration to MJD TAI.""" 

1564 if self.schema.has_mjd_timestamps: 

1565 return float(time.tai.mjd) 

1566 else: 

1567 return time.datetime.astimezone(tz=datetime.UTC) 

1568 

1569 def _get_diaobject_data( 

1570 self, conn: sqlalchemy.engine.Connection, object_ids: Iterable[int], *columns: str 

1571 ) -> list: 

1572 """Select records from either DiaObject or DiaObjectLast and return 

1573 selected rows as names tuples. 

1574 """ 

1575 where: sqlalchemy.ColumnElement[bool] 

1576 if self.config.dia_object_index == "last_object_table": 

1577 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

1578 where = table.columns["diaObjectId"].in_(sorted(object_ids)) 

1579 else: 

1580 table = self._schema.get_table(ApdbTables.DiaObject) 

1581 validity_end_column = self._timestamp_column_name("validityEnd") 

1582 where = sqlalchemy.and_( 

1583 table.columns["diaObjectId"].in_(sorted(object_ids)), 

1584 table.columns[validity_end_column].is_(None), 

1585 ) 

1586 column_list = [table.columns["diaObjectId"]] + [table.columns[column] for column in columns] 

1587 query = sql.select(*column_list).where(where) 

1588 result = conn.execute(query) 

1589 

1590 return list(result) 

1591 

1592 def _get_diasource_data( 

1593 self, conn: sqlalchemy.engine.Connection, source_ids: Iterable[int], *columns: str 

1594 ) -> list: 

1595 """Select records from DiaSource table by diaSourceId and return 

1596 selected rows as named tuples. 

1597 """ 

1598 table = self._schema.get_table(ApdbTables.DiaSource) 

1599 where = table.columns["diaSourceId"].in_(sorted(source_ids)) 

1600 column_list = [table.columns["diaSourceId"]] + [table.columns[column] for column in columns] 

1601 query = sql.select(*column_list).where(where) 

1602 result = conn.execute(query) 

1603 

1604 return list(result)