Coverage for python/lsst/sims/maf/db/resultsDb.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from builtins import str
2from builtins import object
3import os, warnings
4from sqlalchemy import create_engine
5from sqlalchemy.orm import sessionmaker
6from sqlalchemy.engine import url
7from sqlalchemy.ext.declarative import declarative_base
8from sqlalchemy import Column, Integer, String, Float
9from sqlalchemy import ForeignKey
10from sqlalchemy.orm import relationship, backref
11from sqlalchemy.exc import DatabaseError
12from lsst.daf.persistence import DbAuth
14import numpy as np
16Base = declarative_base()
18__all__ = ['MetricRow', 'DisplayRow', 'PlotRow', 'SummaryStatRow', 'ResultsDb']
20class MetricRow(Base):
21 """
22 Define contents and format of metric list table.
24 (Table to list all metrics, their metadata, and their output data files).
25 """
26 __tablename__ = "metrics"
27 # Define columns in metric list table.
28 metricId = Column(Integer, primary_key=True)
29 metricName = Column(String)
30 slicerName = Column(String)
31 simDataName = Column(String)
32 sqlConstraint = Column(String)
33 metricMetadata = Column(String)
34 metricDataFile = Column(String)
35 def __repr__(self):
36 return "<Metric(metricId='%d', metricName='%s', slicerName='%s', simDataName='%s', sqlConstraint='%s', metadata='%s', metricDataFile='%s')>" \
37 %(self.metricId, self.metricName, self.slicerName, self.simDataName,
38 self.sqlConstraint, self.metricMetadata, self.metricDataFile)
40class DisplayRow(Base):
41 """
42 Define contents and format of the displays table.
44 (Table to list the display properties for each metric.)
45 """
46 __tablename__ = "displays"
47 displayId = Column(Integer, primary_key=True)
48 metricId = Column(Integer, ForeignKey('metrics.metricId'))
49 # Group for displaying metric (in webpages).
50 displayGroup = Column(String)
51 # Subgroup for displaying metric.
52 displaySubgroup = Column(String)
53 # Order to display metric (within subgroup).
54 displayOrder = Column(Float)
55 # The figure caption.
56 displayCaption = Column(String)
57 metric = relationship("MetricRow", backref=backref('displays', order_by=displayId))
58 def __rep__(self):
59 return "<Display(displayGroup='%s', displaySubgroup='%s', displayOrder='%.1f', displayCaption='%s')>" \
60 %(self.displayGroup, self.displaySubgroup, self.displayOrder, self.displayCaption)
62class PlotRow(Base):
63 """
64 Define contents and format of plot list table.
66 (Table to list all plots, link them to relevant metrics in MetricList, and provide info on filename).
67 """
68 __tablename__ = "plots"
69 # Define columns in plot list table.
70 plotId = Column(Integer, primary_key=True)
71 # Matches metricID in MetricList table.
72 metricId = Column(Integer, ForeignKey('metrics.metricId'))
73 plotType = Column(String)
74 plotFile = Column(String)
75 metric = relationship("MetricRow", backref=backref('plots', order_by=plotId))
76 def __repr__(self):
77 return "<Plot(metricId='%d', plotType='%s', plotFile='%s')>" \
78 %(self.metricId, self.plotType, self.plotFile)
80class SummaryStatRow(Base):
81 """
82 Define contents and format of the summary statistics table.
84 (Table to list and link summary stats to relevant metrics in MetricList, and provide summary stat name,
85 value and potentially a comment).
86 """
87 __tablename__ = "summarystats"
88 # Define columns in plot list table.
89 statId = Column(Integer, primary_key=True)
90 # Matches metricID in MetricList table.
91 metricId = Column(Integer, ForeignKey('metrics.metricId'))
92 summaryName = Column(String)
93 summaryValue = Column(Float)
94 metric = relationship("MetricRow", backref=backref('summarystats', order_by=statId))
95 def __repr__(self):
96 return "<SummaryStat(metricId='%d', summaryName='%s', summaryValue='%f')>" \
97 %(self.metricId, self.summaryName, self.summaryValue)
99class ResultsDb(object):
100 def __init__(self, outDir= None, database=None, driver='sqlite',
101 host=None, port=None, verbose=False):
102 """
103 Instantiate the results database, creating metrics, plots and summarystats tables.
104 """
105 # Connect to database
106 # for sqlite, connecting to non-existent database creates it automatically
107 if database is None:
108 # Using default value for database name, should specify directory.
109 if outDir is None:
110 outDir = '.'
111 # Check for output directory, make if needed.
112 if not os.path.isdir(outDir):
113 try:
114 os.makedirs(outDir)
115 except OSError as msg:
116 raise OSError(msg, '\n (If this was the database file (not outDir), '
117 'remember to use kwarg "database")')
118 self.database = os.path.join(outDir, 'resultsDb_sqlite.db')
119 self.driver = 'sqlite'
120 else:
121 if driver == 'sqlite':
122 # Using non-default database, but may also specify directory root.
123 if outDir is not None:
124 database = os.path.join(outDir, database)
125 self.database = database
126 self.driver = driver
127 else:
128 # If not sqlite, then 'outDir' doesn't make much sense.
129 self.database = database
130 self.driver = driver
131 self.host = host
132 self.port = port
134 if self.driver == 'sqlite':
135 dbAddress = url.URL(self.driver, database=self.database)
136 else:
137 dbAddress = url.URL(self.driver,
138 username=DbAuth.username(self.host, str(self.port)),
139 password=DbAuth.password(self.host, str(self.port)),
140 host=self.host,
141 port=self.port,
142 database=self.database)
144 engine = create_engine(dbAddress, echo=verbose)
145 self.Session = sessionmaker(bind=engine)
146 self.session = self.Session()
147 # Create the tables, if they don't already exist.
148 try:
149 Base.metadata.create_all(engine)
150 except DatabaseError:
151 raise ValueError("Cannot create a %s database at %s. Check directory exists." %(self.driver, self.database))
152 self.slen = 1024
154 def close(self):
155 """
156 Close connection to database.
157 """
158 self.session.close()
160 def updateMetric(self, metricName, slicerName, simDataName, sqlConstraint,
161 metricMetadata, metricDataFile):
162 """
163 Add a row to or update a row in the metrics table.
165 - metricName: the name of the metric
166 - sliceName: the name of the slicer
167 - simDataName: the name used to identify the simData
168 - sqlConstraint: the sql constraint used to select data from the simData
169 - metricMetadata: the metadata associated with the metric
170 - metricDatafile: the data file the metric data is stored in
172 If same metric (same metricName, slicerName, simDataName, sqlConstraint, metadata)
173 already exists, it does nothing.
175 Returns metricId: the Id number of this metric in the metrics table.
176 """
177 if simDataName is None:
178 simDataName = 'NULL'
179 if sqlConstraint is None:
180 sqlConstraint = 'NULL'
181 if metricMetadata is None:
182 metricMetadata = 'NULL'
183 if metricDataFile is None:
184 metricDataFile = 'NULL'
185 # Check if metric has already been added to database.
186 prev = self.session.query(MetricRow).filter_by(metricName=metricName,
187 slicerName=slicerName,
188 simDataName=simDataName,
189 metricMetadata=metricMetadata,
190 sqlConstraint=sqlConstraint).all()
191 if len(prev) == 0:
192 metricinfo = MetricRow(metricName=metricName, slicerName=slicerName, simDataName=simDataName,
193 sqlConstraint=sqlConstraint, metricMetadata=metricMetadata,
194 metricDataFile=metricDataFile)
195 self.session.add(metricinfo)
196 self.session.commit()
197 else:
198 metricinfo = prev[0]
199 return metricinfo.metricId
201 def updateDisplay(self, metricId, displayDict, overwrite=True):
202 """
203 Add a row to or update a row in the displays table.
205 - metricID: the metric Id of this metric in the metrics table
206 - displayDict: dictionary containing the display info
208 Replaces existing row with same metricId.
209 """
210 # Because we want to maintain 1-1 relationship between metricId's and displayDict's:
211 # First check if a display line is present with this metricID.
212 displayinfo = self.session.query(DisplayRow).filter_by(metricId=metricId).all()
213 if len(displayinfo) > 0:
214 if overwrite:
215 for d in displayinfo:
216 self.session.delete(d)
217 else:
218 return
219 # Then go ahead and add new displayDict.
220 for k in displayDict:
221 if displayDict[k] is None:
222 displayDict[k] = 'NULL'
223 keys = ['group', 'subgroup', 'order', 'caption']
224 for k in keys:
225 if k not in displayDict:
226 displayDict[k] = 'NULL'
227 if displayDict['order'] == 'NULL':
228 displayDict['order'] = 0
229 displayGroup = displayDict['group']
230 displaySubgroup = displayDict['subgroup']
231 displayOrder = displayDict['order']
232 displayCaption = displayDict['caption']
233 if displayCaption.endswith('(auto)'):
234 displayCaption = displayCaption.replace('(auto)', '', 1)
235 displayinfo = DisplayRow(metricId=metricId,
236 displayGroup=displayGroup, displaySubgroup=displaySubgroup,
237 displayOrder=displayOrder, displayCaption=displayCaption)
238 self.session.add(displayinfo)
239 self.session.commit()
241 def updatePlot(self, metricId, plotType, plotFile):
242 """
243 Add a row to or update a row in the plot table.
245 - metricId: the metric Id of this metric in the metrics table
246 - plotType: the 'type' of this plot
247 - plotFile: the filename of this plot
249 Remove older rows with the same metricId, plotType and plotFile.
250 """
251 plotinfo = self.session.query(PlotRow).filter_by(metricId=metricId, plotType=plotType,
252 plotFile=plotFile).all()
253 if len(plotinfo) > 0:
254 for p in plotinfo:
255 self.session.delete(p)
256 plotinfo = PlotRow(metricId=metricId, plotType=plotType, plotFile=plotFile)
257 self.session.add(plotinfo)
258 self.session.commit()
260 def updateSummaryStat(self, metricId, summaryName, summaryValue):
261 """
262 Add a row to or update a row in the summary statistic table.
264 - metricId: the metric ID of this metric in the metrics table
265 - summaryName: the name of this summary statistic
266 - summaryValue: the value for this summary statistic
268 Most summary statistics will be a simple name (string) + value (float) pair.
269 For special summary statistics which must return multiple values, the base name
270 can be provided as 'name', together with a np recarray as 'value', where the
271 recarray also has 'name' and 'value' columns (and each name/value pair is then saved
272 as a summary statistic associated with this same metricId).
273 """
274 # Allow for special summary statistics which return data in a np structured array with
275 # 'name' and 'value' columns. (specificially needed for TableFraction summary statistic).
276 if isinstance(summaryValue, np.ndarray):
277 if (('name' in summaryValue.dtype.names) and ('value' in summaryValue.dtype.names)):
278 for value in summaryValue:
279 sSuffix = value['name']
280 if isinstance(sSuffix, bytes):
281 sSuffix = sSuffix.decode('utf-8')
282 else:
283 sSuffix = str(sSuffix)
284 summarystat = SummaryStatRow(metricId=metricId,
285 summaryName=summaryName + ' ' + sSuffix,
286 summaryValue=value['value'])
287 self.session.add(summarystat)
288 self.session.commit()
289 else:
290 warnings.warn('Warning! Cannot save non-conforming summary statistic.')
291 # Most summary statistics will be simple floats.
292 else:
293 if isinstance(summaryValue, float) or isinstance(summaryValue, int):
294 summarystat = SummaryStatRow(metricId=metricId, summaryName=summaryName,
295 summaryValue=summaryValue)
296 self.session.add(summarystat)
297 self.session.commit()
298 else:
299 warnings.warn('Warning! Cannot save summary statistic that is not a simple float or int')
301 def getMetricId(self, metricName, slicerName=None, metricMetadata=None, simDataName=None):
302 """
303 Given a metric name and optional slicerName/metricMetadata/simData information,
304 Return a list of the matching metricIds.
305 """
306 metricId = []
307 query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName,
308 MetricRow.metricMetadata,
309 MetricRow.simDataName).filter(MetricRow.metricName == metricName)
310 if slicerName is not None:
311 query = query.filter(MetricRow.slicerName == slicerName)
312 if metricMetadata is not None:
313 query = query.filter(MetricRow.metricMetadata == metricMetadata)
314 if simDataName is not None:
315 query = query.filter(MetricRow.simDataName == simDataName)
316 query = query.order_by(MetricRow.slicerName, MetricRow.metricMetadata)
317 for m in query:
318 metricId.append(m.metricId)
319 return metricId
321 def getMetricIdLike(self, metricNameLike=None, slicerNameLike=None,
322 metricMetadataLike=None, simDataName=None):
323 metricId = []
324 query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName,
325 MetricRow.metricMetadata,
326 MetricRow.simDataName)
327 if metricNameLike is not None:
328 query = query.filter(MetricRow.metricName.like('%' + str(metricNameLike) + '%'))
329 if slicerNameLike is not None:
330 query = query.filter(MetricRow.slicerName.like('%' + str(slicerNameLike) + '%'))
331 if metricMetadataLike is not None:
332 query = query.filter(MetricRow.metricMetadata.like('%' + str(metricMetadataLike) + '%'))
333 if simDataName is not None:
334 query = query.filter(MetricRow.simDataName == simDataName)
335 for m in query:
336 metricId.append(m.metricId)
337 return metricId
339 def getAllMetricIds(self):
340 """
341 Return a list of all metricIds.
342 """
343 metricIds = []
344 for m in self.session.query(MetricRow.metricId).all():
345 metricIds.append(m.metricId)
346 return metricIds
348 def getSummaryStats(self, metricId=None, summaryName=None):
349 """
350 Get the summary stats (optionally for metricId list).
351 Optionally, also specify the summary metric name.
352 Returns a numpy array of the metric information + summary statistic information.
353 """
354 if metricId is None:
355 metricId = self.getAllMetricIds()
356 if not hasattr(metricId, '__iter__'):
357 metricId = [metricId,]
358 summarystats = []
359 for mid in metricId:
360 # Join the metric table and the summarystat table, based on the metricID (the second filter)
361 query = (self.session.query(MetricRow, SummaryStatRow).filter(MetricRow.metricId == mid)
362 .filter(MetricRow.metricId == SummaryStatRow.metricId))
363 if summaryName is not None:
364 query = query.filter(SummaryStatRow.summaryName == summaryName)
365 for m, s in query:
366 summarystats.append((m.metricId, m.metricName, m.slicerName, m.metricMetadata,
367 s.summaryName, s.summaryValue))
368 # Convert to numpy array.
369 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
370 ('slicerName', np.str_, self.slen), ('metricMetadata', np.str_, self.slen),
371 ('summaryName', np.str_, self.slen), ('summaryValue', float)])
372 summarystats = np.array(summarystats, dtype)
373 return summarystats
375 def getPlotFiles(self, metricId=None):
376 """
377 Return the metricId, name, metadata, and all plot info (optionally for metricId list).
378 Returns a numpy array of the metric information + plot file names.
379 """
380 if metricId is None:
381 metricId = self.getAllMetricIds()
382 if not hasattr(metricId, '__iter__'):
383 metricId = [metricId,]
384 plotFiles = []
385 for mid in metricId:
386 # Join the metric table and the plot table based on the metricID (the second filter does the join)
387 query = (self.session.query(MetricRow, PlotRow).filter(MetricRow.metricId == mid)
388 .filter(MetricRow.metricId == PlotRow.metricId))
389 for m, p in query:
390 # The plotFile typically ends with .pdf (but the rest of name can have '.' or '_')
391 thumbfile = 'thumb.' + '.'.join(p.plotFile.split('.')[:-1]) + '.png'
392 plotFiles.append((m.metricId, m.metricName, m.metricMetadata,
393 p.plotType, p.plotFile, thumbfile))
394 # Convert to numpy array.
395 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
396 ('metricMetadata', np.str_, self.slen),
397 ('plotType', np.str_, self.slen), ('plotFile', np.str_, self.slen),
398 ('thumbFile', np.str_, self.slen)])
399 plotFiles = np.array(plotFiles, dtype)
400 return plotFiles
402 def getMetricDataFiles(self, metricId=None):
403 """
404 Get the metric data filenames for all or a single metric.
405 Returns a list.
406 """
407 if metricId is None:
408 metricId = self.getAllMetricIds()
409 if not hasattr(metricId, '__iter__'):
410 metricId = [metricId,]
411 dataFiles = []
412 for mid in metricId:
413 for m in self.session.query(MetricRow).filter(MetricRow.metricId == mid).all():
414 dataFiles.append(m.metricDataFile)
415 return dataFiles
418 def getMetricDisplayInfo(self, metricId=None):
419 """
420 Get the contents of the metrics and displays table, together with the 'basemetricname'
421 (optionally, for metricId list).
422 Returns a numpy array of the metric information + display information.
423 """
424 if metricId is None:
425 metricId = self.getAllMetricIds()
426 if not hasattr(metricId, '__iter__'):
427 metricId = [metricId,]
428 metricInfo = []
429 for mId in metricId:
430 # Query for all rows in metrics and displays that match any of the metricIds.
431 query = (self.session.query(MetricRow, DisplayRow).filter(MetricRow.metricId==mId)
432 .filter(MetricRow.metricId==DisplayRow.metricId))
433 for m, d in query:
434 baseMetricName = m.metricName.split('_')[0]
435 mInfo = (m.metricId, m.metricName, baseMetricName, m.slicerName,
436 m.sqlConstraint, m.metricMetadata, m.metricDataFile,
437 d.displayGroup, d.displaySubgroup, d.displayOrder, d.displayCaption)
438 metricInfo.append(mInfo)
439 # Convert to numpy array.
440 dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
441 ('baseMetricNames', np.str_, self.slen),
442 ('slicerName', np.str_, self.slen),
443 ('sqlConstraint', np.str_, self.slen),
444 ('metricMetadata', np.str_, self.slen),
445 ('metricDataFile', np.str_, self.slen),
446 ('displayGroup', np.str_, self.slen),
447 ('displaySubgroup', np.str_, self.slen),
448 ('displayOrder', float),
449 ('displayCaption', np.str_, self.slen * 10)])
450 metricInfo = np.array(metricInfo, dtype)
451 return metricInfo