Coverage for python/lsst/pipe/base/testUtils.py : 12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23__all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"]
26import collections.abc
27import itertools
28import unittest.mock
30from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory
31from lsst.pipe.base import ButlerQuantumContext
34def makeQuantum(task, butler, dataId, ioDataIds):
35 """Create a Quantum for a particular data ID(s).
37 Parameters
38 ----------
39 task : `lsst.pipe.base.PipelineTask`
40 The task whose processing the quantum represents.
41 butler : `lsst.daf.butler.Butler`
42 The collection the quantum refers to.
43 dataId: any data ID type
44 The data ID of the quantum. Must have the same dimensions as
45 ``task``'s connections class.
46 ioDataIds : `collections.abc.Mapping` [`str`]
47 A mapping keyed by input/output names. Values must be data IDs for
48 single connections and sequences of data IDs for multiple connections.
50 Returns
51 -------
52 quantum : `lsst.daf.butler.Quantum`
53 A quantum for ``task``, when called with ``dataIds``.
54 """
55 quantum = Quantum(taskClass=type(task), dataId=dataId)
56 connections = task.config.ConnectionsClass(config=task.config)
58 try:
59 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
60 connection = connections.__getattribute__(name)
61 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
62 ids = _normalizeDataIds(ioDataIds[name])
63 for id in ids:
64 quantum.addPredictedInput(_refFromConnection(butler, connection, id))
65 for name in connections.outputs:
66 connection = connections.__getattribute__(name)
67 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
68 ids = _normalizeDataIds(ioDataIds[name])
69 for id in ids:
70 quantum.addOutput(_refFromConnection(butler, connection, id))
71 return quantum
72 except KeyError as e:
73 raise ValueError("Mismatch in input data.") from e
76def _checkDataIdMultiplicity(name, dataIds, multiple):
77 """Test whether data IDs are scalars for scalar connections and sequences
78 for multiple connections.
80 Parameters
81 ----------
82 name : `str`
83 The name of the connection being tested.
84 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
85 The data ID(s) provided for the connection.
86 multiple : `bool`
87 The ``multiple`` field of the connection.
89 Raises
90 ------
91 ValueError
92 Raised if ``dataIds`` and ``multiple`` do not match.
93 """
94 if multiple:
95 if not isinstance(dataIds, collections.abc.Sequence):
96 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
97 else:
98 # DataCoordinate is a Mapping
99 if not isinstance(dataIds, collections.abc.Mapping):
100 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
103def _normalizeDataIds(dataIds):
104 """Represent both single and multiple data IDs as a list.
106 Parameters
107 ----------
108 dataIds : any data ID type or `~collections.abc.Sequence` thereof
109 The data ID(s) provided for a particular input or output connection.
111 Returns
112 -------
113 normalizedIds : `~collections.abc.Sequence` [data ID]
114 A sequence equal to ``dataIds`` if it was already a sequence, or
115 ``[dataIds]`` if it was a single ID.
116 """
117 if isinstance(dataIds, collections.abc.Sequence):
118 return dataIds
119 else:
120 return [dataIds]
123def _refFromConnection(butler, connection, dataId, **kwargs):
124 """Create a DatasetRef for a connection in a collection.
126 Parameters
127 ----------
128 butler : `lsst.daf.butler.Butler`
129 The collection to point to.
130 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
131 The connection defining the dataset type to point to.
132 dataId
133 The data ID for the dataset to point to.
134 **kwargs
135 Additional keyword arguments used to augment or construct
136 a `~lsst.daf.butler.DataCoordinate`.
138 Returns
139 -------
140 ref : `lsst.daf.butler.DatasetRef`
141 A reference to a dataset compatible with ``connection``, with ID
142 ``dataId``, in the collection pointed to by ``butler``.
143 """
144 universe = butler.registry.dimensions
145 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
147 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
148 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
149 if "skypix" in connection.dimensions:
150 datasetType = butler.registry.getDatasetType(connection.name)
151 else:
152 datasetType = connection.makeDatasetType(universe)
154 try:
155 butler.registry.getDatasetType(datasetType.name)
156 except KeyError:
157 raise ValueError(f"Invalid dataset type {connection.name}.")
158 try:
159 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
160 return ref
161 except KeyError as e:
162 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
163 from e
166def runTestQuantum(task, butler, quantum, mockRun=True):
167 """Run a PipelineTask on a Quantum.
169 Parameters
170 ----------
171 task : `lsst.pipe.base.PipelineTask`
172 The task to run on the quantum.
173 butler : `lsst.daf.butler.Butler`
174 The collection to run on.
175 quantum : `lsst.daf.butler.Quantum`
176 The quantum to run.
177 mockRun : `bool`
178 Whether or not to replace ``task``'s ``run`` method. The default of
179 `True` is recommended unless ``run`` needs to do real work (e.g.,
180 because the test needs real output datasets).
182 Returns
183 -------
184 run : `unittest.mock.Mock` or `None`
185 If ``mockRun`` is set, the mock that replaced ``run``. This object can
186 be queried for the arguments ``runQuantum`` passed to ``run``.
187 """
188 butlerQc = ButlerQuantumContext(butler, quantum)
189 connections = task.config.ConnectionsClass(config=task.config)
190 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
191 if mockRun:
192 with unittest.mock.patch.object(task, "run") as mock, \
193 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"):
194 task.runQuantum(butlerQc, inputRefs, outputRefs)
195 return mock
196 else:
197 task.runQuantum(butlerQc, inputRefs, outputRefs)
198 return None
201def assertValidOutput(task, result):
202 """Test that the output of a call to ``run`` conforms to its own connections.
204 Parameters
205 ----------
206 task : `lsst.pipe.base.PipelineTask`
207 The task whose connections need validation. This is a fully-configured
208 task object to support features such as optional outputs.
209 result : `lsst.pipe.base.Struct`
210 A result object produced by calling ``task.run``.
212 Raises
213 -------
214 AssertionError:
215 Raised if ``result`` does not match what's expected from ``task's``
216 connections.
217 """
218 connections = task.config.ConnectionsClass(config=task.config)
219 recoveredOutputs = result.getDict()
221 for name in connections.outputs:
222 connection = connections.__getattribute__(name)
223 # name
224 try:
225 output = recoveredOutputs[name]
226 except KeyError:
227 raise AssertionError(f"No such output: {name}")
228 # multiple
229 if connection.multiple:
230 if not isinstance(output, collections.abc.Sequence):
231 raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.")
232 else:
233 # use lazy evaluation to not use StorageClassFactory unless necessary
234 if isinstance(output, collections.abc.Sequence) \
235 and not issubclass(
236 StorageClassFactory().getStorageClass(connection.storageClass).pytype,
237 collections.abc.Sequence):
238 raise AssertionError(f"Expected {name} to be a single value, got {output} instead.")
239 # no test for storageClass, as I'm not sure how much persistence depends on duck-typing