Coverage for python/lsst/pipe/base/testUtils.py: 13%

130 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-17 10:52 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = [ 

31 "assertValidInitOutput", 

32 "assertValidOutput", 

33 "getInitInputs", 

34 "lintConnections", 

35 "makeQuantum", 

36 "runTestQuantum", 

37] 

38 

39 

40import collections.abc 

41import itertools 

42import unittest.mock 

43from collections import defaultdict 

44from collections.abc import Mapping, Sequence, Set 

45from typing import TYPE_CHECKING, Any 

46 

47from lsst.daf.butler import ( 

48 Butler, 

49 DataCoordinate, 

50 DataId, 

51 DatasetRef, 

52 DatasetType, 

53 Dimension, 

54 DimensionUniverse, 

55 Quantum, 

56 SkyPixDimension, 

57 StorageClassFactory, 

58) 

59from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

60 

61from ._quantumContext import QuantumContext 

62 

63if TYPE_CHECKING: 

64 from .config import PipelineTaskConfig 

65 from .connections import PipelineTaskConnections 

66 from .pipelineTask import PipelineTask 

67 from .struct import Struct 

68 

69 

70def makeQuantum( 

71 task: PipelineTask, 

72 butler: Butler, 

73 dataId: DataId, 

74 ioDataIds: Mapping[str, DataId | Sequence[DataId]], 

75) -> Quantum: 

76 """Create a Quantum for a particular data ID(s). 

77 

78 Parameters 

79 ---------- 

80 task : `lsst.pipe.base.PipelineTask` 

81 The task whose processing the quantum represents. 

82 butler : `lsst.daf.butler.Butler` 

83 The collection the quantum refers to. 

84 dataId: any data ID type 

85 The data ID of the quantum. Must have the same dimensions as 

86 ``task``'s connections class. 

87 ioDataIds : `collections.abc.Mapping` [`str`] 

88 A mapping keyed by input/output names. Values must be data IDs for 

89 single connections and sequences of data IDs for multiple connections. 

90 

91 Returns 

92 ------- 

93 quantum : `lsst.daf.butler.Quantum` 

94 A quantum for ``task``, when called with ``dataIds``. 

95 """ 

96 # This is a type ignore, because `connections` is a dynamic class, but 

97 # it for sure will have this property 

98 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

99 

100 try: 

101 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys()) 

102 except ValueError as e: 

103 raise ValueError("Error in quantum dimensions.") from e 

104 

105 inputs = defaultdict(list) 

106 outputs = defaultdict(list) 

107 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

108 try: 

109 connection = connections.__getattribute__(name) 

110 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

111 ids = _normalizeDataIds(ioDataIds[name]) 

112 for id in ids: 

113 ref = _refFromConnection(butler, connection, id) 

114 inputs[ref.datasetType].append(ref) 

115 except (ValueError, KeyError) as e: 

116 raise ValueError(f"Error in connection {name}.") from e 

117 for name in connections.outputs: 

118 try: 

119 connection = connections.__getattribute__(name) 

120 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

121 ids = _normalizeDataIds(ioDataIds[name]) 

122 for id in ids: 

123 ref = _refFromConnection(butler, connection, id) 

124 outputs[ref.datasetType].append(ref) 

125 except (ValueError, KeyError) as e: 

126 raise ValueError(f"Error in connection {name}.") from e 

127 quantum = Quantum( 

128 taskClass=type(task), 

129 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions), 

130 inputs=inputs, 

131 outputs=outputs, 

132 ) 

133 return quantum 

134 

135 

136def _checkDimensionsMatch( 

137 universe: DimensionUniverse, 

138 expected: Set[str] | Set[Dimension], 

139 actual: Set[str] | Set[Dimension], 

140) -> None: 

141 """Test whether two sets of dimensions agree after conversions. 

142 

143 Parameters 

144 ---------- 

145 universe : `lsst.daf.butler.DimensionUniverse` 

146 The set of all known dimensions. 

147 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

148 The dimensions expected from a task specification. 

149 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

150 The dimensions provided by input. 

151 

152 Raises 

153 ------ 

154 ValueError 

155 Raised if ``expected`` and ``actual`` cannot be reconciled. 

156 """ 

157 if _simplify(universe, expected) != _simplify(universe, actual): 

158 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

159 

160 

161def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]: 

162 """Reduce a set of dimensions to a string-only form. 

163 

164 Parameters 

165 ---------- 

166 universe : `lsst.daf.butler.DimensionUniverse` 

167 The set of all known dimensions. 

168 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

169 A set of dimensions to simplify. 

170 

171 Returns 

172 ------- 

173 dimensions : `Set` [`str`] 

174 A copy of ``dimensions`` reduced to string form, with all spatial 

175 dimensions simplified to ``skypix``. 

176 """ 

177 simplified: set[str] = set() 

178 for dimension in dimensions: 

179 # skypix not a real Dimension, handle it first 

180 if dimension == "skypix": 

181 simplified.add(dimension) # type: ignore 

182 else: 

183 # Need a Dimension to test spatialness 

184 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

185 if isinstance(fullDimension, SkyPixDimension): 

186 simplified.add("skypix") 

187 else: 

188 simplified.add(fullDimension.name) 

189 return simplified 

190 

191 

192def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None: 

193 """Test whether data IDs are scalars for scalar connections and sequences 

194 for multiple connections. 

195 

196 Parameters 

197 ---------- 

198 name : `str` 

199 The name of the connection being tested. 

200 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

201 The data ID(s) provided for the connection. 

202 multiple : `bool` 

203 The ``multiple`` field of the connection. 

204 

205 Raises 

206 ------ 

207 ValueError 

208 Raised if ``dataIds`` and ``multiple`` do not match. 

209 """ 

210 if multiple: 

211 if not isinstance(dataIds, collections.abc.Sequence): 

212 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

213 else: 

214 # DataCoordinate is a Mapping 

215 if not isinstance(dataIds, collections.abc.Mapping): 

216 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

217 

218 

219def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]: 

220 """Represent both single and multiple data IDs as a list. 

221 

222 Parameters 

223 ---------- 

224 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

225 The data ID(s) provided for a particular input or output connection. 

226 

227 Returns 

228 ------- 

229 normalizedIds : `~collections.abc.Sequence` [data ID] 

230 A sequence equal to ``dataIds`` if it was already a sequence, or 

231 ``[dataIds]`` if it was a single ID. 

232 """ 

233 if isinstance(dataIds, collections.abc.Sequence): 

234 return dataIds 

235 else: 

236 return [dataIds] 

237 

238 

239def _refFromConnection( 

240 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

241) -> DatasetRef: 

242 """Create a DatasetRef for a connection in a collection. 

243 

244 Parameters 

245 ---------- 

246 butler : `lsst.daf.butler.Butler` 

247 The collection to point to. 

248 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

249 The connection defining the dataset type to point to. 

250 dataId 

251 The data ID for the dataset to point to. 

252 **kwargs 

253 Additional keyword arguments used to augment or construct 

254 a `~lsst.daf.butler.DataCoordinate`. 

255 

256 Returns 

257 ------- 

258 ref : `lsst.daf.butler.DatasetRef` 

259 A reference to a dataset compatible with ``connection``, with ID 

260 ``dataId``, in the collection pointed to by ``butler``. 

261 """ 

262 universe = butler.dimensions 

263 # DatasetRef only tests if required dimension is missing, but not extras 

264 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys()) 

265 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

266 

267 datasetType = butler.get_dataset_type(connection.name) 

268 

269 try: 

270 butler.get_dataset_type(datasetType.name) 

271 except KeyError: 

272 raise ValueError(f"Invalid dataset type {connection.name}.") from None 

273 if not butler.run: 

274 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.") 

275 try: 

276 registry_ref = butler.find_dataset(datasetType, dataId, collections=[butler.run]) 

277 if registry_ref: 

278 ref = registry_ref 

279 else: 

280 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run) 

281 butler.registry._importDatasets([ref]) 

282 return ref 

283 except KeyError as e: 

284 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e 

285 

286 

287def runTestQuantum( 

288 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

289) -> unittest.mock.Mock | None: 

290 """Run a PipelineTask on a Quantum. 

291 

292 Parameters 

293 ---------- 

294 task : `lsst.pipe.base.PipelineTask` 

295 The task to run on the quantum. 

296 butler : `lsst.daf.butler.Butler` 

297 The collection to run on. 

298 quantum : `lsst.daf.butler.Quantum` 

299 The quantum to run. 

300 mockRun : `bool` 

301 Whether or not to replace ``task``'s ``run`` method. The default of 

302 `True` is recommended unless ``run`` needs to do real work (e.g., 

303 because the test needs real output datasets). 

304 

305 Returns 

306 ------- 

307 run : `unittest.mock.Mock` or `None` 

308 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

309 be queried for the arguments ``runQuantum`` passed to ``run``. 

310 """ 

311 butlerQc = QuantumContext(butler, quantum) 

312 # This is a type ignore, because `connections` is a dynamic class, but 

313 # it for sure will have this property 

314 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

315 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

316 if mockRun: 

317 with ( 

318 unittest.mock.patch.object(task, "run") as mock, 

319 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"), 

320 ): 

321 task.runQuantum(butlerQc, inputRefs, outputRefs) 

322 return mock 

323 else: 

324 task.runQuantum(butlerQc, inputRefs, outputRefs) 

325 return None 

326 

327 

328def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

329 """Test that an attribute on an object matches the specification given in 

330 a connection. 

331 

332 Parameters 

333 ---------- 

334 obj 

335 An object expected to contain the attribute ``attrName``. 

336 attrName : `str` 

337 The name of the attribute to be tested. 

338 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

339 The connection, usually some type of output, specifying ``attrName``. 

340 

341 Raises 

342 ------ 

343 AssertionError: 

344 Raised if ``obj.attrName`` does not match what's expected 

345 from ``connection``. 

346 """ 

347 # name 

348 try: 

349 attrValue = obj.__getattribute__(attrName) 

350 except AttributeError: 

351 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None 

352 # multiple 

353 if connection.multiple: 

354 if not isinstance(attrValue, collections.abc.Sequence): 

355 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

356 else: 

357 # use lazy evaluation to not use StorageClassFactory unless 

358 # necessary 

359 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

360 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

361 ): 

362 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

363 # no test for storageClass, as I'm not sure how much persistence 

364 # depends on duck-typing 

365 

366 

367def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

368 """Test that the output of a call to ``run`` conforms to its own 

369 connections. 

370 

371 Parameters 

372 ---------- 

373 task : `lsst.pipe.base.PipelineTask` 

374 The task whose connections need validation. This is a fully-configured 

375 task object to support features such as optional outputs. 

376 result : `lsst.pipe.base.Struct` 

377 A result object produced by calling ``task.run``. 

378 

379 Raises 

380 ------ 

381 AssertionError: 

382 Raised if ``result`` does not match what's expected from ``task's`` 

383 connections. 

384 """ 

385 # This is a type ignore, because `connections` is a dynamic class, but 

386 # it for sure will have this property 

387 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

388 

389 for name in connections.outputs: 

390 connection = connections.__getattribute__(name) 

391 _assertAttributeMatchesConnection(result, name, connection) 

392 

393 

394def assertValidInitOutput(task: PipelineTask) -> None: 

395 """Test that a constructed task conforms to its own init-connections. 

396 

397 Parameters 

398 ---------- 

399 task : `lsst.pipe.base.PipelineTask` 

400 The task whose connections need validation. 

401 

402 Raises 

403 ------ 

404 AssertionError: 

405 Raised if ``task`` does not have the state expected from ``task's`` 

406 connections. 

407 """ 

408 # This is a type ignore, because `connections` is a dynamic class, but 

409 # it for sure will have this property 

410 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

411 

412 for name in connections.initOutputs: 

413 connection = connections.__getattribute__(name) 

414 _assertAttributeMatchesConnection(task, name, connection) 

415 

416 

417def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]: 

418 """Return the initInputs object that would have been passed to a 

419 `~lsst.pipe.base.PipelineTask` constructor. 

420 

421 Parameters 

422 ---------- 

423 butler : `lsst.daf.butler.Butler` 

424 The repository to search for input datasets. Must have 

425 pre-configured collections. 

426 config : `lsst.pipe.base.PipelineTaskConfig` 

427 The config for the task to be constructed. 

428 

429 Returns 

430 ------- 

431 initInputs : `dict` [`str`] 

432 A dictionary of objects in the format of the ``initInputs`` parameter 

433 to `lsst.pipe.base.PipelineTask`. 

434 """ 

435 connections = config.connections.ConnectionsClass(config=config) 

436 initInputs = {} 

437 for name in connections.initInputs: 

438 attribute = getattr(connections, name) 

439 # Get full dataset type to check for consistency problems 

440 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass) 

441 # All initInputs have empty data IDs 

442 initInputs[name] = butler.get(dsType) 

443 

444 return initInputs 

445 

446 

447def lintConnections( 

448 connections: PipelineTaskConnections, 

449 *, 

450 checkMissingMultiple: bool = True, 

451 checkUnnecessaryMultiple: bool = True, 

452) -> None: 

453 """Inspect a connections class for common errors. 

454 

455 These tests are designed to detect misuse of connections features in 

456 standard designs. An unusually designed connections class may trigger 

457 alerts despite being correctly written; specific checks can be turned off 

458 using keywords. 

459 

460 Parameters 

461 ---------- 

462 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

463 The connections class to test. 

464 checkMissingMultiple : `bool` 

465 Whether to test for single connections that would match multiple 

466 datasets at run time. 

467 checkUnnecessaryMultiple : `bool` 

468 Whether to test for multiple connections that would only match 

469 one dataset. 

470 

471 Raises 

472 ------ 

473 AssertionError 

474 Raised if any of the selected checks fail for any connection. 

475 """ 

476 # Since all comparisons are inside the class, don't bother 

477 # normalizing skypix. 

478 quantumDimensions = connections.dimensions 

479 

480 errors = "" 

481 # connectionTypes.DimensionedConnection is implementation detail, 

482 # don't use it. 

483 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

484 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

485 connDimensions = set(connection.dimensions) 

486 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

487 errors += ( 

488 f"Connection {name} may be called with multiple values of " 

489 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

490 ) 

491 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

492 errors += ( 

493 f"Connection {name} has multiple=True but can only be called with one " 

494 f"value of {connDimensions} for each {quantumDimensions}.\n" 

495 ) 

496 if errors: 

497 raise AssertionError(errors)