Coverage for python / lsst / dax / apdb / cassandra / cassandra_utils.py: 21%

145 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:19 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "ApdbCassandraTableData", 

26 "PreparedStatementCache", 

27 "StatementFactory", 

28 "literal", 

29 "quote_id", 

30 "raw_data_factory", 

31 "select_concurrent", 

32] 

33 

34import logging 

35import warnings 

36from collections.abc import Collection, Iterable, Iterator, Sequence 

37from datetime import datetime, timedelta 

38from typing import Any 

39from uuid import UUID 

40 

41import felis.datamodel 

42import numpy as np 

43import pandas 

44 

45# If cassandra-driver is not there the module can still be imported 

46# but things will not work. 

47try: 

48 import cassandra.concurrent 

49 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Session 

50 from cassandra.query import PreparedStatement, SimpleStatement 

51 

52 CASSANDRA_IMPORTED = True 

53except ImportError: 

54 CASSANDRA_IMPORTED = False 

55 EXEC_PROFILE_DEFAULT = object() 

56 

57from .. import schema_model 

58from ..apdbReplica import ApdbTableData 

59from .queries import Query 

60 

61_LOG = logging.getLogger(__name__) 

62 

63 

64class ApdbCassandraTableData(ApdbTableData): 

65 """Implementation of ApdbTableData that wraps Cassandra raw data.""" 

66 

67 def __init__(self, columns: list[str], rows: list[tuple]): 

68 self._columns = columns 

69 self._rows = rows 

70 self._column_types: dict[str, felis.datamodel.DataType] = {} 

71 

72 def set_column_types(self, types: dict[str, felis.datamodel.DataType]) -> None: 

73 """Update column types. 

74 

75 Parameters 

76 ---------- 

77 types : `dict`[`str`, `felis.datamodel.DataType`] 

78 Mapping of column name its type. 

79 

80 Notes 

81 ----- 

82 Due to the way how instances of this class are constructed it is 

83 impossible to pass types of columns to the constructor, instead we will 

84 need to make a call to this method after construction. 

85 """ 

86 self._column_types = types 

87 

88 def column_names(self) -> Sequence[str]: 

89 # docstring inherited 

90 return self._columns 

91 

92 def column_defs(self) -> Sequence[tuple[str, felis.datamodel.DataType]]: 

93 return tuple((column, self._column_types[column]) for column in self._columns) 

94 

95 def rows(self) -> Collection[tuple]: 

96 # docstring inherited 

97 return self._rows 

98 

99 def append(self, other: ApdbCassandraTableData) -> None: 

100 """Extend rows in this table with rows in other table""" 

101 if self._columns != other._columns: 

102 raise ValueError(f"Different columns returned by queries: {self._columns} and {other._columns}") 

103 self._rows.extend(other._rows) 

104 

105 def project(self, *, drop: Iterable[str] = set()) -> None: 

106 """Modify data in place by droppiing some columns.""" 

107 drop_set = set(drop) 

108 if not drop_set: 

109 return 

110 

111 drop_idx = [] 

112 for idx, col_name in enumerate(self._columns): 

113 if col_name in drop_set: 

114 drop_idx.append(idx) 

115 # Have to reverse it so deletion does not change index. 

116 drop_idx.reverse() 

117 

118 for row_idx in range(len(self._rows)): 

119 row = list(self._rows[row_idx]) 

120 for idx in drop_idx: 

121 del row[idx] 

122 self._rows[row_idx] = tuple(row) 

123 

124 for idx in drop_idx: 

125 del self._columns[idx] 

126 

127 def to_pandas(self, table: schema_model.Table) -> pandas.DataFrame: 

128 """Convert data to pandas DataFrame. 

129 

130 Parameters 

131 ---------- 

132 table : `schema_model.Table` 

133 Table schema matching the data in this instance. 

134 

135 Returns 

136 ------- 

137 dataframe : `pandas.DataFrame` 

138 Resulting DataFrame. 

139 """ 

140 column_types = {column_def.name: column_def.pandas_type for column_def in table.columns} 

141 

142 # In rare cases there could be columns that are not in the configured 

143 # schema, e.g. during schema migrations. Use object column type for 

144 # them but also produce a warning. 

145 extra_columns = [column for column in self._columns if column not in column_types] 

146 if extra_columns: 

147 warnings.warn( 

148 f"Query result includes column(s) do not appear in schema for table {table.name}: " 

149 f"{', '.join(extra_columns)}", 

150 stacklevel=2, 

151 ) 

152 

153 if not self._rows: 

154 column_data = {} 

155 for column in self._columns: 

156 column_data[column] = pandas.Series(dtype=column_types.get(column, object)) 

157 return pandas.DataFrame(column_data) 

158 

159 # To avoid nested loops convert everything to ndarray. 

160 array = np.array(self._rows, dtype=object) 

161 array = array.T 

162 column_data = {} 

163 for i, column in enumerate(self._columns): 

164 column_data[column] = pandas.Series(array[i], dtype=column_types.get(column, object)) 

165 return pandas.DataFrame(column_data) 

166 

167 def __iter__(self) -> Iterator[tuple]: 

168 """Make it look like a row iterator, needed for some odd logic.""" 

169 return iter(self._rows) 

170 

171 

172class PreparedStatementCache: 

173 """Cache for prepared Cassandra statements""" 

174 

175 def __init__(self, session: Session) -> None: 

176 self._session = session 

177 self._prepared_statements: dict[str, PreparedStatement] = {} 

178 

179 def prepare(self, query: str) -> PreparedStatement: 

180 """Convert query string into prepared statement.""" 

181 stmt = self._prepared_statements.get(query) 

182 if stmt is None: 

183 stmt = self._session.prepare(query) 

184 self._prepared_statements[query] = stmt 

185 return stmt 

186 

187 

188class StatementFactory: 

189 """Class that builds Cassandra statements from Query objects.""" 

190 

191 def __init__(self, session: Session, cache: PreparedStatementCache | None = None) -> None: 

192 self._session = session 

193 self._prepared_cache = cache 

194 

195 def __call__(self, query: Query, prepare: bool = False) -> PreparedStatement | SimpleStatement: 

196 """Generate Cassandra statement from Query. 

197 

198 Parameters 

199 ---------- 

200 query : `Query` 

201 Query to convert to Cassandra statement. 

202 prepare : `bool`, optional 

203 if `True` then generate prepared statement (and only if 

204 ``query.can_prepare`` is True). 

205 

206 Returns 

207 ------- 

208 statement : `PreparedStatement` or `SimpleStatement` 

209 Statement to execute. 

210 """ 

211 if prepare and query.can_prepare and self._prepared_cache is not None: 

212 stmt = self._prepared_cache.prepare(query.render("?")) 

213 else: 

214 stmt = SimpleStatement(query.render("%s")) 

215 return stmt 

216 

217 def with_params( 

218 self, query: Query, prepare: bool = False 

219 ) -> tuple[PreparedStatement | SimpleStatement, tuple]: 

220 """Generate Cassandra statement and its parameters from Query. 

221 

222 Parameters 

223 ---------- 

224 query : `Query` 

225 Query to convert to Cassandra statement. 

226 prepare : `bool`, optional 

227 if `True` then generate prepared statement (and only if 

228 ``query.can_prepare`` is True). 

229 

230 Returns 

231 ------- 

232 statement : `PreparedStatement` or `SimpleStatement` 

233 Statement to execute. 

234 parameters : `tuple` 

235 Parameters for this statement. 

236 """ 

237 stmt = self(query, prepare) 

238 return stmt, query.parameters 

239 

240 

241def raw_data_factory(colnames: list[str], rows: list[tuple]) -> ApdbCassandraTableData: 

242 """Make 2-element tuple containing unmodified data: list of column names 

243 and list of rows. 

244 

245 Parameters 

246 ---------- 

247 colnames : `list` [ `str` ] 

248 Names of the columns. 

249 rows : `list` of `tuple` 

250 Result rows. 

251 

252 Returns 

253 ------- 

254 data : `ApdbCassandraTableData` 

255 Input data wrapped into ApdbCassandraTableData. 

256 

257 Notes 

258 ----- 

259 When using this method as row factory for Cassandra, the resulting 

260 object should be accessed in a non-standard way using 

261 `ResultSet._current_rows` attribute. 

262 """ 

263 return ApdbCassandraTableData(colnames, rows) 

264 

265 

266def execute_concurrent( 

267 session: Session, 

268 statements: list[tuple], 

269 *, 

270 execution_profile: object = EXEC_PROFILE_DEFAULT, 

271 concurrency: int = 100, 

272) -> None: 

273 """Wrapp call to `cassandra.concurrent.execute_concurrent` to avoid 

274 importing cassandra in other modules. 

275 """ 

276 cassandra.concurrent.execute_concurrent( 

277 session, 

278 statements, 

279 concurrency=concurrency, 

280 execution_profile=execution_profile, 

281 ) 

282 

283 

284def select_concurrent( 

285 session: Session, statements: list[tuple], execution_profile: str, concurrency: int 

286) -> pandas.DataFrame | ApdbCassandraTableData | list: 

287 """Execute bunch of queries concurrently and merge their results into 

288 a single result. 

289 

290 Parameters 

291 ---------- 

292 statements : `list` [ `tuple` ] 

293 List of statements and their parameters, passed directly to 

294 ``execute_concurrent()``. 

295 execution_profile : `str` 

296 Execution profile name. 

297 

298 Returns 

299 ------- 

300 result 

301 Combined result of multiple statements, type of the result depends on 

302 specific row factory defined in execution profile. If row factory is 

303 `pandas_dataframe_factory` then pandas DataFrame is created from a 

304 combined result. If row factory is `raw_data_factory` then 

305 `ApdbCassandraTableData` is built from all records. Otherwise a list of 

306 rows is returned, type of each row is determined by the row factory. 

307 

308 Notes 

309 ----- 

310 This method can raise any exception that is raised by one of the provided 

311 statements. 

312 """ 

313 results = cassandra.concurrent.execute_concurrent( 

314 session, 

315 statements, 

316 results_generator=True, 

317 raise_on_first_error=False, 

318 concurrency=concurrency, 

319 execution_profile=execution_profile, 

320 ) 

321 

322 ep = session.get_execution_profile(execution_profile) 

323 if ep.row_factory is raw_data_factory: 

324 # Collect rows into a single list and build Dataframe out of that 

325 _LOG.debug("making raw data out of rows/columns") 

326 table_data: ApdbCassandraTableData | None = None 

327 for success, result in results: 

328 if success: 

329 data = result._current_rows 

330 assert isinstance(data, ApdbCassandraTableData) 

331 if table_data is None: 

332 table_data = data 

333 else: 

334 table_data.append(data) 

335 else: 

336 _LOG.error("error returned by query: %s", result) 

337 raise result 

338 if table_data is None: 

339 table_data = ApdbCassandraTableData([], []) 

340 return table_data 

341 

342 else: 

343 # Just concatenate all rows into a single collection. 

344 rows = [] 

345 for success, result in results: 

346 if success: 

347 rows.extend(result) 

348 else: 

349 _LOG.error("error returned by query: %s", result) 

350 raise result 

351 _LOG.debug("number of rows: %s", len(rows)) 

352 return rows 

353 

354 

355def literal(v: Any) -> Any: 

356 """Transform object into a value for the query.""" 

357 if v is None or v is pandas.NA: 

358 v = None 

359 elif isinstance(v, datetime): 

360 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1) * 1000) 

361 elif isinstance(v, bytes | str | UUID | int): 

362 pass 

363 elif isinstance(v, np.bool_): 

364 v = bool(v) 

365 else: 

366 try: 

367 if not np.isfinite(v): 

368 v = None 

369 except TypeError: 

370 pass 

371 return v 

372 

373 

374def quote_id(columnName: str) -> str: 

375 """Smart quoting for column names. Lower-case names are not quoted.""" 

376 if not columnName.islower(): 

377 columnName = '"' + columnName + '"' 

378 return columnName