Coverage for python/lsst/pipe/base/testUtils.py: 12%
146 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:10 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:10 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "assertValidInitOutput",
26 "assertValidOutput",
27 "getInitInputs",
28 "lintConnections",
29 "makeQuantum",
30 "runTestQuantum",
31]
34import collections.abc
35import itertools
36import unittest.mock
37from collections import defaultdict
38from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Mapping, Optional, Sequence, Set, Union
40from lsst.daf.butler import (
41 Butler,
42 DataCoordinate,
43 DataId,
44 DatasetRef,
45 DatasetType,
46 Dimension,
47 DimensionUniverse,
48 Quantum,
49 SkyPixDimension,
50 StorageClassFactory,
51)
52from lsst.pipe.base.connectionTypes import BaseConnection, DimensionedConnection
54from .butlerQuantumContext import ButlerQuantumContext
56if TYPE_CHECKING: 56 ↛ 57line 56 didn't jump to line 57, because the condition on line 56 was never true
57 from .config import PipelineTaskConfig
58 from .connections import PipelineTaskConnections
59 from .pipelineTask import PipelineTask
60 from .struct import Struct
63def makeQuantum(
64 task: PipelineTask,
65 butler: Butler,
66 dataId: DataId,
67 ioDataIds: Mapping[str, Union[DataId, Sequence[DataId]]],
68) -> Quantum:
69 """Create a Quantum for a particular data ID(s).
71 Parameters
72 ----------
73 task : `lsst.pipe.base.PipelineTask`
74 The task whose processing the quantum represents.
75 butler : `lsst.daf.butler.Butler`
76 The collection the quantum refers to.
77 dataId: any data ID type
78 The data ID of the quantum. Must have the same dimensions as
79 ``task``'s connections class.
80 ioDataIds : `collections.abc.Mapping` [`str`]
81 A mapping keyed by input/output names. Values must be data IDs for
82 single connections and sequences of data IDs for multiple connections.
84 Returns
85 -------
86 quantum : `lsst.daf.butler.Quantum`
87 A quantum for ``task``, when called with ``dataIds``.
88 """
89 # This is a type ignore, because `connections` is a dynamic class, but
90 # it for sure will have this property
91 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
93 try:
94 _checkDimensionsMatch(butler.registry.dimensions, connections.dimensions, dataId.keys())
95 except ValueError as e:
96 raise ValueError("Error in quantum dimensions.") from e
98 inputs = defaultdict(list)
99 outputs = defaultdict(list)
100 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
101 try:
102 connection = connections.__getattribute__(name)
103 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
104 ids = _normalizeDataIds(ioDataIds[name])
105 for id in ids:
106 ref = _refFromConnection(butler, connection, id)
107 inputs[ref.datasetType].append(ref)
108 except (ValueError, KeyError) as e:
109 raise ValueError(f"Error in connection {name}.") from e
110 for name in connections.outputs:
111 try:
112 connection = connections.__getattribute__(name)
113 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
114 ids = _normalizeDataIds(ioDataIds[name])
115 for id in ids:
116 ref = _refFromConnection(butler, connection, id)
117 outputs[ref.datasetType].append(ref)
118 except (ValueError, KeyError) as e:
119 raise ValueError(f"Error in connection {name}.") from e
120 quantum = Quantum(
121 taskClass=type(task),
122 dataId=DataCoordinate.standardize(dataId, universe=butler.registry.dimensions),
123 inputs=inputs,
124 outputs=outputs,
125 )
126 return quantum
129def _checkDimensionsMatch(
130 universe: DimensionUniverse,
131 expected: Union[AbstractSet[str], AbstractSet[Dimension]],
132 actual: Union[AbstractSet[str], AbstractSet[Dimension]],
133) -> None:
134 """Test whether two sets of dimensions agree after conversions.
136 Parameters
137 ----------
138 universe : `lsst.daf.butler.DimensionUniverse`
139 The set of all known dimensions.
140 expected : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
141 The dimensions expected from a task specification.
142 actual : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
143 The dimensions provided by input.
145 Raises
146 ------
147 ValueError
148 Raised if ``expected`` and ``actual`` cannot be reconciled.
149 """
150 if _simplify(universe, expected) != _simplify(universe, actual):
151 raise ValueError(f"Mismatch in dimensions; expected {expected} but got {actual}.")
154def _simplify(
155 universe: DimensionUniverse, dimensions: Union[AbstractSet[str], AbstractSet[Dimension]]
156) -> Set[str]:
157 """Reduce a set of dimensions to a string-only form.
159 Parameters
160 ----------
161 universe : `lsst.daf.butler.DimensionUniverse`
162 The set of all known dimensions.
163 dimensions : `Set` [`str`] or `Set` [`~lsst.daf.butler.Dimension`]
164 A set of dimensions to simplify.
166 Returns
167 -------
168 dimensions : `Set` [`str`]
169 A copy of ``dimensions`` reduced to string form, with all spatial
170 dimensions simplified to ``skypix``.
171 """
172 simplified: Set[str] = set()
173 for dimension in dimensions:
174 # skypix not a real Dimension, handle it first
175 if dimension == "skypix":
176 simplified.add(dimension) # type: ignore
177 else:
178 # Need a Dimension to test spatialness
179 fullDimension = universe[dimension] if isinstance(dimension, str) else dimension
180 if isinstance(fullDimension, SkyPixDimension):
181 simplified.add("skypix")
182 else:
183 simplified.add(fullDimension.name)
184 return simplified
187def _checkDataIdMultiplicity(name: str, dataIds: Union[DataId, Sequence[DataId]], multiple: bool) -> None:
188 """Test whether data IDs are scalars for scalar connections and sequences
189 for multiple connections.
191 Parameters
192 ----------
193 name : `str`
194 The name of the connection being tested.
195 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
196 The data ID(s) provided for the connection.
197 multiple : `bool`
198 The ``multiple`` field of the connection.
200 Raises
201 ------
202 ValueError
203 Raised if ``dataIds`` and ``multiple`` do not match.
204 """
205 if multiple:
206 if not isinstance(dataIds, collections.abc.Sequence):
207 raise ValueError(f"Expected multiple data IDs for {name}, got {dataIds}.")
208 else:
209 # DataCoordinate is a Mapping
210 if not isinstance(dataIds, collections.abc.Mapping):
211 raise ValueError(f"Expected single data ID for {name}, got {dataIds}.")
214def _normalizeDataIds(dataIds: Union[DataId, Sequence[DataId]]) -> Sequence[DataId]:
215 """Represent both single and multiple data IDs as a list.
217 Parameters
218 ----------
219 dataIds : any data ID type or `~collections.abc.Sequence` thereof
220 The data ID(s) provided for a particular input or output connection.
222 Returns
223 -------
224 normalizedIds : `~collections.abc.Sequence` [data ID]
225 A sequence equal to ``dataIds`` if it was already a sequence, or
226 ``[dataIds]`` if it was a single ID.
227 """
228 if isinstance(dataIds, collections.abc.Sequence):
229 return dataIds
230 else:
231 return [dataIds]
234def _refFromConnection(
235 butler: Butler, connection: DimensionedConnection, dataId: DataId, **kwargs: Any
236) -> DatasetRef:
237 """Create a DatasetRef for a connection in a collection.
239 Parameters
240 ----------
241 butler : `lsst.daf.butler.Butler`
242 The collection to point to.
243 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
244 The connection defining the dataset type to point to.
245 dataId
246 The data ID for the dataset to point to.
247 **kwargs
248 Additional keyword arguments used to augment or construct
249 a `~lsst.daf.butler.DataCoordinate`.
251 Returns
252 -------
253 ref : `lsst.daf.butler.DatasetRef`
254 A reference to a dataset compatible with ``connection``, with ID
255 ``dataId``, in the collection pointed to by ``butler``.
256 """
257 universe = butler.registry.dimensions
258 # DatasetRef only tests if required dimension is missing, but not extras
259 _checkDimensionsMatch(universe, set(connection.dimensions), dataId.keys())
260 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
262 # skypix is a PipelineTask alias for "some spatial index", Butler doesn't
263 # understand it. Code copied from TaskDatasetTypes.fromTaskDef
264 if "skypix" in connection.dimensions:
265 datasetType = butler.registry.getDatasetType(connection.name)
266 else:
267 datasetType = connection.makeDatasetType(universe)
269 try:
270 butler.registry.getDatasetType(datasetType.name)
271 except KeyError:
272 raise ValueError(f"Invalid dataset type {connection.name}.")
273 try:
274 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
275 return ref
276 except KeyError as e:
277 raise ValueError(f"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") from e
280def _resolveTestQuantumInputs(butler: Butler, quantum: Quantum) -> None:
281 """Look up all input datasets a test quantum in the `Registry` to resolve
282 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
283 ``run`` attributes).
285 Parameters
286 ----------
287 quantum : `~lsst.daf.butler.Quantum`
288 Single Quantum instance.
289 butler : `~lsst.daf.butler.Butler`
290 Data butler.
291 """
292 # TODO (DM-26819): This function is a direct copy of
293 # `lsst.ctrl.mpexec.SingleQuantumExecutor.updateQuantumInputs`, but the
294 # `runTestQuantum` function that calls it is essentially duplicating logic
295 # in that class as well (albeit not verbatim). We should probably move
296 # `SingleQuantumExecutor` to ``pipe_base`` and see if it is directly usable
297 # in test code instead of having these classes at all.
298 for refsForDatasetType in quantum.inputs.values():
299 newRefsForDatasetType = []
300 for ref in refsForDatasetType:
301 if ref.id is None:
302 resolvedRef = butler.registry.findDataset(
303 ref.datasetType, ref.dataId, collections=butler.collections
304 )
305 if resolvedRef is None:
306 raise ValueError(
307 f"Cannot find {ref.datasetType.name} with id {ref.dataId} "
308 f"in collections {butler.collections}."
309 )
310 newRefsForDatasetType.append(resolvedRef)
311 else:
312 newRefsForDatasetType.append(ref)
313 refsForDatasetType[:] = newRefsForDatasetType
316def runTestQuantum(
317 task: PipelineTask, butler: Butler, quantum: Quantum, mockRun: bool = True
318) -> Optional[unittest.mock.Mock]:
319 """Run a PipelineTask on a Quantum.
321 Parameters
322 ----------
323 task : `lsst.pipe.base.PipelineTask`
324 The task to run on the quantum.
325 butler : `lsst.daf.butler.Butler`
326 The collection to run on.
327 quantum : `lsst.daf.butler.Quantum`
328 The quantum to run.
329 mockRun : `bool`
330 Whether or not to replace ``task``'s ``run`` method. The default of
331 `True` is recommended unless ``run`` needs to do real work (e.g.,
332 because the test needs real output datasets).
334 Returns
335 -------
336 run : `unittest.mock.Mock` or `None`
337 If ``mockRun`` is set, the mock that replaced ``run``. This object can
338 be queried for the arguments ``runQuantum`` passed to ``run``.
339 """
340 _resolveTestQuantumInputs(butler, quantum)
341 butlerQc = ButlerQuantumContext(butler, quantum)
342 # This is a type ignore, because `connections` is a dynamic class, but
343 # it for sure will have this property
344 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
345 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
346 if mockRun:
347 with unittest.mock.patch.object(task, "run") as mock, unittest.mock.patch(
348 "lsst.pipe.base.ButlerQuantumContext.put"
349 ):
350 task.runQuantum(butlerQc, inputRefs, outputRefs)
351 return mock
352 else:
353 task.runQuantum(butlerQc, inputRefs, outputRefs)
354 return None
357def _assertAttributeMatchesConnection(obj: Any, attrName: str, connection: BaseConnection) -> None:
358 """Test that an attribute on an object matches the specification given in
359 a connection.
361 Parameters
362 ----------
363 obj
364 An object expected to contain the attribute ``attrName``.
365 attrName : `str`
366 The name of the attribute to be tested.
367 connection : `lsst.pipe.base.connectionTypes.BaseConnection`
368 The connection, usually some type of output, specifying ``attrName``.
370 Raises
371 ------
372 AssertionError:
373 Raised if ``obj.attrName`` does not match what's expected
374 from ``connection``.
375 """
376 # name
377 try:
378 attrValue = obj.__getattribute__(attrName)
379 except AttributeError:
380 raise AssertionError(f"No such attribute on {obj!r}: {attrName}")
381 # multiple
382 if connection.multiple:
383 if not isinstance(attrValue, collections.abc.Sequence):
384 raise AssertionError(f"Expected {attrName} to be a sequence, got {attrValue!r} instead.")
385 else:
386 # use lazy evaluation to not use StorageClassFactory unless
387 # necessary
388 if isinstance(attrValue, collections.abc.Sequence) and not issubclass(
389 StorageClassFactory().getStorageClass(connection.storageClass).pytype, collections.abc.Sequence
390 ):
391 raise AssertionError(f"Expected {attrName} to be a single value, got {attrValue!r} instead.")
392 # no test for storageClass, as I'm not sure how much persistence
393 # depends on duck-typing
396def assertValidOutput(task: PipelineTask, result: Struct) -> None:
397 """Test that the output of a call to ``run`` conforms to its own
398 connections.
400 Parameters
401 ----------
402 task : `lsst.pipe.base.PipelineTask`
403 The task whose connections need validation. This is a fully-configured
404 task object to support features such as optional outputs.
405 result : `lsst.pipe.base.Struct`
406 A result object produced by calling ``task.run``.
408 Raises
409 ------
410 AssertionError:
411 Raised if ``result`` does not match what's expected from ``task's``
412 connections.
413 """
414 # This is a type ignore, because `connections` is a dynamic class, but
415 # it for sure will have this property
416 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
418 for name in connections.outputs:
419 connection = connections.__getattribute__(name)
420 _assertAttributeMatchesConnection(result, name, connection)
423def assertValidInitOutput(task: PipelineTask) -> None:
424 """Test that a constructed task conforms to its own init-connections.
426 Parameters
427 ----------
428 task : `lsst.pipe.base.PipelineTask`
429 The task whose connections need validation.
431 Raises
432 ------
433 AssertionError:
434 Raised if ``task`` does not have the state expected from ``task's``
435 connections.
436 """
437 # This is a type ignore, because `connections` is a dynamic class, but
438 # it for sure will have this property
439 connections = task.config.ConnectionsClass(config=task.config) # type: ignore
441 for name in connections.initOutputs:
442 connection = connections.__getattribute__(name)
443 _assertAttributeMatchesConnection(task, name, connection)
446def getInitInputs(butler: Butler, config: PipelineTaskConfig) -> Dict[str, Any]:
447 """Return the initInputs object that would have been passed to a
448 `~lsst.pipe.base.PipelineTask` constructor.
450 Parameters
451 ----------
452 butler : `lsst.daf.butler.Butler`
453 The repository to search for input datasets. Must have
454 pre-configured collections.
455 config : `lsst.pipe.base.PipelineTaskConfig`
456 The config for the task to be constructed.
458 Returns
459 -------
460 initInputs : `dict` [`str`]
461 A dictionary of objects in the format of the ``initInputs`` parameter
462 to `lsst.pipe.base.PipelineTask`.
463 """
464 connections = config.connections.ConnectionsClass(config=config)
465 initInputs = {}
466 for name in connections.initInputs:
467 attribute = getattr(connections, name)
468 # Get full dataset type to check for consistency problems
469 dsType = DatasetType(
470 attribute.name, butler.registry.dimensions.extract(set()), attribute.storageClass
471 )
472 # All initInputs have empty data IDs
473 initInputs[name] = butler.get(dsType)
475 return initInputs
478def lintConnections(
479 connections: PipelineTaskConnections,
480 *,
481 checkMissingMultiple: bool = True,
482 checkUnnecessaryMultiple: bool = True,
483) -> None:
484 """Inspect a connections class for common errors.
486 These tests are designed to detect misuse of connections features in
487 standard designs. An unusually designed connections class may trigger
488 alerts despite being correctly written; specific checks can be turned off
489 using keywords.
491 Parameters
492 ----------
493 connections : `lsst.pipe.base.PipelineTaskConnections`-type
494 The connections class to test.
495 checkMissingMultiple : `bool`
496 Whether to test for single connections that would match multiple
497 datasets at run time.
498 checkUnnecessaryMultiple : `bool`
499 Whether to test for multiple connections that would only match
500 one dataset.
502 Raises
503 ------
504 AssertionError
505 Raised if any of the selected checks fail for any connection.
506 """
507 # Since all comparisons are inside the class, don't bother
508 # normalizing skypix.
509 quantumDimensions = connections.dimensions
511 errors = ""
512 # connectionTypes.DimensionedConnection is implementation detail,
513 # don't use it.
514 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs, connections.outputs):
515 connection: DimensionedConnection = connections.allConnections[name] # type: ignore
516 connDimensions = set(connection.dimensions)
517 if checkMissingMultiple and not connection.multiple and connDimensions > quantumDimensions:
518 errors += (
519 f"Connection {name} may be called with multiple values of "
520 f"{connDimensions - quantumDimensions} but has multiple=False.\n"
521 )
522 if checkUnnecessaryMultiple and connection.multiple and connDimensions <= quantumDimensions:
523 errors += (
524 f"Connection {name} has multiple=True but can only be called with one "
525 f"value of {connDimensions} for each {quantumDimensions}.\n"
526 )
527 if errors:
528 raise AssertionError(errors)