Coverage for python/lsst/analysis/ap/apdb.py: 25%

139 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-14 10:33 +0000

1# This file is part of analysis_ap. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""APDB connection management and data access tools. 

23""" 

24 

25__all__ = ["DbQuery", "ApdbSqliteQuery", "ApdbPostgresQuery"] 

26 

27import abc 

28import contextlib 

29 

30import pandas as pd 

31import sqlalchemy 

32 

33 

34class DbQuery(abc.ABC): 

35 """Base class for APDB connection and query management. 

36 

37 Subclasses must specify a ``connection`` property to use as a context- 

38 manager for queries. 

39 

40 Parameters 

41 ---------- 

42 instrument : `str` 

43 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

44 and to add to the table columns; supports any gen3 instrument. 

45 To be deprecated once this information is in the database. 

46 """ 

47 

48 def __init__(self, instrument=None): 

49 if not instrument: 

50 raise RuntimeError("Instrument is required until DM-39502, " 

51 "when it will be part of the APDB metadata.") 

52 self._instrument = instrument 

53 self.set_excluded_diaSource_flags(['pixelFlags_bad', 

54 'pixelFlags_suspect', 

55 'pixelFlags_saturatedCenter', 

56 'pixelFlags_interpolated', 

57 'pixelFlags_interpolatedCenter', 

58 'pixelFlags_edge', 

59 ]) 

60 

61 @property 

62 @contextlib.contextmanager 

63 @abc.abstractmethod 

64 def connection(self): 

65 """Context manager for database connections. 

66 

67 Yields 

68 ------ 

69 connection : `sqlalchemy.engine.Connection` 

70 Connection to the database that will be queried. Whether the 

71 connection is closed after the context manager is closed is 

72 implementation dependent. 

73 """ 

74 pass 

75 

76 def set_excluded_diaSource_flags(self, flag_list): 

77 """Set flags of diaSources to exclude when loading diaSources. 

78 

79 Any diaSources with configured flags are not returned 

80 when calling `load_sources_for_object` or `load_sources` 

81 with `exclude_flagged = True`. 

82 

83 Parameters 

84 ---------- 

85 flag_list : `list` [`str`] 

86 Flag names to exclude. 

87 """ 

88 for flag in flag_list: 

89 if flag not in self._tables["DiaSource"].columns: 

90 raise ValueError(f"flag {flag} not included in DiaSource flags") 

91 

92 self.diaSource_flags_exclude = flag_list 

93 

94 def _make_flag_exclusion_query(self, query, table, flag_list): 

95 """Return an SQL where query that excludes sources with chosen flags. 

96 

97 Parameters 

98 ---------- 

99 flag_list : `list` [`str`] 

100 Flag names to exclude. 

101 query : `sqlalchemy.sql.Query` 

102 Query to include the where statement in. 

103 table : `sqlalchemy.schema.Table` 

104 Table containing the column to be queried. 

105 

106 Returns 

107 ------- 

108 query : `sqlalchemy.sql.Query` 

109 Query that selects rows to exclude based on flags. 

110 """ 

111 # Build a query that selects any source with one or more chosen flags, 

112 # and return the opposite (`not_`) of that query. 

113 query = query.where(sqlalchemy.not_(sqlalchemy.or_(table.columns[flag_col] == 1 

114 for flag_col in flag_list))) 

115 return query 

116 

117 def load_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000): 

118 """Load diaSources for a single diaObject. 

119 

120 Parameters 

121 ---------- 

122 dia_object_id : `int` 

123 Id of object to load sources for. 

124 exclude_flagged : `bool`, optional 

125 Exclude sources that have selected flags set. 

126 Use `set_excluded_diaSource_flags` to configure which flags 

127 are excluded. 

128 limit : `int` 

129 Maximum number of rows to return. 

130 

131 Returns 

132 ------- 

133 data : `pandas.DataFrame` 

134 A data frame of diaSources for the specified diaObject. 

135 """ 

136 table = self._tables["DiaSource"] 

137 query = table.select().where(table.columns["diaObjectId"] == dia_object_id) 

138 if exclude_flagged: 

139 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

140 query = query.order_by(table.columns["visit"], 

141 table.columns["detector"], 

142 table.columns["diaSourceId"]) 

143 with self.connection as connection: 

144 result = pd.read_sql_query(query, connection) 

145 

146 self._fill_from_instrument(result) 

147 return result 

148 

149 def load_forced_sources_for_object(self, dia_object_id, exclude_flagged=False, limit=100000): 

150 """Load diaForcedSources for a single diaObject. 

151 

152 Parameters 

153 ---------- 

154 dia_object_id : `int` 

155 Id of object to load sources for. 

156 exclude_flagged : `bool`, optional 

157 Exclude sources that have selected flags set. 

158 Use `set_excluded_diaSource_flags` to configure which flags 

159 are excluded. 

160 limit : `int` 

161 Maximum number of rows to return. 

162 

163 Returns 

164 ------- 

165 data : `pandas.DataFrame` 

166 A data frame of diaSources for the specified diaObject. 

167 """ 

168 table = self._tables["DiaForcedSource"] 

169 query = table.select().where(table.columns["diaObjectId"] == dia_object_id) 

170 if exclude_flagged: 

171 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

172 query = query.order_by(table.columns["visit"], 

173 table.columns["detector"], 

174 table.columns["diaForcedSourceId"]) 

175 with self.connection as connection: 

176 result = pd.read_sql_query(query, connection) 

177 

178 self._fill_from_instrument(result) 

179 return result 

180 

181 def load_source(self, id): 

182 """Load one diaSource. 

183 

184 Parameters 

185 ---------- 

186 id : `int` 

187 The diaSourceId to load data for. 

188 

189 Returns 

190 ------- 

191 data : `pandas.Series` 

192 The requested diaSource. 

193 """ 

194 table = self._tables["DiaSource"] 

195 query = table.select().where(table.columns["diaSourceId"] == id) 

196 with self.connection as connection: 

197 result = pd.read_sql_query(query, connection) 

198 if len(result) == 0: 

199 raise RuntimeError(f"diaSourceId={id} not found in DiaSource table") 

200 

201 self._fill_from_instrument(result) 

202 return result.iloc[0] 

203 

204 def load_sources(self, exclude_flagged=False, limit=100000): 

205 """Load diaSources. 

206 

207 Parameters 

208 ---------- 

209 exclude_flagged : `bool`, optional 

210 Exclude sources that have selected flags set. 

211 Use `set_excluded_diaSource_flags` to configure which flags 

212 are excluded. 

213 limit : `int` 

214 Maximum number of rows to return. 

215 

216 Returns 

217 ------- 

218 data : `pandas.DataFrame` 

219 All available diaSources. 

220 """ 

221 table = self._tables["DiaSource"] 

222 query = table.select() 

223 if exclude_flagged: 

224 query = self._make_flag_exclusion_query(query, table, self.diaSource_flags_exclude) 

225 query = query.order_by(table.columns["visit"], 

226 table.columns["detector"], 

227 table.columns["diaSourceId"]) 

228 if limit is not None: 

229 query = query.limit(limit) 

230 

231 with self.connection as connection: 

232 result = pd.read_sql_query(query, connection) 

233 

234 self._fill_from_instrument(result) 

235 return result 

236 

237 def load_object(self, id): 

238 """Load the most-recently updated version of one diaObject. 

239 

240 Parameters 

241 ---------- 

242 id : `int` 

243 The diaObjectId to load data for. 

244 

245 Returns 

246 ------- 

247 data : `pandas.Series` 

248 The requested object. 

249 """ 

250 table = self._tables["DiaObject"] 

251 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711 

252 query = query.where(table.columns["diaObjectId"] == id) 

253 with self.connection as connection: 

254 result = pd.read_sql_query(query, connection) 

255 if len(result) == 0: 

256 raise RuntimeError(f"diaObjectId={id} not found in DiaObject table") 

257 

258 return result.iloc[0] 

259 

260 def load_objects(self, limit=100000, latest=True): 

261 """Load all diaObjects. 

262 

263 Parameters 

264 ---------- 

265 limit : `int` 

266 Maximum number of rows to return. 

267 latest : `bool` 

268 Only load diaObjects where validityEnd is None. 

269 These are the most-recently updated diaObjects. 

270 

271 Returns 

272 ------- 

273 data : `pandas.DataFrame` 

274 All available diaObjects. 

275 """ 

276 table = self._tables["DiaObject"] 

277 if latest: 

278 query = table.select().where(table.columns["validityEnd"] == None) # noqa: E711 

279 query = query.order_by(table.columns["diaObjectId"]) 

280 if limit is not None: 

281 query = query.limit(limit) 

282 

283 with self.connection as connection: 

284 result = pd.read_sql_query(query, connection) 

285 

286 return result 

287 

288 def load_forced_source(self, id): 

289 """Load one diaForcedSource. 

290 

291 Parameters 

292 ---------- 

293 id : `int` 

294 The diaForcedSourceId to load data for. 

295 

296 Returns 

297 ------- 

298 data : `pandas.Series` 

299 The requested forced source. 

300 """ 

301 table = self._tables["DiaForcedSource"] 

302 query = table.select().where(table.columns["diaForcedSourceId"] == id) 

303 with self.connection as connection: 

304 result = pd.read_sql_query(query, connection) 

305 if len(result) == 0: 

306 raise RuntimeError(f"diaForcedSourceId={id} not found in DiaForcedSource table") 

307 

308 self._fill_from_instrument(result) 

309 return result.iloc[0] 

310 

311 def load_forced_sources(self, limit=100000): 

312 """Load all diaForcedSources. 

313 

314 Parameters 

315 ---------- 

316 limit : `int` 

317 Maximum number of rows to return. 

318 

319 Returns 

320 ------- 

321 data : `pandas.DataFrame` 

322 All available diaForcedSources. 

323 """ 

324 table = self._tables["DiaForcedSource"] 

325 query = table.select() 

326 query = query.order_by(table.columns["visit"], 

327 table.columns["detector"], 

328 table.columns["diaForcedSourceId"]) 

329 if limit is not None: 

330 query = query.limit(limit) 

331 

332 with self.connection as connection: 

333 result = pd.read_sql_query(query, connection) 

334 self._fill_from_instrument(result) 

335 return result 

336 

337 def _fill_from_instrument(self, diaSources): 

338 """Add instrument to the database. 

339 This method is temporary, until APDB has instrument in its metadata. 

340 

341 Parameters 

342 ---------- 

343 diaSources : `pandas.core.frame.DataFrame` 

344 Pandas dataframe with diaSources from an APDB; modified in-place. 

345 """ 

346 # do nothing for an empty series 

347 if len(diaSources) == 0: 

348 return 

349 

350 diaSources['instrument'] = self._instrument 

351 

352 

353class ApdbSqliteQuery(DbQuery): 

354 """Open an sqlite3 APDB file to load data from it. 

355 

356 This class keeps the sqlite connection open after initialization because 

357 our sqlite usage is to load a local file. Closing and re-opening would 

358 re-scan the whole file every time, and we don't need to worry about 

359 multiple users when working with local sqlite files. 

360 

361 Parameters 

362 ---------- 

363 filename : `str` 

364 Path to the sqlite3 file containing the APDB to load. 

365 instrument : `str` 

366 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

367 and to add to the table columns; supports any gen3 instrument. 

368 To be deprecated once this information is in the database. 

369 """ 

370 

371 def __init__(self, filename, instrument=None, **kwargs): 

372 # For sqlite, use a larger pool and a faster timeout, to allow many 

373 # repeat transactions with the same connection, as transactions on 

374 # our sqlite DBs should be small and fast. 

375 self._engine = sqlalchemy.create_engine(f"sqlite:///{filename}", 

376 pool_timeout=5, pool_size=200) 

377 

378 with self.connection as connection: 

379 metadata = sqlalchemy.MetaData() 

380 metadata.reflect(bind=connection) 

381 self._tables = metadata.tables 

382 super().__init__(instrument=instrument, **kwargs) 

383 

384 @property 

385 @contextlib.contextmanager 

386 def connection(self): 

387 yield self._engine.connect() 

388 

389 

390class ApdbPostgresQuery(DbQuery): 

391 """Connect to a running postgres APDB instance and load data from it. 

392 

393 This class connects to the database only when the ``connection`` context 

394 manager is entered, and closes the connection after it exits. 

395 

396 Parameters 

397 ---------- 

398 namespace : `str` 

399 Database namespace to load from. Called "schema" in postgres docs. 

400 url : `str` 

401 Complete url to connect to postgres database, without prepended 

402 ``postgresql://``. 

403 instrument : `str` 

404 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

405 and to add to the table columns; supports any gen3 instrument. 

406 To be deprecated once this information is in the database. 

407 """ 

408 

409 def __init__(self, namespace, url="rubin@usdf-prompt-processing-dev.slac.stanford.edu/lsst-devl", 

410 instrument=None, **kwargs): 

411 self._connection_string = f"postgresql://{url}" 

412 self._namespace = namespace 

413 self._engine = sqlalchemy.create_engine(self._connection_string, poolclass=sqlalchemy.pool.NullPool) 

414 

415 with self.connection as connection: 

416 metadata = sqlalchemy.MetaData(schema=namespace) 

417 metadata.reflect(bind=connection) 

418 # ensure tables don't have schema prepended 

419 self._tables = {} 

420 for table in metadata.tables.values(): 

421 self._tables[table.name] = table 

422 super().__init__(instrument=instrument, **kwargs) 

423 

424 @property 

425 @contextlib.contextmanager 

426 def connection(self): 

427 _connection = self._engine.connect() 

428 try: 

429 yield _connection 

430 finally: 

431 _connection.close()