Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import str 

2from builtins import object 

3import os, warnings 

4from sqlalchemy import create_engine 

5from sqlalchemy.orm import sessionmaker 

6from sqlalchemy.engine import url 

7from sqlalchemy.ext.declarative import declarative_base 

8from sqlalchemy import Column, Integer, String, Float 

9from sqlalchemy import ForeignKey 

10from sqlalchemy.orm import relationship, backref 

11from sqlalchemy.exc import DatabaseError 

12 

13import numpy as np 

14 

15Base = declarative_base() 

16 

17__all__ = ['MetricRow', 'DisplayRow', 'PlotRow', 'SummaryStatRow', 'ResultsDb'] 

18 

19class MetricRow(Base): 

20 """ 

21 Define contents and format of metric list table. 

22 

23 (Table to list all metrics, their metadata, and their output data files). 

24 """ 

25 __tablename__ = "metrics" 

26 # Define columns in metric list table. 

27 metricId = Column(Integer, primary_key=True) 

28 metricName = Column(String) 

29 slicerName = Column(String) 

30 simDataName = Column(String) 

31 sqlConstraint = Column(String) 

32 metricMetadata = Column(String) 

33 metricDataFile = Column(String) 

34 def __repr__(self): 

35 return "<Metric(metricId='%d', metricName='%s', slicerName='%s', simDataName='%s', " \ 

36 "sqlConstraint='%s', metadata='%s', metricDataFile='%s')>" \ 

37 %(self.metricId, self.metricName, self.slicerName, self.simDataName, 

38 self.sqlConstraint, self.metricMetadata, self.metricDataFile) 

39 

40class DisplayRow(Base): 

41 """ 

42 Define contents and format of the displays table. 

43 

44 (Table to list the display properties for each metric.) 

45 """ 

46 __tablename__ = "displays" 

47 displayId = Column(Integer, primary_key=True) 

48 metricId = Column(Integer, ForeignKey('metrics.metricId')) 

49 # Group for displaying metric (in webpages). 

50 displayGroup = Column(String) 

51 # Subgroup for displaying metric. 

52 displaySubgroup = Column(String) 

53 # Order to display metric (within subgroup). 

54 displayOrder = Column(Float) 

55 # The figure caption. 

56 displayCaption = Column(String) 

57 metric = relationship("MetricRow", backref=backref('displays', order_by=displayId)) 

58 def __rep__(self): 

59 return "<Display(displayGroup='%s', displaySubgroup='%s', " \ 

60 "displayOrder='%.1f', displayCaption='%s')>" \ 

61 %(self.displayGroup, self.displaySubgroup, self.displayOrder, self.displayCaption) 

62 

63class PlotRow(Base): 

64 """ 

65 Define contents and format of plot list table. 

66 

67 (Table to list all plots, link them to relevant metrics in MetricList, and provide info on filename). 

68 """ 

69 __tablename__ = "plots" 

70 # Define columns in plot list table. 

71 plotId = Column(Integer, primary_key=True) 

72 # Matches metricID in MetricList table. 

73 metricId = Column(Integer, ForeignKey('metrics.metricId')) 

74 plotType = Column(String) 

75 plotFile = Column(String) 

76 metric = relationship("MetricRow", backref=backref('plots', order_by=plotId)) 

77 def __repr__(self): 

78 return "<Plot(metricId='%d', plotType='%s', plotFile='%s')>" \ 

79 %(self.metricId, self.plotType, self.plotFile) 

80 

81class SummaryStatRow(Base): 

82 """ 

83 Define contents and format of the summary statistics table. 

84 

85 (Table to list and link summary stats to relevant metrics in MetricList, and provide summary stat name, 

86 value and potentially a comment). 

87 """ 

88 __tablename__ = "summarystats" 

89 # Define columns in plot list table. 

90 statId = Column(Integer, primary_key=True) 

91 # Matches metricID in MetricList table. 

92 metricId = Column(Integer, ForeignKey('metrics.metricId')) 

93 summaryName = Column(String) 

94 summaryValue = Column(Float) 

95 metric = relationship("MetricRow", backref=backref('summarystats', order_by=statId)) 

96 def __repr__(self): 

97 return "<SummaryStat(metricId='%d', summaryName='%s', summaryValue='%f')>" \ 

98 %(self.metricId, self.summaryName, self.summaryValue) 

99 

100class ResultsDb(object): 

101 """The ResultsDb is a sqlite database containing information on the metrics run via MAF, 

102 the plots created, the display information (such as captions), and any summary statistics output. 

103 """ 

104 def __init__(self, outDir= None, database=None, verbose=False): 

105 """ 

106 Instantiate the results database, creating metrics, plots and summarystats tables. 

107 """ 

108 # We now require resultsDb to be a sqlite file (for simplicity). Leaving as attribute though. 

109 self.driver = 'sqlite' 

110 # Connect to database 

111 # for sqlite, connecting to non-existent database creates it automatically 

112 if database is None: 

113 # Using default value for database name, should specify directory. 

114 if outDir is None: 

115 outDir = '.' 

116 # Check for output directory, make if needed. 

117 if not os.path.isdir(outDir): 

118 try: 

119 os.makedirs(outDir) 

120 except OSError as msg: 

121 raise OSError(msg, '\n (If this was the database file (not outDir), ' 

122 'remember to use kwarg "database")') 

123 self.database = os.path.join(outDir, 'resultsDb_sqlite.db') 

124 else: 

125 # Using non-default database, but may also specify directory root. 

126 if outDir is not None: 

127 database = os.path.join(outDir, database) 

128 self.database = database 

129 

130 dbAddress = url.URL(self.driver, database=self.database) 

131 

132 engine = create_engine(dbAddress, echo=verbose) 

133 self.Session = sessionmaker(bind=engine) 

134 self.session = self.Session() 

135 # Create the tables, if they don't already exist. 

136 try: 

137 Base.metadata.create_all(engine) 

138 except DatabaseError: 

139 raise ValueError("Cannot create a %s database at %s. Check directory exists." %(self.driver, 

140 self.database)) 

141 self.slen = 1024 

142 

143 def close(self): 

144 """ 

145 Close connection to database. 

146 """ 

147 self.session.close() 

148 

149 def updateMetric(self, metricName, slicerName, simDataName, sqlConstraint, 

150 metricMetadata, metricDataFile): 

151 """ 

152 Add a row to or update a row in the metrics table. 

153 

154 - metricName: the name of the metric 

155 - sliceName: the name of the slicer 

156 - simDataName: the name used to identify the simData 

157 - sqlConstraint: the sql constraint used to select data from the simData 

158 - metricMetadata: the metadata associated with the metric 

159 - metricDatafile: the data file the metric data is stored in 

160 

161 If same metric (same metricName, slicerName, simDataName, sqlConstraint, metadata) 

162 already exists, it does nothing. 

163 

164 Returns metricId: the Id number of this metric in the metrics table. 

165 """ 

166 if simDataName is None: 

167 simDataName = 'NULL' 

168 if sqlConstraint is None: 

169 sqlConstraint = 'NULL' 

170 if metricMetadata is None: 

171 metricMetadata = 'NULL' 

172 if metricDataFile is None: 

173 metricDataFile = 'NULL' 

174 # Check if metric has already been added to database. 

175 prev = self.session.query(MetricRow).filter_by(metricName=metricName, 

176 slicerName=slicerName, 

177 simDataName=simDataName, 

178 metricMetadata=metricMetadata, 

179 sqlConstraint=sqlConstraint).all() 

180 if len(prev) == 0: 

181 metricinfo = MetricRow(metricName=metricName, slicerName=slicerName, simDataName=simDataName, 

182 sqlConstraint=sqlConstraint, metricMetadata=metricMetadata, 

183 metricDataFile=metricDataFile) 

184 self.session.add(metricinfo) 

185 self.session.commit() 

186 else: 

187 metricinfo = prev[0] 

188 return metricinfo.metricId 

189 

190 def updateDisplay(self, metricId, displayDict, overwrite=True): 

191 """ 

192 Add a row to or update a row in the displays table. 

193 

194 - metricID: the metric Id of this metric in the metrics table 

195 - displayDict: dictionary containing the display info 

196 

197 Replaces existing row with same metricId. 

198 """ 

199 # Because we want to maintain 1-1 relationship between metricId's and displayDict's: 

200 # First check if a display line is present with this metricID. 

201 displayinfo = self.session.query(DisplayRow).filter_by(metricId=metricId).all() 

202 if len(displayinfo) > 0: 

203 if overwrite: 

204 for d in displayinfo: 

205 self.session.delete(d) 

206 else: 

207 return 

208 # Then go ahead and add new displayDict. 

209 for k in displayDict: 

210 if displayDict[k] is None: 

211 displayDict[k] = 'NULL' 

212 keys = ['group', 'subgroup', 'order', 'caption'] 

213 for k in keys: 

214 if k not in displayDict: 

215 displayDict[k] = 'NULL' 

216 if displayDict['order'] == 'NULL': 

217 displayDict['order'] = 0 

218 displayGroup = displayDict['group'] 

219 displaySubgroup = displayDict['subgroup'] 

220 displayOrder = displayDict['order'] 

221 displayCaption = displayDict['caption'] 

222 if displayCaption.endswith('(auto)'): 

223 displayCaption = displayCaption.replace('(auto)', '', 1) 

224 displayinfo = DisplayRow(metricId=metricId, 

225 displayGroup=displayGroup, displaySubgroup=displaySubgroup, 

226 displayOrder=displayOrder, displayCaption=displayCaption) 

227 self.session.add(displayinfo) 

228 self.session.commit() 

229 

230 def updatePlot(self, metricId, plotType, plotFile): 

231 """ 

232 Add a row to or update a row in the plot table. 

233 

234 - metricId: the metric Id of this metric in the metrics table 

235 - plotType: the 'type' of this plot 

236 - plotFile: the filename of this plot 

237 

238 Remove older rows with the same metricId, plotType and plotFile. 

239 """ 

240 plotinfo = self.session.query(PlotRow).filter_by(metricId=metricId, plotType=plotType, 

241 plotFile=plotFile).all() 

242 if len(plotinfo) > 0: 

243 for p in plotinfo: 

244 self.session.delete(p) 

245 plotinfo = PlotRow(metricId=metricId, plotType=plotType, plotFile=plotFile) 

246 self.session.add(plotinfo) 

247 self.session.commit() 

248 

249 def updateSummaryStat(self, metricId, summaryName, summaryValue): 

250 """ 

251 Add a row to or update a row in the summary statistic table. 

252 

253 - metricId: the metric ID of this metric in the metrics table 

254 - summaryName: the name of this summary statistic 

255 - summaryValue: the value for this summary statistic 

256 

257 Most summary statistics will be a simple name (string) + value (float) pair. 

258 For special summary statistics which must return multiple values, the base name 

259 can be provided as 'name', together with a np recarray as 'value', where the 

260 recarray also has 'name' and 'value' columns (and each name/value pair is then saved 

261 as a summary statistic associated with this same metricId). 

262 """ 

263 # Allow for special summary statistics which return data in a np structured array with 

264 # 'name' and 'value' columns. (specificially needed for TableFraction summary statistic). 

265 if isinstance(summaryValue, np.ndarray): 

266 if (('name' in summaryValue.dtype.names) and ('value' in summaryValue.dtype.names)): 

267 for value in summaryValue: 

268 sSuffix = value['name'] 

269 if isinstance(sSuffix, bytes): 

270 sSuffix = sSuffix.decode('utf-8') 

271 else: 

272 sSuffix = str(sSuffix) 

273 summarystat = SummaryStatRow(metricId=metricId, 

274 summaryName=summaryName + ' ' + sSuffix, 

275 summaryValue=value['value']) 

276 self.session.add(summarystat) 

277 self.session.commit() 

278 else: 

279 warnings.warn('Warning! Cannot save non-conforming summary statistic.') 

280 # Most summary statistics will be simple floats. 

281 else: 

282 if isinstance(summaryValue, float) or isinstance(summaryValue, int): 

283 summarystat = SummaryStatRow(metricId=metricId, summaryName=summaryName, 

284 summaryValue=summaryValue) 

285 self.session.add(summarystat) 

286 self.session.commit() 

287 else: 

288 warnings.warn('Warning! Cannot save summary statistic that is not a simple float or int') 

289 

290 def getMetricId(self, metricName, slicerName=None, metricMetadata=None, simDataName=None): 

291 """ 

292 Given a metric name and optional slicerName/metricMetadata/simData information, 

293 Return a list of the matching metricIds. 

294 """ 

295 metricId = [] 

296 query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName, 

297 MetricRow.metricMetadata, 

298 MetricRow.simDataName).filter(MetricRow.metricName == metricName) 

299 if slicerName is not None: 

300 query = query.filter(MetricRow.slicerName == slicerName) 

301 if metricMetadata is not None: 

302 query = query.filter(MetricRow.metricMetadata == metricMetadata) 

303 if simDataName is not None: 

304 query = query.filter(MetricRow.simDataName == simDataName) 

305 query = query.order_by(MetricRow.slicerName, MetricRow.metricMetadata) 

306 for m in query: 

307 metricId.append(m.metricId) 

308 return metricId 

309 

310 def getMetricIdLike(self, metricNameLike=None, slicerNameLike=None, 

311 metricMetadataLike=None, simDataName=None): 

312 metricId = [] 

313 query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName, 

314 MetricRow.metricMetadata, 

315 MetricRow.simDataName) 

316 if metricNameLike is not None: 

317 query = query.filter(MetricRow.metricName.like('%' + str(metricNameLike) + '%')) 

318 if slicerNameLike is not None: 

319 query = query.filter(MetricRow.slicerName.like('%' + str(slicerNameLike) + '%')) 

320 if metricMetadataLike is not None: 

321 query = query.filter(MetricRow.metricMetadata.like('%' + str(metricMetadataLike) + '%')) 

322 if simDataName is not None: 

323 query = query.filter(MetricRow.simDataName == simDataName) 

324 for m in query: 

325 metricId.append(m.metricId) 

326 return metricId 

327 

328 def getAllMetricIds(self): 

329 """ 

330 Return a list of all metricIds. 

331 """ 

332 metricIds = [] 

333 for m in self.session.query(MetricRow.metricId).all(): 

334 metricIds.append(m.metricId) 

335 return metricIds 

336 

337 def getSummaryStats(self, metricId=None, summaryName=None): 

338 """ 

339 Get the summary stats (optionally for metricId list). 

340 Optionally, also specify the summary metric name. 

341 Returns a numpy array of the metric information + summary statistic information. 

342 """ 

343 if metricId is None: 

344 metricId = self.getAllMetricIds() 

345 if not hasattr(metricId, '__iter__'): 

346 metricId = [metricId,] 

347 summarystats = [] 

348 for mid in metricId: 

349 # Join the metric table and the summarystat table, based on the metricID (the second filter) 

350 query = (self.session.query(MetricRow, SummaryStatRow).filter(MetricRow.metricId == mid) 

351 .filter(MetricRow.metricId == SummaryStatRow.metricId)) 

352 if summaryName is not None: 

353 query = query.filter(SummaryStatRow.summaryName == summaryName) 

354 for m, s in query: 

355 summarystats.append((m.metricId, m.metricName, m.slicerName, m.metricMetadata, 

356 s.summaryName, s.summaryValue)) 

357 # Convert to numpy array. 

358 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen), 

359 ('slicerName', np.str_, self.slen), ('metricMetadata', np.str_, self.slen), 

360 ('summaryName', np.str_, self.slen), ('summaryValue', float)]) 

361 summarystats = np.array(summarystats, dtype) 

362 return summarystats 

363 

364 def getPlotFiles(self, metricId=None): 

365 """ 

366 Return the metricId, name, metadata, and all plot info (optionally for metricId list). 

367 Returns a numpy array of the metric information + plot file names. 

368 """ 

369 if metricId is None: 

370 metricId = self.getAllMetricIds() 

371 if not hasattr(metricId, '__iter__'): 

372 metricId = [metricId,] 

373 plotFiles = [] 

374 for mid in metricId: 

375 # Join the metric table and the plot table based on the metricID (the second filter does the join) 

376 query = (self.session.query(MetricRow, PlotRow).filter(MetricRow.metricId == mid) 

377 .filter(MetricRow.metricId == PlotRow.metricId)) 

378 for m, p in query: 

379 # The plotFile typically ends with .pdf (but the rest of name can have '.' or '_') 

380 thumbfile = 'thumb.' + '.'.join(p.plotFile.split('.')[:-1]) + '.png' 

381 plotFiles.append((m.metricId, m.metricName, m.metricMetadata, 

382 p.plotType, p.plotFile, thumbfile)) 

383 # Convert to numpy array. 

384 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen), 

385 ('metricMetadata', np.str_, self.slen), 

386 ('plotType', np.str_, self.slen), ('plotFile', np.str_, self.slen), 

387 ('thumbFile', np.str_, self.slen)]) 

388 plotFiles = np.array(plotFiles, dtype) 

389 return plotFiles 

390 

391 def getMetricDataFiles(self, metricId=None): 

392 """ 

393 Get the metric data filenames for all or a single metric. 

394 Returns a list. 

395 """ 

396 if metricId is None: 

397 metricId = self.getAllMetricIds() 

398 if not hasattr(metricId, '__iter__'): 

399 metricId = [metricId,] 

400 dataFiles = [] 

401 for mid in metricId: 

402 for m in self.session.query(MetricRow).filter(MetricRow.metricId == mid).all(): 

403 dataFiles.append(m.metricDataFile) 

404 return dataFiles 

405 

406 def getMetricInfo(self, metricId=None): 

407 """Get the simple metric info, without display information. 

408 """ 

409 if metricId is None: 

410 metricId = self.getAllMetricIds() 

411 if not hasattr(metricId, '__iter__'): 

412 metricId = [metricId,] 

413 metricInfo = [] 

414 for mId in metricId: 

415 # Query for all rows in metrics and displays that match any of the metricIds. 

416 query = (self.session.query(MetricRow).filter(MetricRow.metricId==mId)) 

417 for m in query: 

418 baseMetricName = m.metricName.split('_')[0] 

419 mInfo = (m.metricId, m.metricName, baseMetricName, m.slicerName, 

420 m.sqlConstraint, m.metricMetadata, m.metricDataFile) 

421 metricInfo.append(mInfo) 

422 # Convert to numpy array. 

423 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen), 

424 ('baseMetricNames', np.str_, self.slen), 

425 ('slicerName', np.str_, self.slen), 

426 ('sqlConstraint', np.str_, self.slen), 

427 ('metricMetadata', np.str_, self.slen), 

428 ('metricDataFile', np.str_, self.slen)]) 

429 metricInfo = np.array(metricInfo, dtype) 

430 return metricInfo 

431 

432 def getMetricDisplayInfo(self, metricId=None): 

433 """ 

434 Get the contents of the metrics and displays table, together with the 'basemetricname' 

435 (optionally, for metricId list). 

436 Returns a numpy array of the metric information + display information. 

437 

438 One underlying assumption here is that all metrics have some display info. 

439 In newer batches, this may not be the case, as the display info gets auto-generated when the 

440 metric is plotted. 

441 """ 

442 if metricId is None: 

443 metricId = self.getAllMetricIds() 

444 if not hasattr(metricId, '__iter__'): 

445 metricId = [metricId,] 

446 metricInfo = [] 

447 for mId in metricId: 

448 # Query for all rows in metrics and displays that match any of the metricIds. 

449 query = (self.session.query(MetricRow, DisplayRow).filter(MetricRow.metricId==mId) 

450 .filter(MetricRow.metricId==DisplayRow.metricId)) 

451 for m, d in query: 

452 baseMetricName = m.metricName.split('_')[0] 

453 mInfo = (m.metricId, m.metricName, baseMetricName, m.slicerName, 

454 m.sqlConstraint, m.metricMetadata, m.metricDataFile, 

455 d.displayGroup, d.displaySubgroup, d.displayOrder, d.displayCaption) 

456 metricInfo.append(mInfo) 

457 # Convert to numpy array. 

458 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen), 

459 ('baseMetricNames', np.str_, self.slen), 

460 ('slicerName', np.str_, self.slen), 

461 ('sqlConstraint', np.str_, self.slen), 

462 ('metricMetadata', np.str_, self.slen), 

463 ('metricDataFile', np.str_, self.slen), 

464 ('displayGroup', np.str_, self.slen), 

465 ('displaySubgroup', np.str_, self.slen), 

466 ('displayOrder', float), 

467 ('displayCaption', np.str_, self.slen * 10)]) 

468 metricInfo = np.array(metricInfo, dtype) 

469 return metricInfo