Coverage for python/lsst/pipe/base/testUtils.py: 13%

130 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 10:31 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "assertValidInitOutput", 

26 "assertValidOutput", 

27 "getInitInputs", 

28 "lintConnections", 

29 "makeQuantum", 

30 "runTestQuantum", 

31] 

32 

33 

34import collections.abc 

35import itertools 

36import unittest.mock 

37from collections import defaultdict 

38from collections.abc import Mapping, Sequence, Set 

39from typing import TYPE_CHECKING, Any 

40 

41from lsst.daf.butler import ( 

42 Butler, 

43 DataCoordinate, 

44 DataId, 

45 DatasetRef, 

46 DatasetType, 

47 Dimension, 

48 DimensionUniverse, 

49 Quantum, 

50 SkyPixDimension, 

51 StorageClassFactory, 

52) 

53from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

54 

55from ._quantumContext import QuantumContext 

56 

57if TYPE_CHECKING: 

58 from .config import PipelineTaskConfig 

59 from .connections import PipelineTaskConnections 

60 from .pipelineTask import PipelineTask 

61 from .struct import Struct 

62 

63 

64def makeQuantum( 

65 task: PipelineTask, 

66 butler: Butler, 

67 dataId: DataId, 

68 ioDataIds: Mapping[str, DataId | Sequence[DataId]], 

69) -> Quantum: 

70 """Create a Quantum for a particular data ID(s). 

71 

72 Parameters 

73 ---------- 

74 task : `lsst.pipe.base.PipelineTask` 

75 The task whose processing the quantum represents. 

76 butler : `lsst.daf.butler.Butler` 

77 The collection the quantum refers to. 

78 dataId: any data ID type 

79 The data ID of the quantum. Must have the same dimensions as 

80 ``task``'s connections class. 

81 ioDataIds : `collections.abc.Mapping` [`str`] 

82 A mapping keyed by input/output names. Values must be data IDs for 

83 single connections and sequences of data IDs for multiple connections. 

84 

85 Returns 

86 ------- 

87 quantum : `lsst.daf.butler.Quantum` 

88 A quantum for ``task``, when called with ``dataIds``. 

89 """ 

90 # This is a type ignore, because `connections` is a dynamic class, but 

91 # it for sure will have this property 

92 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

93 

94 try: 

95 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys()) 

96 except ValueError as e: 

97 raise ValueError("Error in quantum dimensions.") from e 

98 

99 inputs = defaultdict(list) 

100 outputs = defaultdict(list) 

101 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

102 try: 

103 connection = connections.__getattribute__(name) 

104 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

105 ids = _normalizeDataIds(ioDataIds[name]) 

106 for id in ids: 

107 ref = _refFromConnection(butler, connection, id) 

108 inputs[ref.datasetType].append(ref) 

109 except (ValueError, KeyError) as e: 

110 raise ValueError(f"Error in connection {name}.") from e 

111 for name in connections.outputs: 

112 try: 

113 connection = connections.__getattribute__(name) 

114 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

115 ids = _normalizeDataIds(ioDataIds[name]) 

116 for id in ids: 

117 ref = _refFromConnection(butler, connection, id) 

118 outputs[ref.datasetType].append(ref) 

119 except (ValueError, KeyError) as e: 

120 raise ValueError(f"Error in connection {name}.") from e 

121 quantum = Quantum( 

122 taskClass=type(task), 

123 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions), 

124 inputs=inputs, 

125 outputs=outputs, 

126 ) 

127 return quantum 

128 

129 

130def _checkDimensionsMatch( 

131 universe: DimensionUniverse, 

132 expected: Set[str] | Set[Dimension], 

133 actual: Set[str] | Set[Dimension], 

134) -> None: 

135 """Test whether two sets of dimensions agree after conversions. 

136 

137 Parameters 

138 ---------- 

139 universe : `lsst.daf.butler.DimensionUniverse` 

140 The set of all known dimensions. 

141 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

142 The dimensions expected from a task specification. 

143 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

144 The dimensions provided by input. 

145 

146 Raises 

147 ------ 

148 ValueError 

149 Raised if ``expected`` and ``actual`` cannot be reconciled. 

150 """ 

151 if _simplify(universe, expected) != _simplify(universe, actual): 

152 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

153 

154 

155def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]: 

156 """Reduce a set of dimensions to a string-only form. 

157 

158 Parameters 

159 ---------- 

160 universe : `lsst.daf.butler.DimensionUniverse` 

161 The set of all known dimensions. 

162 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

163 A set of dimensions to simplify. 

164 

165 Returns 

166 ------- 

167 dimensions : `Set` [`str`] 

168 A copy of ``dimensions`` reduced to string form, with all spatial 

169 dimensions simplified to ``skypix``. 

170 """ 

171 simplified: set[str] = set() 

172 for dimension in dimensions: 

173 # skypix not a real Dimension, handle it first 

174 if dimension == "skypix": 

175 simplified.add(dimension) # type: ignore 

176 else: 

177 # Need a Dimension to test spatialness 

178 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

179 if isinstance(fullDimension, SkyPixDimension): 

180 simplified.add("skypix") 

181 else: 

182 simplified.add(fullDimension.name) 

183 return simplified 

184 

185 

186def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None: 

187 """Test whether data IDs are scalars for scalar connections and sequences 

188 for multiple connections. 

189 

190 Parameters 

191 ---------- 

192 name : `str` 

193 The name of the connection being tested. 

194 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

195 The data ID(s) provided for the connection. 

196 multiple : `bool` 

197 The ``multiple`` field of the connection. 

198 

199 Raises 

200 ------ 

201 ValueError 

202 Raised if ``dataIds`` and ``multiple`` do not match. 

203 """ 

204 if multiple: 

205 if not isinstance(dataIds, collections.abc.Sequence): 

206 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

207 else: 

208 # DataCoordinate is a Mapping 

209 if not isinstance(dataIds, collections.abc.Mapping): 

210 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

211 

212 

213def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]: 

214 """Represent both single and multiple data IDs as a list. 

215 

216 Parameters 

217 ---------- 

218 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

219 The data ID(s) provided for a particular input or output connection. 

220 

221 Returns 

222 ------- 

223 normalizedIds : `~collections.abc.Sequence` [data ID] 

224 A sequence equal to ``dataIds`` if it was already a sequence, or 

225 ``[dataIds]`` if it was a single ID. 

226 """ 

227 if isinstance(dataIds, collections.abc.Sequence): 

228 return dataIds 

229 else: 

230 return [dataIds] 

231 

232 

233def _refFromConnection( 

234 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

235) -> DatasetRef: 

236 """Create a DatasetRef for a connection in a collection. 

237 

238 Parameters 

239 ---------- 

240 butler : `lsst.daf.butler.Butler` 

241 The collection to point to. 

242 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

243 The connection defining the dataset type to point to. 

244 dataId 

245 The data ID for the dataset to point to. 

246 **kwargs 

247 Additional keyword arguments used to augment or construct 

248 a `~lsst.daf.butler.DataCoordinate`. 

249 

250 Returns 

251 ------- 

252 ref : `lsst.daf.butler.DatasetRef` 

253 A reference to a dataset compatible with ``connection``, with ID 

254 ``dataId``, in the collection pointed to by ``butler``. 

255 """ 

256 universe = butler.dimensions 

257 # DatasetRef only tests if required dimension is missing, but not extras 

258 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

259 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

260 

261 datasetType = butler.registry.getDatasetType(connection.name) 

262 

263 try: 

264 butler.registry.getDatasetType(datasetType.name) 

265 except KeyError: 

266 raise ValueError(f"Invalid dataset type {connection.name}.") from None 

267 if not butler.run: 

268 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.") 

269 try: 

270 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run]) 

271 if registry_ref: 

272 ref = registry_ref 

273 else: 

274 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run) 

275 butler.registry._importDatasets([ref]) 

276 return ref 

277 except KeyError as e: 

278 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

279 

280 

281def runTestQuantum( 

282 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

283) -> unittest.mock.Mock | None: 

284 """Run a PipelineTask on a Quantum. 

285 

286 Parameters 

287 ---------- 

288 task : `lsst.pipe.base.PipelineTask` 

289 The task to run on the quantum. 

290 butler : `lsst.daf.butler.Butler` 

291 The collection to run on. 

292 quantum : `lsst.daf.butler.Quantum` 

293 The quantum to run. 

294 mockRun : `bool` 

295 Whether or not to replace ``task``'s ``run`` method. The default of 

296 `True` is recommended unless ``run`` needs to do real work (e.g., 

297 because the test needs real output datasets). 

298 

299 Returns 

300 ------- 

301 run : `unittest.mock.Mock` or `None` 

302 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

303 be queried for the arguments ``runQuantum`` passed to ``run``. 

304 """ 

305 butlerQc = QuantumContext(butler, quantum) 

306 # This is a type ignore, because `connections` is a dynamic class, but 

307 # it for sure will have this property 

308 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

309 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

310 if mockRun: 

311 with ( 

312 unittest.mock.patch.object(task, "run") as mock, 

313 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"), 

314 ): 

315 task.runQuantum(butlerQc, inputRefs, outputRefs) 

316 return mock 

317 else: 

318 task.runQuantum(butlerQc, inputRefs, outputRefs) 

319 return None 

320 

321 

322def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

323 """Test that an attribute on an object matches the specification given in 

324 a connection. 

325 

326 Parameters 

327 ---------- 

328 obj 

329 An object expected to contain the attribute ``attrName``. 

330 attrName : `str` 

331 The name of the attribute to be tested. 

332 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

333 The connection, usually some type of output, specifying ``attrName``. 

334 

335 Raises 

336 ------ 

337 AssertionError: 

338 Raised if ``obj.attrName`` does not match what's expected 

339 from ``connection``. 

340 """ 

341 # name 

342 try: 

343 attrValue = obj.__getattribute__(attrName) 

344 except AttributeError: 

345 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None 

346 # multiple 

347 if connection.multiple: 

348 if not isinstance(attrValue, collections.abc.Sequence): 

349 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

350 else: 

351 # use lazy evaluation to not use StorageClassFactory unless 

352 # necessary 

353 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

354 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

355 ): 

356 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

357 # no test for storageClass, as I'm not sure how much persistence 

358 # depends on duck-typing 

359 

360 

361def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

362 """Test that the output of a call to ``run`` conforms to its own 

363 connections. 

364 

365 Parameters 

366 ---------- 

367 task : `lsst.pipe.base.PipelineTask` 

368 The task whose connections need validation. This is a fully-configured 

369 task object to support features such as optional outputs. 

370 result : `lsst.pipe.base.Struct` 

371 A result object produced by calling ``task.run``. 

372 

373 Raises 

374 ------ 

375 AssertionError: 

376 Raised if ``result`` does not match what's expected from ``task's`` 

377 connections. 

378 """ 

379 # This is a type ignore, because `connections` is a dynamic class, but 

380 # it for sure will have this property 

381 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

382 

383 for name in connections.outputs: 

384 connection = connections.__getattribute__(name) 

385 _assertAttributeMatchesConnection(result, name, connection) 

386 

387 

388def assertValidInitOutput(task: PipelineTask) -> None: 

389 """Test that a constructed task conforms to its own init-connections. 

390 

391 Parameters 

392 ---------- 

393 task : `lsst.pipe.base.PipelineTask` 

394 The task whose connections need validation. 

395 

396 Raises 

397 ------ 

398 AssertionError: 

399 Raised if ``task`` does not have the state expected from ``task's`` 

400 connections. 

401 """ 

402 # This is a type ignore, because `connections` is a dynamic class, but 

403 # it for sure will have this property 

404 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

405 

406 for name in connections.initOutputs: 

407 connection = connections.__getattribute__(name) 

408 _assertAttributeMatchesConnection(task, name, connection) 

409 

410 

411def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]: 

412 """Return the initInputs object that would have been passed to a 

413 `~lsst.pipe.base.PipelineTask` constructor. 

414 

415 Parameters 

416 ---------- 

417 butler : `lsst.daf.butler.Butler` 

418 The repository to search for input datasets. Must have 

419 pre-configured collections. 

420 config : `lsst.pipe.base.PipelineTaskConfig` 

421 The config for the task to be constructed. 

422 

423 Returns 

424 ------- 

425 initInputs : `dict` [`str`] 

426 A dictionary of objects in the format of the ``initInputs`` parameter 

427 to `lsst.pipe.base.PipelineTask`. 

428 """ 

429 connections = config.connections.ConnectionsClass(config=config) 

430 initInputs = {} 

431 for name in connections.initInputs: 

432 attribute = getattr(connections, name) 

433 # Get full dataset type to check for consistency problems 

434 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass) 

435 # All initInputs have empty data IDs 

436 initInputs[name] = butler.get(dsType) 

437 

438 return initInputs 

439 

440 

441def lintConnections( 

442 connections: PipelineTaskConnections, 

443 *, 

444 checkMissingMultiple: bool = True, 

445 checkUnnecessaryMultiple: bool = True, 

446) -> None: 

447 """Inspect a connections class for common errors. 

448 

449 These tests are designed to detect misuse of connections features in 

450 standard designs. An unusually designed connections class may trigger 

451 alerts despite being correctly written; specific checks can be turned off 

452 using keywords. 

453 

454 Parameters 

455 ---------- 

456 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

457 The connections class to test. 

458 checkMissingMultiple : `bool` 

459 Whether to test for single connections that would match multiple 

460 datasets at run time. 

461 checkUnnecessaryMultiple : `bool` 

462 Whether to test for multiple connections that would only match 

463 one dataset. 

464 

465 Raises 

466 ------ 

467 AssertionError 

468 Raised if any of the selected checks fail for any connection. 

469 """ 

470 # Since all comparisons are inside the class, don't bother 

471 # normalizing skypix. 

472 quantumDimensions = connections.dimensions 

473 

474 errors = "" 

475 # connectionTypes.DimensionedConnection is implementation detail, 

476 # don't use it. 

477 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

478 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

479 connDimensions = set(connection.dimensions) 

480 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

481 errors += ( 

482 f"Connection {name} may be called with multiple values of " 

483 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

484 ) 

485 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

486 errors += ( 

487 f"Connection {name} has multiple=True but can only be called with one " 

488 f"value of {connDimensions} for each {quantumDimensions}.\n" 

489 ) 

490 if errors: 

491 raise AssertionError(errors)