Coverage for python/lsst/pipe/base/connections.py: 44%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError"
36]
38from collections import UserDict, namedtuple
39from dataclasses import dataclass
40from types import SimpleNamespace
41import typing
42from typing import Union, Iterable
44import itertools
45import string
47from . import config as configMod
48from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput,
49 Output, BaseConnection, BaseInput)
50from ._status import NoWorkFound
51from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
53if typing.TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 from .config import PipelineTaskConfig
57class ScalarError(TypeError):
58 """Exception raised when dataset type is configured as scalar
59 but there are multiple data IDs in a Quantum for that dataset.
60 """
63class PipelineTaskConnectionDict(UserDict):
64 """This is a special dict class used by PipelineTaskConnectionMetaclass
66 This dict is used in PipelineTaskConnection class creation, as the
67 dictionary that is initially used as __dict__. It exists to
68 intercept connection fields declared in a PipelineTaskConnection, and
69 what name is used to identify them. The names are then added to class
70 level list according to the connection type of the class attribute. The
71 names are also used as keys in a class level dictionary associated with
72 the corresponding class attribute. This information is a duplicate of
73 what exists in __dict__, but provides a simple place to lookup and
74 iterate on only these variables.
75 """
76 def __init__(self, *args, **kwargs):
77 super().__init__(*args, **kwargs)
78 # Initialize class level variables used to track any declared
79 # class level variables that are instances of
80 # connectionTypes.BaseConnection
81 self.data['inputs'] = []
82 self.data['prerequisiteInputs'] = []
83 self.data['outputs'] = []
84 self.data['initInputs'] = []
85 self.data['initOutputs'] = []
86 self.data['allConnections'] = {}
88 def __setitem__(self, name, value):
89 if isinstance(value, Input):
90 self.data['inputs'].append(name)
91 elif isinstance(value, PrerequisiteInput):
92 self.data['prerequisiteInputs'].append(name)
93 elif isinstance(value, Output):
94 self.data['outputs'].append(name)
95 elif isinstance(value, InitInput):
96 self.data['initInputs'].append(name)
97 elif isinstance(value, InitOutput):
98 self.data['initOutputs'].append(name)
99 # This should not be an elif, as it needs tested for
100 # everything that inherits from BaseConnection
101 if isinstance(value, BaseConnection):
102 object.__setattr__(value, 'varName', name)
103 self.data['allConnections'][name] = value
104 # defer to the default behavior
105 super().__setitem__(name, value)
108class PipelineTaskConnectionsMetaclass(type):
109 """Metaclass used in the declaration of PipelineTaskConnections classes
110 """
111 def __prepare__(name, bases, **kwargs): # noqa: 805
112 # Create an instance of our special dict to catch and track all
113 # variables that are instances of connectionTypes.BaseConnection
114 # Copy any existing connections from a parent class
115 dct = PipelineTaskConnectionDict()
116 for base in bases:
117 if isinstance(base, PipelineTaskConnectionsMetaclass): 117 ↛ 116line 117 didn't jump to line 116, because the condition on line 117 was never false
118 for name, value in base.allConnections.items(): 118 ↛ 119line 118 didn't jump to line 119, because the loop on line 118 never started
119 dct[name] = value
120 return dct
122 def __new__(cls, name, bases, dct, **kwargs):
123 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions "
124 "attribute which is an iterable of dimension names")
126 if name != 'PipelineTaskConnections':
127 # Verify that dimensions are passed as a keyword in class
128 # declaration
129 if 'dimensions' not in kwargs: 129 ↛ 130line 129 didn't jump to line 130, because the condition on line 129 was never true
130 for base in bases:
131 if hasattr(base, 'dimensions'):
132 kwargs['dimensions'] = base.dimensions
133 break
134 if 'dimensions' not in kwargs:
135 raise dimensionsValueError
136 try:
137 if isinstance(kwargs['dimensions'], str): 137 ↛ 138line 137 didn't jump to line 138, because the condition on line 137 was never true
138 raise TypeError("Dimensions must be iterable of dimensions, got str,"
139 "possibly omitted trailing comma")
140 if not isinstance(kwargs['dimensions'], typing.Iterable): 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true
141 raise TypeError("Dimensions must be iterable of dimensions")
142 dct['dimensions'] = set(kwargs['dimensions'])
143 except TypeError as exc:
144 raise dimensionsValueError from exc
145 # Lookup any python string templates that may have been used in the
146 # declaration of the name field of a class connection attribute
147 allTemplates = set()
148 stringFormatter = string.Formatter()
149 # Loop over all connections
150 for obj in dct['allConnections'].values():
151 nameValue = obj.name
152 # add all the parameters to the set of templates
153 for param in stringFormatter.parse(nameValue):
154 if param[1] is not None:
155 allTemplates.add(param[1])
157 # look up any template from base classes and merge them all
158 # together
159 mergeDict = {}
160 for base in bases[::-1]:
161 if hasattr(base, 'defaultTemplates'): 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true
162 mergeDict.update(base.defaultTemplates)
163 if 'defaultTemplates' in kwargs:
164 mergeDict.update(kwargs['defaultTemplates'])
166 if len(mergeDict) > 0:
167 kwargs['defaultTemplates'] = mergeDict
169 # Verify that if templated strings were used, defaults were
170 # supplied as an argument in the declaration of the connection
171 # class
172 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true
173 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no "
174 "defaut templates were provided, add a dictionary attribute named "
175 "defaultTemplates which contains the mapping between template key and value")
176 if len(allTemplates) > 0:
177 # Verify all templates have a default, and throw if they do not
178 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys())
179 templateDifference = allTemplates.difference(defaultTemplateKeys)
180 if templateDifference: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true
181 raise TypeError(f"Default template keys were not provided for {templateDifference}")
182 # Verify that templates do not share names with variable names
183 # used for a connection, this is needed because of how
184 # templates are specified in an associated config class.
185 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys()))
186 if len(nameTemplateIntersection) > 0: 186 ↛ 187line 186 didn't jump to line 187, because the condition on line 186 was never true
187 raise TypeError(f"Template parameters cannot share names with Class attributes"
188 f" (conflicts are {nameTemplateIntersection}).")
189 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {})
191 # Convert all the connection containers into frozensets so they cannot
192 # be modified at the class scope
193 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
194 dct[connectionName] = frozenset(dct[connectionName])
195 # our custom dict type must be turned into an actual dict to be used in
196 # type.__new__
197 return super().__new__(cls, name, bases, dict(dct))
199 def __init__(cls, name, bases, dct, **kwargs):
200 # This overrides the default init to drop the kwargs argument. Python
201 # metaclasses will have this argument set if any kwargs are passes at
202 # class construction time, but should be consumed before calling
203 # __init__ on the type metaclass. This is in accordance with python
204 # documentation on metaclasses
205 super().__init__(name, bases, dct)
208class QuantizedConnection(SimpleNamespace):
209 """A Namespace to map defined variable names of connections to the
210 associated `lsst.daf.butler.DatasetRef` objects.
212 This class maps the names used to define a connection on a
213 PipelineTaskConnectionsClass to the corresponding
214 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
215 instance. This will be a quantum of execution based on the graph created
216 by examining all the connections defined on the
217 `PipelineTaskConnectionsClass`.
218 """
219 def __init__(self, **kwargs):
220 # Create a variable to track what attributes are added. This is used
221 # later when iterating over this QuantizedConnection instance
222 object.__setattr__(self, "_attributes", set())
224 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
225 # Capture the attribute name as it is added to this object
226 self._attributes.add(name)
227 super().__setattr__(name, value)
229 def __delattr__(self, name):
230 object.__delattr__(self, name)
231 self._attributes.remove(name)
233 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
234 typing.List[DatasetRef]]], None, None]:
235 """Make an Iterator for this QuantizedConnection
237 Iterating over a QuantizedConnection will yield a tuple with the name
238 of an attribute and the value associated with that name. This is
239 similar to dict.items() but is on the namespace attributes rather than
240 dict keys.
241 """
242 yield from ((name, getattr(self, name)) for name in self._attributes)
244 def keys(self) -> typing.Generator[str, None, None]:
245 """Returns an iterator over all the attributes added to a
246 QuantizedConnection class
247 """
248 yield from self._attributes
251class InputQuantizedConnection(QuantizedConnection):
252 pass
255class OutputQuantizedConnection(QuantizedConnection):
256 pass
259class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")):
260 """Class which denotes that a datasetRef should be treated as deferred when
261 interacting with the butler
263 Parameters
264 ----------
265 datasetRef : `lsst.daf.butler.DatasetRef`
266 The `lsst.daf.butler.DatasetRef` that will be eventually used to
267 resolve a dataset
268 """
269 __slots__ = ()
272class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
273 """PipelineTaskConnections is a class used to declare desired IO when a
274 PipelineTask is run by an activator
276 Parameters
277 ----------
278 config : `PipelineTaskConfig`
279 A `PipelineTaskConfig` class instance whose class has been configured
280 to use this `PipelineTaskConnectionsClass`
282 See also
283 --------
284 iterConnections
286 Notes
287 -----
288 ``PipelineTaskConnection`` classes are created by declaring class
289 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
290 listed as follows:
292 * ``InitInput`` - Defines connections in a quantum graph which are used as
293 inputs to the ``__init__`` function of the `PipelineTask` corresponding
294 to this class
295 * ``InitOuput`` - Defines connections in a quantum graph which are to be
296 persisted using a butler at the end of the ``__init__`` function of the
297 `PipelineTask` corresponding to this class. The variable name used to
298 define this connection should be the same as an attribute name on the
299 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
300 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
301 a `PipelineTask` instance should have an attribute
302 ``self.outputSchema`` defined. Its value is what will be saved by the
303 activator framework.
304 * ``PrerequisiteInput`` - An input connection type that defines a
305 `lsst.daf.butler.DatasetType` that must be present at execution time,
306 but that will not be used during the course of creating the quantum
307 graph to be executed. These most often are things produced outside the
308 processing pipeline, such as reference catalogs.
309 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
310 in the ``run`` method of a `PipelineTask`. The name used to declare
311 class attribute must match a function argument name in the ``run``
312 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
313 defines an ``Input`` with the name ``calexp``, then the corresponding
314 signature should be ``PipelineTask.run(calexp, ...)``
315 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
316 execution of a `PipelineTask`. The name used to declare the connection
317 must correspond to an attribute of a `Struct` that is returned by a
318 `PipelineTask` ``run`` method. E.g. if an output connection is
319 defined with the name ``measCat``, then the corresponding
320 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
321 X matches the ``storageClass`` type defined on the output connection.
323 The process of declaring a ``PipelineTaskConnection`` class involves
324 parameters passed in the declaration statement.
326 The first parameter is ``dimensions`` which is an iterable of strings which
327 defines the unit of processing the run method of a corresponding
328 `PipelineTask` will operate on. These dimensions must match dimensions that
329 exist in the butler registry which will be used in executing the
330 corresponding `PipelineTask`.
332 The second parameter is labeled ``defaultTemplates`` and is conditionally
333 optional. The name attributes of connections can be specified as python
334 format strings, with named format arguments. If any of the name parameters
335 on connections defined in a `PipelineTaskConnections` class contain a
336 template, then a default template value must be specified in the
337 ``defaultTemplates`` argument. This is done by passing a dictionary with
338 keys corresponding to a template identifier, and values corresponding to
339 the value to use as a default when formatting the string. For example if
340 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
341 ``defaultTemplates`` = {'input': 'deep'}.
343 Once a `PipelineTaskConnections` class is created, it is used in the
344 creation of a `PipelineTaskConfig`. This is further documented in the
345 documentation of `PipelineTaskConfig`. For the purposes of this
346 documentation, the relevant information is that the config class allows
347 configuration of connection names by users when running a pipeline.
349 Instances of a `PipelineTaskConnections` class are used by the pipeline
350 task execution framework to introspect what a corresponding `PipelineTask`
351 will require, and what it will produce.
353 Examples
354 --------
355 >>> from lsst.pipe.base import connectionTypes as cT
356 >>> from lsst.pipe.base import PipelineTaskConnections
357 >>> from lsst.pipe.base import PipelineTaskConfig
358 >>> class ExampleConnections(PipelineTaskConnections,
359 ... dimensions=("A", "B"),
360 ... defaultTemplates={"foo": "Example"}):
361 ... inputConnection = cT.Input(doc="Example input",
362 ... dimensions=("A", "B"),
363 ... storageClass=Exposure,
364 ... name="{foo}Dataset")
365 ... outputConnection = cT.Output(doc="Example output",
366 ... dimensions=("A", "B"),
367 ... storageClass=Exposure,
368 ... name="{foo}output")
369 >>> class ExampleConfig(PipelineTaskConfig,
370 ... pipelineConnections=ExampleConnections):
371 ... pass
372 >>> config = ExampleConfig()
373 >>> config.connections.foo = Modified
374 >>> config.connections.outputConnection = "TotallyDifferent"
375 >>> connections = ExampleConnections(config=config)
376 >>> assert(connections.inputConnection.name == "ModifiedDataset")
377 >>> assert(connections.outputConnection.name == "TotallyDifferent")
378 """
380 def __init__(self, *, config: 'PipelineTaskConfig' = None):
381 self.inputs = set(self.inputs)
382 self.prerequisiteInputs = set(self.prerequisiteInputs)
383 self.outputs = set(self.outputs)
384 self.initInputs = set(self.initInputs)
385 self.initOutputs = set(self.initOutputs)
386 self.allConnections = dict(self.allConnections)
388 if config is None or not isinstance(config, configMod.PipelineTaskConfig):
389 raise ValueError("PipelineTaskConnections must be instantiated with"
390 " a PipelineTaskConfig instance")
391 self.config = config
392 # Extract the template names that were defined in the config instance
393 # by looping over the keys of the defaultTemplates dict specified at
394 # class declaration time
395 templateValues = {name: getattr(config.connections, name) for name in getattr(self,
396 'defaultTemplates').keys()}
397 # Extract the configured value corresponding to each connection
398 # variable. I.e. for each connection identifier, populate a override
399 # for the connection.name attribute
400 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues)
401 for name in self.allConnections.keys()}
403 # connections.name corresponds to a dataset type name, create a reverse
404 # mapping that goes from dataset type name to attribute identifier name
405 # (variable name) on the connection class
406 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
408 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection,
409 OutputQuantizedConnection]:
410 """Builds QuantizedConnections corresponding to input Quantum
412 Parameters
413 ----------
414 quantum : `lsst.daf.butler.Quantum`
415 Quantum object which defines the inputs and outputs for a given
416 unit of processing
418 Returns
419 -------
420 retVal : `tuple` of (`InputQuantizedConnection`,
421 `OutputQuantizedConnection`) Namespaces mapping attribute names
422 (identifiers of connections) to butler references defined in the
423 input `lsst.daf.butler.Quantum`
424 """
425 inputDatasetRefs = InputQuantizedConnection()
426 outputDatasetRefs = OutputQuantizedConnection()
427 # operate on a reference object and an interable of names of class
428 # connection attributes
429 for refs, names in zip((inputDatasetRefs, outputDatasetRefs),
430 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)):
431 # get a name of a class connection attribute
432 for attributeName in names:
433 # get the attribute identified by name
434 attribute = getattr(self, attributeName)
435 # Branch if the attribute dataset type is an input
436 if attribute.name in quantum.inputs:
437 # Get the DatasetRefs
438 quantumInputRefs = quantum.inputs[attribute.name]
439 # if the dataset is marked to load deferred, wrap it in a
440 # DeferredDatasetRef
441 if attribute.deferLoad:
442 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs]
443 # Unpack arguments that are not marked multiples (list of
444 # length one)
445 if not attribute.multiple:
446 if len(quantumInputRefs) > 1:
447 raise ScalarError(
448 f"Received multiple datasets "
449 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
450 f"for scalar connection {attributeName} "
451 f"({quantumInputRefs[0].datasetType.name}) "
452 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
453 )
454 if len(quantumInputRefs) == 0:
455 continue
456 quantumInputRefs = quantumInputRefs[0]
457 # Add to the QuantizedConnection identifier
458 setattr(refs, attributeName, quantumInputRefs)
459 # Branch if the attribute dataset type is an output
460 elif attribute.name in quantum.outputs:
461 value = quantum.outputs[attribute.name]
462 # Unpack arguments that are not marked multiples (list of
463 # length one)
464 if not attribute.multiple:
465 value = value[0]
466 # Add to the QuantizedConnection identifier
467 setattr(refs, attributeName, value)
468 # Specified attribute is not in inputs or outputs dont know how
469 # to handle, throw
470 else:
471 raise ValueError(f"Attribute with name {attributeName} has no counterpoint "
472 "in input quantum")
473 return inputDatasetRefs, outputDatasetRefs
475 def adjustQuantum(
476 self,
477 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
478 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]],
479 label: str,
480 data_id: DataCoordinate,
481 ) -> tuple.Tuple[typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]],
482 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]]]:
483 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
484 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
485 of the activator.
487 Parameters
488 ----------
489 inputs : `dict`
490 Dictionary whose keys are an input (regular or prerequisite)
491 connection name and whose values are a tuple of the connection
492 instance and a collection of associated `DatasetRef` objects.
493 The exact type of the nested collections is unspecified; it can be
494 assumed to be multi-pass iterable and support `len` and ``in``, but
495 it should not be mutated in place. In contrast, the outer
496 dictionaries are guaranteed to be temporary copies that are true
497 `dict` instances, and hence may be modified and even returned; this
498 is especially useful for delegating to `super` (see notes below).
499 outputs : `Mapping`
500 Mapping of output datasets, with the same structure as ``inputs``.
501 label : `str`
502 Label for this task in the pipeline (should be used in all
503 diagnostic messages).
504 data_id : `lsst.daf.butler.DataCoordinate`
505 Data ID for this quantum in the pipeline (should be used in all
506 diagnostic messages).
508 Returns
509 -------
510 adjusted_inputs : `Mapping`
511 Mapping of the same form as ``inputs`` with updated containers of
512 input `DatasetRef` objects. Connections that are not changed
513 should not be returned at all. Datasets may only be removed, not
514 added. Nested collections may be of any multi-pass iterable type,
515 and the order of iteration will set the order of iteration within
516 `PipelineTask.runQuantum`.
517 adjusted_outputs : `Mapping`
518 Mapping of updated output datasets, with the same structure and
519 interpretation as ``adjusted_inputs``.
521 Raises
522 ------
523 ScalarError
524 Raised if any `Input` or `PrerequisiteInput` connection has
525 ``multiple`` set to `False`, but multiple datasets.
526 NoWorkFound
527 Raised to indicate that this quantum should not be run; not enough
528 datasets were found for a regular `Input` connection, and the
529 quantum should be pruned or skipped.
530 FileNotFoundError
531 Raised to cause QuantumGraph generation to fail (with the message
532 included in this exception); not enough datasets were found for a
533 `PrerequisiteInput` connection.
535 Notes
536 -----
537 The base class implementation performs important checks. It always
538 returns an empty mapping (i.e. makes no adjustments). It should
539 always called be via `super` by custom implementations, ideally at the
540 end of the custom implementation with already-adjusted mappings when
541 any datasets are actually dropped, e.g.::
543 def adjustQuantum(self, inputs, outputs, label, data_id):
544 # Filter out some dataset refs for one connection.
545 connection, old_refs = inputs["my_input"]
546 new_refs = [ref for ref in old_refs if ...]
547 adjusted_inputs = {"my_input", (connection, new_refs)}
548 # Update the original inputs so we can pass them to super.
549 inputs.update(adjusted_inputs)
550 # Can ignore outputs from super because they are guaranteed
551 # to be empty.
552 super().adjustQuantum(inputs, outputs, label_data_id)
553 # Return only the connections we modified.
554 return adjusted_inputs, {}
556 Removing outputs here is guaranteed to affect what is actually
557 passed to `PipelineTask.runQuantum`, but its effect on the larger
558 graph may be deferred to execution, depending on the context in
559 which `adjustQuantum` is being run: if one quantum removes an output
560 that is needed by a second quantum as input, the second quantum may not
561 be adjusted (and hence pruned or skipped) until that output is actually
562 found to be missing at execution time.
564 Tasks that desire zip-iteration consistency between any combinations of
565 connections that have the same data ID should generally implement
566 `adjustQuantum` to achieve this, even if they could also run that
567 logic during execution; this allows the system to see outputs that will
568 not be produced because the corresponding input is missing as early as
569 possible.
570 """
571 for name, (connection, refs) in inputs.items():
572 dataset_type_name = connection.name
573 if not connection.multiple and len(refs) > 1:
574 raise ScalarError(
575 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
576 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
577 f"for quantum data ID {data_id}."
578 )
579 if len(refs) < connection.minimum:
580 if isinstance(connection, PrerequisiteInput):
581 # This branch should only be possible during QG generation,
582 # or if someone deleted the dataset between making the QG
583 # and trying to run it. Either one should be a hard error.
584 raise FileNotFoundError(
585 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
586 f"({dataset_type_name}) with minimum={connection.minimum} for quantum data ID "
587 f"{data_id}."
588 )
589 else:
590 # This branch should be impossible during QG generation,
591 # because that algorithm can only make quanta whose inputs
592 # are either already present or should be created during
593 # execution. It can trigger during execution if the input
594 # wasn't actually created by an upstream task in the same
595 # graph.
596 raise NoWorkFound(label, name, connection)
597 for name, (connection, refs) in outputs.items():
598 dataset_type_name = connection.name
599 if not connection.multiple and len(refs) > 1:
600 raise ScalarError(
601 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
602 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
603 f"for quantum data ID {data_id}."
604 )
605 return {}, {}
608def iterConnections(connections: PipelineTaskConnections,
609 connectionType: Union[str, Iterable[str]]
610 ) -> typing.Generator[BaseConnection, None, None]:
611 """Creates an iterator over the selected connections type which yields
612 all the defined connections of that type.
614 Parameters
615 ----------
616 connections: `PipelineTaskConnections`
617 An instance of a `PipelineTaskConnections` object that will be iterated
618 over.
619 connectionType: `str`
620 The type of connections to iterate over, valid values are inputs,
621 outputs, prerequisiteInputs, initInputs, initOutputs.
623 Yields
624 -------
625 connection: `BaseConnection`
626 A connection defined on the input connections object of the type
627 supplied. The yielded value Will be an derived type of
628 `BaseConnection`.
629 """
630 if isinstance(connectionType, str):
631 connectionType = (connectionType,)
632 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
633 yield getattr(connections, name)
636@dataclass
637class AdjustQuantumHelper:
638 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
640 This class holds `input` and `output` mappings in the form used by
641 `Quantum` and execution harness code, i.e. with `DatasetType` keys,
642 translating them to and from the connection-oriented mappings used inside
643 `PipelineTaskConnections`.
644 """
646 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
647 """Mapping of regular input and prerequisite input datasets, grouped by
648 `DatasetType`.
649 """
651 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]]
652 """Mapping of output datasets, grouped by `DatasetType`.
653 """
655 inputs_adjusted: bool = False
656 """Whether any inputs were removed in the last call to `adjust_in_place`.
657 """
659 outputs_adjusted: bool = False
660 """Whether any outputs were removed in the last call to `adjust_in_place`.
661 """
663 def adjust_in_place(
664 self,
665 connections: PipelineTaskConnections,
666 label: str,
667 data_id: DataCoordinate,
668 ) -> None:
669 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
670 with its results.
672 Parameters
673 ----------
674 connections : `PipelineTaskConnections`
675 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
676 label : `str`
677 Label for this task in the pipeline (should be used in all
678 diagnostic messages).
679 data_id : `lsst.daf.butler.DataCoordinate`
680 Data ID for this quantum in the pipeline (should be used in all
681 diagnostic messages).
682 """
683 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
684 # connection-keyed, PipelineTask-oriented mappings.
685 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {}
686 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {}
687 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
688 connection = getattr(connections, name)
689 dataset_type_name = connection.name
690 inputs_by_connection[name] = (
691 connection,
692 tuple(self.inputs.get(dataset_type_name, ()))
693 )
694 for name in itertools.chain(connections.outputs):
695 connection = getattr(connections, name)
696 dataset_type_name = connection.name
697 outputs_by_connection[name] = (
698 connection,
699 tuple(self.outputs.get(dataset_type_name, ()))
700 )
701 # Actually call adjustQuantum.
702 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
703 inputs_by_connection,
704 outputs_by_connection,
705 label,
706 data_id,
707 )
708 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
709 # installing new mappings in self if necessary.
710 if adjusted_inputs_by_connection:
711 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs)
712 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
713 dataset_type_name = connection.name
714 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
715 raise RuntimeError(
716 f"adjustQuantum implementation for task with label {label} returned {name} "
717 f"({dataset_type_name}) input datasets that are not a subset of those "
718 f"it was given for data ID {data_id}."
719 )
720 adjusted_inputs[dataset_type_name] = list(updated_refs)
721 self.inputs = adjusted_inputs.freeze()
722 self.inputs_adjusted = True
723 else:
724 self.inputs_adjusted = False
725 if adjusted_outputs_by_connection:
726 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs)
727 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
728 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
729 raise RuntimeError(
730 f"adjustQuantum implementation for task with label {label} returned {name} "
731 f"({dataset_type_name}) output datasets that are not a subset of those "
732 f"it was given for data ID {data_id}."
733 )
734 adjusted_outputs[dataset_type_name] = list(updated_refs)
735 self.outputs = adjusted_outputs.freeze()
736 self.outputs_adjusted = True
737 else:
738 self.outputs_adjusted = False