Coverage for python / lsst / analysis / ap / apdbCassandra.py: 18%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:33 +0000

1# This file is part of analysis_ap. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["DbCassandraQuery"] 

25 

26import warnings 

27from typing import cast 

28 

29import pandas as pd 

30 

31from lsst.ap.association import UnpackApdbFlags 

32from lsst.dax.apdb import Apdb, ApdbCassandra, ApdbTables 

33from lsst.dax.apdb.cassandra.cassandra_utils import quote_id 

34from lsst.pipe.base import Instrument 

35from lsst.resources import ResourcePath, ResourcePathExpression 

36from .apdb import DbQuery 

37 

38 

39class DbCassandraQuery(DbQuery): 

40 """Implementation of `DbQuery` interface for Cassandra backend. 

41 

42 Parameters 

43 ---------- 

44 config_uri : `~lsst.resources.ResourcePathExpression` 

45 URI or local file path pointing to a file with serialized 

46 configuration, or a string with a "label:" prefix to locate 

47 configuration in APDB index. 

48 instrument : `str` 

49 Short name (e.g. "DECam") of instrument to make a dataId unpacker 

50 and to add to the table columns; supports any gen3 instrument. 

51 To be deprecated once this information is in the database. 

52 """ 

53 

54 timeout = 600 

55 """Timeout for queries in seconds. Regular timeout specified in APDB 

56 configuration could be too short for for full-scan queries that this 

57 class executes. 

58 """ 

59 

60 def __init__( 

61 self, 

62 config_uri: ResourcePathExpression, 

63 *, 

64 instrument: Instrument | None = None, 

65 ): 

66 self._instrument = instrument 

67 

68 flag_map = ResourcePath( 

69 "resource://lsst.ap.association/resources/data/association-flag-map.yaml" 

70 ) 

71 self._unpacker = UnpackApdbFlags(flag_map, "DiaSource") 

72 

73 self.set_excluded_diaSource_flags( 

74 [ 

75 "base_PixelFlags_flag_bad", 

76 "base_PixelFlags_flag_suspect", 

77 "base_PixelFlags_flag_saturatedCenter", 

78 "base_PixelFlags_flag_interpolated", 

79 "base_PixelFlags_flag_interpolatedCenter", 

80 "base_PixelFlags_flag_edge", 

81 ] 

82 ) 

83 

84 # We depend on ApdbCassandra for many things which we do not want to 

85 # reimplement here. 

86 apdb = Apdb.from_uri(config_uri) 

87 if not isinstance(apdb, ApdbCassandra): 

88 raise TypeError( 

89 f"Configuration file {config_uri} was produced for non-Cassandra backend." 

90 ) 

91 self._apdb = apdb 

92 

93 # NOTE: not getting instrument here, as I don't know the interface for 

94 # it and we don't want to rely on Cassandra for any analysis tooling. 

95 

96 def set_excluded_diaSource_flags(self, flag_list: list[str]) -> None: 

97 # Docstring is inherited from base class. 

98 for flag in flag_list: 

99 if not self._unpacker.flagExists(flag, columnName="flags"): 

100 raise ValueError(f"flag {flag} not included in DiaSource flags") 

101 

102 self.diaSource_flags_exclude = flag_list 

103 

104 def _filter_flags(self, catalog: pd.DataFrame, column_name: str = "flags") -> None: 

105 """Filter catalog contents to exclude . 

106 

107 Parameters 

108 ---------- 

109 catalog : `pandas.DataFrame` 

110 Catalog to filter, update happens in-place. 

111 column_name : `str`, optional 

112 Name of flag column to query. 

113 """ 

114 bitmask = int( 

115 self._unpacker.makeFlagBitMask( 

116 self.diaSource_flags_exclude, columnName=column_name 

117 ) 

118 ) 

119 if bitmask == 0: 

120 warnings.warn( 

121 f"Flag bitmask is zero. Supplied flags: {self.diaSource_flags_exclude}", 

122 RuntimeWarning, 

123 ) 

124 mask = (catalog[column_name] & bitmask) != 0 

125 catalog.drop(catalog[mask].index, inplace=True) 

126 

127 def _build_query( 

128 self, 

129 table: ApdbTables, 

130 *, 

131 columns: list[str] = [], 

132 where: str = "", 

133 limit: int = -1, 

134 ) -> str: 

135 """Build query for a specific table and selection.""" 

136 if columns: 

137 what = ",".join(quote_id(column) for column in columns) 

138 else: 

139 what = "*" 

140 

141 query = f'SELECT {what} from "{self._apdb._keyspace}"."{table.name}"' 

142 if where: 

143 query += f" WHERE {where}" 

144 if limit > 0: 

145 query += f" LIMIT {limit}" 

146 query += " ALLOW FILTERING" 

147 return query 

148 

149 def load_sources_for_object( 

150 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

151 ) -> pd.DataFrame: 

152 # Docstring is inherited from base class. 

153 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

154 query = self._build_query( 

155 ApdbTables.DiaSource, 

156 columns=column_names, 

157 where='"diaObjectId" = ?', 

158 limit=int(limit), 

159 ) 

160 statement = self._apdb._preparer.prepare(query) 

161 result = self._apdb._session.execute( 

162 statement, 

163 (dia_object_id,), 

164 timeout=self.timeout, 

165 execution_profile="read_pandas", 

166 ) 

167 catalog = cast(pd.DataFrame, result._current_rows) 

168 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True) 

169 

170 if exclude_flagged: 

171 self._filter_flags(catalog) 

172 

173 return catalog 

174 

175 def load_forced_sources_for_object( 

176 self, dia_object_id: int, exclude_flagged: bool = False, limit: int = 100000 

177 ) -> pd.DataFrame: 

178 # Docstring is inherited from base class. 

179 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

180 query = self._build_query( 

181 ApdbTables.DiaForcedSource, 

182 columns=column_names, 

183 where='"diaObjectId" = ?', 

184 limit=int(limit), 

185 ) 

186 statement = self._apdb._preparer.prepare(query) 

187 result = self._apdb._session.execute( 

188 statement, 

189 (dia_object_id,), 

190 timeout=self.timeout, 

191 execution_profile="read_pandas", 

192 ) 

193 catalog = cast(pd.DataFrame, result._current_rows) 

194 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True) 

195 

196 if exclude_flagged: 

197 self._filter_flags(catalog) 

198 

199 return catalog 

200 

201 def load_source(self, id: int) -> pd.Series: 

202 # Docstring is inherited from base class. 

203 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

204 query = self._build_query( 

205 ApdbTables.DiaSource, columns=column_names, where='"diaSourceId" = ?' 

206 ) 

207 statement = self._apdb._preparer.prepare(query) 

208 result = self._apdb._session.execute( 

209 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

210 ) 

211 catalog = cast(pd.DataFrame, result._current_rows) 

212 return catalog.iloc[0] 

213 

214 def load_sources( 

215 self, exclude_flagged: bool = False, limit: int = 100000 

216 ) -> pd.DataFrame: 

217 # Docstring is inherited from base class. 

218 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaSource) 

219 query = self._build_query( 

220 ApdbTables.DiaSource, 

221 columns=column_names, 

222 limit=int(limit), 

223 ) 

224 statement = self._apdb._preparer.prepare(query) 

225 result = self._apdb._session.execute( 

226 statement, timeout=self.timeout, execution_profile="read_pandas" 

227 ) 

228 catalog = cast(pd.DataFrame, result._current_rows) 

229 catalog.sort_values(by=["visit", "detector", "diaSourceId"], inplace=True) 

230 

231 if exclude_flagged: 

232 self._filter_flags(catalog) 

233 

234 return catalog 

235 

236 def load_object(self, id: int) -> pd.Series: 

237 # Docstring is inherited from base class. 

238 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

239 query = self._build_query( 

240 ApdbTables.DiaObjectLast, columns=column_names, where='"diaObjectId" = ?' 

241 ) 

242 statement = self._apdb._preparer.prepare(query) 

243 result = self._apdb._session.execute( 

244 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

245 ) 

246 catalog = cast(pd.DataFrame, result._current_rows) 

247 return catalog.iloc[0] 

248 

249 def load_objects(self, limit: int = 100000, latest: bool = True) -> pd.DataFrame: 

250 # Docstring is inherited from base class. 

251 if latest: 

252 table = ApdbTables.DiaObjectLast 

253 else: 

254 # when we do replication then we don't always generate DiaObject 

255 # contents. 

256 config = self._apdb.config 

257 if config.use_insert_id and config.use_insert_id_skips_diaobjects: 

258 raise ValueError("DiaObject history is not available for this database") 

259 table = ApdbTables.DiaObject 

260 

261 column_names = self._apdb._schema.apdbColumnNames(table) 

262 query = self._build_query( 

263 table, 

264 columns=column_names, 

265 limit=int(limit), 

266 ) 

267 statement = self._apdb._preparer.prepare(query) 

268 result = self._apdb._session.execute( 

269 statement, timeout=self.timeout, execution_profile="read_pandas" 

270 ) 

271 catalog = cast(pd.DataFrame, result._current_rows) 

272 catalog.sort_values(by=["diaObjectId"], inplace=True) 

273 return catalog 

274 

275 def load_forced_source(self, id: int) -> pd.Series: 

276 # Docstring is inherited from base class. 

277 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

278 query = self._build_query( 

279 ApdbTables.DiaForcedSource, 

280 columns=column_names, 

281 where='"diaForcedSourceId" = ?', 

282 ) 

283 statement = self._apdb._preparer.prepare(query) 

284 result = self._apdb._session.execute( 

285 statement, (id,), timeout=self.timeout, execution_profile="read_pandas" 

286 ) 

287 catalog = cast(pd.DataFrame, result._current_rows) 

288 return catalog.iloc[0] 

289 

290 def load_forced_sources(self, limit: int = 100000) -> pd.DataFrame: 

291 # Docstring is inherited from base class. 

292 column_names = self._apdb._schema.apdbColumnNames(ApdbTables.DiaForcedSource) 

293 query = self._build_query( 

294 ApdbTables.DiaForcedSource, 

295 columns=column_names, 

296 limit=int(limit), 

297 ) 

298 statement = self._apdb._preparer.prepare(query) 

299 result = self._apdb._session.execute( 

300 statement, timeout=self.timeout, execution_profile="read_pandas" 

301 ) 

302 catalog = cast(pd.DataFrame, result._current_rows) 

303 catalog.sort_values(by=["visit", "detector", "diaForcedSourceId"], inplace=True) 

304 

305 return catalog