Coverage for python/lsst/dax/apdb/cassandra_utils.py: 22%

117 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-02 11:13 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "ApdbCassandraTableData", 

26 "PreparedStatementCache", 

27 "literal", 

28 "pandas_dataframe_factory", 

29 "quote_id", 

30 "raw_data_factory", 

31 "select_concurrent", 

32] 

33 

34import logging 

35from collections.abc import Iterable, Iterator 

36from datetime import datetime, timedelta 

37from typing import Any 

38from uuid import UUID 

39 

40import numpy as np 

41import pandas 

42 

43# If cassandra-driver is not there the module can still be imported 

44# but things will not work. 

45try: 

46 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Session 

47 from cassandra.concurrent import execute_concurrent 

48 from cassandra.query import PreparedStatement 

49 

50 CASSANDRA_IMPORTED = True 

51except ImportError: 

52 CASSANDRA_IMPORTED = False 

53 

54from .apdb import ApdbTableData 

55 

56_LOG = logging.getLogger(__name__) 

57 

58 

59if CASSANDRA_IMPORTED: 59 ↛ 92line 59 didn't jump to line 92, because the condition on line 59 was never false

60 

61 class SessionWrapper: 

62 """Special wrapper class to workaround ``execute_concurrent()`` issue 

63 which does not allow non-default execution profile. 

64 

65 Instance of this class can be passed to execute_concurrent() instead 

66 of `Session` instance. This class implements a small set of methods 

67 that are needed by ``execute_concurrent()``. When 

68 ``execute_concurrent()`` is fixed to accept exectution profiles, this 

69 wrapper can be dropped. 

70 """ 

71 

72 def __init__(self, session: Session, execution_profile: Any = EXEC_PROFILE_DEFAULT): 

73 self._session = session 

74 self._execution_profile = execution_profile 

75 

76 def execute_async( 

77 self, 

78 *args: Any, 

79 execution_profile: Any = EXEC_PROFILE_DEFAULT, 

80 **kwargs: Any, 

81 ) -> Any: 

82 # explicit parameter can override our settings 

83 if execution_profile is EXEC_PROFILE_DEFAULT: 

84 execution_profile = self._execution_profile 

85 return self._session.execute_async(*args, execution_profile=execution_profile, **kwargs) 

86 

87 def submit(self, *args: Any, **kwargs: Any) -> Any: 

88 # internal method 

89 return self._session.submit(*args, **kwargs) 

90 

91 

92class ApdbCassandraTableData(ApdbTableData): 

93 """Implementation of ApdbTableData that wraps Cassandra raw data.""" 

94 

95 def __init__(self, columns: list[str], rows: list[tuple]): 

96 self._columns = columns 

97 self._rows = rows 

98 

99 def column_names(self) -> list[str]: 

100 # docstring inherited 

101 return self._columns 

102 

103 def rows(self) -> Iterable[tuple]: 

104 # docstring inherited 

105 return self._rows 

106 

107 def append(self, other: ApdbCassandraTableData) -> None: 

108 """Extend rows in this table with rows in other table""" 

109 if self._columns != other._columns: 

110 raise ValueError(f"Different columns returned by queries: {self._columns} and {other._columns}") 

111 self._rows.extend(other._rows) 

112 

113 def __iter__(self) -> Iterator[tuple]: 

114 """Make it look like a row iterator, needed for some odd logic.""" 

115 return iter(self._rows) 

116 

117 

118class PreparedStatementCache: 

119 """Cache for prepared Cassandra statements""" 

120 

121 def __init__(self, session: Session) -> None: 

122 self._session = session 

123 self._prepared_statements: dict[str, PreparedStatement] = {} 

124 

125 def prepare(self, query: str) -> PreparedStatement: 

126 """Convert query string into prepared statement.""" 

127 stmt = self._prepared_statements.get(query) 

128 if stmt is None: 

129 stmt = self._session.prepare(query) 

130 self._prepared_statements[query] = stmt 

131 return stmt 

132 

133 

134def pandas_dataframe_factory(colnames: list[str], rows: list[tuple]) -> pandas.DataFrame: 

135 """Create pandas DataFrame from Cassandra result set. 

136 

137 Parameters 

138 ---------- 

139 colnames : `list` [ `str` ] 

140 Names of the columns. 

141 rows : `list` of `tuple` 

142 Result rows. 

143 

144 Returns 

145 ------- 

146 catalog : `pandas.DataFrame` 

147 DataFrame with the result set. 

148 

149 Notes 

150 ----- 

151 When using this method as row factory for Cassandra, the resulting 

152 DataFrame should be accessed in a non-standard way using 

153 `ResultSet._current_rows` attribute. 

154 """ 

155 return pandas.DataFrame.from_records(rows, columns=colnames) 

156 

157 

158def raw_data_factory(colnames: list[str], rows: list[tuple]) -> ApdbCassandraTableData: 

159 """Make 2-element tuple containing unmodified data: list of column names 

160 and list of rows. 

161 

162 Parameters 

163 ---------- 

164 colnames : `list` [ `str` ] 

165 Names of the columns. 

166 rows : `list` of `tuple` 

167 Result rows. 

168 

169 Returns 

170 ------- 

171 data : `ApdbCassandraTableData` 

172 Input data wrapped into ApdbCassandraTableData. 

173 

174 Notes 

175 ----- 

176 When using this method as row factory for Cassandra, the resulting 

177 object should be accessed in a non-standard way using 

178 `ResultSet._current_rows` attribute. 

179 """ 

180 return ApdbCassandraTableData(colnames, rows) 

181 

182 

183def select_concurrent( 

184 session: Session, statements: list[tuple], execution_profile: str, concurrency: int 

185) -> pandas.DataFrame | ApdbCassandraTableData | list: 

186 """Execute bunch of queries concurrently and merge their results into 

187 a single result. 

188 

189 Parameters 

190 ---------- 

191 statements : `list` [ `tuple` ] 

192 List of statements and their parameters, passed directly to 

193 ``execute_concurrent()``. 

194 execution_profile : `str` 

195 Execution profile name. 

196 

197 Returns 

198 ------- 

199 result 

200 Combined result of multiple statements, type of the result depends on 

201 specific row factory defined in execution profile. If row factory is 

202 `pandas_dataframe_factory` then pandas DataFrame is created from a 

203 combined result. If row factory is `raw_data_factory` then 

204 `ApdbCassandraTableData` is built from all records. Otherwise a list of 

205 rows is returned, type of each row is determined by the row factory. 

206 

207 Notes 

208 ----- 

209 This method can raise any exception that is raised by one of the provided 

210 statements. 

211 """ 

212 session_wrap = SessionWrapper(session, execution_profile) 

213 results = execute_concurrent( 

214 session_wrap, 

215 statements, 

216 results_generator=True, 

217 raise_on_first_error=False, 

218 concurrency=concurrency, 

219 ) 

220 

221 ep = session.get_execution_profile(execution_profile) 

222 if ep.row_factory is raw_data_factory: 

223 # Collect rows into a single list and build Dataframe out of that 

224 _LOG.debug("making pandas data frame out of rows/columns") 

225 table_data: ApdbCassandraTableData | None = None 

226 for success, result in results: 

227 if success: 

228 data = result._current_rows 

229 assert isinstance(data, ApdbCassandraTableData) 

230 if table_data is None: 

231 table_data = data 

232 else: 

233 table_data.append(data) 

234 else: 

235 _LOG.error("error returned by query: %s", result) 

236 raise result 

237 if table_data is None: 

238 table_data = ApdbCassandraTableData([], []) 

239 return table_data 

240 

241 elif ep.row_factory is pandas_dataframe_factory: 

242 # Merge multiple DataFrames into one 

243 _LOG.debug("making pandas data frame out of set of data frames") 

244 dataframes = [] 

245 for success, result in results: 

246 if success: 

247 dataframes.append(result._current_rows) 

248 else: 

249 _LOG.error("error returned by query: %s", result) 

250 raise result 

251 # Concatenate all frames, but skip empty ones. 

252 non_empty = [df for df in dataframes if not df.empty] 

253 if not non_empty: 

254 # If all frames are empty, return the first one. 

255 catalog = dataframes[0] 

256 elif len(non_empty) == 1: 

257 catalog = non_empty[0] 

258 else: 

259 catalog = pandas.concat(non_empty) 

260 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

261 return catalog 

262 

263 else: 

264 # Just concatenate all rows into a single collection. 

265 rows = [] 

266 for success, result in results: 

267 if success: 

268 rows.extend(result) 

269 else: 

270 _LOG.error("error returned by query: %s", result) 

271 raise result 

272 _LOG.debug("number of rows: %s", len(rows)) 

273 return rows 

274 

275 

276def literal(v: Any) -> Any: 

277 """Transform object into a value for the query.""" 

278 if v is None: 

279 pass 

280 elif isinstance(v, datetime): 

281 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1)) * 1000 

282 elif isinstance(v, (bytes, str, UUID, int)): 

283 pass 

284 else: 

285 try: 

286 if not np.isfinite(v): 

287 v = None 

288 except TypeError: 

289 pass 

290 return v 

291 

292 

293def quote_id(columnName: str) -> str: 

294 """Smart quoting for column names. Lower-case names are not quoted.""" 

295 if not columnName.islower(): 

296 columnName = '"' + columnName + '"' 

297 return columnName