Coverage for python/lsst/pipe/base/testUtils.py: 13%

130 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "assertValidInitOutput", 

26 "assertValidOutput", 

27 "getInitInputs", 

28 "lintConnections", 

29 "makeQuantum", 

30 "runTestQuantum", 

31] 

32 

33 

34import collections.abc 

35import itertools 

36import unittest.mock 

37from collections import defaultdict 

38from collections.abc import Mapping, Sequence, Set 

39from typing import TYPE_CHECKING, Any 

40 

41from lsst.daf.butler import ( 

42 Butler, 

43 DataCoordinate, 

44 DataId, 

45 DatasetRef, 

46 DatasetType, 

47 Dimension, 

48 DimensionUniverse, 

49 Quantum, 

50 SkyPixDimension, 

51 StorageClassFactory, 

52) 

53from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

54 

55from ._quantumContext import QuantumContext 

56 

57if TYPE_CHECKING: 

58 from .config import PipelineTaskConfig 

59 from .connections import PipelineTaskConnections 

60 from .pipelineTask import PipelineTask 

61 from .struct import Struct 

62 

63 

64def makeQuantum( 

65 task: PipelineTask, 

66 butler: Butler, 

67 dataId: DataId, 

68 ioDataIds: Mapping[str, DataId | Sequence[DataId]], 

69) -> Quantum: 

70 """Create a Quantum for a particular data ID(s). 

71 

72 Parameters 

73 ---------- 

74 task : `lsst.pipe.base.PipelineTask` 

75 The task whose processing the quantum represents. 

76 butler : `lsst.daf.butler.Butler` 

77 The collection the quantum refers to. 

78 dataId: any data ID type 

79 The data ID of the quantum. Must have the same dimensions as 

80 ``task``'s connections class. 

81 ioDataIds : `collections.abc.Mapping` [`str`] 

82 A mapping keyed by input/output names. Values must be data IDs for 

83 single connections and sequences of data IDs for multiple connections. 

84 

85 Returns 

86 ------- 

87 quantum : `lsst.daf.butler.Quantum` 

88 A quantum for ``task``, when called with ``dataIds``. 

89 """ 

90 # This is a type ignore, because `connections` is a dynamic class, but 

91 # it for sure will have this property 

92 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

93 

94 try: 

95 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys()) 

96 except ValueError as e: 

97 raise ValueError("Error in quantum dimensions.") from e 

98 

99 inputs = defaultdict(list) 

100 outputs = defaultdict(list) 

101 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

102 try: 

103 connection = connections.__getattribute__(name) 

104 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

105 ids = _normalizeDataIds(ioDataIds[name]) 

106 for id in ids: 

107 ref = _refFromConnection(butler, connection, id) 

108 inputs[ref.datasetType].append(ref) 

109 except (ValueError, KeyError) as e: 

110 raise ValueError(f"Error in connection {name}.") from e 

111 for name in connections.outputs: 

112 try: 

113 connection = connections.__getattribute__(name) 

114 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

115 ids = _normalizeDataIds(ioDataIds[name]) 

116 for id in ids: 

117 ref = _refFromConnection(butler, connection, id) 

118 outputs[ref.datasetType].append(ref) 

119 except (ValueError, KeyError) as e: 

120 raise ValueError(f"Error in connection {name}.") from e 

121 quantum = Quantum( 

122 taskClass=type(task), 

123 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions), 

124 inputs=inputs, 

125 outputs=outputs, 

126 ) 

127 return quantum 

128 

129 

130def _checkDimensionsMatch( 

131 universe: DimensionUniverse, 

132 expected: Set[str] | Set[Dimension], 

133 actual: Set[str] | Set[Dimension], 

134) -> None: 

135 """Test whether two sets of dimensions agree after conversions. 

136 

137 Parameters 

138 ---------- 

139 universe : `lsst.daf.butler.DimensionUniverse` 

140 The set of all known dimensions. 

141 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

142 The dimensions expected from a task specification. 

143 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

144 The dimensions provided by input. 

145 

146 Raises 

147 ------ 

148 ValueError 

149 Raised if ``expected`` and ``actual`` cannot be reconciled. 

150 """ 

151 if _simplify(universe, expected) != _simplify(universe, actual): 

152 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

153 

154 

155def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]: 

156 """Reduce a set of dimensions to a string-only form. 

157 

158 Parameters 

159 ---------- 

160 universe : `lsst.daf.butler.DimensionUniverse` 

161 The set of all known dimensions. 

162 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

163 A set of dimensions to simplify. 

164 

165 Returns 

166 ------- 

167 dimensions : `Set` [`str`] 

168 A copy of ``dimensions`` reduced to string form, with all spatial 

169 dimensions simplified to ``skypix``. 

170 """ 

171 simplified: set[str] = set() 

172 for dimension in dimensions: 

173 # skypix not a real Dimension, handle it first 

174 if dimension == "skypix": 

175 simplified.add(dimension) # type: ignore 

176 else: 

177 # Need a Dimension to test spatialness 

178 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

179 if isinstance(fullDimension, SkyPixDimension): 

180 simplified.add("skypix") 

181 else: 

182 simplified.add(fullDimension.name) 

183 return simplified 

184 

185 

186def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None: 

187 """Test whether data IDs are scalars for scalar connections and sequences 

188 for multiple connections. 

189 

190 Parameters 

191 ---------- 

192 name : `str` 

193 The name of the connection being tested. 

194 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

195 The data ID(s) provided for the connection. 

196 multiple : `bool` 

197 The ``multiple`` field of the connection. 

198 

199 Raises 

200 ------ 

201 ValueError 

202 Raised if ``dataIds`` and ``multiple`` do not match. 

203 """ 

204 if multiple: 

205 if not isinstance(dataIds, collections.abc.Sequence): 

206 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

207 else: 

208 # DataCoordinate is a Mapping 

209 if not isinstance(dataIds, collections.abc.Mapping): 

210 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

211 

212 

213def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]: 

214 """Represent both single and multiple data IDs as a list. 

215 

216 Parameters 

217 ---------- 

218 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

219 The data ID(s) provided for a particular input or output connection. 

220 

221 Returns 

222 ------- 

223 normalizedIds : `~collections.abc.Sequence` [data ID] 

224 A sequence equal to ``dataIds`` if it was already a sequence, or 

225 ``[dataIds]`` if it was a single ID. 

226 """ 

227 if isinstance(dataIds, collections.abc.Sequence): 

228 return dataIds 

229 else: 

230 return [dataIds] 

231 

232 

233def _refFromConnection( 

234 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

235) -> DatasetRef: 

236 """Create a DatasetRef for a connection in a collection. 

237 

238 Parameters 

239 ---------- 

240 butler : `lsst.daf.butler.Butler` 

241 The collection to point to. 

242 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

243 The connection defining the dataset type to point to. 

244 dataId 

245 The data ID for the dataset to point to. 

246 **kwargs 

247 Additional keyword arguments used to augment or construct 

248 a `~lsst.daf.butler.DataCoordinate`. 

249 

250 Returns 

251 ------- 

252 ref : `lsst.daf.butler.DatasetRef` 

253 A reference to a dataset compatible with ``connection``, with ID 

254 ``dataId``, in the collection pointed to by ``butler``. 

255 """ 

256 universe = butler.dimensions 

257 # DatasetRef only tests if required dimension is missing, but not extras 

258 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

259 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

260 

261 datasetType = butler.registry.getDatasetType(connection.name) 

262 

263 try: 

264 butler.registry.getDatasetType(datasetType.name) 

265 except KeyError: 

266 raise ValueError(f"Invalid dataset type {connection.name}.") 

267 if not butler.run: 

268 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.") 

269 try: 

270 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run]) 

271 if registry_ref: 

272 ref = registry_ref 

273 else: 

274 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run) 

275 butler.registry._importDatasets([ref]) 

276 return ref 

277 except KeyError as e: 

278 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

279 

280 

281def runTestQuantum( 

282 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

283) -> unittest.mock.Mock | None: 

284 """Run a PipelineTask on a Quantum. 

285 

286 Parameters 

287 ---------- 

288 task : `lsst.pipe.base.PipelineTask` 

289 The task to run on the quantum. 

290 butler : `lsst.daf.butler.Butler` 

291 The collection to run on. 

292 quantum : `lsst.daf.butler.Quantum` 

293 The quantum to run. 

294 mockRun : `bool` 

295 Whether or not to replace ``task``'s ``run`` method. The default of 

296 `True` is recommended unless ``run`` needs to do real work (e.g., 

297 because the test needs real output datasets). 

298 

299 Returns 

300 ------- 

301 run : `unittest.mock.Mock` or `None` 

302 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

303 be queried for the arguments ``runQuantum`` passed to ``run``. 

304 """ 

305 butlerQc = QuantumContext(butler, quantum) 

306 # This is a type ignore, because `connections` is a dynamic class, but 

307 # it for sure will have this property 

308 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

309 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

310 if mockRun: 

311 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch( 

312 "lsst.pipe.base.QuantumContext.put" 

313 ): 

314 task.runQuantum(butlerQc, inputRefs, outputRefs) 

315 return mock 

316 else: 

317 task.runQuantum(butlerQc, inputRefs, outputRefs) 

318 return None 

319 

320 

321def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

322 """Test that an attribute on an object matches the specification given in 

323 a connection. 

324 

325 Parameters 

326 ---------- 

327 obj 

328 An object expected to contain the attribute ``attrName``. 

329 attrName : `str` 

330 The name of the attribute to be tested. 

331 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

332 The connection, usually some type of output, specifying ``attrName``. 

333 

334 Raises 

335 ------ 

336 AssertionError: 

337 Raised if ``obj.attrName`` does not match what's expected 

338 from ``connection``. 

339 """ 

340 # name 

341 try: 

342 attrValue = obj.__getattribute__(attrName) 

343 except AttributeError: 

344 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") 

345 # multiple 

346 if connection.multiple: 

347 if not isinstance(attrValue, collections.abc.Sequence): 

348 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

349 else: 

350 # use lazy evaluation to not use StorageClassFactory unless 

351 # necessary 

352 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

353 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

354 ): 

355 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

356 # no test for storageClass, as I'm not sure how much persistence 

357 # depends on duck-typing 

358 

359 

360def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

361 """Test that the output of a call to ``run`` conforms to its own 

362 connections. 

363 

364 Parameters 

365 ---------- 

366 task : `lsst.pipe.base.PipelineTask` 

367 The task whose connections need validation. This is a fully-configured 

368 task object to support features such as optional outputs. 

369 result : `lsst.pipe.base.Struct` 

370 A result object produced by calling ``task.run``. 

371 

372 Raises 

373 ------ 

374 AssertionError: 

375 Raised if ``result`` does not match what's expected from ``task's`` 

376 connections. 

377 """ 

378 # This is a type ignore, because `connections` is a dynamic class, but 

379 # it for sure will have this property 

380 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

381 

382 for name in connections.outputs: 

383 connection = connections.__getattribute__(name) 

384 _assertAttributeMatchesConnection(result, name, connection) 

385 

386 

387def assertValidInitOutput(task: PipelineTask) -> None: 

388 """Test that a constructed task conforms to its own init-connections. 

389 

390 Parameters 

391 ---------- 

392 task : `lsst.pipe.base.PipelineTask` 

393 The task whose connections need validation. 

394 

395 Raises 

396 ------ 

397 AssertionError: 

398 Raised if ``task`` does not have the state expected from ``task's`` 

399 connections. 

400 """ 

401 # This is a type ignore, because `connections` is a dynamic class, but 

402 # it for sure will have this property 

403 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

404 

405 for name in connections.initOutputs: 

406 connection = connections.__getattribute__(name) 

407 _assertAttributeMatchesConnection(task, name, connection) 

408 

409 

410def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]: 

411 """Return the initInputs object that would have been passed to a 

412 `~lsst.pipe.base.PipelineTask` constructor. 

413 

414 Parameters 

415 ---------- 

416 butler : `lsst.daf.butler.Butler` 

417 The repository to search for input datasets. Must have 

418 pre-configured collections. 

419 config : `lsst.pipe.base.PipelineTaskConfig` 

420 The config for the task to be constructed. 

421 

422 Returns 

423 ------- 

424 initInputs : `dict` [`str`] 

425 A dictionary of objects in the format of the ``initInputs`` parameter 

426 to `lsst.pipe.base.PipelineTask`. 

427 """ 

428 connections = config.connections.ConnectionsClass(config=config) 

429 initInputs = {} 

430 for name in connections.initInputs: 

431 attribute = getattr(connections, name) 

432 # Get full dataset type to check for consistency problems 

433 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass) 

434 # All initInputs have empty data IDs 

435 initInputs[name] = butler.get(dsType) 

436 

437 return initInputs 

438 

439 

440def lintConnections( 

441 connections: PipelineTaskConnections, 

442 *, 

443 checkMissingMultiple: bool = True, 

444 checkUnnecessaryMultiple: bool = True, 

445) -> None: 

446 """Inspect a connections class for common errors. 

447 

448 These tests are designed to detect misuse of connections features in 

449 standard designs. An unusually designed connections class may trigger 

450 alerts despite being correctly written; specific checks can be turned off 

451 using keywords. 

452 

453 Parameters 

454 ---------- 

455 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

456 The connections class to test. 

457 checkMissingMultiple : `bool` 

458 Whether to test for single connections that would match multiple 

459 datasets at run time. 

460 checkUnnecessaryMultiple : `bool` 

461 Whether to test for multiple connections that would only match 

462 one dataset. 

463 

464 Raises 

465 ------ 

466 AssertionError 

467 Raised if any of the selected checks fail for any connection. 

468 """ 

469 # Since all comparisons are inside the class, don't bother 

470 # normalizing skypix. 

471 quantumDimensions = connections.dimensions 

472 

473 errors = "" 

474 # connectionTypes.DimensionedConnection is implementation detail, 

475 # don't use it. 

476 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

477 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

478 connDimensions = set(connection.dimensions) 

479 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

480 errors += ( 

481 f"Connection {name} may be called with multiple values of " 

482 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

483 ) 

484 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

485 errors += ( 

486 f"Connection {name} has multiple=True but can only be called with one " 

487 f"value of {connDimensions} for each {quantumDimensions}.\n" 

488 ) 

489 if errors: 

490 raise AssertionError(errors)