Coverage for python/lsst/pipe/base/testUtils.py: 13%
130 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "assertValidInitOutput",
26 "assertValidOutput",
27 "getInitInputs",
28 "lintConnections",
29 "makeQuantum",
30 "runTestQuantum",
31]
34import collections.abc
35import itertools
36import unittest.mock
37from collections import defaultdict
38from collections.abc import Mapping, Sequence, Set
39from typing import TYPE_CHECKING, Any
41from lsst.daf.butler import (
42 Butler,
43 DataCoordinate,
44 DataId,
45 DatasetRef,
46 DatasetType,
47 Dimension,
48 DimensionUniverse,
49 Quantum,
50 SkyPixDimension,
51 StorageClassFactory,
52)
53from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
55from ._quantumContext import QuantumContext
57if TYPE_CHECKING:
58 from .config import PipelineTaskConfig
59 from .connections import PipelineTaskConnections
60 from .pipelineTask import PipelineTask
61 from .struct import Struct
64def makeQuantum(
65 task: PipelineTask,
66 butler: Butler,
67 dataId: DataId,
68 ioDataIds: Mapping[str, DataId | Sequence[DataId]],
69) -> Quantum:
70 """Create a Quantum for a particular data ID(s).
72 Parameters
73 ----------
74 task : `lsst.pipe.base.PipelineTask`
75 The task whose processing the quantum represents.
76 butler : `lsst.daf.butler.Butler`
77 The collection the quantum refers to.
78 dataId: any data ID type
79 The data ID of the quantum. Must have the same dimensions as
80 ``task``'s connections class.
81 ioDataIds : `collections.abc.Mapping` [`str`]
82 A mapping keyed by input/output names. Values must be data IDs for
83 single connections and sequences of data IDs for multiple connections.
85 Returns
86 -------
87 quantum : `lsst.daf.butler.Quantum`
88 A quantum for ``task``, when called with ``dataIds``.
89 """
90 # This is a type ignore, because `connections` is a dynamic class, but
91 # it for sure will have this property
92 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
94 try:
95 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys())
96 except ValueError as e:
97 raise ValueError("Error in quantum dimensions.") from e
99 inputs = defaultdict(list)
100 outputs = defaultdict(list)
101 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
102 try:
103 connection = connections.__getattribute__(name)
104 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
105 ids = _normalizeDataIds(ioDataIds[name])
106 for id in ids:
107 ref = _refFromConnection(butler, connection, id)
108 inputs[ref.datasetType].append(ref)
109 except (ValueError, KeyError) as e:
110 raise ValueError(f"Error in connection {name}.") from e
111 for name in connections.outputs:
112 try:
113 connection = connections.__getattribute__(name)
114 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
115 ids = _normalizeDataIds(ioDataIds[name])
116 for id in ids:
117 ref = _refFromConnection(butler, connection, id)
118 outputs[ref.datasetType].append(ref)
119 except (ValueError, KeyError) as e:
120 raise ValueError(f"Error in connection {name}.") from e
121 quantum = Quantum(
122 taskClass=type(task),
123 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions),
124 inputs=inputs,
125 outputs=outputs,
126 )
127 return quantum
130def _checkDimensionsMatch(
131 universe: DimensionUniverse,
132 expected: Set[str] | Set[Dimension],
133 actual: Set[str] | Set[Dimension],
134) -> None:
135 """Test whether two sets of dimensions agree after conversions.
137 Parameters
138 ----------
139 universe : `lsst.daf.butler.DimensionUniverse`
140 The set of all known dimensions.
141 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
142 The dimensions expected from a task specification.
143 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
144 The dimensions provided by input.
146 Raises
147 ------
148 ValueError
149 Raised if ``expected`` and ``actual`` cannot be reconciled.
150 """
151 if _simplify(universe, expected) != _simplify(universe, actual):
152 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
155def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]:
156 """Reduce a set of dimensions to a string-only form.
158 Parameters
159 ----------
160 universe : `lsst.daf.butler.DimensionUniverse`
161 The set of all known dimensions.
162 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
163 A set of dimensions to simplify.
165 Returns
166 -------
167 dimensions : `Set` [`str`]
168 A copy of ``dimensions`` reduced to string form, with all spatial
169 dimensions simplified to ``skypix``.
170 """
171 simplified: set[str] = set()
172 for dimension in dimensions:
173 # skypix not a real Dimension, handle it first
174 if dimension == "skypix":
175 simplified.add(dimension) # type: ignore
176 else:
177 # Need a Dimension to test spatialness
178 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
179 if isinstance(fullDimension, SkyPixDimension):
180 simplified.add("skypix")
181 else:
182 simplified.add(fullDimension.name)
183 return simplified
186def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None:
187 """Test whether data IDs are scalars for scalar connections and sequences
188 for multiple connections.
190 Parameters
191 ----------
192 name : `str`
193 The name of the connection being tested.
194 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
195 The data ID(s) provided for the connection.
196 multiple : `bool`
197 The ``multiple`` field of the connection.
199 Raises
200 ------
201 ValueError
202 Raised if ``dataIds`` and ``multiple`` do not match.
203 """
204 if multiple:
205 if not isinstance(dataIds, collections.abc.Sequence):
206 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
207 else:
208 # DataCoordinate is a Mapping
209 if not isinstance(dataIds, collections.abc.Mapping):
210 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
213def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]:
214 """Represent both single and multiple data IDs as a list.
216 Parameters
217 ----------
218 dataIds : any data ID type or `~collections.abc.Sequence` thereof
219 The data ID(s) provided for a particular input or output connection.
221 Returns
222 -------
223 normalizedIds : `~collections.abc.Sequence` [data ID]
224 A sequence equal to ``dataIds`` if it was already a sequence, or
225 ``[dataIds]`` if it was a single ID.
226 """
227 if isinstance(dataIds, collections.abc.Sequence):
228 return dataIds
229 else:
230 return [dataIds]
233def _refFromConnection(
234 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
235) -> DatasetRef:
236 """Create a DatasetRef for a connection in a collection.
238 Parameters
239 ----------
240 butler : `lsst.daf.butler.Butler`
241 The collection to point to.
242 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
243 The connection defining the dataset type to point to.
244 dataId
245 The data ID for the dataset to point to.
246 **kwargs
247 Additional keyword arguments used to augment or construct
248 a `~lsst.daf.butler.DataCoordinate`.
250 Returns
251 -------
252 ref : `lsst.daf.butler.DatasetRef`
253 A reference to a dataset compatible with ``connection``, with ID
254 ``dataId``, in the collection pointed to by ``butler``.
255 """
256 universe = butler.dimensions
257 # DatasetRef only tests if required dimension is missing, but not extras
258 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
259 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
261 datasetType = butler.registry.getDatasetType(connection.name)
263 try:
264 butler.registry.getDatasetType(datasetType.name)
265 except KeyError:
266 raise ValueError(f"Invalid dataset type {connection.name}.") from None
267 if not butler.run:
268 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.")
269 try:
270 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run])
271 if registry_ref:
272 ref = registry_ref
273 else:
274 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run)
275 butler.registry._importDatasets([ref])
276 return ref
277 except KeyError as e:
278 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
281def runTestQuantum(
282 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
283) -> unittest.mock.Mock | None:
284 """Run a PipelineTask on a Quantum.
286 Parameters
287 ----------
288 task : `lsst.pipe.base.PipelineTask`
289 The task to run on the quantum.
290 butler : `lsst.daf.butler.Butler`
291 The collection to run on.
292 quantum : `lsst.daf.butler.Quantum`
293 The quantum to run.
294 mockRun : `bool`
295 Whether or not to replace ``task``'s ``run`` method. The default of
296 `True` is recommended unless ``run`` needs to do real work (e.g.,
297 because the test needs real output datasets).
299 Returns
300 -------
301 run : `unittest.mock.Mock` or `None`
302 If ``mockRun`` is set, the mock that replaced ``run``. This object can
303 be queried for the arguments ``runQuantum`` passed to ``run``.
304 """
305 butlerQc = QuantumContext(butler, quantum)
306 # This is a type ignore, because `connections` is a dynamic class, but
307 # it for sure will have this property
308 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
309 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
310 if mockRun:
311 with (
312 unittest.mock.patch.object(task, "run") as mock,
313 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"),
314 ):
315 task.runQuantum(butlerQc, inputRefs, outputRefs)
316 return mock
317 else:
318 task.runQuantum(butlerQc, inputRefs, outputRefs)
319 return None
322def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
323 """Test that an attribute on an object matches the specification given in
324 a connection.
326 Parameters
327 ----------
328 obj
329 An object expected to contain the attribute ``attrName``.
330 attrName : `str`
331 The name of the attribute to be tested.
332 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
333 The connection, usually some type of output, specifying ``attrName``.
335 Raises
336 ------
337 AssertionError:
338 Raised if ``obj.attrName`` does not match what's expected
339 from ``connection``.
340 """
341 # name
342 try:
343 attrValue = obj.__getattribute__(attrName)
344 except AttributeError:
345 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None
346 # multiple
347 if connection.multiple:
348 if not isinstance(attrValue, collections.abc.Sequence):
349 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
350 else:
351 # use lazy evaluation to not use StorageClassFactory unless
352 # necessary
353 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
354 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
355 ):
356 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
357 # no test for storageClass, as I'm not sure how much persistence
358 # depends on duck-typing
361def assertValidOutput(task: PipelineTask, result: Struct) -> None:
362 """Test that the output of a call to ``run`` conforms to its own
363 connections.
365 Parameters
366 ----------
367 task : `lsst.pipe.base.PipelineTask`
368 The task whose connections need validation. This is a fully-configured
369 task object to support features such as optional outputs.
370 result : `lsst.pipe.base.Struct`
371 A result object produced by calling ``task.run``.
373 Raises
374 ------
375 AssertionError:
376 Raised if ``result`` does not match what's expected from ``task's``
377 connections.
378 """
379 # This is a type ignore, because `connections` is a dynamic class, but
380 # it for sure will have this property
381 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
383 for name in connections.outputs:
384 connection = connections.__getattribute__(name)
385 _assertAttributeMatchesConnection(result, name, connection)
388def assertValidInitOutput(task: PipelineTask) -> None:
389 """Test that a constructed task conforms to its own init-connections.
391 Parameters
392 ----------
393 task : `lsst.pipe.base.PipelineTask`
394 The task whose connections need validation.
396 Raises
397 ------
398 AssertionError:
399 Raised if ``task`` does not have the state expected from ``task's``
400 connections.
401 """
402 # This is a type ignore, because `connections` is a dynamic class, but
403 # it for sure will have this property
404 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
406 for name in connections.initOutputs:
407 connection = connections.__getattribute__(name)
408 _assertAttributeMatchesConnection(task, name, connection)
411def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]:
412 """Return the initInputs object that would have been passed to a
413 `~lsst.pipe.base.PipelineTask` constructor.
415 Parameters
416 ----------
417 butler : `lsst.daf.butler.Butler`
418 The repository to search for input datasets. Must have
419 pre-configured collections.
420 config : `lsst.pipe.base.PipelineTaskConfig`
421 The config for the task to be constructed.
423 Returns
424 -------
425 initInputs : `dict` [`str`]
426 A dictionary of objects in the format of the ``initInputs`` parameter
427 to `lsst.pipe.base.PipelineTask`.
428 """
429 connections = config.connections.ConnectionsClass(config=config)
430 initInputs = {}
431 for name in connections.initInputs:
432 attribute = getattr(connections, name)
433 # Get full dataset type to check for consistency problems
434 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass)
435 # All initInputs have empty data IDs
436 initInputs[name] = butler.get(dsType)
438 return initInputs
441def lintConnections(
442 connections: PipelineTaskConnections,
443 *,
444 checkMissingMultiple: bool = True,
445 checkUnnecessaryMultiple: bool = True,
446) -> None:
447 """Inspect a connections class for common errors.
449 These tests are designed to detect misuse of connections features in
450 standard designs. An unusually designed connections class may trigger
451 alerts despite being correctly written; specific checks can be turned off
452 using keywords.
454 Parameters
455 ----------
456 connections : `lsst.pipe.base.PipelineTaskConnections`-type
457 The connections class to test.
458 checkMissingMultiple : `bool`
459 Whether to test for single connections that would match multiple
460 datasets at run time.
461 checkUnnecessaryMultiple : `bool`
462 Whether to test for multiple connections that would only match
463 one dataset.
465 Raises
466 ------
467 AssertionError
468 Raised if any of the selected checks fail for any connection.
469 """
470 # Since all comparisons are inside the class, don't bother
471 # normalizing skypix.
472 quantumDimensions = connections.dimensions
474 errors = ""
475 # connectionTypes.DimensionedConnection is implementation detail,
476 # don't use it.
477 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
478 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
479 connDimensions = set(connection.dimensions)
480 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
481 errors += (
482 f"Connection {name} may be called with multiple values of "
483 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
484 )
485 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
486 errors += (
487 f"Connection {name} has multiple=True but can only be called with one "
488 f"value of {connDimensions} for each {quantumDimensions}.\n"
489 )
490 if errors:
491 raise AssertionError(errors)