Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["assertValidInitOutput", 

24 "assertValidOutput", 

25 "getInitInputs", 

26 "makeQuantum", 

27 "runTestQuantum", 

28 ] 

29 

30 

31from collections import defaultdict 

32import collections.abc 

33import itertools 

34import unittest.mock 

35 

36from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory 

37from lsst.pipe.base import ButlerQuantumContext 

38 

39 

40def makeQuantum(task, butler, dataId, ioDataIds): 

41 """Create a Quantum for a particular data ID(s). 

42 

43 Parameters 

44 ---------- 

45 task : `lsst.pipe.base.PipelineTask` 

46 The task whose processing the quantum represents. 

47 butler : `lsst.daf.butler.Butler` 

48 The collection the quantum refers to. 

49 dataId: any data ID type 

50 The data ID of the quantum. Must have the same dimensions as 

51 ``task``'s connections class. 

52 ioDataIds : `collections.abc.Mapping` [`str`] 

53 A mapping keyed by input/output names. Values must be data IDs for 

54 single connections and sequences of data IDs for multiple connections. 

55 

56 Returns 

57 ------- 

58 quantum : `lsst.daf.butler.Quantum` 

59 A quantum for ``task``, when called with ``dataIds``. 

60 """ 

61 connections = task.config.ConnectionsClass(config=task.config) 

62 

63 try: 

64 inputs = defaultdict(list) 

65 outputs = defaultdict(list) 

66 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

67 connection = connections.__getattribute__(name) 

68 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

69 ids = _normalizeDataIds(ioDataIds[name]) 

70 for id in ids: 

71 ref = _refFromConnection(butler, connection, id) 

72 inputs[ref.datasetType].append(ref) 

73 for name in connections.outputs: 

74 connection = connections.__getattribute__(name) 

75 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

76 ids = _normalizeDataIds(ioDataIds[name]) 

77 for id in ids: 

78 ref = _refFromConnection(butler, connection, id) 

79 outputs[ref.datasetType].append(ref) 

80 quantum = Quantum(taskClass=type(task), 

81 dataId=dataId, 

82 inputs=inputs, 

83 outputs=outputs) 

84 return quantum 

85 except KeyError as e: 

86 raise ValueError("Mismatch in input data.") from e 

87 

88 

89def _checkDataIdMultiplicity(name, dataIds, multiple): 

90 """Test whether data IDs are scalars for scalar connections and sequences 

91 for multiple connections. 

92 

93 Parameters 

94 ---------- 

95 name : `str` 

96 The name of the connection being tested. 

97 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

98 The data ID(s) provided for the connection. 

99 multiple : `bool` 

100 The ``multiple`` field of the connection. 

101 

102 Raises 

103 ------ 

104 ValueError 

105 Raised if ``dataIds`` and ``multiple`` do not match. 

106 """ 

107 if multiple: 

108 if not isinstance(dataIds, collections.abc.Sequence): 

109 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

110 else: 

111 # DataCoordinate is a Mapping 

112 if not isinstance(dataIds, collections.abc.Mapping): 

113 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

114 

115 

116def _normalizeDataIds(dataIds): 

117 """Represent both single and multiple data IDs as a list. 

118 

119 Parameters 

120 ---------- 

121 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

122 The data ID(s) provided for a particular input or output connection. 

123 

124 Returns 

125 ------- 

126 normalizedIds : `~collections.abc.Sequence` [data ID] 

127 A sequence equal to ``dataIds`` if it was already a sequence, or 

128 ``[dataIds]`` if it was a single ID. 

129 """ 

130 if isinstance(dataIds, collections.abc.Sequence): 

131 return dataIds 

132 else: 

133 return [dataIds] 

134 

135 

136def _refFromConnection(butler, connection, dataId, **kwargs): 

137 """Create a DatasetRef for a connection in a collection. 

138 

139 Parameters 

140 ---------- 

141 butler : `lsst.daf.butler.Butler` 

142 The collection to point to. 

143 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

144 The connection defining the dataset type to point to. 

145 dataId 

146 The data ID for the dataset to point to. 

147 **kwargs 

148 Additional keyword arguments used to augment or construct 

149 a `~lsst.daf.butler.DataCoordinate`. 

150 

151 Returns 

152 ------- 

153 ref : `lsst.daf.butler.DatasetRef` 

154 A reference to a dataset compatible with ``connection``, with ID 

155 ``dataId``, in the collection pointed to by ``butler``. 

156 """ 

157 universe = butler.registry.dimensions 

158 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

159 

160 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

161 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

162 if "skypix" in connection.dimensions: 

163 datasetType = butler.registry.getDatasetType(connection.name) 

164 else: 

165 datasetType = connection.makeDatasetType(universe) 

166 

167 try: 

168 butler.registry.getDatasetType(datasetType.name) 

169 except KeyError: 

170 raise ValueError(f"Invalid dataset type {connection.name}.") 

171 try: 

172 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

173 return ref 

174 except KeyError as e: 

175 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

176 from e 

177 

178 

179def _resolveTestQuantumInputs(butler, quantum): 

180 """Look up all input datasets a test quantum in the `Registry` to resolve 

181 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

182 ``run`` attributes). 

183 

184 Parameters 

185 ---------- 

186 quantum : `~lsst.daf.butler.Quantum` 

187 Single Quantum instance. 

188 butler : `~lsst.daf.butler.Butler` 

189 Data butler. 

190 """ 

191 # TODO (DM-26819): This function is a direct copy of 

192 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

193 # `runTestQuantum` function that calls it is essentially duplicating logic 

194 # in that class as well (albeit not verbatim). We should probably move 

195 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

196 # in test code instead of having these classes at all. 

197 for refsForDatasetType in quantum.inputs.values(): 

198 newRefsForDatasetType = [] 

199 for ref in refsForDatasetType: 

200 if ref.id is None: 

201 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId, 

202 collections=butler.collections) 

203 if resolvedRef is None: 

204 raise ValueError( 

205 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

206 f"in collections {butler.collections}." 

207 ) 

208 newRefsForDatasetType.append(resolvedRef) 

209 else: 

210 newRefsForDatasetType.append(ref) 

211 refsForDatasetType[:] = newRefsForDatasetType 

212 

213 

214def runTestQuantum(task, butler, quantum, mockRun=True): 

215 """Run a PipelineTask on a Quantum. 

216 

217 Parameters 

218 ---------- 

219 task : `lsst.pipe.base.PipelineTask` 

220 The task to run on the quantum. 

221 butler : `lsst.daf.butler.Butler` 

222 The collection to run on. 

223 quantum : `lsst.daf.butler.Quantum` 

224 The quantum to run. 

225 mockRun : `bool` 

226 Whether or not to replace ``task``'s ``run`` method. The default of 

227 `True` is recommended unless ``run`` needs to do real work (e.g., 

228 because the test needs real output datasets). 

229 

230 Returns 

231 ------- 

232 run : `unittest.mock.Mock` or `None` 

233 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

234 be queried for the arguments ``runQuantum`` passed to ``run``. 

235 """ 

236 _resolveTestQuantumInputs(butler, quantum) 

237 butlerQc = ButlerQuantumContext(butler, quantum) 

238 connections = task.config.ConnectionsClass(config=task.config) 

239 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

240 if mockRun: 

241 with unittest.mock.patch.object(task, "run") as mock, \ 

242 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

243 task.runQuantum(butlerQc, inputRefs, outputRefs) 

244 return mock 

245 else: 

246 task.runQuantum(butlerQc, inputRefs, outputRefs) 

247 return None 

248 

249 

250def _assertAttributeMatchesConnection(obj, attrName, connection): 

251 """Test that an attribute on an object matches the specification given in 

252 a connection. 

253 

254 Parameters 

255 ---------- 

256 obj 

257 An object expected to contain the attribute ``attrName``. 

258 attrName : `str` 

259 The name of the attribute to be tested. 

260 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

261 The connection, usually some type of output, specifying ``attrName``. 

262 

263 Raises 

264 ------ 

265 AssertionError: 

266 Raised if ``obj.attrName`` does not match what's expected 

267 from ``connection``. 

268 """ 

269 # name 

270 try: 

271 attrValue = obj.__getattribute__(attrName) 

272 except AttributeError: 

273 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

274 # multiple 

275 if connection.multiple: 

276 if not isinstance(attrValue, collections.abc.Sequence): 

277 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

278 else: 

279 # use lazy evaluation to not use StorageClassFactory unless 

280 # necessary 

281 if isinstance(attrValue, collections.abc.Sequence) \ 

282 and not issubclass( 

283 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

284 collections.abc.Sequence): 

285 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

286 # no test for storageClass, as I'm not sure how much persistence 

287 # depends on duck-typing 

288 

289 

290def assertValidOutput(task, result): 

291 """Test that the output of a call to ``run`` conforms to its own 

292 connections. 

293 

294 Parameters 

295 ---------- 

296 task : `lsst.pipe.base.PipelineTask` 

297 The task whose connections need validation. This is a fully-configured 

298 task object to support features such as optional outputs. 

299 result : `lsst.pipe.base.Struct` 

300 A result object produced by calling ``task.run``. 

301 

302 Raises 

303 ------ 

304 AssertionError: 

305 Raised if ``result`` does not match what's expected from ``task's`` 

306 connections. 

307 """ 

308 connections = task.config.ConnectionsClass(config=task.config) 

309 

310 for name in connections.outputs: 

311 connection = connections.__getattribute__(name) 

312 _assertAttributeMatchesConnection(result, name, connection) 

313 

314 

315def assertValidInitOutput(task): 

316 """Test that a constructed task conforms to its own init-connections. 

317 

318 Parameters 

319 ---------- 

320 task : `lsst.pipe.base.PipelineTask` 

321 The task whose connections need validation. 

322 

323 Raises 

324 ------ 

325 AssertionError: 

326 Raised if ``task`` does not have the state expected from ``task's`` 

327 connections. 

328 """ 

329 connections = task.config.ConnectionsClass(config=task.config) 

330 

331 for name in connections.initOutputs: 

332 connection = connections.__getattribute__(name) 

333 _assertAttributeMatchesConnection(task, name, connection) 

334 

335 

336def getInitInputs(butler, config): 

337 """Return the initInputs object that would have been passed to a 

338 `~lsst.pipe.base.PipelineTask` constructor. 

339 

340 Parameters 

341 ---------- 

342 butler : `lsst.daf.butler.Butler` 

343 The repository to search for input datasets. Must have 

344 pre-configured collections. 

345 config : `lsst.pipe.base.PipelineTaskConfig` 

346 The config for the task to be constructed. 

347 

348 Returns 

349 ------- 

350 initInputs : `dict` [`str`] 

351 A dictionary of objects in the format of the ``initInputs`` parameter 

352 to `lsst.pipe.base.PipelineTask`. 

353 """ 

354 connections = config.connections.ConnectionsClass(config=config) 

355 initInputs = {} 

356 for name in connections.initInputs: 

357 attribute = getattr(connections, name) 

358 # Get full dataset type to check for consistency problems 

359 dsType = DatasetType(attribute.name, butler.registry.dimensions.extract(set()), 

360 attribute.storageClass) 

361 # All initInputs have empty data IDs 

362 initInputs[name] = butler.get(dsType) 

363 

364 return initInputs