Coverage for python/lsst/pipe/base/testUtils.py: 12%

150 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 02:03 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "assertValidInitOutput", 

26 "assertValidOutput", 

27 "getInitInputs", 

28 "lintConnections", 

29 "makeQuantum", 

30 "runTestQuantum", 

31] 

32 

33 

34import collections.abc 

35import itertools 

36import unittest.mock 

37from collections import defaultdict 

38from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union 

39 

40from lsst.daf.butler import ( 

41 Butler, 

42 DataCoordinate, 

43 DataId, 

44 DatasetRef, 

45 DatasetType, 

46 Dimension, 

47 DimensionUniverse, 

48 Quantum, 

49 SkyPixDimension, 

50 StorageClassFactory, 

51) 

52from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

53 

54from .butlerQuantumContext import ButlerQuantumContext 

55 

56if TYPE_CHECKING: 56 ↛ 57line 56 didn't jump to line 57, because the condition on line 56 was never true

57 from .config import PipelineTaskConfig 

58 from .connections import PipelineTaskConnections 

59 from .pipelineTask import PipelineTask 

60 from .struct import Struct 

61 

62 

63def makeQuantum( 

64 task: PipelineTask, 

65 butler: Butler, 

66 dataId: DataId, 

67 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]], 

68) -> Quantum: 

69 """Create a Quantum for a particular data ID(s). 

70 

71 Parameters 

72 ---------- 

73 task : `lsst.pipe.base.PipelineTask` 

74 The task whose processing the quantum represents. 

75 butler : `lsst.daf.butler.Butler` 

76 The collection the quantum refers to. 

77 dataId: any data ID type 

78 The data ID of the quantum. Must have the same dimensions as 

79 ``task``'s connections class. 

80 ioDataIds : `collections.abc.Mapping` [`str`] 

81 A mapping keyed by input/output names. Values must be data IDs for 

82 single connections and sequences of data IDs for multiple connections. 

83 

84 Returns 

85 ------- 

86 quantum : `lsst.daf.butler.Quantum` 

87 A quantum for ``task``, when called with ``dataIds``. 

88 """ 

89 # This is a type ignore, because `connections` is a dynamic class, but 

90 # it for sure will have this property 

91 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

92 

93 try: 

94 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys()) 

95 except ValueError as e: 

96 raise ValueError("Error in quantum dimensions.") from e 

97 

98 inputs = defaultdict(list) 

99 outputs = defaultdict(list) 

100 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

101 try: 

102 connection = connections.__getattribute__(name) 

103 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

104 ids = _normalizeDataIds(ioDataIds[name]) 

105 for id in ids: 

106 ref = _refFromConnection(butler, connection, id) 

107 inputs[ref.datasetType].append(ref) 

108 except (ValueError, KeyError) as e: 

109 raise ValueError(f"Error in connection {name}.") from e 

110 for name in connections.outputs: 

111 try: 

112 connection = connections.__getattribute__(name) 

113 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

114 ids = _normalizeDataIds(ioDataIds[name]) 

115 for id in ids: 

116 ref = _refFromConnection(butler, connection, id) 

117 outputs[ref.datasetType].append(ref) 

118 except (ValueError, KeyError) as e: 

119 raise ValueError(f"Error in connection {name}.") from e 

120 quantum = Quantum( 

121 taskClass=type(task), 

122 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions), 

123 inputs=inputs, 

124 outputs=outputs, 

125 ) 

126 return quantum 

127 

128 

129def _checkDimensionsMatch( 

130 universe: DimensionUniverse, 

131 expected: Union[AbstractSet[str], AbstractSet[Dimension]], 

132 actual: Union[AbstractSet[str], AbstractSet[Dimension]], 

133) -> None: 

134 """Test whether two sets of dimensions agree after conversions. 

135 

136 Parameters 

137 ---------- 

138 universe : `lsst.daf.butler.DimensionUniverse` 

139 The set of all known dimensions. 

140 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

141 The dimensions expected from a task specification. 

142 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

143 The dimensions provided by input. 

144 

145 Raises 

146 ------ 

147 ValueError 

148 Raised if ``expected`` and ``actual`` cannot be reconciled. 

149 """ 

150 if _simplify(universe, expected) != _simplify(universe, actual): 

151 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

152 

153 

154def _simplify( 

155 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]] 

156) -> Set[str]: 

157 """Reduce a set of dimensions to a string-only form. 

158 

159 Parameters 

160 ---------- 

161 universe : `lsst.daf.butler.DimensionUniverse` 

162 The set of all known dimensions. 

163 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

164 A set of dimensions to simplify. 

165 

166 Returns 

167 ------- 

168 dimensions : `Set` [`str`] 

169 A copy of ``dimensions`` reduced to string form, with all spatial 

170 dimensions simplified to ``skypix``. 

171 """ 

172 simplified: Set[str] = set() 

173 for dimension in dimensions: 

174 # skypix not a real Dimension, handle it first 

175 if dimension == "skypix": 

176 simplified.add(dimension) # type: ignore 

177 else: 

178 # Need a Dimension to test spatialness 

179 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

180 if isinstance(fullDimension, SkyPixDimension): 

181 simplified.add("skypix") 

182 else: 

183 simplified.add(fullDimension.name) 

184 return simplified 

185 

186 

187def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None: 

188 """Test whether data IDs are scalars for scalar connections and sequences 

189 for multiple connections. 

190 

191 Parameters 

192 ---------- 

193 name : `str` 

194 The name of the connection being tested. 

195 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

196 The data ID(s) provided for the connection. 

197 multiple : `bool` 

198 The ``multiple`` field of the connection. 

199 

200 Raises 

201 ------ 

202 ValueError 

203 Raised if ``dataIds`` and ``multiple`` do not match. 

204 """ 

205 if multiple: 

206 if not isinstance(dataIds, collections.abc.Sequence): 

207 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

208 else: 

209 # DataCoordinate is a Mapping 

210 if not isinstance(dataIds, collections.abc.Mapping): 

211 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

212 

213 

214def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]: 

215 """Represent both single and multiple data IDs as a list. 

216 

217 Parameters 

218 ---------- 

219 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

220 The data ID(s) provided for a particular input or output connection. 

221 

222 Returns 

223 ------- 

224 normalizedIds : `~collections.abc.Sequence` [data ID] 

225 A sequence equal to ``dataIds`` if it was already a sequence, or 

226 ``[dataIds]`` if it was a single ID. 

227 """ 

228 if isinstance(dataIds, collections.abc.Sequence): 

229 return dataIds 

230 else: 

231 return [dataIds] 

232 

233 

234def _refFromConnection( 

235 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

236) -> DatasetRef: 

237 """Create a DatasetRef for a connection in a collection. 

238 

239 Parameters 

240 ---------- 

241 butler : `lsst.daf.butler.Butler` 

242 The collection to point to. 

243 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

244 The connection defining the dataset type to point to. 

245 dataId 

246 The data ID for the dataset to point to. 

247 **kwargs 

248 Additional keyword arguments used to augment or construct 

249 a `~lsst.daf.butler.DataCoordinate`. 

250 

251 Returns 

252 ------- 

253 ref : `lsst.daf.butler.DatasetRef` 

254 A reference to a dataset compatible with ``connection``, with ID 

255 ``dataId``, in the collection pointed to by ``butler``. 

256 """ 

257 universe = butler.registry.dimensions 

258 # DatasetRef only tests if required dimension is missing, but not extras 

259 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

260 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

261 

262 datasetType = butler.registry.getDatasetType(connection.name) 

263 

264 try: 

265 butler.registry.getDatasetType(datasetType.name) 

266 except KeyError: 

267 raise ValueError(f"Invalid dataset type {connection.name}.") 

268 if not butler.run: 

269 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.") 

270 try: 

271 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run]) 

272 if registry_ref: 

273 ref = registry_ref 

274 else: 

275 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run) 

276 butler.registry._importDatasets([ref]) 

277 return ref 

278 except KeyError as e: 

279 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

280 

281 

282def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None: 

283 """Look up all input datasets a test quantum in the `Registry` to resolve 

284 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

285 ``run`` attributes). 

286 

287 Parameters 

288 ---------- 

289 quantum : `~lsst.daf.butler.Quantum` 

290 Single Quantum instance. 

291 butler : `~lsst.daf.butler.Butler` 

292 Data butler. 

293 """ 

294 # TODO (DM-26819): This function is a direct copy of 

295 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

296 # `runTestQuantum` function that calls it is essentially duplicating logic 

297 # in that class as well (albeit not verbatim). We should probably move 

298 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

299 # in test code instead of having these classes at all. 

300 for refsForDatasetType in quantum.inputs.values(): 

301 newRefsForDatasetType = [] 

302 for ref in refsForDatasetType: 

303 if ref.id is None: 

304 resolvedRef = butler.registry.findDataset( 

305 ref.datasetType, ref.dataId, collections=butler.collections 

306 ) 

307 if resolvedRef is None: 

308 raise ValueError( 

309 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

310 f"in collections {butler.collections}." 

311 ) 

312 newRefsForDatasetType.append(resolvedRef) 

313 else: 

314 newRefsForDatasetType.append(ref) 

315 refsForDatasetType[:] = newRefsForDatasetType 

316 

317 

318def runTestQuantum( 

319 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

320) -> Optional[unittest.mock.Mock]: 

321 """Run a PipelineTask on a Quantum. 

322 

323 Parameters 

324 ---------- 

325 task : `lsst.pipe.base.PipelineTask` 

326 The task to run on the quantum. 

327 butler : `lsst.daf.butler.Butler` 

328 The collection to run on. 

329 quantum : `lsst.daf.butler.Quantum` 

330 The quantum to run. 

331 mockRun : `bool` 

332 Whether or not to replace ``task``'s ``run`` method. The default of 

333 `True` is recommended unless ``run`` needs to do real work (e.g., 

334 because the test needs real output datasets). 

335 

336 Returns 

337 ------- 

338 run : `unittest.mock.Mock` or `None` 

339 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

340 be queried for the arguments ``runQuantum`` passed to ``run``. 

341 """ 

342 _resolveTestQuantumInputs(butler, quantum) 

343 butlerQc = ButlerQuantumContext.from_full(butler, quantum) 

344 # This is a type ignore, because `connections` is a dynamic class, but 

345 # it for sure will have this property 

346 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

347 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

348 if mockRun: 

349 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch( 

350 "lsst.pipe.base.ButlerQuantumContext.put" 

351 ): 

352 task.runQuantum(butlerQc, inputRefs, outputRefs) 

353 return mock 

354 else: 

355 task.runQuantum(butlerQc, inputRefs, outputRefs) 

356 return None 

357 

358 

359def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

360 """Test that an attribute on an object matches the specification given in 

361 a connection. 

362 

363 Parameters 

364 ---------- 

365 obj 

366 An object expected to contain the attribute ``attrName``. 

367 attrName : `str` 

368 The name of the attribute to be tested. 

369 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

370 The connection, usually some type of output, specifying ``attrName``. 

371 

372 Raises 

373 ------ 

374 AssertionError: 

375 Raised if ``obj.attrName`` does not match what's expected 

376 from ``connection``. 

377 """ 

378 # name 

379 try: 

380 attrValue = obj.__getattribute__(attrName) 

381 except AttributeError: 

382 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

383 # multiple 

384 if connection.multiple: 

385 if not isinstance(attrValue, collections.abc.Sequence): 

386 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

387 else: 

388 # use lazy evaluation to not use StorageClassFactory unless 

389 # necessary 

390 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

391 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

392 ): 

393 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

394 # no test for storageClass, as I'm not sure how much persistence 

395 # depends on duck-typing 

396 

397 

398def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

399 """Test that the output of a call to ``run`` conforms to its own 

400 connections. 

401 

402 Parameters 

403 ---------- 

404 task : `lsst.pipe.base.PipelineTask` 

405 The task whose connections need validation. This is a fully-configured 

406 task object to support features such as optional outputs. 

407 result : `lsst.pipe.base.Struct` 

408 A result object produced by calling ``task.run``. 

409 

410 Raises 

411 ------ 

412 AssertionError: 

413 Raised if ``result`` does not match what's expected from ``task's`` 

414 connections. 

415 """ 

416 # This is a type ignore, because `connections` is a dynamic class, but 

417 # it for sure will have this property 

418 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

419 

420 for name in connections.outputs: 

421 connection = connections.__getattribute__(name) 

422 _assertAttributeMatchesConnection(result, name, connection) 

423 

424 

425def assertValidInitOutput(task: PipelineTask) -> None: 

426 """Test that a constructed task conforms to its own init-connections. 

427 

428 Parameters 

429 ---------- 

430 task : `lsst.pipe.base.PipelineTask` 

431 The task whose connections need validation. 

432 

433 Raises 

434 ------ 

435 AssertionError: 

436 Raised if ``task`` does not have the state expected from ``task's`` 

437 connections. 

438 """ 

439 # This is a type ignore, because `connections` is a dynamic class, but 

440 # it for sure will have this property 

441 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

442 

443 for name in connections.initOutputs: 

444 connection = connections.__getattribute__(name) 

445 _assertAttributeMatchesConnection(task, name, connection) 

446 

447 

448def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]: 

449 """Return the initInputs object that would have been passed to a 

450 `~lsst.pipe.base.PipelineTask` constructor. 

451 

452 Parameters 

453 ---------- 

454 butler : `lsst.daf.butler.Butler` 

455 The repository to search for input datasets. Must have 

456 pre-configured collections. 

457 config : `lsst.pipe.base.PipelineTaskConfig` 

458 The config for the task to be constructed. 

459 

460 Returns 

461 ------- 

462 initInputs : `dict` [`str`] 

463 A dictionary of objects in the format of the ``initInputs`` parameter 

464 to `lsst.pipe.base.PipelineTask`. 

465 """ 

466 connections = config.connections.ConnectionsClass(config=config) 

467 initInputs = {} 

468 for name in connections.initInputs: 

469 attribute = getattr(connections, name) 

470 # Get full dataset type to check for consistency problems 

471 dsType = DatasetType( 

472 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass 

473 ) 

474 # All initInputs have empty data IDs 

475 initInputs[name] = butler.get(dsType) 

476 

477 return initInputs 

478 

479 

480def lintConnections( 

481 connections: PipelineTaskConnections, 

482 *, 

483 checkMissingMultiple: bool = True, 

484 checkUnnecessaryMultiple: bool = True, 

485) -> None: 

486 """Inspect a connections class for common errors. 

487 

488 These tests are designed to detect misuse of connections features in 

489 standard designs. An unusually designed connections class may trigger 

490 alerts despite being correctly written; specific checks can be turned off 

491 using keywords. 

492 

493 Parameters 

494 ---------- 

495 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

496 The connections class to test. 

497 checkMissingMultiple : `bool` 

498 Whether to test for single connections that would match multiple 

499 datasets at run time. 

500 checkUnnecessaryMultiple : `bool` 

501 Whether to test for multiple connections that would only match 

502 one dataset. 

503 

504 Raises 

505 ------ 

506 AssertionError 

507 Raised if any of the selected checks fail for any connection. 

508 """ 

509 # Since all comparisons are inside the class, don't bother 

510 # normalizing skypix. 

511 quantumDimensions = connections.dimensions 

512 

513 errors = "" 

514 # connectionTypes.DimensionedConnection is implementation detail, 

515 # don't use it. 

516 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

517 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

518 connDimensions = set(connection.dimensions) 

519 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

520 errors += ( 

521 f"Connection {name} may be called with multiple values of " 

522 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

523 ) 

524 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

525 errors += ( 

526 f"Connection {name} has multiple=True but can only be called with one " 

527 f"value of {connDimensions} for each {quantumDimensions}.\n" 

528 ) 

529 if errors: 

530 raise AssertionError(errors)