Coverage for python/lsst/pipe/base/testUtils.py: 13%

147 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-04 05:01 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "assertValidInitOutput", 

26 "assertValidOutput", 

27 "getInitInputs", 

28 "lintConnections", 

29 "makeQuantum", 

30 "runTestQuantum", 

31] 

32 

33 

34import collections.abc 

35import itertools 

36import unittest.mock 

37import warnings 

38from collections import defaultdict 

39from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union 

40 

41from lsst.daf.butler import ( 

42 Butler, 

43 DataCoordinate, 

44 DataId, 

45 DatasetRef, 

46 DatasetType, 

47 Dimension, 

48 DimensionUniverse, 

49 Quantum, 

50 SkyPixDimension, 

51 StorageClassFactory, 

52 UnresolvedRefWarning, 

53) 

54from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

55 

56from .butlerQuantumContext import ButlerQuantumContext 

57 

58if TYPE_CHECKING: 58 ↛ 59line 58 didn't jump to line 59, because the condition on line 58 was never true

59 from .config import PipelineTaskConfig 

60 from .connections import PipelineTaskConnections 

61 from .pipelineTask import PipelineTask 

62 from .struct import Struct 

63 

64 

65def makeQuantum( 

66 task: PipelineTask, 

67 butler: Butler, 

68 dataId: DataId, 

69 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]], 

70) -> Quantum: 

71 """Create a Quantum for a particular data ID(s). 

72 

73 Parameters 

74 ---------- 

75 task : `lsst.pipe.base.PipelineTask` 

76 The task whose processing the quantum represents. 

77 butler : `lsst.daf.butler.Butler` 

78 The collection the quantum refers to. 

79 dataId: any data ID type 

80 The data ID of the quantum. Must have the same dimensions as 

81 ``task``'s connections class. 

82 ioDataIds : `collections.abc.Mapping` [`str`] 

83 A mapping keyed by input/output names. Values must be data IDs for 

84 single connections and sequences of data IDs for multiple connections. 

85 

86 Returns 

87 ------- 

88 quantum : `lsst.daf.butler.Quantum` 

89 A quantum for ``task``, when called with ``dataIds``. 

90 """ 

91 # This is a type ignore, because `connections` is a dynamic class, but 

92 # it for sure will have this property 

93 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

94 

95 try: 

96 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys()) 

97 except ValueError as e: 

98 raise ValueError("Error in quantum dimensions.") from e 

99 

100 inputs = defaultdict(list) 

101 outputs = defaultdict(list) 

102 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

103 try: 

104 connection = connections.__getattribute__(name) 

105 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

106 ids = _normalizeDataIds(ioDataIds[name]) 

107 for id in ids: 

108 ref = _refFromConnection(butler, connection, id) 

109 inputs[ref.datasetType].append(ref) 

110 except (ValueError, KeyError) as e: 

111 raise ValueError(f"Error in connection {name}.") from e 

112 for name in connections.outputs: 

113 try: 

114 connection = connections.__getattribute__(name) 

115 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

116 ids = _normalizeDataIds(ioDataIds[name]) 

117 for id in ids: 

118 ref = _refFromConnection(butler, connection, id) 

119 outputs[ref.datasetType].append(ref) 

120 except (ValueError, KeyError) as e: 

121 raise ValueError(f"Error in connection {name}.") from e 

122 quantum = Quantum( 

123 taskClass=type(task), 

124 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions), 

125 inputs=inputs, 

126 outputs=outputs, 

127 ) 

128 return quantum 

129 

130 

131def _checkDimensionsMatch( 

132 universe: DimensionUniverse, 

133 expected: Union[AbstractSet[str], AbstractSet[Dimension]], 

134 actual: Union[AbstractSet[str], AbstractSet[Dimension]], 

135) -> None: 

136 """Test whether two sets of dimensions agree after conversions. 

137 

138 Parameters 

139 ---------- 

140 universe : `lsst.daf.butler.DimensionUniverse` 

141 The set of all known dimensions. 

142 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

143 The dimensions expected from a task specification. 

144 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

145 The dimensions provided by input. 

146 

147 Raises 

148 ------ 

149 ValueError 

150 Raised if ``expected`` and ``actual`` cannot be reconciled. 

151 """ 

152 if _simplify(universe, expected) != _simplify(universe, actual): 

153 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

154 

155 

156def _simplify( 

157 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]] 

158) -> Set[str]: 

159 """Reduce a set of dimensions to a string-only form. 

160 

161 Parameters 

162 ---------- 

163 universe : `lsst.daf.butler.DimensionUniverse` 

164 The set of all known dimensions. 

165 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

166 A set of dimensions to simplify. 

167 

168 Returns 

169 ------- 

170 dimensions : `Set` [`str`] 

171 A copy of ``dimensions`` reduced to string form, with all spatial 

172 dimensions simplified to ``skypix``. 

173 """ 

174 simplified: Set[str] = set() 

175 for dimension in dimensions: 

176 # skypix not a real Dimension, handle it first 

177 if dimension == "skypix": 

178 simplified.add(dimension) # type: ignore 

179 else: 

180 # Need a Dimension to test spatialness 

181 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

182 if isinstance(fullDimension, SkyPixDimension): 

183 simplified.add("skypix") 

184 else: 

185 simplified.add(fullDimension.name) 

186 return simplified 

187 

188 

189def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None: 

190 """Test whether data IDs are scalars for scalar connections and sequences 

191 for multiple connections. 

192 

193 Parameters 

194 ---------- 

195 name : `str` 

196 The name of the connection being tested. 

197 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

198 The data ID(s) provided for the connection. 

199 multiple : `bool` 

200 The ``multiple`` field of the connection. 

201 

202 Raises 

203 ------ 

204 ValueError 

205 Raised if ``dataIds`` and ``multiple`` do not match. 

206 """ 

207 if multiple: 

208 if not isinstance(dataIds, collections.abc.Sequence): 

209 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

210 else: 

211 # DataCoordinate is a Mapping 

212 if not isinstance(dataIds, collections.abc.Mapping): 

213 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

214 

215 

216def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]: 

217 """Represent both single and multiple data IDs as a list. 

218 

219 Parameters 

220 ---------- 

221 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

222 The data ID(s) provided for a particular input or output connection. 

223 

224 Returns 

225 ------- 

226 normalizedIds : `~collections.abc.Sequence` [data ID] 

227 A sequence equal to ``dataIds`` if it was already a sequence, or 

228 ``[dataIds]`` if it was a single ID. 

229 """ 

230 if isinstance(dataIds, collections.abc.Sequence): 

231 return dataIds 

232 else: 

233 return [dataIds] 

234 

235 

236def _refFromConnection( 

237 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

238) -> DatasetRef: 

239 """Create a DatasetRef for a connection in a collection. 

240 

241 Parameters 

242 ---------- 

243 butler : `lsst.daf.butler.Butler` 

244 The collection to point to. 

245 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

246 The connection defining the dataset type to point to. 

247 dataId 

248 The data ID for the dataset to point to. 

249 **kwargs 

250 Additional keyword arguments used to augment or construct 

251 a `~lsst.daf.butler.DataCoordinate`. 

252 

253 Returns 

254 ------- 

255 ref : `lsst.daf.butler.DatasetRef` 

256 A reference to a dataset compatible with ``connection``, with ID 

257 ``dataId``, in the collection pointed to by ``butler``. 

258 """ 

259 universe = butler.registry.dimensions 

260 # DatasetRef only tests if required dimension is missing, but not extras 

261 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

262 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

263 

264 datasetType = butler.registry.getDatasetType(connection.name) 

265 

266 try: 

267 butler.registry.getDatasetType(datasetType.name) 

268 except KeyError: 

269 raise ValueError(f"Invalid dataset type {connection.name}.") 

270 try: 

271 with warnings.catch_warnings(): 

272 warnings.simplefilter("ignore", category=UnresolvedRefWarning) 

273 ref = DatasetRef(datasetType=datasetType, dataId=dataId) 

274 return ref 

275 except KeyError as e: 

276 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

277 

278 

279def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None: 

280 """Look up all input datasets a test quantum in the `Registry` to resolve 

281 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and 

282 ``run`` attributes). 

283 

284 Parameters 

285 ---------- 

286 quantum : `~lsst.daf.butler.Quantum` 

287 Single Quantum instance. 

288 butler : `~lsst.daf.butler.Butler` 

289 Data butler. 

290 """ 

291 # TODO (DM-26819): This function is a direct copy of 

292 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the 

293 # `runTestQuantum` function that calls it is essentially duplicating logic 

294 # in that class as well (albeit not verbatim). We should probably move 

295 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable 

296 # in test code instead of having these classes at all. 

297 for refsForDatasetType in quantum.inputs.values(): 

298 newRefsForDatasetType = [] 

299 for ref in refsForDatasetType: 

300 if ref.id is None: 

301 resolvedRef = butler.registry.findDataset( 

302 ref.datasetType, ref.dataId, collections=butler.collections 

303 ) 

304 if resolvedRef is None: 

305 raise ValueError( 

306 f"Cannot find {ref.datasetType.name} with id {ref.dataId} " 

307 f"in collections {butler.collections}." 

308 ) 

309 newRefsForDatasetType.append(resolvedRef) 

310 else: 

311 newRefsForDatasetType.append(ref) 

312 refsForDatasetType[:] = newRefsForDatasetType 

313 

314 

315def runTestQuantum( 

316 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

317) -> Optional[unittest.mock.Mock]: 

318 """Run a PipelineTask on a Quantum. 

319 

320 Parameters 

321 ---------- 

322 task : `lsst.pipe.base.PipelineTask` 

323 The task to run on the quantum. 

324 butler : `lsst.daf.butler.Butler` 

325 The collection to run on. 

326 quantum : `lsst.daf.butler.Quantum` 

327 The quantum to run. 

328 mockRun : `bool` 

329 Whether or not to replace ``task``'s ``run`` method. The default of 

330 `True` is recommended unless ``run`` needs to do real work (e.g., 

331 because the test needs real output datasets). 

332 

333 Returns 

334 ------- 

335 run : `unittest.mock.Mock` or `None` 

336 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

337 be queried for the arguments ``runQuantum`` passed to ``run``. 

338 """ 

339 _resolveTestQuantumInputs(butler, quantum) 

340 butlerQc = ButlerQuantumContext.from_full(butler, quantum) 

341 # This is a type ignore, because `connections` is a dynamic class, but 

342 # it for sure will have this property 

343 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

344 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

345 if mockRun: 

346 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch( 

347 "lsst.pipe.base.ButlerQuantumContext.put" 

348 ): 

349 task.runQuantum(butlerQc, inputRefs, outputRefs) 

350 return mock 

351 else: 

352 task.runQuantum(butlerQc, inputRefs, outputRefs) 

353 return None 

354 

355 

356def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

357 """Test that an attribute on an object matches the specification given in 

358 a connection. 

359 

360 Parameters 

361 ---------- 

362 obj 

363 An object expected to contain the attribute ``attrName``. 

364 attrName : `str` 

365 The name of the attribute to be tested. 

366 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

367 The connection, usually some type of output, specifying ``attrName``. 

368 

369 Raises 

370 ------ 

371 AssertionError: 

372 Raised if ``obj.attrName`` does not match what's expected 

373 from ``connection``. 

374 """ 

375 # name 

376 try: 

377 attrValue = obj.__getattribute__(attrName) 

378 except AttributeError: 

379 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

380 # multiple 

381 if connection.multiple: 

382 if not isinstance(attrValue, collections.abc.Sequence): 

383 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

384 else: 

385 # use lazy evaluation to not use StorageClassFactory unless 

386 # necessary 

387 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

388 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

389 ): 

390 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

391 # no test for storageClass, as I'm not sure how much persistence 

392 # depends on duck-typing 

393 

394 

395def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

396 """Test that the output of a call to ``run`` conforms to its own 

397 connections. 

398 

399 Parameters 

400 ---------- 

401 task : `lsst.pipe.base.PipelineTask` 

402 The task whose connections need validation. This is a fully-configured 

403 task object to support features such as optional outputs. 

404 result : `lsst.pipe.base.Struct` 

405 A result object produced by calling ``task.run``. 

406 

407 Raises 

408 ------ 

409 AssertionError: 

410 Raised if ``result`` does not match what's expected from ``task's`` 

411 connections. 

412 """ 

413 # This is a type ignore, because `connections` is a dynamic class, but 

414 # it for sure will have this property 

415 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

416 

417 for name in connections.outputs: 

418 connection = connections.__getattribute__(name) 

419 _assertAttributeMatchesConnection(result, name, connection) 

420 

421 

422def assertValidInitOutput(task: PipelineTask) -> None: 

423 """Test that a constructed task conforms to its own init-connections. 

424 

425 Parameters 

426 ---------- 

427 task : `lsst.pipe.base.PipelineTask` 

428 The task whose connections need validation. 

429 

430 Raises 

431 ------ 

432 AssertionError: 

433 Raised if ``task`` does not have the state expected from ``task's`` 

434 connections. 

435 """ 

436 # This is a type ignore, because `connections` is a dynamic class, but 

437 # it for sure will have this property 

438 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

439 

440 for name in connections.initOutputs: 

441 connection = connections.__getattribute__(name) 

442 _assertAttributeMatchesConnection(task, name, connection) 

443 

444 

445def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]: 

446 """Return the initInputs object that would have been passed to a 

447 `~lsst.pipe.base.PipelineTask` constructor. 

448 

449 Parameters 

450 ---------- 

451 butler : `lsst.daf.butler.Butler` 

452 The repository to search for input datasets. Must have 

453 pre-configured collections. 

454 config : `lsst.pipe.base.PipelineTaskConfig` 

455 The config for the task to be constructed. 

456 

457 Returns 

458 ------- 

459 initInputs : `dict` [`str`] 

460 A dictionary of objects in the format of the ``initInputs`` parameter 

461 to `lsst.pipe.base.PipelineTask`. 

462 """ 

463 connections = config.connections.ConnectionsClass(config=config) 

464 initInputs = {} 

465 for name in connections.initInputs: 

466 attribute = getattr(connections, name) 

467 # Get full dataset type to check for consistency problems 

468 dsType = DatasetType( 

469 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass 

470 ) 

471 # All initInputs have empty data IDs 

472 initInputs[name] = butler.get(dsType) 

473 

474 return initInputs 

475 

476 

477def lintConnections( 

478 connections: PipelineTaskConnections, 

479 *, 

480 checkMissingMultiple: bool = True, 

481 checkUnnecessaryMultiple: bool = True, 

482) -> None: 

483 """Inspect a connections class for common errors. 

484 

485 These tests are designed to detect misuse of connections features in 

486 standard designs. An unusually designed connections class may trigger 

487 alerts despite being correctly written; specific checks can be turned off 

488 using keywords. 

489 

490 Parameters 

491 ---------- 

492 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

493 The connections class to test. 

494 checkMissingMultiple : `bool` 

495 Whether to test for single connections that would match multiple 

496 datasets at run time. 

497 checkUnnecessaryMultiple : `bool` 

498 Whether to test for multiple connections that would only match 

499 one dataset. 

500 

501 Raises 

502 ------ 

503 AssertionError 

504 Raised if any of the selected checks fail for any connection. 

505 """ 

506 # Since all comparisons are inside the class, don't bother 

507 # normalizing skypix. 

508 quantumDimensions = connections.dimensions 

509 

510 errors = "" 

511 # connectionTypes.DimensionedConnection is implementation detail, 

512 # don't use it. 

513 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

514 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

515 connDimensions = set(connection.dimensions) 

516 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

517 errors += ( 

518 f"Connection {name} may be called with multiple values of " 

519 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

520 ) 

521 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

522 errors += ( 

523 f"Connection {name} has multiple=True but can only be called with one " 

524 f"value of {connDimensions} for each {quantumDimensions}.\n" 

525 ) 

526 if errors: 

527 raise AssertionError(errors)