Coverage for python/lsst/sims/maf/plots/plotHandler.py : 5%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from builtins import zip
2from builtins import range
3from builtins import object
4import os
5import numpy as np
6import warnings
7import matplotlib.pyplot as plt
8import lsst.sims.maf.utils as utils
10__all__ = ['applyZPNorm', 'PlotHandler', 'BasePlotter']
12def applyZPNorm(metricValue, plotDict):
13 if 'zp' in plotDict:
14 if plotDict['zp'] is not None:
15 metricValue = metricValue - plotDict['zp']
16 if 'normVal' in plotDict:
17 if plotDict['normVal'] is not None:
18 metricValue = metricValue / plotDict['normVal']
19 return metricValue
22class BasePlotter(object):
23 """
24 Serve as the base type for MAF plotters and example of API.
25 """
26 def __init__(self):
27 self.plotType = None
28 # This should be included in every subsequent defaultPlotDict (assumed to be present).
29 self.defaultPlotDict = {'title': None, 'xlabel': None, 'label': None,
30 'labelsize': None, 'fontsize': None, 'figsize': None}
32 def __call__(self, metricValue, slicer, userPlotDict, fignum=None):
33 pass
36class PlotHandler(object):
38 def __init__(self, outDir='.', resultsDb=None, savefig=True,
39 figformat='pdf', dpi=600, thumbnail=True, trimWhitespace=True):
40 self.outDir = outDir
41 self.resultsDb = resultsDb
42 self.savefig = savefig
43 self.figformat = figformat
44 self.dpi = dpi
45 self.trimWhitespace = trimWhitespace
46 self.thumbnail = thumbnail
47 self.filtercolors = {'u': 'cyan', 'g': 'g', 'r': 'y',
48 'i': 'r', 'z': 'm', 'y': 'k', ' ': None}
49 self.filterorder = {' ': -1, 'u': 0, 'g': 1, 'r': 2, 'i': 3, 'z': 4, 'y': 5}
51 def setMetricBundles(self, mBundles):
52 """
53 Set the metric bundle or bundles (list or dictionary).
54 Reuse the PlotHandler by resetting this reference.
55 The metric bundles have to have the same slicer.
56 """
57 self.mBundles = []
58 # Try to add the metricBundles in filter order.
59 if isinstance(mBundles, dict):
60 for mB in mBundles.values():
61 vals = mB.fileRoot.split('_')
62 forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1]
63 forder = [o for o in forder if o is not None]
64 if len(forder) == 0:
65 forder = len(self.mBundles)
66 else:
67 forder = forder[-1]
68 self.mBundles.insert(forder, mB)
69 self.slicer = self.mBundles[0].slicer
70 else:
71 for mB in mBundles:
72 vals = mB.fileRoot.split('_')
73 forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1]
74 forder = [o for o in forder if o is not None]
75 if len(forder) == 0:
76 forder = len(self.mBundles)
77 else:
78 forder = forder[-1]
79 self.mBundles.insert(forder, mB)
80 self.slicer = self.mBundles[0].slicer
81 for mB in self.mBundles:
82 if mB.slicer.slicerName != self.slicer.slicerName:
83 raise ValueError('MetricBundle items must have the same type of slicer')
84 self._combineMetricNames()
85 self._combineRunNames()
86 self._combineMetadata()
87 self._combineConstraints()
88 self.setPlotDicts(reset=True)
90 def setPlotDicts(self, plotDicts=None, plotFunc=None, reset=False):
91 """
92 Set or update (or 'reset') the plotDict for the (possibly joint) plots.
94 Resolution is:
95 auto-generated items (colors/labels/titles)
96 < anything previously set in the plotHandler
97 < defaults set by the plotter
98 < explicitly set items in the metricBundle plotDict
99 < explicitly set items in the plotDicts list passed to this method.
100 """
101 if reset:
102 # Have to explicitly set each dictionary to a (separate) blank dictionary.
103 self.plotDicts = [{} for b in self.mBundles]
105 if isinstance(plotDicts, dict):
106 # We were passed a single dictionary, not a list.
107 plotDicts = [plotDicts] * len(self.mBundles)
109 autoLabelList = self._buildLegendLabels()
110 autoColorList = self._buildColors()
111 autoCbar = self._buildCbarFormat()
112 autoTitle = self._buildTitle()
113 if plotFunc is not None:
114 autoXlabel, autoYlabel = self._buildXYlabels(plotFunc)
116 # Loop through each bundle and generate a plotDict for it.
117 for i, bundle in enumerate(self.mBundles):
118 # First use the auto-generated values.
119 tmpPlotDict = {}
120 tmpPlotDict['title'] = autoTitle
121 tmpPlotDict['label'] = autoLabelList[i]
122 tmpPlotDict['color'] = autoColorList[i]
123 tmpPlotDict['cbarFormat'] = autoCbar
124 # Then update that with anything previously set in the plotHandler.
125 tmpPlotDict.update(self.plotDicts[i])
126 # Then override with plotDict items set explicitly based on the plot type.
127 if plotFunc is not None:
128 tmpPlotDict['xlabel'] = autoXlabel
129 tmpPlotDict['ylabel'] = autoYlabel
130 # Replace auto-generated plot dict items with things
131 # set by the plotterDefaults, if they are not None.
132 plotterDefaults = plotFunc.defaultPlotDict
133 for k, v in plotterDefaults.items():
134 if v is not None:
135 tmpPlotDict[k] = v
136 # Then add/override based on the bundle plotDict parameters if they are set.
137 tmpPlotDict.update(bundle.plotDict)
138 # Finally, override with anything set explicitly by the user right now.
139 if plotDicts is not None:
140 tmpPlotDict.update(plotDicts[i])
141 # And save this new dictionary back in the class.
142 self.plotDicts[i] = tmpPlotDict
144 # Check that the plotDicts do not conflict.
145 self._checkPlotDicts()
147 def _combineMetricNames(self):
148 """
149 Combine metric names.
150 """
151 # Find the unique metric names.
152 self.metricNames = set()
153 for mB in self.mBundles:
154 self.metricNames.add(mB.metric.name)
155 # Find a pleasing combination of the metric names.
156 order = ['u', 'g', 'r', 'i', 'z', 'y']
157 if len(self.metricNames) == 1:
158 jointName = ' '.join(self.metricNames)
159 else:
160 # Split each unique name into a list to see if we can merge the names.
161 nameLengths = [len(x.split()) for x in self.metricNames]
162 nameLists = [x.split() for x in self.metricNames]
163 # If the metric names are all the same length, see if we can combine any parts.
164 if len(set(nameLengths)) == 1:
165 jointName = []
166 for i in range(nameLengths[0]):
167 tmp = set([x[i] for x in nameLists])
168 # Try to catch special case of filters and put them in order.
169 if tmp.intersection(order) == tmp:
170 filterlist = ''
171 for f in order:
172 if f in tmp:
173 filterlist += f
174 jointName.append(filterlist)
175 else:
176 # Otherwise, just join and put into jointName.
177 jointName.append(''.join(tmp))
178 jointName = ' '.join(jointName)
179 # If the metric names are not the same length, just join everything.
180 else:
181 jointName = ' '.join(self.metricNames)
182 self.jointMetricNames = jointName
184 def _combineRunNames(self):
185 """
186 Combine runNames.
187 """
188 self.runNames = set()
189 for mB in self.mBundles:
190 self.runNames.add(mB.runName)
191 self.jointRunNames = ' '.join(self.runNames)
193 def _combineMetadata(self):
194 """
195 Combine metadata.
196 """
197 metadata = set()
198 for mB in self.mBundles:
199 metadata.add(mB.metadata)
200 self.metadata = metadata
201 # Find a pleasing combination of the metadata.
202 if len(metadata) == 1:
203 self.jointMetadata = ' '.join(metadata)
204 else:
205 order = ['u', 'g', 'r', 'i', 'z', 'y']
206 # See if there are any subcomponents we can combine,
207 # splitting on some values we expect to separate metadata clauses.
208 splitmetas = []
209 for m in self.metadata:
210 # Try to split metadata into separate phrases (filter / proposal / constraint..).
211 if ' and ' in m:
212 m = m.split(' and ')
213 elif ', ' in m:
214 m = m.split(', ')
215 else:
216 m = [m, ]
217 # Strip white spaces from individual elements.
218 m = set([im.strip() for im in m])
219 splitmetas.append(m)
220 # Look for common elements and separate from the general metadata.
221 common = set.intersection(*splitmetas)
222 diff = [x.difference(common) for x in splitmetas]
223 # Now look within the 'diff' elements and see if there are any common words to split off.
224 diffsplit = []
225 for d in diff:
226 if len(d) > 0:
227 m = set([x.split() for x in d][0])
228 else:
229 m = set()
230 diffsplit.append(m)
231 diffcommon = set.intersection(*diffsplit)
232 diffdiff = [x.difference(diffcommon) for x in diffsplit]
233 # If the length of any of the 'differences' is 0, then we should stop and not try to subdivide.
234 lengths = [len(x) for x in diffdiff]
235 if min(lengths) == 0:
236 # Sort them in order of length (so it goes 'g', 'g dithered', etc.)
237 tmp = []
238 for d in diff:
239 tmp.append(list(d)[0])
240 diff = tmp
241 xlengths = [len(x) for x in diff]
242 idx = np.argsort(xlengths)
243 diffdiff = [diff[i] for i in idx]
244 diffcommon = []
245 else:
246 # diffdiff is the part where we might expect our filter values to appear;
247 # try to put this in order.
248 diffdiffOrdered = []
249 diffdiffEnd = []
250 for f in order:
251 for d in diffdiff:
252 if len(d) == 1:
253 if list(d)[0] == f:
254 diffdiffOrdered.append(d)
255 for d in diffdiff:
256 if d not in diffdiffOrdered:
257 diffdiffEnd.append(d)
258 diffdiff = diffdiffOrdered + diffdiffEnd
259 diffdiff = [' '.join(c) for c in diffdiff]
260 # And put it all back together.
261 combo = (', '.join([''.join(c) for c in diffdiff]) + ' ' +
262 ' '.join([''.join(d) for d in diffcommon]) + ' ' +
263 ' '.join([''.join(e) for e in common]))
264 self.jointMetadata = combo
266 def _combineConstraints(self):
267 """
268 Combine the constraints.
269 """
270 constraints = set()
271 for mB in self.mBundles:
272 if mB.constraint is not None:
273 constraints.add(mB.constraint)
274 self.constraints = '; '.join(constraints)
276 def _buildTitle(self):
277 """
278 Build a plot title from the metric names, runNames and metadata.
279 """
280 # Create a plot title from the unique parts of the metric/runName/metadata.
281 plotTitle = ''
282 if len(self.runNames) == 1:
283 plotTitle += list(self.runNames)[0]
284 if len(self.metadata) == 1:
285 plotTitle += ' ' + list(self.metadata)[0]
286 if len(self.metricNames) == 1:
287 plotTitle += ': ' + list(self.metricNames)[0]
288 if plotTitle == '':
289 # If there were more than one of everything above, use joint metadata and metricNames.
290 plotTitle = self.jointMetadata + ' ' + self.jointMetricNames
291 return plotTitle
293 def _buildXYlabels(self, plotFunc):
294 """
295 Build a plot x and y label.
296 """
297 if plotFunc.plotType == 'BinnedData':
298 if len(self.mBundles) == 1:
299 mB = self.mBundles[0]
300 xlabel = mB.slicer.sliceColName + ' (' + mB.slicer.sliceColUnits + ')'
301 ylabel = mB.metric.name + ' (' + mB.metric.units + ')'
302 else:
303 xlabel = set()
304 for mB in self.mBundles:
305 xlabel.add(mB.slicer.sliceColName)
306 xlabel = ', '.join(xlabel)
307 ylabel = self.jointMetricNames
308 elif plotFunc.plotType == 'MetricVsH':
309 if len(self.mBundles) == 1:
310 mB = self.mBundles[0]
311 ylabel = mB.metric.name + ' (' + mB.metric.units + ')'
312 else:
313 ylabel = self.jointMetricNames
314 xlabel = 'H (mag)'
315 else:
316 if len(self.mBundles) == 1:
317 mB = self.mBundles[0]
318 xlabel = mB.metric.name
319 if mB.metric.units is not None:
320 if len(mB.metric.units) > 0:
321 xlabel += ' (' + mB.metric.units + ')'
322 ylabel = None
323 else:
324 xlabel = self.jointMetricNames
325 ylabel = set()
326 for mB in self.mBundles:
327 if 'ylabel' in mB.plotDict:
328 ylabel.add(mB.plotDict['ylabel'])
329 if len(ylabel) == 1:
330 ylabel = list(ylabel)[0]
331 else:
332 ylabel = None
333 return xlabel, ylabel
335 def _buildLegendLabels(self):
336 """
337 Build a set of legend labels, using parts of the runName/metadata/metricNames that change.
338 """
339 if len(self.mBundles) == 1:
340 return [None]
341 labels = []
342 for mB in self.mBundles:
343 if 'label' in mB.plotDict:
344 label = mB.plotDict['label']
345 else:
346 label = ''
347 if len(self.runNames) > 1:
348 label += mB.runName
349 if len(self.metadata) > 1:
350 label += ' ' + mB.metadata
351 if len(self.metricNames) > 1:
352 label += ' ' + mB.metric.name
353 labels.append(label)
354 return labels
356 def _buildColors(self):
357 """
358 Try to set an appropriate range of colors for the metric Bundles.
359 """
360 if len(self.mBundles) == 1:
361 if 'color' in self.mBundles[0].plotDict:
362 return [self.mBundles[0].plotDict['color']]
363 else:
364 return ['b']
365 colors = []
366 for mB in self.mBundles:
367 color = 'b'
368 if 'color' in mB.plotDict:
369 color = mB.plotDict['color']
370 else:
371 if mB.constraint is not None:
372 # If the filter is part of the sql constraint, we'll
373 # try to use that first.
374 if 'filter' in mB.constraint:
375 vals = mB.constraint.split('"')
376 for v in vals:
377 if len(v) == 1:
378 # Guess that this is the filter value
379 if v in self.filtercolors:
380 color = self.filtercolors[v]
381 colors.append(color)
382 # If we happened to end up with the same color throughout
383 # (say, the metrics were all in the same filter)
384 # then go ahead and generate random colors.
385 if (len(self.mBundles) > 1) and (len(np.unique(colors)) == 1):
386 colors = [np.random.rand(3,) for mB in self.mBundles]
387 return colors
389 def _buildCbarFormat(self):
390 """
391 Set the color bar format.
392 """
393 cbarFormat = None
394 if len(self.mBundles) == 1:
395 if self.mBundles[0].metric.metricDtype == 'int':
396 cbarFormat = '%d'
397 else:
398 metricDtypes = set()
399 for mB in self.mBundles:
400 metricDtypes.add(mB.metric.metricDtype)
401 if len(metricDtypes) == 1:
402 if list(metricDtypes)[0] == 'int':
403 cbarFormat = '%d'
404 return cbarFormat
406 def _buildFileRoot(self, outfileSuffix=None):
407 """
408 Build a root filename for plot outputs.
409 If there is only one metricBundle, this is equal to the metricBundle fileRoot + outfileSuffix.
410 For multiple metricBundles, this is created from the runNames, metadata and metric names.
412 If you do not wish to use the automatic filenames, then you could set 'savefig' to False and
413 save the file manually to disk, using the plot figure numbers returned by 'plot'.
414 """
415 if len(self.mBundles) == 1:
416 outfile = self.mBundles[0].fileRoot
417 else:
418 outfile = '_'.join([self.jointRunNames, self.jointMetricNames, self.jointMetadata])
419 outfile += '_' + self.mBundles[0].slicer.slicerName[:4].upper()
420 if outfileSuffix is not None:
421 outfile += '_' + outfileSuffix
422 outfile = utils.nameSanitize(outfile)
423 return outfile
425 def _buildDisplayDict(self):
426 """
427 Generate a display dictionary.
428 This is most useful for when there are many metricBundles being combined into a single plot.
429 """
430 if len(self.mBundles) == 1:
431 return self.mBundles[0].displayDict
432 else:
433 displayDict = {}
434 group = set()
435 subgroup = set()
436 order = 0
437 for mB in self.mBundles:
438 group.add(mB.displayDict['group'])
439 subgroup.add(mB.displayDict['subgroup'])
440 if order < mB.displayDict['order']:
441 order = mB.displayDict['order'] + 1
442 displayDict['order'] = order
443 if len(group) > 1:
444 displayDict['group'] = 'Comparisons'
445 else:
446 displayDict['group'] = list(group)[0]
447 if len(subgroup) > 1:
448 displayDict['subgroup'] = 'Comparisons'
449 else:
450 displayDict['subgroup'] = list(subgroup)[0]
452 displayDict['caption'] = ('%s metric(s) calculated on a %s grid, '
453 'for opsim runs %s, for metadata values of %s.'
454 % (self.jointMetricNames,
455 self.mBundles[0].slicer.slicerName,
456 self.jointRunNames, self.jointMetadata))
458 return displayDict
460 def _checkPlotDicts(self):
461 """
462 Check to make sure there are no conflicts in the plotDicts that are being used in the same subplot.
463 """
464 # Check that the length is OK
465 if len(self.plotDicts) != len(self.mBundles):
466 raise ValueError('plotDicts (%i) must be same length as mBundles (%i)'
467 % (len(self.plotDicts), len(self.mBundles)))
469 # These are the keys that need to match (or be None)
470 keys2Check = ['xlim', 'ylim', 'colorMin', 'colorMax', 'title']
472 # Identify how many subplots there are. If there are more than one, just don't change anything.
473 # This assumes that if there are more than one, the plotDicts are actually all compatible.
474 subplots = set()
475 for pd in self.plotDicts:
476 if 'subplot' in pd:
477 subplots.add(pd['subplot'])
479 # Now check subplots are consistent.
480 if len(subplots) <= 1:
481 reset_keys = []
482 for key in keys2Check:
483 values = [pd[key] for pd in self.plotDicts if key in pd]
484 if len(np.unique(values)) > 1:
485 # We will reset some of the keys to the default, but for some we should do better.
486 if key.endswith('Max'):
487 for pd in self.plotDicts:
488 pd[key] = np.max(values)
489 elif key.endswith('Min'):
490 for pd in self.plotDicts:
491 pd[key] = np.min(values)
492 elif key == 'title':
493 title = self._buildTitle()
494 for pd in self.plotDicts:
495 pd['title'] = title
496 else:
497 warnings.warn('Found more than one value to be set for "%s" in the plotDicts.' % (key) +
498 ' Will reset to default value. (found values %s)' % values)
499 reset_keys.append(key)
500 # Reset the most of the keys to defaults; this can generally be done safely.
501 for key in reset_keys:
502 for pd in self.plotDicts:
503 pd[key] = None
505 def plot(self, plotFunc, plotDicts=None, displayDict=None, outfileRoot=None, outfileSuffix=None):
506 """
507 Create plot for mBundles, using plotFunc.
509 plotDicts: List of plotDicts if one wants to use a _new_ plotDict per MetricBundle.
510 """
511 if not plotFunc.objectPlotter:
512 # Check that metricValues type and plotter are compatible (most are float/float, but
513 # some plotters expect object data .. and some only do sometimes).
514 for mB in self.mBundles:
515 if mB.metric.metricDtype == 'object':
516 metricIsColor = mB.plotDict.get('metricIsColor', False)
517 if not metricIsColor:
518 warnings.warn('Cannot plot object metric values with this plotter.')
519 return
521 # Update x/y labels using plotType.
522 self.setPlotDicts(plotDicts=plotDicts, plotFunc=plotFunc, reset=False)
523 # Set outfile name.
524 if outfileRoot is None:
525 outfile = self._buildFileRoot(outfileSuffix)
526 else:
527 outfile = outfileRoot
528 plotType = plotFunc.plotType
529 if len(self.mBundles) > 1:
530 plotType = 'Combo' + plotType
531 # Make plot.
532 fignum = None
533 for mB, plotDict in zip(self.mBundles, self.plotDicts):
534 if mB.metricValues is None:
535 # Skip this metricBundle.
536 msg = 'MetricBundle (%s) has no attribute "metricValues".' % (mB.fileRoot)
537 msg += ' Either the values have not been calculated or they have been deleted.'
538 warnings.warn(msg)
539 else:
540 fignum = plotFunc(mB.metricValues, mB.slicer, plotDict, fignum=fignum)
541 # Add a legend if more than one metricValue is being plotted or if legendloc is specified.
542 legendloc = None
543 if 'legendloc' in self.plotDicts[0]:
544 legendloc = self.plotDicts[0]['legendloc']
545 if len(self.mBundles) > 1:
546 try:
547 legendloc = self.plotDicts[0]['legendloc']
548 except KeyError:
549 legendloc = 'upper right'
550 if legendloc is not None:
551 plt.figure(fignum)
552 plt.legend(loc=legendloc, fancybox=True, fontsize='smaller')
553 # Add the super title if provided.
554 if 'suptitle' in self.plotDicts[0]:
555 plt.suptitle(self.plotDicts[0]['suptitle'])
556 # Save to disk and file info to resultsDb if desired.
557 if self.savefig:
558 if displayDict is None:
559 displayDict = self._buildDisplayDict()
560 self.saveFig(fignum, outfile, plotType, self.jointMetricNames, self.slicer.slicerName,
561 self.jointRunNames, self.constraints, self.jointMetadata, displayDict)
562 return fignum
564 def saveFig(self, fignum, outfileRoot, plotType, metricName, slicerName,
565 runName, constraint, metadata, displayDict=None):
566 fig = plt.figure(fignum)
567 plotFile = outfileRoot + '_' + plotType + '.' + self.figformat
568 if self.trimWhitespace:
569 fig.savefig(os.path.join(self.outDir, plotFile), figformat=self.figformat, dpi=self.dpi,
570 bbox_inches='tight')
571 else:
572 fig.savefig(os.path.join(self.outDir, plotFile), figformat=self.figformat, dpi=self.dpi)
573 # Generate a png thumbnail.
574 if self.thumbnail:
575 thumbFile = 'thumb.' + outfileRoot + '_' + plotType + '.png'
576 plt.savefig(os.path.join(self.outDir, thumbFile), dpi=72, bbox_inches='tight')
577 # Save information about the file to resultsDb.
578 if self.resultsDb:
579 if displayDict is None:
580 displayDict = {}
581 metricId = self.resultsDb.updateMetric(metricName, slicerName, runName, constraint,
582 metadata, None)
583 self.resultsDb.updateDisplay(metricId=metricId, displayDict=displayDict, overwrite=False)
584 self.resultsDb.updatePlot(metricId=metricId, plotType=plotType, plotFile=plotFile)