Coverage for python/lsst/dax/apdb/cassandra_utils.py: 20%

103 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-31 02:59 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "literal", 

26 "pandas_dataframe_factory", 

27 "quote_id", 

28 "raw_data_factory", 

29 "select_concurrent", 

30] 

31 

32import logging 

33from collections.abc import Iterable, Iterator 

34from datetime import datetime, timedelta 

35from typing import Any 

36from uuid import UUID 

37 

38import numpy as np 

39import pandas 

40 

41# If cassandra-driver is not there the module can still be imported 

42# but things will not work. 

43try: 

44 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Session 

45 from cassandra.concurrent import execute_concurrent 

46 

47 CASSANDRA_IMPORTED = True 

48except ImportError: 

49 CASSANDRA_IMPORTED = False 

50 

51from .apdb import ApdbTableData 

52 

53_LOG = logging.getLogger(__name__) 

54 

55 

56if CASSANDRA_IMPORTED: 56 ↛ 58line 56 didn't jump to line 58, because the condition on line 56 was never true

57 

58 class SessionWrapper: 

59 """Special wrapper class to workaround ``execute_concurrent()`` issue 

60 which does not allow non-default execution profile. 

61 

62 Instance of this class can be passed to execute_concurrent() instead 

63 of `Session` instance. This class implements a small set of methods 

64 that are needed by ``execute_concurrent()``. When 

65 ``execute_concurrent()`` is fixed to accept exectution profiles, this 

66 wrapper can be dropped. 

67 """ 

68 

69 def __init__(self, session: Session, execution_profile: Any = EXEC_PROFILE_DEFAULT): 

70 self._session = session 

71 self._execution_profile = execution_profile 

72 

73 def execute_async( 

74 self, 

75 *args: Any, 

76 execution_profile: Any = EXEC_PROFILE_DEFAULT, 

77 **kwargs: Any, 

78 ) -> Any: 

79 # explicit parameter can override our settings 

80 if execution_profile is EXEC_PROFILE_DEFAULT: 

81 execution_profile = self._execution_profile 

82 return self._session.execute_async(*args, execution_profile=execution_profile, **kwargs) 

83 

84 def submit(self, *args: Any, **kwargs: Any) -> Any: 

85 # internal method 

86 return self._session.submit(*args, **kwargs) 

87 

88 

89class ApdbCassandraTableData(ApdbTableData): 

90 """Implementation of ApdbTableData that wraps Cassandra raw data.""" 

91 

92 def __init__(self, columns: list[str], rows: list[tuple]): 

93 self._columns = columns 

94 self._rows = rows 

95 

96 def column_names(self) -> list[str]: 

97 # docstring inherited 

98 return self._columns 

99 

100 def rows(self) -> Iterable[tuple]: 

101 # docstring inherited 

102 return self._rows 

103 

104 def append(self, other: ApdbCassandraTableData) -> None: 

105 """Extend rows in this table with rows in other table""" 

106 if self._columns != other._columns: 

107 raise ValueError(f"Different columns returned by queries: {self._columns} and {other._columns}") 

108 self._rows.extend(other._rows) 

109 

110 def __iter__(self) -> Iterator[tuple]: 

111 """Make it look like a row iterator, needed for some odd logic.""" 

112 return iter(self._rows) 

113 

114 

115def pandas_dataframe_factory(colnames: list[str], rows: list[tuple]) -> pandas.DataFrame: 

116 """Special non-standard row factory that creates pandas DataFrame from 

117 Cassandra result set. 

118 

119 Parameters 

120 ---------- 

121 colnames : `list` [ `str` ] 

122 Names of the columns. 

123 rows : `list` of `tuple` 

124 Result rows. 

125 

126 Returns 

127 ------- 

128 catalog : `pandas.DataFrame` 

129 DataFrame with the result set. 

130 

131 Notes 

132 ----- 

133 When using this method as row factory for Cassandra, the resulting 

134 DataFrame should be accessed in a non-standard way using 

135 `ResultSet._current_rows` attribute. 

136 """ 

137 return pandas.DataFrame.from_records(rows, columns=colnames) 

138 

139 

140def raw_data_factory(colnames: list[str], rows: list[tuple]) -> ApdbCassandraTableData: 

141 """Special non-standard row factory that makes 2-element tuple containing 

142 unmodified data: list of column names and list of rows. 

143 

144 Parameters 

145 ---------- 

146 colnames : `list` [ `str` ] 

147 Names of the columns. 

148 rows : `list` of `tuple` 

149 Result rows. 

150 

151 Returns 

152 ------- 

153 data : `ApdbCassandraTableData` 

154 Input data wrapped into ApdbCassandraTableData. 

155 

156 Notes 

157 ----- 

158 When using this method as row factory for Cassandra, the resulting 

159 object should be accessed in a non-standard way using 

160 `ResultSet._current_rows` attribute. 

161 """ 

162 return ApdbCassandraTableData(colnames, rows) 

163 

164 

165def select_concurrent( 

166 session: Session, statements: list[tuple], execution_profile: str, concurrency: int 

167) -> pandas.DataFrame | ApdbCassandraTableData | list: 

168 """Execute bunch of queries concurrently and merge their results into 

169 a single result. 

170 

171 Parameters 

172 ---------- 

173 statements : `list` [ `tuple` ] 

174 List of statements and their parameters, passed directly to 

175 ``execute_concurrent()``. 

176 execution_profile : `str` 

177 Execution profile name. 

178 

179 Returns 

180 ------- 

181 result 

182 Combined result of multiple statements, type of the result depends on 

183 specific row factory defined in execution profile. If row factory is 

184 `pandas_dataframe_factory` then pandas DataFrame is created from a 

185 combined result. If row factory is `raw_data_factory` then 

186 `ApdbCassandraTableData` is built from all records. Otherwise a list of 

187 rows is returned, type of each row is determined by the row factory. 

188 

189 Notes 

190 ----- 

191 This method can raise any exception that is raised by one of the provided 

192 statements. 

193 """ 

194 session_wrap = SessionWrapper(session, execution_profile) 

195 results = execute_concurrent( 

196 session_wrap, 

197 statements, 

198 results_generator=True, 

199 raise_on_first_error=False, 

200 concurrency=concurrency, 

201 ) 

202 

203 ep = session.get_execution_profile(execution_profile) 

204 if ep.row_factory is raw_data_factory: 

205 # Collect rows into a single list and build Dataframe out of that 

206 _LOG.debug("making pandas data frame out of rows/columns") 

207 table_data: ApdbCassandraTableData | None = None 

208 for success, result in results: 

209 if success: 

210 data = result._current_rows 

211 assert isinstance(data, ApdbCassandraTableData) 

212 if table_data is None: 

213 table_data = data 

214 else: 

215 table_data.append(data) 

216 else: 

217 _LOG.error("error returned by query: %s", result) 

218 raise result 

219 if table_data is None: 

220 table_data = ApdbCassandraTableData([], []) 

221 return table_data 

222 

223 elif ep.row_factory is pandas_dataframe_factory: 

224 # Merge multiple DataFrames into one 

225 _LOG.debug("making pandas data frame out of set of data frames") 

226 dataframes = [] 

227 for success, result in results: 

228 if success: 

229 dataframes.append(result._current_rows) 

230 else: 

231 _LOG.error("error returned by query: %s", result) 

232 raise result 

233 # concatenate all frames 

234 if len(dataframes) == 1: 

235 catalog = dataframes[0] 

236 else: 

237 catalog = pandas.concat(dataframes) 

238 _LOG.debug("pandas catalog shape: %s", catalog.shape) 

239 return catalog 

240 

241 else: 

242 # Just concatenate all rows into a single collection. 

243 rows = [] 

244 for success, result in results: 

245 if success: 

246 rows.extend(result) 

247 else: 

248 _LOG.error("error returned by query: %s", result) 

249 raise result 

250 _LOG.debug("number of rows: %s", len(rows)) 

251 return rows 

252 

253 

254def literal(v: Any) -> Any: 

255 """Transform object into a value for the query.""" 

256 if v is None: 

257 pass 

258 elif isinstance(v, datetime): 

259 v = int((v - datetime(1970, 1, 1)) / timedelta(seconds=1)) * 1000 

260 elif isinstance(v, (bytes, str, UUID, int)): 

261 pass 

262 else: 

263 try: 

264 if not np.isfinite(v): 

265 v = None 

266 except TypeError: 

267 pass 

268 return v 

269 

270 

271def quote_id(columnName: str) -> str: 

272 """Smart quoting for column names. Lower-case names are not quoted.""" 

273 if not columnName.islower(): 

274 columnName = '"' + columnName + '"' 

275 return columnName