lsst.pipe.tasks  13.0-66-gfbf2f2ce+5
repositoryIterator.py
Go to the documentation of this file.
1 #
2 # LSST Data Management System
3 # Copyright 2008, 2009, 2010, 2011, 2012 LSST Corporation.
4 #
5 # This product includes software developed by the
6 # LSST Project (http://www.lsst.org/).
7 #
8 # This program is free software: you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation, either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the LSST License Statement and
19 # the GNU General Public License along with this program. If not,
20 # see <http://www.lsstcorp.org/LegalNotices/>.
21 #
22 """Tools to help you iterate over a set of repositories.
23 
24 Helpful while creating them or harvesting data from them.
25 """
26 from __future__ import absolute_import, division, print_function
27 from builtins import zip
28 from builtins import object
29 import itertools
30 
31 import numpy
32 
33 STR_PADDING = 5 # used by _getDTypeList; the number of characters to add to the first string value seen
34 # when estimating the number of characters needed to store values for a key
35 
36 
37 def _getDTypeList(keyTuple, valTuple):
38  """Construct a numpy dtype for a data ID or repository ID
39 
40  @param[in] keyTuple: ID key names, in order
41  @param[in] valTuple: a value tuple
42  @return numpy dtype as a list
43 
44  @warning: this guesses at string length (STR_PADDING + length of string in valTuple);
45  longer strings will be truncated when inserted into numpy structured arrays
46  """
47  typeList = []
48  for name, val in zip(keyTuple, valTuple):
49  if isinstance(val, str):
50  predLen = len(val) + STR_PADDING
51  typeList.append((name, str, predLen))
52  else:
53  typeList.append((name, numpy.array([val]).dtype))
54  return typeList
55 
56 
57 class SourceData(object):
58  """Accumulate a set of measurements from a set of source tables
59 
60  To use:
61  - specify the desired source measurements when constructing this object
62  - call addSourceMetrics for each repository you harvest data from
63  - call finalize to produce the final data
64 
65  Data available after calling finalize:
66  - self.sourceArr: a numpy structured array of shape (num repositories, num sources)
67  containing named columns for:
68  - source ID
69  - each data ID key
70  - each item of data extracted from the source table
71  - self.sourceIdDict: a dict of (source ID: index of axis 1 of self.sourceArr)
72  - self.repoArr: a numpy structured array of shape (num repositories,)
73  containing a named column for each repository key (see RepositoryIterator)
74 
75  @note: sources that had non-finite data (e.g. NaN) for every value extracted are silently omitted
76  """
77 
78  def __init__(self, datasetType, sourceKeyTuple):
79  """
80  @param[in] datasetType: dataset type for source
81  @param[in] sourceKeyTuple: list of keys of data items to extract from the source tables
82 
83  @raise RuntimeError if sourceKeyTuple is empty
84  """
85  if len(sourceKeyTuple) < 1:
86  raise RuntimeError("Must specify at least one key in sourceKeyTuple")
87  self.datasetType = datasetType
88  self._sourceKeyTuple = tuple(sourceKeyTuple)
89 
90  self._idKeyTuple = None # tuple of data ID keys, in order; set by first call to _getSourceMetrics
91  self._idKeyDTypeList = None # numpy dtype for data ID tuple, as a list of (key, type);
92  # set by first call to _getSourceMetrics
93  self._sourceDTypeList = None # numpy dtype for source data, as a list of (key, type);
94  # set by first call to _getSourceMetrics
95  self._repoKeyTuple = None # tuple of repo ID keys, in order; set by first call to addSourceMetrics
96  self._repoDTypeList = None # numpy dtype for repoArr, as a list of (key, type);
97  # set by first call to addSourceMetrics
98 
99  self._tempDataList = [] # list (one entry per repository)
100  # of dict of source ID: tuple of data ID data concatenated with source metric data, where:
101  # data ID data is in order self._idKeyTuple
102  # source metric data is in order self._sourceKeyTuple
103  self.repoInfoList = [] # list of repoInfo
104 
105  def _getSourceMetrics(self, idKeyTuple, idValList, sourceTableList):
106  """Obtain the desired source measurements from a list of source tables
107 
108  Extracts a set of source measurements (specified by sourceKeyTuple) from a list of source tables
109  (one per data ID) and saves them as a dict of source ID: list of data
110 
111  @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call
112  @param[in] idValList: a list of data ID value tuples;
113  each tuple contains values in the order in idKeyTuple
114  @param[in] sourceTableList: a list of source tables, one per entry in idValList
115 
116  @return a dict of source id: data id tuple + source data tuple
117  where source data tuple order matches sourceKeyTuple
118  and data id tuple matches self._idKeyTuple (which is set from the first idKeyTuple)
119 
120  @raise RuntimeError if idKeyTuple is different than it was for the first call.
121 
122  GetRepositoryDataTask.run returns idKeyTuple and idValList; you can easily make
123  a subclass of GetRepositoryDataTask that also returns sourceTableList.
124 
125  Updates instance variables:
126  - self._idKeyTuple if not already set.
127  """
128  if self._idKeyTuple is None:
129  self._idKeyTuple = tuple(idKeyTuple)
130  self._idKeyDTypeList = _getDTypeList(keyTuple=self._idKeyTuple,
131  valTuple=idValList[0])
132  else:
133  if self._idKeyTuple != tuple(idKeyTuple):
134  raise RuntimeError("idKeyTuple = %s != %s = first idKeyTuple; must be the same each time" %
135  (idKeyTuple, self._idKeyTuple))
136 
137  dataDict = {}
138  for idTuple, sourceTable in zip(idValList, sourceTableList):
139  if len(sourceTable) == 0:
140  continue
141 
142  idList = sourceTable.get("id")
143  dataList = [sourceTable.get(key) for key in self._sourceKeyTuple]
144 
145  if self._sourceDTypeList is None:
146  self._sourceDTypeList = [(key, arr.dtype)
147  for key, arr in zip(self._sourceKeyTuple, dataList)]
148 
149  transposedDataList = list(zip(*dataList))
150  del dataList
151 
152  dataDict.update((srcId, idTuple + tuple(data))
153  for srcId, data in zip(idList, transposedDataList))
154  return dataDict
155 
156  def addSourceMetrics(self, repoInfo, idKeyTuple, idValList, sourceTableList):
157  """Accumulate source measurements from a list of source tables.
158 
159  Once you have accumulated all source measurements, call finalize to process the data.
160 
161  @param[in] repoInfo: a RepositoryInfo instance
162  @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call
163  @param[in] idValList: a list of data ID value tuples;
164  each tuple contains values in the order in idKeyTuple
165  @param[in] sourceTableList: a list of source tables, one per entry in idValList
166 
167  @raise RuntimeError if idKeyTuple is different than it was for the first call.
168 
169  Accumulates the data in temporary cache self._tempDataList.
170 
171  @return number of sources
172  """
173  if self._repoKeyTuple is None:
174  self._repoKeyTuple = repoInfo.keyTuple
175  self._repoDTypeList = repoInfo.dtype
176 
177  dataDict = self._getSourceMetrics(idKeyTuple, idValList, sourceTableList)
178 
179  self._tempDataList.append(dataDict)
180  self.repoInfoList.append(repoInfo)
181  return len(dataDict)
182 
183  def finalize(self):
184  """Process the accumulated source measurements to create the final data products.
185 
186  Only call this after you have added all source metrics using addSourceMetrics.
187 
188  Reads temporary cache self._tempDataList and then deletes it.
189  """
190  if len(self._tempDataList) == 0:
191  raise RuntimeError("No data found")
192 
193  fullSrcIdSet = set()
194  for dataIdDict in self._tempDataList:
195  fullSrcIdSet.update(iter(dataIdDict.keys()))
196 
197  # source data
198  sourceArrDType = [("sourceId", int)] + self._idKeyDTypeList + self._sourceDTypeList
199  # data for missing sources (only for the data in the source data dict, so excludes srcId)
200  nullSourceTuple = tuple(numpy.zeros(1, dtype=self._idKeyDTypeList + self._sourceDTypeList)[0])
201 
202  sourceData = [[(srcId,) + srcDataDict.get(srcId, nullSourceTuple) for srcId in fullSrcIdSet]
203  for srcDataDict in self._tempDataList]
204 
205  self.sourceArr = numpy.array(sourceData, dtype=sourceArrDType)
206  del sourceData
207 
208  self.sourceIdDict = dict((srcId, i) for i, srcId in enumerate(fullSrcIdSet))
209 
210  # repository data
211  repoData = [repoInfo.valTuple for repoInfo in self.repoInfoList]
212  self.repoArr = numpy.array(repoData, dtype=self._repoDTypeList)
213 
214  self._tempDataList = None
215 
216 
217 class RepositoryInfo(object):
218  """Information about one data repository
219 
220  Constructed by RepositoryIterator and used by SourceData.
221  """
222 
223  def __init__(self, keyTuple, valTuple, dtype, name):
224  if len(keyTuple) != len(valTuple):
225  raise RuntimeError("lengths of keyTuple=%s and valTuple=%s do not match" % (keyTuple, valTuple))
226  self.keyTuple = tuple(keyTuple)
227  self.valTuple = tuple(valTuple)
228  self.dtype = dtype
229  self.name = name
230 
231 
232 class RepositoryIterator(object):
233  """Iterate over a set of data repositories that use a naming convention based on parameter values
234  """
235 
236  def __init__(self, formatStr, **dataDict):
237  """Construct a repository iterator from a dict of name: valueList
238 
239  @param[in] formatStr: format string using dictionary notation, e.g.: "%(foo)s_%(bar)d"
240  @param[in] **dataDict: name=valueList pairs
241  """
242  self._formatStr = formatStr
243  self._keyTuple = tuple(sorted(dataDict.keys()))
244  self._valListOfLists = [numpy.array(dataDict[key]) for key in self._keyTuple]
245  self._dtype = [(key, self._valListOfLists[i].dtype)
246  for i, key in enumerate(self._keyTuple)]
247 
248  def __iter__(self):
249  """Retrieve next RepositoryInfo object
250  """
251  for valTuple in itertools.product(*self._valListOfLists):
252  valDict = dict(zip(self._keyTuple, valTuple))
253  name = self.format(valDict)
254  yield RepositoryInfo(keyTuple=self._keyTuple, valTuple=valTuple, dtype=self._dtype, name=name)
255 
256  def __len__(self):
257  """Return the number of items in the iterator"""
258  n = 1
259  for valTuple in self._valListOfLists:
260  n *= len(valTuple)
261  return n
262 
263  def format(self, valDict):
264  """Return formatted string for a specified value dictionary
265 
266  @param[in] valDict: a dict of key: value pairs that identify a repository
267  """
268  return self._formatStr % valDict
269 
270  def getKeyTuple(self):
271  """Return the a tuple of keys in the same order as items in value tuples
272  """
273  return self._keyTuple
274 
275  def _getDTypeList(self):
276  """Get a dtype for a structured array of repository keys
277  """
278  return self._dtype
def _getSourceMetrics(self, idKeyTuple, idValList, sourceTableList)
def __init__(self, keyTuple, valTuple, dtype, name)
def addSourceMetrics(self, repoInfo, idKeyTuple, idValList, sourceTableList)
def __init__(self, datasetType, sourceKeyTuple)