Coverage for python/lsst/pipe/tasks/repositoryIterator.py: 17%
96 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-20 09:49 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-20 09:49 +0000
1#
2# LSST Data Management System
3# Copyright 2008, 2009, 2010, 2011, 2012 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22"""Tools to help you iterate over a set of repositories.
24Helpful while creating them or harvesting data from them.
25"""
26import itertools
28import numpy
30STR_PADDING = 5 # used by _getDTypeList; the number of characters to add to the first string value seen
31# when estimating the number of characters needed to store values for a key
34def _getDTypeList(keyTuple, valTuple):
35 """Construct a numpy dtype for a data ID or repository ID
37 @param[in] keyTuple: ID key names, in order
38 @param[in] valTuple: a value tuple
39 @return numpy dtype as a list
41 @warning: this guesses at string length (STR_PADDING + length of string in valTuple);
42 longer strings will be truncated when inserted into numpy structured arrays
43 """
44 typeList = []
45 for name, val in zip(keyTuple, valTuple):
46 if isinstance(val, str):
47 predLen = len(val) + STR_PADDING
48 typeList.append((name, str, predLen))
49 else:
50 typeList.append((name, numpy.array([val]).dtype))
51 return typeList
54class SourceData:
55 """Accumulate a set of measurements from a set of source tables
57 To use:
58 - specify the desired source measurements when constructing this object
59 - call addSourceMetrics for each repository you harvest data from
60 - call finalize to produce the final data
62 Data available after calling finalize:
63 - self.sourceArr: a numpy structured array of shape (num repositories, num sources)
64 containing named columns for:
65 - source ID
66 - each data ID key
67 - each item of data extracted from the source table
68 - self.sourceIdDict: a dict of (source ID: index of axis 1 of self.sourceArr)
69 - self.repoArr: a numpy structured array of shape (num repositories,)
70 containing a named column for each repository key (see RepositoryIterator)
72 @note: sources that had non-finite data (e.g. NaN) for every value extracted are silently omitted
73 """
75 def __init__(self, datasetType, sourceKeyTuple):
76 """
77 @param[in] datasetType: dataset type for source
78 @param[in] sourceKeyTuple: list of keys of data items to extract from the source tables
80 @raise RuntimeError if sourceKeyTuple is empty
81 """
82 if len(sourceKeyTuple) < 1:
83 raise RuntimeError("Must specify at least one key in sourceKeyTuple")
84 self.datasetType = datasetType
85 self._sourceKeyTuple = tuple(sourceKeyTuple)
87 self._idKeyTuple = None # tuple of data ID keys, in order; set by first call to _getSourceMetrics
88 self._idKeyDTypeList = None # numpy dtype for data ID tuple, as a list of (key, type);
89 # set by first call to _getSourceMetrics
90 self._sourceDTypeList = None # numpy dtype for source data, as a list of (key, type);
91 # set by first call to _getSourceMetrics
92 self._repoKeyTuple = None # tuple of repo ID keys, in order; set by first call to addSourceMetrics
93 self._repoDTypeList = None # numpy dtype for repoArr, as a list of (key, type);
94 # set by first call to addSourceMetrics
96 self._tempDataList = [] # list (one entry per repository)
97 # of dict of source ID: tuple of data ID data concatenated with source metric data, where:
98 # data ID data is in order self._idKeyTuple
99 # source metric data is in order self._sourceKeyTuple
100 self.repoInfoList = [] # list of repoInfo
102 def _getSourceMetrics(self, idKeyTuple, idValList, sourceTableList):
103 """Obtain the desired source measurements from a list of source tables
105 Extracts a set of source measurements (specified by sourceKeyTuple) from a list of source tables
106 (one per data ID) and saves them as a dict of source ID: list of data
108 @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call
109 @param[in] idValList: a list of data ID value tuples;
110 each tuple contains values in the order in idKeyTuple
111 @param[in] sourceTableList: a list of source tables, one per entry in idValList
113 @return a dict of source id: data id tuple + source data tuple
114 where source data tuple order matches sourceKeyTuple
115 and data id tuple matches self._idKeyTuple (which is set from the first idKeyTuple)
117 @raise RuntimeError if idKeyTuple is different than it was for the first call.
119 GetRepositoryDataTask.run returns idKeyTuple and idValList; you can easily make
120 a subclass of GetRepositoryDataTask that also returns sourceTableList.
122 Updates instance variables:
123 - self._idKeyTuple if not already set.
124 """
125 if self._idKeyTuple is None:
126 self._idKeyTuple = tuple(idKeyTuple)
127 self._idKeyDTypeList = _getDTypeList(keyTuple=self._idKeyTuple,
128 valTuple=idValList[0])
129 else:
130 if self._idKeyTuple != tuple(idKeyTuple):
131 raise RuntimeError("idKeyTuple = %s != %s = first idKeyTuple; must be the same each time" %
132 (idKeyTuple, self._idKeyTuple))
134 dataDict = {}
135 for idTuple, sourceTable in zip(idValList, sourceTableList):
136 if len(sourceTable) == 0:
137 continue
139 idList = sourceTable.get("id")
140 dataList = [sourceTable.get(key) for key in self._sourceKeyTuple]
142 if self._sourceDTypeList is None:
143 self._sourceDTypeList = [(key, arr.dtype)
144 for key, arr in zip(self._sourceKeyTuple, dataList)]
146 transposedDataList = list(zip(*dataList))
147 del dataList
149 dataDict.update((srcId, idTuple + tuple(data))
150 for srcId, data in zip(idList, transposedDataList))
151 return dataDict
153 def addSourceMetrics(self, repoInfo, idKeyTuple, idValList, sourceTableList):
154 """Accumulate source measurements from a list of source tables.
156 Once you have accumulated all source measurements, call finalize to process the data.
158 @param[in] repoInfo: a RepositoryInfo instance
159 @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call
160 @param[in] idValList: a list of data ID value tuples;
161 each tuple contains values in the order in idKeyTuple
162 @param[in] sourceTableList: a list of source tables, one per entry in idValList
164 @raise RuntimeError if idKeyTuple is different than it was for the first call.
166 Accumulates the data in temporary cache self._tempDataList.
168 @return number of sources
169 """
170 if self._repoKeyTuple is None:
171 self._repoKeyTuple = repoInfo.keyTuple
172 self._repoDTypeList = repoInfo.dtype
174 dataDict = self._getSourceMetrics(idKeyTuple, idValList, sourceTableList)
176 self._tempDataList.append(dataDict)
177 self.repoInfoList.append(repoInfo)
178 return len(dataDict)
180 def finalize(self):
181 """Process the accumulated source measurements to create the final data products.
183 Only call this after you have added all source metrics using addSourceMetrics.
185 Reads temporary cache self._tempDataList and then deletes it.
186 """
187 if len(self._tempDataList) == 0:
188 raise RuntimeError("No data found")
190 fullSrcIdSet = set()
191 for dataIdDict in self._tempDataList:
192 fullSrcIdSet.update(iter(dataIdDict.keys()))
194 # source data
195 sourceArrDType = [("sourceId", int)] + self._idKeyDTypeList + self._sourceDTypeList
196 # data for missing sources (only for the data in the source data dict, so excludes srcId)
197 nullSourceTuple = tuple(numpy.zeros(1, dtype=self._idKeyDTypeList + self._sourceDTypeList)[0])
199 sourceData = [[(srcId,) + srcDataDict.get(srcId, nullSourceTuple) for srcId in fullSrcIdSet]
200 for srcDataDict in self._tempDataList]
202 self.sourceArr = numpy.array(sourceData, dtype=sourceArrDType)
203 del sourceData
205 self.sourceIdDict = dict((srcId, i) for i, srcId in enumerate(fullSrcIdSet))
207 # repository data
208 repoData = [repoInfo.valTuple for repoInfo in self.repoInfoList]
209 self.repoArr = numpy.array(repoData, dtype=self._repoDTypeList)
211 self._tempDataList = None
214class RepositoryInfo:
215 """Information about one data repository
217 Constructed by RepositoryIterator and used by SourceData.
218 """
220 def __init__(self, keyTuple, valTuple, dtype, name):
221 if len(keyTuple) != len(valTuple):
222 raise RuntimeError("lengths of keyTuple=%s and valTuple=%s do not match" % (keyTuple, valTuple))
223 self.keyTuple = tuple(keyTuple)
224 self.valTuple = tuple(valTuple)
225 self.dtype = dtype
226 self.name = name
229class RepositoryIterator:
230 """Iterate over a set of data repositories that use a naming convention based on parameter values
231 """
233 def __init__(self, formatStr, **dataDict):
234 """Construct a repository iterator from a dict of name: valueList
236 @param[in] formatStr: format string using dictionary notation, e.g.: "%(foo)s_%(bar)d"
237 @param[in] **dataDict: name=valueList pairs
238 """
239 self._formatStr = formatStr
240 self._keyTuple = tuple(sorted(dataDict.keys()))
241 self._valListOfLists = [numpy.array(dataDict[key]) for key in self._keyTuple]
242 self._dtype = [(key, self._valListOfLists[i].dtype)
243 for i, key in enumerate(self._keyTuple)]
245 def __iter__(self):
246 """Retrieve next RepositoryInfo object
247 """
248 for valTuple in itertools.product(*self._valListOfLists):
249 valDict = dict(zip(self._keyTuple, valTuple))
250 name = self.format(valDict)
251 yield RepositoryInfo(keyTuple=self._keyTuple, valTuple=valTuple, dtype=self._dtype, name=name)
253 def __len__(self):
254 """Return the number of items in the iterator"""
255 n = 1
256 for valTuple in self._valListOfLists:
257 n *= len(valTuple)
258 return n
260 def format(self, valDict):
261 """Return formatted string for a specified value dictionary
263 @param[in] valDict: a dict of key: value pairs that identify a repository
264 """
265 return self._formatStr % valDict
267 def getKeyTuple(self):
268 """Return the a tuple of keys in the same order as items in value tuples
269 """
270 return self._keyTuple
272 def _getDTypeList(self):
273 """Get a dtype for a structured array of repository keys
274 """
275 return self._dtype