lsst.pipe.base  20.0.0-14-g1ce627f+450400e286
testUtils.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 
22 
23 __all__ = ["makeQuantum", "runTestQuantum", "assertValidOutput"]
24 
25 
26 import collections.abc
27 import itertools
28 import unittest.mock
29 
30 from lsst.daf.butler import DataCoordinate, DatasetRef, Quantum, StorageClassFactory
31 from lsst.pipe.base import ButlerQuantumContext
32 
33 
34 def makeQuantum(task, butler, dataId, ioDataIds):
35  """Create a Quantum for a particular data ID(s).
36 
37  Parameters
38  ----------
39  task : `lsst.pipe.base.PipelineTask`
40  The task whose processing the quantum represents.
41  butler : `lsst.daf.butler.Butler`
42  The collection the quantum refers to.
43  dataId: any data ID type
44  The data ID of the quantum. Must have the same dimensions as
45  ``task``'s connections class.
46  ioDataIds : `collections.abc.Mapping` [`str`]
47  A mapping keyed by input/output names. Values must be data IDs for
48  single connections and sequences of data IDs for multiple connections.
49 
50  Returns
51  -------
52  quantum : `lsst.daf.butler.Quantum`
53  A quantum for ``task``, when called with ``dataIds``.
54  """
55  quantum = Quantum(taskClass=type(task), dataId=dataId)
56  connections = task.config.ConnectionsClass(config=task.config)
57 
58  try:
59  for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
60  connection = connections.__getattribute__(name)
61  _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
62  ids = _normalizeDataIds(ioDataIds[name])
63  for id in ids:
64  quantum.addPredictedInput(_refFromConnection(butler, connection, id))
65  for name in connections.outputs:
66  connection = connections.__getattribute__(name)
67  _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
68  ids = _normalizeDataIds(ioDataIds[name])
69  for id in ids:
70  quantum.addOutput(_refFromConnection(butler, connection, id))
71  return quantum
72  except KeyError as e:
73  raise ValueError("Mismatch in input data.") from e
74 
75 
76 def _checkDataIdMultiplicity(name, dataIds, multiple):
77  """Test whether data IDs are scalars for scalar connections and sequences
78  for multiple connections.
79 
80  Parameters
81  ----------
82  name : `str`
83  The name of the connection being tested.
84  dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
85  The data ID(s) provided for the connection.
86  multiple : `bool`
87  The ``multiple`` field of the connection.
88 
89  Raises
90  ------
91  ValueError
92  Raised if ``dataIds`` and ``multiple`` do not match.
93  """
94  if multiple:
95  if not isinstance(dataIds, collections.abc.Sequence):
96  raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
97  else:
98  # DataCoordinate is a Mapping
99  if not isinstance(dataIds, collections.abc.Mapping):
100  raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
101 
102 
103 def _normalizeDataIds(dataIds):
104  """Represent both single and multiple data IDs as a list.
105 
106  Parameters
107  ----------
108  dataIds : any data ID type or `~collections.abc.Sequence` thereof
109  The data ID(s) provided for a particular input or output connection.
110 
111  Returns
112  -------
113  normalizedIds : `~collections.abc.Sequence` [data ID]
114  A sequence equal to ``dataIds`` if it was already a sequence, or
115  ``[dataIds]`` if it was a single ID.
116  """
117  if isinstance(dataIds, collections.abc.Sequence):
118  return dataIds
119  else:
120  return [dataIds]
121 
122 
123 def _refFromConnection(butler, connection, dataId, **kwargs):
124  """Create a DatasetRef for a connection in a collection.
125 
126  Parameters
127  ----------
128  butler : `lsst.daf.butler.Butler`
129  The collection to point to.
130  connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
131  The connection defining the dataset type to point to.
132  dataId
133  The data ID for the dataset to point to.
134  **kwargs
135  Additional keyword arguments used to augment or construct
136  a `~lsst.daf.butler.DataCoordinate`.
137 
138  Returns
139  -------
140  ref : `lsst.daf.butler.DatasetRef`
141  A reference to a dataset compatible with ``connection``, with ID
142  ``dataId``, in the collection pointed to by ``butler``.
143  """
144  universe = butler.registry.dimensions
145  dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
146 
147  # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
148  # understand it. Code copied from TaskDatasetTypes.fromTaskDef
149  if "skypix" in connection.dimensions:
150  datasetType = butler.registry.getDatasetType(connection.name)
151  else:
152  datasetType = connection.makeDatasetType(universe)
153 
154  try:
155  butler.registry.getDatasetType(datasetType.name)
156  except KeyError:
157  raise ValueError(f"Invalid dataset type {connection.name}.")
158  try:
159  ref = DatasetRef(datasetType=datasetType, dataId=dataId)
160  return ref
161  except KeyError as e:
162  raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
163  from e
164 
165 
166 def _resolveTestQuantumInputs(butler, quantum):
167  """Look up all input datasets a test quantum in the `Registry` to resolve
168  all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
169  ``run`` attributes).
170 
171  Parameters
172  ----------
173  quantum : `~lsst.daf.butler.Quantum`
174  Single Quantum instance.
175  butler : `~lsst.daf.butler.Butler`
176  Data butler.
177  """
178  # TODO (DM-26819): This function is a direct copy of
179  # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
180  # `runTestQuantum` function that calls it is essentially duplicating logic
181  # in that class as well (albeit not verbatim). We should probably move
182  # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
183  # in test code instead of having these classes at all.
184  for refsForDatasetType in quantum.predictedInputs.values():
185  newRefsForDatasetType = []
186  for ref in refsForDatasetType:
187  if ref.id is None:
188  resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId,
189  collections=butler.collections)
190  if resolvedRef is None:
191  raise ValueError(
192  f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
193  f"in collections {butler.collections}."
194  )
195  newRefsForDatasetType.append(resolvedRef)
196  else:
197  newRefsForDatasetType.append(ref)
198  refsForDatasetType[:] = newRefsForDatasetType
199 
200 
201 def runTestQuantum(task, butler, quantum, mockRun=True):
202  """Run a PipelineTask on a Quantum.
203 
204  Parameters
205  ----------
206  task : `lsst.pipe.base.PipelineTask`
207  The task to run on the quantum.
208  butler : `lsst.daf.butler.Butler`
209  The collection to run on.
210  quantum : `lsst.daf.butler.Quantum`
211  The quantum to run.
212  mockRun : `bool`
213  Whether or not to replace ``task``'s ``run`` method. The default of
214  `True` is recommended unless ``run`` needs to do real work (e.g.,
215  because the test needs real output datasets).
216 
217  Returns
218  -------
219  run : `unittest.mock.Mock` or `None`
220  If ``mockRun`` is set, the mock that replaced ``run``. This object can
221  be queried for the arguments ``runQuantum`` passed to ``run``.
222  """
223  _resolveTestQuantumInputs(butler, quantum)
224  butlerQc = ButlerQuantumContext(butler, quantum)
225  connections = task.config.ConnectionsClass(config=task.config)
226  inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
227  if mockRun:
228  with unittest.mock.patch.object(task, "run") as mock, \
229  unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"):
230  task.runQuantum(butlerQc, inputRefs, outputRefs)
231  return mock
232  else:
233  task.runQuantum(butlerQc, inputRefs, outputRefs)
234  return None
235 
236 
237 def assertValidOutput(task, result):
238  """Test that the output of a call to ``run`` conforms to its own connections.
239 
240  Parameters
241  ----------
242  task : `lsst.pipe.base.PipelineTask`
243  The task whose connections need validation. This is a fully-configured
244  task object to support features such as optional outputs.
245  result : `lsst.pipe.base.Struct`
246  A result object produced by calling ``task.run``.
247 
248  Raises
249  -------
250  AssertionError:
251  Raised if ``result`` does not match what's expected from ``task's``
252  connections.
253  """
254  connections = task.config.ConnectionsClass(config=task.config)
255  recoveredOutputs = result.getDict()
256 
257  for name in connections.outputs:
258  connection = connections.__getattribute__(name)
259  # name
260  try:
261  output = recoveredOutputs[name]
262  except KeyError:
263  raise AssertionError(f"No such output: {name}")
264  # multiple
265  if connection.multiple:
266  if not isinstance(output, collections.abc.Sequence):
267  raise AssertionError(f"Expected {name} to be a sequence, got {output} instead.")
268  else:
269  # use lazy evaluation to not use StorageClassFactory unless necessary
270  if isinstance(output, collections.abc.Sequence) \
271  and not issubclass(
272  StorageClassFactory().getStorageClass(connection.storageClass).pytype,
273  collections.abc.Sequence):
274  raise AssertionError(f"Expected {name} to be a single value, got {output} instead.")
275  # no test for storageClass, as I'm not sure how much persistence depends on duck-typing
lsst::pipe::base.butlerQuantumContext.ButlerQuantumContext
Definition: butlerQuantumContext.py:35
lsst::pipe::base.testUtils.makeQuantum
def makeQuantum(task, butler, dataId, ioDataIds)
Definition: testUtils.py:34
lsst::pipe::base.testUtils.runTestQuantum
def runTestQuantum(task, butler, quantum, mockRun=True)
Definition: testUtils.py:201
lsst::pipe::base
Definition: main.dox:1
lsst::pipe::base.testUtils.assertValidOutput
def assertValidOutput(task, result)
Definition: testUtils.py:237