Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"] 

24 

25 

26import collections.abc 

27import itertools 

28import unittest.mock 

29 

30from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory 

31from lsst.pipe.base import ButlerQuantumContext 

32 

33 

34def makeQuantum(task, butler, dataId, ioDataIds): 

35 """Create a Quantum for a particular data ID(s). 

36 

37 Parameters 

38 ---------- 

39 task : `lsst.pipe.base.PipelineTask` 

40 The task whose processing the quantum represents. 

41 butler : `lsst.daf.butler.Butler` 

42 The collection the quantum refers to. 

43 dataId: any data ID type 

44 The data ID of the quantum. Must have the same dimensions as 

45 ``task``'s connections class. 

46 ioDataIds : `collections.abc.Mapping` [`str`] 

47 A mapping keyed by input/output names. Values must be data IDs for 

48 single connections and sequences of data IDs for multiple connections. 

49 

50 Returns 

51 ------- 

52 quantum : `lsst.daf.butler.Quantum` 

53 A quantum for ``task``, when called with ``dataIds``. 

54 """ 

55 quantum = Quantum(taskClass=type(task), dataId=dataId) 

56 connections = task.config.ConnectionsClass(config=task.config) 

57 

58 try: 

59 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

60 connection = connections.__getattribute__(name) 

61 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

62 ids = _normalizeDataIds(ioDataIds[name]) 

63 for id in ids: 

64 quantum.addPredictedInput(_refFromConnection(butler, connection, id)) 

65 for name in connections.outputs: 

66 connection = connections.__getattribute__(name) 

67 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

68 ids = _normalizeDataIds(ioDataIds[name]) 

69 for id in ids: 

70 quantum.addOutput(_refFromConnection(butler, connection, id)) 

71 return quantum 

72 except KeyError as e: 

73 raise ValueError("Mismatch in input data.") from e 

74 

75 

76def _checkDataIdMultiplicity(name, dataIds, multiple): 

77 """Test whether data IDs are scalars for scalar connections and sequences 

78 for multiple connections. 

79 

80 Parameters 

81 ---------- 

82 name : `str` 

83 The name of the connection being tested. 

84 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

85 The data ID(s) provided for the connection. 

86 multiple : `bool` 

87 The ``multiple`` field of the connection. 

88 

89 Raises 

90 ------ 

91 ValueError 

92 Raised if ``dataIds`` and ``multiple`` do not match. 

93 """ 

94 if multiple: 

95 if not isinstance(dataIds, collections.abc.Sequence): 

96 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

97 else: 

98 # DataCoordinate is a Mapping 

99 if not isinstance(dataIds, collections.abc.Mapping): 

100 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

101 

102 

103def _normalizeDataIds(dataIds): 

104 """Represent both single and multiple data IDs as a list. 

105 

106 Parameters 

107 ---------- 

108 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

109 The data ID(s) provided for a particular input or output connection. 

110 

111 Returns 

112 ------- 

113 normalizedIds : `~collections.abc.Sequence` [data ID] 

114 A sequence equal to ``dataIds`` if it was already a sequence, or 

115 ``[dataIds]`` if it was a single ID. 

116 """ 

117 if isinstance(dataIds, collections.abc.Sequence): 

118 return dataIds 

119 else: 

120 return [dataIds] 

121 

122 

123def _refFromConnection(butler, connection, dataId, **kwargs): 

124 """Create a DatasetRef for a connection in a collection. 

125 

126 Parameters 

127 ---------- 

128 butler : `lsst.daf.butler.Butler` 

129 The collection to point to. 

130 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

131 The connection defining the dataset type to point to. 

132 dataId 

133 The data ID for the dataset to point to. 

134 **kwargs 

135 Additional keyword arguments used to augment or construct 

136 a `~lsst.daf.butler.DataCoordinate`. 

137 

138 Returns 

139 ------- 

140 ref : `lsst.daf.butler.DatasetRef` 

141 A reference to a dataset compatible with ``connection``, with ID 

142 ``dataId``, in the collection pointed to by ``butler``. 

143 """ 

144 universe = butler.registry.dimensions 

145 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

146 

147 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

148 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

149 if "skypix" in connection.dimensions: 

150 datasetType = butler.registry.getDatasetType(connection.name) 

151 else: 

152 datasetType = connection.makeDatasetType(universe) 

153 

154 try: 

155 butler.registry.getDatasetType(datasetType.name) 

156 except KeyError: 

157 raise ValueError(f"Invalid dataset type {connection.name}.") 

158 try: 

159 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

160 return ref 

161 except KeyError as e: 

162 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

163 from e 

164 

165 

166def runTestQuantum(task, butler, quantum, mockRun=True): 

167 """Run a PipelineTask on a Quantum. 

168 

169 Parameters 

170 ---------- 

171 task : `lsst.pipe.base.PipelineTask` 

172 The task to run on the quantum. 

173 butler : `lsst.daf.butler.Butler` 

174 The collection to run on. 

175 quantum : `lsst.daf.butler.Quantum` 

176 The quantum to run. 

177 mockRun : `bool` 

178 Whether or not to replace ``task``'s ``run`` method. The default of 

179 `True` is recommended unless ``run`` needs to do real work (e.g., 

180 because the test needs real output datasets). 

181 

182 Returns 

183 ------- 

184 run : `unittest.mock.Mock` or `None` 

185 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

186 be queried for the arguments ``runQuantum`` passed to ``run``. 

187 """ 

188 butlerQc = ButlerQuantumContext(butler, quantum) 

189 connections = task.config.ConnectionsClass(config=task.config) 

190 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

191 if mockRun: 

192 with unittest.mock.patch.object(task, "run") as mock, \ 

193 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

194 task.runQuantum(butlerQc, inputRefs, outputRefs) 

195 return mock 

196 else: 

197 task.runQuantum(butlerQc, inputRefs, outputRefs) 

198 return None 

199 

200 

201def assertValidOutput(task, result): 

202 """Test that the output of a call to ``run`` conforms to its own connections. 

203 

204 Parameters 

205 ---------- 

206 task : `lsst.pipe.base.PipelineTask` 

207 The task whose connections need validation. This is a fully-configured 

208 task object to support features such as optional outputs. 

209 result : `lsst.pipe.base.Struct` 

210 A result object produced by calling ``task.run``. 

211 

212 Raises 

213 ------- 

214 AssertionError: 

215 Raised if ``result`` does not match what's expected from ``task's`` 

216 connections. 

217 """ 

218 connections = task.config.ConnectionsClass(config=task.config) 

219 recoveredOutputs = result.getDict() 

220 

221 for name in connections.outputs: 

222 connection = connections.__getattribute__(name) 

223 # name 

224 try: 

225 output = recoveredOutputs[name] 

226 except KeyError: 

227 raise AssertionError(f"No such output: {name}") 

228 # multiple 

229 if connection.multiple: 

230 if not isinstance(output, collections.abc.Sequence): 

231 raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.") 

232 else: 

233 # use lazy evaluation to not use StorageClassFactory unless necessary 

234 if isinstance(output, collections.abc.Sequence) \ 

235 and not issubclass( 

236 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

237 collections.abc.Sequence): 

238 raise AssertionError(f"Expected {name} to be a single value, got {output} instead.") 

239 # no test for storageClass, as I'm not sure how much persistence depends on duck-typing