Coverage for python/lsst/pipe/base/testUtils.py: 13%
130 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = [
31 "assertValidInitOutput",
32 "assertValidOutput",
33 "getInitInputs",
34 "lintConnections",
35 "makeQuantum",
36 "runTestQuantum",
37]
40import collections.abc
41import itertools
42import unittest.mock
43from collections import defaultdict
44from collections.abc import Mapping, Sequence, Set
45from typing import TYPE_CHECKING, Any
47from lsst.daf.butler import (
48 Butler,
49 DataCoordinate,
50 DataId,
51 DatasetRef,
52 DatasetType,
53 Dimension,
54 DimensionUniverse,
55 Quantum,
56 SkyPixDimension,
57 StorageClassFactory,
58)
59from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
61from ._quantumContext import QuantumContext
63if TYPE_CHECKING:
64 from .config import PipelineTaskConfig
65 from .connections import PipelineTaskConnections
66 from .pipelineTask import PipelineTask
67 from .struct import Struct
70def makeQuantum(
71 task: PipelineTask,
72 butler: Butler,
73 dataId: DataId,
74 ioDataIds: Mapping[str, DataId | Sequence[DataId]],
75) -> Quantum:
76 """Create a Quantum for a particular data ID(s).
78 Parameters
79 ----------
80 task : `lsst.pipe.base.PipelineTask`
81 The task whose processing the quantum represents.
82 butler : `lsst.daf.butler.Butler`
83 The collection the quantum refers to.
84 dataId: any data ID type
85 The data ID of the quantum. Must have the same dimensions as
86 ``task``'s connections class.
87 ioDataIds : `collections.abc.Mapping` [`str`]
88 A mapping keyed by input/output names. Values must be data IDs for
89 single connections and sequences of data IDs for multiple connections.
91 Returns
92 -------
93 quantum : `lsst.daf.butler.Quantum`
94 A quantum for ``task``, when called with ``dataIds``.
95 """
96 # This is a type ignore, because `connections` is a dynamic class, but
97 # it for sure will have this property
98 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
100 try:
101 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys())
102 except ValueError as e:
103 raise ValueError("Error in quantum dimensions.") from e
105 inputs = defaultdict(list)
106 outputs = defaultdict(list)
107 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
108 try:
109 connection = connections.__getattribute__(name)
110 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
111 ids = _normalizeDataIds(ioDataIds[name])
112 for id in ids:
113 ref = _refFromConnection(butler, connection, id)
114 inputs[ref.datasetType].append(ref)
115 except (ValueError, KeyError) as e:
116 raise ValueError(f"Error in connection {name}.") from e
117 for name in connections.outputs:
118 try:
119 connection = connections.__getattribute__(name)
120 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
121 ids = _normalizeDataIds(ioDataIds[name])
122 for id in ids:
123 ref = _refFromConnection(butler, connection, id)
124 outputs[ref.datasetType].append(ref)
125 except (ValueError, KeyError) as e:
126 raise ValueError(f"Error in connection {name}.") from e
127 quantum = Quantum(
128 taskClass=type(task),
129 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions),
130 inputs=inputs,
131 outputs=outputs,
132 )
133 return quantum
136def _checkDimensionsMatch(
137 universe: DimensionUniverse,
138 expected: Set[str] | Set[Dimension],
139 actual: Set[str] | Set[Dimension],
140) -> None:
141 """Test whether two sets of dimensions agree after conversions.
143 Parameters
144 ----------
145 universe : `lsst.daf.butler.DimensionUniverse`
146 The set of all known dimensions.
147 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
148 The dimensions expected from a task specification.
149 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
150 The dimensions provided by input.
152 Raises
153 ------
154 ValueError
155 Raised if ``expected`` and ``actual`` cannot be reconciled.
156 """
157 if _simplify(universe, expected) != _simplify(universe, actual):
158 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
161def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]:
162 """Reduce a set of dimensions to a string-only form.
164 Parameters
165 ----------
166 universe : `lsst.daf.butler.DimensionUniverse`
167 The set of all known dimensions.
168 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
169 A set of dimensions to simplify.
171 Returns
172 -------
173 dimensions : `Set` [`str`]
174 A copy of ``dimensions`` reduced to string form, with all spatial
175 dimensions simplified to ``skypix``.
176 """
177 simplified: set[str] = set()
178 for dimension in dimensions:
179 # skypix not a real Dimension, handle it first
180 if dimension == "skypix":
181 simplified.add(dimension) # type: ignore
182 else:
183 # Need a Dimension to test spatialness
184 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
185 if isinstance(fullDimension, SkyPixDimension):
186 simplified.add("skypix")
187 else:
188 simplified.add(fullDimension.name)
189 return simplified
192def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None:
193 """Test whether data IDs are scalars for scalar connections and sequences
194 for multiple connections.
196 Parameters
197 ----------
198 name : `str`
199 The name of the connection being tested.
200 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
201 The data ID(s) provided for the connection.
202 multiple : `bool`
203 The ``multiple`` field of the connection.
205 Raises
206 ------
207 ValueError
208 Raised if ``dataIds`` and ``multiple`` do not match.
209 """
210 if multiple:
211 if not isinstance(dataIds, collections.abc.Sequence):
212 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
213 else:
214 # DataCoordinate is a Mapping
215 if not isinstance(dataIds, collections.abc.Mapping):
216 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
219def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]:
220 """Represent both single and multiple data IDs as a list.
222 Parameters
223 ----------
224 dataIds : any data ID type or `~collections.abc.Sequence` thereof
225 The data ID(s) provided for a particular input or output connection.
227 Returns
228 -------
229 normalizedIds : `~collections.abc.Sequence` [data ID]
230 A sequence equal to ``dataIds`` if it was already a sequence, or
231 ``[dataIds]`` if it was a single ID.
232 """
233 if isinstance(dataIds, collections.abc.Sequence):
234 return dataIds
235 else:
236 return [dataIds]
239def _refFromConnection(
240 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
241) -> DatasetRef:
242 """Create a DatasetRef for a connection in a collection.
244 Parameters
245 ----------
246 butler : `lsst.daf.butler.Butler`
247 The collection to point to.
248 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
249 The connection defining the dataset type to point to.
250 dataId
251 The data ID for the dataset to point to.
252 **kwargs
253 Additional keyword arguments used to augment or construct
254 a `~lsst.daf.butler.DataCoordinate`.
256 Returns
257 -------
258 ref : `lsst.daf.butler.DatasetRef`
259 A reference to a dataset compatible with ``connection``, with ID
260 ``dataId``, in the collection pointed to by ``butler``.
261 """
262 universe = butler.dimensions
263 # DatasetRef only tests if required dimension is missing, but not extras
264 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
265 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
267 datasetType = butler.registry.getDatasetType(connection.name)
269 try:
270 butler.registry.getDatasetType(datasetType.name)
271 except KeyError:
272 raise ValueError(f"Invalid dataset type {connection.name}.") from None
273 if not butler.run:
274 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.")
275 try:
276 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run])
277 if registry_ref:
278 ref = registry_ref
279 else:
280 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run)
281 butler.registry._importDatasets([ref])
282 return ref
283 except KeyError as e:
284 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
287def runTestQuantum(
288 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
289) -> unittest.mock.Mock | None:
290 """Run a PipelineTask on a Quantum.
292 Parameters
293 ----------
294 task : `lsst.pipe.base.PipelineTask`
295 The task to run on the quantum.
296 butler : `lsst.daf.butler.Butler`
297 The collection to run on.
298 quantum : `lsst.daf.butler.Quantum`
299 The quantum to run.
300 mockRun : `bool`
301 Whether or not to replace ``task``'s ``run`` method. The default of
302 `True` is recommended unless ``run`` needs to do real work (e.g.,
303 because the test needs real output datasets).
305 Returns
306 -------
307 run : `unittest.mock.Mock` or `None`
308 If ``mockRun`` is set, the mock that replaced ``run``. This object can
309 be queried for the arguments ``runQuantum`` passed to ``run``.
310 """
311 butlerQc = QuantumContext(butler, quantum)
312 # This is a type ignore, because `connections` is a dynamic class, but
313 # it for sure will have this property
314 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
315 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
316 if mockRun:
317 with (
318 unittest.mock.patch.object(task, "run") as mock,
319 unittest.mock.patch("lsst.pipe.base.QuantumContext.put"),
320 ):
321 task.runQuantum(butlerQc, inputRefs, outputRefs)
322 return mock
323 else:
324 task.runQuantum(butlerQc, inputRefs, outputRefs)
325 return None
328def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
329 """Test that an attribute on an object matches the specification given in
330 a connection.
332 Parameters
333 ----------
334 obj
335 An object expected to contain the attribute ``attrName``.
336 attrName : `str`
337 The name of the attribute to be tested.
338 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
339 The connection, usually some type of output, specifying ``attrName``.
341 Raises
342 ------
343 AssertionError:
344 Raised if ``obj.attrName`` does not match what's expected
345 from ``connection``.
346 """
347 # name
348 try:
349 attrValue = obj.__getattribute__(attrName)
350 except AttributeError:
351 raise AssertionError(f"No such attribute on {obj!r}: {attrName}") from None
352 # multiple
353 if connection.multiple:
354 if not isinstance(attrValue, collections.abc.Sequence):
355 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
356 else:
357 # use lazy evaluation to not use StorageClassFactory unless
358 # necessary
359 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
360 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
361 ):
362 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
363 # no test for storageClass, as I'm not sure how much persistence
364 # depends on duck-typing
367def assertValidOutput(task: PipelineTask, result: Struct) -> None:
368 """Test that the output of a call to ``run`` conforms to its own
369 connections.
371 Parameters
372 ----------
373 task : `lsst.pipe.base.PipelineTask`
374 The task whose connections need validation. This is a fully-configured
375 task object to support features such as optional outputs.
376 result : `lsst.pipe.base.Struct`
377 A result object produced by calling ``task.run``.
379 Raises
380 ------
381 AssertionError:
382 Raised if ``result`` does not match what's expected from ``task's``
383 connections.
384 """
385 # This is a type ignore, because `connections` is a dynamic class, but
386 # it for sure will have this property
387 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
389 for name in connections.outputs:
390 connection = connections.__getattribute__(name)
391 _assertAttributeMatchesConnection(result, name, connection)
394def assertValidInitOutput(task: PipelineTask) -> None:
395 """Test that a constructed task conforms to its own init-connections.
397 Parameters
398 ----------
399 task : `lsst.pipe.base.PipelineTask`
400 The task whose connections need validation.
402 Raises
403 ------
404 AssertionError:
405 Raised if ``task`` does not have the state expected from ``task's``
406 connections.
407 """
408 # This is a type ignore, because `connections` is a dynamic class, but
409 # it for sure will have this property
410 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
412 for name in connections.initOutputs:
413 connection = connections.__getattribute__(name)
414 _assertAttributeMatchesConnection(task, name, connection)
417def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]:
418 """Return the initInputs object that would have been passed to a
419 `~lsst.pipe.base.PipelineTask` constructor.
421 Parameters
422 ----------
423 butler : `lsst.daf.butler.Butler`
424 The repository to search for input datasets. Must have
425 pre-configured collections.
426 config : `lsst.pipe.base.PipelineTaskConfig`
427 The config for the task to be constructed.
429 Returns
430 -------
431 initInputs : `dict` [`str`]
432 A dictionary of objects in the format of the ``initInputs`` parameter
433 to `lsst.pipe.base.PipelineTask`.
434 """
435 connections = config.connections.ConnectionsClass(config=config)
436 initInputs = {}
437 for name in connections.initInputs:
438 attribute = getattr(connections, name)
439 # Get full dataset type to check for consistency problems
440 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass)
441 # All initInputs have empty data IDs
442 initInputs[name] = butler.get(dsType)
444 return initInputs
447def lintConnections(
448 connections: PipelineTaskConnections,
449 *,
450 checkMissingMultiple: bool = True,
451 checkUnnecessaryMultiple: bool = True,
452) -> None:
453 """Inspect a connections class for common errors.
455 These tests are designed to detect misuse of connections features in
456 standard designs. An unusually designed connections class may trigger
457 alerts despite being correctly written; specific checks can be turned off
458 using keywords.
460 Parameters
461 ----------
462 connections : `lsst.pipe.base.PipelineTaskConnections`-type
463 The connections class to test.
464 checkMissingMultiple : `bool`
465 Whether to test for single connections that would match multiple
466 datasets at run time.
467 checkUnnecessaryMultiple : `bool`
468 Whether to test for multiple connections that would only match
469 one dataset.
471 Raises
472 ------
473 AssertionError
474 Raised if any of the selected checks fail for any connection.
475 """
476 # Since all comparisons are inside the class, don't bother
477 # normalizing skypix.
478 quantumDimensions = connections.dimensions
480 errors = ""
481 # connectionTypes.DimensionedConnection is implementation detail,
482 # don't use it.
483 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
484 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
485 connDimensions = set(connection.dimensions)
486 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
487 errors += (
488 f"Connection {name} may be called with multiple values of "
489 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
490 )
491 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
492 errors += (
493 f"Connection {name} has multiple=True but can only be called with one "
494 f"value of {connDimensions} for each {quantumDimensions}.\n"
495 )
496 if errors:
497 raise AssertionError(errors)