Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"] 

24 

25 

26import collections.abc 

27import itertools 

28import unittest.mock 

29 

30from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory 

31from lsst.pipe.base import ButlerQuantumContext 

32 

33 

34def makeQuantum(task, butler, dataId, ioDataIds): 

35 """Create a Quantum for a particular data ID(s). 

36 

37 Parameters 

38 ---------- 

39 task : `lsst.pipe.base.PipelineTask` 

40 The task whose processing the quantum represents. 

41 butler : `lsst.daf.butler.Butler` 

42 The collection the quantum refers to. 

43 dataId: any data ID type 

44 The data ID of the quantum. Must have the same dimensions as 

45 ``task``'s connections class. 

46 ioDataIds : `collections.abc.Mapping` [`str`] 

47 A mapping keyed by input/output names. Values must be data IDs for 

48 single connections and sequences of data IDs for multiple connections. 

49 

50 Returns 

51 ------- 

52 quantum : `lsst.daf.butler.Quantum` 

53 A quantum for ``task``, when called with ``dataIds``. 

54 """ 

55 quantum = Quantum(taskClass=type(task), dataId=dataId) 

56 connections = task.config.ConnectionsClass(config=task.config) 

57 

58 try: 

59 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

60 connection = connections.__getattribute__(name) 

61 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

62 ids = _normalizeDataIds(ioDataIds[name]) 

63 for id in ids: 

64 quantum.addPredictedInput(_refFromConnection(butler, connection, id)) 

65 for name in connections.outputs: 

66 connection = connections.__getattribute__(name) 

67 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

68 ids = _normalizeDataIds(ioDataIds[name]) 

69 for id in ids: 

70 quantum.addOutput(_refFromConnection(butler, connection, id)) 

71 return quantum 

72 except KeyError as e: 

73 raise ValueError("Mismatch in input data.") from e 

74 

75 

76def _checkDataIdMultiplicity(name, dataIds, multiple): 

77 """Test whether data IDs are scalars for scalar connections and sequences 

78 for multiple connections. 

79 

80 Parameters 

81 ---------- 

82 name : `str` 

83 The name of the connection being tested. 

84 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

85 The data ID(s) provided for the connection. 

86 multiple : `bool` 

87 The ``multiple`` field of the connection. 

88 

89 Raises 

90 ------ 

91 ValueError 

92 Raised if ``dataIds`` and ``multiple`` do not match. 

93 """ 

94 if multiple: 

95 if not isinstance(dataIds, collections.abc.Sequence): 

96 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

97 else: 

98 # DataCoordinate is a Mapping 

99 if not isinstance(dataIds, collections.abc.Mapping): 

100 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

101 

102 

103def _normalizeDataIds(dataIds): 

104 """Represent both single and multiple data IDs as a list. 

105 

106 Parameters 

107 ---------- 

108 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

109 The data ID(s) provided for a particular input or output connection. 

110 

111 Returns 

112 ------- 

113 normalizedIds : `~collections.abc.Sequence` [data ID] 

114 A sequence equal to ``dataIds`` if it was already a sequence, or 

115 ``[dataIds]`` if it was a single ID. 

116 """ 

117 if isinstance(dataIds, collections.abc.Sequence): 

118 return dataIds 

119 else: 

120 return [dataIds] 

121 

122 

123def _refFromConnection(butler, connection, dataId, **kwargs): 

124 """Create a DatasetRef for a connection in a collection. 

125 

126 Parameters 

127 ---------- 

128 butler : `lsst.daf.butler.Butler` 

129 The collection to point to. 

130 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

131 The connection defining the dataset type to point to. 

132 dataId 

133 The data ID for the dataset to point to. 

134 **kwargs 

135 Additional keyword arguments used to augment or construct 

136 a `~lsst.daf.butler.DataCoordinate`. 

137 

138 Returns 

139 ------- 

140 ref : `lsst.daf.butler.DatasetRef` 

141 A reference to a dataset compatible with ``connection``, with ID 

142 ``dataId``, in the collection pointed to by ``butler``. 

143 """ 

144 universe = butler.registry.dimensions 

145 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

146 

147 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

148 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

149 if "skypix" in connection.dimensions: 

150 datasetType = butler.registry.getDatasetType(connection.name) 

151 else: 

152 datasetType = connection.makeDatasetType(universe) 

153 

154 try: 

155 butler.registry.getDatasetType(datasetType.name) 

156 except KeyError: 

157 raise ValueError(f"Invalid dataset type {connection.name}.") 

158 try: 

159 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

160 return ref 

161 except KeyError as e: 

162 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

163 from e 

164 

165 

166def _resolveTestQuantumInputs(butler, quantum): 

167 """Look up all input datasets a test quantum in the `Registry` to resolve 

168 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

169 ``run`` attributes). 

170 

171 Parameters 

172 ---------- 

173 quantum : `~lsst.daf.butler.Quantum` 

174 Single Quantum instance. 

175 butler : `~lsst.daf.butler.Butler` 

176 Data butler. 

177 """ 

178 # TODO (DM-26819): This function is a direct copy of 

179 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

180 # `runTestQuantum` function that calls it is essentially duplicating logic 

181 # in that class as well (albeit not verbatim). We should probably move 

182 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

183 # in test code instead of having these classes at all. 

184 for refsForDatasetType in quantum.predictedInputs.values(): 

185 newRefsForDatasetType = [] 

186 for ref in refsForDatasetType: 

187 if ref.id is None: 

188 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId, 

189 collections=butler.collections) 

190 if resolvedRef is None: 

191 raise ValueError( 

192 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

193 f"in collections {butler.collections}." 

194 ) 

195 newRefsForDatasetType.append(resolvedRef) 

196 else: 

197 newRefsForDatasetType.append(ref) 

198 refsForDatasetType[:] = newRefsForDatasetType 

199 

200 

201def runTestQuantum(task, butler, quantum, mockRun=True): 

202 """Run a PipelineTask on a Quantum. 

203 

204 Parameters 

205 ---------- 

206 task : `lsst.pipe.base.PipelineTask` 

207 The task to run on the quantum. 

208 butler : `lsst.daf.butler.Butler` 

209 The collection to run on. 

210 quantum : `lsst.daf.butler.Quantum` 

211 The quantum to run. 

212 mockRun : `bool` 

213 Whether or not to replace ``task``'s ``run`` method. The default of 

214 `True` is recommended unless ``run`` needs to do real work (e.g., 

215 because the test needs real output datasets). 

216 

217 Returns 

218 ------- 

219 run : `unittest.mock.Mock` or `None` 

220 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

221 be queried for the arguments ``runQuantum`` passed to ``run``. 

222 """ 

223 _resolveTestQuantumInputs(butler, quantum) 

224 butlerQc = ButlerQuantumContext(butler, quantum) 

225 connections = task.config.ConnectionsClass(config=task.config) 

226 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

227 if mockRun: 

228 with unittest.mock.patch.object(task, "run") as mock, \ 

229 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

230 task.runQuantum(butlerQc, inputRefs, outputRefs) 

231 return mock 

232 else: 

233 task.runQuantum(butlerQc, inputRefs, outputRefs) 

234 return None 

235 

236 

237def assertValidOutput(task, result): 

238 """Test that the output of a call to ``run`` conforms to its own connections. 

239 

240 Parameters 

241 ---------- 

242 task : `lsst.pipe.base.PipelineTask` 

243 The task whose connections need validation. This is a fully-configured 

244 task object to support features such as optional outputs. 

245 result : `lsst.pipe.base.Struct` 

246 A result object produced by calling ``task.run``. 

247 

248 Raises 

249 ------- 

250 AssertionError: 

251 Raised if ``result`` does not match what's expected from ``task's`` 

252 connections. 

253 """ 

254 connections = task.config.ConnectionsClass(config=task.config) 

255 recoveredOutputs = result.getDict() 

256 

257 for name in connections.outputs: 

258 connection = connections.__getattribute__(name) 

259 # name 

260 try: 

261 output = recoveredOutputs[name] 

262 except KeyError: 

263 raise AssertionError(f"No such output: {name}") 

264 # multiple 

265 if connection.multiple: 

266 if not isinstance(output, collections.abc.Sequence): 

267 raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.") 

268 else: 

269 # use lazy evaluation to not use StorageClassFactory unless necessary 

270 if isinstance(output, collections.abc.Sequence) \ 

271 and not issubclass( 

272 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

273 collections.abc.Sequence): 

274 raise AssertionError(f"Expected {name} to be a single value, got {output} instead.") 

275 # no test for storageClass, as I'm not sure how much persistence depends on duck-typing