Coverage for python / lsst / analysis / ap / apdbCassandra.py: 19%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:24 +0000

1# This file is part of analysis_ap. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["DbCassandraQuery"] 

25 

26import os 

27import warnings 

28from typing import TYPE_CHECKING, cast 

29 

30import pandas as pd 

31 

32import lsst.utils 

33from lsst.ap.association import UnpackApdbFlags 

34from lsst.dax.apdb import Apdb, ApdbCassandra, ApdbTables 

35from lsst.dax.apdb.cassandra.cassandra_utils import quote_id 

36from lsst.pipe.base import Instrument 

37from lsst.resources import ResourcePathExpression 

38from .apdb import DbQuery 

39 

40if TYPE_CHECKING: 

41 import lsst.daf.butler 

42 

43 

44class DbCassandraQuery(DbQuery): 

45 """Implementation of `DbQuery` interface for Cassandra backend. 

46 

47 Parameters 

48 ---------- 

49 config_uri : `~lsst.resources.ResourcePathExpression` 

50 URI or local file path pointing to a file with serialized 

51 configuration, or a string with a "label:" prefix to locate 

52 configuration in APDB index. 

53 instrument : `str` 

54 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

55 and to add to the table columns; supports any gen3 instrument. 

56 To be deprecated once this information is in the database. 

57 """ 

58 

59 timeout = 600 

60 """Timeout for queries in seconds. Regular timeout specified in APDB 

61 configuration could be too short for for full-scan queries that this 

62 class executes. 

63 """ 

64 

65 def __init__( 

66 self, 

67 config_uri: ResourcePathExpression, 

68 *, 

69 instrument: Instrument | None = None, 

70 ): 

71 self._instrument = instrument 

72 

73 flag_map = os.path.join( 

74 lsst.utils.getPackageDir("ap_association"), "data/association-flag-map.yaml" 

75 ) 

76 self._unpacker = UnpackApdbFlags(flag_map, "DiaSource") 

77 

78 self.set_excluded_diaSource_flags( 

79 [ 

80 "base_PixelFlags_flag_bad", 

81 "base_PixelFlags_flag_suspect", 

82 "base_PixelFlags_flag_saturatedCenter", 

83 "base_PixelFlags_flag_interpolated", 

84 "base_PixelFlags_flag_interpolatedCenter", 

85 "base_PixelFlags_flag_edge", 

86 ] 

87 ) 

88 

89 # We depend on ApdbCassandra for many things which we do not want to 

90 # reimplement here. 

91 apdb = Apdb.from_uri(config_uri) 

92 if not isinstance(apdb, ApdbCassandra): 

93 raise TypeError( 

94 f"Configuration file {config_uri} was produced for non-Cassandra backend." 

95 ) 

96 self._apdb = apdb 

97 

98 # NOTE: not getting instrument here, as I don't know the interface for 

99 # it and we don't want to rely on Cassandra for any analysis tooling. 

100 

101 def set_excluded_diaSource_flags(self, flag_list: list[str]) -> None: 

102 # Docstring is inherited from base class. 

103 for flag in flag_list: 

104 if not self._unpacker.flagExists(flag, columnName="flags"): 

105 raise ValueError(f"flag {flag} not included in DiaSource flags") 

106 

107 self.diaSource_flags_exclude = flag_list 

108 

109 def _filter_flags(self, catalog: pd.DataFrame, column_name: str = "flags") -> None: 

110 """Filter catalog contents to exclude . 

111 

112 Parameters 

113 ---------- 

114 catalog : `pandas.DataFrame` 

115 Catalog to filter, update happens in-place. 

116 column_name : `str`, optional 

117 Name of flag column to query. 

118 """ 

119 bitmask = int( 

120 self._unpacker.makeFlagBitMask( 

121 self.diaSource_flags_exclude, columnName=column_name 

122 ) 

123 ) 

124 if bitmask == 0: 

125 warnings.warn( 

126 f"Flag bitmask is zero. Supplied flags: {self.diaSource_flags_exclude}", 

127 RuntimeWarning, 

128 ) 

129 mask = (catalog[column_name] & bitmask) != 0 

130 catalog.drop(catalog[mask].index, inplace=True) 

131 

132 def _build_query( 

133 self, 

134 table: ApdbTables, 

135 *, 

136 columns: list[str] = [], 

137 where: str = "", 

138 limit: int = -1, 

139 ) -> str: 

140 """Build query for a specific table and selection.""" 

141 if columns: 

142 what = ",".join(quote_id(column) for column in columns) 

143 else: 

144 what = "*" 

145 

146 query = f'SELECT {what} from "{self._apdb._keyspace}"."{table.name}"' 

147 if where: 

148 query += f" WHERE {where}" 

149 if limit > 0: 

150 query += f" LIMIT {limit}" 

151 query += " ALLOW FILTERING" 

152 return query 

153 

154 def load_sources_for_object( 

155 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

156 ) -> pd.DataFrame: 

157 # Docstring is inherited from base class. 

158 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

159 query = self._build_query( 

160 ApdbTables.DiaSource, 

161 columns=column_names, 

162 where='"diaObjectId" = ?', 

163 limit=int(limit), 

164 ) 

165 statement = self._apdb._preparer.prepare(query) 

166 result = self._apdb._session.execute( 

167 statement, 

168 (dia_object_id,), 

169 timeout=self.timeout, 

170 execution_profile="read_pandas", 

171 ) 

172 catalog = cast(pd.DataFrame, result._current_rows) 

173 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True) 

174 

175 if exclude_flagged: 

176 self._filter_flags(catalog) 

177 

178 return catalog 

179 

180 def load_forced_sources_for_object( 

181 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

182 ) -> pd.DataFrame: 

183 # Docstring is inherited from base class. 

184 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

185 query = self._build_query( 

186 ApdbTables.DiaForcedSource, 

187 columns=column_names, 

188 where='"diaObjectId" = ?', 

189 limit=int(limit), 

190 ) 

191 statement = self._apdb._preparer.prepare(query) 

192 result = self._apdb._session.execute( 

193 statement, 

194 (dia_object_id,), 

195 timeout=self.timeout, 

196 execution_profile="read_pandas", 

197 ) 

198 catalog = cast(pd.DataFrame, result._current_rows) 

199 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True) 

200 

201 if exclude_flagged: 

202 self._filter_flags(catalog) 

203 

204 return catalog 

205 

206 def load_source(self, id: int) -> pd.Series: 

207 # Docstring is inherited from base class. 

208 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

209 query = self._build_query( 

210 ApdbTables.DiaSource, columns=column_names, where='"diaSourceId" = ?' 

211 ) 

212 statement = self._apdb._preparer.prepare(query) 

213 result = self._apdb._session.execute( 

214 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

215 ) 

216 catalog = cast(pd.DataFrame, result._current_rows) 

217 return catalog.iloc[0] 

218 

219 def load_sources( 

220 self, exclude_flagged: bool = False, limit: int = 100000 

221 ) -> pd.DataFrame: 

222 # Docstring is inherited from base class. 

223 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

224 query = self._build_query( 

225 ApdbTables.DiaSource, 

226 columns=column_names, 

227 limit=int(limit), 

228 ) 

229 statement = self._apdb._preparer.prepare(query) 

230 result = self._apdb._session.execute( 

231 statement, timeout=self.timeout, execution_profile="read_pandas" 

232 ) 

233 catalog = cast(pd.DataFrame, result._current_rows) 

234 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True) 

235 

236 if exclude_flagged: 

237 self._filter_flags(catalog) 

238 

239 return catalog 

240 

241 def load_object(self, id: int) -> pd.Series: 

242 # Docstring is inherited from base class. 

243 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

244 query = self._build_query( 

245 ApdbTables.DiaObjectLast, columns=column_names, where='"diaObjectId" = ?' 

246 ) 

247 statement = self._apdb._preparer.prepare(query) 

248 result = self._apdb._session.execute( 

249 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

250 ) 

251 catalog = cast(pd.DataFrame, result._current_rows) 

252 return catalog.iloc[0] 

253 

254 def load_objects(self, limit: int = 100000, latest: bool = True) -> pd.DataFrame: 

255 # Docstring is inherited from base class. 

256 if latest: 

257 table = ApdbTables.DiaObjectLast 

258 else: 

259 # when we do replication then we don't always generate DiaObject 

260 # contents. 

261 config = self._apdb.config 

262 if config.use_insert_id and config.use_insert_id_skips_diaobjects: 

263 raise ValueError("DiaObject history is not available for this database") 

264 table = ApdbTables.DiaObject 

265 

266 column_names = self._apdb._schema.apdbColumnNames(table) 

267 query = self._build_query( 

268 table, 

269 columns=column_names, 

270 limit=int(limit), 

271 ) 

272 statement = self._apdb._preparer.prepare(query) 

273 result = self._apdb._session.execute( 

274 statement, timeout=self.timeout, execution_profile="read_pandas" 

275 ) 

276 catalog = cast(pd.DataFrame, result._current_rows) 

277 catalog.sort_values(by=["diaObjectId"], inplace=True) 

278 return catalog 

279 

280 def load_forced_source(self, id: int) -> pd.Series: 

281 # Docstring is inherited from base class. 

282 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

283 query = self._build_query( 

284 ApdbTables.DiaForcedSource, 

285 columns=column_names, 

286 where='"diaForcedSourceId" = ?', 

287 ) 

288 statement = self._apdb._preparer.prepare(query) 

289 result = self._apdb._session.execute( 

290 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

291 ) 

292 catalog = cast(pd.DataFrame, result._current_rows) 

293 return catalog.iloc[0] 

294 

295 def load_forced_sources(self, limit: int = 100000) -> pd.DataFrame: 

296 # Docstring is inherited from base class. 

297 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

298 query = self._build_query( 

299 ApdbTables.DiaForcedSource, 

300 columns=column_names, 

301 limit=int(limit), 

302 ) 

303 statement = self._apdb._preparer.prepare(query) 

304 result = self._apdb._session.execute( 

305 statement, timeout=self.timeout, execution_profile="read_pandas" 

306 ) 

307 catalog = cast(pd.DataFrame, result._current_rows) 

308 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True) 

309 

310 return catalog