Coverage for python/lsst/pipe/base/testUtils.py: 13%

131 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-30 10:01 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = [ 

31 "assertValidInitOutput", 

32 "assertValidOutput", 

33 "getInitInputs", 

34 "lintConnections", 

35 "makeQuantum", 

36 "runTestQuantum", 

37] 

38 

39 

40import collections.abc 

41import itertools 

42import unittest.mock 

43from collections import defaultdict 

44from collections.abc import Mapping, Sequence, Set 

45from typing import TYPE_CHECKING, Any 

46 

47from lsst.daf.butler import ( 

48 Butler, 

49 DataCoordinate, 

50 DataId, 

51 DatasetRef, 

52 DatasetType, 

53 Dimension, 

54 DimensionUniverse, 

55 Quantum, 

56 SkyPixDimension, 

57 StorageClassFactory, 

58) 

59from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection 

60 

61from ._quantumContext import QuantumContext 

62 

63if TYPE_CHECKING: 

64 from .config import PipelineTaskConfig 

65 from .connections import PipelineTaskConnections 

66 from .pipelineTask import PipelineTask 

67 from .struct import Struct 

68 

69 

70def makeQuantum( 

71 task: PipelineTask, 

72 butler: Butler, 

73 dataId: DataId, 

74 ioDataIds: Mapping[str, DataId | Sequence[DataId]], 

75) -> Quantum: 

76 """Create a Quantum for a particular data ID(s). 

77 

78 Parameters 

79 ---------- 

80 task : `lsst.pipe.base.PipelineTask` 

81 The task whose processing the quantum represents. 

82 butler : `lsst.daf.butler.Butler` 

83 The collection the quantum refers to. 

84 dataId : any data ID type 

85 The data ID of the quantum. Must have the same dimensions as 

86 ``task``'s connections class. 

87 ioDataIds : `collections.abc.Mapping` [`str`] 

88 A mapping keyed by input/output names. Values must be data IDs for 

89 single connections and sequences of data IDs for multiple connections. 

90 

91 Returns 

92 ------- 

93 quantum : `lsst.daf.butler.Quantum` 

94 A quantum for ``task``, when called with ``dataIds``. 

95 """ 

96 # This is a type ignore, because `connections` is a dynamic class, but 

97 # it for sure will have this property 

98 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

99 

100 dataId = DataCoordinate.standardize(dataId, universe=butler.dimensions) 

101 try: 

102 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.dimensions.required) 

103 except ValueError as e: 

104 raise ValueError("Error in quantum dimensions.") from e 

105 

106 inputs = defaultdict(list) 

107 outputs = defaultdict(list) 

108 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

109 try: 

110 connection = connections.__getattribute__(name) 

111 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

112 ids = _normalizeDataIds(ioDataIds[name]) 

113 for id in ids: 

114 ref = _refFromConnection(butler, connection, id) 

115 inputs[ref.datasetType].append(ref) 

116 except (ValueError, KeyError) as e: 

117 raise ValueError(f"Error in connection {name}.") from e 

118 for name in connections.outputs: 

119 try: 

120 connection = connections.__getattribute__(name) 

121 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple) 

122 ids = _normalizeDataIds(ioDataIds[name]) 

123 for id in ids: 

124 ref = _refFromConnection(butler, connection, id) 

125 outputs[ref.datasetType].append(ref) 

126 except (ValueError, KeyError) as e: 

127 raise ValueError(f"Error in connection {name}.") from e 

128 quantum = Quantum( 

129 taskClass=type(task), 

130 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions), 

131 inputs=inputs, 

132 outputs=outputs, 

133 ) 

134 return quantum 

135 

136 

137def _checkDimensionsMatch( 

138 universe: DimensionUniverse, 

139 expected: Set[str] | Set[Dimension], 

140 actual: Set[str] | Set[Dimension], 

141) -> None: 

142 """Test whether two sets of dimensions agree after conversions. 

143 

144 Parameters 

145 ---------- 

146 universe : `lsst.daf.butler.DimensionUniverse` 

147 The set of all known dimensions. 

148 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

149 The dimensions expected from a task specification. 

150 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

151 The dimensions provided by input. 

152 

153 Raises 

154 ------ 

155 ValueError 

156 Raised if ``expected`` and ``actual`` cannot be reconciled. 

157 """ 

158 if _simplify(universe, expected) != _simplify(universe, actual): 

159 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.") 

160 

161 

162def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]: 

163 """Reduce a set of dimensions to a string-only form. 

164 

165 Parameters 

166 ---------- 

167 universe : `lsst.daf.butler.DimensionUniverse` 

168 The set of all known dimensions. 

169 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`] 

170 A set of dimensions to simplify. 

171 

172 Returns 

173 ------- 

174 dimensions : `Set` [`str`] 

175 A copy of ``dimensions`` reduced to string form, with all spatial 

176 dimensions simplified to ``skypix``. 

177 """ 

178 simplified: set[str] = set() 

179 for dimension in dimensions: 

180 # skypix not a real Dimension, handle it first 

181 if dimension == "skypix": 

182 simplified.add(dimension) # type: ignore 

183 else: 

184 # Need a Dimension to test spatialness 

185 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension 

186 if isinstance(fullDimension, SkyPixDimension): 

187 simplified.add("skypix") 

188 else: 

189 simplified.add(fullDimension.name) 

190 return simplified 

191 

192 

193def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None: 

194 """Test whether data IDs are scalars for scalar connections and sequences 

195 for multiple connections. 

196 

197 Parameters 

198 ---------- 

199 name : `str` 

200 The name of the connection being tested. 

201 dataIds : any data ID type or `~collections.abc.Sequence` [data ID] 

202 The data ID(s) provided for the connection. 

203 multiple : `bool` 

204 The ``multiple`` field of the connection. 

205 

206 Raises 

207 ------ 

208 ValueError 

209 Raised if ``dataIds`` and ``multiple`` do not match. 

210 """ 

211 if multiple: 

212 if not isinstance(dataIds, collections.abc.Sequence): 

213 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.") 

214 else: 

215 # DataCoordinate is a Mapping 

216 if not isinstance(dataIds, collections.abc.Mapping): 

217 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.") 

218 

219 

220def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]: 

221 """Represent both single and multiple data IDs as a list. 

222 

223 Parameters 

224 ---------- 

225 dataIds : any data ID type or `~collections.abc.Sequence` thereof 

226 The data ID(s) provided for a particular input or output connection. 

227 

228 Returns 

229 ------- 

230 normalizedIds : `~collections.abc.Sequence` [data ID] 

231 A sequence equal to ``dataIds`` if it was already a sequence, or 

232 ``[dataIds]`` if it was a single ID. 

233 """ 

234 if isinstance(dataIds, collections.abc.Sequence): 

235 return dataIds 

236 else: 

237 return [dataIds] 

238 

239 

240def _refFromConnection( 

241 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any 

242) -> DatasetRef: 

243 """Create a DatasetRef for a connection in a collection. 

244 

245 Parameters 

246 ---------- 

247 butler : `lsst.daf.butler.Butler` 

248 The collection to point to. 

249 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection` 

250 The connection defining the dataset type to point to. 

251 dataId 

252 The data ID for the dataset to point to. 

253 **kwargs 

254 Additional keyword arguments used to augment or construct 

255 a `~lsst.daf.butler.DataCoordinate`. 

256 

257 Returns 

258 ------- 

259 ref : `lsst.daf.butler.DatasetRef` 

260 A reference to a dataset compatible with ``connection``, with ID 

261 ``dataId``, in the collection pointed to by ``butler``. 

262 """ 

263 universe = butler.dimensions 

264 # DatasetRef only tests if required dimension is missing, but not extras 

265 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe) 

266 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.dimensions.required) 

267 

268 datasetType = butler.get_dataset_type(connection.name) 

269 

270 try: 

271 butler.get_dataset_type(datasetType.name) 

272 except KeyError: 

273 raise ValueError(f"Invalid dataset type {connection.name}.") from None 

274 if not butler.run: 

275 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.") 

276 try: 

277 registry_ref = butler.find_dataset(datasetType, dataId, collections=[butler.run]) 

278 if registry_ref: 

279 ref = registry_ref 

280 else: 

281 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run) 

282 butler.registry._importDatasets([ref]) 

283 return ref 

284 except KeyError as e: 

285 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId} not compatible.") from e 

286 

287 

288def runTestQuantum( 

289 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True 

290) -> unittest.mock.Mock | None: 

291 """Run a PipelineTask on a Quantum. 

292 

293 Parameters 

294 ---------- 

295 task : `lsst.pipe.base.PipelineTask` 

296 The task to run on the quantum. 

297 butler : `lsst.daf.butler.Butler` 

298 The collection to run on. 

299 quantum : `lsst.daf.butler.Quantum` 

300 The quantum to run. 

301 mockRun : `bool` 

302 Whether or not to replace ``task``'s ``run`` method. The default of 

303 `True` is recommended unless ``run`` needs to do real work (e.g., 

304 because the test needs real output datasets). 

305 

306 Returns 

307 ------- 

308 run : `unittest.mock.Mock` or `None` 

309 If ``mockRun`` is set, the mock that replaced ``run``. This object can 

310 be queried for the arguments ``runQuantum`` passed to ``run``. 

311 """ 

312 butlerQc = QuantumContext(butler, quantum) 

313 # This is a type ignore, because `connections` is a dynamic class, but 

314 # it for sure will have this property 

315 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

316 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

317 if mockRun: 

318 with ( 

319 unittest.mock.patch.object(task, "run") as mock, 

320 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"), 

321 ): 

322 task.runQuantum(butlerQc, inputRefs, outputRefs) 

323 return mock 

324 else: 

325 task.runQuantum(butlerQc, inputRefs, outputRefs) 

326 return None 

327 

328 

329def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None: 

330 """Test that an attribute on an object matches the specification given in 

331 a connection. 

332 

333 Parameters 

334 ---------- 

335 obj 

336 An object expected to contain the attribute ``attrName``. 

337 attrName : `str` 

338 The name of the attribute to be tested. 

339 connection : `lsst.pipe.base.connectionTypes.BaseConnection` 

340 The connection, usually some type of output, specifying ``attrName``. 

341 

342 Raises 

343 ------ 

344 AssertionError: 

345 Raised if ``obj.attrName`` does not match what's expected 

346 from ``connection``. 

347 """ 

348 # name 

349 try: 

350 attrValue = obj.__getattribute__(attrName) 

351 except AttributeError: 

352 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None 

353 # multiple 

354 if connection.multiple: 

355 if not isinstance(attrValue, collections.abc.Sequence): 

356 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.") 

357 else: 

358 # use lazy evaluation to not use StorageClassFactory unless 

359 # necessary 

360 if isinstance(attrValue, collections.abc.Sequence) and not issubclass( 

361 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence 

362 ): 

363 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.") 

364 # no test for storageClass, as I'm not sure how much persistence 

365 # depends on duck-typing 

366 

367 

368def assertValidOutput(task: PipelineTask, result: Struct) -> None: 

369 """Test that the output of a call to ``run`` conforms to its own 

370 connections. 

371 

372 Parameters 

373 ---------- 

374 task : `lsst.pipe.base.PipelineTask` 

375 The task whose connections need validation. This is a fully-configured 

376 task object to support features such as optional outputs. 

377 result : `lsst.pipe.base.Struct` 

378 A result object produced by calling ``task.run``. 

379 

380 Raises 

381 ------ 

382 AssertionError: 

383 Raised if ``result`` does not match what's expected from ``task's`` 

384 connections. 

385 """ 

386 # This is a type ignore, because `connections` is a dynamic class, but 

387 # it for sure will have this property 

388 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

389 

390 for name in connections.outputs: 

391 connection = connections.__getattribute__(name) 

392 _assertAttributeMatchesConnection(result, name, connection) 

393 

394 

395def assertValidInitOutput(task: PipelineTask) -> None: 

396 """Test that a constructed task conforms to its own init-connections. 

397 

398 Parameters 

399 ---------- 

400 task : `lsst.pipe.base.PipelineTask` 

401 The task whose connections need validation. 

402 

403 Raises 

404 ------ 

405 AssertionError: 

406 Raised if ``task`` does not have the state expected from ``task's`` 

407 connections. 

408 """ 

409 # This is a type ignore, because `connections` is a dynamic class, but 

410 # it for sure will have this property 

411 connections = task.config.ConnectionsClass(config=task.config) # type: ignore 

412 

413 for name in connections.initOutputs: 

414 connection = connections.__getattribute__(name) 

415 _assertAttributeMatchesConnection(task, name, connection) 

416 

417 

418def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]: 

419 """Return the initInputs object that would have been passed to a 

420 `~lsst.pipe.base.PipelineTask` constructor. 

421 

422 Parameters 

423 ---------- 

424 butler : `lsst.daf.butler.Butler` 

425 The repository to search for input datasets. Must have 

426 pre-configured collections. 

427 config : `lsst.pipe.base.PipelineTaskConfig` 

428 The config for the task to be constructed. 

429 

430 Returns 

431 ------- 

432 initInputs : `dict` [`str`] 

433 A dictionary of objects in the format of the ``initInputs`` parameter 

434 to `lsst.pipe.base.PipelineTask`. 

435 """ 

436 connections = config.connections.ConnectionsClass(config=config) 

437 initInputs = {} 

438 for name in connections.initInputs: 

439 attribute = getattr(connections, name) 

440 # Get full dataset type to check for consistency problems 

441 dsType = DatasetType(attribute.name, butler.dimensions.empty, attribute.storageClass) 

442 # All initInputs have empty data IDs 

443 initInputs[name] = butler.get(dsType) 

444 

445 return initInputs 

446 

447 

448def lintConnections( 

449 connections: PipelineTaskConnections, 

450 *, 

451 checkMissingMultiple: bool = True, 

452 checkUnnecessaryMultiple: bool = True, 

453) -> None: 

454 """Inspect a connections class for common errors. 

455 

456 These tests are designed to detect misuse of connections features in 

457 standard designs. An unusually designed connections class may trigger 

458 alerts despite being correctly written; specific checks can be turned off 

459 using keywords. 

460 

461 Parameters 

462 ---------- 

463 connections : `lsst.pipe.base.PipelineTaskConnections`-type 

464 The connections class to test. 

465 checkMissingMultiple : `bool` 

466 Whether to test for single connections that would match multiple 

467 datasets at run time. 

468 checkUnnecessaryMultiple : `bool` 

469 Whether to test for multiple connections that would only match 

470 one dataset. 

471 

472 Raises 

473 ------ 

474 AssertionError 

475 Raised if any of the selected checks fail for any connection. 

476 """ 

477 # Since all comparisons are inside the class, don't bother 

478 # normalizing skypix. 

479 quantumDimensions = connections.dimensions 

480 

481 errors = "" 

482 # connectionTypes.DimensionedConnection is implementation detail, 

483 # don't use it. 

484 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs): 

485 connection: DimensionedConnection = connections.allConnections[name] # type: ignore 

486 connDimensions = set(connection.dimensions) 

487 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions: 

488 errors += ( 

489 f"Connection {name} may be called with multiple values of " 

490 f"{connDimensions - quantumDimensions} but has multiple=False.\n" 

491 ) 

492 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions: 

493 errors += ( 

494 f"Connection {name} has multiple=True but can only be called with one " 

495 f"value of {connDimensions} for each {quantumDimensions}.\n" 

496 ) 

497 if errors: 

498 raise AssertionError(errors)