Coverage for python/lsst/pipe/base/testUtils.py: 12%

146 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-18 02:36 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "assertValidInitOutput", 

26 "assertValidOutput", 

27 "getInitInputs", 

28 "lintConnections", 

29 "makeQuantum", 

30 "runTestQuantum", 

31] 

32 

33 

34import collections.abc 

35import itertools 

36import unittest.mock 

37from collections import defaultdict 

38from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union 

39 

40from lsst.daf.butler import ( 

41 Butler, 

42 DataCoordinate, 

43 DataId, 

44 DatasetRef, 

45 DatasetType, 

46 Dimension, 

47 DimensionUniverse, 

48 Quantum, 

49 SkyPixDimension, 

50 StorageClassFactory, 

51) 

52from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

53 

54from .butlerQuantumContext import ButlerQuantumContext 

55 

56if TYPE_CHECKING: 56 ↛ 57line 56 didn't jump to line 57, because the condition on line 56 was never true

57 from .config import PipelineTaskConfig 

58 from .connections import PipelineTaskConnections 

59 from .pipelineTask import PipelineTask 

60 from .struct import Struct 

61 

62 

63def makeQuantum( 

64 task: PipelineTask, 

65 butler: Butler, 

66 dataId: DataId, 

67 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]], 

68) -> Quantum: 

69 """Create a Quantum for a particular data ID(s). 

70 

71 Parameters 

72 ---------- 

73 task : `lsst.pipe.base.PipelineTask` 

74 The task whose processing the quantum represents. 

75 butler : `lsst.daf.butler.Butler` 

76 The collection the quantum refers to. 

77 dataId: any data ID type 

78 The data ID of the quantum. Must have the same dimensions as 

79 ``task``'s connections class. 

80 ioDataIds : `collections.abc.Mapping` [`str`] 

81 A mapping keyed by input/output names. Values must be data IDs for 

82 single connections and sequences of data IDs for multiple connections. 

83 

84 Returns 

85 ------- 

86 quantum : `lsst.daf.butler.Quantum` 

87 A quantum for ``task``, when called with ``dataIds``. 

88 """ 

89 connections = task.config.ConnectionsClass(config=task.config) 

90 

91 try: 

92 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys()) 

93 except ValueError as e: 

94 raise ValueError("Error in quantum dimensions.") from e 

95 

96 inputs = defaultdict(list) 

97 outputs = defaultdict(list) 

98 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

99 try: 

100 connection = connections.__getattribute__(name) 

101 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

102 ids = _normalizeDataIds(ioDataIds[name]) 

103 for id in ids: 

104 ref = _refFromConnection(butler, connection, id) 

105 inputs[ref.datasetType].append(ref) 

106 except (ValueError, KeyError) as e: 

107 raise ValueError(f"Error in connection {name}.") from e 

108 for name in connections.outputs: 

109 try: 

110 connection = connections.__getattribute__(name) 

111 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

112 ids = _normalizeDataIds(ioDataIds[name]) 

113 for id in ids: 

114 ref = _refFromConnection(butler, connection, id) 

115 outputs[ref.datasetType].append(ref) 

116 except (ValueError, KeyError) as e: 

117 raise ValueError(f"Error in connection {name}.") from e 

118 quantum = Quantum( 

119 taskClass=type(task), 

120 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions), 

121 inputs=inputs, 

122 outputs=outputs, 

123 ) 

124 return quantum 

125 

126 

127def _checkDimensionsMatch( 

128 universe: DimensionUniverse, 

129 expected: Union[AbstractSet[str], AbstractSet[Dimension]], 

130 actual: Union[AbstractSet[str], AbstractSet[Dimension]], 

131) -> None: 

132 """Test whether two sets of dimensions agree after conversions. 

133 

134 Parameters 

135 ---------- 

136 universe : `lsst.daf.butler.DimensionUniverse` 

137 The set of all known dimensions. 

138 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

139 The dimensions expected from a task specification. 

140 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

141 The dimensions provided by input. 

142 

143 Raises 

144 ------ 

145 ValueError 

146 Raised if ``expected`` and ``actual`` cannot be reconciled. 

147 """ 

148 if _simplify(universe, expected) != _simplify(universe, actual): 

149 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

150 

151 

152def _simplify( 

153 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]] 

154) -> Set[str]: 

155 """Reduce a set of dimensions to a string-only form. 

156 

157 Parameters 

158 ---------- 

159 universe : `lsst.daf.butler.DimensionUniverse` 

160 The set of all known dimensions. 

161 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

162 A set of dimensions to simplify. 

163 

164 Returns 

165 ------- 

166 dimensions : `Set` [`str`] 

167 A copy of ``dimensions`` reduced to string form, with all spatial 

168 dimensions simplified to ``skypix``. 

169 """ 

170 simplified: Set[str] = set() 

171 for dimension in dimensions: 

172 # skypix not a real Dimension, handle it first 

173 if dimension == "skypix": 

174 simplified.add(dimension) # type: ignore 

175 else: 

176 # Need a Dimension to test spatialness 

177 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

178 if isinstance(fullDimension, SkyPixDimension): 

179 simplified.add("skypix") 

180 else: 

181 simplified.add(fullDimension.name) 

182 return simplified 

183 

184 

185def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None: 

186 """Test whether data IDs are scalars for scalar connections and sequences 

187 for multiple connections. 

188 

189 Parameters 

190 ---------- 

191 name : `str` 

192 The name of the connection being tested. 

193 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

194 The data ID(s) provided for the connection. 

195 multiple : `bool` 

196 The ``multiple`` field of the connection. 

197 

198 Raises 

199 ------ 

200 ValueError 

201 Raised if ``dataIds`` and ``multiple`` do not match. 

202 """ 

203 if multiple: 

204 if not isinstance(dataIds, collections.abc.Sequence): 

205 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

206 else: 

207 # DataCoordinate is a Mapping 

208 if not isinstance(dataIds, collections.abc.Mapping): 

209 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

210 

211 

212def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]: 

213 """Represent both single and multiple data IDs as a list. 

214 

215 Parameters 

216 ---------- 

217 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

218 The data ID(s) provided for a particular input or output connection. 

219 

220 Returns 

221 ------- 

222 normalizedIds : `~collections.abc.Sequence` [data ID] 

223 A sequence equal to ``dataIds`` if it was already a sequence, or 

224 ``[dataIds]`` if it was a single ID. 

225 """ 

226 if isinstance(dataIds, collections.abc.Sequence): 

227 return dataIds 

228 else: 

229 return [dataIds] 

230 

231 

232def _refFromConnection( 

233 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

234) -> DatasetRef: 

235 """Create a DatasetRef for a connection in a collection. 

236 

237 Parameters 

238 ---------- 

239 butler : `lsst.daf.butler.Butler` 

240 The collection to point to. 

241 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

242 The connection defining the dataset type to point to. 

243 dataId 

244 The data ID for the dataset to point to. 

245 **kwargs 

246 Additional keyword arguments used to augment or construct 

247 a `~lsst.daf.butler.DataCoordinate`. 

248 

249 Returns 

250 ------- 

251 ref : `lsst.daf.butler.DatasetRef` 

252 A reference to a dataset compatible with ``connection``, with ID 

253 ``dataId``, in the collection pointed to by ``butler``. 

254 """ 

255 universe = butler.registry.dimensions 

256 # DatasetRef only tests if required dimension is missing, but not extras 

257 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

258 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

259 

260 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't 

261 # understand it. Code copied from TaskDatasetTypes.fromTaskDef 

262 if "skypix" in connection.dimensions: 

263 datasetType = butler.registry.getDatasetType(connection.name) 

264 else: 

265 datasetType = connection.makeDatasetType(universe) 

266 

267 try: 

268 butler.registry.getDatasetType(datasetType.name) 

269 except KeyError: 

270 raise ValueError(f"Invalid dataset type {connection.name}.") 

271 try: 

272 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

273 return ref 

274 except KeyError as e: 

275 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

276 

277 

278def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None: 

279 """Look up all input datasets a test quantum in the `Registry` to resolve 

280 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

281 ``run`` attributes). 

282 

283 Parameters 

284 ---------- 

285 quantum : `~lsst.daf.butler.Quantum` 

286 Single Quantum instance. 

287 butler : `~lsst.daf.butler.Butler` 

288 Data butler. 

289 """ 

290 # TODO (DM-26819): This function is a direct copy of 

291 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

292 # `runTestQuantum` function that calls it is essentially duplicating logic 

293 # in that class as well (albeit not verbatim). We should probably move 

294 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

295 # in test code instead of having these classes at all. 

296 for refsForDatasetType in quantum.inputs.values(): 

297 newRefsForDatasetType = [] 

298 for ref in refsForDatasetType: 

299 if ref.id is None: 

300 resolvedRef = butler.registry.findDataset( 

301 ref.datasetType, ref.dataId, collections=butler.collections 

302 ) 

303 if resolvedRef is None: 

304 raise ValueError( 

305 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

306 f"in collections {butler.collections}." 

307 ) 

308 newRefsForDatasetType.append(resolvedRef) 

309 else: 

310 newRefsForDatasetType.append(ref) 

311 refsForDatasetType[:] = newRefsForDatasetType 

312 

313 

314def runTestQuantum( 

315 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

316) -> Optional[unittest.mock.Mock]: 

317 """Run a PipelineTask on a Quantum. 

318 

319 Parameters 

320 ---------- 

321 task : `lsst.pipe.base.PipelineTask` 

322 The task to run on the quantum. 

323 butler : `lsst.daf.butler.Butler` 

324 The collection to run on. 

325 quantum : `lsst.daf.butler.Quantum` 

326 The quantum to run. 

327 mockRun : `bool` 

328 Whether or not to replace ``task``'s ``run`` method. The default of 

329 `True` is recommended unless ``run`` needs to do real work (e.g., 

330 because the test needs real output datasets). 

331 

332 Returns 

333 ------- 

334 run : `unittest.mock.Mock` or `None` 

335 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

336 be queried for the arguments ``runQuantum`` passed to ``run``. 

337 """ 

338 _resolveTestQuantumInputs(butler, quantum) 

339 butlerQc = ButlerQuantumContext(butler, quantum) 

340 connections = task.config.ConnectionsClass(config=task.config) 

341 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

342 if mockRun: 

343 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch( 

344 "lsst.pipe.base.ButlerQuantumContext.put" 

345 ): 

346 task.runQuantum(butlerQc, inputRefs, outputRefs) 

347 return mock 

348 else: 

349 task.runQuantum(butlerQc, inputRefs, outputRefs) 

350 return None 

351 

352 

353def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

354 """Test that an attribute on an object matches the specification given in 

355 a connection. 

356 

357 Parameters 

358 ---------- 

359 obj 

360 An object expected to contain the attribute ``attrName``. 

361 attrName : `str` 

362 The name of the attribute to be tested. 

363 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

364 The connection, usually some type of output, specifying ``attrName``. 

365 

366 Raises 

367 ------ 

368 AssertionError: 

369 Raised if ``obj.attrName`` does not match what's expected 

370 from ``connection``. 

371 """ 

372 # name 

373 try: 

374 attrValue = obj.__getattribute__(attrName) 

375 except AttributeError: 

376 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

377 # multiple 

378 if connection.multiple: 

379 if not isinstance(attrValue, collections.abc.Sequence): 

380 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

381 else: 

382 # use lazy evaluation to not use StorageClassFactory unless 

383 # necessary 

384 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

385 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

386 ): 

387 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

388 # no test for storageClass, as I'm not sure how much persistence 

389 # depends on duck-typing 

390 

391 

392def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

393 """Test that the output of a call to ``run`` conforms to its own 

394 connections. 

395 

396 Parameters 

397 ---------- 

398 task : `lsst.pipe.base.PipelineTask` 

399 The task whose connections need validation. This is a fully-configured 

400 task object to support features such as optional outputs. 

401 result : `lsst.pipe.base.Struct` 

402 A result object produced by calling ``task.run``. 

403 

404 Raises 

405 ------ 

406 AssertionError: 

407 Raised if ``result`` does not match what's expected from ``task's`` 

408 connections. 

409 """ 

410 connections = task.config.ConnectionsClass(config=task.config) 

411 

412 for name in connections.outputs: 

413 connection = connections.__getattribute__(name) 

414 _assertAttributeMatchesConnection(result, name, connection) 

415 

416 

417def assertValidInitOutput(task: PipelineTask) -> None: 

418 """Test that a constructed task conforms to its own init-connections. 

419 

420 Parameters 

421 ---------- 

422 task : `lsst.pipe.base.PipelineTask` 

423 The task whose connections need validation. 

424 

425 Raises 

426 ------ 

427 AssertionError: 

428 Raised if ``task`` does not have the state expected from ``task's`` 

429 connections. 

430 """ 

431 connections = task.config.ConnectionsClass(config=task.config) 

432 

433 for name in connections.initOutputs: 

434 connection = connections.__getattribute__(name) 

435 _assertAttributeMatchesConnection(task, name, connection) 

436 

437 

438def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]: 

439 """Return the initInputs object that would have been passed to a 

440 `~lsst.pipe.base.PipelineTask` constructor. 

441 

442 Parameters 

443 ---------- 

444 butler : `lsst.daf.butler.Butler` 

445 The repository to search for input datasets. Must have 

446 pre-configured collections. 

447 config : `lsst.pipe.base.PipelineTaskConfig` 

448 The config for the task to be constructed. 

449 

450 Returns 

451 ------- 

452 initInputs : `dict` [`str`] 

453 A dictionary of objects in the format of the ``initInputs`` parameter 

454 to `lsst.pipe.base.PipelineTask`. 

455 """ 

456 connections = config.connections.ConnectionsClass(config=config) 

457 initInputs = {} 

458 for name in connections.initInputs: 

459 attribute = getattr(connections, name) 

460 # Get full dataset type to check for consistency problems 

461 dsType = DatasetType( 

462 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass 

463 ) 

464 # All initInputs have empty data IDs 

465 initInputs[name] = butler.get(dsType) 

466 

467 return initInputs 

468 

469 

470def lintConnections( 

471 connections: PipelineTaskConnections, 

472 *, 

473 checkMissingMultiple: bool = True, 

474 checkUnnecessaryMultiple: bool = True, 

475) -> None: 

476 """Inspect a connections class for common errors. 

477 

478 These tests are designed to detect misuse of connections features in 

479 standard designs. An unusually designed connections class may trigger 

480 alerts despite being correctly written; specific checks can be turned off 

481 using keywords. 

482 

483 Parameters 

484 ---------- 

485 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

486 The connections class to test. 

487 checkMissingMultiple : `bool` 

488 Whether to test for single connections that would match multiple 

489 datasets at run time. 

490 checkUnnecessaryMultiple : `bool` 

491 Whether to test for multiple connections that would only match 

492 one dataset. 

493 

494 Raises 

495 ------ 

496 AssertionError 

497 Raised if any of the selected checks fail for any connection. 

498 """ 

499 # Since all comparisons are inside the class, don't bother 

500 # normalizing skypix. 

501 quantumDimensions = connections.dimensions 

502 

503 errors = "" 

504 # connectionTypes.DimensionedConnection is implementation detail, 

505 # don't use it. 

506 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

507 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

508 connDimensions = set(connection.dimensions) 

509 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

510 errors += ( 

511 f"Connection {name} may be called with multiple values of " 

512 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

513 ) 

514 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

515 errors += ( 

516 f"Connection {name} has multiple=True but can only be called with one " 

517 f"value of {connDimensions} for each {quantumDimensions}.\n" 

518 ) 

519 if errors: 

520 raise AssertionError(errors)