Coverage for python / lsst / analysis / ap / apdb.py: 22%

164 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:46 +0000

1# This file is part of analysis_ap. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""APDB connection management and data access tools. 

23""" 

24 

25__all__ = ["DbQuery", "ApdbSqliteQuery", "ApdbPostgresQuery"] 

26 

27import abc 

28import contextlib 

29import warnings 

30 

31import pandas as pd 

32import sqlalchemy 

33 

34 

35class DbQuery(abc.ABC): 

36 """Abstract interface for APDB queries. 

37 

38 Notes 

39 ----- 

40 APDB interface used by AP pipeline is defined by `lsst.dax.apdb.Apdb` 

41 class. Methods in this class are for non-pipeline tools that can analyse 

42 data produced by pipeline. APDB schema is not designed for analysis queries 

43 and performance of these methods can be non-optimal, especially for 

44 Cassandra backend. It is expected that these analysis queries should not be 

45 executed on production Cassandra service. 

46 """ 

47 

48 def set_excluded_diaSource_flags(self, flag_list: list[str]) -> None: 

49 """Set flags of diaSources to exclude when loading diaSources. 

50 

51 Any diaSources with configured flags are not returned 

52 when calling `load_sources_for_object` or `load_sources` 

53 with `exclude_flagged = True`. 

54 

55 Parameters 

56 ---------- 

57 flag_list : `list` [`str`] 

58 Flag names to exclude. 

59 """ 

60 raise NotImplementedError() 

61 

62 def load_sources_for_object( 

63 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

64 ) -> pd.DataFrame: 

65 """Load diaSources for a single diaObject. 

66 

67 Parameters 

68 ---------- 

69 dia_object_id : `int` 

70 Id of object to load sources for. 

71 exclude_flagged : `bool`, optional 

72 Exclude sources that have selected flags set. 

73 Use `set_excluded_diaSource_flags` to configure which flags 

74 are excluded. 

75 limit : `int` 

76 Maximum number of rows to return. 

77 

78 Returns 

79 ------- 

80 data : `pandas.DataFrame` 

81 A data frame of diaSources for the specified diaObject. 

82 """ 

83 raise NotImplementedError() 

84 

85 def load_forced_sources_for_object( 

86 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

87 ) -> pd.DataFrame: 

88 """Load diaForcedSources for a single diaObject. 

89 

90 Parameters 

91 ---------- 

92 dia_object_id : `int` 

93 Id of object to load sources for. 

94 exclude_flagged : `bool`, optional 

95 Exclude sources that have selected flags set. 

96 Use `set_excluded_diaSource_flags` to configure which flags 

97 are excluded. 

98 limit : `int` 

99 Maximum number of rows to return. 

100 

101 Returns 

102 ------- 

103 data : `pandas.DataFrame` 

104 A data frame of diaSources for the specified diaObject. 

105 """ 

106 raise NotImplementedError() 

107 

108 def load_source(self, id: int) -> pd.Series: 

109 """Load one diaSource. 

110 

111 Parameters 

112 ---------- 

113 id : `int` 

114 The diaSourceId to load data for. 

115 

116 Returns 

117 ------- 

118 data : `pandas.Series` 

119 The requested diaSource. 

120 """ 

121 raise NotImplementedError() 

122 

123 def load_sources(self, exclude_flagged: bool = False, limit: int = 100000) -> pd.DataFrame: 

124 """Load diaSources. 

125 

126 Parameters 

127 ---------- 

128 exclude_flagged : `bool`, optional 

129 Exclude sources that have selected flags set. 

130 Use `set_excluded_diaSource_flags` to configure which flags 

131 are excluded. 

132 limit : `int` 

133 Maximum number of rows to return. 

134 

135 Returns 

136 ------- 

137 data : `pandas.DataFrame` 

138 All available diaSources. 

139 """ 

140 raise NotImplementedError() 

141 

142 def load_object(self, id: int) -> pd.Series: 

143 """Load the most-recently updated version of one diaObject. 

144 

145 Parameters 

146 ---------- 

147 id : `int` 

148 The diaObjectId to load data for. 

149 

150 Returns 

151 ------- 

152 data : `pandas.Series` 

153 The requested object. 

154 """ 

155 raise NotImplementedError() 

156 

157 def load_objects(self, limit: int = 100000, latest: bool = True) -> pd.DataFrame: 

158 """Load all diaObjects. 

159 

160 Parameters 

161 ---------- 

162 limit : `int` 

163 Maximum number of rows to return. 

164 latest : `bool` 

165 Only load diaObjects where validityEnd is None. 

166 These are the most-recently updated diaObjects. 

167 

168 Returns 

169 ------- 

170 data : `pandas.DataFrame` 

171 All available diaObjects. 

172 """ 

173 raise NotImplementedError() 

174 

175 def load_forced_source(self, id: int) -> pd.Series: 

176 """Load one diaForcedSource. 

177 

178 Parameters 

179 ---------- 

180 id : `int` 

181 The diaForcedSourceId to load data for. 

182 

183 Returns 

184 ------- 

185 data : `pandas.Series` 

186 The requested forced source. 

187 """ 

188 raise NotImplementedError() 

189 

190 def load_forced_sources(self, limit: int = 100000) -> pd.DataFrame: 

191 """Load all diaForcedSources. 

192 

193 Parameters 

194 ---------- 

195 limit : `int` 

196 Maximum number of rows to return. 

197 

198 Returns 

199 ------- 

200 data : `pandas.DataFrame` 

201 All available diaForcedSources. 

202 """ 

203 raise NotImplementedError() 

204 

205 

206class DbSqlQuery(DbQuery): 

207 """Base class for APDB connection and query management for SQL backends. 

208 

209 Subclasses must specify a ``connection`` property to use as a context- 

210 manager for queries. 

211 

212 Parameters 

213 ---------- 

214 instrument : `str` 

215 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

216 and to add to the table columns; supports any gen3 instrument. 

217 To be deprecated once this information is in the database. 

218 """ 

219 

220 def __init__(self, instrument=None): 

221 if instrument is not None: 

222 warnings.warn("The instrument name is now pulled from the APDB; " 

223 "this kwarg is ignored and will be removed after v29", 

224 FutureWarning, 

225 stacklevel=2) 

226 

227 self.set_excluded_diaSource_flags(['pixelFlags_bad', 

228 'pixelFlags_suspect', 

229 'pixelFlags_saturatedCenter', 

230 'pixelFlags_interpolated', 

231 'pixelFlags_interpolatedCenter', 

232 'pixelFlags_edge', 

233 ]) 

234 

235 key = "instrument" 

236 table = self._tables["metadata"] 

237 sql = sqlalchemy.sql.select(table.columns.value).where(table.columns.name == key) 

238 with self.connection as conn: 

239 result = conn.execute(sql) 

240 self._instrument = result.scalar() 

241 

242 @property 

243 @contextlib.contextmanager 

244 @abc.abstractmethod 

245 def connection(self): 

246 """Context manager for database connections. 

247 

248 Yields 

249 ------ 

250 connection : `sqlalchemy.engine.Connection` 

251 Connection to the database that will be queried. Whether the 

252 connection is closed after the context manager is closed is 

253 implementation dependent. 

254 """ 

255 pass 

256 

257 def set_excluded_diaSource_flags(self, flag_list): 

258 """Set flags of diaSources to exclude when loading diaSources. 

259 

260 Any diaSources with configured flags are not returned 

261 when calling `load_sources_for_object` or `load_sources` 

262 with `exclude_flagged = True`. 

263 

264 Parameters 

265 ---------- 

266 flag_list : `list` [`str`] 

267 Flag names to exclude. 

268 """ 

269 for flag in flag_list: 

270 if flag not in self._tables["DiaSource"].columns: 

271 raise ValueError(f"flag {flag} not included in DiaSource flags") 

272 

273 self.diaSource_flags_exclude = flag_list 

274 

275 def _make_flag_exclusion_query(self, query, table, flag_list): 

276 """Return an SQL where query that excludes sources with chosen flags. 

277 

278 Parameters 

279 ---------- 

280 flag_list : `list` [`str`] 

281 Flag names to exclude. 

282 query : `sqlalchemy.sql.Query` 

283 Query to include the where statement in. 

284 table : `sqlalchemy.schema.Table` 

285 Table containing the column to be queried. 

286 

287 Returns 

288 ------- 

289 query : `sqlalchemy.sql.Query` 

290 Query that selects rows to exclude based on flags. 

291 """ 

292 # Build a query that selects any source with one or more chosen flags, 

293 # and return the opposite (`not_`) of that query. 

294 query = query.where(sqlalchemy.and_(table.columns[flag_col] == False # noqa: E712 

295 for flag_col in flag_list)) 

296 return query 

297 

298 def load_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000): 

299 """Load diaSources for a single diaObject. 

300 

301 Parameters 

302 ---------- 

303 dia_object_id : `int` 

304 Id of object to load sources for. 

305 exclude_flagged : `bool`, optional 

306 Exclude sources that have selected flags set. 

307 Use `set_excluded_diaSource_flags` to configure which flags 

308 are excluded. 

309 limit : `int` 

310 Maximum number of rows to return. 

311 

312 Returns 

313 ------- 

314 data : `pandas.DataFrame` 

315 A data frame of diaSources for the specified diaObject. 

316 """ 

317 table = self._tables["DiaSource"] 

318 query = table.select().where(table.columns["diaObjectId"] == dia_object_id) 

319 if exclude_flagged: 

320 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

321 query = query.order_by(table.columns["visit"], 

322 table.columns["detector"], 

323 table.columns["diaSourceId"]) 

324 with self.connection as connection: 

325 result = pd.read_sql_query(query, connection) 

326 

327 self._fill_from_instrument(result) 

328 return result 

329 

330 def load_forced_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000): 

331 """Load diaForcedSources for a single diaObject. 

332 

333 Parameters 

334 ---------- 

335 dia_object_id : `int` 

336 Id of object to load sources for. 

337 exclude_flagged : `bool`, optional 

338 Exclude sources that have selected flags set. 

339 Use `set_excluded_diaSource_flags` to configure which flags 

340 are excluded. 

341 limit : `int` 

342 Maximum number of rows to return. 

343 

344 Returns 

345 ------- 

346 data : `pandas.DataFrame` 

347 A data frame of diaSources for the specified diaObject. 

348 """ 

349 table = self._tables["DiaForcedSource"] 

350 query = table.select().where(table.columns["diaObjectId"] == dia_object_id) 

351 if exclude_flagged: 

352 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

353 query = query.order_by(table.columns["visit"], 

354 table.columns["detector"], 

355 table.columns["diaForcedSourceId"]) 

356 with self.connection as connection: 

357 result = pd.read_sql_query(query, connection) 

358 

359 self._fill_from_instrument(result) 

360 return result 

361 

362 def load_source(self, id): 

363 """Load one diaSource. 

364 

365 Parameters 

366 ---------- 

367 id : `int` 

368 The diaSourceId to load data for. 

369 

370 Returns 

371 ------- 

372 data : `pandas.Series` 

373 The requested diaSource. 

374 """ 

375 table = self._tables["DiaSource"] 

376 query = table.select().where(table.columns["diaSourceId"] == id) 

377 with self.connection as connection: 

378 result = pd.read_sql_query(query, connection) 

379 if len(result) == 0: 

380 raise RuntimeError(f"diaSourceId={id} not found in DiaSource table") 

381 

382 self._fill_from_instrument(result) 

383 return result.iloc[0] 

384 

385 def load_sources(self, exclude_flagged=False, limit=100000): 

386 """Load diaSources. 

387 

388 Parameters 

389 ---------- 

390 exclude_flagged : `bool`, optional 

391 Exclude sources that have selected flags set. 

392 Use `set_excluded_diaSource_flags` to configure which flags 

393 are excluded. 

394 limit : `int` 

395 Maximum number of rows to return. 

396 

397 Returns 

398 ------- 

399 data : `pandas.DataFrame` 

400 All available diaSources. 

401 """ 

402 table = self._tables["DiaSource"] 

403 query = table.select() 

404 if exclude_flagged: 

405 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

406 query = query.order_by(table.columns["visit"], 

407 table.columns["detector"], 

408 table.columns["diaSourceId"]) 

409 if limit is not None: 

410 query = query.limit(limit) 

411 

412 with self.connection as connection: 

413 result = pd.read_sql_query(query, connection) 

414 

415 self._fill_from_instrument(result) 

416 return result 

417 

418 def load_object(self, id): 

419 """Load the most-recently updated version of one diaObject. 

420 

421 Parameters 

422 ---------- 

423 id : `int` 

424 The diaObjectId to load data for. 

425 

426 Returns 

427 ------- 

428 data : `pandas.Series` 

429 The requested object. 

430 """ 

431 table = self._tables["DiaObject"] 

432 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711 

433 query = query.where(table.columns["diaObjectId"] == id) 

434 with self.connection as connection: 

435 result = pd.read_sql_query(query, connection) 

436 if len(result) == 0: 

437 raise RuntimeError(f"diaObjectId={id} not found in DiaObject table") 

438 

439 return result.iloc[0] 

440 

441 def load_objects(self, limit=100000, latest=True): 

442 """Load all diaObjects. 

443 

444 Parameters 

445 ---------- 

446 limit : `int` 

447 Maximum number of rows to return. 

448 latest : `bool` 

449 Only load diaObjects where validityEnd is None. 

450 These are the most-recently updated diaObjects. 

451 

452 Returns 

453 ------- 

454 data : `pandas.DataFrame` 

455 All available diaObjects. 

456 """ 

457 table = self._tables["DiaObject"] 

458 if latest: 

459 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711 

460 else: 

461 query = table.select() 

462 query = query.order_by(table.columns["diaObjectId"]) 

463 if limit is not None: 

464 query = query.limit(limit) 

465 

466 with self.connection as connection: 

467 result = pd.read_sql_query(query, connection) 

468 

469 return result 

470 

471 def load_forced_source(self, id): 

472 """Load one diaForcedSource. 

473 

474 Parameters 

475 ---------- 

476 id : `int` 

477 The diaForcedSourceId to load data for. 

478 

479 Returns 

480 ------- 

481 data : `pandas.Series` 

482 The requested forced source. 

483 """ 

484 table = self._tables["DiaForcedSource"] 

485 query = table.select().where(table.columns["diaForcedSourceId"] == id) 

486 with self.connection as connection: 

487 result = pd.read_sql_query(query, connection) 

488 if len(result) == 0: 

489 raise RuntimeError(f"diaForcedSourceId={id} not found in DiaForcedSource table") 

490 

491 self._fill_from_instrument(result) 

492 return result.iloc[0] 

493 

494 def load_forced_sources(self, limit=100000): 

495 """Load all diaForcedSources. 

496 

497 Parameters 

498 ---------- 

499 limit : `int` 

500 Maximum number of rows to return. 

501 

502 Returns 

503 ------- 

504 data : `pandas.DataFrame` 

505 All available diaForcedSources. 

506 """ 

507 table = self._tables["DiaForcedSource"] 

508 query = table.select() 

509 query = query.order_by(table.columns["visit"], 

510 table.columns["detector"], 

511 table.columns["diaForcedSourceId"]) 

512 if limit is not None: 

513 query = query.limit(limit) 

514 

515 with self.connection as connection: 

516 result = pd.read_sql_query(query, connection) 

517 self._fill_from_instrument(result) 

518 return result 

519 

520 def _fill_from_instrument(self, diaSources): 

521 """Add an instrument column to a list of sources. 

522 This method is temporary, until APDB has instrument in its metadata. 

523 

524 Parameters 

525 ---------- 

526 diaSources : `pandas.core.frame.DataFrame` 

527 Pandas dataframe with diaSources from an APDB; modified in-place. 

528 """ 

529 # do nothing for an empty series 

530 if len(diaSources) == 0: 

531 return 

532 

533 diaSources['instrument'] = self._instrument 

534 

535 

536class ApdbSqliteQuery(DbSqlQuery): 

537 """Open an sqlite3 APDB file to load data from it. 

538 

539 This class keeps the sqlite connection open after initialization because 

540 our sqlite usage is to load a local file. Closing and re-opening would 

541 re-scan the whole file every time, and we don't need to worry about 

542 multiple users when working with local sqlite files. 

543 

544 Parameters 

545 ---------- 

546 filename : `str` 

547 Path to the sqlite3 file containing the APDB to load. 

548 instrument : `str` 

549 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

550 and to add to the table columns; supports any gen3 instrument. 

551 To be deprecated once this information is in the database. 

552 """ 

553 

554 def __init__(self, filename, instrument=None, **kwargs): 

555 # For sqlite, use a larger pool and a faster timeout, to allow many 

556 # repeat transactions with the same connection, as transactions on 

557 # our sqlite DBs should be small and fast. 

558 self._engine = sqlalchemy.create_engine(f"sqlite:///{filename}", 

559 pool_timeout=5, pool_size=200) 

560 

561 with self.connection as connection: 

562 metadata = sqlalchemy.MetaData() 

563 metadata.reflect(bind=connection) 

564 self._tables = metadata.tables 

565 super().__init__(**kwargs) 

566 

567 @property 

568 @contextlib.contextmanager 

569 def connection(self): 

570 yield self._engine.connect() 

571 

572 

573class ApdbPostgresQuery(DbSqlQuery): 

574 """Connect to a running postgres APDB instance and load data from it. 

575 

576 This class connects to the database only when the ``connection`` context 

577 manager is entered, and closes the connection after it exits. 

578 

579 Parameters 

580 ---------- 

581 namespace : `str` 

582 Database namespace to load from. Called "schema" in postgres docs. 

583 url : `str` 

584 Complete url to connect to postgres database, without prepended 

585 ``postgresql://``. 

586 instrument : `str` 

587 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

588 and to add to the table columns; supports any gen3 instrument. 

589 To be deprecated once this information is in the database. 

590 """ 

591 

592 def __init__(self, namespace, url="rubin@usdf-prompt-processing-dev.slac.stanford.edu/lsst-devl", 

593 instrument=None, **kwargs): 

594 self._connection_string = f"postgresql://{url}" 

595 self._namespace = namespace 

596 self._engine = sqlalchemy.create_engine(self._connection_string, poolclass=sqlalchemy.pool.NullPool) 

597 

598 with self.connection as connection: 

599 metadata = sqlalchemy.MetaData(schema=namespace) 

600 metadata.reflect(bind=connection) 

601 # ensure tables don't have schema prepended 

602 self._tables = {} 

603 for table in metadata.tables.values(): 

604 self._tables[table.name] = table 

605 super().__init__(instrument=instrument, **kwargs) 

606 

607 @property 

608 @contextlib.contextmanager 

609 def connection(self): 

610 _connection = self._engine.connect() 

611 try: 

612 yield _connection 

613 finally: 

614 _connection.close()