Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"] 

24 

25 

26import collections.abc 

27import unittest.mock 

28 

29from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory 

30from lsst.pipe.base import ButlerQuantumContext 

31 

32 

33def makeQuantum(task, butler, dataIds): 

34 """Create a Quantum for a particular data ID(s). 

35 

36 Parameters 

37 ---------- 

38 task : `lsst.pipe.base.PipelineTask` 

39 The task whose processing the quantum represents. 

40 butler : `lsst.daf.butler.Butler` 

41 The collection the quantum refers to. 

42 dataIds : `collections.abc.Mapping` [`str`] 

43 A mapping keyed by input/output names. Values must be data IDs for 

44 single connections and sequences of data IDs for multiple connections. 

45 

46 Returns 

47 ------- 

48 quantum : `lsst.daf.butler.Quantum` 

49 A quantum for ``task``, when called with ``dataIds``. 

50 """ 

51 quantum = Quantum(taskClass=type(task)) 

52 connections = task.config.ConnectionsClass(config=task.config) 

53 

54 try: 

55 for name in connections.inputs: 

56 connection = connections.__getattribute__(name) 

57 _checkDataIdMultiplicity(name, dataIds[name], connection.multiple) 

58 ids = _normalizeDataIds(dataIds[name]) 

59 for id in ids: 

60 quantum.addPredictedInput(_refFromConnection(butler, connection, id)) 

61 for name in connections.outputs: 

62 connection = connections.__getattribute__(name) 

63 _checkDataIdMultiplicity(name, dataIds[name], connection.multiple) 

64 ids = _normalizeDataIds(dataIds[name]) 

65 for id in ids: 

66 quantum.addOutput(_refFromConnection(butler, connection, id)) 

67 return quantum 

68 except KeyError as e: 

69 raise ValueError("Mismatch in input data.") from e 

70 

71 

72def _checkDataIdMultiplicity(name, dataIds, multiple): 

73 """Test whether data IDs are scalars for scalar connections and sequences 

74 for multiple connections. 

75 

76 Parameters 

77 ---------- 

78 name : `str` 

79 The name of the connection being tested. 

80 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

81 The data ID(s) provided for the connection. 

82 multiple : `bool` 

83 The ``multiple`` field of the connection. 

84 

85 Raises 

86 ------ 

87 ValueError 

88 Raised if ``dataIds`` and ``multiple`` do not match. 

89 """ 

90 if multiple: 

91 if not isinstance(dataIds, collections.abc.Sequence): 

92 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

93 else: 

94 # DataCoordinate is a Mapping 

95 if not isinstance(dataIds, collections.abc.Mapping): 

96 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

97 

98 

99def _normalizeDataIds(dataIds): 

100 """Represent both single and multiple data IDs as a list. 

101 

102 Parameters 

103 ---------- 

104 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

105 The data ID(s) provided for a particular input or output connection. 

106 

107 Returns 

108 ------- 

109 normalizedIds : `~collections.abc.Sequence` [data ID] 

110 A sequence equal to ``dataIds`` if it was already a sequence, or 

111 ``[dataIds]`` if it was a single ID. 

112 """ 

113 if isinstance(dataIds, collections.abc.Sequence): 

114 return dataIds 

115 else: 

116 return [dataIds] 

117 

118 

119def _refFromConnection(butler, connection, dataId, **kwargs): 

120 """Create a DatasetRef for a connection in a collection. 

121 

122 Parameters 

123 ---------- 

124 butler : `lsst.daf.butler.Butler` 

125 The collection to point to. 

126 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

127 The connection defining the dataset type to point to. 

128 dataId 

129 The data ID for the dataset to point to. 

130 **kwargs 

131 Additional keyword arguments used to augment or construct 

132 a `~lsst.daf.butler.DataCoordinate`. 

133 

134 Returns 

135 ------- 

136 ref : `lsst.daf.butler.DatasetRef` 

137 A reference to a dataset compatible with ``connection``, with ID 

138 ``dataId``, in the collection pointed to by ``butler``. 

139 """ 

140 universe = butler.registry.dimensions 

141 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

142 datasetType = connection.makeDatasetType(universe) 

143 try: 

144 butler.registry.getDatasetType(datasetType.name) 

145 except KeyError: 

146 raise ValueError(f"Invalid dataset type {connection.name}.") 

147 try: 

148 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

149 return ref 

150 except KeyError as e: 

151 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

152 from e 

153 

154 

155def runTestQuantum(task, butler, quantum, mockRun=True): 

156 """Run a PipelineTask on a Quantum. 

157 

158 Parameters 

159 ---------- 

160 task : `lsst.pipe.base.PipelineTask` 

161 The task to run on the quantum. 

162 butler : `lsst.daf.butler.Butler` 

163 The collection to run on. 

164 quantum : `lsst.daf.butler.Quantum` 

165 The quantum to run. 

166 mockRun : `bool` 

167 Whether or not to replace ``task``'s ``run`` method. The default of 

168 `True` is recommended unless ``run`` needs to do real work (e.g., 

169 because the test needs real output datasets). 

170 

171 Returns 

172 ------- 

173 run : `unittest.mock.Mock` or `None` 

174 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

175 be queried for the arguments ``runQuantum`` passed to ``run``. 

176 """ 

177 butlerQc = ButlerQuantumContext(butler, quantum) 

178 connections = task.config.ConnectionsClass(config=task.config) 

179 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

180 if mockRun: 

181 with unittest.mock.patch.object(task, "run") as mock, \ 

182 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

183 task.runQuantum(butlerQc, inputRefs, outputRefs) 

184 return mock 

185 else: 

186 task.runQuantum(butlerQc, inputRefs, outputRefs) 

187 return None 

188 

189 

190def assertValidOutput(task, result): 

191 """Test that the output of a call to ``run`` conforms to its own connections. 

192 

193 Parameters 

194 ---------- 

195 task : `lsst.pipe.base.PipelineTask` 

196 The task whose connections need validation. This is a fully-configured 

197 task object to support features such as optional outputs. 

198 result : `lsst.pipe.base.Struct` 

199 A result object produced by calling ``task.run``. 

200 

201 Raises 

202 ------- 

203 AssertionError: 

204 Raised if ``result`` does not match what's expected from ``task's`` 

205 connections. 

206 """ 

207 connections = task.config.ConnectionsClass(config=task.config) 

208 recoveredOutputs = result.getDict() 

209 

210 for name in connections.outputs: 

211 connection = connections.__getattribute__(name) 

212 # name 

213 try: 

214 output = recoveredOutputs[name] 

215 except KeyError: 

216 raise AssertionError(f"No such output: {name}") 

217 # multiple 

218 if connection.multiple: 

219 if not isinstance(output, collections.abc.Sequence): 

220 raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.") 

221 else: 

222 # use lazy evaluation to not use StorageClassFactory unless necessary 

223 if isinstance(output, collections.abc.Sequence) \ 

224 and not issubclass( 

225 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

226 collections.abc.Sequence): 

227 raise AssertionError(f"Expected {name} to be a single value, got {output} instead.") 

228 # no test for storageClass, as I'm not sure how much persistence depends on duck-typing