Coverage for python/lsst/sims/maf/db/database.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import print_function
2from future.utils import with_metaclass
3import os
4import inspect
5import numpy as np
6from sqlalchemy import func, text
7from sqlalchemy import Table
8from sqlalchemy.engine import reflection
9import warnings
10with warnings.catch_warnings():
11 warnings.simplefilter("ignore", UserWarning)
12 from lsst.sims.catalogs.db import DBObject
15__all__ = ['DatabaseRegistry', 'Database']
17class DatabaseRegistry(type):
18 """
19 Meta class for databases, to build a registry of database classes.
20 """
21 def __init__(cls, name, bases, dict):
22 super(DatabaseRegistry, cls).__init__(name, bases, dict)
23 if not hasattr(cls, 'registry'):
24 cls.registry = {}
25 modname = inspect.getmodule(cls).__name__ + '.'
26 if modname.startswith('lsst.sims.maf.db'):
27 modname = ''
28 else:
29 if len(modname.split('.')) > 1: 29 ↛ 32line 29 didn't jump to line 32, because the condition on line 29 was never false
30 modname = '.'.join(modname.split('.')[:-1]) + '.'
31 else:
32 modname = modname + '.'
33 databasename = modname + name
34 if databasename in cls.registry: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true
35 raise Exception('Redefining databases %s! (there are >1 database classes with the same name)'
36 %(databasename))
37 if databasename not in ['BaseDatabase']: 37 ↛ exitline 37 didn't return from function '__init__', because the condition on line 37 was never false
38 cls.registry[databasename] = cls
39 def getClass(cls, databasename):
40 return cls.registry[databasename]
41 def help(cls, doc=False):
42 for databasename in sorted(cls.registry):
43 if not doc:
44 print(databasename)
45 if doc:
46 print('---- ', databasename, ' ----')
47 print(inspect.getdoc(cls.registry[databasename]))
50class Database(with_metaclass(DatabaseRegistry, DBObject)):
51 """Base class for database access. Implements some basic query functionality and demonstrates API.
53 Parameters
54 ----------
55 database : str
56 Name of the database (or full path + filename for sqlite db).
57 driver : str, opt
58 Dialect+driver for sqlalchemy. Default 'sqlite'. (other examples, 'pymssql+mssql').
59 host : str, opt
60 Hostname for database. Default None (for sqlite).
61 port : int, opt
62 Port for database. Default None.
63 defaultTable : str, opt
64 Default table in the database to query for metric data.
65 longstrings : bool, opt
66 Flag to convert strings in database to long (1024) or short (256) characters in numpy recarray.
67 Default False (convert to 256 character strings).
68 verbose : bool, opt
69 Flag for additional output. Default False.
70 """
72 def __init__(self, database, driver='sqlite', host=None, port=None, defaultTable=None,
73 longstrings=False, verbose=False):
74 # If it's a sqlite file, check that the filename exists.
75 # This gives a more understandable error message than trying to connect to non-existent file later.
76 if driver=='sqlite':
77 if not os.path.isfile(database):
78 raise IOError('Sqlite database file "%s" not found.' %(database))
80 # Connect to database using DBObject init.
81 super(Database, self).__init__(database=database, driver=driver,
82 host=host, port=port, verbose=verbose, connection=None)
84 self.dbTypeMap = {'BIGINT': (int,), 'BOOLEAN': (bool,), 'FLOAT': (float,), 'INTEGER': (int,),
85 'NUMERIC': (float,), 'SMALLINT': (int,), 'TINYINT': (int,),
86 'VARCHAR': (np.str, 256), 'TEXT': (np.str, 256), 'CLOB': (np.str, 256),
87 'NVARCHAR': (np.str, 256), 'NCLOB': (np.str, 256), 'NTEXT': (np.str, 256),
88 'CHAR': (np.str, 1), 'INT': (int,), 'REAL': (float,), 'DOUBLE': (float,),
89 'STRING': (np.str, 256), 'DOUBLE_PRECISION': (float,), 'DECIMAL': (float,),
90 'DATETIME': (np.str, 50)}
91 if longstrings:
92 typeOverRide = {'VARCHAR':(np.str, 1024), 'NVARCHAR':(np.str, 1024),
93 'TEXT':(np.str, 1024), 'CLOB':(np.str, 1024),
94 'STRING':(np.str, 1024)}
95 self.dbTypeMap.update(typeOverRide)
97 # Get a dict (keyed by the table names) of all the columns in each table and view.
98 self.tableNames = reflection.Inspector.from_engine(self.connection.engine).get_table_names()
99 self.tableNames += reflection.Inspector.from_engine(self.connection.engine).get_view_names()
100 self.columnNames = {}
101 for t in self.tableNames:
102 cols = reflection.Inspector.from_engine(self.connection.engine).get_columns(t)
103 self.columnNames[t] = [xxx['name'] for xxx in cols]
104 # Create all the sqlalchemy table objects. This lets us see the schema and query it with types.
105 self.tables = {}
106 for tablename in self.tableNames:
107 self.tables[tablename] = Table(tablename, self.connection.metadata, autoload=True)
108 self.defaultTable = defaultTable
109 # if there is is only one table and we haven't said otherwise, set defaultTable automatically.
110 if self.defaultTable is None and len(self.tableNames) == 1:
111 self.defaultTable = self.tableNames[0]
113 def close(self):
114 self.connection.session.close()
115 self.connection.engine.dispose()
117 def fetchMetricData(self, colnames, sqlconstraint=None, groupBy=None, tableName=None):
118 """Fetch 'colnames' from 'tableName'.
120 This is basically a thin wrapper around query_columns, but uses the default table.
121 It's mostly still here for backward compatibility.
123 Parameters
124 ----------
125 colnames : list
126 The columns to fetch from the table.
127 sqlconstraint : str or None, opt
128 The sql constraint to apply to the data (minus "WHERE"). Default None.
129 Examples: to fetch data for the r band filter only, set sqlconstraint to 'filter = "r"'.
130 groupBy : str or None, opt
131 The column to group the returned data by.
132 Default (when using summaryTable) is the MJD, otherwise will be None.
133 tableName : str or None, opt
134 The table to query. The default (None) will use the summary table, set by self.defaultTable.
136 Returns
137 -------
138 np.recarray
139 A structured array containing the data queried from the database.
140 """
141 if tableName is None:
142 tableName = self.defaultTable
144 # For a basic Database object, there is no default column to group by. So reset to None.
145 if groupBy == 'default':
146 groupBy = None
148 if tableName not in self.tableNames:
149 raise ValueError('Table %s not recognized; not in list of database tables.' % (tableName))
151 metricdata = self.query_columns(tableName, colnames=colnames, sqlconstraint=sqlconstraint,
152 groupBy=groupBy)
153 return metricdata
155 def fetchConfig(self, *args, **kwargs):
156 """Get config (metadata) info on source of data for metric calculation.
157 """
158 # Demo API (for interface with driver).
159 configSummary = {}
160 configDetails = {}
161 return configSummary, configDetails
163 def query_arbitrary(self, sqlQuery, dtype=None):
164 """Simple wrapper around execute_arbitrary for backwards compatibility.
166 Parameters
167 -----------
168 sqlQuery : str
169 SQL query.
170 dtype: opt, numpy dtype.
171 Numpy recarray dtype. If None, then an attempt to determine the dtype will be made.
172 This attempt will fail if there are commas in the data you query.
174 Returns
175 -------
176 numpy.recarray
177 """
178 return self.execute_arbitrary(sqlQuery, dtype=dtype)
180 def query_columns(self, tablename, colnames=None, sqlconstraint=None,
181 groupBy=None, numLimit=None, chunksize=1000000):
182 """Query a table in the database and return data from colnames in recarray.
184 Parameters
185 ----------
186 tablename : str
187 Name of table to query.
188 colnames : list of str or None, opt
189 Columns from the table to query for. If None, all columns are selected.
190 sqlconstraint : str or None, opt
191 Constraint to apply to to the query. Default None.
192 groupBy : str or None, opt
193 Name of column to group by. Default None.
194 numLimit : int or None, opt
195 Number of records to return. Default no limit.
196 chunksize : int, opt
197 Query database and convert to recarray in series of chunks of chunksize.
199 Returns
200 -------
201 numpy.recarray
202 """
203 # Build the sqlalchemy query from a single table, with various columns/constraints/etc.
204 # Does NOT use a mapping between column names and database names - assumes the database names
205 # are what the user will specify.
207 # Build the query.
208 query = self._build_query(tablename, colnames=colnames, sqlconstraint=sqlconstraint,
209 groupBy=groupBy, numLimit=numLimit)
211 # Determine dtype for numpy recarray.
212 dtype = []
213 for col in colnames:
214 ty = self.tables[tablename].c[col].type
215 dt = self.dbTypeMap[ty.__visit_name__]
216 try:
217 # Override the default length, if the type has it
218 # (for example, if it is VARCHAR(1))
219 if ty.length is not None:
220 dt = dt[:-1] + (ty.length,)
221 except AttributeError:
222 pass
223 dtype.append((col,) + dt)
225 # Execute query on database.
226 exec_query = self.connection.session.execute(query)
228 if chunksize is None or chunksize==0:
229 # Fetch all results and convert to numpy recarray.
230 results = exec_query.fetchall()
231 data = self._convert_results(results, dtype)
232 else:
233 chunks = []
234 # Loop through results, converting in steps of chunksize.
235 results = exec_query.fetchmany(chunksize)
236 while len(results) > 0:
237 chunks.append(self._convert_results(results, dtype))
238 results = exec_query.fetchmany(chunksize)
239 if len(chunks) == 0:
240 data = np.recarray((0,), dtype=dtype)
241 else:
242 data = np.hstack(chunks)
243 return data
245 def _build_query(self, tablename, colnames, sqlconstraint=None, groupBy=None, numLimit=None):
246 if tablename not in self.tables:
247 raise ValueError('Tablename %s not in list of available tables (%s).'
248 % (tablename, self.tables.keys()))
249 if colnames is None:
250 colnames = self.columnNames[tablename]
251 else:
252 for col in colnames:
253 if col not in self.columnNames[tablename]:
254 raise ValueError("Requested column %s not available in table %s" % (col, tablename))
255 if groupBy is not None:
256 if groupBy not in self.columnNames[tablename]:
257 raise ValueError("GroupBy column %s is not available in table %s" % (groupBy, tablename))
258 # Put together sqlalchemy query object.
259 for col in colnames:
260 if col == colnames[0]:
261 query = self.connection.session.query(col)
262 else:
263 query = query.add_columns(col)
264 query = query.select_from(self.tables[tablename])
265 if sqlconstraint is not None:
266 if len(sqlconstraint) > 0:
267 query = query.filter(text(sqlconstraint))
268 if groupBy is not None:
269 query = query.group_by(groupBy)
270 if numLimit is not None:
271 query = query.limit(numLimit)
272 return query
274 def _convert_results(self, results, dtype):
275 if len(results) == 0:
276 data = np.recarray((0,), dtype=dtype)
277 else:
278 # Have to do the tuple(xx) for py2 string objects. With py3 is okay to just pass results.
279 data = np.rec.fromrecords([tuple(xx) for xx in results], dtype=dtype)
280 return data