Coverage for python/lsst/sims/maf/db/database.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import print_function
2from future.utils import with_metaclass
3import os
4import inspect
5import numpy as np
6from sqlalchemy import func, text, column
7from sqlalchemy import Table
8import sqlalchemy
9from .dbObj import DBObject
12__all__ = ['DatabaseRegistry', 'Database']
15class DatabaseRegistry(type):
16 """
17 Meta class for databases, to build a registry of database classes.
18 """
19 def __init__(cls, name, bases, dict):
20 super(DatabaseRegistry, cls).__init__(name, bases, dict)
21 if not hasattr(cls, 'registry'):
22 cls.registry = {}
23 modname = inspect.getmodule(cls).__name__ + '.'
24 if modname.startswith('lsst.sims.maf.db'):
25 modname = ''
26 else:
27 if len(modname.split('.')) > 1: 27 ↛ 30line 27 didn't jump to line 30, because the condition on line 27 was never false
28 modname = '.'.join(modname.split('.')[:-1]) + '.'
29 else:
30 modname = modname + '.'
31 databasename = modname + name
32 if databasename in cls.registry: 32 ↛ 33line 32 didn't jump to line 33, because the condition on line 32 was never true
33 raise Exception('Redefining databases %s! (there are >1 database classes with the same name)'
34 % (databasename))
35 if databasename not in ['BaseDatabase']: 35 ↛ exitline 35 didn't return from function '__init__', because the condition on line 35 was never false
36 cls.registry[databasename] = cls
38 def getClass(cls, databasename):
39 return cls.registry[databasename]
41 def help(cls, doc=False):
42 for databasename in sorted(cls.registry):
43 if not doc:
44 print(databasename)
45 if doc:
46 print('---- ', databasename, ' ----')
47 print(inspect.getdoc(cls.registry[databasename]))
50class Database(with_metaclass(DatabaseRegistry, DBObject)):
51 """Base class for database access. Implements some basic query functionality and demonstrates API.
53 Parameters
54 ----------
55 database : str
56 Name of the database (or full path + filename for sqlite db).
57 driver : str, opt
58 Dialect+driver for sqlalchemy. Default 'sqlite'. (other examples, 'pymssql+mssql').
59 host : str, opt
60 Hostname for database. Default None (for sqlite).
61 port : int, opt
62 Port for database. Default None.
63 defaultTable : str, opt
64 Default table in the database to query for metric data.
65 longstrings : bool, opt
66 Flag to convert strings in database to long (1024) or short (256) characters in numpy recarray.
67 Default False (convert to 256 character strings).
68 verbose : bool, opt
69 Flag for additional output. Default False.
70 """
72 def __init__(self, database, driver='sqlite', host=None, port=None, defaultTable=None,
73 longstrings=False, verbose=False):
74 # If it's a sqlite file, check that the filename exists.
75 # This gives a more understandable error message than trying to connect to non-existent file later.
76 if driver == 'sqlite':
77 if not os.path.isfile(database):
78 raise IOError('Sqlite database file "%s" not found.' % (database))
80 # Connect to database using DBObject init.
81 super(Database, self).__init__(database=database, driver=driver,
82 host=host, port=port, verbose=verbose, connection=None)
84 self.dbTypeMap = {'BIGINT': (int,), 'BOOLEAN': (bool,), 'FLOAT': (float,), 'INTEGER': (int,),
85 'NUMERIC': (float,), 'SMALLINT': (int,), 'TINYINT': (int,),
86 'VARCHAR': (np.str_, 256), 'TEXT': (np.str_, 256), 'CLOB': (np.str_, 256),
87 'NVARCHAR': (np.str_, 256), 'NCLOB': (np.str_, 256), 'NTEXT': (np.str_, 256),
88 'CHAR': (np.str_, 1), 'INT': (int,), 'REAL': (float,), 'DOUBLE': (float,),
89 'STRING': (np.str_, 256), 'DOUBLE_PRECISION': (float,), 'DECIMAL': (float,),
90 'DATETIME': (np.str_, 50)}
91 if longstrings:
92 typeOverRide = {'VARCHAR': (np.str_, 1024), 'NVARCHAR': (np.str_, 1024),
93 'TEXT': (np.str_, 1024), 'CLOB': (np.str_, 1024),
94 'STRING': (np.str_, 1024)}
96 self.dbTypeMap.update(typeOverRide)
98 # Get a dict (keyed by the table names) of all the columns in each table and view.
99 self.tableNames = sqlalchemy.inspect(self.connection.engine).get_table_names()
100 self.tableNames += sqlalchemy.inspect(self.connection.engine).get_view_names()
101 self.columnNames = {}
102 for t in self.tableNames:
103 cols = sqlalchemy.inspect(self.connection.engine).get_columns(t)
104 self.columnNames[t] = [xxx['name'] for xxx in cols]
105 # Create all the sqlalchemy table objects. This lets us see the schema and query it with types.
106 self.tables = {}
107 for tablename in self.tableNames:
108 self.tables[tablename] = Table(tablename, self.connection.metadata, autoload=True)
109 self.defaultTable = defaultTable
110 # if there is is only one table and we haven't said otherwise, set defaultTable automatically.
111 if self.defaultTable is None and len(self.tableNames) == 1:
112 self.defaultTable = self.tableNames[0]
114 def close(self):
115 self.connection.session.close()
116 self.connection.engine.dispose()
118 def fetchMetricData(self, colnames, sqlconstraint=None, groupBy=None, tableName=None):
119 """Fetch 'colnames' from 'tableName'.
121 This is basically a thin wrapper around query_columns, but uses the default table.
122 It's mostly still here for backward compatibility.
124 Parameters
125 ----------
126 colnames : list
127 The columns to fetch from the table.
128 sqlconstraint : str or None, opt
129 The sql constraint to apply to the data (minus "WHERE"). Default None.
130 Examples: to fetch data for the r band filter only, set sqlconstraint to 'filter = "r"'.
131 groupBy : str or None, opt
132 The column to group the returned data by.
133 Default (when using summaryTable) is the MJD, otherwise will be None.
134 tableName : str or None, opt
135 The table to query. The default (None) will use the summary table, set by self.defaultTable.
137 Returns
138 -------
139 np.recarray
140 A structured array containing the data queried from the database.
141 """
142 if tableName is None:
143 tableName = self.defaultTable
145 # For a basic Database object, there is no default column to group by. So reset to None.
146 if groupBy == 'default':
147 groupBy = None
149 if tableName not in self.tableNames:
150 raise ValueError('Table %s not recognized; not in list of database tables.' % (tableName))
152 metricdata = self.query_columns(tableName, colnames=colnames, sqlconstraint=sqlconstraint,
153 groupBy=groupBy)
154 return metricdata
156 def fetchConfig(self, *args, **kwargs):
157 """Get config (metadata) info on source of data for metric calculation.
158 """
159 # Demo API (for interface with driver).
160 configSummary = {}
161 configDetails = {}
162 return configSummary, configDetails
164 def query_arbitrary(self, sqlQuery, dtype=None):
165 """Simple wrapper around execute_arbitrary for backwards compatibility.
167 Parameters
168 -----------
169 sqlQuery : str
170 SQL query.
171 dtype: opt, numpy dtype.
172 Numpy recarray dtype. If None, then an attempt to determine the dtype will be made.
173 This attempt will fail if there are commas in the data you query.
175 Returns
176 -------
177 numpy.recarray
178 """
179 return self.execute_arbitrary(sqlQuery, dtype=dtype)
181 def query_columns(self, tablename, colnames=None, sqlconstraint=None,
182 groupBy=None, numLimit=None, chunksize=1000000):
183 """Query a table in the database and return data from colnames in recarray.
185 Parameters
186 ----------
187 tablename : str
188 Name of table to query.
189 colnames : list of str or None, opt
190 Columns from the table to query for. If None, all columns are selected.
191 sqlconstraint : str or None, opt
192 Constraint to apply to to the query. Default None.
193 groupBy : str or None, opt
194 Name of column to group by. Default None.
195 numLimit : int or None, opt
196 Number of records to return. Default no limit.
197 chunksize : int, opt
198 Query database and convert to recarray in series of chunks of chunksize.
200 Returns
201 -------
202 numpy.recarray
203 """
204 # Build the sqlalchemy query from a single table, with various columns/constraints/etc.
205 # Does NOT use a mapping between column names and database names - assumes the database names
206 # are what the user will specify.
208 # Build the query.
209 tablename_str = str(tablename).replace('"', '')
210 query = self._build_query(tablename, colnames=colnames, sqlconstraint=sqlconstraint,
211 groupBy=groupBy, numLimit=numLimit)
213 # Determine dtype for numpy recarray.
214 dtype = []
215 for col in colnames:
216 ty = self.tables[tablename_str].c[str(col).replace('"', '')].type
217 dt = self.dbTypeMap[ty.__visit_name__]
218 try:
219 # Override the default length, if the type has it
220 # (for example, if it is VARCHAR(1))
221 if ty.length is not None:
222 dt = dt[:-1] + (ty.length,)
223 except AttributeError:
224 pass
225 dtype.append((str(col).replace('"', ''),) + dt)
227 # Execute query on database.
228 exec_query = self.connection.session.execute(query)
230 if chunksize is None or chunksize == 0:
231 # Fetch all results and convert to numpy recarray.
232 results = exec_query.fetchall()
233 data = self._convert_results(results, dtype)
234 else:
235 chunks = []
236 # Loop through results, converting in steps of chunksize.
237 results = exec_query.fetchmany(chunksize)
238 while len(results) > 0:
239 chunks.append(self._convert_results(results, dtype))
240 results = exec_query.fetchmany(chunksize)
241 if len(chunks) == 0:
242 data = np.recarray((0,), dtype=dtype)
243 else:
244 data = np.hstack(chunks)
245 return data
247 def _build_query(self, tablename, colnames, sqlconstraint=None, groupBy=None, numLimit=None):
248 tablename_str = str(tablename).replace('"', '')
249 if tablename_str not in self.tables:
250 raise ValueError('Tablename %s not in list of available tables (%s).'
251 % (tablename, self.tables.keys()))
252 if colnames is None:
253 colnames = self.columnNames[tablename]
254 else:
255 for col in colnames:
256 if str(col).replace('"', '') not in self.columnNames[tablename_str]:
257 raise ValueError("Requested column %s not available in table %s" % (col, tablename_str))
258 if groupBy is not None:
259 if str(groupBy).replace('"', '') not in self.columnNames[tablename]:
260 raise ValueError("GroupBy column %s is not available in table %s" % (groupBy, tablename_str))
261 # Put together sqlalchemy query object.
262 for col in colnames:
263 if col == colnames[0]:
264 query = self.connection.session.query(column(col))
265 else:
266 query = query.add_columns(column(col))
267 query = query.select_from(self.tables[tablename_str])
268 if sqlconstraint is not None:
269 if len(sqlconstraint) > 0:
270 query = query.filter(text(sqlconstraint))
271 if groupBy is not None:
272 query = query.group_by(groupBy)
273 if numLimit is not None:
274 query = query.limit(numLimit)
275 return query
277 def _convert_results(self, results, dtype):
278 if len(results) == 0:
279 data = np.recarray((0,), dtype=dtype)
280 else:
281 # Have to do the tuple(xx) for py2 string objects. With py3 is okay to just pass results.
282 data = np.rec.fromrecords([tuple(xx) for xx in results], dtype=dtype)
283 return data