Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import print_function 

2from future.utils import with_metaclass 

3import os 

4import inspect 

5import numpy as np 

6from sqlalchemy import func, text 

7from sqlalchemy import Table 

8from sqlalchemy.engine import reflection 

9import warnings 

10with warnings.catch_warnings(): 

11 warnings.simplefilter("ignore", UserWarning) 

12 from lsst.sims.catalogs.db import DBObject 

13 

14 

15__all__ = ['DatabaseRegistry', 'Database'] 

16 

17class DatabaseRegistry(type): 

18 """ 

19 Meta class for databases, to build a registry of database classes. 

20 """ 

21 def __init__(cls, name, bases, dict): 

22 super(DatabaseRegistry, cls).__init__(name, bases, dict) 

23 if not hasattr(cls, 'registry'): 

24 cls.registry = {} 

25 modname = inspect.getmodule(cls).__name__ + '.' 

26 if modname.startswith('lsst.sims.maf.db'): 

27 modname = '' 

28 else: 

29 if len(modname.split('.')) > 1: 29 ↛ 32line 29 didn't jump to line 32, because the condition on line 29 was never false

30 modname = '.'.join(modname.split('.')[:-1]) + '.' 

31 else: 

32 modname = modname + '.' 

33 databasename = modname + name 

34 if databasename in cls.registry: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true

35 raise Exception('Redefining databases %s! (there are >1 database classes with the same name)' 

36 %(databasename)) 

37 if databasename not in ['BaseDatabase']: 37 ↛ exitline 37 didn't return from function '__init__', because the condition on line 37 was never false

38 cls.registry[databasename] = cls 

39 def getClass(cls, databasename): 

40 return cls.registry[databasename] 

41 def help(cls, doc=False): 

42 for databasename in sorted(cls.registry): 

43 if not doc: 

44 print(databasename) 

45 if doc: 

46 print('---- ', databasename, ' ----') 

47 print(inspect.getdoc(cls.registry[databasename])) 

48 

49 

50class Database(with_metaclass(DatabaseRegistry, DBObject)): 

51 """Base class for database access. Implements some basic query functionality and demonstrates API. 

52 

53 Parameters 

54 ---------- 

55 database : str 

56 Name of the database (or full path + filename for sqlite db). 

57 driver : str, opt 

58 Dialect+driver for sqlalchemy. Default 'sqlite'. (other examples, 'pymssql+mssql'). 

59 host : str, opt 

60 Hostname for database. Default None (for sqlite). 

61 port : int, opt 

62 Port for database. Default None. 

63 defaultTable : str, opt 

64 Default table in the database to query for metric data. 

65 longstrings : bool, opt 

66 Flag to convert strings in database to long (1024) or short (256) characters in numpy recarray. 

67 Default False (convert to 256 character strings). 

68 verbose : bool, opt 

69 Flag for additional output. Default False. 

70 """ 

71 

72 def __init__(self, database, driver='sqlite', host=None, port=None, defaultTable=None, 

73 longstrings=False, verbose=False): 

74 # If it's a sqlite file, check that the filename exists. 

75 # This gives a more understandable error message than trying to connect to non-existent file later. 

76 if driver=='sqlite': 

77 if not os.path.isfile(database): 

78 raise IOError('Sqlite database file "%s" not found.' %(database)) 

79 

80 # Connect to database using DBObject init. 

81 super(Database, self).__init__(database=database, driver=driver, 

82 host=host, port=port, verbose=verbose, connection=None) 

83 

84 self.dbTypeMap = {'BIGINT': (int,), 'BOOLEAN': (bool,), 'FLOAT': (float,), 'INTEGER': (int,), 

85 'NUMERIC': (float,), 'SMALLINT': (int,), 'TINYINT': (int,), 

86 'VARCHAR': (np.str_, 256), 'TEXT': (np.str_, 256), 'CLOB': (np.str_, 256), 

87 'NVARCHAR': (np.str_, 256), 'NCLOB': (np.str_, 256), 'NTEXT': (np.str_, 256), 

88 'CHAR': (np.str_, 1), 'INT': (int,), 'REAL': (float,), 'DOUBLE': (float,), 

89 'STRING': (np.str_, 256), 'DOUBLE_PRECISION': (float,), 'DECIMAL': (float,), 

90 'DATETIME': (np.str_, 50)} 

91 if longstrings: 

92 typeOverRide = {'VARCHAR':(np.str_, 1024), 'NVARCHAR':(np.str_, 1024), 

93 'TEXT':(np.str_, 1024), 'CLOB':(np.str_, 1024), 

94 'STRING':(np.str_, 1024)} 

95 self.dbTypeMap.update(typeOverRide) 

96 

97 # Get a dict (keyed by the table names) of all the columns in each table and view. 

98 self.tableNames = reflection.Inspector.from_engine(self.connection.engine).get_table_names() 

99 self.tableNames += reflection.Inspector.from_engine(self.connection.engine).get_view_names() 

100 self.columnNames = {} 

101 for t in self.tableNames: 

102 cols = reflection.Inspector.from_engine(self.connection.engine).get_columns(t) 

103 self.columnNames[t] = [xxx['name'] for xxx in cols] 

104 # Create all the sqlalchemy table objects. This lets us see the schema and query it with types. 

105 self.tables = {} 

106 for tablename in self.tableNames: 

107 self.tables[tablename] = Table(tablename, self.connection.metadata, autoload=True) 

108 self.defaultTable = defaultTable 

109 # if there is is only one table and we haven't said otherwise, set defaultTable automatically. 

110 if self.defaultTable is None and len(self.tableNames) == 1: 

111 self.defaultTable = self.tableNames[0] 

112 

113 def close(self): 

114 self.connection.session.close() 

115 self.connection.engine.dispose() 

116 

117 def fetchMetricData(self, colnames, sqlconstraint=None, groupBy=None, tableName=None): 

118 """Fetch 'colnames' from 'tableName'. 

119 

120 This is basically a thin wrapper around query_columns, but uses the default table. 

121 It's mostly still here for backward compatibility. 

122 

123 Parameters 

124 ---------- 

125 colnames : list 

126 The columns to fetch from the table. 

127 sqlconstraint : str or None, opt 

128 The sql constraint to apply to the data (minus "WHERE"). Default None. 

129 Examples: to fetch data for the r band filter only, set sqlconstraint to 'filter = "r"'. 

130 groupBy : str or None, opt 

131 The column to group the returned data by. 

132 Default (when using summaryTable) is the MJD, otherwise will be None. 

133 tableName : str or None, opt 

134 The table to query. The default (None) will use the summary table, set by self.defaultTable. 

135 

136 Returns 

137 ------- 

138 np.recarray 

139 A structured array containing the data queried from the database. 

140 """ 

141 if tableName is None: 

142 tableName = self.defaultTable 

143 

144 # For a basic Database object, there is no default column to group by. So reset to None. 

145 if groupBy == 'default': 

146 groupBy = None 

147 

148 if tableName not in self.tableNames: 

149 raise ValueError('Table %s not recognized; not in list of database tables.' % (tableName)) 

150 

151 metricdata = self.query_columns(tableName, colnames=colnames, sqlconstraint=sqlconstraint, 

152 groupBy=groupBy) 

153 return metricdata 

154 

155 def fetchConfig(self, *args, **kwargs): 

156 """Get config (metadata) info on source of data for metric calculation. 

157 """ 

158 # Demo API (for interface with driver). 

159 configSummary = {} 

160 configDetails = {} 

161 return configSummary, configDetails 

162 

163 def query_arbitrary(self, sqlQuery, dtype=None): 

164 """Simple wrapper around execute_arbitrary for backwards compatibility. 

165 

166 Parameters 

167 ----------- 

168 sqlQuery : str 

169 SQL query. 

170 dtype: opt, numpy dtype. 

171 Numpy recarray dtype. If None, then an attempt to determine the dtype will be made. 

172 This attempt will fail if there are commas in the data you query. 

173 

174 Returns 

175 ------- 

176 numpy.recarray 

177 """ 

178 return self.execute_arbitrary(sqlQuery, dtype=dtype) 

179 

180 def query_columns(self, tablename, colnames=None, sqlconstraint=None, 

181 groupBy=None, numLimit=None, chunksize=1000000): 

182 """Query a table in the database and return data from colnames in recarray. 

183 

184 Parameters 

185 ---------- 

186 tablename : str 

187 Name of table to query. 

188 colnames : list of str or None, opt 

189 Columns from the table to query for. If None, all columns are selected. 

190 sqlconstraint : str or None, opt 

191 Constraint to apply to to the query. Default None. 

192 groupBy : str or None, opt 

193 Name of column to group by. Default None. 

194 numLimit : int or None, opt 

195 Number of records to return. Default no limit. 

196 chunksize : int, opt 

197 Query database and convert to recarray in series of chunks of chunksize. 

198 

199 Returns 

200 ------- 

201 numpy.recarray 

202 """ 

203 # Build the sqlalchemy query from a single table, with various columns/constraints/etc. 

204 # Does NOT use a mapping between column names and database names - assumes the database names 

205 # are what the user will specify. 

206 

207 # Build the query. 

208 query = self._build_query(tablename, colnames=colnames, sqlconstraint=sqlconstraint, 

209 groupBy=groupBy, numLimit=numLimit) 

210 

211 # Determine dtype for numpy recarray. 

212 dtype = [] 

213 for col in colnames: 

214 ty = self.tables[tablename].c[col].type 

215 dt = self.dbTypeMap[ty.__visit_name__] 

216 try: 

217 # Override the default length, if the type has it 

218 # (for example, if it is VARCHAR(1)) 

219 if ty.length is not None: 

220 dt = dt[:-1] + (ty.length,) 

221 except AttributeError: 

222 pass 

223 dtype.append((col,) + dt) 

224 

225 # Execute query on database. 

226 exec_query = self.connection.session.execute(query) 

227 

228 if chunksize is None or chunksize==0: 

229 # Fetch all results and convert to numpy recarray. 

230 results = exec_query.fetchall() 

231 data = self._convert_results(results, dtype) 

232 else: 

233 chunks = [] 

234 # Loop through results, converting in steps of chunksize. 

235 results = exec_query.fetchmany(chunksize) 

236 while len(results) > 0: 

237 chunks.append(self._convert_results(results, dtype)) 

238 results = exec_query.fetchmany(chunksize) 

239 if len(chunks) == 0: 

240 data = np.recarray((0,), dtype=dtype) 

241 else: 

242 data = np.hstack(chunks) 

243 return data 

244 

245 def _build_query(self, tablename, colnames, sqlconstraint=None, groupBy=None, numLimit=None): 

246 if tablename not in self.tables: 

247 raise ValueError('Tablename %s not in list of available tables (%s).' 

248 % (tablename, self.tables.keys())) 

249 if colnames is None: 

250 colnames = self.columnNames[tablename] 

251 else: 

252 for col in colnames: 

253 if col not in self.columnNames[tablename]: 

254 raise ValueError("Requested column %s not available in table %s" % (col, tablename)) 

255 if groupBy is not None: 

256 if groupBy not in self.columnNames[tablename]: 

257 raise ValueError("GroupBy column %s is not available in table %s" % (groupBy, tablename)) 

258 # Put together sqlalchemy query object. 

259 for col in colnames: 

260 if col == colnames[0]: 

261 query = self.connection.session.query(col) 

262 else: 

263 query = query.add_columns(col) 

264 query = query.select_from(self.tables[tablename]) 

265 if sqlconstraint is not None: 

266 if len(sqlconstraint) > 0: 

267 query = query.filter(text(sqlconstraint)) 

268 if groupBy is not None: 

269 query = query.group_by(groupBy) 

270 if numLimit is not None: 

271 query = query.limit(numLimit) 

272 return query 

273 

274 def _convert_results(self, results, dtype): 

275 if len(results) == 0: 

276 data = np.recarray((0,), dtype=dtype) 

277 else: 

278 # Have to do the tuple(xx) for py2 string objects. With py3 is okay to just pass results. 

279 data = np.rec.fromrecords([tuple(xx) for xx in results], dtype=dtype) 

280 return data