Coverage for python/lsst/pipe/base/connections.py: 47%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError",
36]
38import itertools
39import string
40import typing
41from collections import UserDict
42from dataclasses import dataclass
43from types import SimpleNamespace
44from typing import Any, ClassVar, Dict, Iterable, List, Set, Union
46from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
48from ._status import NoWorkFound
49from .connectionTypes import (
50 BaseConnection,
51 BaseInput,
52 InitInput,
53 InitOutput,
54 Input,
55 Output,
56 PrerequisiteInput,
57)
59if typing.TYPE_CHECKING: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true
60 from .config import PipelineTaskConfig
63class ScalarError(TypeError):
64 """Exception raised when dataset type is configured as scalar
65 but there are multiple data IDs in a Quantum for that dataset.
66 """
69class PipelineTaskConnectionDict(UserDict):
70 """This is a special dict class used by PipelineTaskConnectionMetaclass
72 This dict is used in PipelineTaskConnection class creation, as the
73 dictionary that is initially used as __dict__. It exists to
74 intercept connection fields declared in a PipelineTaskConnection, and
75 what name is used to identify them. The names are then added to class
76 level list according to the connection type of the class attribute. The
77 names are also used as keys in a class level dictionary associated with
78 the corresponding class attribute. This information is a duplicate of
79 what exists in __dict__, but provides a simple place to lookup and
80 iterate on only these variables.
81 """
83 def __init__(self, *args: Any, **kwargs: Any):
84 super().__init__(*args, **kwargs)
85 # Initialize class level variables used to track any declared
86 # class level variables that are instances of
87 # connectionTypes.BaseConnection
88 self.data["inputs"] = []
89 self.data["prerequisiteInputs"] = []
90 self.data["outputs"] = []
91 self.data["initInputs"] = []
92 self.data["initOutputs"] = []
93 self.data["allConnections"] = {}
95 def __setitem__(self, name: str, value: Any) -> None:
96 if isinstance(value, Input):
97 self.data["inputs"].append(name)
98 elif isinstance(value, PrerequisiteInput):
99 self.data["prerequisiteInputs"].append(name)
100 elif isinstance(value, Output):
101 self.data["outputs"].append(name)
102 elif isinstance(value, InitInput):
103 self.data["initInputs"].append(name)
104 elif isinstance(value, InitOutput):
105 self.data["initOutputs"].append(name)
106 # This should not be an elif, as it needs tested for
107 # everything that inherits from BaseConnection
108 if isinstance(value, BaseConnection):
109 object.__setattr__(value, "varName", name)
110 self.data["allConnections"][name] = value
111 # defer to the default behavior
112 super().__setitem__(name, value)
115class PipelineTaskConnectionsMetaclass(type):
116 """Metaclass used in the declaration of PipelineTaskConnections classes"""
118 def __prepare__(name, bases, **kwargs): # noqa: 805
119 # Create an instance of our special dict to catch and track all
120 # variables that are instances of connectionTypes.BaseConnection
121 # Copy any existing connections from a parent class
122 dct = PipelineTaskConnectionDict()
123 for base in bases:
124 if isinstance(base, PipelineTaskConnectionsMetaclass): 124 ↛ 123line 124 didn't jump to line 123, because the condition on line 124 was never false
125 for name, value in base.allConnections.items(): 125 ↛ 126line 125 didn't jump to line 126, because the loop on line 125 never started
126 dct[name] = value
127 return dct
129 def __new__(cls, name, bases, dct, **kwargs):
130 dimensionsValueError = TypeError(
131 "PipelineTaskConnections class must be created with a dimensions "
132 "attribute which is an iterable of dimension names"
133 )
135 if name != "PipelineTaskConnections":
136 # Verify that dimensions are passed as a keyword in class
137 # declaration
138 if "dimensions" not in kwargs: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true
139 for base in bases:
140 if hasattr(base, "dimensions"):
141 kwargs["dimensions"] = base.dimensions
142 break
143 if "dimensions" not in kwargs:
144 raise dimensionsValueError
145 try:
146 if isinstance(kwargs["dimensions"], str): 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true
147 raise TypeError(
148 "Dimensions must be iterable of dimensions, got str,"
149 "possibly omitted trailing comma"
150 )
151 if not isinstance(kwargs["dimensions"], typing.Iterable): 151 ↛ 152line 151 didn't jump to line 152, because the condition on line 151 was never true
152 raise TypeError("Dimensions must be iterable of dimensions")
153 dct["dimensions"] = set(kwargs["dimensions"])
154 except TypeError as exc:
155 raise dimensionsValueError from exc
156 # Lookup any python string templates that may have been used in the
157 # declaration of the name field of a class connection attribute
158 allTemplates = set()
159 stringFormatter = string.Formatter()
160 # Loop over all connections
161 for obj in dct["allConnections"].values():
162 nameValue = obj.name
163 # add all the parameters to the set of templates
164 for param in stringFormatter.parse(nameValue):
165 if param[1] is not None:
166 allTemplates.add(param[1])
168 # look up any template from base classes and merge them all
169 # together
170 mergeDict = {}
171 for base in bases[::-1]:
172 if hasattr(base, "defaultTemplates"): 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true
173 mergeDict.update(base.defaultTemplates)
174 if "defaultTemplates" in kwargs:
175 mergeDict.update(kwargs["defaultTemplates"])
177 if len(mergeDict) > 0:
178 kwargs["defaultTemplates"] = mergeDict
180 # Verify that if templated strings were used, defaults were
181 # supplied as an argument in the declaration of the connection
182 # class
183 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true
184 raise TypeError(
185 "PipelineTaskConnection class contains templated attribute names, but no "
186 "defaut templates were provided, add a dictionary attribute named "
187 "defaultTemplates which contains the mapping between template key and value"
188 )
189 if len(allTemplates) > 0:
190 # Verify all templates have a default, and throw if they do not
191 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
192 templateDifference = allTemplates.difference(defaultTemplateKeys)
193 if templateDifference: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true
194 raise TypeError(f"Default template keys were not provided for {templateDifference}")
195 # Verify that templates do not share names with variable names
196 # used for a connection, this is needed because of how
197 # templates are specified in an associated config class.
198 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
199 if len(nameTemplateIntersection) > 0: 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true
200 raise TypeError(
201 f"Template parameters cannot share names with Class attributes"
202 f" (conflicts are {nameTemplateIntersection})."
203 )
204 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
206 # Convert all the connection containers into frozensets so they cannot
207 # be modified at the class scope
208 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
209 dct[connectionName] = frozenset(dct[connectionName])
210 # our custom dict type must be turned into an actual dict to be used in
211 # type.__new__
212 return super().__new__(cls, name, bases, dict(dct))
214 def __init__(cls, name, bases, dct, **kwargs):
215 # This overrides the default init to drop the kwargs argument. Python
216 # metaclasses will have this argument set if any kwargs are passes at
217 # class construction time, but should be consumed before calling
218 # __init__ on the type metaclass. This is in accordance with python
219 # documentation on metaclasses
220 super().__init__(name, bases, dct)
223class QuantizedConnection(SimpleNamespace):
224 """A Namespace to map defined variable names of connections to the
225 associated `lsst.daf.butler.DatasetRef` objects.
227 This class maps the names used to define a connection on a
228 PipelineTaskConnectionsClass to the corresponding
229 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
230 instance. This will be a quantum of execution based on the graph created
231 by examining all the connections defined on the
232 `PipelineTaskConnectionsClass`.
233 """
235 def __init__(self, **kwargs):
236 # Create a variable to track what attributes are added. This is used
237 # later when iterating over this QuantizedConnection instance
238 object.__setattr__(self, "_attributes", set())
240 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]) -> None:
241 # Capture the attribute name as it is added to this object
242 self._attributes.add(name)
243 super().__setattr__(name, value)
245 def __delattr__(self, name):
246 object.__delattr__(self, name)
247 self._attributes.remove(name)
249 def __iter__(
250 self,
251 ) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, typing.List[DatasetRef]]], None, None]:
252 """Make an Iterator for this QuantizedConnection
254 Iterating over a QuantizedConnection will yield a tuple with the name
255 of an attribute and the value associated with that name. This is
256 similar to dict.items() but is on the namespace attributes rather than
257 dict keys.
258 """
259 yield from ((name, getattr(self, name)) for name in self._attributes)
261 def keys(self) -> typing.Generator[str, None, None]:
262 """Returns an iterator over all the attributes added to a
263 QuantizedConnection class
264 """
265 yield from self._attributes
268class InputQuantizedConnection(QuantizedConnection):
269 pass
272class OutputQuantizedConnection(QuantizedConnection):
273 pass
276@dataclass(frozen=True)
277class DeferredDatasetRef:
278 """A wrapper class for `DatasetRef` that indicates that a `PipelineTask`
279 should receive a `DeferredDatasetHandle` instead of an in-memory dataset.
281 Parameters
282 ----------
283 datasetRef : `lsst.daf.butler.DatasetRef`
284 The `lsst.daf.butler.DatasetRef` that will be eventually used to
285 resolve a dataset
286 """
288 datasetRef: DatasetRef
290 @property
291 def datasetType(self) -> DatasetType:
292 """The dataset type for this dataset."""
293 return self.datasetRef.datasetType
295 @property
296 def dataId(self) -> DataCoordinate:
297 """The data ID for this dataset."""
298 return self.datasetRef.dataId
301class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
302 """PipelineTaskConnections is a class used to declare desired IO when a
303 PipelineTask is run by an activator
305 Parameters
306 ----------
307 config : `PipelineTaskConfig`
308 A `PipelineTaskConfig` class instance whose class has been configured
309 to use this `PipelineTaskConnectionsClass`
311 See also
312 --------
313 iterConnections
315 Notes
316 -----
317 ``PipelineTaskConnection`` classes are created by declaring class
318 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
319 listed as follows:
321 * ``InitInput`` - Defines connections in a quantum graph which are used as
322 inputs to the ``__init__`` function of the `PipelineTask` corresponding
323 to this class
324 * ``InitOuput`` - Defines connections in a quantum graph which are to be
325 persisted using a butler at the end of the ``__init__`` function of the
326 `PipelineTask` corresponding to this class. The variable name used to
327 define this connection should be the same as an attribute name on the
328 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
329 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
330 a `PipelineTask` instance should have an attribute
331 ``self.outputSchema`` defined. Its value is what will be saved by the
332 activator framework.
333 * ``PrerequisiteInput`` - An input connection type that defines a
334 `lsst.daf.butler.DatasetType` that must be present at execution time,
335 but that will not be used during the course of creating the quantum
336 graph to be executed. These most often are things produced outside the
337 processing pipeline, such as reference catalogs.
338 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
339 in the ``run`` method of a `PipelineTask`. The name used to declare
340 class attribute must match a function argument name in the ``run``
341 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
342 defines an ``Input`` with the name ``calexp``, then the corresponding
343 signature should be ``PipelineTask.run(calexp, ...)``
344 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
345 execution of a `PipelineTask`. The name used to declare the connection
346 must correspond to an attribute of a `Struct` that is returned by a
347 `PipelineTask` ``run`` method. E.g. if an output connection is
348 defined with the name ``measCat``, then the corresponding
349 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
350 X matches the ``storageClass`` type defined on the output connection.
352 The process of declaring a ``PipelineTaskConnection`` class involves
353 parameters passed in the declaration statement.
355 The first parameter is ``dimensions`` which is an iterable of strings which
356 defines the unit of processing the run method of a corresponding
357 `PipelineTask` will operate on. These dimensions must match dimensions that
358 exist in the butler registry which will be used in executing the
359 corresponding `PipelineTask`.
361 The second parameter is labeled ``defaultTemplates`` and is conditionally
362 optional. The name attributes of connections can be specified as python
363 format strings, with named format arguments. If any of the name parameters
364 on connections defined in a `PipelineTaskConnections` class contain a
365 template, then a default template value must be specified in the
366 ``defaultTemplates`` argument. This is done by passing a dictionary with
367 keys corresponding to a template identifier, and values corresponding to
368 the value to use as a default when formatting the string. For example if
369 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
370 ``defaultTemplates`` = {'input': 'deep'}.
372 Once a `PipelineTaskConnections` class is created, it is used in the
373 creation of a `PipelineTaskConfig`. This is further documented in the
374 documentation of `PipelineTaskConfig`. For the purposes of this
375 documentation, the relevant information is that the config class allows
376 configuration of connection names by users when running a pipeline.
378 Instances of a `PipelineTaskConnections` class are used by the pipeline
379 task execution framework to introspect what a corresponding `PipelineTask`
380 will require, and what it will produce.
382 Examples
383 --------
384 >>> from lsst.pipe.base import connectionTypes as cT
385 >>> from lsst.pipe.base import PipelineTaskConnections
386 >>> from lsst.pipe.base import PipelineTaskConfig
387 >>> class ExampleConnections(PipelineTaskConnections,
388 ... dimensions=("A", "B"),
389 ... defaultTemplates={"foo": "Example"}):
390 ... inputConnection = cT.Input(doc="Example input",
391 ... dimensions=("A", "B"),
392 ... storageClass=Exposure,
393 ... name="{foo}Dataset")
394 ... outputConnection = cT.Output(doc="Example output",
395 ... dimensions=("A", "B"),
396 ... storageClass=Exposure,
397 ... name="{foo}output")
398 >>> class ExampleConfig(PipelineTaskConfig,
399 ... pipelineConnections=ExampleConnections):
400 ... pass
401 >>> config = ExampleConfig()
402 >>> config.connections.foo = Modified
403 >>> config.connections.outputConnection = "TotallyDifferent"
404 >>> connections = ExampleConnections(config=config)
405 >>> assert(connections.inputConnection.name == "ModifiedDataset")
406 >>> assert(connections.outputConnection.name == "TotallyDifferent")
407 """
409 dimensions: ClassVar[Set[str]]
411 def __init__(self, *, config: "PipelineTaskConfig" = None):
412 self.inputs: Set[str] = set(self.inputs)
413 self.prerequisiteInputs: Set[str] = set(self.prerequisiteInputs)
414 self.outputs: Set[str] = set(self.outputs)
415 self.initInputs: Set[str] = set(self.initInputs)
416 self.initOutputs: Set[str] = set(self.initOutputs)
417 self.allConnections: Dict[str, BaseConnection] = dict(self.allConnections)
419 from .config import PipelineTaskConfig # local import to avoid cycle
421 if config is None or not isinstance(config, PipelineTaskConfig):
422 raise ValueError(
423 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
424 )
425 self.config = config
426 # Extract the template names that were defined in the config instance
427 # by looping over the keys of the defaultTemplates dict specified at
428 # class declaration time
429 templateValues = {
430 name: getattr(config.connections, name) for name in getattr(self, "defaultTemplates").keys()
431 }
432 # Extract the configured value corresponding to each connection
433 # variable. I.e. for each connection identifier, populate a override
434 # for the connection.name attribute
435 self._nameOverrides = {
436 name: getattr(config.connections, name).format(**templateValues)
437 for name in self.allConnections.keys()
438 }
440 # connections.name corresponds to a dataset type name, create a reverse
441 # mapping that goes from dataset type name to attribute identifier name
442 # (variable name) on the connection class
443 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
445 def buildDatasetRefs(
446 self, quantum: Quantum
447 ) -> typing.Tuple[InputQuantizedConnection, OutputQuantizedConnection]:
448 """Builds QuantizedConnections corresponding to input Quantum
450 Parameters
451 ----------
452 quantum : `lsst.daf.butler.Quantum`
453 Quantum object which defines the inputs and outputs for a given
454 unit of processing
456 Returns
457 -------
458 retVal : `tuple` of (`InputQuantizedConnection`,
459 `OutputQuantizedConnection`) Namespaces mapping attribute names
460 (identifiers of connections) to butler references defined in the
461 input `lsst.daf.butler.Quantum`
462 """
463 inputDatasetRefs = InputQuantizedConnection()
464 outputDatasetRefs = OutputQuantizedConnection()
465 # operate on a reference object and an interable of names of class
466 # connection attributes
467 for refs, names in zip(
468 (inputDatasetRefs, outputDatasetRefs),
469 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
470 ):
471 # get a name of a class connection attribute
472 for attributeName in names:
473 # get the attribute identified by name
474 attribute = getattr(self, attributeName)
475 # Branch if the attribute dataset type is an input
476 if attribute.name in quantum.inputs:
477 # if the dataset is marked to load deferred, wrap it in a
478 # DeferredDatasetRef
479 quantumInputRefs: Union[List[DatasetRef], List[DeferredDatasetRef]]
480 if attribute.deferLoad:
481 quantumInputRefs = [
482 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
483 ]
484 else:
485 quantumInputRefs = list(quantum.inputs[attribute.name])
486 # Unpack arguments that are not marked multiples (list of
487 # length one)
488 if not attribute.multiple:
489 if len(quantumInputRefs) > 1:
490 raise ScalarError(
491 "Received multiple datasets "
492 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
493 f"for scalar connection {attributeName} "
494 f"({quantumInputRefs[0].datasetType.name}) "
495 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
496 )
497 if len(quantumInputRefs) == 0:
498 continue
499 setattr(refs, attributeName, quantumInputRefs[0])
500 else:
501 # Add to the QuantizedConnection identifier
502 setattr(refs, attributeName, quantumInputRefs)
503 # Branch if the attribute dataset type is an output
504 elif attribute.name in quantum.outputs:
505 value = quantum.outputs[attribute.name]
506 # Unpack arguments that are not marked multiples (list of
507 # length one)
508 if not attribute.multiple:
509 setattr(refs, attributeName, value[0])
510 else:
511 setattr(refs, attributeName, value)
512 # Specified attribute is not in inputs or outputs dont know how
513 # to handle, throw
514 else:
515 raise ValueError(
516 f"Attribute with name {attributeName} has no counterpoint in input quantum"
517 )
518 return inputDatasetRefs, outputDatasetRefs
520 def adjustQuantum(
521 self,
522 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
523 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]],
524 label: str,
525 data_id: DataCoordinate,
526 ) -> typing.Tuple[
527 typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
528 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]],
529 ]:
530 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
531 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
532 of the activator.
534 Parameters
535 ----------
536 inputs : `dict`
537 Dictionary whose keys are an input (regular or prerequisite)
538 connection name and whose values are a tuple of the connection
539 instance and a collection of associated `DatasetRef` objects.
540 The exact type of the nested collections is unspecified; it can be
541 assumed to be multi-pass iterable and support `len` and ``in``, but
542 it should not be mutated in place. In contrast, the outer
543 dictionaries are guaranteed to be temporary copies that are true
544 `dict` instances, and hence may be modified and even returned; this
545 is especially useful for delegating to `super` (see notes below).
546 outputs : `Mapping`
547 Mapping of output datasets, with the same structure as ``inputs``.
548 label : `str`
549 Label for this task in the pipeline (should be used in all
550 diagnostic messages).
551 data_id : `lsst.daf.butler.DataCoordinate`
552 Data ID for this quantum in the pipeline (should be used in all
553 diagnostic messages).
555 Returns
556 -------
557 adjusted_inputs : `Mapping`
558 Mapping of the same form as ``inputs`` with updated containers of
559 input `DatasetRef` objects. Connections that are not changed
560 should not be returned at all. Datasets may only be removed, not
561 added. Nested collections may be of any multi-pass iterable type,
562 and the order of iteration will set the order of iteration within
563 `PipelineTask.runQuantum`.
564 adjusted_outputs : `Mapping`
565 Mapping of updated output datasets, with the same structure and
566 interpretation as ``adjusted_inputs``.
568 Raises
569 ------
570 ScalarError
571 Raised if any `Input` or `PrerequisiteInput` connection has
572 ``multiple`` set to `False`, but multiple datasets.
573 NoWorkFound
574 Raised to indicate that this quantum should not be run; not enough
575 datasets were found for a regular `Input` connection, and the
576 quantum should be pruned or skipped.
577 FileNotFoundError
578 Raised to cause QuantumGraph generation to fail (with the message
579 included in this exception); not enough datasets were found for a
580 `PrerequisiteInput` connection.
582 Notes
583 -----
584 The base class implementation performs important checks. It always
585 returns an empty mapping (i.e. makes no adjustments). It should
586 always called be via `super` by custom implementations, ideally at the
587 end of the custom implementation with already-adjusted mappings when
588 any datasets are actually dropped, e.g.::
590 def adjustQuantum(self, inputs, outputs, label, data_id):
591 # Filter out some dataset refs for one connection.
592 connection, old_refs = inputs["my_input"]
593 new_refs = [ref for ref in old_refs if ...]
594 adjusted_inputs = {"my_input", (connection, new_refs)}
595 # Update the original inputs so we can pass them to super.
596 inputs.update(adjusted_inputs)
597 # Can ignore outputs from super because they are guaranteed
598 # to be empty.
599 super().adjustQuantum(inputs, outputs, label_data_id)
600 # Return only the connections we modified.
601 return adjusted_inputs, {}
603 Removing outputs here is guaranteed to affect what is actually
604 passed to `PipelineTask.runQuantum`, but its effect on the larger
605 graph may be deferred to execution, depending on the context in
606 which `adjustQuantum` is being run: if one quantum removes an output
607 that is needed by a second quantum as input, the second quantum may not
608 be adjusted (and hence pruned or skipped) until that output is actually
609 found to be missing at execution time.
611 Tasks that desire zip-iteration consistency between any combinations of
612 connections that have the same data ID should generally implement
613 `adjustQuantum` to achieve this, even if they could also run that
614 logic during execution; this allows the system to see outputs that will
615 not be produced because the corresponding input is missing as early as
616 possible.
617 """
618 for name, (input_connection, refs) in inputs.items():
619 dataset_type_name = input_connection.name
620 if not input_connection.multiple and len(refs) > 1:
621 raise ScalarError(
622 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
623 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
624 f"for quantum data ID {data_id}."
625 )
626 if len(refs) < input_connection.minimum:
627 if isinstance(input_connection, PrerequisiteInput):
628 # This branch should only be possible during QG generation,
629 # or if someone deleted the dataset between making the QG
630 # and trying to run it. Either one should be a hard error.
631 raise FileNotFoundError(
632 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
633 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
634 f"{data_id}."
635 )
636 else:
637 # This branch should be impossible during QG generation,
638 # because that algorithm can only make quanta whose inputs
639 # are either already present or should be created during
640 # execution. It can trigger during execution if the input
641 # wasn't actually created by an upstream task in the same
642 # graph.
643 raise NoWorkFound(label, name, input_connection)
644 for name, (output_connection, refs) in outputs.items():
645 dataset_type_name = output_connection.name
646 if not output_connection.multiple and len(refs) > 1:
647 raise ScalarError(
648 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
649 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
650 f"for quantum data ID {data_id}."
651 )
652 return {}, {}
655def iterConnections(
656 connections: PipelineTaskConnections, connectionType: Union[str, Iterable[str]]
657) -> typing.Generator[BaseConnection, None, None]:
658 """Creates an iterator over the selected connections type which yields
659 all the defined connections of that type.
661 Parameters
662 ----------
663 connections: `PipelineTaskConnections`
664 An instance of a `PipelineTaskConnections` object that will be iterated
665 over.
666 connectionType: `str`
667 The type of connections to iterate over, valid values are inputs,
668 outputs, prerequisiteInputs, initInputs, initOutputs.
670 Yields
671 -------
672 connection: `BaseConnection`
673 A connection defined on the input connections object of the type
674 supplied. The yielded value Will be an derived type of
675 `BaseConnection`.
676 """
677 if isinstance(connectionType, str):
678 connectionType = (connectionType,)
679 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
680 yield getattr(connections, name)
683@dataclass
684class AdjustQuantumHelper:
685 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
687 This class holds `input` and `output` mappings in the form used by
688 `Quantum` and execution harness code, i.e. with `DatasetType` keys,
689 translating them to and from the connection-oriented mappings used inside
690 `PipelineTaskConnections`.
691 """
693 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
694 """Mapping of regular input and prerequisite input datasets, grouped by
695 `DatasetType`.
696 """
698 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
699 """Mapping of output datasets, grouped by `DatasetType`.
700 """
702 inputs_adjusted: bool = False
703 """Whether any inputs were removed in the last call to `adjust_in_place`.
704 """
706 outputs_adjusted: bool = False
707 """Whether any outputs were removed in the last call to `adjust_in_place`.
708 """
710 def adjust_in_place(
711 self,
712 connections: PipelineTaskConnections,
713 label: str,
714 data_id: DataCoordinate,
715 ) -> None:
716 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
717 with its results.
719 Parameters
720 ----------
721 connections : `PipelineTaskConnections`
722 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
723 label : `str`
724 Label for this task in the pipeline (should be used in all
725 diagnostic messages).
726 data_id : `lsst.daf.butler.DataCoordinate`
727 Data ID for this quantum in the pipeline (should be used in all
728 diagnostic messages).
729 """
730 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
731 # connection-keyed, PipelineTask-oriented mappings.
732 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {}
733 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {}
734 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
735 connection = getattr(connections, name)
736 dataset_type_name = connection.name
737 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
738 for name in itertools.chain(connections.outputs):
739 connection = getattr(connections, name)
740 dataset_type_name = connection.name
741 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
742 # Actually call adjustQuantum.
743 # MyPy correctly complains that this call is not quite legal, but the
744 # method docs explain exactly what's expected and it's the behavior we
745 # want. It'd be nice to avoid this if we ever have to change the
746 # interface anyway, but not an immediate problem.
747 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
748 inputs_by_connection, # type: ignore
749 outputs_by_connection, # type: ignore
750 label,
751 data_id,
752 )
753 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
754 # installing new mappings in self if necessary.
755 if adjusted_inputs_by_connection:
756 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs)
757 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
758 dataset_type_name = connection.name
759 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
760 raise RuntimeError(
761 f"adjustQuantum implementation for task with label {label} returned {name} "
762 f"({dataset_type_name}) input datasets that are not a subset of those "
763 f"it was given for data ID {data_id}."
764 )
765 adjusted_inputs[dataset_type_name] = list(updated_refs)
766 self.inputs = adjusted_inputs.freeze()
767 self.inputs_adjusted = True
768 else:
769 self.inputs_adjusted = False
770 if adjusted_outputs_by_connection:
771 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs)
772 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
773 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
774 raise RuntimeError(
775 f"adjustQuantum implementation for task with label {label} returned {name} "
776 f"({dataset_type_name}) output datasets that are not a subset of those "
777 f"it was given for data ID {data_id}."
778 )
779 adjusted_outputs[dataset_type_name] = list(updated_refs)
780 self.outputs = adjusted_outputs.freeze()
781 self.outputs_adjusted = True
782 else:
783 self.outputs_adjusted = False