Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import numpy 

2from lsst.utils import getPackageDir 

3import os 

4from sqlalchemy.orm import scoped_session, sessionmaker 

5from sqlalchemy.engine import reflection, url 

6from sqlalchemy import (create_engine, MetaData, event, inspect) 

7import warnings 

8from io import BytesIO 

9str_cast = str 

10 

11 

12__all__ = ['DBObject'] 

13 

14 

15def valueOfPi(): 

16 """ 

17 A function to return the value of pi. This is needed for adding PI() 

18 to sqlite databases 

19 """ 

20 return numpy.pi 

21 

22 

23def declareTrigFunctions(conn, connection_rec, connection_proxy): 

24 """ 

25 A database event listener 

26 which will define the math functions necessary for evaluating the 

27 Haversine function in sqlite databases (where they are not otherwise 

28 defined) 

29 

30 see: http://docs.sqlalchemy.org/en/latest/core/events.html 

31 """ 

32 

33 conn.create_function("COS", 1, numpy.cos) 

34 conn.create_function("SIN", 1, numpy.sin) 

35 conn.create_function("ASIN", 1, numpy.arcsin) 

36 conn.create_function("SQRT", 1, numpy.sqrt) 

37 conn.create_function("POWER", 2, numpy.power) 

38 conn.create_function("PI", 0, valueOfPi) 

39 

40 

41class ChunkIterator(object): 

42 """Iterator for query chunks""" 

43 def __init__(self, dbobj, query, chunk_size, arbitrarySQL=False): 

44 self.dbobj = dbobj 

45 self.exec_query = dbobj.connection.session.execute(query) 

46 self.chunk_size = chunk_size 

47 

48 # arbitrarySQL exists in case a CatalogDBObject calls 

49 # get_arbitrary_chunk_iterator; in that case, we need to 

50 # be able to tell this object to call _postprocess_arbitrary_results, 

51 # rather than _postprocess_results 

52 self.arbitrarySQL = arbitrarySQL 

53 

54 def __iter__(self): 

55 return self 

56 

57 def __next__(self): 

58 if self.chunk_size is None and not self.exec_query.closed: 

59 chunk = self.exec_query.fetchall() 

60 return self._postprocess_results(chunk) 

61 elif self.chunk_size is not None: 

62 chunk = self.exec_query.fetchmany(self.chunk_size) 

63 return self._postprocess_results(chunk) 

64 else: 

65 raise StopIteration 

66 

67 def _postprocess_results(self, chunk): 

68 if len(chunk) == 0: 

69 raise StopIteration 

70 if self.arbitrarySQL: 

71 return self.dbobj._postprocess_arbitrary_results(chunk) 

72 else: 

73 return self.dbobj._postprocess_results(chunk) 

74 

75 

76class DBConnection(object): 

77 """ 

78 This is a class that will hold the engine, session, and metadata for a 

79 DBObject. This will allow multiple DBObjects to share the same 

80 sqlalchemy connection, when appropriate. 

81 """ 

82 

83 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False): 

84 """ 

85 @param [in] database is the name of the database file being connected to 

86 

87 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

88 

89 @param [in] host is the URL of the remote host, if appropriate 

90 

91 @param [in] port is the port on the remote host to connect to, if appropriate 

92 

93 @param [in] verbose is a boolean controlling sqlalchemy's verbosity 

94 """ 

95 

96 self._database = database 

97 self._driver = driver 

98 self._host = host 

99 self._port = port 

100 self._verbose = verbose 

101 

102 self._validate_conn_params() 

103 self._connect_to_engine() 

104 

105 def __del__(self): 

106 try: 

107 del self._metadata 

108 except AttributeError: 

109 pass 

110 

111 try: 

112 del self._engine 

113 except AttributeError: 

114 pass 

115 

116 try: 

117 del self._session 

118 except AttributeError: 

119 pass 

120 

121 def _connect_to_engine(self): 

122 

123 # Remove dbAuth things. Assume we are only connecting to a local database. 

124 # Line to use when we update sqlalchemy  

125 #dbUrl = url.URL.create(self._driver, 

126 # database=self._database) 

127 # Remove this line when sqlalchemy updated: 

128 dbUrl = url.URL(self._driver, database=self._database) 

129 

130 self._engine = create_engine(dbUrl, echo=self._verbose) 

131 

132 if self._engine.dialect.name == 'sqlite': 

133 event.listen(self._engine, 'checkout', declareTrigFunctions) 

134 

135 self._session = scoped_session(sessionmaker(autoflush=True, 

136 bind=self._engine)) 

137 self._metadata = MetaData(bind=self._engine) 

138 

139 def _validate_conn_params(self): 

140 """Validate connection parameters 

141 

142 - Check if user passed dbAddress instead of an database. Convert and warn. 

143 - Check that required connection paramters are present 

144 - Replace default host/port if driver is 'sqlite' 

145 """ 

146 

147 if self._database is None: 

148 raise AttributeError("Cannot instantiate DBConnection; database is 'None'") 

149 

150 if '//' in self._database: 

151 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. " 

152 "Attempting to convert to database, driver, host, " 

153 "and port parameters. Any usernames and passwords are ignored and must " 

154 "be in the db-auth.paf policy file. " % (self.database), FutureWarning) 

155 

156 dbUrl = url.make_url(self._database) 

157 dialect = dbUrl.get_dialect() 

158 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name 

159 for key, value in dbUrl.translate_connect_args().items(): 

160 if value is not None: 

161 setattr(self, '_'+key, value) 

162 

163 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. " 

164 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'." 

165 if not hasattr(self, '_driver'): 

166 raise AttributeError("%s has no attribute 'driver'. " % (self.__class__.__name__) + errMessage) 

167 elif self._driver is None: 

168 raise AttributeError("%s.driver is None. " % (self.__class__.__name__) + errMessage) 

169 

170 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. " 

171 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. " 

172 if not hasattr(self, '_database'): 

173 raise AttributeError("%s has no attribute 'database'. " % (self.__class__.__name__) + errMessage) 

174 elif self._database is None: 

175 raise AttributeError("%s.database is None. " % (self.__class__.__name__) + errMessage) 

176 

177 if 'sqlite' in self._driver: 

178 # When passed sqlite database, override default host/port 

179 self._host = None 

180 self._port = None 

181 

182 def __eq__(self, other): 

183 return (str(self._database) == str(other._database)) and \ 

184 (str(self._driver) == str(other._driver)) and \ 

185 (str(self._host) == str(other._host)) and \ 

186 (str(self._port) == str(other._port)) 

187 

188 @property 

189 def engine(self): 

190 return self._engine 

191 

192 @property 

193 def session(self): 

194 return self._session 

195 

196 @property 

197 def metadata(self): 

198 return self._metadata 

199 

200 @property 

201 def database(self): 

202 return self._database 

203 

204 @property 

205 def driver(self): 

206 return self._driver 

207 

208 @property 

209 def host(self): 

210 return self._host 

211 

212 @property 

213 def port(self): 

214 return self._port 

215 

216 @property 

217 def verbose(self): 

218 return self._verbose 

219 

220 

221class DBObject(object): 

222 

223 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, 

224 connection=None, cache_connection=True): 

225 """ 

226 Initialize DBObject. 

227 

228 @param [in] database is the name of the database file being connected to 

229 

230 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

231 

232 @param [in] host is the URL of the remote host, if appropriate 

233 

234 @param [in] port is the port on the remote host to connect to, if appropriate 

235 

236 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False) 

237 

238 @param [in] connection is an optional instance of DBConnection, in the event that 

239 this DBObject can share a database connection with another DBObject. This is only 

240 necessary or even possible in a few specialized cases and should be used carefully. 

241 

242 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of 

243 DBConnections (if available) to get the connection to this database. 

244 """ 

245 

246 self.dtype = None 

247 # this is a cache for the query, so that any one query does not have to guess dtype multiple times 

248 

249 if connection is None: 

250 # Explicit constructor to DBObject preferred 

251 kwargDict = dict(database=database, 

252 driver=driver, 

253 host=host, 

254 port=port, 

255 verbose=verbose) 

256 

257 for key, value in kwargDict.items(): 

258 if value is not None or not hasattr(self, key): 

259 setattr(self, key, value) 

260 

261 self.connection = self._get_connection(self.database, self.driver, self.host, self.port, 

262 use_cache=cache_connection) 

263 

264 else: 

265 self.connection = connection 

266 self.database = connection.database 

267 self.driver = connection.driver 

268 self.host = connection.host 

269 self.port = connection.port 

270 self.verbose = connection.verbose 

271 

272 def _get_connection(self, database, driver, host, port, use_cache=True): 

273 """ 

274 Search self._connection_cache (if it exists; it won't for DBObject, but 

275 will for CatalogDBObject) for a DBConnection matching the specified 

276 parameters. If it exists, return it. If not, open a connection to 

277 the specified database, add it to the cache, and return the connection. 

278 

279 Parameters 

280 ---------- 

281 database is the name of the database file being connected to 

282 

283 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

284 

285 host is the URL of the remote host, if appropriate 

286 

287 port is the port on the remote host to connect to, if appropriate 

288 

289 use_cache is a boolean specifying whether or not we try to use the 

290 cache of database connections (you don't want to if opening many 

291 connections in many threads). 

292 """ 

293 

294 if use_cache and hasattr(self, '_connection_cache'): 

295 for conn in self._connection_cache: 

296 if str(conn.database) == str(database): 

297 if str(conn.driver) == str(driver): 

298 if str(conn.host) == str(host): 

299 if str(conn.port) == str(port): 

300 return conn 

301 

302 conn = DBConnection(database=database, driver=driver, host=host, port=port) 

303 

304 if use_cache and hasattr(self, '_connection_cache'): 

305 self._connection_cache.append(conn) 

306 

307 return conn 

308 

309 def get_table_names(self): 

310 """Return a list of the names of the tables (and views) in the database""" 

311 return [str(xx) for xx in inspect(self.connection.engine).get_table_names()] + \ 

312 [str(xx) for xx in inspect(self.connection.engine).get_view_names()] 

313 

314 def get_column_names(self, tableName=None): 

315 """ 

316 Return a list of the names of the columns in the specified table. 

317 If no table is specified, return a dict of lists. The dict will be keyed 

318 to the table names. The lists will be of the column names in that table 

319 """ 

320 tableNameList = self.get_table_names() 

321 if tableName is not None: 

322 if tableName not in tableNameList: 

323 return [] 

324 return [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(tableName)] 

325 else: 

326 columnDict = {} 

327 for name in tableNameList: 

328 columnList = [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(name)] 

329 columnDict[name] = columnList 

330 return columnDict 

331 

332 def _final_pass(self, results): 

333 """ Make final modifications to a set of data before returning it to the user 

334 

335 **Parameters** 

336 

337 * results : a structured array constructed from the result set from a query 

338 

339 **Returns** 

340 

341 * results : a potentially modified structured array. The default is to do nothing. 

342 

343 """ 

344 return results 

345 

346 def _convert_results_to_numpy_recarray_dbobj(self, results): 

347 if self.dtype is None: 

348 """ 

349 Determine the dtype from the data. 

350 Store it in a global variable so we do not have to repeat on every chunk. 

351 """ 

352 dataString = '' 

353 

354 # We are going to detect the dtype by reading in a single row 

355 # of data with np.genfromtxt. To do this, we must pass the 

356 # row as a string delimited by a specified character. Here we 

357 # select a character that does not occur anywhere in the data. 

358 delimit_char_list = [',', ';', '|', ':', '/', '\\'] 

359 delimit_char = None 

360 for cc in delimit_char_list: 

361 is_valid = True 

362 for xx in results[0]: 

363 if cc in str(xx): 

364 is_valid = False 

365 break 

366 

367 if is_valid: 

368 delimit_char = cc 

369 break 

370 

371 if delimit_char is None: 

372 raise RuntimeError("DBObject could not detect the dtype of your return rows\n" 

373 "Please specify a dtype with the 'dtype' kwarg.") 

374 

375 for xx in results[0]: 

376 if dataString != '': 

377 dataString += delimit_char 

378 dataString += str(xx) 

379 names = [str_cast(ww) for ww in results[0].keys()] 

380 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None, 

381 names=names, delimiter=delimit_char, 

382 encoding='utf-8') 

383 dt_list = [] 

384 for name in dataArr.dtype.names: 

385 type_name = str(dataArr.dtype[name]) 

386 sub_list = [name] 

387 if type_name.startswith('S') or type_name.startswith('|S'): 

388 sub_list.append(str_cast) 

389 sub_list.append(int(type_name.replace('S', '').replace('|', ''))) 

390 else: 

391 sub_list.append(dataArr.dtype[name]) 

392 dt_list.append(tuple(sub_list)) 

393 

394 self.dtype = numpy.dtype(dt_list) 

395 

396 if len(results) == 0: 

397 return numpy.recarray((0,), dtype=self.dtype) 

398 

399 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results], dtype=self.dtype) 

400 return retresults 

401 

402 def _postprocess_results(self, results): 

403 """ 

404 This wrapper exists so that a ChunkIterator built from a DBObject 

405 can have the same API as a ChunkIterator built from a CatalogDBObject 

406 """ 

407 return self._postprocess_arbitrary_results(results) 

408 

409 def _postprocess_arbitrary_results(self, results): 

410 

411 if not isinstance(results, numpy.recarray): 

412 retresults = self._convert_results_to_numpy_recarray_dbobj(results) 

413 else: 

414 retresults = results 

415 

416 return self._final_pass(retresults) 

417 

418 def execute_arbitrary(self, query, dtype=None): 

419 """ 

420 Executes an arbitrary query. Returns a recarray of the results. 

421 

422 dtype will be the dtype of the output recarray. If it is None, then 

423 the code will guess the datatype and assign generic names to the columns 

424 """ 

425 

426 is_string = isinstance(query, str) 

427 

428 if not is_string: 

429 raise RuntimeError("DBObject execute must be called with a string query") 

430 

431 unacceptableCommands = ["delete", "drop", "insert", "update"] 

432 for badCommand in unacceptableCommands: 

433 if query.lower().find(badCommand.lower()) >= 0: 

434 raise RuntimeError("query made to DBObject execute contained %s " % badCommand) 

435 

436 self.dtype = dtype 

437 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall()) 

438 return retresults 

439 

440 def get_arbitrary_chunk_iterator(self, query, chunk_size=None, dtype=None): 

441 """ 

442 This wrapper exists so that CatalogDBObjects can refer to 

443 get_arbitrary_chunk_iterator and DBObjects can refer to 

444 get_chunk_iterator 

445 """ 

446 return self.get_chunk_iterator(query, chunk_size=chunk_size, dtype=dtype) 

447 

448 def get_chunk_iterator(self, query, chunk_size=None, dtype=None): 

449 """ 

450 Take an arbitrary, user-specified query and return a ChunkIterator that 

451 executes that query 

452 

453 dtype will tell the ChunkIterator what datatype to expect for this query. 

454 This information gets passed to _postprocess_results. 

455 

456 If 'None', then _postprocess_results will just guess the datatype 

457 and return generic names for the columns. 

458 """ 

459 self.dtype = dtype 

460 return ChunkIterator(self, query, chunk_size, arbitrarySQL=True)