Coverage for python/lsst/pipe/base/connections.py: 46%
225 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-14 02:24 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-14 02:24 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError",
36]
38import itertools
39import string
40import typing
41from collections import UserDict
42from dataclasses import dataclass
43from types import SimpleNamespace
44from typing import Any, ClassVar, Dict, Iterable, List, Set, Union
46from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
48from ._status import NoWorkFound
49from .connectionTypes import (
50 BaseConnection,
51 BaseInput,
52 InitInput,
53 InitOutput,
54 Input,
55 Output,
56 PrerequisiteInput,
57)
59if typing.TYPE_CHECKING: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true
60 from .config import PipelineTaskConfig
63class ScalarError(TypeError):
64 """Exception raised when dataset type is configured as scalar
65 but there are multiple data IDs in a Quantum for that dataset.
66 """
69class PipelineTaskConnectionDict(UserDict):
70 """This is a special dict class used by PipelineTaskConnectionMetaclass
72 This dict is used in PipelineTaskConnection class creation, as the
73 dictionary that is initially used as __dict__. It exists to
74 intercept connection fields declared in a PipelineTaskConnection, and
75 what name is used to identify them. The names are then added to class
76 level list according to the connection type of the class attribute. The
77 names are also used as keys in a class level dictionary associated with
78 the corresponding class attribute. This information is a duplicate of
79 what exists in __dict__, but provides a simple place to lookup and
80 iterate on only these variables.
81 """
83 def __init__(self, *args: Any, **kwargs: Any):
84 super().__init__(*args, **kwargs)
85 # Initialize class level variables used to track any declared
86 # class level variables that are instances of
87 # connectionTypes.BaseConnection
88 self.data["inputs"] = []
89 self.data["prerequisiteInputs"] = []
90 self.data["outputs"] = []
91 self.data["initInputs"] = []
92 self.data["initOutputs"] = []
93 self.data["allConnections"] = {}
95 def __setitem__(self, name: str, value: Any) -> None:
96 if isinstance(value, Input):
97 self.data["inputs"].append(name)
98 elif isinstance(value, PrerequisiteInput):
99 self.data["prerequisiteInputs"].append(name)
100 elif isinstance(value, Output):
101 self.data["outputs"].append(name)
102 elif isinstance(value, InitInput):
103 self.data["initInputs"].append(name)
104 elif isinstance(value, InitOutput):
105 self.data["initOutputs"].append(name)
106 # This should not be an elif, as it needs tested for
107 # everything that inherits from BaseConnection
108 if isinstance(value, BaseConnection):
109 object.__setattr__(value, "varName", name)
110 self.data["allConnections"][name] = value
111 # defer to the default behavior
112 super().__setitem__(name, value)
115class PipelineTaskConnectionsMetaclass(type):
116 """Metaclass used in the declaration of PipelineTaskConnections classes"""
118 def __prepare__(name, bases, **kwargs): # noqa: 805
119 # Create an instance of our special dict to catch and track all
120 # variables that are instances of connectionTypes.BaseConnection
121 # Copy any existing connections from a parent class
122 dct = PipelineTaskConnectionDict()
123 for base in bases:
124 if isinstance(base, PipelineTaskConnectionsMetaclass): 124 ↛ 123line 124 didn't jump to line 123, because the condition on line 124 was never false
125 for name, value in base.allConnections.items(): 125 ↛ 126line 125 didn't jump to line 126, because the loop on line 125 never started
126 dct[name] = value
127 return dct
129 def __new__(cls, name, bases, dct, **kwargs):
130 dimensionsValueError = TypeError(
131 "PipelineTaskConnections class must be created with a dimensions "
132 "attribute which is an iterable of dimension names"
133 )
135 if name != "PipelineTaskConnections":
136 # Verify that dimensions are passed as a keyword in class
137 # declaration
138 if "dimensions" not in kwargs: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true
139 for base in bases:
140 if hasattr(base, "dimensions"):
141 kwargs["dimensions"] = base.dimensions
142 break
143 if "dimensions" not in kwargs:
144 raise dimensionsValueError
145 try:
146 if isinstance(kwargs["dimensions"], str): 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true
147 raise TypeError(
148 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
149 )
150 if not isinstance(kwargs["dimensions"], typing.Iterable): 150 ↛ 151line 150 didn't jump to line 151, because the condition on line 150 was never true
151 raise TypeError("Dimensions must be iterable of dimensions")
152 dct["dimensions"] = set(kwargs["dimensions"])
153 except TypeError as exc:
154 raise dimensionsValueError from exc
155 # Lookup any python string templates that may have been used in the
156 # declaration of the name field of a class connection attribute
157 allTemplates = set()
158 stringFormatter = string.Formatter()
159 # Loop over all connections
160 for obj in dct["allConnections"].values():
161 nameValue = obj.name
162 # add all the parameters to the set of templates
163 for param in stringFormatter.parse(nameValue):
164 if param[1] is not None:
165 allTemplates.add(param[1])
167 # look up any template from base classes and merge them all
168 # together
169 mergeDict = {}
170 for base in bases[::-1]:
171 if hasattr(base, "defaultTemplates"): 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true
172 mergeDict.update(base.defaultTemplates)
173 if "defaultTemplates" in kwargs:
174 mergeDict.update(kwargs["defaultTemplates"])
176 if len(mergeDict) > 0:
177 kwargs["defaultTemplates"] = mergeDict
179 # Verify that if templated strings were used, defaults were
180 # supplied as an argument in the declaration of the connection
181 # class
182 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 182 ↛ 183line 182 didn't jump to line 183, because the condition on line 182 was never true
183 raise TypeError(
184 "PipelineTaskConnection class contains templated attribute names, but no "
185 "defaut templates were provided, add a dictionary attribute named "
186 "defaultTemplates which contains the mapping between template key and value"
187 )
188 if len(allTemplates) > 0:
189 # Verify all templates have a default, and throw if they do not
190 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
191 templateDifference = allTemplates.difference(defaultTemplateKeys)
192 if templateDifference: 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true
193 raise TypeError(f"Default template keys were not provided for {templateDifference}")
194 # Verify that templates do not share names with variable names
195 # used for a connection, this is needed because of how
196 # templates are specified in an associated config class.
197 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
198 if len(nameTemplateIntersection) > 0: 198 ↛ 199line 198 didn't jump to line 199, because the condition on line 198 was never true
199 raise TypeError(
200 "Template parameters cannot share names with Class attributes"
201 f" (conflicts are {nameTemplateIntersection})."
202 )
203 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
205 # Convert all the connection containers into frozensets so they cannot
206 # be modified at the class scope
207 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
208 dct[connectionName] = frozenset(dct[connectionName])
209 # our custom dict type must be turned into an actual dict to be used in
210 # type.__new__
211 return super().__new__(cls, name, bases, dict(dct))
213 def __init__(cls, name, bases, dct, **kwargs):
214 # This overrides the default init to drop the kwargs argument. Python
215 # metaclasses will have this argument set if any kwargs are passes at
216 # class construction time, but should be consumed before calling
217 # __init__ on the type metaclass. This is in accordance with python
218 # documentation on metaclasses
219 super().__init__(name, bases, dct)
222class QuantizedConnection(SimpleNamespace):
223 """A Namespace to map defined variable names of connections to the
224 associated `lsst.daf.butler.DatasetRef` objects.
226 This class maps the names used to define a connection on a
227 PipelineTaskConnectionsClass to the corresponding
228 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
229 instance. This will be a quantum of execution based on the graph created
230 by examining all the connections defined on the
231 `PipelineTaskConnectionsClass`.
232 """
234 def __init__(self, **kwargs):
235 # Create a variable to track what attributes are added. This is used
236 # later when iterating over this QuantizedConnection instance
237 object.__setattr__(self, "_attributes", set())
239 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]) -> None:
240 # Capture the attribute name as it is added to this object
241 self._attributes.add(name)
242 super().__setattr__(name, value)
244 def __delattr__(self, name):
245 object.__delattr__(self, name)
246 self._attributes.remove(name)
248 def __len__(self) -> int:
249 return len(self._attributes)
251 def __iter__(
252 self,
253 ) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, typing.List[DatasetRef]]], None, None]:
254 """Make an Iterator for this QuantizedConnection
256 Iterating over a QuantizedConnection will yield a tuple with the name
257 of an attribute and the value associated with that name. This is
258 similar to dict.items() but is on the namespace attributes rather than
259 dict keys.
260 """
261 yield from ((name, getattr(self, name)) for name in self._attributes)
263 def keys(self) -> typing.Generator[str, None, None]:
264 """Returns an iterator over all the attributes added to a
265 QuantizedConnection class
266 """
267 yield from self._attributes
270class InputQuantizedConnection(QuantizedConnection):
271 pass
274class OutputQuantizedConnection(QuantizedConnection):
275 pass
278@dataclass(frozen=True)
279class DeferredDatasetRef:
280 """A wrapper class for `DatasetRef` that indicates that a `PipelineTask`
281 should receive a `DeferredDatasetHandle` instead of an in-memory dataset.
283 Parameters
284 ----------
285 datasetRef : `lsst.daf.butler.DatasetRef`
286 The `lsst.daf.butler.DatasetRef` that will be eventually used to
287 resolve a dataset
288 """
290 datasetRef: DatasetRef
292 @property
293 def datasetType(self) -> DatasetType:
294 """The dataset type for this dataset."""
295 return self.datasetRef.datasetType
297 @property
298 def dataId(self) -> DataCoordinate:
299 """The data ID for this dataset."""
300 return self.datasetRef.dataId
303class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
304 """PipelineTaskConnections is a class used to declare desired IO when a
305 PipelineTask is run by an activator
307 Parameters
308 ----------
309 config : `PipelineTaskConfig`
310 A `PipelineTaskConfig` class instance whose class has been configured
311 to use this `PipelineTaskConnectionsClass`
313 See also
314 --------
315 iterConnections
317 Notes
318 -----
319 ``PipelineTaskConnection`` classes are created by declaring class
320 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
321 listed as follows:
323 * ``InitInput`` - Defines connections in a quantum graph which are used as
324 inputs to the ``__init__`` function of the `PipelineTask` corresponding
325 to this class
326 * ``InitOuput`` - Defines connections in a quantum graph which are to be
327 persisted using a butler at the end of the ``__init__`` function of the
328 `PipelineTask` corresponding to this class. The variable name used to
329 define this connection should be the same as an attribute name on the
330 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
331 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
332 a `PipelineTask` instance should have an attribute
333 ``self.outputSchema`` defined. Its value is what will be saved by the
334 activator framework.
335 * ``PrerequisiteInput`` - An input connection type that defines a
336 `lsst.daf.butler.DatasetType` that must be present at execution time,
337 but that will not be used during the course of creating the quantum
338 graph to be executed. These most often are things produced outside the
339 processing pipeline, such as reference catalogs.
340 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
341 in the ``run`` method of a `PipelineTask`. The name used to declare
342 class attribute must match a function argument name in the ``run``
343 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
344 defines an ``Input`` with the name ``calexp``, then the corresponding
345 signature should be ``PipelineTask.run(calexp, ...)``
346 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
347 execution of a `PipelineTask`. The name used to declare the connection
348 must correspond to an attribute of a `Struct` that is returned by a
349 `PipelineTask` ``run`` method. E.g. if an output connection is
350 defined with the name ``measCat``, then the corresponding
351 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
352 X matches the ``storageClass`` type defined on the output connection.
354 The process of declaring a ``PipelineTaskConnection`` class involves
355 parameters passed in the declaration statement.
357 The first parameter is ``dimensions`` which is an iterable of strings which
358 defines the unit of processing the run method of a corresponding
359 `PipelineTask` will operate on. These dimensions must match dimensions that
360 exist in the butler registry which will be used in executing the
361 corresponding `PipelineTask`.
363 The second parameter is labeled ``defaultTemplates`` and is conditionally
364 optional. The name attributes of connections can be specified as python
365 format strings, with named format arguments. If any of the name parameters
366 on connections defined in a `PipelineTaskConnections` class contain a
367 template, then a default template value must be specified in the
368 ``defaultTemplates`` argument. This is done by passing a dictionary with
369 keys corresponding to a template identifier, and values corresponding to
370 the value to use as a default when formatting the string. For example if
371 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
372 ``defaultTemplates`` = {'input': 'deep'}.
374 Once a `PipelineTaskConnections` class is created, it is used in the
375 creation of a `PipelineTaskConfig`. This is further documented in the
376 documentation of `PipelineTaskConfig`. For the purposes of this
377 documentation, the relevant information is that the config class allows
378 configuration of connection names by users when running a pipeline.
380 Instances of a `PipelineTaskConnections` class are used by the pipeline
381 task execution framework to introspect what a corresponding `PipelineTask`
382 will require, and what it will produce.
384 Examples
385 --------
386 >>> from lsst.pipe.base import connectionTypes as cT
387 >>> from lsst.pipe.base import PipelineTaskConnections
388 >>> from lsst.pipe.base import PipelineTaskConfig
389 >>> class ExampleConnections(PipelineTaskConnections,
390 ... dimensions=("A", "B"),
391 ... defaultTemplates={"foo": "Example"}):
392 ... inputConnection = cT.Input(doc="Example input",
393 ... dimensions=("A", "B"),
394 ... storageClass=Exposure,
395 ... name="{foo}Dataset")
396 ... outputConnection = cT.Output(doc="Example output",
397 ... dimensions=("A", "B"),
398 ... storageClass=Exposure,
399 ... name="{foo}output")
400 >>> class ExampleConfig(PipelineTaskConfig,
401 ... pipelineConnections=ExampleConnections):
402 ... pass
403 >>> config = ExampleConfig()
404 >>> config.connections.foo = Modified
405 >>> config.connections.outputConnection = "TotallyDifferent"
406 >>> connections = ExampleConnections(config=config)
407 >>> assert(connections.inputConnection.name == "ModifiedDataset")
408 >>> assert(connections.outputConnection.name == "TotallyDifferent")
409 """
411 dimensions: ClassVar[Set[str]]
413 def __init__(self, *, config: "PipelineTaskConfig" | None = None):
414 self.inputs: Set[str] = set(self.inputs)
415 self.prerequisiteInputs: Set[str] = set(self.prerequisiteInputs)
416 self.outputs: Set[str] = set(self.outputs)
417 self.initInputs: Set[str] = set(self.initInputs)
418 self.initOutputs: Set[str] = set(self.initOutputs)
419 self.allConnections: Dict[str, BaseConnection] = dict(self.allConnections)
421 from .config import PipelineTaskConfig # local import to avoid cycle
423 if config is None or not isinstance(config, PipelineTaskConfig):
424 raise ValueError(
425 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
426 )
427 self.config = config
428 # Extract the template names that were defined in the config instance
429 # by looping over the keys of the defaultTemplates dict specified at
430 # class declaration time
431 templateValues = {
432 name: getattr(config.connections, name) for name in getattr(self, "defaultTemplates").keys()
433 }
434 # Extract the configured value corresponding to each connection
435 # variable. I.e. for each connection identifier, populate a override
436 # for the connection.name attribute
437 self._nameOverrides = {
438 name: getattr(config.connections, name).format(**templateValues)
439 for name in self.allConnections.keys()
440 }
442 # connections.name corresponds to a dataset type name, create a reverse
443 # mapping that goes from dataset type name to attribute identifier name
444 # (variable name) on the connection class
445 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
447 def buildDatasetRefs(
448 self, quantum: Quantum
449 ) -> typing.Tuple[InputQuantizedConnection, OutputQuantizedConnection]:
450 """Builds QuantizedConnections corresponding to input Quantum
452 Parameters
453 ----------
454 quantum : `lsst.daf.butler.Quantum`
455 Quantum object which defines the inputs and outputs for a given
456 unit of processing
458 Returns
459 -------
460 retVal : `tuple` of (`InputQuantizedConnection`,
461 `OutputQuantizedConnection`) Namespaces mapping attribute names
462 (identifiers of connections) to butler references defined in the
463 input `lsst.daf.butler.Quantum`
464 """
465 inputDatasetRefs = InputQuantizedConnection()
466 outputDatasetRefs = OutputQuantizedConnection()
467 # operate on a reference object and an interable of names of class
468 # connection attributes
469 for refs, names in zip(
470 (inputDatasetRefs, outputDatasetRefs),
471 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
472 ):
473 # get a name of a class connection attribute
474 for attributeName in names:
475 # get the attribute identified by name
476 attribute = getattr(self, attributeName)
477 # Branch if the attribute dataset type is an input
478 if attribute.name in quantum.inputs:
479 # if the dataset is marked to load deferred, wrap it in a
480 # DeferredDatasetRef
481 quantumInputRefs: Union[List[DatasetRef], List[DeferredDatasetRef]]
482 if attribute.deferLoad:
483 quantumInputRefs = [
484 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
485 ]
486 else:
487 quantumInputRefs = list(quantum.inputs[attribute.name])
488 # Unpack arguments that are not marked multiples (list of
489 # length one)
490 if not attribute.multiple:
491 if len(quantumInputRefs) > 1:
492 raise ScalarError(
493 "Received multiple datasets "
494 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
495 f"for scalar connection {attributeName} "
496 f"({quantumInputRefs[0].datasetType.name}) "
497 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
498 )
499 if len(quantumInputRefs) == 0:
500 continue
501 setattr(refs, attributeName, quantumInputRefs[0])
502 else:
503 # Add to the QuantizedConnection identifier
504 setattr(refs, attributeName, quantumInputRefs)
505 # Branch if the attribute dataset type is an output
506 elif attribute.name in quantum.outputs:
507 value = quantum.outputs[attribute.name]
508 # Unpack arguments that are not marked multiples (list of
509 # length one)
510 if not attribute.multiple:
511 setattr(refs, attributeName, value[0])
512 else:
513 setattr(refs, attributeName, value)
514 # Specified attribute is not in inputs or outputs dont know how
515 # to handle, throw
516 else:
517 raise ValueError(
518 f"Attribute with name {attributeName} has no counterpoint in input quantum"
519 )
520 return inputDatasetRefs, outputDatasetRefs
522 def adjustQuantum(
523 self,
524 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
525 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]],
526 label: str,
527 data_id: DataCoordinate,
528 ) -> typing.Tuple[
529 typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
530 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]],
531 ]:
532 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
533 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
534 of the activator.
536 Parameters
537 ----------
538 inputs : `dict`
539 Dictionary whose keys are an input (regular or prerequisite)
540 connection name and whose values are a tuple of the connection
541 instance and a collection of associated `DatasetRef` objects.
542 The exact type of the nested collections is unspecified; it can be
543 assumed to be multi-pass iterable and support `len` and ``in``, but
544 it should not be mutated in place. In contrast, the outer
545 dictionaries are guaranteed to be temporary copies that are true
546 `dict` instances, and hence may be modified and even returned; this
547 is especially useful for delegating to `super` (see notes below).
548 outputs : `Mapping`
549 Mapping of output datasets, with the same structure as ``inputs``.
550 label : `str`
551 Label for this task in the pipeline (should be used in all
552 diagnostic messages).
553 data_id : `lsst.daf.butler.DataCoordinate`
554 Data ID for this quantum in the pipeline (should be used in all
555 diagnostic messages).
557 Returns
558 -------
559 adjusted_inputs : `Mapping`
560 Mapping of the same form as ``inputs`` with updated containers of
561 input `DatasetRef` objects. Connections that are not changed
562 should not be returned at all. Datasets may only be removed, not
563 added. Nested collections may be of any multi-pass iterable type,
564 and the order of iteration will set the order of iteration within
565 `PipelineTask.runQuantum`.
566 adjusted_outputs : `Mapping`
567 Mapping of updated output datasets, with the same structure and
568 interpretation as ``adjusted_inputs``.
570 Raises
571 ------
572 ScalarError
573 Raised if any `Input` or `PrerequisiteInput` connection has
574 ``multiple`` set to `False`, but multiple datasets.
575 NoWorkFound
576 Raised to indicate that this quantum should not be run; not enough
577 datasets were found for a regular `Input` connection, and the
578 quantum should be pruned or skipped.
579 FileNotFoundError
580 Raised to cause QuantumGraph generation to fail (with the message
581 included in this exception); not enough datasets were found for a
582 `PrerequisiteInput` connection.
584 Notes
585 -----
586 The base class implementation performs important checks. It always
587 returns an empty mapping (i.e. makes no adjustments). It should
588 always called be via `super` by custom implementations, ideally at the
589 end of the custom implementation with already-adjusted mappings when
590 any datasets are actually dropped, e.g.::
592 def adjustQuantum(self, inputs, outputs, label, data_id):
593 # Filter out some dataset refs for one connection.
594 connection, old_refs = inputs["my_input"]
595 new_refs = [ref for ref in old_refs if ...]
596 adjusted_inputs = {"my_input", (connection, new_refs)}
597 # Update the original inputs so we can pass them to super.
598 inputs.update(adjusted_inputs)
599 # Can ignore outputs from super because they are guaranteed
600 # to be empty.
601 super().adjustQuantum(inputs, outputs, label_data_id)
602 # Return only the connections we modified.
603 return adjusted_inputs, {}
605 Removing outputs here is guaranteed to affect what is actually
606 passed to `PipelineTask.runQuantum`, but its effect on the larger
607 graph may be deferred to execution, depending on the context in
608 which `adjustQuantum` is being run: if one quantum removes an output
609 that is needed by a second quantum as input, the second quantum may not
610 be adjusted (and hence pruned or skipped) until that output is actually
611 found to be missing at execution time.
613 Tasks that desire zip-iteration consistency between any combinations of
614 connections that have the same data ID should generally implement
615 `adjustQuantum` to achieve this, even if they could also run that
616 logic during execution; this allows the system to see outputs that will
617 not be produced because the corresponding input is missing as early as
618 possible.
619 """
620 for name, (input_connection, refs) in inputs.items():
621 dataset_type_name = input_connection.name
622 if not input_connection.multiple and len(refs) > 1:
623 raise ScalarError(
624 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
625 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
626 f"for quantum data ID {data_id}."
627 )
628 if len(refs) < input_connection.minimum:
629 if isinstance(input_connection, PrerequisiteInput):
630 # This branch should only be possible during QG generation,
631 # or if someone deleted the dataset between making the QG
632 # and trying to run it. Either one should be a hard error.
633 raise FileNotFoundError(
634 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
635 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
636 f"{data_id}."
637 )
638 else:
639 # This branch should be impossible during QG generation,
640 # because that algorithm can only make quanta whose inputs
641 # are either already present or should be created during
642 # execution. It can trigger during execution if the input
643 # wasn't actually created by an upstream task in the same
644 # graph.
645 raise NoWorkFound(label, name, input_connection)
646 for name, (output_connection, refs) in outputs.items():
647 dataset_type_name = output_connection.name
648 if not output_connection.multiple and len(refs) > 1:
649 raise ScalarError(
650 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
651 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
652 f"for quantum data ID {data_id}."
653 )
654 return {}, {}
657def iterConnections(
658 connections: PipelineTaskConnections, connectionType: Union[str, Iterable[str]]
659) -> typing.Generator[BaseConnection, None, None]:
660 """Creates an iterator over the selected connections type which yields
661 all the defined connections of that type.
663 Parameters
664 ----------
665 connections: `PipelineTaskConnections`
666 An instance of a `PipelineTaskConnections` object that will be iterated
667 over.
668 connectionType: `str`
669 The type of connections to iterate over, valid values are inputs,
670 outputs, prerequisiteInputs, initInputs, initOutputs.
672 Yields
673 ------
674 connection: `BaseConnection`
675 A connection defined on the input connections object of the type
676 supplied. The yielded value Will be an derived type of
677 `BaseConnection`.
678 """
679 if isinstance(connectionType, str):
680 connectionType = (connectionType,)
681 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
682 yield getattr(connections, name)
685@dataclass
686class AdjustQuantumHelper:
687 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
689 This class holds `input` and `output` mappings in the form used by
690 `Quantum` and execution harness code, i.e. with `DatasetType` keys,
691 translating them to and from the connection-oriented mappings used inside
692 `PipelineTaskConnections`.
693 """
695 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
696 """Mapping of regular input and prerequisite input datasets, grouped by
697 `DatasetType`.
698 """
700 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
701 """Mapping of output datasets, grouped by `DatasetType`.
702 """
704 inputs_adjusted: bool = False
705 """Whether any inputs were removed in the last call to `adjust_in_place`.
706 """
708 outputs_adjusted: bool = False
709 """Whether any outputs were removed in the last call to `adjust_in_place`.
710 """
712 def adjust_in_place(
713 self,
714 connections: PipelineTaskConnections,
715 label: str,
716 data_id: DataCoordinate,
717 ) -> None:
718 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
719 with its results.
721 Parameters
722 ----------
723 connections : `PipelineTaskConnections`
724 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
725 label : `str`
726 Label for this task in the pipeline (should be used in all
727 diagnostic messages).
728 data_id : `lsst.daf.butler.DataCoordinate`
729 Data ID for this quantum in the pipeline (should be used in all
730 diagnostic messages).
731 """
732 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
733 # connection-keyed, PipelineTask-oriented mappings.
734 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {}
735 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {}
736 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
737 connection = getattr(connections, name)
738 dataset_type_name = connection.name
739 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
740 for name in itertools.chain(connections.outputs):
741 connection = getattr(connections, name)
742 dataset_type_name = connection.name
743 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
744 # Actually call adjustQuantum.
745 # MyPy correctly complains that this call is not quite legal, but the
746 # method docs explain exactly what's expected and it's the behavior we
747 # want. It'd be nice to avoid this if we ever have to change the
748 # interface anyway, but not an immediate problem.
749 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
750 inputs_by_connection, # type: ignore
751 outputs_by_connection, # type: ignore
752 label,
753 data_id,
754 )
755 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
756 # installing new mappings in self if necessary.
757 if adjusted_inputs_by_connection:
758 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs)
759 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
760 dataset_type_name = connection.name
761 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
762 raise RuntimeError(
763 f"adjustQuantum implementation for task with label {label} returned {name} "
764 f"({dataset_type_name}) input datasets that are not a subset of those "
765 f"it was given for data ID {data_id}."
766 )
767 adjusted_inputs[dataset_type_name] = list(updated_refs)
768 self.inputs = adjusted_inputs.freeze()
769 self.inputs_adjusted = True
770 else:
771 self.inputs_adjusted = False
772 if adjusted_outputs_by_connection:
773 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs)
774 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
775 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
776 raise RuntimeError(
777 f"adjustQuantum implementation for task with label {label} returned {name} "
778 f"({dataset_type_name}) output datasets that are not a subset of those "
779 f"it was given for data ID {data_id}."
780 )
781 adjusted_outputs[dataset_type_name] = list(updated_refs)
782 self.outputs = adjusted_outputs.freeze()
783 self.outputs_adjusted = True
784 else:
785 self.outputs_adjusted = False