Coverage for python/lsst/pipe/base/testUtils.py : 12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23__all__ = ["assertValidInitOutput",
24 "assertValidOutput",
25 "getInitInputs",
26 "makeQuantum",
27 "runTestQuantum",
28 ]
31from collections import defaultdict
32import collections.abc
33import itertools
34import unittest.mock
36from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory
37from lsst.pipe.base import ButlerQuantumContext
40def makeQuantum(task, butler, dataId, ioDataIds):
41 """Create a Quantum for a particular data ID(s).
43 Parameters
44 ----------
45 task : `lsst.pipe.base.PipelineTask`
46 The task whose processing the quantum represents.
47 butler : `lsst.daf.butler.Butler`
48 The collection the quantum refers to.
49 dataId: any data ID type
50 The data ID of the quantum. Must have the same dimensions as
51 ``task``'s connections class.
52 ioDataIds : `collections.abc.Mapping` [`str`]
53 A mapping keyed by input/output names. Values must be data IDs for
54 single connections and sequences of data IDs for multiple connections.
56 Returns
57 -------
58 quantum : `lsst.daf.butler.Quantum`
59 A quantum for ``task``, when called with ``dataIds``.
60 """
61 connections = task.config.ConnectionsClass(config=task.config)
63 try:
64 inputs = defaultdict(list)
65 outputs = defaultdict(list)
66 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
67 connection = connections.__getattribute__(name)
68 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
69 ids = _normalizeDataIds(ioDataIds[name])
70 for id in ids:
71 ref = _refFromConnection(butler, connection, id)
72 inputs[ref.datasetType].append(ref)
73 for name in connections.outputs:
74 connection = connections.__getattribute__(name)
75 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
76 ids = _normalizeDataIds(ioDataIds[name])
77 for id in ids:
78 ref = _refFromConnection(butler, connection, id)
79 outputs[ref.datasetType].append(ref)
80 quantum = Quantum(taskClass=type(task),
81 dataId=dataId,
82 inputs=inputs,
83 outputs=outputs)
84 return quantum
85 except KeyError as e:
86 raise ValueError("Mismatch in input data.") from e
89def _checkDataIdMultiplicity(name, dataIds, multiple):
90 """Test whether data IDs are scalars for scalar connections and sequences
91 for multiple connections.
93 Parameters
94 ----------
95 name : `str`
96 The name of the connection being tested.
97 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
98 The data ID(s) provided for the connection.
99 multiple : `bool`
100 The ``multiple`` field of the connection.
102 Raises
103 ------
104 ValueError
105 Raised if ``dataIds`` and ``multiple`` do not match.
106 """
107 if multiple:
108 if not isinstance(dataIds, collections.abc.Sequence):
109 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
110 else:
111 # DataCoordinate is a Mapping
112 if not isinstance(dataIds, collections.abc.Mapping):
113 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
116def _normalizeDataIds(dataIds):
117 """Represent both single and multiple data IDs as a list.
119 Parameters
120 ----------
121 dataIds : any data ID type or `~collections.abc.Sequence` thereof
122 The data ID(s) provided for a particular input or output connection.
124 Returns
125 -------
126 normalizedIds : `~collections.abc.Sequence` [data ID]
127 A sequence equal to ``dataIds`` if it was already a sequence, or
128 ``[dataIds]`` if it was a single ID.
129 """
130 if isinstance(dataIds, collections.abc.Sequence):
131 return dataIds
132 else:
133 return [dataIds]
136def _refFromConnection(butler, connection, dataId, **kwargs):
137 """Create a DatasetRef for a connection in a collection.
139 Parameters
140 ----------
141 butler : `lsst.daf.butler.Butler`
142 The collection to point to.
143 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
144 The connection defining the dataset type to point to.
145 dataId
146 The data ID for the dataset to point to.
147 **kwargs
148 Additional keyword arguments used to augment or construct
149 a `~lsst.daf.butler.DataCoordinate`.
151 Returns
152 -------
153 ref : `lsst.daf.butler.DatasetRef`
154 A reference to a dataset compatible with ``connection``, with ID
155 ``dataId``, in the collection pointed to by ``butler``.
156 """
157 universe = butler.registry.dimensions
158 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
160 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
161 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
162 if "skypix" in connection.dimensions:
163 datasetType = butler.registry.getDatasetType(connection.name)
164 else:
165 datasetType = connection.makeDatasetType(universe)
167 try:
168 butler.registry.getDatasetType(datasetType.name)
169 except KeyError:
170 raise ValueError(f"Invalid dataset type {connection.name}.")
171 try:
172 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
173 return ref
174 except KeyError as e:
175 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
176 from e
179def _resolveTestQuantumInputs(butler, quantum):
180 """Look up all input datasets a test quantum in the `Registry` to resolve
181 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
182 ``run`` attributes).
184 Parameters
185 ----------
186 quantum : `~lsst.daf.butler.Quantum`
187 Single Quantum instance.
188 butler : `~lsst.daf.butler.Butler`
189 Data butler.
190 """
191 # TODO (DM-26819): This function is a direct copy of
192 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
193 # `runTestQuantum` function that calls it is essentially duplicating logic
194 # in that class as well (albeit not verbatim). We should probably move
195 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
196 # in test code instead of having these classes at all.
197 for refsForDatasetType in quantum.inputs.values():
198 newRefsForDatasetType = []
199 for ref in refsForDatasetType:
200 if ref.id is None:
201 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId,
202 collections=butler.collections)
203 if resolvedRef is None:
204 raise ValueError(
205 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
206 f"in collections {butler.collections}."
207 )
208 newRefsForDatasetType.append(resolvedRef)
209 else:
210 newRefsForDatasetType.append(ref)
211 refsForDatasetType[:] = newRefsForDatasetType
214def runTestQuantum(task, butler, quantum, mockRun=True):
215 """Run a PipelineTask on a Quantum.
217 Parameters
218 ----------
219 task : `lsst.pipe.base.PipelineTask`
220 The task to run on the quantum.
221 butler : `lsst.daf.butler.Butler`
222 The collection to run on.
223 quantum : `lsst.daf.butler.Quantum`
224 The quantum to run.
225 mockRun : `bool`
226 Whether or not to replace ``task``'s ``run`` method. The default of
227 `True` is recommended unless ``run`` needs to do real work (e.g.,
228 because the test needs real output datasets).
230 Returns
231 -------
232 run : `unittest.mock.Mock` or `None`
233 If ``mockRun`` is set, the mock that replaced ``run``. This object can
234 be queried for the arguments ``runQuantum`` passed to ``run``.
235 """
236 _resolveTestQuantumInputs(butler, quantum)
237 butlerQc = ButlerQuantumContext(butler, quantum)
238 connections = task.config.ConnectionsClass(config=task.config)
239 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
240 if mockRun:
241 with unittest.mock.patch.object(task, "run") as mock, \
242 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"):
243 task.runQuantum(butlerQc, inputRefs, outputRefs)
244 return mock
245 else:
246 task.runQuantum(butlerQc, inputRefs, outputRefs)
247 return None
250def _assertAttributeMatchesConnection(obj, attrName, connection):
251 """Test that an attribute on an object matches the specification given in
252 a connection.
254 Parameters
255 ----------
256 obj
257 An object expected to contain the attribute ``attrName``.
258 attrName : `str`
259 The name of the attribute to be tested.
260 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
261 The connection, usually some type of output, specifying ``attrName``.
263 Raises
264 ------
265 AssertionError:
266 Raised if ``obj.attrName`` does not match what's expected
267 from ``connection``.
268 """
269 # name
270 try:
271 attrValue = obj.__getattribute__(attrName)
272 except AttributeError:
273 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
274 # multiple
275 if connection.multiple:
276 if not isinstance(attrValue, collections.abc.Sequence):
277 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
278 else:
279 # use lazy evaluation to not use StorageClassFactory unless
280 # necessary
281 if isinstance(attrValue, collections.abc.Sequence) \
282 and not issubclass(
283 StorageClassFactory().getStorageClass(connection.storageClass).pytype,
284 collections.abc.Sequence):
285 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
286 # no test for storageClass, as I'm not sure how much persistence
287 # depends on duck-typing
290def assertValidOutput(task, result):
291 """Test that the output of a call to ``run`` conforms to its own
292 connections.
294 Parameters
295 ----------
296 task : `lsst.pipe.base.PipelineTask`
297 The task whose connections need validation. This is a fully-configured
298 task object to support features such as optional outputs.
299 result : `lsst.pipe.base.Struct`
300 A result object produced by calling ``task.run``.
302 Raises
303 ------
304 AssertionError:
305 Raised if ``result`` does not match what's expected from ``task's``
306 connections.
307 """
308 connections = task.config.ConnectionsClass(config=task.config)
310 for name in connections.outputs:
311 connection = connections.__getattribute__(name)
312 _assertAttributeMatchesConnection(result, name, connection)
315def assertValidInitOutput(task):
316 """Test that a constructed task conforms to its own init-connections.
318 Parameters
319 ----------
320 task : `lsst.pipe.base.PipelineTask`
321 The task whose connections need validation.
323 Raises
324 ------
325 AssertionError:
326 Raised if ``task`` does not have the state expected from ``task's``
327 connections.
328 """
329 connections = task.config.ConnectionsClass(config=task.config)
331 for name in connections.initOutputs:
332 connection = connections.__getattribute__(name)
333 _assertAttributeMatchesConnection(task, name, connection)
336def getInitInputs(butler, config):
337 """Return the initInputs object that would have been passed to a
338 `~lsst.pipe.base.PipelineTask` constructor.
340 Parameters
341 ----------
342 butler : `lsst.daf.butler.Butler`
343 The repository to search for input datasets. Must have
344 pre-configured collections.
345 config : `lsst.pipe.base.PipelineTaskConfig`
346 The config for the task to be constructed.
348 Returns
349 -------
350 initInputs : `dict` [`str`]
351 A dictionary of objects in the format of the ``initInputs`` parameter
352 to `lsst.pipe.base.PipelineTask`.
353 """
354 connections = config.connections.ConnectionsClass(config=config)
355 initInputs = {}
356 for name in connections.initInputs:
357 attribute = getattr(connections, name)
358 # Get full dataset type to check for consistency problems
359 dsType = DatasetType(attribute.name, butler.registry.dimensions.extract(set()),
360 attribute.storageClass)
361 # All initInputs have empty data IDs
362 initInputs[name] = butler.get(dsType)
364 return initInputs