Coverage for python / lsst / dax / apdb / cassandra / sessionFactory.py: 31%

93 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:46 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SessionContext", "SessionFactory"] 

25 

26import logging 

27from collections.abc import Mapping 

28from contextlib import ExitStack 

29from typing import TYPE_CHECKING, Any 

30 

31# If cassandra-driver is not there the module can still be imported 

32# but ApdbCassandra cannot be instantiated. 

33try: 

34 import cassandra 

35 import cassandra.query 

36 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

37 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

38 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

39 

40 CASSANDRA_IMPORTED = True 

41except ImportError: 

42 CASSANDRA_IMPORTED = False 

43 

44from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

45 

46from ..monitor import MonAgent 

47from ..timer import Timer 

48from .cassandra_utils import pandas_dataframe_factory, raw_data_factory 

49 

50if TYPE_CHECKING: 

51 from .config import ApdbCassandraConfig 

52 

53_LOG = logging.getLogger(__name__) 

54 

55_MON = MonAgent(__name__) 

56 

57 

58def _dump_query(rf: Any) -> None: 

59 """Dump cassandra query to debug log.""" 

60 _LOG.debug("Cassandra query: %s", rf.query) 

61 

62 

63if CASSANDRA_IMPORTED: 63 ↛ 78line 63 didn't jump to line 78 because the condition on line 63 was always true

64 

65 class _AddressTranslator(AddressTranslator): 

66 """Translate internal IP address to external. 

67 

68 Only used for docker-based setup, not a viable long-term solution. 

69 """ 

70 

71 def __init__(self, public_ips: tuple[str, ...], private_ips: tuple[str, ...]): 

72 self._map = dict(zip(private_ips, public_ips)) 

73 

74 def translate(self, private_ip: str) -> str: 

75 return self._map.get(private_ip, private_ip) 

76 

77 

78class SessionFactory: 

79 """Implementation of SessionFactory that uses parameters from Apdb 

80 configuration. 

81 

82 Parameters 

83 ---------- 

84 config : `ApdbCassandraConfig` 

85 Configuration object. 

86 """ 

87 

88 def __init__(self, config: ApdbCassandraConfig): 

89 self._config = config 

90 self._cluster: Cluster | None = None 

91 self._session: Session | None = None 

92 

93 def __del__(self) -> None: 

94 # Need to call Cluster.shutdown() to avoid warnings. 

95 if hasattr(self, "_cluster"): 

96 if self._cluster: 

97 self._cluster.shutdown() 

98 

99 def session(self) -> Session: 

100 """Return Cassandra Session, making new connection if necessary. 

101 

102 Returns 

103 ------- 

104 session : `cassandra.cluster.Sesion` 

105 Cassandra session object. 

106 """ 

107 if self._session is None: 

108 self._cluster, self._session = self._make_session() 

109 return self._session 

110 

111 def _make_session(self) -> tuple[Cluster, Session]: 

112 """Make Cassandra session. 

113 

114 Returns 

115 ------- 

116 cluster : `cassandra.cluster.Cluster` 

117 Cassandra Cluster object 

118 session : `cassandra.cluster.Session` 

119 Cassandra session object 

120 """ 

121 addressTranslator: AddressTranslator | None = None 

122 if self._config.connection_config.private_ips: 

123 addressTranslator = _AddressTranslator( 

124 self._config.contact_points, self._config.connection_config.private_ips 

125 ) 

126 

127 with Timer("cluster_connect", _MON): 

128 cluster = Cluster( 

129 execution_profiles=self._make_profiles(), 

130 contact_points=self._config.contact_points, 

131 port=self._config.connection_config.port, 

132 address_translator=addressTranslator, 

133 protocol_version=self._config.connection_config.protocol_version, 

134 auth_provider=self._make_auth_provider(), 

135 **self._config.connection_config.extra_parameters, 

136 ) 

137 session = cluster.connect() 

138 

139 # Dump queries if debug level is enabled. 

140 if _LOG.isEnabledFor(logging.DEBUG): 

141 session.add_request_init_listener(_dump_query) 

142 

143 # Disable result paging 

144 session.default_fetch_size = None 

145 

146 return cluster, session 

147 

148 def _make_auth_provider(self) -> AuthProvider | None: 

149 """Make Cassandra authentication provider instance.""" 

150 try: 

151 dbauth = DbAuth() 

152 except DbAuthNotFoundError: 

153 # Credentials file doesn't exist, use anonymous login. 

154 return None 

155 

156 empty_username = True 

157 # Try every contact point in turn. 

158 for hostname in self._config.contact_points: 

159 try: 

160 username, password = dbauth.getAuth( 

161 "cassandra", 

162 self._config.connection_config.username, 

163 hostname, 

164 self._config.connection_config.port, 

165 self._config.keyspace, 

166 ) 

167 if not username: 

168 # Password without user name, try next hostname, but give 

169 # warning later if no better match is found. 

170 empty_username = True 

171 else: 

172 return PlainTextAuthProvider(username=username, password=password) 

173 except DbAuthNotFoundError: 

174 pass 

175 

176 if empty_username: 

177 _LOG.warning( 

178 f"Credentials file ({dbauth.db_auth_path}) provided password but not " 

179 "user name, anonymous Cassandra logon will be attempted." 

180 ) 

181 

182 return None 

183 

184 def _make_profiles(self) -> Mapping[Any, ExecutionProfile]: 

185 """Make all execution profiles used in the code.""" 

186 config = self._config 

187 if config.connection_config.private_ips: 

188 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

189 else: 

190 loadBalancePolicy = RoundRobinPolicy() 

191 

192 read_tuples_profile = ExecutionProfile( 

193 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

194 request_timeout=config.connection_config.read_timeout, 

195 row_factory=cassandra.query.tuple_factory, 

196 load_balancing_policy=loadBalancePolicy, 

197 ) 

198 read_pandas_profile = ExecutionProfile( 

199 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

200 request_timeout=config.connection_config.read_timeout, 

201 row_factory=pandas_dataframe_factory, 

202 load_balancing_policy=loadBalancePolicy, 

203 ) 

204 read_raw_profile = ExecutionProfile( 

205 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

206 request_timeout=config.connection_config.read_timeout, 

207 row_factory=raw_data_factory, 

208 load_balancing_policy=loadBalancePolicy, 

209 ) 

210 # Profile to use with select_concurrent to return pandas data frame 

211 read_pandas_multi_profile = ExecutionProfile( 

212 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

213 request_timeout=config.connection_config.read_timeout, 

214 row_factory=pandas_dataframe_factory, 

215 load_balancing_policy=loadBalancePolicy, 

216 ) 

217 # Profile to use with select_concurrent to return raw data (columns and 

218 # rows) 

219 read_raw_multi_profile = ExecutionProfile( 

220 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

221 request_timeout=config.connection_config.read_timeout, 

222 row_factory=raw_data_factory, 

223 load_balancing_policy=loadBalancePolicy, 

224 ) 

225 write_profile = ExecutionProfile( 

226 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.write_consistency), 

227 request_timeout=config.connection_config.write_timeout, 

228 load_balancing_policy=loadBalancePolicy, 

229 ) 

230 # To replace default DCAwareRoundRobinPolicy 

231 default_profile = ExecutionProfile( 

232 load_balancing_policy=loadBalancePolicy, 

233 ) 

234 return { 

235 "read_tuples": read_tuples_profile, 

236 "read_pandas": read_pandas_profile, 

237 "read_raw": read_raw_profile, 

238 "read_pandas_multi": read_pandas_multi_profile, 

239 "read_raw_multi": read_raw_multi_profile, 

240 "write": write_profile, 

241 EXEC_PROFILE_DEFAULT: default_profile, 

242 } 

243 

244 

245class SessionContext(ExitStack): 

246 """Context manager for creating short-lived Cassandra sessions. 

247 

248 Parameters 

249 ---------- 

250 config : `ApdbCassandraConfig` 

251 Configuration object. 

252 """ 

253 

254 def __init__(self, config: ApdbCassandraConfig): 

255 super().__init__() 

256 self._session_factory = SessionFactory(config) 

257 

258 def __enter__(self) -> Session: 

259 super().__enter__() 

260 cluster, session = self._session_factory._make_session() 

261 self.enter_context(cluster) 

262 self.enter_context(session) 

263 return session