Coverage for python/lsst/pipe/base/testUtils.py: 12%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "assertValidInitOutput",
26 "assertValidOutput",
27 "getInitInputs",
28 "lintConnections",
29 "makeQuantum",
30 "runTestQuantum",
31]
34import collections.abc
35import itertools
36import unittest.mock
37from collections import defaultdict
38from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union
40from lsst.daf.butler import (
41 Butler,
42 DataCoordinate,
43 DataId,
44 DatasetRef,
45 DatasetType,
46 Dimension,
47 DimensionUniverse,
48 Quantum,
49 SkyPixDimension,
50 StorageClassFactory,
51)
52from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
54from .butlerQuantumContext import ButlerQuantumContext
56if TYPE_CHECKING: 56 ↛ 57line 56 didn't jump to line 57, because the condition on line 56 was never true
57 from .config import PipelineTaskConfig
58 from .connections import PipelineTaskConnections
59 from .pipelineTask import PipelineTask
60 from .struct import Struct
63def makeQuantum(
64 task: PipelineTask,
65 butler: Butler,
66 dataId: DataId,
67 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]],
68) -> Quantum:
69 """Create a Quantum for a particular data ID(s).
71 Parameters
72 ----------
73 task : `lsst.pipe.base.PipelineTask`
74 The task whose processing the quantum represents.
75 butler : `lsst.daf.butler.Butler`
76 The collection the quantum refers to.
77 dataId: any data ID type
78 The data ID of the quantum. Must have the same dimensions as
79 ``task``'s connections class.
80 ioDataIds : `collections.abc.Mapping` [`str`]
81 A mapping keyed by input/output names. Values must be data IDs for
82 single connections and sequences of data IDs for multiple connections.
84 Returns
85 -------
86 quantum : `lsst.daf.butler.Quantum`
87 A quantum for ``task``, when called with ``dataIds``.
88 """
89 connections = task.config.ConnectionsClass(config=task.config)
91 try:
92 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
93 except ValueError as e:
94 raise ValueError("Error in quantum dimensions.") from e
96 inputs = defaultdict(list)
97 outputs = defaultdict(list)
98 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
99 try:
100 connection = connections.__getattribute__(name)
101 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
102 ids = _normalizeDataIds(ioDataIds[name])
103 for id in ids:
104 ref = _refFromConnection(butler, connection, id)
105 inputs[ref.datasetType].append(ref)
106 except (ValueError, KeyError) as e:
107 raise ValueError(f"Error in connection {name}.") from e
108 for name in connections.outputs:
109 try:
110 connection = connections.__getattribute__(name)
111 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
112 ids = _normalizeDataIds(ioDataIds[name])
113 for id in ids:
114 ref = _refFromConnection(butler, connection, id)
115 outputs[ref.datasetType].append(ref)
116 except (ValueError, KeyError) as e:
117 raise ValueError(f"Error in connection {name}.") from e
118 quantum = Quantum(
119 taskClass=type(task),
120 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions),
121 inputs=inputs,
122 outputs=outputs,
123 )
124 return quantum
127def _checkDimensionsMatch(
128 universe: DimensionUniverse,
129 expected: Union[AbstractSet[str], AbstractSet[Dimension]],
130 actual: Union[AbstractSet[str], AbstractSet[Dimension]],
131) -> None:
132 """Test whether two sets of dimensions agree after conversions.
134 Parameters
135 ----------
136 universe : `lsst.daf.butler.DimensionUniverse`
137 The set of all known dimensions.
138 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
139 The dimensions expected from a task specification.
140 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
141 The dimensions provided by input.
143 Raises
144 ------
145 ValueError
146 Raised if ``expected`` and ``actual`` cannot be reconciled.
147 """
148 if _simplify(universe, expected) != _simplify(universe, actual):
149 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
152def _simplify(
153 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]]
154) -> Set[str]:
155 """Reduce a set of dimensions to a string-only form.
157 Parameters
158 ----------
159 universe : `lsst.daf.butler.DimensionUniverse`
160 The set of all known dimensions.
161 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
162 A set of dimensions to simplify.
164 Returns
165 -------
166 dimensions : `Set` [`str`]
167 A copy of ``dimensions`` reduced to string form, with all spatial
168 dimensions simplified to ``skypix``.
169 """
170 simplified: Set[str] = set()
171 for dimension in dimensions:
172 # skypix not a real Dimension, handle it first
173 if dimension == "skypix":
174 simplified.add(dimension) # type: ignore
175 else:
176 # Need a Dimension to test spatialness
177 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
178 if isinstance(fullDimension, SkyPixDimension):
179 simplified.add("skypix")
180 else:
181 simplified.add(fullDimension.name)
182 return simplified
185def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None:
186 """Test whether data IDs are scalars for scalar connections and sequences
187 for multiple connections.
189 Parameters
190 ----------
191 name : `str`
192 The name of the connection being tested.
193 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
194 The data ID(s) provided for the connection.
195 multiple : `bool`
196 The ``multiple`` field of the connection.
198 Raises
199 ------
200 ValueError
201 Raised if ``dataIds`` and ``multiple`` do not match.
202 """
203 if multiple:
204 if not isinstance(dataIds, collections.abc.Sequence):
205 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
206 else:
207 # DataCoordinate is a Mapping
208 if not isinstance(dataIds, collections.abc.Mapping):
209 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
212def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]:
213 """Represent both single and multiple data IDs as a list.
215 Parameters
216 ----------
217 dataIds : any data ID type or `~collections.abc.Sequence` thereof
218 The data ID(s) provided for a particular input or output connection.
220 Returns
221 -------
222 normalizedIds : `~collections.abc.Sequence` [data ID]
223 A sequence equal to ``dataIds`` if it was already a sequence, or
224 ``[dataIds]`` if it was a single ID.
225 """
226 if isinstance(dataIds, collections.abc.Sequence):
227 return dataIds
228 else:
229 return [dataIds]
232def _refFromConnection(
233 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
234) -> DatasetRef:
235 """Create a DatasetRef for a connection in a collection.
237 Parameters
238 ----------
239 butler : `lsst.daf.butler.Butler`
240 The collection to point to.
241 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
242 The connection defining the dataset type to point to.
243 dataId
244 The data ID for the dataset to point to.
245 **kwargs
246 Additional keyword arguments used to augment or construct
247 a `~lsst.daf.butler.DataCoordinate`.
249 Returns
250 -------
251 ref : `lsst.daf.butler.DatasetRef`
252 A reference to a dataset compatible with ``connection``, with ID
253 ``dataId``, in the collection pointed to by ``butler``.
254 """
255 universe = butler.registry.dimensions
256 # DatasetRef only tests if required dimension is missing, but not extras
257 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
258 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
260 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
261 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
262 if "skypix" in connection.dimensions:
263 datasetType = butler.registry.getDatasetType(connection.name)
264 else:
265 datasetType = connection.makeDatasetType(universe)
267 try:
268 butler.registry.getDatasetType(datasetType.name)
269 except KeyError:
270 raise ValueError(f"Invalid dataset type {connection.name}.")
271 try:
272 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
273 return ref
274 except KeyError as e:
275 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
278def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None:
279 """Look up all input datasets a test quantum in the `Registry` to resolve
280 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
281 ``run`` attributes).
283 Parameters
284 ----------
285 quantum : `~lsst.daf.butler.Quantum`
286 Single Quantum instance.
287 butler : `~lsst.daf.butler.Butler`
288 Data butler.
289 """
290 # TODO (DM-26819): This function is a direct copy of
291 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
292 # `runTestQuantum` function that calls it is essentially duplicating logic
293 # in that class as well (albeit not verbatim). We should probably move
294 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
295 # in test code instead of having these classes at all.
296 for refsForDatasetType in quantum.inputs.values():
297 newRefsForDatasetType = []
298 for ref in refsForDatasetType:
299 if ref.id is None:
300 resolvedRef = butler.registry.findDataset(
301 ref.datasetType, ref.dataId, collections=butler.collections
302 )
303 if resolvedRef is None:
304 raise ValueError(
305 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
306 f"in collections {butler.collections}."
307 )
308 newRefsForDatasetType.append(resolvedRef)
309 else:
310 newRefsForDatasetType.append(ref)
311 refsForDatasetType[:] = newRefsForDatasetType
314def runTestQuantum(
315 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
316) -> Optional[unittest.mock.Mock]:
317 """Run a PipelineTask on a Quantum.
319 Parameters
320 ----------
321 task : `lsst.pipe.base.PipelineTask`
322 The task to run on the quantum.
323 butler : `lsst.daf.butler.Butler`
324 The collection to run on.
325 quantum : `lsst.daf.butler.Quantum`
326 The quantum to run.
327 mockRun : `bool`
328 Whether or not to replace ``task``'s ``run`` method. The default of
329 `True` is recommended unless ``run`` needs to do real work (e.g.,
330 because the test needs real output datasets).
332 Returns
333 -------
334 run : `unittest.mock.Mock` or `None`
335 If ``mockRun`` is set, the mock that replaced ``run``. This object can
336 be queried for the arguments ``runQuantum`` passed to ``run``.
337 """
338 _resolveTestQuantumInputs(butler, quantum)
339 butlerQc = ButlerQuantumContext(butler, quantum)
340 connections = task.config.ConnectionsClass(config=task.config)
341 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
342 if mockRun:
343 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch(
344 "lsst.pipe.base.ButlerQuantumContext.put"
345 ):
346 task.runQuantum(butlerQc, inputRefs, outputRefs)
347 return mock
348 else:
349 task.runQuantum(butlerQc, inputRefs, outputRefs)
350 return None
353def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
354 """Test that an attribute on an object matches the specification given in
355 a connection.
357 Parameters
358 ----------
359 obj
360 An object expected to contain the attribute ``attrName``.
361 attrName : `str`
362 The name of the attribute to be tested.
363 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
364 The connection, usually some type of output, specifying ``attrName``.
366 Raises
367 ------
368 AssertionError:
369 Raised if ``obj.attrName`` does not match what's expected
370 from ``connection``.
371 """
372 # name
373 try:
374 attrValue = obj.__getattribute__(attrName)
375 except AttributeError:
376 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
377 # multiple
378 if connection.multiple:
379 if not isinstance(attrValue, collections.abc.Sequence):
380 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
381 else:
382 # use lazy evaluation to not use StorageClassFactory unless
383 # necessary
384 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
385 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
386 ):
387 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
388 # no test for storageClass, as I'm not sure how much persistence
389 # depends on duck-typing
392def assertValidOutput(task: PipelineTask, result: Struct) -> None:
393 """Test that the output of a call to ``run`` conforms to its own
394 connections.
396 Parameters
397 ----------
398 task : `lsst.pipe.base.PipelineTask`
399 The task whose connections need validation. This is a fully-configured
400 task object to support features such as optional outputs.
401 result : `lsst.pipe.base.Struct`
402 A result object produced by calling ``task.run``.
404 Raises
405 ------
406 AssertionError:
407 Raised if ``result`` does not match what's expected from ``task's``
408 connections.
409 """
410 connections = task.config.ConnectionsClass(config=task.config)
412 for name in connections.outputs:
413 connection = connections.__getattribute__(name)
414 _assertAttributeMatchesConnection(result, name, connection)
417def assertValidInitOutput(task: PipelineTask) -> None:
418 """Test that a constructed task conforms to its own init-connections.
420 Parameters
421 ----------
422 task : `lsst.pipe.base.PipelineTask`
423 The task whose connections need validation.
425 Raises
426 ------
427 AssertionError:
428 Raised if ``task`` does not have the state expected from ``task's``
429 connections.
430 """
431 connections = task.config.ConnectionsClass(config=task.config)
433 for name in connections.initOutputs:
434 connection = connections.__getattribute__(name)
435 _assertAttributeMatchesConnection(task, name, connection)
438def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]:
439 """Return the initInputs object that would have been passed to a
440 `~lsst.pipe.base.PipelineTask` constructor.
442 Parameters
443 ----------
444 butler : `lsst.daf.butler.Butler`
445 The repository to search for input datasets. Must have
446 pre-configured collections.
447 config : `lsst.pipe.base.PipelineTaskConfig`
448 The config for the task to be constructed.
450 Returns
451 -------
452 initInputs : `dict` [`str`]
453 A dictionary of objects in the format of the ``initInputs`` parameter
454 to `lsst.pipe.base.PipelineTask`.
455 """
456 connections = config.connections.ConnectionsClass(config=config)
457 initInputs = {}
458 for name in connections.initInputs:
459 attribute = getattr(connections, name)
460 # Get full dataset type to check for consistency problems
461 dsType = DatasetType(
462 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass
463 )
464 # All initInputs have empty data IDs
465 initInputs[name] = butler.get(dsType)
467 return initInputs
470def lintConnections(
471 connections: PipelineTaskConnections,
472 *,
473 checkMissingMultiple: bool = True,
474 checkUnnecessaryMultiple: bool = True,
475) -> None:
476 """Inspect a connections class for common errors.
478 These tests are designed to detect misuse of connections features in
479 standard designs. An unusually designed connections class may trigger
480 alerts despite being correctly written; specific checks can be turned off
481 using keywords.
483 Parameters
484 ----------
485 connections : `lsst.pipe.base.PipelineTaskConnections`-type
486 The connections class to test.
487 checkMissingMultiple : `bool`
488 Whether to test for single connections that would match multiple
489 datasets at run time.
490 checkUnnecessaryMultiple : `bool`
491 Whether to test for multiple connections that would only match
492 one dataset.
494 Raises
495 ------
496 AssertionError
497 Raised if any of the selected checks fail for any connection.
498 """
499 # Since all comparisons are inside the class, don't bother
500 # normalizing skypix.
501 quantumDimensions = connections.dimensions
503 errors = ""
504 # connectionTypes.DimensionedConnection is implementation detail,
505 # don't use it.
506 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
507 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
508 connDimensions = set(connection.dimensions)
509 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
510 errors += (
511 f"Connection {name} may be called with multiple values of "
512 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
513 )
514 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
515 errors += (
516 f"Connection {name} has multiple=True but can only be called with one "
517 f"value of {connDimensions} for each {quantumDimensions}.\n"
518 )
519 if errors:
520 raise AssertionError(errors)