Coverage for python/lsst/pipe/base/testUtils.py: 10%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

138 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23__all__ = [ 

24 "assertValidInitOutput", 

25 "assertValidOutput", 

26 "getInitInputs", 

27 "lintConnections", 

28 "makeQuantum", 

29 "runTestQuantum", 

30] 

31 

32 

33import collections.abc 

34import itertools 

35import unittest.mock 

36from collections import defaultdict 

37 

38from lsst.daf.butler import ( 

39 DataCoordinate, 

40 DatasetRef, 

41 DatasetType, 

42 Quantum, 

43 SkyPixDimension, 

44 StorageClassFactory, 

45) 

46from lsst.pipe.base import ButlerQuantumContext 

47 

48 

49def makeQuantum(task, butler, dataId, ioDataIds): 

50 """Create a Quantum for a particular data ID(s). 

51 

52 Parameters 

53 ---------- 

54 task : `lsst.pipe.base.PipelineTask` 

55 The task whose processing the quantum represents. 

56 butler : `lsst.daf.butler.Butler` 

57 The collection the quantum refers to. 

58 dataId: any data ID type 

59 The data ID of the quantum. Must have the same dimensions as 

60 ``task``'s connections class. 

61 ioDataIds : `collections.abc.Mapping` [`str`] 

62 A mapping keyed by input/output names. Values must be data IDs for 

63 single connections and sequences of data IDs for multiple connections. 

64 

65 Returns 

66 ------- 

67 quantum : `lsst.daf.butler.Quantum` 

68 A quantum for ``task``, when called with ``dataIds``. 

69 """ 

70 connections = task.config.ConnectionsClass(config=task.config) 

71 

72 try: 

73 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys()) 

74 except ValueError as e: 

75 raise ValueError("Error in quantum dimensions.") from e 

76 

77 inputs = defaultdict(list) 

78 outputs = defaultdict(list) 

79 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

80 try: 

81 connection = connections.__getattribute__(name) 

82 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

83 ids = _normalizeDataIds(ioDataIds[name]) 

84 for id in ids: 

85 ref = _refFromConnection(butler, connection, id) 

86 inputs[ref.datasetType].append(ref) 

87 except (ValueError, KeyError) as e: 

88 raise ValueError(f"Error in connection {name}.") from e 

89 for name in connections.outputs: 

90 try: 

91 connection = connections.__getattribute__(name) 

92 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

93 ids = _normalizeDataIds(ioDataIds[name]) 

94 for id in ids: 

95 ref = _refFromConnection(butler, connection, id) 

96 outputs[ref.datasetType].append(ref) 

97 except (ValueError, KeyError) as e: 

98 raise ValueError(f"Error in connection {name}.") from e 

99 quantum = Quantum(taskClass=type(task), dataId=dataId, inputs=inputs, outputs=outputs) 

100 return quantum 

101 

102 

103def _checkDimensionsMatch(universe, expected, actual): 

104 """Test whether two sets of dimensions agree after conversions. 

105 

106 Parameters 

107 ---------- 

108 universe : `lsst.daf.butler.DimensionUniverse` 

109 The set of all known dimensions. 

110 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

111 The dimensions expected from a task specification. 

112 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

113 The dimensions provided by input. 

114 

115 Raises 

116 ------ 

117 ValueError 

118 Raised if ``expected`` and ``actual`` cannot be reconciled. 

119 """ 

120 if _simplify(universe, expected) != _simplify(universe, actual): 

121 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

122 

123 

124def _simplify(universe, dimensions): 

125 """Reduce a set of dimensions to a string-only form. 

126 

127 Parameters 

128 ---------- 

129 universe : `lsst.daf.butler.DimensionUniverse` 

130 The set of all known dimensions. 

131 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

132 A set of dimensions to simplify. 

133 

134 Returns 

135 ------- 

136 dimensions : `Set` [`str`] 

137 A copy of ``dimensions`` reduced to string form, with all spatial 

138 dimensions simplified to ``skypix``. 

139 """ 

140 simplified = set() 

141 for dimension in dimensions: 

142 # skypix not a real Dimension, handle it first 

143 if dimension == "skypix": 

144 simplified.add(dimension) 

145 else: 

146 # Need a Dimension to test spatialness 

147 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

148 if isinstance(fullDimension, SkyPixDimension): 

149 simplified.add("skypix") 

150 else: 

151 simplified.add(fullDimension.name) 

152 return simplified 

153 

154 

155def _checkDataIdMultiplicity(name, dataIds, multiple): 

156 """Test whether data IDs are scalars for scalar connections and sequences 

157 for multiple connections. 

158 

159 Parameters 

160 ---------- 

161 name : `str` 

162 The name of the connection being tested. 

163 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

164 The data ID(s) provided for the connection. 

165 multiple : `bool` 

166 The ``multiple`` field of the connection. 

167 

168 Raises 

169 ------ 

170 ValueError 

171 Raised if ``dataIds`` and ``multiple`` do not match. 

172 """ 

173 if multiple: 

174 if not isinstance(dataIds, collections.abc.Sequence): 

175 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

176 else: 

177 # DataCoordinate is a Mapping 

178 if not isinstance(dataIds, collections.abc.Mapping): 

179 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

180 

181 

182def _normalizeDataIds(dataIds): 

183 """Represent both single and multiple data IDs as a list. 

184 

185 Parameters 

186 ---------- 

187 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

188 The data ID(s) provided for a particular input or output connection. 

189 

190 Returns 

191 ------- 

192 normalizedIds : `~collections.abc.Sequence` [data ID] 

193 A sequence equal to ``dataIds`` if it was already a sequence, or 

194 ``[dataIds]`` if it was a single ID. 

195 """ 

196 if isinstance(dataIds, collections.abc.Sequence): 

197 return dataIds 

198 else: 

199 return [dataIds] 

200 

201 

202def _refFromConnection(butler, connection, dataId, **kwargs): 

203 """Create a DatasetRef for a connection in a collection. 

204 

205 Parameters 

206 ---------- 

207 butler : `lsst.daf.butler.Butler` 

208 The collection to point to. 

209 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

210 The connection defining the dataset type to point to. 

211 dataId 

212 The data ID for the dataset to point to. 

213 **kwargs 

214 Additional keyword arguments used to augment or construct 

215 a `~lsst.daf.butler.DataCoordinate`. 

216 

217 Returns 

218 ------- 

219 ref : `lsst.daf.butler.DatasetRef` 

220 A reference to a dataset compatible with ``connection``, with ID 

221 ``dataId``, in the collection pointed to by ``butler``. 

222 """ 

223 universe = butler.registry.dimensions 

224 # DatasetRef only tests if required dimension is missing, but not extras 

225 _checkDimensionsMatch(universe, connection.dimensions, dataId.keys()) 

226 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

227 

228 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

229 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

230 if "skypix" in connection.dimensions: 

231 datasetType = butler.registry.getDatasetType(connection.name) 

232 else: 

233 datasetType = connection.makeDatasetType(universe) 

234 

235 try: 

236 butler.registry.getDatasetType(datasetType.name) 

237 except KeyError: 

238 raise ValueError(f"Invalid dataset type {connection.name}.") 

239 try: 

240 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

241 return ref 

242 except KeyError as e: 

243 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

244 

245 

246def _resolveTestQuantumInputs(butler, quantum): 

247 """Look up all input datasets a test quantum in the `Registry` to resolve 

248 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

249 ``run`` attributes). 

250 

251 Parameters 

252 ---------- 

253 quantum : `~lsst.daf.butler.Quantum` 

254 Single Quantum instance. 

255 butler : `~lsst.daf.butler.Butler` 

256 Data butler. 

257 """ 

258 # TODO (DM-26819): This function is a direct copy of 

259 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

260 # `runTestQuantum` function that calls it is essentially duplicating logic 

261 # in that class as well (albeit not verbatim). We should probably move 

262 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

263 # in test code instead of having these classes at all. 

264 for refsForDatasetType in quantum.inputs.values(): 

265 newRefsForDatasetType = [] 

266 for ref in refsForDatasetType: 

267 if ref.id is None: 

268 resolvedRef = butler.registry.findDataset( 

269 ref.datasetType, ref.dataId, collections=butler.collections 

270 ) 

271 if resolvedRef is None: 

272 raise ValueError( 

273 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

274 f"in collections {butler.collections}." 

275 ) 

276 newRefsForDatasetType.append(resolvedRef) 

277 else: 

278 newRefsForDatasetType.append(ref) 

279 refsForDatasetType[:] = newRefsForDatasetType 

280 

281 

282def runTestQuantum(task, butler, quantum, mockRun=True): 

283 """Run a PipelineTask on a Quantum. 

284 

285 Parameters 

286 ---------- 

287 task : `lsst.pipe.base.PipelineTask` 

288 The task to run on the quantum. 

289 butler : `lsst.daf.butler.Butler` 

290 The collection to run on. 

291 quantum : `lsst.daf.butler.Quantum` 

292 The quantum to run. 

293 mockRun : `bool` 

294 Whether or not to replace ``task``'s ``run`` method. The default of 

295 `True` is recommended unless ``run`` needs to do real work (e.g., 

296 because the test needs real output datasets). 

297 

298 Returns 

299 ------- 

300 run : `unittest.mock.Mock` or `None` 

301 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

302 be queried for the arguments ``runQuantum`` passed to ``run``. 

303 """ 

304 _resolveTestQuantumInputs(butler, quantum) 

305 butlerQc = ButlerQuantumContext(butler, quantum) 

306 connections = task.config.ConnectionsClass(config=task.config) 

307 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

308 if mockRun: 

309 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch( 

310 "lsst.pipe.base.ButlerQuantumContext.put" 

311 ): 

312 task.runQuantum(butlerQc, inputRefs, outputRefs) 

313 return mock 

314 else: 

315 task.runQuantum(butlerQc, inputRefs, outputRefs) 

316 return None 

317 

318 

319def _assertAttributeMatchesConnection(obj, attrName, connection): 

320 """Test that an attribute on an object matches the specification given in 

321 a connection. 

322 

323 Parameters 

324 ---------- 

325 obj 

326 An object expected to contain the attribute ``attrName``. 

327 attrName : `str` 

328 The name of the attribute to be tested. 

329 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

330 The connection, usually some type of output, specifying ``attrName``. 

331 

332 Raises 

333 ------ 

334 AssertionError: 

335 Raised if ``obj.attrName`` does not match what's expected 

336 from ``connection``. 

337 """ 

338 # name 

339 try: 

340 attrValue = obj.__getattribute__(attrName) 

341 except AttributeError: 

342 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

343 # multiple 

344 if connection.multiple: 

345 if not isinstance(attrValue, collections.abc.Sequence): 

346 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

347 else: 

348 # use lazy evaluation to not use StorageClassFactory unless 

349 # necessary 

350 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

351 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

352 ): 

353 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

354 # no test for storageClass, as I'm not sure how much persistence 

355 # depends on duck-typing 

356 

357 

358def assertValidOutput(task, result): 

359 """Test that the output of a call to ``run`` conforms to its own 

360 connections. 

361 

362 Parameters 

363 ---------- 

364 task : `lsst.pipe.base.PipelineTask` 

365 The task whose connections need validation. This is a fully-configured 

366 task object to support features such as optional outputs. 

367 result : `lsst.pipe.base.Struct` 

368 A result object produced by calling ``task.run``. 

369 

370 Raises 

371 ------ 

372 AssertionError: 

373 Raised if ``result`` does not match what's expected from ``task's`` 

374 connections. 

375 """ 

376 connections = task.config.ConnectionsClass(config=task.config) 

377 

378 for name in connections.outputs: 

379 connection = connections.__getattribute__(name) 

380 _assertAttributeMatchesConnection(result, name, connection) 

381 

382 

383def assertValidInitOutput(task): 

384 """Test that a constructed task conforms to its own init-connections. 

385 

386 Parameters 

387 ---------- 

388 task : `lsst.pipe.base.PipelineTask` 

389 The task whose connections need validation. 

390 

391 Raises 

392 ------ 

393 AssertionError: 

394 Raised if ``task`` does not have the state expected from ``task's`` 

395 connections. 

396 """ 

397 connections = task.config.ConnectionsClass(config=task.config) 

398 

399 for name in connections.initOutputs: 

400 connection = connections.__getattribute__(name) 

401 _assertAttributeMatchesConnection(task, name, connection) 

402 

403 

404def getInitInputs(butler, config): 

405 """Return the initInputs object that would have been passed to a 

406 `~lsst.pipe.base.PipelineTask` constructor. 

407 

408 Parameters 

409 ---------- 

410 butler : `lsst.daf.butler.Butler` 

411 The repository to search for input datasets. Must have 

412 pre-configured collections. 

413 config : `lsst.pipe.base.PipelineTaskConfig` 

414 The config for the task to be constructed. 

415 

416 Returns 

417 ------- 

418 initInputs : `dict` [`str`] 

419 A dictionary of objects in the format of the ``initInputs`` parameter 

420 to `lsst.pipe.base.PipelineTask`. 

421 """ 

422 connections = config.connections.ConnectionsClass(config=config) 

423 initInputs = {} 

424 for name in connections.initInputs: 

425 attribute = getattr(connections, name) 

426 # Get full dataset type to check for consistency problems 

427 dsType = DatasetType( 

428 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass 

429 ) 

430 # All initInputs have empty data IDs 

431 initInputs[name] = butler.get(dsType) 

432 

433 return initInputs 

434 

435 

436def lintConnections( 

437 connections, 

438 *, 

439 checkMissingMultiple=True, 

440 checkUnnecessaryMultiple=True, 

441): 

442 """Inspect a connections class for common errors. 

443 

444 These tests are designed to detect misuse of connections features in 

445 standard designs. An unusually designed connections class may trigger 

446 alerts despite being correctly written; specific checks can be turned off 

447 using keywords. 

448 

449 Parameters 

450 ---------- 

451 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

452 The connections class to test. 

453 checkMissingMultiple : `bool` 

454 Whether to test for single connections that would match multiple 

455 datasets at run time. 

456 checkUnnecessaryMultiple : `bool` 

457 Whether to test for multiple connections that would only match 

458 one dataset. 

459 

460 Raises 

461 ------ 

462 AssertionError 

463 Raised if any of the selected checks fail for any connection. 

464 """ 

465 # Since all comparisons are inside the class, don't bother 

466 # normalizing skypix. 

467 quantumDimensions = connections.dimensions 

468 

469 errors = "" 

470 # connectionTypes.DimensionedConnection is implementation detail, 

471 # don't use it. 

472 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

473 connection = connections.allConnections[name] 

474 connDimensions = set(connection.dimensions) 

475 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

476 errors += ( 

477 f"Connection {name} may be called with multiple values of " 

478 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

479 ) 

480 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

481 errors += ( 

482 f"Connection {name} has multiple=True but can only be called with one " 

483 f"value of {connDimensions} for each {quantumDimensions}.\n" 

484 ) 

485 if errors: 

486 raise AssertionError(errors)