Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"] 

24 

25 

26from collections import defaultdict 

27import collections.abc 

28import itertools 

29import unittest.mock 

30 

31from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory 

32from lsst.pipe.base import ButlerQuantumContext 

33 

34 

35def makeQuantum(task, butler, dataId, ioDataIds): 

36 """Create a Quantum for a particular data ID(s). 

37 

38 Parameters 

39 ---------- 

40 task : `lsst.pipe.base.PipelineTask` 

41 The task whose processing the quantum represents. 

42 butler : `lsst.daf.butler.Butler` 

43 The collection the quantum refers to. 

44 dataId: any data ID type 

45 The data ID of the quantum. Must have the same dimensions as 

46 ``task``'s connections class. 

47 ioDataIds : `collections.abc.Mapping` [`str`] 

48 A mapping keyed by input/output names. Values must be data IDs for 

49 single connections and sequences of data IDs for multiple connections. 

50 

51 Returns 

52 ------- 

53 quantum : `lsst.daf.butler.Quantum` 

54 A quantum for ``task``, when called with ``dataIds``. 

55 """ 

56 connections = task.config.ConnectionsClass(config=task.config) 

57 

58 try: 

59 inputs = defaultdict(list) 

60 outputs = defaultdict(list) 

61 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

62 connection = connections.__getattribute__(name) 

63 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

64 ids = _normalizeDataIds(ioDataIds[name]) 

65 for id in ids: 

66 ref = _refFromConnection(butler, connection, id) 

67 inputs[ref.datasetType].append(ref) 

68 for name in connections.outputs: 

69 connection = connections.__getattribute__(name) 

70 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

71 ids = _normalizeDataIds(ioDataIds[name]) 

72 for id in ids: 

73 ref = _refFromConnection(butler, connection, id) 

74 outputs[ref.datasetType].append(ref) 

75 quantum = Quantum(taskClass=type(task), 

76 dataId=dataId, 

77 inputs=inputs, 

78 outputs=outputs) 

79 return quantum 

80 except KeyError as e: 

81 raise ValueError("Mismatch in input data.") from e 

82 

83 

84def _checkDataIdMultiplicity(name, dataIds, multiple): 

85 """Test whether data IDs are scalars for scalar connections and sequences 

86 for multiple connections. 

87 

88 Parameters 

89 ---------- 

90 name : `str` 

91 The name of the connection being tested. 

92 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

93 The data ID(s) provided for the connection. 

94 multiple : `bool` 

95 The ``multiple`` field of the connection. 

96 

97 Raises 

98 ------ 

99 ValueError 

100 Raised if ``dataIds`` and ``multiple`` do not match. 

101 """ 

102 if multiple: 

103 if not isinstance(dataIds, collections.abc.Sequence): 

104 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

105 else: 

106 # DataCoordinate is a Mapping 

107 if not isinstance(dataIds, collections.abc.Mapping): 

108 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

109 

110 

111def _normalizeDataIds(dataIds): 

112 """Represent both single and multiple data IDs as a list. 

113 

114 Parameters 

115 ---------- 

116 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

117 The data ID(s) provided for a particular input or output connection. 

118 

119 Returns 

120 ------- 

121 normalizedIds : `~collections.abc.Sequence` [data ID] 

122 A sequence equal to ``dataIds`` if it was already a sequence, or 

123 ``[dataIds]`` if it was a single ID. 

124 """ 

125 if isinstance(dataIds, collections.abc.Sequence): 

126 return dataIds 

127 else: 

128 return [dataIds] 

129 

130 

131def _refFromConnection(butler, connection, dataId, **kwargs): 

132 """Create a DatasetRef for a connection in a collection. 

133 

134 Parameters 

135 ---------- 

136 butler : `lsst.daf.butler.Butler` 

137 The collection to point to. 

138 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

139 The connection defining the dataset type to point to. 

140 dataId 

141 The data ID for the dataset to point to. 

142 **kwargs 

143 Additional keyword arguments used to augment or construct 

144 a `~lsst.daf.butler.DataCoordinate`. 

145 

146 Returns 

147 ------- 

148 ref : `lsst.daf.butler.DatasetRef` 

149 A reference to a dataset compatible with ``connection``, with ID 

150 ``dataId``, in the collection pointed to by ``butler``. 

151 """ 

152 universe = butler.registry.dimensions 

153 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

154 

155 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

156 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

157 if "skypix" in connection.dimensions: 

158 datasetType = butler.registry.getDatasetType(connection.name) 

159 else: 

160 datasetType = connection.makeDatasetType(universe) 

161 

162 try: 

163 butler.registry.getDatasetType(datasetType.name) 

164 except KeyError: 

165 raise ValueError(f"Invalid dataset type {connection.name}.") 

166 try: 

167 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

168 return ref 

169 except KeyError as e: 

170 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

171 from e 

172 

173 

174def _resolveTestQuantumInputs(butler, quantum): 

175 """Look up all input datasets a test quantum in the `Registry` to resolve 

176 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

177 ``run`` attributes). 

178 

179 Parameters 

180 ---------- 

181 quantum : `~lsst.daf.butler.Quantum` 

182 Single Quantum instance. 

183 butler : `~lsst.daf.butler.Butler` 

184 Data butler. 

185 """ 

186 # TODO (DM-26819): This function is a direct copy of 

187 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

188 # `runTestQuantum` function that calls it is essentially duplicating logic 

189 # in that class as well (albeit not verbatim). We should probably move 

190 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

191 # in test code instead of having these classes at all. 

192 for refsForDatasetType in quantum.inputs.values(): 

193 newRefsForDatasetType = [] 

194 for ref in refsForDatasetType: 

195 if ref.id is None: 

196 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId, 

197 collections=butler.collections) 

198 if resolvedRef is None: 

199 raise ValueError( 

200 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

201 f"in collections {butler.collections}." 

202 ) 

203 newRefsForDatasetType.append(resolvedRef) 

204 else: 

205 newRefsForDatasetType.append(ref) 

206 refsForDatasetType[:] = newRefsForDatasetType 

207 

208 

209def runTestQuantum(task, butler, quantum, mockRun=True): 

210 """Run a PipelineTask on a Quantum. 

211 

212 Parameters 

213 ---------- 

214 task : `lsst.pipe.base.PipelineTask` 

215 The task to run on the quantum. 

216 butler : `lsst.daf.butler.Butler` 

217 The collection to run on. 

218 quantum : `lsst.daf.butler.Quantum` 

219 The quantum to run. 

220 mockRun : `bool` 

221 Whether or not to replace ``task``'s ``run`` method. The default of 

222 `True` is recommended unless ``run`` needs to do real work (e.g., 

223 because the test needs real output datasets). 

224 

225 Returns 

226 ------- 

227 run : `unittest.mock.Mock` or `None` 

228 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

229 be queried for the arguments ``runQuantum`` passed to ``run``. 

230 """ 

231 _resolveTestQuantumInputs(butler, quantum) 

232 butlerQc = ButlerQuantumContext(butler, quantum) 

233 connections = task.config.ConnectionsClass(config=task.config) 

234 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

235 if mockRun: 

236 with unittest.mock.patch.object(task, "run") as mock, \ 

237 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

238 task.runQuantum(butlerQc, inputRefs, outputRefs) 

239 return mock 

240 else: 

241 task.runQuantum(butlerQc, inputRefs, outputRefs) 

242 return None 

243 

244 

245def assertValidOutput(task, result): 

246 """Test that the output of a call to ``run`` conforms to its own 

247 connections. 

248 

249 Parameters 

250 ---------- 

251 task : `lsst.pipe.base.PipelineTask` 

252 The task whose connections need validation. This is a fully-configured 

253 task object to support features such as optional outputs. 

254 result : `lsst.pipe.base.Struct` 

255 A result object produced by calling ``task.run``. 

256 

257 Raises 

258 ------- 

259 AssertionError: 

260 Raised if ``result`` does not match what's expected from ``task's`` 

261 connections. 

262 """ 

263 connections = task.config.ConnectionsClass(config=task.config) 

264 recoveredOutputs = result.getDict() 

265 

266 for name in connections.outputs: 

267 connection = connections.__getattribute__(name) 

268 # name 

269 try: 

270 output = recoveredOutputs[name] 

271 except KeyError: 

272 raise AssertionError(f"No such output: {name}") 

273 # multiple 

274 if connection.multiple: 

275 if not isinstance(output, collections.abc.Sequence): 

276 raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.") 

277 else: 

278 # use lazy evaluation to not use StorageClassFactory unless 

279 # necessary 

280 if isinstance(output, collections.abc.Sequence) \ 

281 and not issubclass( 

282 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

283 collections.abc.Sequence): 

284 raise AssertionError(f"Expected {name} to be a single value, got {output} instead.") 

285 # no test for storageClass, as I'm not sure how much persistence 

286 # depends on duck-typing