Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import zip 

2from builtins import range 

3from builtins import object 

4import os 

5import numpy as np 

6import warnings 

7import matplotlib.pyplot as plt 

8import lsst.sims.maf.utils as utils 

9 

10__all__ = ['applyZPNorm', 'PlotHandler', 'BasePlotter'] 

11 

12def applyZPNorm(metricValue, plotDict): 

13 if 'zp' in plotDict: 

14 if plotDict['zp'] is not None: 

15 metricValue = metricValue - plotDict['zp'] 

16 if 'normVal' in plotDict: 

17 if plotDict['normVal'] is not None: 

18 metricValue = metricValue / plotDict['normVal'] 

19 return metricValue 

20 

21 

22class BasePlotter(object): 

23 """ 

24 Serve as the base type for MAF plotters and example of API. 

25 """ 

26 def __init__(self): 

27 self.plotType = None 

28 # This should be included in every subsequent defaultPlotDict (assumed to be present). 

29 self.defaultPlotDict = {'title': None, 'xlabel': None, 'label': None, 

30 'labelsize': None, 'fontsize': None, 'figsize': None} 

31 

32 def __call__(self, metricValue, slicer, userPlotDict, fignum=None): 

33 pass 

34 

35 

36class PlotHandler(object): 

37 

38 def __init__(self, outDir='.', resultsDb=None, savefig=True, 

39 figformat='pdf', dpi=600, thumbnail=True, trimWhitespace=True): 

40 self.outDir = outDir 

41 self.resultsDb = resultsDb 

42 self.savefig = savefig 

43 self.figformat = figformat 

44 self.dpi = dpi 

45 self.trimWhitespace = trimWhitespace 

46 self.thumbnail = thumbnail 

47 self.filtercolors = {'u': 'cyan', 'g': 'g', 'r': 'y', 

48 'i': 'r', 'z': 'm', 'y': 'k', ' ': None} 

49 self.filterorder = {' ': -1, 'u': 0, 'g': 1, 'r': 2, 'i': 3, 'z': 4, 'y': 5} 

50 

51 def setMetricBundles(self, mBundles): 

52 """ 

53 Set the metric bundle or bundles (list or dictionary). 

54 Reuse the PlotHandler by resetting this reference. 

55 The metric bundles have to have the same slicer. 

56 """ 

57 self.mBundles = [] 

58 # Try to add the metricBundles in filter order. 

59 if isinstance(mBundles, dict): 

60 for mB in mBundles.values(): 

61 vals = mB.fileRoot.split('_') 

62 forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1] 

63 forder = [o for o in forder if o is not None] 

64 if len(forder) == 0: 

65 forder = len(self.mBundles) 

66 else: 

67 forder = forder[-1] 

68 self.mBundles.insert(forder, mB) 

69 self.slicer = self.mBundles[0].slicer 

70 else: 

71 for mB in mBundles: 

72 vals = mB.fileRoot.split('_') 

73 forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1] 

74 forder = [o for o in forder if o is not None] 

75 if len(forder) == 0: 

76 forder = len(self.mBundles) 

77 else: 

78 forder = forder[-1] 

79 self.mBundles.insert(forder, mB) 

80 self.slicer = self.mBundles[0].slicer 

81 for mB in self.mBundles: 

82 if mB.slicer.slicerName != self.slicer.slicerName: 

83 raise ValueError('MetricBundle items must have the same type of slicer') 

84 self._combineMetricNames() 

85 self._combineRunNames() 

86 self._combineMetadata() 

87 self._combineConstraints() 

88 self.setPlotDicts(reset=True) 

89 

90 def setPlotDicts(self, plotDicts=None, plotFunc=None, reset=False): 

91 """ 

92 Set or update (or 'reset') the plotDict for the (possibly joint) plots. 

93 

94 Resolution is: 

95 auto-generated items (colors/labels/titles) 

96 < anything previously set in the plotHandler 

97 < defaults set by the plotter 

98 < explicitly set items in the metricBundle plotDict 

99 < explicitly set items in the plotDicts list passed to this method. 

100 """ 

101 if reset: 

102 # Have to explicitly set each dictionary to a (separate) blank dictionary. 

103 self.plotDicts = [{} for b in self.mBundles] 

104 

105 if isinstance(plotDicts, dict): 

106 # We were passed a single dictionary, not a list. 

107 plotDicts = [plotDicts] * len(self.mBundles) 

108 

109 autoLabelList = self._buildLegendLabels() 

110 autoColorList = self._buildColors() 

111 autoCbar = self._buildCbarFormat() 

112 autoTitle = self._buildTitle() 

113 if plotFunc is not None: 

114 autoXlabel, autoYlabel = self._buildXYlabels(plotFunc) 

115 

116 # Loop through each bundle and generate a plotDict for it. 

117 for i, bundle in enumerate(self.mBundles): 

118 # First use the auto-generated values. 

119 tmpPlotDict = {} 

120 tmpPlotDict['title'] = autoTitle 

121 tmpPlotDict['label'] = autoLabelList[i] 

122 tmpPlotDict['color'] = autoColorList[i] 

123 tmpPlotDict['cbarFormat'] = autoCbar 

124 # Then update that with anything previously set in the plotHandler. 

125 tmpPlotDict.update(self.plotDicts[i]) 

126 # Then override with plotDict items set explicitly based on the plot type. 

127 if plotFunc is not None: 

128 tmpPlotDict['xlabel'] = autoXlabel 

129 tmpPlotDict['ylabel'] = autoYlabel 

130 # Replace auto-generated plot dict items with things 

131 # set by the plotterDefaults, if they are not None. 

132 plotterDefaults = plotFunc.defaultPlotDict 

133 for k, v in plotterDefaults.items(): 

134 if v is not None: 

135 tmpPlotDict[k] = v 

136 # Then add/override based on the bundle plotDict parameters if they are set. 

137 tmpPlotDict.update(bundle.plotDict) 

138 # Finally, override with anything set explicitly by the user right now. 

139 if plotDicts is not None: 

140 tmpPlotDict.update(plotDicts[i]) 

141 # And save this new dictionary back in the class. 

142 self.plotDicts[i] = tmpPlotDict 

143 

144 # Check that the plotDicts do not conflict. 

145 self._checkPlotDicts() 

146 

147 def _combineMetricNames(self): 

148 """ 

149 Combine metric names. 

150 """ 

151 # Find the unique metric names. 

152 self.metricNames = set() 

153 for mB in self.mBundles: 

154 self.metricNames.add(mB.metric.name) 

155 # Find a pleasing combination of the metric names. 

156 order = ['u', 'g', 'r', 'i', 'z', 'y'] 

157 if len(self.metricNames) == 1: 

158 jointName = ' '.join(self.metricNames) 

159 else: 

160 # Split each unique name into a list to see if we can merge the names. 

161 nameLengths = [len(x.split()) for x in self.metricNames] 

162 nameLists = [x.split() for x in self.metricNames] 

163 # If the metric names are all the same length, see if we can combine any parts. 

164 if len(set(nameLengths)) == 1: 

165 jointName = [] 

166 for i in range(nameLengths[0]): 

167 tmp = set([x[i] for x in nameLists]) 

168 # Try to catch special case of filters and put them in order. 

169 if tmp.intersection(order) == tmp: 

170 filterlist = '' 

171 for f in order: 

172 if f in tmp: 

173 filterlist += f 

174 jointName.append(filterlist) 

175 else: 

176 # Otherwise, just join and put into jointName. 

177 jointName.append(''.join(tmp)) 

178 jointName = ' '.join(jointName) 

179 # If the metric names are not the same length, just join everything. 

180 else: 

181 jointName = ' '.join(self.metricNames) 

182 self.jointMetricNames = jointName 

183 

184 def _combineRunNames(self): 

185 """ 

186 Combine runNames. 

187 """ 

188 self.runNames = set() 

189 for mB in self.mBundles: 

190 self.runNames.add(mB.runName) 

191 self.jointRunNames = ' '.join(self.runNames) 

192 

193 def _combineMetadata(self): 

194 """ 

195 Combine metadata. 

196 """ 

197 metadata = set() 

198 for mB in self.mBundles: 

199 metadata.add(mB.metadata) 

200 self.metadata = metadata 

201 # Find a pleasing combination of the metadata. 

202 if len(metadata) == 1: 

203 self.jointMetadata = ' '.join(metadata) 

204 else: 

205 order = ['u', 'g', 'r', 'i', 'z', 'y'] 

206 # See if there are any subcomponents we can combine, 

207 # splitting on some values we expect to separate metadata clauses. 

208 splitmetas = [] 

209 for m in self.metadata: 

210 # Try to split metadata into separate phrases (filter / proposal / constraint..). 

211 if ' and ' in m: 

212 m = m.split(' and ') 

213 elif ', ' in m: 

214 m = m.split(', ') 

215 else: 

216 m = [m, ] 

217 # Strip white spaces from individual elements. 

218 m = set([im.strip() for im in m]) 

219 splitmetas.append(m) 

220 # Look for common elements and separate from the general metadata. 

221 common = set.intersection(*splitmetas) 

222 diff = [x.difference(common) for x in splitmetas] 

223 # Now look within the 'diff' elements and see if there are any common words to split off. 

224 diffsplit = [] 

225 for d in diff: 

226 if len(d) > 0: 

227 m = set([x.split() for x in d][0]) 

228 else: 

229 m = set() 

230 diffsplit.append(m) 

231 diffcommon = set.intersection(*diffsplit) 

232 diffdiff = [x.difference(diffcommon) for x in diffsplit] 

233 # If the length of any of the 'differences' is 0, then we should stop and not try to subdivide. 

234 lengths = [len(x) for x in diffdiff] 

235 if min(lengths) == 0: 

236 # Sort them in order of length (so it goes 'g', 'g dithered', etc.) 

237 tmp = [] 

238 for d in diff: 

239 tmp.append(list(d)[0]) 

240 diff = tmp 

241 xlengths = [len(x) for x in diff] 

242 idx = np.argsort(xlengths) 

243 diffdiff = [diff[i] for i in idx] 

244 diffcommon = [] 

245 else: 

246 # diffdiff is the part where we might expect our filter values to appear; 

247 # try to put this in order. 

248 diffdiffOrdered = [] 

249 diffdiffEnd = [] 

250 for f in order: 

251 for d in diffdiff: 

252 if len(d) == 1: 

253 if list(d)[0] == f: 

254 diffdiffOrdered.append(d) 

255 for d in diffdiff: 

256 if d not in diffdiffOrdered: 

257 diffdiffEnd.append(d) 

258 diffdiff = diffdiffOrdered + diffdiffEnd 

259 diffdiff = [' '.join(c) for c in diffdiff] 

260 # And put it all back together. 

261 combo = (', '.join([''.join(c) for c in diffdiff]) + ' ' + 

262 ' '.join([''.join(d) for d in diffcommon]) + ' ' + 

263 ' '.join([''.join(e) for e in common])) 

264 self.jointMetadata = combo 

265 

266 def _combineConstraints(self): 

267 """ 

268 Combine the constraints. 

269 """ 

270 constraints = set() 

271 for mB in self.mBundles: 

272 if mB.constraint is not None: 

273 constraints.add(mB.constraint) 

274 self.constraints = '; '.join(constraints) 

275 

276 def _buildTitle(self): 

277 """ 

278 Build a plot title from the metric names, runNames and metadata. 

279 """ 

280 # Create a plot title from the unique parts of the metric/runName/metadata. 

281 plotTitle = '' 

282 if len(self.runNames) == 1: 

283 plotTitle += list(self.runNames)[0] 

284 if len(self.metadata) == 1: 

285 plotTitle += ' ' + list(self.metadata)[0] 

286 if len(self.metricNames) == 1: 

287 plotTitle += ': ' + list(self.metricNames)[0] 

288 if plotTitle == '': 

289 # If there were more than one of everything above, use joint metadata and metricNames. 

290 plotTitle = self.jointMetadata + ' ' + self.jointMetricNames 

291 return plotTitle 

292 

293 def _buildXYlabels(self, plotFunc): 

294 """ 

295 Build a plot x and y label. 

296 """ 

297 if plotFunc.plotType == 'BinnedData': 

298 if len(self.mBundles) == 1: 

299 mB = self.mBundles[0] 

300 xlabel = mB.slicer.sliceColName + ' (' + mB.slicer.sliceColUnits + ')' 

301 ylabel = mB.metric.name + ' (' + mB.metric.units + ')' 

302 else: 

303 xlabel = set() 

304 for mB in self.mBundles: 

305 xlabel.add(mB.slicer.sliceColName) 

306 xlabel = ', '.join(xlabel) 

307 ylabel = self.jointMetricNames 

308 elif plotFunc.plotType == 'MetricVsH': 

309 if len(self.mBundles) == 1: 

310 mB = self.mBundles[0] 

311 ylabel = mB.metric.name + ' (' + mB.metric.units + ')' 

312 else: 

313 ylabel = self.jointMetricNames 

314 xlabel = 'H (mag)' 

315 else: 

316 if len(self.mBundles) == 1: 

317 mB = self.mBundles[0] 

318 xlabel = mB.metric.name 

319 if mB.metric.units is not None: 

320 if len(mB.metric.units) > 0: 

321 xlabel += ' (' + mB.metric.units + ')' 

322 ylabel = None 

323 else: 

324 xlabel = self.jointMetricNames 

325 ylabel = set() 

326 for mB in self.mBundles: 

327 if 'ylabel' in mB.plotDict: 

328 ylabel.add(mB.plotDict['ylabel']) 

329 if len(ylabel) == 1: 

330 ylabel = list(ylabel)[0] 

331 else: 

332 ylabel = None 

333 return xlabel, ylabel 

334 

335 def _buildLegendLabels(self): 

336 """ 

337 Build a set of legend labels, using parts of the runName/metadata/metricNames that change. 

338 """ 

339 if len(self.mBundles) == 1: 

340 return [None] 

341 labels = [] 

342 for mB in self.mBundles: 

343 if 'label' in mB.plotDict: 

344 label = mB.plotDict['label'] 

345 else: 

346 label = '' 

347 if len(self.runNames) > 1: 

348 label += mB.runName 

349 if len(self.metadata) > 1: 

350 label += ' ' + mB.metadata 

351 if len(self.metricNames) > 1: 

352 label += ' ' + mB.metric.name 

353 labels.append(label) 

354 return labels 

355 

356 def _buildColors(self): 

357 """ 

358 Try to set an appropriate range of colors for the metric Bundles. 

359 """ 

360 if len(self.mBundles) == 1: 

361 if 'color' in self.mBundles[0].plotDict: 

362 return [self.mBundles[0].plotDict['color']] 

363 else: 

364 return ['b'] 

365 colors = [] 

366 for mB in self.mBundles: 

367 color = 'b' 

368 if 'color' in mB.plotDict: 

369 color = mB.plotDict['color'] 

370 else: 

371 if mB.constraint is not None: 

372 # If the filter is part of the sql constraint, we'll 

373 # try to use that first. 

374 if 'filter' in mB.constraint: 

375 vals = mB.constraint.split('"') 

376 for v in vals: 

377 if len(v) == 1: 

378 # Guess that this is the filter value 

379 if v in self.filtercolors: 

380 color = self.filtercolors[v] 

381 colors.append(color) 

382 # If we happened to end up with the same color throughout 

383 # (say, the metrics were all in the same filter) 

384 # then go ahead and generate random colors. 

385 if (len(self.mBundles) > 1) and (len(np.unique(colors)) == 1): 

386 colors = [np.random.rand(3,) for mB in self.mBundles] 

387 return colors 

388 

389 def _buildCbarFormat(self): 

390 """ 

391 Set the color bar format. 

392 """ 

393 cbarFormat = None 

394 if len(self.mBundles) == 1: 

395 if self.mBundles[0].metric.metricDtype == 'int': 

396 cbarFormat = '%d' 

397 else: 

398 metricDtypes = set() 

399 for mB in self.mBundles: 

400 metricDtypes.add(mB.metric.metricDtype) 

401 if len(metricDtypes) == 1: 

402 if list(metricDtypes)[0] == 'int': 

403 cbarFormat = '%d' 

404 return cbarFormat 

405 

406 def _buildFileRoot(self, outfileSuffix=None): 

407 """ 

408 Build a root filename for plot outputs. 

409 If there is only one metricBundle, this is equal to the metricBundle fileRoot + outfileSuffix. 

410 For multiple metricBundles, this is created from the runNames, metadata and metric names. 

411 

412 If you do not wish to use the automatic filenames, then you could set 'savefig' to False and 

413 save the file manually to disk, using the plot figure numbers returned by 'plot'. 

414 """ 

415 if len(self.mBundles) == 1: 

416 outfile = self.mBundles[0].fileRoot 

417 else: 

418 outfile = '_'.join([self.jointRunNames, self.jointMetricNames, self.jointMetadata]) 

419 outfile += '_' + self.mBundles[0].slicer.slicerName[:4].upper() 

420 if outfileSuffix is not None: 

421 outfile += '_' + outfileSuffix 

422 outfile = utils.nameSanitize(outfile) 

423 return outfile 

424 

425 def _buildDisplayDict(self): 

426 """ 

427 Generate a display dictionary. 

428 This is most useful for when there are many metricBundles being combined into a single plot. 

429 """ 

430 if len(self.mBundles) == 1: 

431 return self.mBundles[0].displayDict 

432 else: 

433 displayDict = {} 

434 group = set() 

435 subgroup = set() 

436 order = 0 

437 for mB in self.mBundles: 

438 group.add(mB.displayDict['group']) 

439 subgroup.add(mB.displayDict['subgroup']) 

440 if order < mB.displayDict['order']: 

441 order = mB.displayDict['order'] + 1 

442 displayDict['order'] = order 

443 if len(group) > 1: 

444 displayDict['group'] = 'Comparisons' 

445 else: 

446 displayDict['group'] = list(group)[0] 

447 if len(subgroup) > 1: 

448 displayDict['subgroup'] = 'Comparisons' 

449 else: 

450 displayDict['subgroup'] = list(subgroup)[0] 

451 

452 displayDict['caption'] = ('%s metric(s) calculated on a %s grid, ' 

453 'for opsim runs %s, for metadata values of %s.' 

454 % (self.jointMetricNames, 

455 self.mBundles[0].slicer.slicerName, 

456 self.jointRunNames, self.jointMetadata)) 

457 

458 return displayDict 

459 

460 def _checkPlotDicts(self): 

461 """ 

462 Check to make sure there are no conflicts in the plotDicts that are being used in the same subplot. 

463 """ 

464 # Check that the length is OK 

465 if len(self.plotDicts) != len(self.mBundles): 

466 raise ValueError('plotDicts (%i) must be same length as mBundles (%i)' 

467 % (len(self.plotDicts), len(self.mBundles))) 

468 

469 # These are the keys that need to match (or be None) 

470 keys2Check = ['xlim', 'ylim', 'colorMin', 'colorMax', 'title'] 

471 

472 # Identify how many subplots there are. If there are more than one, just don't change anything. 

473 # This assumes that if there are more than one, the plotDicts are actually all compatible. 

474 subplots = set() 

475 for pd in self.plotDicts: 

476 if 'subplot' in pd: 

477 subplots.add(pd['subplot']) 

478 

479 # Now check subplots are consistent. 

480 if len(subplots) <= 1: 

481 reset_keys = [] 

482 for key in keys2Check: 

483 values = [pd[key] for pd in self.plotDicts if key in pd] 

484 if len(np.unique(values)) > 1: 

485 # We will reset some of the keys to the default, but for some we should do better. 

486 if key.endswith('Max'): 

487 for pd in self.plotDicts: 

488 pd[key] = np.max(values) 

489 elif key.endswith('Min'): 

490 for pd in self.plotDicts: 

491 pd[key] = np.min(values) 

492 elif key == 'title': 

493 title = self._buildTitle() 

494 for pd in self.plotDicts: 

495 pd['title'] = title 

496 else: 

497 warnings.warn('Found more than one value to be set for "%s" in the plotDicts.' % (key) + 

498 ' Will reset to default value. (found values %s)' % values) 

499 reset_keys.append(key) 

500 # Reset the most of the keys to defaults; this can generally be done safely. 

501 for key in reset_keys: 

502 for pd in self.plotDicts: 

503 pd[key] = None 

504 

505 def plot(self, plotFunc, plotDicts=None, displayDict=None, outfileRoot=None, outfileSuffix=None): 

506 """ 

507 Create plot for mBundles, using plotFunc. 

508 

509 plotDicts: List of plotDicts if one wants to use a _new_ plotDict per MetricBundle. 

510 """ 

511 if not plotFunc.objectPlotter: 

512 # Check that metricValues type and plotter are compatible (most are float/float, but 

513 # some plotters expect object data .. and some only do sometimes). 

514 for mB in self.mBundles: 

515 if mB.metric.metricDtype == 'object': 

516 metricIsColor = mB.plotDict.get('metricIsColor', False) 

517 if not metricIsColor: 

518 warnings.warn('Cannot plot object metric values with this plotter.') 

519 return 

520 

521 # Update x/y labels using plotType. 

522 self.setPlotDicts(plotDicts=plotDicts, plotFunc=plotFunc, reset=False) 

523 # Set outfile name. 

524 if outfileRoot is None: 

525 outfile = self._buildFileRoot(outfileSuffix) 

526 else: 

527 outfile = outfileRoot 

528 plotType = plotFunc.plotType 

529 if len(self.mBundles) > 1: 

530 plotType = 'Combo' + plotType 

531 # Make plot. 

532 fignum = None 

533 for mB, plotDict in zip(self.mBundles, self.plotDicts): 

534 if mB.metricValues is None: 

535 # Skip this metricBundle. 

536 msg = 'MetricBundle (%s) has no attribute "metricValues".' % (mB.fileRoot) 

537 msg += ' Either the values have not been calculated or they have been deleted.' 

538 warnings.warn(msg) 

539 else: 

540 fignum = plotFunc(mB.metricValues, mB.slicer, plotDict, fignum=fignum) 

541 # Add a legend if more than one metricValue is being plotted or if legendloc is specified. 

542 legendloc = None 

543 if 'legendloc' in self.plotDicts[0]: 

544 legendloc = self.plotDicts[0]['legendloc'] 

545 if len(self.mBundles) > 1: 

546 try: 

547 legendloc = self.plotDicts[0]['legendloc'] 

548 except KeyError: 

549 legendloc = 'upper right' 

550 if legendloc is not None: 

551 plt.figure(fignum) 

552 plt.legend(loc=legendloc, fancybox=True, fontsize='smaller') 

553 # Add the super title if provided. 

554 if 'suptitle' in self.plotDicts[0]: 

555 plt.suptitle(self.plotDicts[0]['suptitle']) 

556 # Save to disk and file info to resultsDb if desired. 

557 if self.savefig: 

558 if displayDict is None: 

559 displayDict = self._buildDisplayDict() 

560 self.saveFig(fignum, outfile, plotType, self.jointMetricNames, self.slicer.slicerName, 

561 self.jointRunNames, self.constraints, self.jointMetadata, displayDict) 

562 return fignum 

563 

564 def saveFig(self, fignum, outfileRoot, plotType, metricName, slicerName, 

565 runName, constraint, metadata, displayDict=None): 

566 fig = plt.figure(fignum) 

567 plotFile = outfileRoot + '_' + plotType + '.' + self.figformat 

568 if self.trimWhitespace: 

569 fig.savefig(os.path.join(self.outDir, plotFile), figformat=self.figformat, dpi=self.dpi, 

570 bbox_inches='tight') 

571 else: 

572 fig.savefig(os.path.join(self.outDir, plotFile), figformat=self.figformat, dpi=self.dpi) 

573 # Generate a png thumbnail. 

574 if self.thumbnail: 

575 thumbFile = 'thumb.' + outfileRoot + '_' + plotType + '.png' 

576 plt.savefig(os.path.join(self.outDir, thumbFile), dpi=72, bbox_inches='tight') 

577 # Save information about the file to resultsDb. 

578 if self.resultsDb: 

579 if displayDict is None: 

580 displayDict = {} 

581 metricId = self.resultsDb.updateMetric(metricName, slicerName, runName, constraint, 

582 metadata, None) 

583 self.resultsDb.updateDisplay(metricId=metricId, displayDict=displayDict, overwrite=False) 

584 self.resultsDb.updatePlot(metricId=metricId, plotType=plotType, plotFile=plotFile)