22 """Module defining connection classes for PipelineTask.
25 __all__ = [
"PipelineTaskConnections",
"InputQuantizedConnection",
"OutputQuantizedConnection",
26 "DeferredDatasetRef",
"iterConnections"]
28 from collections
import UserDict, namedtuple
29 from types
import SimpleNamespace
35 from .
import config
as configMod
36 from .connectionTypes
import (InitInput, InitOutput, Input, PrerequisiteInput,
37 Output, BaseConnection)
38 from lsst.daf.butler
import DatasetRef, DatasetType, NamedKeyDict, Quantum
40 if typing.TYPE_CHECKING:
41 from .config
import PipelineTaskConfig
45 """Exception raised when dataset type is configured as scalar
46 but there are multiple data IDs in a Quantum for that dataset.
51 """This is a special dict class used by PipelineTaskConnectionMetaclass
53 This dict is used in PipelineTaskConnection class creation, as the
54 dictionary that is initially used as __dict__. It exists to
55 intercept connection fields declared in a PipelineTaskConnection, and
56 what name is used to identify them. The names are then added to class
57 level list according to the connection type of the class attribute. The
58 names are also used as keys in a class level dictionary associated with
59 the corresponding class attribute. This information is a duplicate of
60 what exists in __dict__, but provides a simple place to lookup and
61 iterate on only these variables.
68 self.data[
'inputs'] = []
69 self.data[
'prerequisiteInputs'] = []
70 self.data[
'outputs'] = []
71 self.data[
'initInputs'] = []
72 self.data[
'initOutputs'] = []
73 self.data[
'allConnections'] = {}
76 if isinstance(value, Input):
77 self.data[
'inputs'].append(name)
78 elif isinstance(value, PrerequisiteInput):
79 self.data[
'prerequisiteInputs'].append(name)
80 elif isinstance(value, Output):
81 self.data[
'outputs'].append(name)
82 elif isinstance(value, InitInput):
83 self.data[
'initInputs'].append(name)
84 elif isinstance(value, InitOutput):
85 self.data[
'initOutputs'].append(name)
88 if isinstance(value, BaseConnection):
89 object.__setattr__(value,
'varName', name)
90 self.data[
'allConnections'][name] = value
96 """Metaclass used in the declaration of PipelineTaskConnections classes
104 if isinstance(base, PipelineTaskConnectionsMetaclass):
105 for name, value
in base.allConnections.items():
110 dimensionsValueError = TypeError(
"PipelineTaskConnections class must be created with a dimensions "
111 "attribute which is an iterable of dimension names")
113 if name !=
'PipelineTaskConnections':
116 if 'dimensions' not in kwargs:
118 if hasattr(base,
'dimensions'):
119 kwargs[
'dimensions'] = base.dimensions
121 if 'dimensions' not in kwargs:
122 raise dimensionsValueError
124 if isinstance(kwargs[
'dimensions'], str):
125 raise TypeError(
"Dimensions must be iterable of dimensions, got str,"
126 "possibly omitted trailing comma")
127 if not isinstance(kwargs[
'dimensions'], typing.Iterable):
128 raise TypeError(
"Dimensions must be iterable of dimensions")
129 dct[
'dimensions'] = set(kwargs[
'dimensions'])
130 except TypeError
as exc:
131 raise dimensionsValueError
from exc
135 stringFormatter = string.Formatter()
137 for obj
in dct[
'allConnections'].values():
140 for param
in stringFormatter.parse(nameValue):
141 if param[1]
is not None:
142 allTemplates.add(param[1])
147 for base
in bases[::-1]:
148 if hasattr(base,
'defaultTemplates'):
149 mergeDict.update(base.defaultTemplates)
150 if 'defaultTemplates' in kwargs:
151 mergeDict.update(kwargs[
'defaultTemplates'])
153 if len(mergeDict) > 0:
154 kwargs[
'defaultTemplates'] = mergeDict
159 if len(allTemplates) > 0
and 'defaultTemplates' not in kwargs:
160 raise TypeError(
"PipelineTaskConnection class contains templated attribute names, but no "
161 "defaut templates were provided, add a dictionary attribute named "
162 "defaultTemplates which contains the mapping between template key and value")
163 if len(allTemplates) > 0:
165 defaultTemplateKeys = set(kwargs[
'defaultTemplates'].keys())
166 templateDifference = allTemplates.difference(defaultTemplateKeys)
167 if templateDifference:
168 raise TypeError(f
"Default template keys were not provided for {templateDifference}")
172 nameTemplateIntersection = allTemplates.intersection(set(dct[
'allConnections'].keys()))
173 if len(nameTemplateIntersection) > 0:
174 raise TypeError(f
"Template parameters cannot share names with Class attributes"
175 f
" (conflicts are {nameTemplateIntersection}).")
176 dct[
'defaultTemplates'] = kwargs.get(
'defaultTemplates', {})
180 for connectionName
in (
"inputs",
"prerequisiteInputs",
"outputs",
"initInputs",
"initOutputs"):
181 dct[connectionName] = frozenset(dct[connectionName])
184 return super().
__new__(cls, name, bases, dict(dct))
196 """A Namespace to map defined variable names of connections to their
197 `lsst.daf.buter.DatasetRef`s
199 This class maps the names used to define a connection on a
200 PipelineTaskConnectionsClass to the corresponding
201 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
202 instance. This will be a quantum of execution based on the graph created
203 by examining all the connections defined on the
204 `PipelineTaskConnectionsClass`.
209 object.__setattr__(self,
"_attributes", set())
211 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
213 self._attributes.add(name)
217 object.__delattr__(self, name)
218 self._attributes.remove(name)
220 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
221 typing.List[DatasetRef]]], None, None]:
222 """Make an Iterator for this QuantizedConnection
224 Iterating over a QuantizedConnection will yield a tuple with the name
225 of an attribute and the value associated with that name. This is
226 similar to dict.items() but is on the namespace attributes rather than
229 yield from ((name, getattr(self, name))
for name
in self._attributes)
231 def keys(self) -> typing.Generator[str, None, None]:
232 """Returns an iterator over all the attributes added to a
233 QuantizedConnection class
235 yield from self._attributes
242 class OutputQuantizedConnection(QuantizedConnection):
247 """Class which denotes that a datasetRef should be treated as deferred when
248 interacting with the butler
252 datasetRef : `lsst.daf.butler.DatasetRef`
253 The `lsst.daf.butler.DatasetRef` that will be eventually used to
260 """PipelineTaskConnections is a class used to declare desired IO when a
261 PipelineTask is run by an activator
265 config : `PipelineTaskConfig`
266 A `PipelineTaskConfig` class instance whose class has been configured
267 to use this `PipelineTaskConnectionsClass`
271 ``PipelineTaskConnection`` classes are created by declaring class
272 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
275 * ``InitInput`` - Defines connections in a quantum graph which are used as
276 inputs to the ``__init__`` function of the `PipelineTask` corresponding
278 * ``InitOuput`` - Defines connections in a quantum graph which are to be
279 persisted using a butler at the end of the ``__init__`` function of the
280 `PipelineTask` corresponding to this class. The variable name used to
281 define this connection should be the same as an attribute name on the
282 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
283 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
284 a `PipelineTask` instance should have an attribute
285 ``self.outputSchema`` defined. Its value is what will be saved by the
287 * ``PrerequisiteInput`` - An input connection type that defines a
288 `lsst.daf.butler.DatasetType` that must be present at execution time,
289 but that will not be used during the course of creating the quantum
290 graph to be executed. These most often are things produced outside the
291 processing pipeline, such as reference catalogs.
292 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
293 in the ``run`` method of a `PipelineTask`. The name used to declare
294 class attribute must match a function argument name in the ``run``
295 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
296 defines an ``Input`` with the name ``calexp``, then the corresponding
297 signature should be ``PipelineTask.run(calexp, ...)``
298 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
299 execution of a `PipelineTask`. The name used to declare the connection
300 must correspond to an attribute of a `Struct` that is returned by a
301 `PipelineTask` ``run`` method. E.g. if an output connection is
302 defined with the name ``measCat``, then the corresponding
303 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
304 X matches the ``storageClass`` type defined on the output connection.
306 The process of declaring a ``PipelineTaskConnection`` class involves
307 parameters passed in the declaration statement.
309 The first parameter is ``dimensions`` which is an iterable of strings which
310 defines the unit of processing the run method of a corresponding
311 `PipelineTask` will operate on. These dimensions must match dimensions that
312 exist in the butler registry which will be used in executing the
313 corresponding `PipelineTask`.
315 The second parameter is labeled ``defaultTemplates`` and is conditionally
316 optional. The name attributes of connections can be specified as python
317 format strings, with named format arguments. If any of the name parameters
318 on connections defined in a `PipelineTaskConnections` class contain a
319 template, then a default template value must be specified in the
320 ``defaultTemplates`` argument. This is done by passing a dictionary with
321 keys corresponding to a template identifier, and values corresponding to
322 the value to use as a default when formatting the string. For example if
323 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
324 ``defaultTemplates`` = {'input': 'deep'}.
326 Once a `PipelineTaskConnections` class is created, it is used in the
327 creation of a `PipelineTaskConfig`. This is further documented in the
328 documentation of `PipelineTaskConfig`. For the purposes of this
329 documentation, the relevant information is that the config class allows
330 configuration of connection names by users when running a pipeline.
332 Instances of a `PipelineTaskConnections` class are used by the pipeline
333 task execution framework to introspect what a corresponding `PipelineTask`
334 will require, and what it will produce.
338 >>> from lsst.pipe.base import connectionTypes as cT
339 >>> from lsst.pipe.base import PipelineTaskConnections
340 >>> from lsst.pipe.base import PipelineTaskConfig
341 >>> class ExampleConnections(PipelineTaskConnections,
342 ... dimensions=("A", "B"),
343 ... defaultTemplates={"foo": "Example"}):
344 ... inputConnection = cT.Input(doc="Example input",
345 ... dimensions=("A", "B"),
346 ... storageClass=Exposure,
347 ... name="{foo}Dataset")
348 ... outputConnection = cT.Output(doc="Example output",
349 ... dimensions=("A", "B"),
350 ... storageClass=Exposure,
351 ... name="{foo}output")
352 >>> class ExampleConfig(PipelineTaskConfig,
353 ... pipelineConnections=ExampleConnections):
355 >>> config = ExampleConfig()
356 >>> config.connections.foo = Modified
357 >>> config.connections.outputConnection = "TotallyDifferent"
358 >>> connections = ExampleConnections(config=config)
359 >>> assert(connections.inputConnection.name == "ModifiedDataset")
360 >>> assert(connections.outputConnection.name == "TotallyDifferent")
363 def __init__(self, *, config:
'PipelineTaskConfig' =
None):
370 if config
is None or not isinstance(config, configMod.PipelineTaskConfig):
371 raise ValueError(
"PipelineTaskConnections must be instantiated with"
372 " a PipelineTaskConfig instance")
377 templateValues = {name: getattr(config.connections, name)
for name
in getattr(self,
378 'defaultTemplates').keys()}
382 self.
_nameOverrides = {name: getattr(config.connections, name).format(**templateValues)
383 for name
in self.allConnections.keys()}
391 OutputQuantizedConnection]:
392 """Builds QuantizedConnections corresponding to input Quantum
396 quantum : `lsst.daf.butler.Quantum`
397 Quantum object which defines the inputs and outputs for a given
402 retVal : `tuple` of (`InputQuantizedConnection`,
403 `OutputQuantizedConnection`) Namespaces mapping attribute names
404 (identifiers of connections) to butler references defined in the
405 input `lsst.daf.butler.Quantum`
411 for refs, names
in zip((inputDatasetRefs, outputDatasetRefs),
414 for attributeName
in names:
416 attribute = getattr(self, attributeName)
418 if attribute.name
in quantum.predictedInputs:
420 quantumInputRefs = quantum.predictedInputs[attribute.name]
423 if attribute.deferLoad:
427 if not attribute.multiple:
428 if len(quantumInputRefs) > 1:
430 f
"Received multiple datasets "
431 f
"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
432 f
"for scalar connection {attributeName} "
433 f
"({quantumInputRefs[0].datasetType.name}) "
434 f
"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
436 if len(quantumInputRefs) == 0:
438 quantumInputRefs = quantumInputRefs[0]
440 setattr(refs, attributeName, quantumInputRefs)
442 elif attribute.name
in quantum.outputs:
443 value = quantum.outputs[attribute.name]
446 if not attribute.multiple:
449 setattr(refs, attributeName, value)
453 raise ValueError(f
"Attribute with name {attributeName} has no counterpoint "
455 return inputDatasetRefs, outputDatasetRefs
457 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]]
458 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]:
459 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
460 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
463 The base class implementation simply checks that input connections with
464 ``multiple`` set to `False` have no more than one dataset.
468 datasetRefMap : `NamedKeyDict`
469 Mapping from dataset type to a `set` of
470 `lsst.daf.butler.DatasetRef` objects
474 datasetRefMap : `NamedKeyDict`
475 Modified mapping of input with possibly adjusted
476 `lsst.daf.butler.DatasetRef` objects.
481 Raised if any `Input` or `PrerequisiteInput` connection has
482 ``multiple`` set to `False`, but multiple datasets.
484 Overrides of this function have the option of raising an Exception
485 if a field in the input does not satisfy a need for a corresponding
486 pipelineTask, i.e. no reference catalogs are found.
490 refs = datasetRefMap[connection.name]
491 if not connection.multiple
and len(refs) > 1:
493 f
"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
494 f
"for scalar connection {connection.name} ({refs[0].datasetType.name})."
499 def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator:
500 """Creates an iterator over the selected connections type which yields
501 all the defined connections of that type.
505 connections: `PipelineTaskConnections`
506 An instance of a `PipelineTaskConnections` object that will be iterated
508 connectionType: `str`
509 The type of connections to iterate over, valid values are inputs,
510 outputs, prerequisiteInputs, initInputs, initOutputs.
514 connection: `BaseConnection`
515 A connection defined on the input connections object of the type
516 supplied. The yielded value Will be an derived type of
519 for name
in getattr(connections, connectionType):
520 yield getattr(connections, name)