Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import print_function 

2from future.utils import with_metaclass 

3import os 

4import inspect 

5import numpy as np 

6from sqlalchemy import func, text, column 

7from sqlalchemy import Table 

8import sqlalchemy 

9from .dbObj import DBObject 

10 

11 

12__all__ = ['DatabaseRegistry', 'Database'] 

13 

14 

15class DatabaseRegistry(type): 

16 """ 

17 Meta class for databases, to build a registry of database classes. 

18 """ 

19 def __init__(cls, name, bases, dict): 

20 super(DatabaseRegistry, cls).__init__(name, bases, dict) 

21 if not hasattr(cls, 'registry'): 

22 cls.registry = {} 

23 modname = inspect.getmodule(cls).__name__ + '.' 

24 if modname.startswith('lsst.sims.maf.db'): 

25 modname = '' 

26 else: 

27 if len(modname.split('.')) > 1: 27 ↛ 30line 27 didn't jump to line 30, because the condition on line 27 was never false

28 modname = '.'.join(modname.split('.')[:-1]) + '.' 

29 else: 

30 modname = modname + '.' 

31 databasename = modname + name 

32 if databasename in cls.registry: 32 ↛ 33line 32 didn't jump to line 33, because the condition on line 32 was never true

33 raise Exception('Redefining databases %s! (there are >1 database classes with the same name)' 

34 % (databasename)) 

35 if databasename not in ['BaseDatabase']: 35 ↛ exitline 35 didn't return from function '__init__', because the condition on line 35 was never false

36 cls.registry[databasename] = cls 

37 

38 def getClass(cls, databasename): 

39 return cls.registry[databasename] 

40 

41 def help(cls, doc=False): 

42 for databasename in sorted(cls.registry): 

43 if not doc: 

44 print(databasename) 

45 if doc: 

46 print('---- ', databasename, ' ----') 

47 print(inspect.getdoc(cls.registry[databasename])) 

48 

49 

50class Database(with_metaclass(DatabaseRegistry, DBObject)): 

51 """Base class for database access. Implements some basic query functionality and demonstrates API. 

52 

53 Parameters 

54 ---------- 

55 database : str 

56 Name of the database (or full path + filename for sqlite db). 

57 driver : str, opt 

58 Dialect+driver for sqlalchemy. Default 'sqlite'. (other examples, 'pymssql+mssql'). 

59 host : str, opt 

60 Hostname for database. Default None (for sqlite). 

61 port : int, opt 

62 Port for database. Default None. 

63 defaultTable : str, opt 

64 Default table in the database to query for metric data. 

65 longstrings : bool, opt 

66 Flag to convert strings in database to long (1024) or short (256) characters in numpy recarray. 

67 Default False (convert to 256 character strings). 

68 verbose : bool, opt 

69 Flag for additional output. Default False. 

70 """ 

71 

72 def __init__(self, database, driver='sqlite', host=None, port=None, defaultTable=None, 

73 longstrings=False, verbose=False): 

74 # If it's a sqlite file, check that the filename exists. 

75 # This gives a more understandable error message than trying to connect to non-existent file later. 

76 if driver == 'sqlite': 

77 if not os.path.isfile(database): 

78 raise IOError('Sqlite database file "%s" not found.' % (database)) 

79 

80 # Connect to database using DBObject init. 

81 super(Database, self).__init__(database=database, driver=driver, 

82 host=host, port=port, verbose=verbose, connection=None) 

83 

84 self.dbTypeMap = {'BIGINT': (int,), 'BOOLEAN': (bool,), 'FLOAT': (float,), 'INTEGER': (int,), 

85 'NUMERIC': (float,), 'SMALLINT': (int,), 'TINYINT': (int,), 

86 'VARCHAR': (np.str_, 256), 'TEXT': (np.str_, 256), 'CLOB': (np.str_, 256), 

87 'NVARCHAR': (np.str_, 256), 'NCLOB': (np.str_, 256), 'NTEXT': (np.str_, 256), 

88 'CHAR': (np.str_, 1), 'INT': (int,), 'REAL': (float,), 'DOUBLE': (float,), 

89 'STRING': (np.str_, 256), 'DOUBLE_PRECISION': (float,), 'DECIMAL': (float,), 

90 'DATETIME': (np.str_, 50)} 

91 if longstrings: 

92 typeOverRide = {'VARCHAR': (np.str_, 1024), 'NVARCHAR': (np.str_, 1024), 

93 'TEXT': (np.str_, 1024), 'CLOB': (np.str_, 1024), 

94 'STRING': (np.str_, 1024)} 

95 

96 self.dbTypeMap.update(typeOverRide) 

97 

98 # Get a dict (keyed by the table names) of all the columns in each table and view. 

99 self.tableNames = sqlalchemy.inspect(self.connection.engine).get_table_names() 

100 self.tableNames += sqlalchemy.inspect(self.connection.engine).get_view_names() 

101 self.columnNames = {} 

102 for t in self.tableNames: 

103 cols = sqlalchemy.inspect(self.connection.engine).get_columns(t) 

104 self.columnNames[t] = [xxx['name'] for xxx in cols] 

105 # Create all the sqlalchemy table objects. This lets us see the schema and query it with types. 

106 self.tables = {} 

107 for tablename in self.tableNames: 

108 self.tables[tablename] = Table(tablename, self.connection.metadata, autoload=True) 

109 self.defaultTable = defaultTable 

110 # if there is is only one table and we haven't said otherwise, set defaultTable automatically. 

111 if self.defaultTable is None and len(self.tableNames) == 1: 

112 self.defaultTable = self.tableNames[0] 

113 

114 def close(self): 

115 self.connection.session.close() 

116 self.connection.engine.dispose() 

117 

118 def fetchMetricData(self, colnames, sqlconstraint=None, groupBy=None, tableName=None): 

119 """Fetch 'colnames' from 'tableName'. 

120 

121 This is basically a thin wrapper around query_columns, but uses the default table. 

122 It's mostly still here for backward compatibility. 

123 

124 Parameters 

125 ---------- 

126 colnames : list 

127 The columns to fetch from the table. 

128 sqlconstraint : str or None, opt 

129 The sql constraint to apply to the data (minus "WHERE"). Default None. 

130 Examples: to fetch data for the r band filter only, set sqlconstraint to 'filter = "r"'. 

131 groupBy : str or None, opt 

132 The column to group the returned data by. 

133 Default (when using summaryTable) is the MJD, otherwise will be None. 

134 tableName : str or None, opt 

135 The table to query. The default (None) will use the summary table, set by self.defaultTable. 

136 

137 Returns 

138 ------- 

139 np.recarray 

140 A structured array containing the data queried from the database. 

141 """ 

142 if tableName is None: 

143 tableName = self.defaultTable 

144 

145 # For a basic Database object, there is no default column to group by. So reset to None. 

146 if groupBy == 'default': 

147 groupBy = None 

148 

149 if tableName not in self.tableNames: 

150 raise ValueError('Table %s not recognized; not in list of database tables.' % (tableName)) 

151 

152 metricdata = self.query_columns(tableName, colnames=colnames, sqlconstraint=sqlconstraint, 

153 groupBy=groupBy) 

154 return metricdata 

155 

156 def fetchConfig(self, *args, **kwargs): 

157 """Get config (metadata) info on source of data for metric calculation. 

158 """ 

159 # Demo API (for interface with driver). 

160 configSummary = {} 

161 configDetails = {} 

162 return configSummary, configDetails 

163 

164 def query_arbitrary(self, sqlQuery, dtype=None): 

165 """Simple wrapper around execute_arbitrary for backwards compatibility. 

166 

167 Parameters 

168 ----------- 

169 sqlQuery : str 

170 SQL query. 

171 dtype: opt, numpy dtype. 

172 Numpy recarray dtype. If None, then an attempt to determine the dtype will be made. 

173 This attempt will fail if there are commas in the data you query. 

174 

175 Returns 

176 ------- 

177 numpy.recarray 

178 """ 

179 return self.execute_arbitrary(sqlQuery, dtype=dtype) 

180 

181 def query_columns(self, tablename, colnames=None, sqlconstraint=None, 

182 groupBy=None, numLimit=None, chunksize=1000000): 

183 """Query a table in the database and return data from colnames in recarray. 

184 

185 Parameters 

186 ---------- 

187 tablename : str 

188 Name of table to query. 

189 colnames : list of str or None, opt 

190 Columns from the table to query for. If None, all columns are selected. 

191 sqlconstraint : str or None, opt 

192 Constraint to apply to to the query. Default None. 

193 groupBy : str or None, opt 

194 Name of column to group by. Default None. 

195 numLimit : int or None, opt 

196 Number of records to return. Default no limit. 

197 chunksize : int, opt 

198 Query database and convert to recarray in series of chunks of chunksize. 

199 

200 Returns 

201 ------- 

202 numpy.recarray 

203 """ 

204 # Build the sqlalchemy query from a single table, with various columns/constraints/etc. 

205 # Does NOT use a mapping between column names and database names - assumes the database names 

206 # are what the user will specify. 

207 

208 # Build the query. 

209 tablename_str = str(tablename).replace('"', '') 

210 query = self._build_query(tablename, colnames=colnames, sqlconstraint=sqlconstraint, 

211 groupBy=groupBy, numLimit=numLimit) 

212 

213 # Determine dtype for numpy recarray. 

214 dtype = [] 

215 for col in colnames: 

216 ty = self.tables[tablename_str].c[str(col).replace('"', '')].type 

217 dt = self.dbTypeMap[ty.__visit_name__] 

218 try: 

219 # Override the default length, if the type has it 

220 # (for example, if it is VARCHAR(1)) 

221 if ty.length is not None: 

222 dt = dt[:-1] + (ty.length,) 

223 except AttributeError: 

224 pass 

225 dtype.append((str(col).replace('"', ''),) + dt) 

226 

227 # Execute query on database. 

228 exec_query = self.connection.session.execute(query) 

229 

230 if chunksize is None or chunksize == 0: 

231 # Fetch all results and convert to numpy recarray. 

232 results = exec_query.fetchall() 

233 data = self._convert_results(results, dtype) 

234 else: 

235 chunks = [] 

236 # Loop through results, converting in steps of chunksize. 

237 results = exec_query.fetchmany(chunksize) 

238 while len(results) > 0: 

239 chunks.append(self._convert_results(results, dtype)) 

240 results = exec_query.fetchmany(chunksize) 

241 if len(chunks) == 0: 

242 data = np.recarray((0,), dtype=dtype) 

243 else: 

244 data = np.hstack(chunks) 

245 return data 

246 

247 def _build_query(self, tablename, colnames, sqlconstraint=None, groupBy=None, numLimit=None): 

248 tablename_str = str(tablename).replace('"', '') 

249 if tablename_str not in self.tables: 

250 raise ValueError('Tablename %s not in list of available tables (%s).' 

251 % (tablename, self.tables.keys())) 

252 if colnames is None: 

253 colnames = self.columnNames[tablename] 

254 else: 

255 for col in colnames: 

256 if str(col).replace('"', '') not in self.columnNames[tablename_str]: 

257 raise ValueError("Requested column %s not available in table %s" % (col, tablename_str)) 

258 if groupBy is not None: 

259 if str(groupBy).replace('"', '') not in self.columnNames[tablename]: 

260 raise ValueError("GroupBy column %s is not available in table %s" % (groupBy, tablename_str)) 

261 # Put together sqlalchemy query object. 

262 for col in colnames: 

263 if col == colnames[0]: 

264 query = self.connection.session.query(column(col)) 

265 else: 

266 query = query.add_columns(column(col)) 

267 query = query.select_from(self.tables[tablename_str]) 

268 if sqlconstraint is not None: 

269 if len(sqlconstraint) > 0: 

270 query = query.filter(text(sqlconstraint)) 

271 if groupBy is not None: 

272 query = query.group_by(groupBy) 

273 if numLimit is not None: 

274 query = query.limit(numLimit) 

275 return query 

276 

277 def _convert_results(self, results, dtype): 

278 if len(results) == 0: 

279 data = np.recarray((0,), dtype=dtype) 

280 else: 

281 # Have to do the tuple(xx) for py2 string objects. With py3 is okay to just pass results. 

282 data = np.rec.fromrecords([tuple(xx) for xx in results], dtype=dtype) 

283 return data