Coverage for python/lsst/pipe/base/connections.py : 52%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection",
26 "DeferredDatasetRef", "iterConnections"]
28from collections import UserDict, namedtuple
29from types import SimpleNamespace
30import typing
31from typing import Union, Iterable
33import itertools
34import string
36from . import config as configMod
37from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput,
38 Output, BaseConnection)
39from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum
41if typing.TYPE_CHECKING: 41 ↛ 42line 41 didn't jump to line 42, because the condition on line 41 was never true
42 from .config import PipelineTaskConfig
45class ScalarError(TypeError):
46 """Exception raised when dataset type is configured as scalar
47 but there are multiple data IDs in a Quantum for that dataset.
48 """
51class PipelineTaskConnectionDict(UserDict):
52 """This is a special dict class used by PipelineTaskConnectionMetaclass
54 This dict is used in PipelineTaskConnection class creation, as the
55 dictionary that is initially used as __dict__. It exists to
56 intercept connection fields declared in a PipelineTaskConnection, and
57 what name is used to identify them. The names are then added to class
58 level list according to the connection type of the class attribute. The
59 names are also used as keys in a class level dictionary associated with
60 the corresponding class attribute. This information is a duplicate of
61 what exists in __dict__, but provides a simple place to lookup and
62 iterate on only these variables.
63 """
64 def __init__(self, *args, **kwargs):
65 super().__init__(*args, **kwargs)
66 # Initialize class level variables used to track any declared
67 # class level variables that are instances of
68 # connectionTypes.BaseConnection
69 self.data['inputs'] = []
70 self.data['prerequisiteInputs'] = []
71 self.data['outputs'] = []
72 self.data['initInputs'] = []
73 self.data['initOutputs'] = []
74 self.data['allConnections'] = {}
76 def __setitem__(self, name, value):
77 if isinstance(value, Input):
78 self.data['inputs'].append(name)
79 elif isinstance(value, PrerequisiteInput):
80 self.data['prerequisiteInputs'].append(name)
81 elif isinstance(value, Output):
82 self.data['outputs'].append(name)
83 elif isinstance(value, InitInput):
84 self.data['initInputs'].append(name)
85 elif isinstance(value, InitOutput):
86 self.data['initOutputs'].append(name)
87 # This should not be an elif, as it needs tested for
88 # everything that inherits from BaseConnection
89 if isinstance(value, BaseConnection):
90 object.__setattr__(value, 'varName', name)
91 self.data['allConnections'][name] = value
92 # defer to the default behavior
93 super().__setitem__(name, value)
96class PipelineTaskConnectionsMetaclass(type):
97 """Metaclass used in the declaration of PipelineTaskConnections classes
98 """
99 def __prepare__(name, bases, **kwargs): # noqa: 805
100 # Create an instance of our special dict to catch and track all
101 # variables that are instances of connectionTypes.BaseConnection
102 # Copy any existing connections from a parent class
103 dct = PipelineTaskConnectionDict()
104 for base in bases:
105 if isinstance(base, PipelineTaskConnectionsMetaclass): 105 ↛ 104line 105 didn't jump to line 104, because the condition on line 105 was never false
106 for name, value in base.allConnections.items(): 106 ↛ 107line 106 didn't jump to line 107, because the loop on line 106 never started
107 dct[name] = value
108 return dct
110 def __new__(cls, name, bases, dct, **kwargs):
111 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions "
112 "attribute which is an iterable of dimension names")
114 if name != 'PipelineTaskConnections':
115 # Verify that dimensions are passed as a keyword in class
116 # declaration
117 if 'dimensions' not in kwargs: 117 ↛ 118line 117 didn't jump to line 118, because the condition on line 117 was never true
118 for base in bases:
119 if hasattr(base, 'dimensions'):
120 kwargs['dimensions'] = base.dimensions
121 break
122 if 'dimensions' not in kwargs:
123 raise dimensionsValueError
124 try:
125 if isinstance(kwargs['dimensions'], str): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true
126 raise TypeError("Dimensions must be iterable of dimensions, got str,"
127 "possibly omitted trailing comma")
128 if not isinstance(kwargs['dimensions'], typing.Iterable): 128 ↛ 129line 128 didn't jump to line 129, because the condition on line 128 was never true
129 raise TypeError("Dimensions must be iterable of dimensions")
130 dct['dimensions'] = set(kwargs['dimensions'])
131 except TypeError as exc:
132 raise dimensionsValueError from exc
133 # Lookup any python string templates that may have been used in the
134 # declaration of the name field of a class connection attribute
135 allTemplates = set()
136 stringFormatter = string.Formatter()
137 # Loop over all connections
138 for obj in dct['allConnections'].values():
139 nameValue = obj.name
140 # add all the parameters to the set of templates
141 for param in stringFormatter.parse(nameValue):
142 if param[1] is not None:
143 allTemplates.add(param[1])
145 # look up any template from base classes and merge them all
146 # together
147 mergeDict = {}
148 for base in bases[::-1]:
149 if hasattr(base, 'defaultTemplates'): 149 ↛ 150line 149 didn't jump to line 150, because the condition on line 149 was never true
150 mergeDict.update(base.defaultTemplates)
151 if 'defaultTemplates' in kwargs:
152 mergeDict.update(kwargs['defaultTemplates'])
154 if len(mergeDict) > 0:
155 kwargs['defaultTemplates'] = mergeDict
157 # Verify that if templated strings were used, defaults were
158 # supplied as an argument in the declaration of the connection
159 # class
160 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true
161 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no "
162 "defaut templates were provided, add a dictionary attribute named "
163 "defaultTemplates which contains the mapping between template key and value")
164 if len(allTemplates) > 0:
165 # Verify all templates have a default, and throw if they do not
166 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys())
167 templateDifference = allTemplates.difference(defaultTemplateKeys)
168 if templateDifference: 168 ↛ 169line 168 didn't jump to line 169, because the condition on line 168 was never true
169 raise TypeError(f"Default template keys were not provided for {templateDifference}")
170 # Verify that templates do not share names with variable names
171 # used for a connection, this is needed because of how
172 # templates are specified in an associated config class.
173 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys()))
174 if len(nameTemplateIntersection) > 0: 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true
175 raise TypeError(f"Template parameters cannot share names with Class attributes"
176 f" (conflicts are {nameTemplateIntersection}).")
177 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {})
179 # Convert all the connection containers into frozensets so they cannot
180 # be modified at the class scope
181 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
182 dct[connectionName] = frozenset(dct[connectionName])
183 # our custom dict type must be turned into an actual dict to be used in
184 # type.__new__
185 return super().__new__(cls, name, bases, dict(dct))
187 def __init__(cls, name, bases, dct, **kwargs):
188 # This overrides the default init to drop the kwargs argument. Python
189 # metaclasses will have this argument set if any kwargs are passes at
190 # class construction time, but should be consumed before calling
191 # __init__ on the type metaclass. This is in accordance with python
192 # documentation on metaclasses
193 super().__init__(name, bases, dct)
196class QuantizedConnection(SimpleNamespace):
197 """A Namespace to map defined variable names of connections to their
198 `lsst.daf.buter.DatasetRef`s
200 This class maps the names used to define a connection on a
201 PipelineTaskConnectionsClass to the corresponding
202 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
203 instance. This will be a quantum of execution based on the graph created
204 by examining all the connections defined on the
205 `PipelineTaskConnectionsClass`.
206 """
207 def __init__(self, **kwargs):
208 # Create a variable to track what attributes are added. This is used
209 # later when iterating over this QuantizedConnection instance
210 object.__setattr__(self, "_attributes", set())
212 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
213 # Capture the attribute name as it is added to this object
214 self._attributes.add(name)
215 super().__setattr__(name, value)
217 def __delattr__(self, name):
218 object.__delattr__(self, name)
219 self._attributes.remove(name)
221 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
222 typing.List[DatasetRef]]], None, None]:
223 """Make an Iterator for this QuantizedConnection
225 Iterating over a QuantizedConnection will yield a tuple with the name
226 of an attribute and the value associated with that name. This is
227 similar to dict.items() but is on the namespace attributes rather than
228 dict keys.
229 """
230 yield from ((name, getattr(self, name)) for name in self._attributes)
232 def keys(self) -> typing.Generator[str, None, None]:
233 """Returns an iterator over all the attributes added to a
234 QuantizedConnection class
235 """
236 yield from self._attributes
239class InputQuantizedConnection(QuantizedConnection):
240 pass
243class OutputQuantizedConnection(QuantizedConnection):
244 pass
247class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")):
248 """Class which denotes that a datasetRef should be treated as deferred when
249 interacting with the butler
251 Parameters
252 ----------
253 datasetRef : `lsst.daf.butler.DatasetRef`
254 The `lsst.daf.butler.DatasetRef` that will be eventually used to
255 resolve a dataset
256 """
257 __slots__ = ()
260class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
261 """PipelineTaskConnections is a class used to declare desired IO when a
262 PipelineTask is run by an activator
264 Parameters
265 ----------
266 config : `PipelineTaskConfig`
267 A `PipelineTaskConfig` class instance whose class has been configured
268 to use this `PipelineTaskConnectionsClass`
270 Notes
271 -----
272 ``PipelineTaskConnection`` classes are created by declaring class
273 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
274 listed as follows:
276 * ``InitInput`` - Defines connections in a quantum graph which are used as
277 inputs to the ``__init__`` function of the `PipelineTask` corresponding
278 to this class
279 * ``InitOuput`` - Defines connections in a quantum graph which are to be
280 persisted using a butler at the end of the ``__init__`` function of the
281 `PipelineTask` corresponding to this class. The variable name used to
282 define this connection should be the same as an attribute name on the
283 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
284 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
285 a `PipelineTask` instance should have an attribute
286 ``self.outputSchema`` defined. Its value is what will be saved by the
287 activator framework.
288 * ``PrerequisiteInput`` - An input connection type that defines a
289 `lsst.daf.butler.DatasetType` that must be present at execution time,
290 but that will not be used during the course of creating the quantum
291 graph to be executed. These most often are things produced outside the
292 processing pipeline, such as reference catalogs.
293 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
294 in the ``run`` method of a `PipelineTask`. The name used to declare
295 class attribute must match a function argument name in the ``run``
296 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
297 defines an ``Input`` with the name ``calexp``, then the corresponding
298 signature should be ``PipelineTask.run(calexp, ...)``
299 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
300 execution of a `PipelineTask`. The name used to declare the connection
301 must correspond to an attribute of a `Struct` that is returned by a
302 `PipelineTask` ``run`` method. E.g. if an output connection is
303 defined with the name ``measCat``, then the corresponding
304 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
305 X matches the ``storageClass`` type defined on the output connection.
307 The process of declaring a ``PipelineTaskConnection`` class involves
308 parameters passed in the declaration statement.
310 The first parameter is ``dimensions`` which is an iterable of strings which
311 defines the unit of processing the run method of a corresponding
312 `PipelineTask` will operate on. These dimensions must match dimensions that
313 exist in the butler registry which will be used in executing the
314 corresponding `PipelineTask`.
316 The second parameter is labeled ``defaultTemplates`` and is conditionally
317 optional. The name attributes of connections can be specified as python
318 format strings, with named format arguments. If any of the name parameters
319 on connections defined in a `PipelineTaskConnections` class contain a
320 template, then a default template value must be specified in the
321 ``defaultTemplates`` argument. This is done by passing a dictionary with
322 keys corresponding to a template identifier, and values corresponding to
323 the value to use as a default when formatting the string. For example if
324 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
325 ``defaultTemplates`` = {'input': 'deep'}.
327 Once a `PipelineTaskConnections` class is created, it is used in the
328 creation of a `PipelineTaskConfig`. This is further documented in the
329 documentation of `PipelineTaskConfig`. For the purposes of this
330 documentation, the relevant information is that the config class allows
331 configuration of connection names by users when running a pipeline.
333 Instances of a `PipelineTaskConnections` class are used by the pipeline
334 task execution framework to introspect what a corresponding `PipelineTask`
335 will require, and what it will produce.
337 Examples
338 --------
339 >>> from lsst.pipe.base import connectionTypes as cT
340 >>> from lsst.pipe.base import PipelineTaskConnections
341 >>> from lsst.pipe.base import PipelineTaskConfig
342 >>> class ExampleConnections(PipelineTaskConnections,
343 ... dimensions=("A", "B"),
344 ... defaultTemplates={"foo": "Example"}):
345 ... inputConnection = cT.Input(doc="Example input",
346 ... dimensions=("A", "B"),
347 ... storageClass=Exposure,
348 ... name="{foo}Dataset")
349 ... outputConnection = cT.Output(doc="Example output",
350 ... dimensions=("A", "B"),
351 ... storageClass=Exposure,
352 ... name="{foo}output")
353 >>> class ExampleConfig(PipelineTaskConfig,
354 ... pipelineConnections=ExampleConnections):
355 ... pass
356 >>> config = ExampleConfig()
357 >>> config.connections.foo = Modified
358 >>> config.connections.outputConnection = "TotallyDifferent"
359 >>> connections = ExampleConnections(config=config)
360 >>> assert(connections.inputConnection.name == "ModifiedDataset")
361 >>> assert(connections.outputConnection.name == "TotallyDifferent")
362 """
364 def __init__(self, *, config: 'PipelineTaskConfig' = None):
365 self.inputs = set(self.inputs)
366 self.prerequisiteInputs = set(self.prerequisiteInputs)
367 self.outputs = set(self.outputs)
368 self.initInputs = set(self.initInputs)
369 self.initOutputs = set(self.initOutputs)
370 self.allConnections = dict(self.allConnections)
372 if config is None or not isinstance(config, configMod.PipelineTaskConfig):
373 raise ValueError("PipelineTaskConnections must be instantiated with"
374 " a PipelineTaskConfig instance")
375 self.config = config
376 # Extract the template names that were defined in the config instance
377 # by looping over the keys of the defaultTemplates dict specified at
378 # class declaration time
379 templateValues = {name: getattr(config.connections, name) for name in getattr(self,
380 'defaultTemplates').keys()}
381 # Extract the configured value corresponding to each connection
382 # variable. I.e. for each connection identifier, populate a override
383 # for the connection.name attribute
384 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues)
385 for name in self.allConnections.keys()}
387 # connections.name corresponds to a dataset type name, create a reverse
388 # mapping that goes from dataset type name to attribute identifier name
389 # (variable name) on the connection class
390 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
392 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection,
393 OutputQuantizedConnection]:
394 """Builds QuantizedConnections corresponding to input Quantum
396 Parameters
397 ----------
398 quantum : `lsst.daf.butler.Quantum`
399 Quantum object which defines the inputs and outputs for a given
400 unit of processing
402 Returns
403 -------
404 retVal : `tuple` of (`InputQuantizedConnection`,
405 `OutputQuantizedConnection`) Namespaces mapping attribute names
406 (identifiers of connections) to butler references defined in the
407 input `lsst.daf.butler.Quantum`
408 """
409 inputDatasetRefs = InputQuantizedConnection()
410 outputDatasetRefs = OutputQuantizedConnection()
411 # operate on a reference object and an interable of names of class
412 # connection attributes
413 for refs, names in zip((inputDatasetRefs, outputDatasetRefs),
414 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)):
415 # get a name of a class connection attribute
416 for attributeName in names:
417 # get the attribute identified by name
418 attribute = getattr(self, attributeName)
419 # Branch if the attribute dataset type is an input
420 if attribute.name in quantum.inputs:
421 # Get the DatasetRefs
422 quantumInputRefs = quantum.inputs[attribute.name]
423 # if the dataset is marked to load deferred, wrap it in a
424 # DeferredDatasetRef
425 if attribute.deferLoad:
426 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs]
427 # Unpack arguments that are not marked multiples (list of
428 # length one)
429 if not attribute.multiple:
430 if len(quantumInputRefs) > 1:
431 raise ScalarError(
432 f"Received multiple datasets "
433 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
434 f"for scalar connection {attributeName} "
435 f"({quantumInputRefs[0].datasetType.name}) "
436 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
437 )
438 if len(quantumInputRefs) == 0:
439 continue
440 quantumInputRefs = quantumInputRefs[0]
441 # Add to the QuantizedConnection identifier
442 setattr(refs, attributeName, quantumInputRefs)
443 # Branch if the attribute dataset type is an output
444 elif attribute.name in quantum.outputs:
445 value = quantum.outputs[attribute.name]
446 # Unpack arguments that are not marked multiples (list of
447 # length one)
448 if not attribute.multiple:
449 value = value[0]
450 # Add to the QuantizedConnection identifier
451 setattr(refs, attributeName, value)
452 # Specified attribute is not in inputs or outputs dont know how
453 # to handle, throw
454 else:
455 raise ValueError(f"Attribute with name {attributeName} has no counterpoint "
456 "in input quantum")
457 return inputDatasetRefs, outputDatasetRefs
459 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]]
460 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]:
461 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
462 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
463 of the activator.
465 The base class implementation simply checks that input connections with
466 ``multiple`` set to `False` have no more than one dataset.
468 Parameters
469 ----------
470 datasetRefMap : `NamedKeyDict`
471 Mapping from dataset type to a `set` of
472 `lsst.daf.butler.DatasetRef` objects
474 Returns
475 -------
476 datasetRefMap : `NamedKeyDict`
477 Modified mapping of input with possibly adjusted
478 `lsst.daf.butler.DatasetRef` objects.
480 Raises
481 ------
482 ScalarError
483 Raised if any `Input` or `PrerequisiteInput` connection has
484 ``multiple`` set to `False`, but multiple datasets.
485 Exception
486 Overrides of this function have the option of raising an Exception
487 if a field in the input does not satisfy a need for a corresponding
488 pipelineTask, i.e. no reference catalogs are found.
489 """
490 for connection in itertools.chain(iterConnections(self, "inputs"),
491 iterConnections(self, "prerequisiteInputs")):
492 refs = datasetRefMap[connection.name]
493 if not connection.multiple and len(refs) > 1:
494 raise ScalarError(
495 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
496 f"for scalar connection {connection.name} ({refs[0].datasetType.name})."
497 )
498 return datasetRefMap
501def iterConnections(connections: PipelineTaskConnections,
502 connectionType: Union[str, Iterable[str]]
503 ) -> typing.Generator[BaseConnection, None, None]:
504 """Creates an iterator over the selected connections type which yields
505 all the defined connections of that type.
507 Parameters
508 ----------
509 connections: `PipelineTaskConnections`
510 An instance of a `PipelineTaskConnections` object that will be iterated
511 over.
512 connectionType: `str`
513 The type of connections to iterate over, valid values are inputs,
514 outputs, prerequisiteInputs, initInputs, initOutputs.
516 Yields
517 -------
518 connection: `BaseConnection`
519 A connection defined on the input connections object of the type
520 supplied. The yielded value Will be an derived type of
521 `BaseConnection`.
522 """
523 if isinstance(connectionType, str):
524 connectionType = (connectionType,)
525 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
526 yield getattr(connections, name)