Coverage for python/lsst/pipe/base/testUtils.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23__all__ = ["assertValidInitOutput",
24 "assertValidOutput",
25 "getInitInputs",
26 "lintConnections",
27 "makeQuantum",
28 "runTestQuantum",
29 ]
32from collections import defaultdict
33import collections.abc
34import itertools
35import unittest.mock
37from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, Quantum, StorageClassFactory, \
38 SkyPixDimension
39from lsst.pipe.base import ButlerQuantumContext
42def makeQuantum(task, butler, dataId, ioDataIds):
43 """Create a Quantum for a particular data ID(s).
45 Parameters
46 ----------
47 task : `lsst.pipe.base.PipelineTask`
48 The task whose processing the quantum represents.
49 butler : `lsst.daf.butler.Butler`
50 The collection the quantum refers to.
51 dataId: any data ID type
52 The data ID of the quantum. Must have the same dimensions as
53 ``task``'s connections class.
54 ioDataIds : `collections.abc.Mapping` [`str`]
55 A mapping keyed by input/output names. Values must be data IDs for
56 single connections and sequences of data IDs for multiple connections.
58 Returns
59 -------
60 quantum : `lsst.daf.butler.Quantum`
61 A quantum for ``task``, when called with ``dataIds``.
62 """
63 connections = task.config.ConnectionsClass(config=task.config)
65 try:
66 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
67 except ValueError as e:
68 raise ValueError("Error in quantum dimensions.") from e
70 inputs = defaultdict(list)
71 outputs = defaultdict(list)
72 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
73 try:
74 connection = connections.__getattribute__(name)
75 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
76 ids = _normalizeDataIds(ioDataIds[name])
77 for id in ids:
78 ref = _refFromConnection(butler, connection, id)
79 inputs[ref.datasetType].append(ref)
80 except (ValueError, KeyError) as e:
81 raise ValueError(f"Error in connection {name}.") from e
82 for name in connections.outputs:
83 try:
84 connection = connections.__getattribute__(name)
85 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
86 ids = _normalizeDataIds(ioDataIds[name])
87 for id in ids:
88 ref = _refFromConnection(butler, connection, id)
89 outputs[ref.datasetType].append(ref)
90 except (ValueError, KeyError) as e:
91 raise ValueError(f"Error in connection {name}.") from e
92 quantum = Quantum(taskClass=type(task),
93 dataId=dataId,
94 inputs=inputs,
95 outputs=outputs)
96 return quantum
99def _checkDimensionsMatch(universe, expected, actual):
100 """Test whether two sets of dimensions agree after conversions.
102 Parameters
103 ----------
104 universe : `lsst.daf.butler.DimensionUniverse`
105 The set of all known dimensions.
106 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
107 The dimensions expected from a task specification.
108 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
109 The dimensions provided by input.
111 Raises
112 ------
113 ValueError
114 Raised if ``expected`` and ``actual`` cannot be reconciled.
115 """
116 if _simplify(universe, expected) != _simplify(universe, actual):
117 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
120def _simplify(universe, dimensions):
121 """Reduce a set of dimensions to a string-only form.
123 Parameters
124 ----------
125 universe : `lsst.daf.butler.DimensionUniverse`
126 The set of all known dimensions.
127 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
128 A set of dimensions to simplify.
130 Returns
131 -------
132 dimensions : `Set` [`str`]
133 A copy of ``dimensions`` reduced to string form, with all spatial
134 dimensions simplified to ``skypix``.
135 """
136 simplified = set()
137 for dimension in dimensions:
138 # skypix not a real Dimension, handle it first
139 if dimension == "skypix":
140 simplified.add(dimension)
141 else:
142 # Need a Dimension to test spatialness
143 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
144 if isinstance(fullDimension, SkyPixDimension):
145 simplified.add("skypix")
146 else:
147 simplified.add(fullDimension.name)
148 return simplified
151def _checkDataIdMultiplicity(name, dataIds, multiple):
152 """Test whether data IDs are scalars for scalar connections and sequences
153 for multiple connections.
155 Parameters
156 ----------
157 name : `str`
158 The name of the connection being tested.
159 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
160 The data ID(s) provided for the connection.
161 multiple : `bool`
162 The ``multiple`` field of the connection.
164 Raises
165 ------
166 ValueError
167 Raised if ``dataIds`` and ``multiple`` do not match.
168 """
169 if multiple:
170 if not isinstance(dataIds, collections.abc.Sequence):
171 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
172 else:
173 # DataCoordinate is a Mapping
174 if not isinstance(dataIds, collections.abc.Mapping):
175 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
178def _normalizeDataIds(dataIds):
179 """Represent both single and multiple data IDs as a list.
181 Parameters
182 ----------
183 dataIds : any data ID type or `~collections.abc.Sequence` thereof
184 The data ID(s) provided for a particular input or output connection.
186 Returns
187 -------
188 normalizedIds : `~collections.abc.Sequence` [data ID]
189 A sequence equal to ``dataIds`` if it was already a sequence, or
190 ``[dataIds]`` if it was a single ID.
191 """
192 if isinstance(dataIds, collections.abc.Sequence):
193 return dataIds
194 else:
195 return [dataIds]
198def _refFromConnection(butler, connection, dataId, **kwargs):
199 """Create a DatasetRef for a connection in a collection.
201 Parameters
202 ----------
203 butler : `lsst.daf.butler.Butler`
204 The collection to point to.
205 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
206 The connection defining the dataset type to point to.
207 dataId
208 The data ID for the dataset to point to.
209 **kwargs
210 Additional keyword arguments used to augment or construct
211 a `~lsst.daf.butler.DataCoordinate`.
213 Returns
214 -------
215 ref : `lsst.daf.butler.DatasetRef`
216 A reference to a dataset compatible with ``connection``, with ID
217 ``dataId``, in the collection pointed to by ``butler``.
218 """
219 universe = butler.registry.dimensions
220 # DatasetRef only tests if required dimension is missing, but not extras
221 _checkDimensionsMatch(universe, connection.dimensions, dataId.keys())
222 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
224 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
225 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
226 if "skypix" in connection.dimensions:
227 datasetType = butler.registry.getDatasetType(connection.name)
228 else:
229 datasetType = connection.makeDatasetType(universe)
231 try:
232 butler.registry.getDatasetType(datasetType.name)
233 except KeyError:
234 raise ValueError(f"Invalid dataset type {connection.name}.")
235 try:
236 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
237 return ref
238 except KeyError as e:
239 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
240 from e
243def _resolveTestQuantumInputs(butler, quantum):
244 """Look up all input datasets a test quantum in the `Registry` to resolve
245 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
246 ``run`` attributes).
248 Parameters
249 ----------
250 quantum : `~lsst.daf.butler.Quantum`
251 Single Quantum instance.
252 butler : `~lsst.daf.butler.Butler`
253 Data butler.
254 """
255 # TODO (DM-26819): This function is a direct copy of
256 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
257 # `runTestQuantum` function that calls it is essentially duplicating logic
258 # in that class as well (albeit not verbatim). We should probably move
259 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
260 # in test code instead of having these classes at all.
261 for refsForDatasetType in quantum.inputs.values():
262 newRefsForDatasetType = []
263 for ref in refsForDatasetType:
264 if ref.id is None:
265 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId,
266 collections=butler.collections)
267 if resolvedRef is None:
268 raise ValueError(
269 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
270 f"in collections {butler.collections}."
271 )
272 newRefsForDatasetType.append(resolvedRef)
273 else:
274 newRefsForDatasetType.append(ref)
275 refsForDatasetType[:] = newRefsForDatasetType
278def runTestQuantum(task, butler, quantum, mockRun=True):
279 """Run a PipelineTask on a Quantum.
281 Parameters
282 ----------
283 task : `lsst.pipe.base.PipelineTask`
284 The task to run on the quantum.
285 butler : `lsst.daf.butler.Butler`
286 The collection to run on.
287 quantum : `lsst.daf.butler.Quantum`
288 The quantum to run.
289 mockRun : `bool`
290 Whether or not to replace ``task``'s ``run`` method. The default of
291 `True` is recommended unless ``run`` needs to do real work (e.g.,
292 because the test needs real output datasets).
294 Returns
295 -------
296 run : `unittest.mock.Mock` or `None`
297 If ``mockRun`` is set, the mock that replaced ``run``. This object can
298 be queried for the arguments ``runQuantum`` passed to ``run``.
299 """
300 _resolveTestQuantumInputs(butler, quantum)
301 butlerQc = ButlerQuantumContext(butler, quantum)
302 connections = task.config.ConnectionsClass(config=task.config)
303 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
304 if mockRun:
305 with unittest.mock.patch.object(task, "run") as mock, \
306 unittest.mock.patch("lsst.pipe.base.ButlerQuantumContext.put"):
307 task.runQuantum(butlerQc, inputRefs, outputRefs)
308 return mock
309 else:
310 task.runQuantum(butlerQc, inputRefs, outputRefs)
311 return None
314def _assertAttributeMatchesConnection(obj, attrName, connection):
315 """Test that an attribute on an object matches the specification given in
316 a connection.
318 Parameters
319 ----------
320 obj
321 An object expected to contain the attribute ``attrName``.
322 attrName : `str`
323 The name of the attribute to be tested.
324 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
325 The connection, usually some type of output, specifying ``attrName``.
327 Raises
328 ------
329 AssertionError:
330 Raised if ``obj.attrName`` does not match what's expected
331 from ``connection``.
332 """
333 # name
334 try:
335 attrValue = obj.__getattribute__(attrName)
336 except AttributeError:
337 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
338 # multiple
339 if connection.multiple:
340 if not isinstance(attrValue, collections.abc.Sequence):
341 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
342 else:
343 # use lazy evaluation to not use StorageClassFactory unless
344 # necessary
345 if isinstance(attrValue, collections.abc.Sequence) \
346 and not issubclass(
347 StorageClassFactory().getStorageClass(connection.storageClass).pytype,
348 collections.abc.Sequence):
349 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
350 # no test for storageClass, as I'm not sure how much persistence
351 # depends on duck-typing
354def assertValidOutput(task, result):
355 """Test that the output of a call to ``run`` conforms to its own
356 connections.
358 Parameters
359 ----------
360 task : `lsst.pipe.base.PipelineTask`
361 The task whose connections need validation. This is a fully-configured
362 task object to support features such as optional outputs.
363 result : `lsst.pipe.base.Struct`
364 A result object produced by calling ``task.run``.
366 Raises
367 ------
368 AssertionError:
369 Raised if ``result`` does not match what's expected from ``task's``
370 connections.
371 """
372 connections = task.config.ConnectionsClass(config=task.config)
374 for name in connections.outputs:
375 connection = connections.__getattribute__(name)
376 _assertAttributeMatchesConnection(result, name, connection)
379def assertValidInitOutput(task):
380 """Test that a constructed task conforms to its own init-connections.
382 Parameters
383 ----------
384 task : `lsst.pipe.base.PipelineTask`
385 The task whose connections need validation.
387 Raises
388 ------
389 AssertionError:
390 Raised if ``task`` does not have the state expected from ``task's``
391 connections.
392 """
393 connections = task.config.ConnectionsClass(config=task.config)
395 for name in connections.initOutputs:
396 connection = connections.__getattribute__(name)
397 _assertAttributeMatchesConnection(task, name, connection)
400def getInitInputs(butler, config):
401 """Return the initInputs object that would have been passed to a
402 `~lsst.pipe.base.PipelineTask` constructor.
404 Parameters
405 ----------
406 butler : `lsst.daf.butler.Butler`
407 The repository to search for input datasets. Must have
408 pre-configured collections.
409 config : `lsst.pipe.base.PipelineTaskConfig`
410 The config for the task to be constructed.
412 Returns
413 -------
414 initInputs : `dict` [`str`]
415 A dictionary of objects in the format of the ``initInputs`` parameter
416 to `lsst.pipe.base.PipelineTask`.
417 """
418 connections = config.connections.ConnectionsClass(config=config)
419 initInputs = {}
420 for name in connections.initInputs:
421 attribute = getattr(connections, name)
422 # Get full dataset type to check for consistency problems
423 dsType = DatasetType(attribute.name, butler.registry.dimensions.extract(set()),
424 attribute.storageClass)
425 # All initInputs have empty data IDs
426 initInputs[name] = butler.get(dsType)
428 return initInputs
431def lintConnections(connections, *,
432 checkMissingMultiple=True,
433 checkUnnecessaryMultiple=True,
434 ):
435 """Inspect a connections class for common errors.
437 These tests are designed to detect misuse of connections features in
438 standard designs. An unusually designed connections class may trigger
439 alerts despite being correctly written; specific checks can be turned off
440 using keywords.
442 Parameters
443 ----------
444 connections : `lsst.pipe.base.PipelineTaskConnections`-type
445 The connections class to test.
446 checkMissingMultiple : `bool`
447 Whether to test for single connections that would match multiple
448 datasets at run time.
449 checkUnnecessaryMultiple : `bool`
450 Whether to test for multiple connections that would only match
451 one dataset.
453 Raises
454 ------
455 AssertionError
456 Raised if any of the selected checks fail for any connection.
457 """
458 # Since all comparisons are inside the class, don't bother
459 # normalizing skypix.
460 quantumDimensions = connections.dimensions
462 errors = ""
463 # connectionTypes.DimensionedConnection is implementation detail,
464 # don't use it.
465 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
466 connection = connections.allConnections[name]
467 connDimensions = set(connection.dimensions)
468 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
469 errors += f"Connection {name} may be called with multiple values of " \
470 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
471 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
472 errors += f"Connection {name} has multiple=True but can only be called with one " \
473 f"value of {connDimensions} for each {quantumDimensions}.\n"
474 if errors:
475 raise AssertionError(errors)