Coverage for python/lsst/sims/maf/db/dbObj.py : 16%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy
2from lsst.utils import getPackageDir
3import os
4from sqlalchemy.orm import scoped_session, sessionmaker
5from sqlalchemy.engine import reflection, url
6from sqlalchemy import (create_engine, MetaData, event, inspect)
7import warnings
8from io import BytesIO
9str_cast = str
12__all__ = ['DBObject']
15def valueOfPi():
16 """
17 A function to return the value of pi. This is needed for adding PI()
18 to sqlite databases
19 """
20 return numpy.pi
23def declareTrigFunctions(conn, connection_rec, connection_proxy):
24 """
25 A database event listener
26 which will define the math functions necessary for evaluating the
27 Haversine function in sqlite databases (where they are not otherwise
28 defined)
30 see: http://docs.sqlalchemy.org/en/latest/core/events.html
31 """
33 conn.create_function("COS", 1, numpy.cos)
34 conn.create_function("SIN", 1, numpy.sin)
35 conn.create_function("ASIN", 1, numpy.arcsin)
36 conn.create_function("SQRT", 1, numpy.sqrt)
37 conn.create_function("POWER", 2, numpy.power)
38 conn.create_function("PI", 0, valueOfPi)
41class ChunkIterator(object):
42 """Iterator for query chunks"""
43 def __init__(self, dbobj, query, chunk_size, arbitrarySQL=False):
44 self.dbobj = dbobj
45 self.exec_query = dbobj.connection.session.execute(query)
46 self.chunk_size = chunk_size
48 # arbitrarySQL exists in case a CatalogDBObject calls
49 # get_arbitrary_chunk_iterator; in that case, we need to
50 # be able to tell this object to call _postprocess_arbitrary_results,
51 # rather than _postprocess_results
52 self.arbitrarySQL = arbitrarySQL
54 def __iter__(self):
55 return self
57 def __next__(self):
58 if self.chunk_size is None and not self.exec_query.closed:
59 chunk = self.exec_query.fetchall()
60 return self._postprocess_results(chunk)
61 elif self.chunk_size is not None:
62 chunk = self.exec_query.fetchmany(self.chunk_size)
63 return self._postprocess_results(chunk)
64 else:
65 raise StopIteration
67 def _postprocess_results(self, chunk):
68 if len(chunk) == 0:
69 raise StopIteration
70 if self.arbitrarySQL:
71 return self.dbobj._postprocess_arbitrary_results(chunk)
72 else:
73 return self.dbobj._postprocess_results(chunk)
76class DBConnection(object):
77 """
78 This is a class that will hold the engine, session, and metadata for a
79 DBObject. This will allow multiple DBObjects to share the same
80 sqlalchemy connection, when appropriate.
81 """
83 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False):
84 """
85 @param [in] database is the name of the database file being connected to
87 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
89 @param [in] host is the URL of the remote host, if appropriate
91 @param [in] port is the port on the remote host to connect to, if appropriate
93 @param [in] verbose is a boolean controlling sqlalchemy's verbosity
94 """
96 self._database = database
97 self._driver = driver
98 self._host = host
99 self._port = port
100 self._verbose = verbose
102 self._validate_conn_params()
103 self._connect_to_engine()
105 def __del__(self):
106 try:
107 del self._metadata
108 except AttributeError:
109 pass
111 try:
112 del self._engine
113 except AttributeError:
114 pass
116 try:
117 del self._session
118 except AttributeError:
119 pass
121 def _connect_to_engine(self):
123 # Remove dbAuth things. Assume we are only connecting to a local database.
124 # Line to use when we update sqlalchemy
125 #dbUrl = url.URL.create(self._driver,
126 # database=self._database)
127 # Remove this line when sqlalchemy updated:
128 dbUrl = url.URL(self._driver, database=self._database)
130 self._engine = create_engine(dbUrl, echo=self._verbose)
132 if self._engine.dialect.name == 'sqlite':
133 event.listen(self._engine, 'checkout', declareTrigFunctions)
135 self._session = scoped_session(sessionmaker(autoflush=True,
136 bind=self._engine))
137 self._metadata = MetaData(bind=self._engine)
139 def _validate_conn_params(self):
140 """Validate connection parameters
142 - Check if user passed dbAddress instead of an database. Convert and warn.
143 - Check that required connection paramters are present
144 - Replace default host/port if driver is 'sqlite'
145 """
147 if self._database is None:
148 raise AttributeError("Cannot instantiate DBConnection; database is 'None'")
150 if '//' in self._database:
151 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. "
152 "Attempting to convert to database, driver, host, "
153 "and port parameters. Any usernames and passwords are ignored and must "
154 "be in the db-auth.paf policy file. " % (self.database), FutureWarning)
156 dbUrl = url.make_url(self._database)
157 dialect = dbUrl.get_dialect()
158 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name
159 for key, value in dbUrl.translate_connect_args().items():
160 if value is not None:
161 setattr(self, '_'+key, value)
163 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. "
164 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'."
165 if not hasattr(self, '_driver'):
166 raise AttributeError("%s has no attribute 'driver'. " % (self.__class__.__name__) + errMessage)
167 elif self._driver is None:
168 raise AttributeError("%s.driver is None. " % (self.__class__.__name__) + errMessage)
170 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. "
171 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. "
172 if not hasattr(self, '_database'):
173 raise AttributeError("%s has no attribute 'database'. " % (self.__class__.__name__) + errMessage)
174 elif self._database is None:
175 raise AttributeError("%s.database is None. " % (self.__class__.__name__) + errMessage)
177 if 'sqlite' in self._driver:
178 # When passed sqlite database, override default host/port
179 self._host = None
180 self._port = None
182 def __eq__(self, other):
183 return (str(self._database) == str(other._database)) and \
184 (str(self._driver) == str(other._driver)) and \
185 (str(self._host) == str(other._host)) and \
186 (str(self._port) == str(other._port))
188 @property
189 def engine(self):
190 return self._engine
192 @property
193 def session(self):
194 return self._session
196 @property
197 def metadata(self):
198 return self._metadata
200 @property
201 def database(self):
202 return self._database
204 @property
205 def driver(self):
206 return self._driver
208 @property
209 def host(self):
210 return self._host
212 @property
213 def port(self):
214 return self._port
216 @property
217 def verbose(self):
218 return self._verbose
221class DBObject(object):
223 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
224 connection=None, cache_connection=True):
225 """
226 Initialize DBObject.
228 @param [in] database is the name of the database file being connected to
230 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
232 @param [in] host is the URL of the remote host, if appropriate
234 @param [in] port is the port on the remote host to connect to, if appropriate
236 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False)
238 @param [in] connection is an optional instance of DBConnection, in the event that
239 this DBObject can share a database connection with another DBObject. This is only
240 necessary or even possible in a few specialized cases and should be used carefully.
242 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of
243 DBConnections (if available) to get the connection to this database.
244 """
246 self.dtype = None
247 # this is a cache for the query, so that any one query does not have to guess dtype multiple times
249 if connection is None:
250 # Explicit constructor to DBObject preferred
251 kwargDict = dict(database=database,
252 driver=driver,
253 host=host,
254 port=port,
255 verbose=verbose)
257 for key, value in kwargDict.items():
258 if value is not None or not hasattr(self, key):
259 setattr(self, key, value)
261 self.connection = self._get_connection(self.database, self.driver, self.host, self.port,
262 use_cache=cache_connection)
264 else:
265 self.connection = connection
266 self.database = connection.database
267 self.driver = connection.driver
268 self.host = connection.host
269 self.port = connection.port
270 self.verbose = connection.verbose
272 def _get_connection(self, database, driver, host, port, use_cache=True):
273 """
274 Search self._connection_cache (if it exists; it won't for DBObject, but
275 will for CatalogDBObject) for a DBConnection matching the specified
276 parameters. If it exists, return it. If not, open a connection to
277 the specified database, add it to the cache, and return the connection.
279 Parameters
280 ----------
281 database is the name of the database file being connected to
283 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
285 host is the URL of the remote host, if appropriate
287 port is the port on the remote host to connect to, if appropriate
289 use_cache is a boolean specifying whether or not we try to use the
290 cache of database connections (you don't want to if opening many
291 connections in many threads).
292 """
294 if use_cache and hasattr(self, '_connection_cache'):
295 for conn in self._connection_cache:
296 if str(conn.database) == str(database):
297 if str(conn.driver) == str(driver):
298 if str(conn.host) == str(host):
299 if str(conn.port) == str(port):
300 return conn
302 conn = DBConnection(database=database, driver=driver, host=host, port=port)
304 if use_cache and hasattr(self, '_connection_cache'):
305 self._connection_cache.append(conn)
307 return conn
309 def get_table_names(self):
310 """Return a list of the names of the tables (and views) in the database"""
311 return [str(xx) for xx in inspect(self.connection.engine).get_table_names()] + \
312 [str(xx) for xx in inspect(self.connection.engine).get_view_names()]
314 def get_column_names(self, tableName=None):
315 """
316 Return a list of the names of the columns in the specified table.
317 If no table is specified, return a dict of lists. The dict will be keyed
318 to the table names. The lists will be of the column names in that table
319 """
320 tableNameList = self.get_table_names()
321 if tableName is not None:
322 if tableName not in tableNameList:
323 return []
324 return [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(tableName)]
325 else:
326 columnDict = {}
327 for name in tableNameList:
328 columnList = [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(name)]
329 columnDict[name] = columnList
330 return columnDict
332 def _final_pass(self, results):
333 """ Make final modifications to a set of data before returning it to the user
335 **Parameters**
337 * results : a structured array constructed from the result set from a query
339 **Returns**
341 * results : a potentially modified structured array. The default is to do nothing.
343 """
344 return results
346 def _convert_results_to_numpy_recarray_dbobj(self, results):
347 if self.dtype is None:
348 """
349 Determine the dtype from the data.
350 Store it in a global variable so we do not have to repeat on every chunk.
351 """
352 dataString = ''
354 # We are going to detect the dtype by reading in a single row
355 # of data with np.genfromtxt. To do this, we must pass the
356 # row as a string delimited by a specified character. Here we
357 # select a character that does not occur anywhere in the data.
358 delimit_char_list = [',', ';', '|', ':', '/', '\\']
359 delimit_char = None
360 for cc in delimit_char_list:
361 is_valid = True
362 for xx in results[0]:
363 if cc in str(xx):
364 is_valid = False
365 break
367 if is_valid:
368 delimit_char = cc
369 break
371 if delimit_char is None:
372 raise RuntimeError("DBObject could not detect the dtype of your return rows\n"
373 "Please specify a dtype with the 'dtype' kwarg.")
375 for xx in results[0]:
376 if dataString != '':
377 dataString += delimit_char
378 dataString += str(xx)
379 names = [str_cast(ww) for ww in results[0].keys()]
380 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None,
381 names=names, delimiter=delimit_char,
382 encoding='utf-8')
383 dt_list = []
384 for name in dataArr.dtype.names:
385 type_name = str(dataArr.dtype[name])
386 sub_list = [name]
387 if type_name.startswith('S') or type_name.startswith('|S'):
388 sub_list.append(str_cast)
389 sub_list.append(int(type_name.replace('S', '').replace('|', '')))
390 else:
391 sub_list.append(dataArr.dtype[name])
392 dt_list.append(tuple(sub_list))
394 self.dtype = numpy.dtype(dt_list)
396 if len(results) == 0:
397 return numpy.recarray((0,), dtype=self.dtype)
399 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results], dtype=self.dtype)
400 return retresults
402 def _postprocess_results(self, results):
403 """
404 This wrapper exists so that a ChunkIterator built from a DBObject
405 can have the same API as a ChunkIterator built from a CatalogDBObject
406 """
407 return self._postprocess_arbitrary_results(results)
409 def _postprocess_arbitrary_results(self, results):
411 if not isinstance(results, numpy.recarray):
412 retresults = self._convert_results_to_numpy_recarray_dbobj(results)
413 else:
414 retresults = results
416 return self._final_pass(retresults)
418 def execute_arbitrary(self, query, dtype=None):
419 """
420 Executes an arbitrary query. Returns a recarray of the results.
422 dtype will be the dtype of the output recarray. If it is None, then
423 the code will guess the datatype and assign generic names to the columns
424 """
426 is_string = isinstance(query, str)
428 if not is_string:
429 raise RuntimeError("DBObject execute must be called with a string query")
431 unacceptableCommands = ["delete", "drop", "insert", "update"]
432 for badCommand in unacceptableCommands:
433 if query.lower().find(badCommand.lower()) >= 0:
434 raise RuntimeError("query made to DBObject execute contained %s " % badCommand)
436 self.dtype = dtype
437 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall())
438 return retresults
440 def get_arbitrary_chunk_iterator(self, query, chunk_size=None, dtype=None):
441 """
442 This wrapper exists so that CatalogDBObjects can refer to
443 get_arbitrary_chunk_iterator and DBObjects can refer to
444 get_chunk_iterator
445 """
446 return self.get_chunk_iterator(query, chunk_size=chunk_size, dtype=dtype)
448 def get_chunk_iterator(self, query, chunk_size=None, dtype=None):
449 """
450 Take an arbitrary, user-specified query and return a ChunkIterator that
451 executes that query
453 dtype will tell the ChunkIterator what datatype to expect for this query.
454 This information gets passed to _postprocess_results.
456 If 'None', then _postprocess_results will just guess the datatype
457 and return generic names for the columns.
458 """
459 self.dtype = dtype
460 return ChunkIterator(self, query, chunk_size, arbitrarySQL=True)