Coverage for python/lsst/pipe/tasks/repositoryIterator.py: 17%

96 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-20 02:29 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008, 2009, 2010, 2011, 2012 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22"""Tools to help you iterate over a set of repositories. 

23 

24Helpful while creating them or harvesting data from them. 

25""" 

26import itertools 

27 

28import numpy 

29 

30STR_PADDING = 5 # used by _getDTypeList; the number of characters to add to the first string value seen 

31# when estimating the number of characters needed to store values for a key 

32 

33 

34def _getDTypeList(keyTuple, valTuple): 

35 """Construct a numpy dtype for a data ID or repository ID 

36 

37 @param[in] keyTuple: ID key names, in order 

38 @param[in] valTuple: a value tuple 

39 @return numpy dtype as a list 

40 

41 @warning: this guesses at string length (STR_PADDING + length of string in valTuple); 

42 longer strings will be truncated when inserted into numpy structured arrays 

43 """ 

44 typeList = [] 

45 for name, val in zip(keyTuple, valTuple): 

46 if isinstance(val, str): 

47 predLen = len(val) + STR_PADDING 

48 typeList.append((name, str, predLen)) 

49 else: 

50 typeList.append((name, numpy.array([val]).dtype)) 

51 return typeList 

52 

53 

54class SourceData: 

55 """Accumulate a set of measurements from a set of source tables 

56 

57 To use: 

58 - specify the desired source measurements when constructing this object 

59 - call addSourceMetrics for each repository you harvest data from 

60 - call finalize to produce the final data 

61 

62 Data available after calling finalize: 

63 - self.sourceArr: a numpy structured array of shape (num repositories, num sources) 

64 containing named columns for: 

65 - source ID 

66 - each data ID key 

67 - each item of data extracted from the source table 

68 - self.sourceIdDict: a dict of (source ID: index of axis 1 of self.sourceArr) 

69 - self.repoArr: a numpy structured array of shape (num repositories,) 

70 containing a named column for each repository key (see RepositoryIterator) 

71 

72 @note: sources that had non-finite data (e.g. NaN) for every value extracted are silently omitted 

73 """ 

74 

75 def __init__(self, datasetType, sourceKeyTuple): 

76 """ 

77 @param[in] datasetType: dataset type for source 

78 @param[in] sourceKeyTuple: list of keys of data items to extract from the source tables 

79 

80 @raise RuntimeError if sourceKeyTuple is empty 

81 """ 

82 if len(sourceKeyTuple) < 1: 

83 raise RuntimeError("Must specify at least one key in sourceKeyTuple") 

84 self.datasetType = datasetType 

85 self._sourceKeyTuple = tuple(sourceKeyTuple) 

86 

87 self._idKeyTuple = None # tuple of data ID keys, in order; set by first call to _getSourceMetrics 

88 self._idKeyDTypeList = None # numpy dtype for data ID tuple, as a list of (key, type); 

89 # set by first call to _getSourceMetrics 

90 self._sourceDTypeList = None # numpy dtype for source data, as a list of (key, type); 

91 # set by first call to _getSourceMetrics 

92 self._repoKeyTuple = None # tuple of repo ID keys, in order; set by first call to addSourceMetrics 

93 self._repoDTypeList = None # numpy dtype for repoArr, as a list of (key, type); 

94 # set by first call to addSourceMetrics 

95 

96 self._tempDataList = [] # list (one entry per repository) 

97 # of dict of source ID: tuple of data ID data concatenated with source metric data, where: 

98 # data ID data is in order self._idKeyTuple 

99 # source metric data is in order self._sourceKeyTuple 

100 self.repoInfoList = [] # list of repoInfo 

101 

102 def _getSourceMetrics(self, idKeyTuple, idValList, sourceTableList): 

103 """Obtain the desired source measurements from a list of source tables 

104 

105 Extracts a set of source measurements (specified by sourceKeyTuple) from a list of source tables 

106 (one per data ID) and saves them as a dict of source ID: list of data 

107 

108 @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call 

109 @param[in] idValList: a list of data ID value tuples; 

110 each tuple contains values in the order in idKeyTuple 

111 @param[in] sourceTableList: a list of source tables, one per entry in idValList 

112 

113 @return a dict of source id: data id tuple + source data tuple 

114 where source data tuple order matches sourceKeyTuple 

115 and data id tuple matches self._idKeyTuple (which is set from the first idKeyTuple) 

116 

117 @raise RuntimeError if idKeyTuple is different than it was for the first call. 

118 

119 GetRepositoryDataTask.run returns idKeyTuple and idValList; you can easily make 

120 a subclass of GetRepositoryDataTask that also returns sourceTableList. 

121 

122 Updates instance variables: 

123 - self._idKeyTuple if not already set. 

124 """ 

125 if self._idKeyTuple is None: 

126 self._idKeyTuple = tuple(idKeyTuple) 

127 self._idKeyDTypeList = _getDTypeList(keyTuple=self._idKeyTuple, 

128 valTuple=idValList[0]) 

129 else: 

130 if self._idKeyTuple != tuple(idKeyTuple): 

131 raise RuntimeError("idKeyTuple = %s != %s = first idKeyTuple; must be the same each time" % 

132 (idKeyTuple, self._idKeyTuple)) 

133 

134 dataDict = {} 

135 for idTuple, sourceTable in zip(idValList, sourceTableList): 

136 if len(sourceTable) == 0: 

137 continue 

138 

139 idList = sourceTable.get("id") 

140 dataList = [sourceTable.get(key) for key in self._sourceKeyTuple] 

141 

142 if self._sourceDTypeList is None: 

143 self._sourceDTypeList = [(key, arr.dtype) 

144 for key, arr in zip(self._sourceKeyTuple, dataList)] 

145 

146 transposedDataList = list(zip(*dataList)) 

147 del dataList 

148 

149 dataDict.update((srcId, idTuple + tuple(data)) 

150 for srcId, data in zip(idList, transposedDataList)) 

151 return dataDict 

152 

153 def addSourceMetrics(self, repoInfo, idKeyTuple, idValList, sourceTableList): 

154 """Accumulate source measurements from a list of source tables. 

155 

156 Once you have accumulated all source measurements, call finalize to process the data. 

157 

158 @param[in] repoInfo: a RepositoryInfo instance 

159 @param[in] idKeyTuple: a tuple of data ID keys; must be the same for each call 

160 @param[in] idValList: a list of data ID value tuples; 

161 each tuple contains values in the order in idKeyTuple 

162 @param[in] sourceTableList: a list of source tables, one per entry in idValList 

163 

164 @raise RuntimeError if idKeyTuple is different than it was for the first call. 

165 

166 Accumulates the data in temporary cache self._tempDataList. 

167 

168 @return number of sources 

169 """ 

170 if self._repoKeyTuple is None: 

171 self._repoKeyTuple = repoInfo.keyTuple 

172 self._repoDTypeList = repoInfo.dtype 

173 

174 dataDict = self._getSourceMetrics(idKeyTuple, idValList, sourceTableList) 

175 

176 self._tempDataList.append(dataDict) 

177 self.repoInfoList.append(repoInfo) 

178 return len(dataDict) 

179 

180 def finalize(self): 

181 """Process the accumulated source measurements to create the final data products. 

182 

183 Only call this after you have added all source metrics using addSourceMetrics. 

184 

185 Reads temporary cache self._tempDataList and then deletes it. 

186 """ 

187 if len(self._tempDataList) == 0: 

188 raise RuntimeError("No data found") 

189 

190 fullSrcIdSet = set() 

191 for dataIdDict in self._tempDataList: 

192 fullSrcIdSet.update(iter(dataIdDict.keys())) 

193 

194 # source data 

195 sourceArrDType = [("sourceId", int)] + self._idKeyDTypeList + self._sourceDTypeList 

196 # data for missing sources (only for the data in the source data dict, so excludes srcId) 

197 nullSourceTuple = tuple(numpy.zeros(1, dtype=self._idKeyDTypeList + self._sourceDTypeList)[0]) 

198 

199 sourceData = [[(srcId,) + srcDataDict.get(srcId, nullSourceTuple) for srcId in fullSrcIdSet] 

200 for srcDataDict in self._tempDataList] 

201 

202 self.sourceArr = numpy.array(sourceData, dtype=sourceArrDType) 

203 del sourceData 

204 

205 self.sourceIdDict = dict((srcId, i) for i, srcId in enumerate(fullSrcIdSet)) 

206 

207 # repository data 

208 repoData = [repoInfo.valTuple for repoInfo in self.repoInfoList] 

209 self.repoArr = numpy.array(repoData, dtype=self._repoDTypeList) 

210 

211 self._tempDataList = None 

212 

213 

214class RepositoryInfo: 

215 """Information about one data repository 

216 

217 Constructed by RepositoryIterator and used by SourceData. 

218 """ 

219 

220 def __init__(self, keyTuple, valTuple, dtype, name): 

221 if len(keyTuple) != len(valTuple): 

222 raise RuntimeError("lengths of keyTuple=%s and valTuple=%s do not match" % (keyTuple, valTuple)) 

223 self.keyTuple = tuple(keyTuple) 

224 self.valTuple = tuple(valTuple) 

225 self.dtype = dtype 

226 self.name = name 

227 

228 

229class RepositoryIterator: 

230 """Iterate over a set of data repositories that use a naming convention based on parameter values 

231 """ 

232 

233 def __init__(self, formatStr, **dataDict): 

234 """Construct a repository iterator from a dict of name: valueList 

235 

236 @param[in] formatStr: format string using dictionary notation, e.g.: "%(foo)s_%(bar)d" 

237 @param[in] **dataDict: name=valueList pairs 

238 """ 

239 self._formatStr = formatStr 

240 self._keyTuple = tuple(sorted(dataDict.keys())) 

241 self._valListOfLists = [numpy.array(dataDict[key]) for key in self._keyTuple] 

242 self._dtype = [(key, self._valListOfLists[i].dtype) 

243 for i, key in enumerate(self._keyTuple)] 

244 

245 def __iter__(self): 

246 """Retrieve next RepositoryInfo object 

247 """ 

248 for valTuple in itertools.product(*self._valListOfLists): 

249 valDict = dict(zip(self._keyTuple, valTuple)) 

250 name = self.format(valDict) 

251 yield RepositoryInfo(keyTuple=self._keyTuple, valTuple=valTuple, dtype=self._dtype, name=name) 

252 

253 def __len__(self): 

254 """Return the number of items in the iterator""" 

255 n = 1 

256 for valTuple in self._valListOfLists: 

257 n *= len(valTuple) 

258 return n 

259 

260 def format(self, valDict): 

261 """Return formatted string for a specified value dictionary 

262 

263 @param[in] valDict: a dict of key: value pairs that identify a repository 

264 """ 

265 return self._formatStr % valDict 

266 

267 def getKeyTuple(self): 

268 """Return the a tuple of keys in the same order as items in value tuples 

269 """ 

270 return self._keyTuple 

271 

272 def _getDTypeList(self): 

273 """Get a dtype for a structured array of repository keys 

274 """ 

275 return self._dtype