Coverage for python/lsst/dax/apdb/cassandra_utils.py: 23%

103 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-20 17:15 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "literal", 

26 "pandas_dataframe_factory", 

27 "quote_id", 

28 "raw_data_factory", 

29 "select_concurrent", 

30] 

31 

32import logging 

33from collections.abc import Iterable, Iterator 

34from datetime import datetime, timedelta 

35from typing import Any 

36from uuid import UUID 

37 

38import numpy as np 

39import pandas 

40 

41# If cassandra-driver is not there the module can still be imported 

42# but things will not work. 

43try: 

44 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Session 

45 from cassandra.concurrent import execute_concurrent 

46 

47 CASSANDRA_IMPORTED = True 

48except ImportError: 

49 CASSANDRA_IMPORTED = False 

50 

51from .apdb import ApdbTableData 

52 

53_LOG = logging.getLogger(__name__) 

54 

55 

56if CASSANDRA_IMPORTED: 56 ↛ 89line 56 didn't jump to line 89, because the condition on line 56 was never false

57 

58 class SessionWrapper: 

59 """Special wrapper class to workaround ``execute_concurrent()`` issue 

60 which does not allow non-default execution profile. 

61 

62 Instance of this class can be passed to execute_concurrent() instead 

63 of `Session` instance. This class implements a small set of methods 

64 that are needed by ``execute_concurrent()``. When 

65 ``execute_concurrent()`` is fixed to accept exectution profiles, this 

66 wrapper can be dropped. 

67 """ 

68 

69 def __init__(self, session: Session, execution_profile: Any = EXEC_PROFILE_DEFAULT): 

70 self._session = session 

71 self._execution_profile = execution_profile 

72 

73 def execute_async( 

74 self, 

75 *args: Any, 

76 execution_profile: Any = EXEC_PROFILE_DEFAULT, 

77 **kwargs: Any, 

78 ) -> Any: 

79 # explicit parameter can override our settings 

80 if execution_profile is EXEC_PROFILE_DEFAULT: 

81 execution_profile = self._execution_profile 

82 return self._session.execute_async(*args, execution_profile=execution_profile, **kwargs) 

83 

84 def submit(self, *args: Any, **kwargs: Any) -> Any: 

85 # internal method 

86 return self._session.submit(*args, **kwargs) 

87 

88 

89class ApdbCassandraTableData(ApdbTableData): 

90 """Implementation of ApdbTableData that wraps Cassandra raw data.""" 

91 

92 def __init__(self, columns: list[str], rows: list[tuple]): 

93 self._columns = columns 

94 self._rows = rows 

95 

96 def column_names(self) -> list[str]: 

97 # docstring inherited 

98 return self._columns 

99 

100 def rows(self) -> Iterable[tuple]: 

101 # docstring inherited 

102 return self._rows 

103 

104 def append(self, other: ApdbCassandraTableData) -> None: 

105 """Extend rows in this table with rows in other table""" 

106 if self._columns != other._columns: 

107 raise ValueError(f"Different columns returned by queries: {self._columns} and {other._columns}") 

108 self._rows.extend(other._rows) 

109 

110 def __iter__(self) -> Iterator[tuple]: 

111 """Make it look like a row iterator, needed for some odd logic.""" 

112 return iter(self._rows) 

113 

114 

115def pandas_dataframe_factory(colnames: list[str], rows: list[tuple]) -> pandas.DataFrame: 

116 """Create pandas DataFrame from Cassandra result set. 

117 

118 Parameters 

119 ---------- 

120 colnames : `list` [ `str` ] 

121 Names of the columns. 

122 rows : `list` of `tuple` 

123 Result rows. 

124 

125 Returns 

126 ------- 

127 catalog : `pandas.DataFrame` 

128 DataFrame with the result set. 

129 

130 Notes 

131 ----- 

132 When using this method as row factory for Cassandra, the resulting 

133 DataFrame should be accessed in a non-standard way using 

134 `ResultSet._current_rows` attribute. 

135 """ 

136 return pandas.DataFrame.from_records(rows, columns=colnames) 

137 

138 

139def raw_data_factory(colnames: list[str], rows: list[tuple]) -> ApdbCassandraTableData: 

140 """Make 2-element tuple containing unmodified data: list of column names 

141 and list of rows. 

142 

143 Parameters 

144 ---------- 

145 colnames : `list` [ `str` ] 

146 Names of the columns. 

147 rows : `list` of `tuple` 

148 Result rows. 

149 

150 Returns 

151 ------- 

152 data : `ApdbCassandraTableData` 

153 Input data wrapped into ApdbCassandraTableData. 

154 

155 Notes 

156 ----- 

157 When using this method as row factory for Cassandra, the resulting 

158 object should be accessed in a non-standard way using 

159 `ResultSet._current_rows` attribute. 

160 """ 

161 return ApdbCassandraTableData(colnames, rows) 

162 

163 

164def select_concurrent( 

165 session: Session, statements: list[tuple], execution_profile: str, concurrency: int 

166) -> pandas.DataFrame | ApdbCassandraTableData | list: 

167 """Execute bunch of queries concurrently and merge their results into 

168 a single result. 

169 

170 Parameters 

171 ---------- 

172 statements : `list` [ `tuple` ] 

173 List of statements and their parameters, passed directly to 

174 ``execute_concurrent()``. 

175 execution_profile : `str` 

176 Execution profile name. 

177 

178 Returns 

179 ------- 

180 result 

181 Combined result of multiple statements, type of the result depends on 

182 specific row factory defined in execution profile. If row factory is 

183 `pandas_dataframe_factory` then pandas DataFrame is created from a 

184 combined result. If row factory is `raw_data_factory` then 

185 `ApdbCassandraTableData` is built from all records. Otherwise a list of 

186 rows is returned, type of each row is determined by the row factory. 

187 

188 Notes 

189 ----- 

190 This method can raise any exception that is raised by one of the provided 

191 statements. 

192 """ 

193 session_wrap = SessionWrapper(session, execution_profile) 

194 results = execute_concurrent( 

195 session_wrap, 

196 statements, 

197 results_generator=True, 

198 raise_on_first_error=False, 

199 concurrency=concurrency, 

200 ) 

201 

202 ep = session.get_execution_profile(execution_profile) 

203 if ep.row_factory is raw_data_factory: 

204 # Collect rows into a single list and build Dataframe out of that 

205 _LOG.debug("making pandas data frame out of rows/columns") 

206 table_data: ApdbCassandraTableData | None = None 

207 for success, result in results: 

208 if success: 

209 data = result._current_rows 

210 assert isinstance(data, ApdbCassandraTableData) 

211 if table_data is None: 

212 table_data = data 

213 else: 

214 table_data.append(data) 

215 else: 

216 _LOG.error("error returned by query: %s", result) 

217 raise result 

218 if table_data is None: 

219 table_data = ApdbCassandraTableData([], []) 

220 return table_data 

221 

222 elif ep.row_factory is pandas_dataframe_factory: 

223 # Merge multiple DataFrames into one 

224 _LOG.debug("making pandas data frame out of set of data frames") 

225 dataframes = [] 

226 for success, result in results: 

227 if success: 

228 dataframes.append(result._current_rows) 

229 else: 

230 _LOG.error("error returned by query: %s", result) 

231 raise result 

232 # concatenate all frames 

233 if len(dataframes) == 1: 

234 catalog = dataframes[0] 

235 else: 

236 catalog = pandas.concat(dataframes) 

237 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

238 return catalog 

239 

240 else: 

241 # Just concatenate all rows into a single collection. 

242 rows = [] 

243 for success, result in results: 

244 if success: 

245 rows.extend(result) 

246 else: 

247 _LOG.error("error returned by query: %s", result) 

248 raise result 

249 _LOG.debug("number of rows: %s", len(rows)) 

250 return rows 

251 

252 

253def literal(v: Any) -> Any: 

254 """Transform object into a value for the query.""" 

255 if v is None: 

256 pass 

257 elif isinstance(v, datetime): 

258 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1)) * 1000 

259 elif isinstance(v, (bytes, str, UUID, int)): 

260 pass 

261 else: 

262 try: 

263 if not np.isfinite(v): 

264 v = None 

265 except TypeError: 

266 pass 

267 return v 

268 

269 

270def quote_id(columnName: str) -> str: 

271 """Smart quoting for column names. Lower-case names are not quoted.""" 

272 if not columnName.islower(): 

273 columnName = '"' + columnName + '"' 

274 return columnName