Coverage for python/lsst/pipe/base/testUtils.py: 13%
131 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:09 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:09 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = [
31 "assertValidInitOutput",
32 "assertValidOutput",
33 "getInitInputs",
34 "lintConnections",
35 "makeQuantum",
36 "runTestQuantum",
37]
40import collections.abc
41import itertools
42import unittest.mock
43from collections import defaultdict
44from collections.abc import Mapping, Sequence, Set
45from typing import TYPE_CHECKING, Any
47from lsst.daf.butler import (
48 Butler,
49 DataCoordinate,
50 DataId,
51 DatasetRef,
52 DatasetType,
53 Dimension,
54 DimensionUniverse,
55 Quantum,
56 SkyPixDimension,
57 StorageClassFactory,
58)
59from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
61from ._quantumContext import QuantumContext
63if TYPE_CHECKING:
64 from .config import PipelineTaskConfig
65 from .connections import PipelineTaskConnections
66 from .pipelineTask import PipelineTask
67 from .struct import Struct
70def makeQuantum(
71 task: PipelineTask,
72 butler: Butler,
73 dataId: DataId,
74 ioDataIds: Mapping[str, DataId | Sequence[DataId]],
75) -> Quantum:
76 """Create a Quantum for a particular data ID(s).
78 Parameters
79 ----------
80 task : `lsst.pipe.base.PipelineTask`
81 The task whose processing the quantum represents.
82 butler : `lsst.daf.butler.Butler`
83 The collection the quantum refers to.
84 dataId: any data ID type
85 The data ID of the quantum. Must have the same dimensions as
86 ``task``'s connections class.
87 ioDataIds : `collections.abc.Mapping` [`str`]
88 A mapping keyed by input/output names. Values must be data IDs for
89 single connections and sequences of data IDs for multiple connections.
91 Returns
92 -------
93 quantum : `lsst.daf.butler.Quantum`
94 A quantum for ``task``, when called with ``dataIds``.
95 """
96 # This is a type ignore, because `connections` is a dynamic class, but
97 # it for sure will have this property
98 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
100 dataId = DataCoordinate.standardize(dataId, universe=butler.dimensions)
101 try:
102 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.dimensions.required)
103 except ValueError as e:
104 raise ValueError("Error in quantum dimensions.") from e
106 inputs = defaultdict(list)
107 outputs = defaultdict(list)
108 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
109 try:
110 connection = connections.__getattribute__(name)
111 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
112 ids = _normalizeDataIds(ioDataIds[name])
113 for id in ids:
114 ref = _refFromConnection(butler, connection, id)
115 inputs[ref.datasetType].append(ref)
116 except (ValueError, KeyError) as e:
117 raise ValueError(f"Error in connection {name}.") from e
118 for name in connections.outputs:
119 try:
120 connection = connections.__getattribute__(name)
121 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
122 ids = _normalizeDataIds(ioDataIds[name])
123 for id in ids:
124 ref = _refFromConnection(butler, connection, id)
125 outputs[ref.datasetType].append(ref)
126 except (ValueError, KeyError) as e:
127 raise ValueError(f"Error in connection {name}.") from e
128 quantum = Quantum(
129 taskClass=type(task),
130 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions),
131 inputs=inputs,
132 outputs=outputs,
133 )
134 return quantum
137def _checkDimensionsMatch(
138 universe: DimensionUniverse,
139 expected: Set[str] | Set[Dimension],
140 actual: Set[str] | Set[Dimension],
141) -> None:
142 """Test whether two sets of dimensions agree after conversions.
144 Parameters
145 ----------
146 universe : `lsst.daf.butler.DimensionUniverse`
147 The set of all known dimensions.
148 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
149 The dimensions expected from a task specification.
150 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
151 The dimensions provided by input.
153 Raises
154 ------
155 ValueError
156 Raised if ``expected`` and ``actual`` cannot be reconciled.
157 """
158 if _simplify(universe, expected) != _simplify(universe, actual):
159 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
162def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]:
163 """Reduce a set of dimensions to a string-only form.
165 Parameters
166 ----------
167 universe : `lsst.daf.butler.DimensionUniverse`
168 The set of all known dimensions.
169 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
170 A set of dimensions to simplify.
172 Returns
173 -------
174 dimensions : `Set` [`str`]
175 A copy of ``dimensions`` reduced to string form, with all spatial
176 dimensions simplified to ``skypix``.
177 """
178 simplified: set[str] = set()
179 for dimension in dimensions:
180 # skypix not a real Dimension, handle it first
181 if dimension == "skypix":
182 simplified.add(dimension) # type: ignore
183 else:
184 # Need a Dimension to test spatialness
185 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
186 if isinstance(fullDimension, SkyPixDimension):
187 simplified.add("skypix")
188 else:
189 simplified.add(fullDimension.name)
190 return simplified
193def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None:
194 """Test whether data IDs are scalars for scalar connections and sequences
195 for multiple connections.
197 Parameters
198 ----------
199 name : `str`
200 The name of the connection being tested.
201 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
202 The data ID(s) provided for the connection.
203 multiple : `bool`
204 The ``multiple`` field of the connection.
206 Raises
207 ------
208 ValueError
209 Raised if ``dataIds`` and ``multiple`` do not match.
210 """
211 if multiple:
212 if not isinstance(dataIds, collections.abc.Sequence):
213 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
214 else:
215 # DataCoordinate is a Mapping
216 if not isinstance(dataIds, collections.abc.Mapping):
217 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
220def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]:
221 """Represent both single and multiple data IDs as a list.
223 Parameters
224 ----------
225 dataIds : any data ID type or `~collections.abc.Sequence` thereof
226 The data ID(s) provided for a particular input or output connection.
228 Returns
229 -------
230 normalizedIds : `~collections.abc.Sequence` [data ID]
231 A sequence equal to ``dataIds`` if it was already a sequence, or
232 ``[dataIds]`` if it was a single ID.
233 """
234 if isinstance(dataIds, collections.abc.Sequence):
235 return dataIds
236 else:
237 return [dataIds]
240def _refFromConnection(
241 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
242) -> DatasetRef:
243 """Create a DatasetRef for a connection in a collection.
245 Parameters
246 ----------
247 butler : `lsst.daf.butler.Butler`
248 The collection to point to.
249 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
250 The connection defining the dataset type to point to.
251 dataId
252 The data ID for the dataset to point to.
253 **kwargs
254 Additional keyword arguments used to augment or construct
255 a `~lsst.daf.butler.DataCoordinate`.
257 Returns
258 -------
259 ref : `lsst.daf.butler.DatasetRef`
260 A reference to a dataset compatible with ``connection``, with ID
261 ``dataId``, in the collection pointed to by ``butler``.
262 """
263 universe = butler.dimensions
264 # DatasetRef only tests if required dimension is missing, but not extras
265 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
266 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.dimensions.required)
268 datasetType = butler.get_dataset_type(connection.name)
270 try:
271 butler.get_dataset_type(datasetType.name)
272 except KeyError:
273 raise ValueError(f"Invalid dataset type {connection.name}.") from None
274 if not butler.run:
275 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.")
276 try:
277 registry_ref = butler.find_dataset(datasetType, dataId, collections=[butler.run])
278 if registry_ref:
279 ref = registry_ref
280 else:
281 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run)
282 butler.registry._importDatasets([ref])
283 return ref
284 except KeyError as e:
285 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId} not compatible.") from e
288def runTestQuantum(
289 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
290) -> unittest.mock.Mock | None:
291 """Run a PipelineTask on a Quantum.
293 Parameters
294 ----------
295 task : `lsst.pipe.base.PipelineTask`
296 The task to run on the quantum.
297 butler : `lsst.daf.butler.Butler`
298 The collection to run on.
299 quantum : `lsst.daf.butler.Quantum`
300 The quantum to run.
301 mockRun : `bool`
302 Whether or not to replace ``task``'s ``run`` method. The default of
303 `True` is recommended unless ``run`` needs to do real work (e.g.,
304 because the test needs real output datasets).
306 Returns
307 -------
308 run : `unittest.mock.Mock` or `None`
309 If ``mockRun`` is set, the mock that replaced ``run``. This object can
310 be queried for the arguments ``runQuantum`` passed to ``run``.
311 """
312 butlerQc = QuantumContext(butler, quantum)
313 # This is a type ignore, because `connections` is a dynamic class, but
314 # it for sure will have this property
315 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
316 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
317 if mockRun:
318 with (
319 unittest.mock.patch.object(task, "run") as mock,
320 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"),
321 ):
322 task.runQuantum(butlerQc, inputRefs, outputRefs)
323 return mock
324 else:
325 task.runQuantum(butlerQc, inputRefs, outputRefs)
326 return None
329def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
330 """Test that an attribute on an object matches the specification given in
331 a connection.
333 Parameters
334 ----------
335 obj
336 An object expected to contain the attribute ``attrName``.
337 attrName : `str`
338 The name of the attribute to be tested.
339 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
340 The connection, usually some type of output, specifying ``attrName``.
342 Raises
343 ------
344 AssertionError:
345 Raised if ``obj.attrName`` does not match what's expected
346 from ``connection``.
347 """
348 # name
349 try:
350 attrValue = obj.__getattribute__(attrName)
351 except AttributeError:
352 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None
353 # multiple
354 if connection.multiple:
355 if not isinstance(attrValue, collections.abc.Sequence):
356 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
357 else:
358 # use lazy evaluation to not use StorageClassFactory unless
359 # necessary
360 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
361 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
362 ):
363 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
364 # no test for storageClass, as I'm not sure how much persistence
365 # depends on duck-typing
368def assertValidOutput(task: PipelineTask, result: Struct) -> None:
369 """Test that the output of a call to ``run`` conforms to its own
370 connections.
372 Parameters
373 ----------
374 task : `lsst.pipe.base.PipelineTask`
375 The task whose connections need validation. This is a fully-configured
376 task object to support features such as optional outputs.
377 result : `lsst.pipe.base.Struct`
378 A result object produced by calling ``task.run``.
380 Raises
381 ------
382 AssertionError:
383 Raised if ``result`` does not match what's expected from ``task's``
384 connections.
385 """
386 # This is a type ignore, because `connections` is a dynamic class, but
387 # it for sure will have this property
388 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
390 for name in connections.outputs:
391 connection = connections.__getattribute__(name)
392 _assertAttributeMatchesConnection(result, name, connection)
395def assertValidInitOutput(task: PipelineTask) -> None:
396 """Test that a constructed task conforms to its own init-connections.
398 Parameters
399 ----------
400 task : `lsst.pipe.base.PipelineTask`
401 The task whose connections need validation.
403 Raises
404 ------
405 AssertionError:
406 Raised if ``task`` does not have the state expected from ``task's``
407 connections.
408 """
409 # This is a type ignore, because `connections` is a dynamic class, but
410 # it for sure will have this property
411 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
413 for name in connections.initOutputs:
414 connection = connections.__getattribute__(name)
415 _assertAttributeMatchesConnection(task, name, connection)
418def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]:
419 """Return the initInputs object that would have been passed to a
420 `~lsst.pipe.base.PipelineTask` constructor.
422 Parameters
423 ----------
424 butler : `lsst.daf.butler.Butler`
425 The repository to search for input datasets. Must have
426 pre-configured collections.
427 config : `lsst.pipe.base.PipelineTaskConfig`
428 The config for the task to be constructed.
430 Returns
431 -------
432 initInputs : `dict` [`str`]
433 A dictionary of objects in the format of the ``initInputs`` parameter
434 to `lsst.pipe.base.PipelineTask`.
435 """
436 connections = config.connections.ConnectionsClass(config=config)
437 initInputs = {}
438 for name in connections.initInputs:
439 attribute = getattr(connections, name)
440 # Get full dataset type to check for consistency problems
441 dsType = DatasetType(attribute.name, butler.dimensions.empty, attribute.storageClass)
442 # All initInputs have empty data IDs
443 initInputs[name] = butler.get(dsType)
445 return initInputs
448def lintConnections(
449 connections: PipelineTaskConnections,
450 *,
451 checkMissingMultiple: bool = True,
452 checkUnnecessaryMultiple: bool = True,
453) -> None:
454 """Inspect a connections class for common errors.
456 These tests are designed to detect misuse of connections features in
457 standard designs. An unusually designed connections class may trigger
458 alerts despite being correctly written; specific checks can be turned off
459 using keywords.
461 Parameters
462 ----------
463 connections : `lsst.pipe.base.PipelineTaskConnections`-type
464 The connections class to test.
465 checkMissingMultiple : `bool`
466 Whether to test for single connections that would match multiple
467 datasets at run time.
468 checkUnnecessaryMultiple : `bool`
469 Whether to test for multiple connections that would only match
470 one dataset.
472 Raises
473 ------
474 AssertionError
475 Raised if any of the selected checks fail for any connection.
476 """
477 # Since all comparisons are inside the class, don't bother
478 # normalizing skypix.
479 quantumDimensions = connections.dimensions
481 errors = ""
482 # connectionTypes.DimensionedConnection is implementation detail,
483 # don't use it.
484 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
485 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
486 connDimensions = set(connection.dimensions)
487 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
488 errors += (
489 f"Connection {name} may be called with multiple values of "
490 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
491 )
492 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
493 errors += (
494 f"Connection {name} has multiple=True but can only be called with one "
495 f"value of {connDimensions} for each {quantumDimensions}.\n"
496 )
497 if errors:
498 raise AssertionError(errors)