Coverage for python/lsst/pipe/base/testUtils.py: 10%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

138 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = ["assertValidInitOutput", 

24 "assertValidOutput", 

25 "getInitInputs", 

26 "lintConnections", 

27 "makeQuantum", 

28 "runTestQuantum", 

29 ] 

30 

31 

32from collections import defaultdict 

33import collections.abc 

34import itertools 

35import unittest.mock 

36 

37from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory, \ 

38 SkyPixDimension 

39from lsst.pipe.base import ButlerQuantumContext 

40 

41 

42def makeQuantum(task, butler, dataId, ioDataIds): 

43 """Create a Quantum for a particular data ID(s). 

44 

45 Parameters 

46 ---------- 

47 task : `lsst.pipe.base.PipelineTask` 

48 The task whose processing the quantum represents. 

49 butler : `lsst.daf.butler.Butler` 

50 The collection the quantum refers to. 

51 dataId: any data ID type 

52 The data ID of the quantum. Must have the same dimensions as 

53 ``task``'s connections class. 

54 ioDataIds : `collections.abc.Mapping` [`str`] 

55 A mapping keyed by input/output names. Values must be data IDs for 

56 single connections and sequences of data IDs for multiple connections. 

57 

58 Returns 

59 ------- 

60 quantum : `lsst.daf.butler.Quantum` 

61 A quantum for ``task``, when called with ``dataIds``. 

62 """ 

63 connections = task.config.ConnectionsClass(config=task.config) 

64 

65 try: 

66 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys()) 

67 except ValueError as e: 

68 raise ValueError("Error in quantum dimensions.") from e 

69 

70 inputs = defaultdict(list) 

71 outputs = defaultdict(list) 

72 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

73 try: 

74 connection = connections.__getattribute__(name) 

75 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

76 ids = _normalizeDataIds(ioDataIds[name]) 

77 for id in ids: 

78 ref = _refFromConnection(butler, connection, id) 

79 inputs[ref.datasetType].append(ref) 

80 except (ValueError, KeyError) as e: 

81 raise ValueError(f"Error in connection {name}.") from e 

82 for name in connections.outputs: 

83 try: 

84 connection = connections.__getattribute__(name) 

85 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

86 ids = _normalizeDataIds(ioDataIds[name]) 

87 for id in ids: 

88 ref = _refFromConnection(butler, connection, id) 

89 outputs[ref.datasetType].append(ref) 

90 except (ValueError, KeyError) as e: 

91 raise ValueError(f"Error in connection {name}.") from e 

92 quantum = Quantum(taskClass=type(task), 

93 dataId=dataId, 

94 inputs=inputs, 

95 outputs=outputs) 

96 return quantum 

97 

98 

99def _checkDimensionsMatch(universe, expected, actual): 

100 """Test whether two sets of dimensions agree after conversions. 

101 

102 Parameters 

103 ---------- 

104 universe : `lsst.daf.butler.DimensionUniverse` 

105 The set of all known dimensions. 

106 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

107 The dimensions expected from a task specification. 

108 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

109 The dimensions provided by input. 

110 

111 Raises 

112 ------ 

113 ValueError 

114 Raised if ``expected`` and ``actual`` cannot be reconciled. 

115 """ 

116 if _simplify(universe, expected) != _simplify(universe, actual): 

117 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

118 

119 

120def _simplify(universe, dimensions): 

121 """Reduce a set of dimensions to a string-only form. 

122 

123 Parameters 

124 ---------- 

125 universe : `lsst.daf.butler.DimensionUniverse` 

126 The set of all known dimensions. 

127 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

128 A set of dimensions to simplify. 

129 

130 Returns 

131 ------- 

132 dimensions : `Set` [`str`] 

133 A copy of ``dimensions`` reduced to string form, with all spatial 

134 dimensions simplified to ``skypix``. 

135 """ 

136 simplified = set() 

137 for dimension in dimensions: 

138 # skypix not a real Dimension, handle it first 

139 if dimension == "skypix": 

140 simplified.add(dimension) 

141 else: 

142 # Need a Dimension to test spatialness 

143 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

144 if isinstance(fullDimension, SkyPixDimension): 

145 simplified.add("skypix") 

146 else: 

147 simplified.add(fullDimension.name) 

148 return simplified 

149 

150 

151def _checkDataIdMultiplicity(name, dataIds, multiple): 

152 """Test whether data IDs are scalars for scalar connections and sequences 

153 for multiple connections. 

154 

155 Parameters 

156 ---------- 

157 name : `str` 

158 The name of the connection being tested. 

159 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

160 The data ID(s) provided for the connection. 

161 multiple : `bool` 

162 The ``multiple`` field of the connection. 

163 

164 Raises 

165 ------ 

166 ValueError 

167 Raised if ``dataIds`` and ``multiple`` do not match. 

168 """ 

169 if multiple: 

170 if not isinstance(dataIds, collections.abc.Sequence): 

171 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

172 else: 

173 # DataCoordinate is a Mapping 

174 if not isinstance(dataIds, collections.abc.Mapping): 

175 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

176 

177 

178def _normalizeDataIds(dataIds): 

179 """Represent both single and multiple data IDs as a list. 

180 

181 Parameters 

182 ---------- 

183 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

184 The data ID(s) provided for a particular input or output connection. 

185 

186 Returns 

187 ------- 

188 normalizedIds : `~collections.abc.Sequence` [data ID] 

189 A sequence equal to ``dataIds`` if it was already a sequence, or 

190 ``[dataIds]`` if it was a single ID. 

191 """ 

192 if isinstance(dataIds, collections.abc.Sequence): 

193 return dataIds 

194 else: 

195 return [dataIds] 

196 

197 

198def _refFromConnection(butler, connection, dataId, **kwargs): 

199 """Create a DatasetRef for a connection in a collection. 

200 

201 Parameters 

202 ---------- 

203 butler : `lsst.daf.butler.Butler` 

204 The collection to point to. 

205 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

206 The connection defining the dataset type to point to. 

207 dataId 

208 The data ID for the dataset to point to. 

209 **kwargs 

210 Additional keyword arguments used to augment or construct 

211 a `~lsst.daf.butler.DataCoordinate`. 

212 

213 Returns 

214 ------- 

215 ref : `lsst.daf.butler.DatasetRef` 

216 A reference to a dataset compatible with ``connection``, with ID 

217 ``dataId``, in the collection pointed to by ``butler``. 

218 """ 

219 universe = butler.registry.dimensions 

220 # DatasetRef only tests if required dimension is missing, but not extras 

221 _checkDimensionsMatch(universe, connection.dimensions, dataId.keys()) 

222 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

223 

224 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

225 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

226 if "skypix" in connection.dimensions: 

227 datasetType = butler.registry.getDatasetType(connection.name) 

228 else: 

229 datasetType = connection.makeDatasetType(universe) 

230 

231 try: 

232 butler.registry.getDatasetType(datasetType.name) 

233 except KeyError: 

234 raise ValueError(f"Invalid dataset type {connection.name}.") 

235 try: 

236 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

237 return ref 

238 except KeyError as e: 

239 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \ 

240 from e 

241 

242 

243def _resolveTestQuantumInputs(butler, quantum): 

244 """Look up all input datasets a test quantum in the `Registry` to resolve 

245 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

246 ``run`` attributes). 

247 

248 Parameters 

249 ---------- 

250 quantum : `~lsst.daf.butler.Quantum` 

251 Single Quantum instance. 

252 butler : `~lsst.daf.butler.Butler` 

253 Data butler. 

254 """ 

255 # TODO (DM-26819): This function is a direct copy of 

256 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

257 # `runTestQuantum` function that calls it is essentially duplicating logic 

258 # in that class as well (albeit not verbatim). We should probably move 

259 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

260 # in test code instead of having these classes at all. 

261 for refsForDatasetType in quantum.inputs.values(): 

262 newRefsForDatasetType = [] 

263 for ref in refsForDatasetType: 

264 if ref.id is None: 

265 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId, 

266 collections=butler.collections) 

267 if resolvedRef is None: 

268 raise ValueError( 

269 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

270 f"in collections {butler.collections}." 

271 ) 

272 newRefsForDatasetType.append(resolvedRef) 

273 else: 

274 newRefsForDatasetType.append(ref) 

275 refsForDatasetType[:] = newRefsForDatasetType 

276 

277 

278def runTestQuantum(task, butler, quantum, mockRun=True): 

279 """Run a PipelineTask on a Quantum. 

280 

281 Parameters 

282 ---------- 

283 task : `lsst.pipe.base.PipelineTask` 

284 The task to run on the quantum. 

285 butler : `lsst.daf.butler.Butler` 

286 The collection to run on. 

287 quantum : `lsst.daf.butler.Quantum` 

288 The quantum to run. 

289 mockRun : `bool` 

290 Whether or not to replace ``task``'s ``run`` method. The default of 

291 `True` is recommended unless ``run`` needs to do real work (e.g., 

292 because the test needs real output datasets). 

293 

294 Returns 

295 ------- 

296 run : `unittest.mock.Mock` or `None` 

297 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

298 be queried for the arguments ``runQuantum`` passed to ``run``. 

299 """ 

300 _resolveTestQuantumInputs(butler, quantum) 

301 butlerQc = ButlerQuantumContext(butler, quantum) 

302 connections = task.config.ConnectionsClass(config=task.config) 

303 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

304 if mockRun: 

305 with unittest.mock.patch.object(task, "run") as mock, \ 

306 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"): 

307 task.runQuantum(butlerQc, inputRefs, outputRefs) 

308 return mock 

309 else: 

310 task.runQuantum(butlerQc, inputRefs, outputRefs) 

311 return None 

312 

313 

314def _assertAttributeMatchesConnection(obj, attrName, connection): 

315 """Test that an attribute on an object matches the specification given in 

316 a connection. 

317 

318 Parameters 

319 ---------- 

320 obj 

321 An object expected to contain the attribute ``attrName``. 

322 attrName : `str` 

323 The name of the attribute to be tested. 

324 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

325 The connection, usually some type of output, specifying ``attrName``. 

326 

327 Raises 

328 ------ 

329 AssertionError: 

330 Raised if ``obj.attrName`` does not match what's expected 

331 from ``connection``. 

332 """ 

333 # name 

334 try: 

335 attrValue = obj.__getattribute__(attrName) 

336 except AttributeError: 

337 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

338 # multiple 

339 if connection.multiple: 

340 if not isinstance(attrValue, collections.abc.Sequence): 

341 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

342 else: 

343 # use lazy evaluation to not use StorageClassFactory unless 

344 # necessary 

345 if isinstance(attrValue, collections.abc.Sequence) \ 

346 and not issubclass( 

347 StorageClassFactory().getStorageClass(connection.storageClass).pytype, 

348 collections.abc.Sequence): 

349 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

350 # no test for storageClass, as I'm not sure how much persistence 

351 # depends on duck-typing 

352 

353 

354def assertValidOutput(task, result): 

355 """Test that the output of a call to ``run`` conforms to its own 

356 connections. 

357 

358 Parameters 

359 ---------- 

360 task : `lsst.pipe.base.PipelineTask` 

361 The task whose connections need validation. This is a fully-configured 

362 task object to support features such as optional outputs. 

363 result : `lsst.pipe.base.Struct` 

364 A result object produced by calling ``task.run``. 

365 

366 Raises 

367 ------ 

368 AssertionError: 

369 Raised if ``result`` does not match what's expected from ``task's`` 

370 connections. 

371 """ 

372 connections = task.config.ConnectionsClass(config=task.config) 

373 

374 for name in connections.outputs: 

375 connection = connections.__getattribute__(name) 

376 _assertAttributeMatchesConnection(result, name, connection) 

377 

378 

379def assertValidInitOutput(task): 

380 """Test that a constructed task conforms to its own init-connections. 

381 

382 Parameters 

383 ---------- 

384 task : `lsst.pipe.base.PipelineTask` 

385 The task whose connections need validation. 

386 

387 Raises 

388 ------ 

389 AssertionError: 

390 Raised if ``task`` does not have the state expected from ``task's`` 

391 connections. 

392 """ 

393 connections = task.config.ConnectionsClass(config=task.config) 

394 

395 for name in connections.initOutputs: 

396 connection = connections.__getattribute__(name) 

397 _assertAttributeMatchesConnection(task, name, connection) 

398 

399 

400def getInitInputs(butler, config): 

401 """Return the initInputs object that would have been passed to a 

402 `~lsst.pipe.base.PipelineTask` constructor. 

403 

404 Parameters 

405 ---------- 

406 butler : `lsst.daf.butler.Butler` 

407 The repository to search for input datasets. Must have 

408 pre-configured collections. 

409 config : `lsst.pipe.base.PipelineTaskConfig` 

410 The config for the task to be constructed. 

411 

412 Returns 

413 ------- 

414 initInputs : `dict` [`str`] 

415 A dictionary of objects in the format of the ``initInputs`` parameter 

416 to `lsst.pipe.base.PipelineTask`. 

417 """ 

418 connections = config.connections.ConnectionsClass(config=config) 

419 initInputs = {} 

420 for name in connections.initInputs: 

421 attribute = getattr(connections, name) 

422 # Get full dataset type to check for consistency problems 

423 dsType = DatasetType(attribute.name, butler.registry.dimensions.extract(set()), 

424 attribute.storageClass) 

425 # All initInputs have empty data IDs 

426 initInputs[name] = butler.get(dsType) 

427 

428 return initInputs 

429 

430 

431def lintConnections(connections, *, 

432 checkMissingMultiple=True, 

433 checkUnnecessaryMultiple=True, 

434 ): 

435 """Inspect a connections class for common errors. 

436 

437 These tests are designed to detect misuse of connections features in 

438 standard designs. An unusually designed connections class may trigger 

439 alerts despite being correctly written; specific checks can be turned off 

440 using keywords. 

441 

442 Parameters 

443 ---------- 

444 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

445 The connections class to test. 

446 checkMissingMultiple : `bool` 

447 Whether to test for single connections that would match multiple 

448 datasets at run time. 

449 checkUnnecessaryMultiple : `bool` 

450 Whether to test for multiple connections that would only match 

451 one dataset. 

452 

453 Raises 

454 ------ 

455 AssertionError 

456 Raised if any of the selected checks fail for any connection. 

457 """ 

458 # Since all comparisons are inside the class, don't bother 

459 # normalizing skypix. 

460 quantumDimensions = connections.dimensions 

461 

462 errors = "" 

463 # connectionTypes.DimensionedConnection is implementation detail, 

464 # don't use it. 

465 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

466 connection = connections.allConnections[name] 

467 connDimensions = set(connection.dimensions) 

468 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

469 errors += f"Connection {name} may be called with multiple values of " \ 

470 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

471 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

472 errors += f"Connection {name} has multiple=True but can only be called with one " \ 

473 f"value of {connDimensions} for each {quantumDimensions}.\n" 

474 if errors: 

475 raise AssertionError(errors)