Coverage for python/lsst/pipe/base/testUtils.py: 10%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23__all__ = [
24 "assertValidInitOutput",
25 "assertValidOutput",
26 "getInitInputs",
27 "lintConnections",
28 "makeQuantum",
29 "runTestQuantum",
30]
33import collections.abc
34import itertools
35import unittest.mock
36from collections import defaultdict
38from lsst.daf.butler import (
39 DataCoordinate,
40 DatasetRef,
41 DatasetType,
42 Quantum,
43 SkyPixDimension,
44 StorageClassFactory,
45)
46from lsst.pipe.base import ButlerQuantumContext
49def makeQuantum(task, butler, dataId, ioDataIds):
50 """Create a Quantum for a particular data ID(s).
52 Parameters
53 ----------
54 task : `lsst.pipe.base.PipelineTask`
55 The task whose processing the quantum represents.
56 butler : `lsst.daf.butler.Butler`
57 The collection the quantum refers to.
58 dataId: any data ID type
59 The data ID of the quantum. Must have the same dimensions as
60 ``task``'s connections class.
61 ioDataIds : `collections.abc.Mapping` [`str`]
62 A mapping keyed by input/output names. Values must be data IDs for
63 single connections and sequences of data IDs for multiple connections.
65 Returns
66 -------
67 quantum : `lsst.daf.butler.Quantum`
68 A quantum for ``task``, when called with ``dataIds``.
69 """
70 connections = task.config.ConnectionsClass(config=task.config)
72 try:
73 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
74 except ValueError as e:
75 raise ValueError("Error in quantum dimensions.") from e
77 inputs = defaultdict(list)
78 outputs = defaultdict(list)
79 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
80 try:
81 connection = connections.__getattribute__(name)
82 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
83 ids = _normalizeDataIds(ioDataIds[name])
84 for id in ids:
85 ref = _refFromConnection(butler, connection, id)
86 inputs[ref.datasetType].append(ref)
87 except (ValueError, KeyError) as e:
88 raise ValueError(f"Error in connection {name}.") from e
89 for name in connections.outputs:
90 try:
91 connection = connections.__getattribute__(name)
92 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
93 ids = _normalizeDataIds(ioDataIds[name])
94 for id in ids:
95 ref = _refFromConnection(butler, connection, id)
96 outputs[ref.datasetType].append(ref)
97 except (ValueError, KeyError) as e:
98 raise ValueError(f"Error in connection {name}.") from e
99 quantum = Quantum(taskClass=type(task), dataId=dataId, inputs=inputs, outputs=outputs)
100 return quantum
103def _checkDimensionsMatch(universe, expected, actual):
104 """Test whether two sets of dimensions agree after conversions.
106 Parameters
107 ----------
108 universe : `lsst.daf.butler.DimensionUniverse`
109 The set of all known dimensions.
110 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
111 The dimensions expected from a task specification.
112 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
113 The dimensions provided by input.
115 Raises
116 ------
117 ValueError
118 Raised if ``expected`` and ``actual`` cannot be reconciled.
119 """
120 if _simplify(universe, expected) != _simplify(universe, actual):
121 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
124def _simplify(universe, dimensions):
125 """Reduce a set of dimensions to a string-only form.
127 Parameters
128 ----------
129 universe : `lsst.daf.butler.DimensionUniverse`
130 The set of all known dimensions.
131 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
132 A set of dimensions to simplify.
134 Returns
135 -------
136 dimensions : `Set` [`str`]
137 A copy of ``dimensions`` reduced to string form, with all spatial
138 dimensions simplified to ``skypix``.
139 """
140 simplified = set()
141 for dimension in dimensions:
142 # skypix not a real Dimension, handle it first
143 if dimension == "skypix":
144 simplified.add(dimension)
145 else:
146 # Need a Dimension to test spatialness
147 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
148 if isinstance(fullDimension, SkyPixDimension):
149 simplified.add("skypix")
150 else:
151 simplified.add(fullDimension.name)
152 return simplified
155def _checkDataIdMultiplicity(name, dataIds, multiple):
156 """Test whether data IDs are scalars for scalar connections and sequences
157 for multiple connections.
159 Parameters
160 ----------
161 name : `str`
162 The name of the connection being tested.
163 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
164 The data ID(s) provided for the connection.
165 multiple : `bool`
166 The ``multiple`` field of the connection.
168 Raises
169 ------
170 ValueError
171 Raised if ``dataIds`` and ``multiple`` do not match.
172 """
173 if multiple:
174 if not isinstance(dataIds, collections.abc.Sequence):
175 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
176 else:
177 # DataCoordinate is a Mapping
178 if not isinstance(dataIds, collections.abc.Mapping):
179 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
182def _normalizeDataIds(dataIds):
183 """Represent both single and multiple data IDs as a list.
185 Parameters
186 ----------
187 dataIds : any data ID type or `~collections.abc.Sequence` thereof
188 The data ID(s) provided for a particular input or output connection.
190 Returns
191 -------
192 normalizedIds : `~collections.abc.Sequence` [data ID]
193 A sequence equal to ``dataIds`` if it was already a sequence, or
194 ``[dataIds]`` if it was a single ID.
195 """
196 if isinstance(dataIds, collections.abc.Sequence):
197 return dataIds
198 else:
199 return [dataIds]
202def _refFromConnection(butler, connection, dataId, **kwargs):
203 """Create a DatasetRef for a connection in a collection.
205 Parameters
206 ----------
207 butler : `lsst.daf.butler.Butler`
208 The collection to point to.
209 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
210 The connection defining the dataset type to point to.
211 dataId
212 The data ID for the dataset to point to.
213 **kwargs
214 Additional keyword arguments used to augment or construct
215 a `~lsst.daf.butler.DataCoordinate`.
217 Returns
218 -------
219 ref : `lsst.daf.butler.DatasetRef`
220 A reference to a dataset compatible with ``connection``, with ID
221 ``dataId``, in the collection pointed to by ``butler``.
222 """
223 universe = butler.registry.dimensions
224 # DatasetRef only tests if required dimension is missing, but not extras
225 _checkDimensionsMatch(universe, connection.dimensions, dataId.keys())
226 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
228 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
229 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
230 if "skypix" in connection.dimensions:
231 datasetType = butler.registry.getDatasetType(connection.name)
232 else:
233 datasetType = connection.makeDatasetType(universe)
235 try:
236 butler.registry.getDatasetType(datasetType.name)
237 except KeyError:
238 raise ValueError(f"Invalid dataset type {connection.name}.")
239 try:
240 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
241 return ref
242 except KeyError as e:
243 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
246def _resolveTestQuantumInputs(butler, quantum):
247 """Look up all input datasets a test quantum in the `Registry` to resolve
248 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
249 ``run`` attributes).
251 Parameters
252 ----------
253 quantum : `~lsst.daf.butler.Quantum`
254 Single Quantum instance.
255 butler : `~lsst.daf.butler.Butler`
256 Data butler.
257 """
258 # TODO (DM-26819): This function is a direct copy of
259 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
260 # `runTestQuantum` function that calls it is essentially duplicating logic
261 # in that class as well (albeit not verbatim). We should probably move
262 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
263 # in test code instead of having these classes at all.
264 for refsForDatasetType in quantum.inputs.values():
265 newRefsForDatasetType = []
266 for ref in refsForDatasetType:
267 if ref.id is None:
268 resolvedRef = butler.registry.findDataset(
269 ref.datasetType, ref.dataId, collections=butler.collections
270 )
271 if resolvedRef is None:
272 raise ValueError(
273 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
274 f"in collections {butler.collections}."
275 )
276 newRefsForDatasetType.append(resolvedRef)
277 else:
278 newRefsForDatasetType.append(ref)
279 refsForDatasetType[:] = newRefsForDatasetType
282def runTestQuantum(task, butler, quantum, mockRun=True):
283 """Run a PipelineTask on a Quantum.
285 Parameters
286 ----------
287 task : `lsst.pipe.base.PipelineTask`
288 The task to run on the quantum.
289 butler : `lsst.daf.butler.Butler`
290 The collection to run on.
291 quantum : `lsst.daf.butler.Quantum`
292 The quantum to run.
293 mockRun : `bool`
294 Whether or not to replace ``task``'s ``run`` method. The default of
295 `True` is recommended unless ``run`` needs to do real work (e.g.,
296 because the test needs real output datasets).
298 Returns
299 -------
300 run : `unittest.mock.Mock` or `None`
301 If ``mockRun`` is set, the mock that replaced ``run``. This object can
302 be queried for the arguments ``runQuantum`` passed to ``run``.
303 """
304 _resolveTestQuantumInputs(butler, quantum)
305 butlerQc = ButlerQuantumContext(butler, quantum)
306 connections = task.config.ConnectionsClass(config=task.config)
307 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
308 if mockRun:
309 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch(
310 "lsst.pipe.base.ButlerQuantumContext.put"
311 ):
312 task.runQuantum(butlerQc, inputRefs, outputRefs)
313 return mock
314 else:
315 task.runQuantum(butlerQc, inputRefs, outputRefs)
316 return None
319def _assertAttributeMatchesConnection(obj, attrName, connection):
320 """Test that an attribute on an object matches the specification given in
321 a connection.
323 Parameters
324 ----------
325 obj
326 An object expected to contain the attribute ``attrName``.
327 attrName : `str`
328 The name of the attribute to be tested.
329 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
330 The connection, usually some type of output, specifying ``attrName``.
332 Raises
333 ------
334 AssertionError:
335 Raised if ``obj.attrName`` does not match what's expected
336 from ``connection``.
337 """
338 # name
339 try:
340 attrValue = obj.__getattribute__(attrName)
341 except AttributeError:
342 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
343 # multiple
344 if connection.multiple:
345 if not isinstance(attrValue, collections.abc.Sequence):
346 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
347 else:
348 # use lazy evaluation to not use StorageClassFactory unless
349 # necessary
350 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
351 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
352 ):
353 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
354 # no test for storageClass, as I'm not sure how much persistence
355 # depends on duck-typing
358def assertValidOutput(task, result):
359 """Test that the output of a call to ``run`` conforms to its own
360 connections.
362 Parameters
363 ----------
364 task : `lsst.pipe.base.PipelineTask`
365 The task whose connections need validation. This is a fully-configured
366 task object to support features such as optional outputs.
367 result : `lsst.pipe.base.Struct`
368 A result object produced by calling ``task.run``.
370 Raises
371 ------
372 AssertionError:
373 Raised if ``result`` does not match what's expected from ``task's``
374 connections.
375 """
376 connections = task.config.ConnectionsClass(config=task.config)
378 for name in connections.outputs:
379 connection = connections.__getattribute__(name)
380 _assertAttributeMatchesConnection(result, name, connection)
383def assertValidInitOutput(task):
384 """Test that a constructed task conforms to its own init-connections.
386 Parameters
387 ----------
388 task : `lsst.pipe.base.PipelineTask`
389 The task whose connections need validation.
391 Raises
392 ------
393 AssertionError:
394 Raised if ``task`` does not have the state expected from ``task's``
395 connections.
396 """
397 connections = task.config.ConnectionsClass(config=task.config)
399 for name in connections.initOutputs:
400 connection = connections.__getattribute__(name)
401 _assertAttributeMatchesConnection(task, name, connection)
404def getInitInputs(butler, config):
405 """Return the initInputs object that would have been passed to a
406 `~lsst.pipe.base.PipelineTask` constructor.
408 Parameters
409 ----------
410 butler : `lsst.daf.butler.Butler`
411 The repository to search for input datasets. Must have
412 pre-configured collections.
413 config : `lsst.pipe.base.PipelineTaskConfig`
414 The config for the task to be constructed.
416 Returns
417 -------
418 initInputs : `dict` [`str`]
419 A dictionary of objects in the format of the ``initInputs`` parameter
420 to `lsst.pipe.base.PipelineTask`.
421 """
422 connections = config.connections.ConnectionsClass(config=config)
423 initInputs = {}
424 for name in connections.initInputs:
425 attribute = getattr(connections, name)
426 # Get full dataset type to check for consistency problems
427 dsType = DatasetType(
428 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass
429 )
430 # All initInputs have empty data IDs
431 initInputs[name] = butler.get(dsType)
433 return initInputs
436def lintConnections(
437 connections,
438 *,
439 checkMissingMultiple=True,
440 checkUnnecessaryMultiple=True,
441):
442 """Inspect a connections class for common errors.
444 These tests are designed to detect misuse of connections features in
445 standard designs. An unusually designed connections class may trigger
446 alerts despite being correctly written; specific checks can be turned off
447 using keywords.
449 Parameters
450 ----------
451 connections : `lsst.pipe.base.PipelineTaskConnections`-type
452 The connections class to test.
453 checkMissingMultiple : `bool`
454 Whether to test for single connections that would match multiple
455 datasets at run time.
456 checkUnnecessaryMultiple : `bool`
457 Whether to test for multiple connections that would only match
458 one dataset.
460 Raises
461 ------
462 AssertionError
463 Raised if any of the selected checks fail for any connection.
464 """
465 # Since all comparisons are inside the class, don't bother
466 # normalizing skypix.
467 quantumDimensions = connections.dimensions
469 errors = ""
470 # connectionTypes.DimensionedConnection is implementation detail,
471 # don't use it.
472 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
473 connection = connections.allConnections[name]
474 connDimensions = set(connection.dimensions)
475 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
476 errors += (
477 f"Connection {name} may be called with multiple values of "
478 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
479 )
480 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
481 errors += (
482 f"Connection {name} has multiple=True but can only be called with one "
483 f"value of {connDimensions} for each {quantumDimensions}.\n"
484 )
485 if errors:
486 raise AssertionError(errors)