Coverage for python / lsst / dax / apdb / sql / apdbSql.py: 13%

545 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:49 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods.""" 

23 

24from __future__ import annotations 

25 

26__all__ = ["ApdbSql"] 

27 

28import datetime 

29import logging 

30import urllib.parse 

31import uuid 

32import warnings 

33from collections.abc import Iterable, Mapping, MutableMapping 

34from contextlib import closing 

35from typing import TYPE_CHECKING, Any 

36 

37import astropy.time 

38import numpy as np 

39import pandas 

40import sqlalchemy 

41import sqlalchemy.dialects.postgresql 

42import sqlalchemy.dialects.sqlite 

43from sqlalchemy import func, sql 

44from sqlalchemy.pool import NullPool 

45 

46from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

47from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

48from lsst.utils.iteration import chunk_iterable 

49 

50from ..apdb import Apdb 

51from ..apdbConfigFreezer import ApdbConfigFreezer 

52from ..apdbReplica import ReplicaChunk 

53from ..apdbSchema import ApdbTables 

54from ..config import ApdbConfig 

55from ..monitor import MonAgent 

56from ..schema_model import Table 

57from ..timer import Timer 

58from ..versionTuple import IncompatibleVersionError, VersionTuple 

59from .apdbMetadataSql import ApdbMetadataSql 

60from .apdbSqlAdmin import ApdbSqlAdmin 

61from .apdbSqlReplica import ApdbSqlReplica 

62from .apdbSqlSchema import ApdbSqlSchema, ExtraTables 

63from .config import ApdbSqlConfig 

64 

65if TYPE_CHECKING: 

66 import sqlite3 

67 

68 from ..apdbMetadata import ApdbMetadata 

69 from ..apdbUpdateRecord import ApdbUpdateRecord 

70 

71_LOG = logging.getLogger(__name__) 

72 

73_MON = MonAgent(__name__) 

74 

75VERSION = VersionTuple(1, 2, 1) 

76"""Version for the code controlling non-replication tables. This needs to be 

77updated following compatibility rules when schema produced by this code 

78changes. 

79""" 

80 

81 

82def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

83 """Change the type of uint64 columns to int64, and return copy of data 

84 frame. 

85 """ 

86 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

87 return df.astype(dict.fromkeys(names, np.int64)) 

88 

89 

90def _make_midpointMjdTai_start(visit_time: astropy.time.Time, months: int) -> float: 

91 """Calculate starting point for time-based source search. 

92 

93 Parameters 

94 ---------- 

95 visit_time : `astropy.time.Time` 

96 Time of current visit. 

97 months : `int` 

98 Number of months in the sources history. 

99 

100 Returns 

101 ------- 

102 time : `float` 

103 A ``midpointMjdTai`` starting point, MJD time. 

104 """ 

105 # TODO: Use of MJD must be consistent with the code in ap_association 

106 # (see DM-31996) 

107 return float(visit_time.tai.mjd) - months * 30 

108 

109 

110def _onSqlite3Connect( 

111 dbapiConnection: sqlite3.Connection, connectionRecord: sqlalchemy.pool._ConnectionRecord 

112) -> None: 

113 # Enable foreign keys 

114 with closing(dbapiConnection.cursor()) as cursor: 

115 cursor.execute("PRAGMA foreign_keys=ON;") 

116 

117 

118class ApdbSql(Apdb): 

119 """Implementation of APDB interface based on SQL database. 

120 

121 The implementation is configured via standard ``pex_config`` mechanism 

122 using `ApdbSqlConfig` configuration class. For an example of different 

123 configurations check ``config/`` folder. 

124 

125 Parameters 

126 ---------- 

127 config : `ApdbSqlConfig` 

128 Configuration object. 

129 """ 

130 

131 metadataSchemaVersionKey = "version:schema" 

132 """Name of the metadata key to store schema version number.""" 

133 

134 metadataCodeVersionKey = "version:ApdbSql" 

135 """Name of the metadata key to store code version number.""" 

136 

137 metadataReplicaVersionKey = "version:ApdbSqlReplica" 

138 """Name of the metadata key to store replica code version number.""" 

139 

140 metadataConfigKey = "config:apdb-sql.json" 

141 """Name of the metadata key to store code version number.""" 

142 

143 _frozen_parameters = ( 

144 "enable_replica", 

145 "dia_object_index", 

146 "pixelization.htm_level", 

147 "pixelization.htm_index_column", 

148 "ra_dec_columns", 

149 ) 

150 """Names of the config parameters to be frozen in metadata table.""" 

151 

152 def __init__(self, config: ApdbSqlConfig): 

153 self._engine = self._makeEngine(config, create=False) 

154 

155 sa_metadata = sqlalchemy.MetaData(schema=config.namespace) 

156 meta_table_name = ApdbTables.metadata.table_name(prefix=config.prefix) 

157 meta_table = sqlalchemy.schema.Table(meta_table_name, sa_metadata, autoload_with=self._engine) 

158 self._metadata = ApdbMetadataSql(self._engine, meta_table) 

159 

160 # Read frozen config from metadata. 

161 config_json = self._metadata.get(self.metadataConfigKey) 

162 if config_json is not None: 

163 # Update config from metadata. 

164 freezer = ApdbConfigFreezer[ApdbSqlConfig](self._frozen_parameters) 

165 self.config = freezer.update(config, config_json) 

166 else: 

167 self.config = config 

168 

169 self._schema = ApdbSqlSchema( 

170 engine=self._engine, 

171 dia_object_index=self.config.dia_object_index, 

172 schema_file=self.config.schema_file, 

173 ss_schema_file=self.config.ss_schema_file, 

174 prefix=self.config.prefix, 

175 namespace=self.config.namespace, 

176 htm_index_column=self.config.pixelization.htm_index_column, 

177 enable_replica=self.config.enable_replica, 

178 ) 

179 

180 self._db_schema_version = self._versionCheck(self._metadata) 

181 

182 self.pixelator = HtmPixelization(self.config.pixelization.htm_level) 

183 

184 if _LOG.isEnabledFor(logging.DEBUG): 

185 _LOG.debug("ApdbSql Configuration: %s", self.config.model_dump()) 

186 

187 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

188 """Create `Timer` instance given its name.""" 

189 return Timer(name, _MON, tags=tags) 

190 

191 @classmethod 

192 def _makeEngine(cls, config: ApdbSqlConfig, *, create: bool) -> sqlalchemy.engine.Engine: 

193 """Make SQLALchemy engine based on configured parameters. 

194 

195 Parameters 

196 ---------- 

197 config : `ApdbSqlConfig` 

198 Configuration object. 

199 create : `bool` 

200 Whether to try to create new database file, only relevant for 

201 SQLite backend which always creates new files by default. 

202 """ 

203 # engine is reused between multiple processes, make sure that we don't 

204 # share connections by disabling pool (by using NullPool class) 

205 kw: MutableMapping[str, Any] = dict(config.connection_config.extra_parameters) 

206 conn_args: dict[str, Any] = {} 

207 if not config.connection_config.connection_pool: 

208 kw.update(poolclass=NullPool) 

209 if config.connection_config.isolation_level is not None: 

210 kw.update(isolation_level=config.connection_config.isolation_level) 

211 elif config.db_url.startswith("sqlite"): 

212 # Use READ_UNCOMMITTED as default value for sqlite. 

213 kw.update(isolation_level="READ_UNCOMMITTED") 

214 if config.connection_config.connection_timeout is not None: 

215 if config.db_url.startswith("sqlite"): 

216 conn_args.update(timeout=config.connection_config.connection_timeout) 

217 elif config.db_url.startswith(("postgresql", "mysql")): 

218 conn_args.update(connect_timeout=int(config.connection_config.connection_timeout)) 

219 kw.update(connect_args=conn_args) 

220 engine = sqlalchemy.create_engine(cls._connection_url(config.db_url, create=create), **kw) 

221 

222 if engine.dialect.name == "sqlite": 

223 # Need to enable foreign keys on every new connection. 

224 sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect) 

225 

226 return engine 

227 

228 @classmethod 

229 def _connection_url(cls, config_url: str, *, create: bool) -> sqlalchemy.engine.URL | str: 

230 """Generate a complete URL for database with proper credentials. 

231 

232 Parameters 

233 ---------- 

234 config_url : `str` 

235 Database URL as specified in configuration. 

236 create : `bool` 

237 Whether to try to create new database file, only relevant for 

238 SQLite backend which always creates new files by default. 

239 

240 Returns 

241 ------- 

242 connection_url : `sqlalchemy.engine.URL` or `str` 

243 Connection URL including credentials. 

244 """ 

245 # Allow 3rd party authentication mechanisms by assuming connection 

246 # string is correct when we can not recognize (dialect, host, database) 

247 # matching keys. 

248 components = urllib.parse.urlparse(config_url) 

249 if all((components.scheme is not None, components.hostname is not None, components.path is not None)): 

250 try: 

251 db_auth = DbAuth() 

252 config_url = db_auth.getUrl(config_url) 

253 except DbAuthNotFoundError: 

254 # Credentials file doesn't exist or no matching credentials, 

255 # use default auth. 

256 pass 

257 

258 # SQLite has a nasty habit creating empty databases when they do not 

259 # exist, tell it not to do that unless we do need to create it. 

260 if not create: 

261 config_url = cls._update_sqlite_url(config_url) 

262 

263 return config_url 

264 

265 @classmethod 

266 def _update_sqlite_url(cls, url_string: str) -> str: 

267 """If URL refers to sqlite dialect, update it so that the backend does 

268 not try to create database file if it does not exist already. 

269 

270 Parameters 

271 ---------- 

272 url_string : `str` 

273 Connection string. 

274 

275 Returns 

276 ------- 

277 url_string : `str` 

278 Possibly updated connection string. 

279 """ 

280 try: 

281 url = sqlalchemy.make_url(url_string) 

282 except sqlalchemy.exc.SQLAlchemyError: 

283 # If parsing fails it means some special format, likely not 

284 # sqlite so we just return it unchanged. 

285 return url_string 

286 

287 if url.get_backend_name() == "sqlite": 

288 # Massage url so that database name starts with "file:" and 

289 # option string has "mode=rw&uri=true". Database name 

290 # should look like a path (:memory: is not supported by 

291 # Apdb, but someone could still try to use it). 

292 database = url.database 

293 if database and not database.startswith((":", "file:")): 

294 query = dict(url.query, mode="rw", uri="true") 

295 # If ``database`` is an absolute path then original URL should 

296 # include four slashes after "sqlite:". Humans are bad at 

297 # counting things beyond four and sometimes an extra slash gets 

298 # added unintentionally, which causes sqlite to treat initial 

299 # element as "authority" and to complain. Strip extra slashes 

300 # at the start of the path to avoid that (DM-46077). 

301 if database.startswith("//"): 

302 warnings.warn( 

303 f"Database URL contains extra leading slashes which will be removed: {url}", 

304 stacklevel=3, 

305 ) 

306 database = "/" + database.lstrip("/") 

307 url = url.set(database=f"file:{database}", query=query) 

308 url_string = url.render_as_string() 

309 

310 return url_string 

311 

312 def _versionCheck(self, metadata: ApdbMetadataSql) -> VersionTuple: 

313 """Check schema version compatibility and return the database schema 

314 version. 

315 """ 

316 

317 def _get_version(key: str) -> VersionTuple: 

318 """Retrieve version number from given metadata key.""" 

319 version_str = metadata.get(key) 

320 if version_str is None: 

321 # Should not happen with existing metadata table. 

322 raise RuntimeError(f"Version key {key!r} does not exist in metadata table.") 

323 return VersionTuple.fromString(version_str) 

324 

325 db_schema_version = _get_version(self.metadataSchemaVersionKey) 

326 db_code_version = _get_version(self.metadataCodeVersionKey) 

327 

328 # For now there is no way to make read-only APDB instances, assume that 

329 # any access can do updates. 

330 if not self._schema.schemaVersion().checkCompatibility(db_schema_version): 

331 raise IncompatibleVersionError( 

332 f"Configured schema version {self._schema.schemaVersion()} " 

333 f"is not compatible with database version {db_schema_version}" 

334 ) 

335 if not self.apdbImplementationVersion().checkCompatibility(db_code_version): 

336 raise IncompatibleVersionError( 

337 f"Current code version {self.apdbImplementationVersion()} " 

338 f"is not compatible with database version {db_code_version}" 

339 ) 

340 

341 # Check replica code version only if replica is enabled. 

342 if self._schema.replication_enabled: 

343 db_replica_version = _get_version(self.metadataReplicaVersionKey) 

344 code_replica_version = ApdbSqlReplica.apdbReplicaImplementationVersion() 

345 if not code_replica_version.checkCompatibility(db_replica_version): 

346 raise IncompatibleVersionError( 

347 f"Current replication code version {code_replica_version} " 

348 f"is not compatible with database version {db_replica_version}" 

349 ) 

350 

351 return db_schema_version 

352 

353 @classmethod 

354 def apdbImplementationVersion(cls) -> VersionTuple: 

355 """Return version number for current APDB implementation. 

356 

357 Returns 

358 ------- 

359 version : `VersionTuple` 

360 Version of the code defined in implementation class. 

361 """ 

362 return VERSION 

363 

364 @classmethod 

365 def init_database( 

366 cls, 

367 db_url: str, 

368 *, 

369 schema_file: str | None = None, 

370 ss_schema_file: str | None = None, 

371 read_sources_months: int | None = None, 

372 read_forced_sources_months: int | None = None, 

373 enable_replica: bool = False, 

374 connection_timeout: int | None = None, 

375 dia_object_index: str | None = None, 

376 htm_level: int | None = None, 

377 htm_index_column: str | None = None, 

378 ra_dec_columns: tuple[str, str] | None = None, 

379 prefix: str | None = None, 

380 namespace: str | None = None, 

381 drop: bool = False, 

382 ) -> ApdbSqlConfig: 

383 """Initialize new APDB instance and make configuration object for it. 

384 

385 Parameters 

386 ---------- 

387 db_url : `str` 

388 SQLAlchemy database URL. 

389 schema_file : `str`, optional 

390 Location of (YAML) configuration file with APDB schema. If not 

391 specified then default location will be used. 

392 ss_schema_file : `str`, optional 

393 Location of (YAML) configuration file with SSO schema. If not 

394 specified then default location will be used. 

395 read_sources_months : `int`, optional 

396 Number of months of history to read from DiaSource. 

397 read_forced_sources_months : `int`, optional 

398 Number of months of history to read from DiaForcedSource. 

399 enable_replica : `bool`, optional 

400 If True, make additional tables used for replication to PPDB. 

401 connection_timeout : `int`, optional 

402 Database connection timeout in seconds. 

403 dia_object_index : `str`, optional 

404 Indexing mode for DiaObject table. 

405 htm_level : `int`, optional 

406 HTM indexing level. 

407 htm_index_column : `str`, optional 

408 Name of a HTM index column for DiaObject and DiaSource tables. 

409 ra_dec_columns : `tuple` [`str`, `str`], optional 

410 Names of ra/dec columns in DiaObject table. 

411 prefix : `str`, optional 

412 Optional prefix for all table names. 

413 namespace : `str`, optional 

414 Name of the database schema for all APDB tables. If not specified 

415 then default schema is used. 

416 drop : `bool`, optional 

417 If `True` then drop existing tables before re-creating the schema. 

418 

419 Returns 

420 ------- 

421 config : `ApdbSqlConfig` 

422 Resulting configuration object for a created APDB instance. 

423 """ 

424 config = ApdbSqlConfig(db_url=db_url, enable_replica=enable_replica) 

425 if schema_file is not None: 

426 config.schema_file = schema_file 

427 if ss_schema_file is not None: 

428 config.ss_schema_file = ss_schema_file 

429 if read_sources_months is not None: 

430 config.read_sources_months = read_sources_months 

431 if read_forced_sources_months is not None: 

432 config.read_forced_sources_months = read_forced_sources_months 

433 if connection_timeout is not None: 

434 config.connection_config.connection_timeout = connection_timeout 

435 if dia_object_index is not None: 

436 config.dia_object_index = dia_object_index 

437 if htm_level is not None: 

438 config.pixelization.htm_level = htm_level 

439 if htm_index_column is not None: 

440 config.pixelization.htm_index_column = htm_index_column 

441 if ra_dec_columns is not None: 

442 config.ra_dec_columns = ra_dec_columns 

443 if prefix is not None: 

444 config.prefix = prefix 

445 if namespace is not None: 

446 config.namespace = namespace 

447 

448 cls._makeSchema(config, drop=drop) 

449 

450 # SQLite has a nasty habit of creating empty database by default, 

451 # update URL in config file to disable that behavior. 

452 config.db_url = cls._update_sqlite_url(config.db_url) 

453 

454 return config 

455 

456 def get_replica(self) -> ApdbSqlReplica: 

457 """Return `ApdbReplica` instance for this database.""" 

458 return ApdbSqlReplica(self._schema, self._engine, self._db_schema_version) 

459 

460 def tableRowCount(self) -> dict[str, int]: 

461 """Return dictionary with the table names and row counts. 

462 

463 Used by ``ap_proto`` to keep track of the size of the database tables. 

464 Depending on database technology this could be expensive operation. 

465 

466 Returns 

467 ------- 

468 row_counts : `dict` 

469 Dict where key is a table name and value is a row count. 

470 """ 

471 res = {} 

472 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource] 

473 if self.config.dia_object_index == "last_object_table": 

474 tables.append(ApdbTables.DiaObjectLast) 

475 with self._engine.begin() as conn: 

476 for table in tables: 

477 sa_table = self._schema.get_table(table) 

478 stmt = sql.select(func.count()).select_from(sa_table) 

479 count: int = conn.execute(stmt).scalar_one() 

480 res[table.name] = count 

481 

482 return res 

483 

484 def getConfig(self) -> ApdbSqlConfig: 

485 # docstring is inherited from a base class 

486 return self.config 

487 

488 def tableDef(self, table: ApdbTables) -> Table | None: 

489 # docstring is inherited from a base class 

490 return self._schema.tableSchemas.get(table) 

491 

492 @classmethod 

493 def _makeSchema(cls, config: ApdbConfig, drop: bool = False) -> None: 

494 # docstring is inherited from a base class 

495 

496 if not isinstance(config, ApdbSqlConfig): 

497 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

498 

499 engine = cls._makeEngine(config, create=True) 

500 

501 # Ask schema class to create all tables. 

502 schema = ApdbSqlSchema( 

503 engine=engine, 

504 dia_object_index=config.dia_object_index, 

505 schema_file=config.schema_file, 

506 ss_schema_file=config.ss_schema_file, 

507 prefix=config.prefix, 

508 namespace=config.namespace, 

509 htm_index_column=config.pixelization.htm_index_column, 

510 enable_replica=config.enable_replica, 

511 ) 

512 schema.makeSchema(drop=drop) 

513 

514 # Need metadata table to store few items in it. 

515 meta_table = schema.get_table(ApdbTables.metadata) 

516 apdb_meta = ApdbMetadataSql(engine, meta_table) 

517 

518 # Fill version numbers, overwrite if they are already there. 

519 apdb_meta.set(cls.metadataSchemaVersionKey, str(schema.schemaVersion()), force=True) 

520 apdb_meta.set(cls.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True) 

521 if config.enable_replica: 

522 # Only store replica code version if replica is enabled. 

523 apdb_meta.set( 

524 cls.metadataReplicaVersionKey, 

525 str(ApdbSqlReplica.apdbReplicaImplementationVersion()), 

526 force=True, 

527 ) 

528 

529 # Store frozen part of a configuration in metadata. 

530 freezer = ApdbConfigFreezer[ApdbSqlConfig](cls._frozen_parameters) 

531 apdb_meta.set(cls.metadataConfigKey, freezer.to_json(config), force=True) 

532 

533 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

534 # docstring is inherited from a base class 

535 

536 # decide what columns we need 

537 if self.config.dia_object_index == "last_object_table": 

538 table_enum = ApdbTables.DiaObjectLast 

539 else: 

540 table_enum = ApdbTables.DiaObject 

541 table = self._schema.get_table(table_enum) 

542 if not self.config.dia_object_columns: 

543 columns = self._schema.get_apdb_columns(table_enum) 

544 else: 

545 columns = [table.c[col] for col in self.config.dia_object_columns] 

546 query = sql.select(*columns) 

547 

548 # build selection 

549 query = query.where(self._filterRegion(table, region)) 

550 

551 if self._schema.has_mjd_timestamps: 

552 validity_end_column = "validityEndMjdTai" 

553 else: 

554 validity_end_column = "validityEnd" 

555 

556 # select latest version of objects 

557 if self.config.dia_object_index != "last_object_table": 

558 query = query.where(table.columns[validity_end_column] == None) # noqa: E711 

559 

560 # _LOG.debug("query: %s", query) 

561 

562 # execute select 

563 with self._timer("select_time", tags={"table": "DiaObject"}) as timer: 

564 with self._engine.begin() as conn: 

565 objects = pandas.read_sql_query(query, conn) 

566 timer.add_values(row_count=len(objects)) 

567 _LOG.debug("found %s DiaObjects", len(objects)) 

568 return self._fix_result_timestamps(objects) 

569 

570 def getDiaSources( 

571 self, 

572 region: Region, 

573 object_ids: Iterable[int] | None, 

574 visit_time: astropy.time.Time, 

575 start_time: astropy.time.Time | None = None, 

576 ) -> pandas.DataFrame | None: 

577 # docstring is inherited from a base class 

578 if start_time is None and self.config.read_sources_months == 0: 

579 _LOG.debug("Skip DiaSources fetching") 

580 return None 

581 

582 if start_time is None: 

583 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months) 

584 else: 

585 start_time_mjdTai = float(start_time.tai.mjd) 

586 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai) 

587 

588 if object_ids is None: 

589 # region-based select 

590 return self._getDiaSourcesInRegion(region, start_time_mjdTai) 

591 else: 

592 return self._getDiaSourcesByIDs(list(object_ids), start_time_mjdTai) 

593 

594 def getDiaForcedSources( 

595 self, 

596 region: Region, 

597 object_ids: Iterable[int] | None, 

598 visit_time: astropy.time.Time, 

599 start_time: astropy.time.Time | None = None, 

600 ) -> pandas.DataFrame | None: 

601 # docstring is inherited from a base class 

602 if start_time is None and self.config.read_forced_sources_months == 0: 

603 _LOG.debug("Skip DiaForceSources fetching") 

604 return None 

605 

606 if object_ids is None: 

607 # This implementation does not support region-based selection. In 

608 # the past DiaForcedSource schema did not have ra/dec columns (it 

609 # had x/y columns). ra/dec were added at some point, so we could 

610 # add pixelOd column to this table if/when needed. 

611 raise NotImplementedError("Region-based selection is not supported") 

612 

613 # TODO: DateTime.MJD must be consistent with code in ap_association, 

614 # alternatively we can fill midpointMjdTai ourselves in store() 

615 if start_time is None: 

616 start_time_mjdTai = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months) 

617 else: 

618 start_time_mjdTai = float(start_time.tai.mjd) 

619 _LOG.debug("start_time_mjdTai = %.6f", start_time_mjdTai) 

620 

621 with self._timer("select_time", tags={"table": "DiaForcedSource"}) as timer: 

622 sources = self._getSourcesByIDs(ApdbTables.DiaForcedSource, list(object_ids), start_time_mjdTai) 

623 timer.add_values(row_count=len(sources)) 

624 

625 _LOG.debug("found %s DiaForcedSources", len(sources)) 

626 return sources 

627 

628 def containsVisitDetector( 

629 self, 

630 visit: int, 

631 detector: int, 

632 region: Region, 

633 visit_time: astropy.time.Time, 

634 ) -> bool: 

635 # docstring is inherited from a base class 

636 src_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaSource) 

637 frcsrc_table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaForcedSource) 

638 # Query should load only one leaf page of the index 

639 query1 = sql.select(src_table.c.visit).filter_by(visit=visit, detector=detector).limit(1) 

640 

641 with self._engine.begin() as conn: 

642 result = conn.execute(query1).scalar_one_or_none() 

643 if result is not None: 

644 return True 

645 else: 

646 # Backup query if an image was processed but had no diaSources 

647 query2 = sql.select(frcsrc_table.c.visit).filter_by(visit=visit, detector=detector).limit(1) 

648 result = conn.execute(query2).scalar_one_or_none() 

649 return result is not None 

650 

651 def store( 

652 self, 

653 visit_time: astropy.time.Time, 

654 objects: pandas.DataFrame, 

655 sources: pandas.DataFrame | None = None, 

656 forced_sources: pandas.DataFrame | None = None, 

657 ) -> None: 

658 # docstring is inherited from a base class 

659 objects = self._fix_input_timestamps(objects) 

660 if sources is not None: 

661 sources = self._fix_input_timestamps(sources) 

662 if forced_sources is not None: 

663 forced_sources = self._fix_input_timestamps(forced_sources) 

664 

665 # We want to run all inserts in one transaction. 

666 with self._engine.begin() as connection: 

667 replica_chunk: ReplicaChunk | None = None 

668 if self._schema.replication_enabled: 

669 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, self.config.replica_chunk_seconds) 

670 self._storeReplicaChunk(replica_chunk, connection) 

671 

672 # fill pixelId column for DiaObjects 

673 objects = self._add_spatial_index(objects) 

674 self._storeDiaObjects(objects, visit_time, replica_chunk, connection) 

675 

676 if sources is not None: 

677 # fill pixelId column for DiaSources 

678 sources = self._add_spatial_index(sources) 

679 self._storeDiaSources(sources, replica_chunk, connection) 

680 

681 if forced_sources is not None: 

682 self._storeDiaForcedSources(forced_sources, replica_chunk, connection) 

683 

684 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

685 # docstring is inherited from a base class 

686 

687 timestamp: float | datetime.datetime 

688 if self._schema.has_mjd_timestamps: 

689 timestamp_column = "ssObjectReassocTimeMjdTai" 

690 timestamp = float(astropy.time.Time.now().tai.mjd) 

691 else: 

692 timestamp_column = "ssObjectReassocTime" 

693 timestamp = datetime.datetime.now(tz=datetime.UTC) 

694 

695 table = self._schema.get_table(ApdbTables.DiaSource) 

696 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId")) 

697 

698 with self._engine.begin() as conn: 

699 # Need to make sure that every ID exists in the database, but 

700 # executemany may not support rowcount, so iterate and check what 

701 # is missing. 

702 missing_ids: list[int] = [] 

703 for key, value in idMap.items(): 

704 params = { 

705 "srcId": key, 

706 "diaObjectId": 0, 

707 "ssObjectId": value, 

708 timestamp_column: timestamp, 

709 } 

710 result = conn.execute(query, params) 

711 if result.rowcount == 0: 

712 missing_ids.append(key) 

713 if missing_ids: 

714 missing = ",".join(str(item) for item in missing_ids) 

715 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

716 

717 def dailyJob(self) -> None: 

718 # docstring is inherited from a base class 

719 pass 

720 

721 def countUnassociatedObjects(self) -> int: 

722 # docstring is inherited from a base class 

723 

724 # Retrieve the DiaObject table. 

725 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject) 

726 

727 if self._schema.has_mjd_timestamps: 

728 validity_end_column = "validityEndMjdTai" 

729 else: 

730 validity_end_column = "validityEnd" 

731 

732 # Construct the sql statement. 

733 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1) 

734 stmt = stmt.where(table.columns[validity_end_column] == None) # noqa: E711 

735 

736 # Return the count. 

737 with self._engine.begin() as conn: 

738 count = conn.execute(stmt).scalar_one() 

739 

740 return count 

741 

742 @property 

743 def metadata(self) -> ApdbMetadata: 

744 # docstring is inherited from a base class 

745 return self._metadata 

746 

747 @property 

748 def admin(self) -> ApdbSqlAdmin: 

749 # docstring is inherited from a base class 

750 return ApdbSqlAdmin(self.pixelator) 

751 

752 def _getDiaSourcesInRegion(self, region: Region, start_time_mjdTai: float) -> pandas.DataFrame: 

753 """Return catalog of DiaSource instances from given region. 

754 

755 Parameters 

756 ---------- 

757 region : `lsst.sphgeom.Region` 

758 Region to search for DIASources. 

759 start_time_mjdTai : `float` 

760 Lower bound of time window for the query. 

761 

762 Returns 

763 ------- 

764 catalog : `pandas.DataFrame` 

765 Catalog containing DiaSource records. 

766 """ 

767 table = self._schema.get_table(ApdbTables.DiaSource) 

768 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource) 

769 query = sql.select(*columns) 

770 

771 # build selection 

772 time_filter = table.columns["midpointMjdTai"] > start_time_mjdTai 

773 where = sql.expression.and_(self._filterRegion(table, region), time_filter) 

774 query = query.where(where) 

775 

776 # execute select 

777 with self._timer("DiaSource_select_time", tags={"table": "DiaSource"}) as timer: 

778 with self._engine.begin() as conn: 

779 sources = pandas.read_sql_query(query, conn) 

780 timer.add_values(row_counts=len(sources)) 

781 _LOG.debug("found %s DiaSources", len(sources)) 

782 return self._fix_result_timestamps(sources) 

783 

784 def _getDiaSourcesByIDs(self, object_ids: list[int], start_time_mjdTai: float) -> pandas.DataFrame: 

785 """Return catalog of DiaSource instances given set of DiaObject IDs. 

786 

787 Parameters 

788 ---------- 

789 object_ids : 

790 Collection of DiaObject IDs 

791 start_time_mjdTai : `float` 

792 Lower bound of time window for the query. 

793 

794 Returns 

795 ------- 

796 catalog : `pandas.DataFrame` 

797 Catalog containing DiaSource records. 

798 """ 

799 with self._timer("select_time", tags={"table": "DiaSource"}) as timer: 

800 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, start_time_mjdTai) 

801 timer.add_values(row_count=len(sources)) 

802 

803 _LOG.debug("found %s DiaSources", len(sources)) 

804 return sources 

805 

806 def _getSourcesByIDs( 

807 self, table_enum: ApdbTables, object_ids: list[int], midpointMjdTai_start: float 

808 ) -> pandas.DataFrame: 

809 """Return catalog of DiaSource or DiaForcedSource instances given set 

810 of DiaObject IDs. 

811 

812 Parameters 

813 ---------- 

814 table : `sqlalchemy.schema.Table` 

815 Database table. 

816 object_ids : 

817 Collection of DiaObject IDs 

818 midpointMjdTai_start : `float` 

819 Earliest midpointMjdTai to retrieve. 

820 

821 Returns 

822 ------- 

823 catalog : `pandas.DataFrame` 

824 Catalog contaning DiaSource records. `None` is returned if 

825 ``read_sources_months`` configuration parameter is set to 0 or 

826 when ``object_ids`` is empty. 

827 """ 

828 table = self._schema.get_table(table_enum) 

829 columns = self._schema.get_apdb_columns(table_enum) 

830 

831 sources: pandas.DataFrame | None = None 

832 if len(object_ids) <= 0: 

833 _LOG.debug("ID list is empty, just fetch empty result") 

834 query = sql.select(*columns).where(sql.literal(False)) 

835 with self._engine.begin() as conn: 

836 sources = pandas.read_sql_query(query, conn) 

837 else: 

838 data_frames: list[pandas.DataFrame] = [] 

839 for ids in chunk_iterable(sorted(object_ids), 1000): 

840 query = sql.select(*columns) 

841 

842 # Some types like np.int64 can cause issues with 

843 # sqlalchemy, convert them to int. 

844 int_ids = [int(oid) for oid in ids] 

845 

846 # select by object id 

847 query = query.where( 

848 sql.expression.and_( 

849 table.columns["diaObjectId"].in_(int_ids), 

850 table.columns["midpointMjdTai"] > midpointMjdTai_start, 

851 ) 

852 ) 

853 

854 # execute select 

855 with self._engine.begin() as conn: 

856 data_frames.append(pandas.read_sql_query(query, conn)) 

857 

858 if len(data_frames) == 1: 

859 sources = data_frames[0] 

860 else: 

861 sources = pandas.concat(data_frames) 

862 assert sources is not None, "Catalog cannot be None" 

863 return self._fix_result_timestamps(sources) 

864 

865 def _storeReplicaChunk( 

866 self, 

867 replica_chunk: ReplicaChunk, 

868 connection: sqlalchemy.engine.Connection, 

869 ) -> None: 

870 # `visit_time.datetime` returns naive datetime, even though all astropy 

871 # times are in UTC. Add UTC timezone to timestamp so that database 

872 # can store a correct value. 

873 dt = datetime.datetime.fromtimestamp(replica_chunk.last_update_time.unix_tai, tz=datetime.UTC) 

874 

875 table = self._schema.get_table(ExtraTables.ApdbReplicaChunks) 

876 

877 # We need UPSERT which is dialect-specific construct 

878 values = {"last_update_time": dt, "unique_id": replica_chunk.unique_id} 

879 row = {"apdb_replica_chunk": replica_chunk.id} | values 

880 if connection.dialect.name == "sqlite": 

881 insert_sqlite = sqlalchemy.dialects.sqlite.insert(table) 

882 insert_sqlite = insert_sqlite.on_conflict_do_update(index_elements=table.primary_key, set_=values) 

883 connection.execute(insert_sqlite, row) 

884 elif connection.dialect.name == "postgresql": 

885 insert_pg = sqlalchemy.dialects.postgresql.dml.insert(table) 

886 insert_pg = insert_pg.on_conflict_do_update(constraint=table.primary_key, set_=values) 

887 connection.execute(insert_pg, row) 

888 else: 

889 raise TypeError(f"Unsupported dialect {connection.dialect.name} for upsert.") 

890 

891 def _storeDiaObjects( 

892 self, 

893 objs: pandas.DataFrame, 

894 visit_time: astropy.time.Time, 

895 replica_chunk: ReplicaChunk | None, 

896 connection: sqlalchemy.engine.Connection, 

897 ) -> None: 

898 """Store catalog of DiaObjects from current visit. 

899 

900 Parameters 

901 ---------- 

902 objs : `pandas.DataFrame` 

903 Catalog with DiaObject records. 

904 visit_time : `astropy.time.Time` 

905 Time of the visit. 

906 replica_chunk : `ReplicaChunk` 

907 Insert identifier. 

908 """ 

909 if len(objs) == 0: 

910 _LOG.debug("No objects to write to database.") 

911 return 

912 

913 # Some types like np.int64 can cause issues with sqlalchemy, convert 

914 # them to int. 

915 ids = sorted(int(oid) for oid in objs["diaObjectId"]) 

916 _LOG.debug("first object ID: %d", ids[0]) 

917 

918 if self._schema.has_mjd_timestamps: 

919 validity_start_column = "validityStartMjdTai" 

920 validity_end_column = "validityEndMjdTai" 

921 timestamp = float(visit_time.tai.mjd) 

922 else: 

923 validity_start_column = "validityStart" 

924 validity_end_column = "validityEnd" 

925 timestamp = visit_time.datetime 

926 

927 # everything to be done in single transaction 

928 if self.config.dia_object_index == "last_object_table": 

929 # Insert and replace all records in LAST table. 

930 table = self._schema.get_table(ApdbTables.DiaObjectLast) 

931 

932 # DiaObjectLast did not have this column in the past. 

933 use_validity_start = self._schema.check_column(ApdbTables.DiaObjectLast, validity_start_column) 

934 

935 # Drop the previous objects (pandas cannot upsert). 

936 query = table.delete().where(table.columns["diaObjectId"].in_(ids)) 

937 

938 with self._timer("delete_time", tags={"table": table.name}) as timer: 

939 res = connection.execute(query) 

940 timer.add_values(row_count=res.rowcount) 

941 _LOG.debug("deleted %s objects", res.rowcount) 

942 

943 # DiaObjectLast is a subset of DiaObject, strip missing columns 

944 last_column_names = [column.name for column in table.columns] 

945 if validity_start_column in last_column_names and validity_start_column not in objs.columns: 

946 last_column_names.remove(validity_start_column) 

947 last_objs = objs[last_column_names] 

948 last_objs = _coerce_uint64(last_objs) 

949 

950 # Fill validityStart, only when it is in the schema. 

951 if use_validity_start: 

952 if validity_start_column in last_objs: 

953 last_objs[validity_start_column] = timestamp 

954 else: 

955 extra_column = pandas.Series([timestamp] * len(last_objs), name=validity_start_column) 

956 last_objs.set_index(extra_column.index, inplace=True) 

957 last_objs = pandas.concat([last_objs, extra_column], axis="columns") 

958 

959 with self._timer("insert_time", tags={"table": "DiaObjectLast"}) as timer: 

960 last_objs.to_sql( 

961 table.name, 

962 connection, 

963 if_exists="append", 

964 index=False, 

965 schema=table.schema, 

966 ) 

967 timer.add_values(row_count=len(last_objs)) 

968 else: 

969 # truncate existing validity intervals 

970 table = self._schema.get_table(ApdbTables.DiaObject) 

971 

972 update = ( 

973 table.update() 

974 .values(**{validity_end_column: timestamp}) 

975 .where( 

976 sql.expression.and_( 

977 table.columns["diaObjectId"].in_(ids), 

978 table.columns[validity_end_column].is_(None), 

979 ) 

980 ) 

981 ) 

982 

983 with self._timer("truncate_time", tags={"table": table.name}) as timer: 

984 res = connection.execute(update) 

985 timer.add_values(row_count=res.rowcount) 

986 _LOG.debug("truncated %s intervals", res.rowcount) 

987 

988 objs = _coerce_uint64(objs) 

989 

990 # Fill additional columns 

991 extra_columns: list[pandas.Series] = [] 

992 if validity_start_column in objs.columns: 

993 objs[validity_start_column] = timestamp 

994 else: 

995 extra_columns.append(pandas.Series([timestamp] * len(objs), name=validity_start_column)) 

996 if validity_end_column in objs.columns: 

997 objs[validity_end_column] = None 

998 else: 

999 extra_columns.append(pandas.Series([None] * len(objs), name=validity_end_column)) 

1000 if extra_columns: 

1001 objs.set_index(extra_columns[0].index, inplace=True) 

1002 objs = pandas.concat([objs] + extra_columns, axis="columns") 

1003 

1004 # Insert replica data 

1005 table = self._schema.get_table(ApdbTables.DiaObject) 

1006 replica_data: list[dict] = [] 

1007 replica_stmt: Any = None 

1008 replica_table_name = "" 

1009 if replica_chunk is not None: 

1010 pk_names = [column.name for column in table.primary_key] 

1011 replica_data = objs[pk_names].to_dict("records") 

1012 if replica_data: 

1013 for row in replica_data: 

1014 row["apdb_replica_chunk"] = replica_chunk.id 

1015 replica_table = self._schema.get_table(ExtraTables.DiaObjectChunks) 

1016 replica_table_name = replica_table.name 

1017 replica_stmt = replica_table.insert() 

1018 

1019 # insert new versions 

1020 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1021 objs.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1022 timer.add_values(row_count=len(objs)) 

1023 if replica_stmt is not None: 

1024 with self._timer("insert_time", tags={"table": replica_table_name}) as timer: 

1025 connection.execute(replica_stmt, replica_data) 

1026 timer.add_values(row_count=len(replica_data)) 

1027 

1028 def _storeDiaSources( 

1029 self, 

1030 sources: pandas.DataFrame, 

1031 replica_chunk: ReplicaChunk | None, 

1032 connection: sqlalchemy.engine.Connection, 

1033 ) -> None: 

1034 """Store catalog of DiaSources from current visit. 

1035 

1036 Parameters 

1037 ---------- 

1038 sources : `pandas.DataFrame` 

1039 Catalog containing DiaSource records 

1040 """ 

1041 table = self._schema.get_table(ApdbTables.DiaSource) 

1042 

1043 # Insert replica data 

1044 replica_data: list[dict] = [] 

1045 replica_stmt: Any = None 

1046 replica_table_name = "" 

1047 if replica_chunk is not None: 

1048 pk_names = [column.name for column in table.primary_key] 

1049 replica_data = sources[pk_names].to_dict("records") 

1050 if replica_data: 

1051 for row in replica_data: 

1052 row["apdb_replica_chunk"] = replica_chunk.id 

1053 replica_table = self._schema.get_table(ExtraTables.DiaSourceChunks) 

1054 replica_table_name = replica_table.name 

1055 replica_stmt = replica_table.insert() 

1056 

1057 # everything to be done in single transaction 

1058 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1059 sources = _coerce_uint64(sources) 

1060 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1061 timer.add_values(row_count=len(sources)) 

1062 if replica_stmt is not None: 

1063 with self._timer("replica_insert_time", tags={"table": replica_table_name}) as timer: 

1064 connection.execute(replica_stmt, replica_data) 

1065 timer.add_values(row_count=len(replica_data)) 

1066 

1067 def _storeDiaForcedSources( 

1068 self, 

1069 sources: pandas.DataFrame, 

1070 replica_chunk: ReplicaChunk | None, 

1071 connection: sqlalchemy.engine.Connection, 

1072 ) -> None: 

1073 """Store a set of DiaForcedSources from current visit. 

1074 

1075 Parameters 

1076 ---------- 

1077 sources : `pandas.DataFrame` 

1078 Catalog containing DiaForcedSource records 

1079 """ 

1080 table = self._schema.get_table(ApdbTables.DiaForcedSource) 

1081 

1082 # Insert replica data 

1083 replica_data: list[dict] = [] 

1084 replica_stmt: Any = None 

1085 replica_table_name = "" 

1086 if replica_chunk is not None: 

1087 pk_names = [column.name for column in table.primary_key] 

1088 replica_data = sources[pk_names].to_dict("records") 

1089 if replica_data: 

1090 for row in replica_data: 

1091 row["apdb_replica_chunk"] = replica_chunk.id 

1092 replica_table = self._schema.get_table(ExtraTables.DiaForcedSourceChunks) 

1093 replica_table_name = replica_table.name 

1094 replica_stmt = replica_table.insert() 

1095 

1096 # everything to be done in single transaction 

1097 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1098 sources = _coerce_uint64(sources) 

1099 sources.to_sql(table.name, connection, if_exists="append", index=False, schema=table.schema) 

1100 timer.add_values(row_count=len(sources)) 

1101 if replica_stmt is not None: 

1102 with self._timer("insert_time", tags={"table": replica_table_name}) as timer: 

1103 connection.execute(replica_stmt, replica_data) 

1104 timer.add_values(row_count=len(replica_data)) 

1105 

1106 def _storeUpdateRecords( 

1107 self, 

1108 records: Iterable[ApdbUpdateRecord], 

1109 chunk: ReplicaChunk, 

1110 *, 

1111 store_chunk: bool = False, 

1112 connection: sqlalchemy.engine.Connection | None = None, 

1113 ) -> None: 

1114 """Store ApdbUpdateRecords in the replica table for those records. 

1115 

1116 Parameters 

1117 ---------- 

1118 records : `list` [`ApdbUpdateRecord`] 

1119 Records to store. 

1120 chunk : `ReplicaChunk` 

1121 Replica chunk for these records. 

1122 store_chunk : `bool` 

1123 If True then also store replica chunk. 

1124 connection : `sqlalchemy.engine.Connection` 

1125 SQLALchemy connection to use, if `None` the new connection will be 

1126 made. `None` is useful for tests only, regular use will call this 

1127 method in the same transaction that saves other types of records. 

1128 

1129 Raises 

1130 ------ 

1131 TypeError 

1132 Raised if replication is not enabled for this instance. 

1133 """ 

1134 if not self._schema.replication_enabled: 

1135 raise TypeError("Replication is not enabled for this APDB instance.") 

1136 

1137 apdb_replica_chunk = chunk.id 

1138 # Do not use unique_if from ReplicaChunk as it could be reused in 

1139 # multiple calls to this method. 

1140 update_unique_id = uuid.uuid4() 

1141 

1142 record_dicts = [] 

1143 for record in records: 

1144 record_dicts.append( 

1145 { 

1146 "apdb_replica_chunk": apdb_replica_chunk, 

1147 "update_time_ns": record.update_time_ns, 

1148 "update_order": record.update_order, 

1149 "update_unique_id": update_unique_id, 

1150 "update_payload": record.to_json(), 

1151 } 

1152 ) 

1153 

1154 if not record_dicts: 

1155 return 

1156 

1157 # TODO: Need to check that table exists. 

1158 table = self._schema.get_table(ExtraTables.ApdbUpdateRecordChunks) 

1159 

1160 def _do_store(connection: sqlalchemy.engine.Connection) -> None: 

1161 if store_chunk: 

1162 self._storeReplicaChunk(chunk, connection) 

1163 with self._timer("insert_time", tags={"table": table.name}) as timer: 

1164 connection.execute(table.insert(), record_dicts) 

1165 timer.add_values(row_count=len(record_dicts)) 

1166 

1167 if connection is None: 

1168 with self._engine.begin() as connection: 

1169 _do_store(connection) 

1170 else: 

1171 _do_store(connection) 

1172 

1173 def _htm_indices(self, region: Region) -> list[tuple[int, int]]: 

1174 """Generate a set of HTM indices covering specified region. 

1175 

1176 Parameters 

1177 ---------- 

1178 region: `sphgeom.Region` 

1179 Region that needs to be indexed. 

1180 

1181 Returns 

1182 ------- 

1183 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

1184 """ 

1185 _LOG.debug("region: %s", region) 

1186 indices = self.pixelator.envelope(region, self.config.pixelization.htm_max_ranges) 

1187 

1188 return indices.ranges() 

1189 

1190 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement: 

1191 """Make SQLAlchemy expression for selecting records in a region.""" 

1192 htm_index_column = table.columns[self.config.pixelization.htm_index_column] 

1193 exprlist = [] 

1194 pixel_ranges = self._htm_indices(region) 

1195 for low, upper in pixel_ranges: 

1196 upper -= 1 

1197 if low == upper: 

1198 exprlist.append(htm_index_column == low) 

1199 else: 

1200 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

1201 

1202 return sql.expression.or_(*exprlist) 

1203 

1204 def _add_spatial_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1205 """Calculate spatial index for each record and add it to a DataFrame. 

1206 

1207 Parameters 

1208 ---------- 

1209 df : `pandas.DataFrame` 

1210 DataFrame which has to contain ra/dec columns, names of these 

1211 columns are defined by configuration ``ra_dec_columns`` field. 

1212 

1213 Returns 

1214 ------- 

1215 df : `pandas.DataFrame` 

1216 DataFrame with ``pixelId`` column which contains pixel index 

1217 for ra/dec coordinates. 

1218 

1219 Notes 

1220 ----- 

1221 This overrides any existing column in a DataFrame with the same name 

1222 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

1223 returned. 

1224 """ 

1225 # calculate HTM index for every DiaObject 

1226 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

1227 ra_col, dec_col = self.config.ra_dec_columns 

1228 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1229 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

1230 idx = self.pixelator.index(uv3d) 

1231 htm_index[i] = idx 

1232 df = df.copy() 

1233 df[self.config.pixelization.htm_index_column] = htm_index 

1234 return df 

1235 

1236 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1237 """Update timestamp columns in input DataFrame to be aware datetime 

1238 type in in UTC. 

1239 

1240 AP pipeline generates naive datetime instances, we want them to be 

1241 aware before they go to database. All naive timestamps are assumed to 

1242 be in UTC timezone (they should be TAI). 

1243 """ 

1244 # Find all columns with aware non-UTC timestamps and convert to UTC. 

1245 columns = [ 

1246 column 

1247 for column, dtype in df.dtypes.items() 

1248 if isinstance(dtype, pandas.DatetimeTZDtype) and dtype.tz is not datetime.UTC 

1249 ] 

1250 for column in columns: 

1251 df[column] = df[column].dt.tz_convert(datetime.UTC) 

1252 # Find all columns with naive timestamps and add UTC timezone. 

1253 columns = [ 

1254 column for column, dtype in df.dtypes.items() if pandas.api.types.is_datetime64_dtype(dtype) 

1255 ] 

1256 for column in columns: 

1257 df[column] = df[column].dt.tz_localize(datetime.UTC) 

1258 return df 

1259 

1260 def _fix_result_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1261 """Update timestamp columns to be naive datetime type in returned 

1262 DataFrame. 

1263 

1264 AP pipeline code expects DataFrames to contain naive datetime columns, 

1265 while Postgres queries return timezone-aware type. This method converts 

1266 those columns to naive datetime in UTC timezone. 

1267 """ 

1268 # Find all columns with aware timestamps. 

1269 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)] 

1270 for column in columns: 

1271 # tz_convert(None) will convert to UTC and drop timezone. 

1272 df[column] = df[column].dt.tz_convert(None) 

1273 return df