Coverage for python/lsst/pipe/base/testUtils.py: 13%
147 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 09:31 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 09:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "assertValidInitOutput",
26 "assertValidOutput",
27 "getInitInputs",
28 "lintConnections",
29 "makeQuantum",
30 "runTestQuantum",
31]
34import collections.abc
35import itertools
36import unittest.mock
37import warnings
38from collections import defaultdict
39from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union
41from lsst.daf.butler import (
42 Butler,
43 DataCoordinate,
44 DataId,
45 DatasetRef,
46 DatasetType,
47 Dimension,
48 DimensionUniverse,
49 Quantum,
50 SkyPixDimension,
51 StorageClassFactory,
52 UnresolvedRefWarning,
53)
54from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
56from .butlerQuantumContext import ButlerQuantumContext
58if TYPE_CHECKING: 58 ↛ 59line 58 didn't jump to line 59, because the condition on line 58 was never true
59 from .config import PipelineTaskConfig
60 from .connections import PipelineTaskConnections
61 from .pipelineTask import PipelineTask
62 from .struct import Struct
65def makeQuantum(
66 task: PipelineTask,
67 butler: Butler,
68 dataId: DataId,
69 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]],
70) -> Quantum:
71 """Create a Quantum for a particular data ID(s).
73 Parameters
74 ----------
75 task : `lsst.pipe.base.PipelineTask`
76 The task whose processing the quantum represents.
77 butler : `lsst.daf.butler.Butler`
78 The collection the quantum refers to.
79 dataId: any data ID type
80 The data ID of the quantum. Must have the same dimensions as
81 ``task``'s connections class.
82 ioDataIds : `collections.abc.Mapping` [`str`]
83 A mapping keyed by input/output names. Values must be data IDs for
84 single connections and sequences of data IDs for multiple connections.
86 Returns
87 -------
88 quantum : `lsst.daf.butler.Quantum`
89 A quantum for ``task``, when called with ``dataIds``.
90 """
91 # This is a type ignore, because `connections` is a dynamic class, but
92 # it for sure will have this property
93 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
95 try:
96 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
97 except ValueError as e:
98 raise ValueError("Error in quantum dimensions.") from e
100 inputs = defaultdict(list)
101 outputs = defaultdict(list)
102 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
103 try:
104 connection = connections.__getattribute__(name)
105 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
106 ids = _normalizeDataIds(ioDataIds[name])
107 for id in ids:
108 ref = _refFromConnection(butler, connection, id)
109 inputs[ref.datasetType].append(ref)
110 except (ValueError, KeyError) as e:
111 raise ValueError(f"Error in connection {name}.") from e
112 for name in connections.outputs:
113 try:
114 connection = connections.__getattribute__(name)
115 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
116 ids = _normalizeDataIds(ioDataIds[name])
117 for id in ids:
118 ref = _refFromConnection(butler, connection, id)
119 outputs[ref.datasetType].append(ref)
120 except (ValueError, KeyError) as e:
121 raise ValueError(f"Error in connection {name}.") from e
122 quantum = Quantum(
123 taskClass=type(task),
124 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions),
125 inputs=inputs,
126 outputs=outputs,
127 )
128 return quantum
131def _checkDimensionsMatch(
132 universe: DimensionUniverse,
133 expected: Union[AbstractSet[str], AbstractSet[Dimension]],
134 actual: Union[AbstractSet[str], AbstractSet[Dimension]],
135) -> None:
136 """Test whether two sets of dimensions agree after conversions.
138 Parameters
139 ----------
140 universe : `lsst.daf.butler.DimensionUniverse`
141 The set of all known dimensions.
142 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
143 The dimensions expected from a task specification.
144 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
145 The dimensions provided by input.
147 Raises
148 ------
149 ValueError
150 Raised if ``expected`` and ``actual`` cannot be reconciled.
151 """
152 if _simplify(universe, expected) != _simplify(universe, actual):
153 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
156def _simplify(
157 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]]
158) -> Set[str]:
159 """Reduce a set of dimensions to a string-only form.
161 Parameters
162 ----------
163 universe : `lsst.daf.butler.DimensionUniverse`
164 The set of all known dimensions.
165 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
166 A set of dimensions to simplify.
168 Returns
169 -------
170 dimensions : `Set` [`str`]
171 A copy of ``dimensions`` reduced to string form, with all spatial
172 dimensions simplified to ``skypix``.
173 """
174 simplified: Set[str] = set()
175 for dimension in dimensions:
176 # skypix not a real Dimension, handle it first
177 if dimension == "skypix":
178 simplified.add(dimension) # type: ignore
179 else:
180 # Need a Dimension to test spatialness
181 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
182 if isinstance(fullDimension, SkyPixDimension):
183 simplified.add("skypix")
184 else:
185 simplified.add(fullDimension.name)
186 return simplified
189def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None:
190 """Test whether data IDs are scalars for scalar connections and sequences
191 for multiple connections.
193 Parameters
194 ----------
195 name : `str`
196 The name of the connection being tested.
197 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
198 The data ID(s) provided for the connection.
199 multiple : `bool`
200 The ``multiple`` field of the connection.
202 Raises
203 ------
204 ValueError
205 Raised if ``dataIds`` and ``multiple`` do not match.
206 """
207 if multiple:
208 if not isinstance(dataIds, collections.abc.Sequence):
209 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
210 else:
211 # DataCoordinate is a Mapping
212 if not isinstance(dataIds, collections.abc.Mapping):
213 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
216def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]:
217 """Represent both single and multiple data IDs as a list.
219 Parameters
220 ----------
221 dataIds : any data ID type or `~collections.abc.Sequence` thereof
222 The data ID(s) provided for a particular input or output connection.
224 Returns
225 -------
226 normalizedIds : `~collections.abc.Sequence` [data ID]
227 A sequence equal to ``dataIds`` if it was already a sequence, or
228 ``[dataIds]`` if it was a single ID.
229 """
230 if isinstance(dataIds, collections.abc.Sequence):
231 return dataIds
232 else:
233 return [dataIds]
236def _refFromConnection(
237 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
238) -> DatasetRef:
239 """Create a DatasetRef for a connection in a collection.
241 Parameters
242 ----------
243 butler : `lsst.daf.butler.Butler`
244 The collection to point to.
245 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
246 The connection defining the dataset type to point to.
247 dataId
248 The data ID for the dataset to point to.
249 **kwargs
250 Additional keyword arguments used to augment or construct
251 a `~lsst.daf.butler.DataCoordinate`.
253 Returns
254 -------
255 ref : `lsst.daf.butler.DatasetRef`
256 A reference to a dataset compatible with ``connection``, with ID
257 ``dataId``, in the collection pointed to by ``butler``.
258 """
259 universe = butler.registry.dimensions
260 # DatasetRef only tests if required dimension is missing, but not extras
261 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
262 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
264 datasetType = butler.registry.getDatasetType(connection.name)
266 try:
267 butler.registry.getDatasetType(datasetType.name)
268 except KeyError:
269 raise ValueError(f"Invalid dataset type {connection.name}.")
270 try:
271 with warnings.catch_warnings():
272 warnings.simplefilter("ignore", category=UnresolvedRefWarning)
273 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
274 return ref
275 except KeyError as e:
276 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
279def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None:
280 """Look up all input datasets a test quantum in the `Registry` to resolve
281 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
282 ``run`` attributes).
284 Parameters
285 ----------
286 quantum : `~lsst.daf.butler.Quantum`
287 Single Quantum instance.
288 butler : `~lsst.daf.butler.Butler`
289 Data butler.
290 """
291 # TODO (DM-26819): This function is a direct copy of
292 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
293 # `runTestQuantum` function that calls it is essentially duplicating logic
294 # in that class as well (albeit not verbatim). We should probably move
295 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
296 # in test code instead of having these classes at all.
297 for refsForDatasetType in quantum.inputs.values():
298 newRefsForDatasetType = []
299 for ref in refsForDatasetType:
300 if ref.id is None:
301 resolvedRef = butler.registry.findDataset(
302 ref.datasetType, ref.dataId, collections=butler.collections
303 )
304 if resolvedRef is None:
305 raise ValueError(
306 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
307 f"in collections {butler.collections}."
308 )
309 newRefsForDatasetType.append(resolvedRef)
310 else:
311 newRefsForDatasetType.append(ref)
312 refsForDatasetType[:] = newRefsForDatasetType
315def runTestQuantum(
316 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
317) -> Optional[unittest.mock.Mock]:
318 """Run a PipelineTask on a Quantum.
320 Parameters
321 ----------
322 task : `lsst.pipe.base.PipelineTask`
323 The task to run on the quantum.
324 butler : `lsst.daf.butler.Butler`
325 The collection to run on.
326 quantum : `lsst.daf.butler.Quantum`
327 The quantum to run.
328 mockRun : `bool`
329 Whether or not to replace ``task``'s ``run`` method. The default of
330 `True` is recommended unless ``run`` needs to do real work (e.g.,
331 because the test needs real output datasets).
333 Returns
334 -------
335 run : `unittest.mock.Mock` or `None`
336 If ``mockRun`` is set, the mock that replaced ``run``. This object can
337 be queried for the arguments ``runQuantum`` passed to ``run``.
338 """
339 _resolveTestQuantumInputs(butler, quantum)
340 butlerQc = ButlerQuantumContext.from_full(butler, quantum)
341 # This is a type ignore, because `connections` is a dynamic class, but
342 # it for sure will have this property
343 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
344 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
345 if mockRun:
346 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch(
347 "lsst.pipe.base.ButlerQuantumContext.put"
348 ):
349 task.runQuantum(butlerQc, inputRefs, outputRefs)
350 return mock
351 else:
352 task.runQuantum(butlerQc, inputRefs, outputRefs)
353 return None
356def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
357 """Test that an attribute on an object matches the specification given in
358 a connection.
360 Parameters
361 ----------
362 obj
363 An object expected to contain the attribute ``attrName``.
364 attrName : `str`
365 The name of the attribute to be tested.
366 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
367 The connection, usually some type of output, specifying ``attrName``.
369 Raises
370 ------
371 AssertionError:
372 Raised if ``obj.attrName`` does not match what's expected
373 from ``connection``.
374 """
375 # name
376 try:
377 attrValue = obj.__getattribute__(attrName)
378 except AttributeError:
379 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
380 # multiple
381 if connection.multiple:
382 if not isinstance(attrValue, collections.abc.Sequence):
383 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
384 else:
385 # use lazy evaluation to not use StorageClassFactory unless
386 # necessary
387 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
388 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
389 ):
390 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
391 # no test for storageClass, as I'm not sure how much persistence
392 # depends on duck-typing
395def assertValidOutput(task: PipelineTask, result: Struct) -> None:
396 """Test that the output of a call to ``run`` conforms to its own
397 connections.
399 Parameters
400 ----------
401 task : `lsst.pipe.base.PipelineTask`
402 The task whose connections need validation. This is a fully-configured
403 task object to support features such as optional outputs.
404 result : `lsst.pipe.base.Struct`
405 A result object produced by calling ``task.run``.
407 Raises
408 ------
409 AssertionError:
410 Raised if ``result`` does not match what's expected from ``task's``
411 connections.
412 """
413 # This is a type ignore, because `connections` is a dynamic class, but
414 # it for sure will have this property
415 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
417 for name in connections.outputs:
418 connection = connections.__getattribute__(name)
419 _assertAttributeMatchesConnection(result, name, connection)
422def assertValidInitOutput(task: PipelineTask) -> None:
423 """Test that a constructed task conforms to its own init-connections.
425 Parameters
426 ----------
427 task : `lsst.pipe.base.PipelineTask`
428 The task whose connections need validation.
430 Raises
431 ------
432 AssertionError:
433 Raised if ``task`` does not have the state expected from ``task's``
434 connections.
435 """
436 # This is a type ignore, because `connections` is a dynamic class, but
437 # it for sure will have this property
438 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
440 for name in connections.initOutputs:
441 connection = connections.__getattribute__(name)
442 _assertAttributeMatchesConnection(task, name, connection)
445def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]:
446 """Return the initInputs object that would have been passed to a
447 `~lsst.pipe.base.PipelineTask` constructor.
449 Parameters
450 ----------
451 butler : `lsst.daf.butler.Butler`
452 The repository to search for input datasets. Must have
453 pre-configured collections.
454 config : `lsst.pipe.base.PipelineTaskConfig`
455 The config for the task to be constructed.
457 Returns
458 -------
459 initInputs : `dict` [`str`]
460 A dictionary of objects in the format of the ``initInputs`` parameter
461 to `lsst.pipe.base.PipelineTask`.
462 """
463 connections = config.connections.ConnectionsClass(config=config)
464 initInputs = {}
465 for name in connections.initInputs:
466 attribute = getattr(connections, name)
467 # Get full dataset type to check for consistency problems
468 dsType = DatasetType(
469 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass
470 )
471 # All initInputs have empty data IDs
472 initInputs[name] = butler.get(dsType)
474 return initInputs
477def lintConnections(
478 connections: PipelineTaskConnections,
479 *,
480 checkMissingMultiple: bool = True,
481 checkUnnecessaryMultiple: bool = True,
482) -> None:
483 """Inspect a connections class for common errors.
485 These tests are designed to detect misuse of connections features in
486 standard designs. An unusually designed connections class may trigger
487 alerts despite being correctly written; specific checks can be turned off
488 using keywords.
490 Parameters
491 ----------
492 connections : `lsst.pipe.base.PipelineTaskConnections`-type
493 The connections class to test.
494 checkMissingMultiple : `bool`
495 Whether to test for single connections that would match multiple
496 datasets at run time.
497 checkUnnecessaryMultiple : `bool`
498 Whether to test for multiple connections that would only match
499 one dataset.
501 Raises
502 ------
503 AssertionError
504 Raised if any of the selected checks fail for any connection.
505 """
506 # Since all comparisons are inside the class, don't bother
507 # normalizing skypix.
508 quantumDimensions = connections.dimensions
510 errors = ""
511 # connectionTypes.DimensionedConnection is implementation detail,
512 # don't use it.
513 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
514 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
515 connDimensions = set(connection.dimensions)
516 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
517 errors += (
518 f"Connection {name} may be called with multiple values of "
519 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
520 )
521 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
522 errors += (
523 f"Connection {name} has multiple=True but can only be called with one "
524 f"value of {connDimensions} for each {quantumDimensions}.\n"
525 )
526 if errors:
527 raise AssertionError(errors)