Coverage for python/lsst/pipe/base/testUtils.py: 13%
130 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 02:49 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 02:49 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "assertValidInitOutput",
26 "assertValidOutput",
27 "getInitInputs",
28 "lintConnections",
29 "makeQuantum",
30 "runTestQuantum",
31]
34import collections.abc
35import itertools
36import unittest.mock
37from collections import defaultdict
38from collections.abc import Mapping, Sequence, Set
39from typing import TYPE_CHECKING, Any
41from lsst.daf.butler import (
42 Butler,
43 DataCoordinate,
44 DataId,
45 DatasetRef,
46 DatasetType,
47 Dimension,
48 DimensionUniverse,
49 Quantum,
50 SkyPixDimension,
51 StorageClassFactory,
52)
53from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
55from .butlerQuantumContext import ButlerQuantumContext
57if TYPE_CHECKING:
58 from .config import PipelineTaskConfig
59 from .connections import PipelineTaskConnections
60 from .pipelineTask import PipelineTask
61 from .struct import Struct
64def makeQuantum(
65 task: PipelineTask,
66 butler: Butler,
67 dataId: DataId,
68 ioDataIds: Mapping[str, DataId | Sequence[DataId]],
69) -> Quantum:
70 """Create a Quantum for a particular data ID(s).
72 Parameters
73 ----------
74 task : `lsst.pipe.base.PipelineTask`
75 The task whose processing the quantum represents.
76 butler : `lsst.daf.butler.Butler`
77 The collection the quantum refers to.
78 dataId: any data ID type
79 The data ID of the quantum. Must have the same dimensions as
80 ``task``'s connections class.
81 ioDataIds : `collections.abc.Mapping` [`str`]
82 A mapping keyed by input/output names. Values must be data IDs for
83 single connections and sequences of data IDs for multiple connections.
85 Returns
86 -------
87 quantum : `lsst.daf.butler.Quantum`
88 A quantum for ``task``, when called with ``dataIds``.
89 """
90 # This is a type ignore, because `connections` is a dynamic class, but
91 # it for sure will have this property
92 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
94 try:
95 _checkDimensionsMatch(butler.dimensions, connections.dimensions, dataId.keys())
96 except ValueError as e:
97 raise ValueError("Error in quantum dimensions.") from e
99 inputs = defaultdict(list)
100 outputs = defaultdict(list)
101 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
102 try:
103 connection = connections.__getattribute__(name)
104 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
105 ids = _normalizeDataIds(ioDataIds[name])
106 for id in ids:
107 ref = _refFromConnection(butler, connection, id)
108 inputs[ref.datasetType].append(ref)
109 except (ValueError, KeyError) as e:
110 raise ValueError(f"Error in connection {name}.") from e
111 for name in connections.outputs:
112 try:
113 connection = connections.__getattribute__(name)
114 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
115 ids = _normalizeDataIds(ioDataIds[name])
116 for id in ids:
117 ref = _refFromConnection(butler, connection, id)
118 outputs[ref.datasetType].append(ref)
119 except (ValueError, KeyError) as e:
120 raise ValueError(f"Error in connection {name}.") from e
121 quantum = Quantum(
122 taskClass=type(task),
123 dataId=DataCoordinate.standardize(dataId, universe=butler.dimensions),
124 inputs=inputs,
125 outputs=outputs,
126 )
127 return quantum
130def _checkDimensionsMatch(
131 universe: DimensionUniverse,
132 expected: Set[str] | Set[Dimension],
133 actual: Set[str] | Set[Dimension],
134) -> None:
135 """Test whether two sets of dimensions agree after conversions.
137 Parameters
138 ----------
139 universe : `lsst.daf.butler.DimensionUniverse`
140 The set of all known dimensions.
141 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
142 The dimensions expected from a task specification.
143 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
144 The dimensions provided by input.
146 Raises
147 ------
148 ValueError
149 Raised if ``expected`` and ``actual`` cannot be reconciled.
150 """
151 if _simplify(universe, expected) != _simplify(universe, actual):
152 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
155def _simplify(universe: DimensionUniverse, dimensions: Set[str] | Set[Dimension]) -> set[str]:
156 """Reduce a set of dimensions to a string-only form.
158 Parameters
159 ----------
160 universe : `lsst.daf.butler.DimensionUniverse`
161 The set of all known dimensions.
162 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
163 A set of dimensions to simplify.
165 Returns
166 -------
167 dimensions : `Set` [`str`]
168 A copy of ``dimensions`` reduced to string form, with all spatial
169 dimensions simplified to ``skypix``.
170 """
171 simplified: set[str] = set()
172 for dimension in dimensions:
173 # skypix not a real Dimension, handle it first
174 if dimension == "skypix":
175 simplified.add(dimension) # type: ignore
176 else:
177 # Need a Dimension to test spatialness
178 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
179 if isinstance(fullDimension, SkyPixDimension):
180 simplified.add("skypix")
181 else:
182 simplified.add(fullDimension.name)
183 return simplified
186def _checkDataIdMultiplicity(name: str, dataIds: DataId | Sequence[DataId], multiple: bool) -> None:
187 """Test whether data IDs are scalars for scalar connections and sequences
188 for multiple connections.
190 Parameters
191 ----------
192 name : `str`
193 The name of the connection being tested.
194 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
195 The data ID(s) provided for the connection.
196 multiple : `bool`
197 The ``multiple`` field of the connection.
199 Raises
200 ------
201 ValueError
202 Raised if ``dataIds`` and ``multiple`` do not match.
203 """
204 if multiple:
205 if not isinstance(dataIds, collections.abc.Sequence):
206 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
207 else:
208 # DataCoordinate is a Mapping
209 if not isinstance(dataIds, collections.abc.Mapping):
210 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
213def _normalizeDataIds(dataIds: DataId | Sequence[DataId]) -> Sequence[DataId]:
214 """Represent both single and multiple data IDs as a list.
216 Parameters
217 ----------
218 dataIds : any data ID type or `~collections.abc.Sequence` thereof
219 The data ID(s) provided for a particular input or output connection.
221 Returns
222 -------
223 normalizedIds : `~collections.abc.Sequence` [data ID]
224 A sequence equal to ``dataIds`` if it was already a sequence, or
225 ``[dataIds]`` if it was a single ID.
226 """
227 if isinstance(dataIds, collections.abc.Sequence):
228 return dataIds
229 else:
230 return [dataIds]
233def _refFromConnection(
234 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
235) -> DatasetRef:
236 """Create a DatasetRef for a connection in a collection.
238 Parameters
239 ----------
240 butler : `lsst.daf.butler.Butler`
241 The collection to point to.
242 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
243 The connection defining the dataset type to point to.
244 dataId
245 The data ID for the dataset to point to.
246 **kwargs
247 Additional keyword arguments used to augment or construct
248 a `~lsst.daf.butler.DataCoordinate`.
250 Returns
251 -------
252 ref : `lsst.daf.butler.DatasetRef`
253 A reference to a dataset compatible with ``connection``, with ID
254 ``dataId``, in the collection pointed to by ``butler``.
255 """
256 universe = butler.dimensions
257 # DatasetRef only tests if required dimension is missing, but not extras
258 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
259 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
261 datasetType = butler.registry.getDatasetType(connection.name)
263 try:
264 butler.registry.getDatasetType(datasetType.name)
265 except KeyError:
266 raise ValueError(f"Invalid dataset type {connection.name}.")
267 if not butler.run:
268 raise ValueError("Can not create a resolved DatasetRef since the butler has no default run defined.")
269 try:
270 registry_ref = butler.registry.findDataset(datasetType, dataId, collections=[butler.run])
271 if registry_ref:
272 ref = registry_ref
273 else:
274 ref = DatasetRef(datasetType=datasetType, dataId=dataId, run=butler.run)
275 butler.registry._importDatasets([ref])
276 return ref
277 except KeyError as e:
278 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
281def runTestQuantum(
282 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
283) -> unittest.mock.Mock | None:
284 """Run a PipelineTask on a Quantum.
286 Parameters
287 ----------
288 task : `lsst.pipe.base.PipelineTask`
289 The task to run on the quantum.
290 butler : `lsst.daf.butler.Butler`
291 The collection to run on.
292 quantum : `lsst.daf.butler.Quantum`
293 The quantum to run.
294 mockRun : `bool`
295 Whether or not to replace ``task``'s ``run`` method. The default of
296 `True` is recommended unless ``run`` needs to do real work (e.g.,
297 because the test needs real output datasets).
299 Returns
300 -------
301 run : `unittest.mock.Mock` or `None`
302 If ``mockRun`` is set, the mock that replaced ``run``. This object can
303 be queried for the arguments ``runQuantum`` passed to ``run``.
304 """
305 butlerQc = ButlerQuantumContext(butler, quantum)
306 # This is a type ignore, because `connections` is a dynamic class, but
307 # it for sure will have this property
308 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
309 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
310 if mockRun:
311 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch(
312 "lsst.pipe.base.ButlerQuantumContext.put"
313 ):
314 task.runQuantum(butlerQc, inputRefs, outputRefs)
315 return mock
316 else:
317 task.runQuantum(butlerQc, inputRefs, outputRefs)
318 return None
321def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
322 """Test that an attribute on an object matches the specification given in
323 a connection.
325 Parameters
326 ----------
327 obj
328 An object expected to contain the attribute ``attrName``.
329 attrName : `str`
330 The name of the attribute to be tested.
331 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
332 The connection, usually some type of output, specifying ``attrName``.
334 Raises
335 ------
336 AssertionError:
337 Raised if ``obj.attrName`` does not match what's expected
338 from ``connection``.
339 """
340 # name
341 try:
342 attrValue = obj.__getattribute__(attrName)
343 except AttributeError:
344 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
345 # multiple
346 if connection.multiple:
347 if not isinstance(attrValue, collections.abc.Sequence):
348 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
349 else:
350 # use lazy evaluation to not use StorageClassFactory unless
351 # necessary
352 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
353 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
354 ):
355 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
356 # no test for storageClass, as I'm not sure how much persistence
357 # depends on duck-typing
360def assertValidOutput(task: PipelineTask, result: Struct) -> None:
361 """Test that the output of a call to ``run`` conforms to its own
362 connections.
364 Parameters
365 ----------
366 task : `lsst.pipe.base.PipelineTask`
367 The task whose connections need validation. This is a fully-configured
368 task object to support features such as optional outputs.
369 result : `lsst.pipe.base.Struct`
370 A result object produced by calling ``task.run``.
372 Raises
373 ------
374 AssertionError:
375 Raised if ``result`` does not match what's expected from ``task's``
376 connections.
377 """
378 # This is a type ignore, because `connections` is a dynamic class, but
379 # it for sure will have this property
380 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
382 for name in connections.outputs:
383 connection = connections.__getattribute__(name)
384 _assertAttributeMatchesConnection(result, name, connection)
387def assertValidInitOutput(task: PipelineTask) -> None:
388 """Test that a constructed task conforms to its own init-connections.
390 Parameters
391 ----------
392 task : `lsst.pipe.base.PipelineTask`
393 The task whose connections need validation.
395 Raises
396 ------
397 AssertionError:
398 Raised if ``task`` does not have the state expected from ``task's``
399 connections.
400 """
401 # This is a type ignore, because `connections` is a dynamic class, but
402 # it for sure will have this property
403 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
405 for name in connections.initOutputs:
406 connection = connections.__getattribute__(name)
407 _assertAttributeMatchesConnection(task, name, connection)
410def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> dict[str, Any]:
411 """Return the initInputs object that would have been passed to a
412 `~lsst.pipe.base.PipelineTask` constructor.
414 Parameters
415 ----------
416 butler : `lsst.daf.butler.Butler`
417 The repository to search for input datasets. Must have
418 pre-configured collections.
419 config : `lsst.pipe.base.PipelineTaskConfig`
420 The config for the task to be constructed.
422 Returns
423 -------
424 initInputs : `dict` [`str`]
425 A dictionary of objects in the format of the ``initInputs`` parameter
426 to `lsst.pipe.base.PipelineTask`.
427 """
428 connections = config.connections.ConnectionsClass(config=config)
429 initInputs = {}
430 for name in connections.initInputs:
431 attribute = getattr(connections, name)
432 # Get full dataset type to check for consistency problems
433 dsType = DatasetType(attribute.name, butler.dimensions.extract(set()), attribute.storageClass)
434 # All initInputs have empty data IDs
435 initInputs[name] = butler.get(dsType)
437 return initInputs
440def lintConnections(
441 connections: PipelineTaskConnections,
442 *,
443 checkMissingMultiple: bool = True,
444 checkUnnecessaryMultiple: bool = True,
445) -> None:
446 """Inspect a connections class for common errors.
448 These tests are designed to detect misuse of connections features in
449 standard designs. An unusually designed connections class may trigger
450 alerts despite being correctly written; specific checks can be turned off
451 using keywords.
453 Parameters
454 ----------
455 connections : `lsst.pipe.base.PipelineTaskConnections`-type
456 The connections class to test.
457 checkMissingMultiple : `bool`
458 Whether to test for single connections that would match multiple
459 datasets at run time.
460 checkUnnecessaryMultiple : `bool`
461 Whether to test for multiple connections that would only match
462 one dataset.
464 Raises
465 ------
466 AssertionError
467 Raised if any of the selected checks fail for any connection.
468 """
469 # Since all comparisons are inside the class, don't bother
470 # normalizing skypix.
471 quantumDimensions = connections.dimensions
473 errors = ""
474 # connectionTypes.DimensionedConnection is implementation detail,
475 # don't use it.
476 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
477 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
478 connDimensions = set(connection.dimensions)
479 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
480 errors += (
481 f"Connection {name} may be called with multiple values of "
482 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
483 )
484 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
485 errors += (
486 f"Connection {name} has multiple=True but can only be called with one "
487 f"value of {connDimensions} for each {quantumDimensions}.\n"
488 )
489 if errors:
490 raise AssertionError(errors)