lsst.pipe.base  19.0.0-14-g91c0010
connections.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining connection classes for PipelineTask.
23 """
24 
25 __all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection",
26  "DeferredDatasetRef", "iterConnections"]
27 
28 from collections import UserDict, namedtuple
29 from types import SimpleNamespace
30 import typing
31 
32 import itertools
33 import string
34 
35 from . import config as configMod
36 from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput,
37  Output, BaseConnection)
38 from lsst.daf.butler import DatasetRef, Quantum
39 
40 if typing.TYPE_CHECKING:
41  from .config import PipelineTaskConfig
42 
43 
44 class ScalarError(TypeError):
45  """Exception raised when dataset type is configured as scalar
46  but there are multiple DataIds in a Quantum for that dataset.
47 
48  Parameters
49  ----------
50  key : `str`
51  Name of the configuration field for dataset type.
52  If ``numDataIds`` is not specified, it is assumed that this parameter
53  is the full message to be reported and not the key.
54  numDataIds : `int`, optional
55  Actual number of DataIds in a Quantum for this dataset type.
56  """
57  def __init__(self, key, numDataIds=None):
58  if numDataIds is None:
59  # Assume we are receiving a normal TypeError message
60  err_msg = key
61  else:
62  err_msg = f"Expected scalar for output dataset field {key}, " \
63  f"received {numDataIds} DataIds"
64  super().__init__(err_msg)
65 
66 
68  """This is a special dict class used by PipelineTaskConnectionMetaclass
69 
70  This dict is used in PipelineTaskConnection class creation, as the
71  dictionary that is initially used as __dict__. It exists to
72  intercept connection fields declared in a PipelineTaskConnection, and
73  what name is used to identify them. The names are then added to class
74  level list according to the connection type of the class attribute. The
75  names are also used as keys in a class level dictionary associated with
76  the corresponding class attribute. This information is a duplicate of
77  what exists in __dict__, but provides a simple place to lookup and
78  iterate on only these variables.
79  """
80  def __init__(self, *args, **kwargs):
81  super().__init__(*args, **kwargs)
82  # Initialize class level variables used to track any declared
83  # class level variables that are instances of
84  # connectionTypes.BaseConnection
85  self.data['inputs'] = []
86  self.data['prerequisiteInputs'] = []
87  self.data['outputs'] = []
88  self.data['initInputs'] = []
89  self.data['initOutputs'] = []
90  self.data['allConnections'] = {}
91 
92  def __setitem__(self, name, value):
93  if isinstance(value, Input):
94  self.data['inputs'].append(name)
95  elif isinstance(value, PrerequisiteInput):
96  self.data['prerequisiteInputs'].append(name)
97  elif isinstance(value, Output):
98  self.data['outputs'].append(name)
99  elif isinstance(value, InitInput):
100  self.data['initInputs'].append(name)
101  elif isinstance(value, InitOutput):
102  self.data['initOutputs'].append(name)
103  # This should not be an elif, as it needs tested for
104  # everything that inherits from BaseConnection
105  if isinstance(value, BaseConnection):
106  object.__setattr__(value, 'varName', name)
107  self.data['allConnections'][name] = value
108  # defer to the default behavior
109  super().__setitem__(name, value)
110 
111 
113  """Metaclass used in the declaration of PipelineTaskConnections classes
114  """
115  def __prepare__(name, bases, **kwargs): # noqa: 805
116  # Create an instance of our special dict to catch and track all
117  # variables that are instances of connectionTypes.BaseConnection
118  # Copy any existing connections from a parent class
120  for base in bases:
121  if isinstance(base, PipelineTaskConnectionsMetaclass):
122  for name, value in base.allConnections.items():
123  dct[name] = value
124  return dct
125 
126  def __new__(cls, name, bases, dct, **kwargs):
127  dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions "
128  "attribute which is an iterable of dimension names")
129 
130  if name != 'PipelineTaskConnections':
131  # Verify that dimensions are passed as a keyword in class
132  # declaration
133  if 'dimensions' not in kwargs:
134  for base in bases:
135  if hasattr(base, 'dimensions'):
136  kwargs['dimensions'] = base.dimensions
137  break
138  if 'dimensions' not in kwargs:
139  raise dimensionsValueError
140  try:
141  dct['dimensions'] = set(kwargs['dimensions'])
142  except TypeError as exc:
143  raise dimensionsValueError from exc
144  # Lookup any python string templates that may have been used in the
145  # declaration of the name field of a class connection attribute
146  allTemplates = set()
147  stringFormatter = string.Formatter()
148  # Loop over all connections
149  for obj in dct['allConnections'].values():
150  nameValue = obj.name
151  # add all the parameters to the set of templates
152  for param in stringFormatter.parse(nameValue):
153  if param[1] is not None:
154  allTemplates.add(param[1])
155 
156  # look up any template from base classes and merge them all
157  # together
158  mergeDict = {}
159  for base in bases[::-1]:
160  if hasattr(base, 'defaultTemplates'):
161  mergeDict.update(base.defaultTemplates)
162  if 'defaultTemplates' in kwargs:
163  mergeDict.update(kwargs['defaultTemplates'])
164 
165  if len(mergeDict) > 0:
166  kwargs['defaultTemplates'] = mergeDict
167 
168  # Verify that if templated strings were used, defaults were
169  # supplied as an argument in the declaration of the connection
170  # class
171  if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs:
172  raise TypeError("PipelineTaskConnection class contains templated attribute names, but no "
173  "defaut templates were provided, add a dictionary attribute named "
174  "defaultTemplates which contains the mapping between template key and value")
175  if len(allTemplates) > 0:
176  # Verify all templates have a default, and throw if they do not
177  defaultTemplateKeys = set(kwargs['defaultTemplates'].keys())
178  templateDifference = allTemplates.difference(defaultTemplateKeys)
179  if templateDifference:
180  raise TypeError(f"Default template keys were not provided for {templateDifference}")
181  # Verify that templates do not share names with variable names
182  # used for a connection, this is needed because of how
183  # templates are specified in an associated config class.
184  nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys()))
185  if len(nameTemplateIntersection) > 0:
186  raise TypeError(f"Template parameters cannot share names with Class attributes")
187  dct['defaultTemplates'] = kwargs.get('defaultTemplates', {})
188 
189  # Convert all the connection containers into frozensets so they cannot
190  # be modified at the class scope
191  for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
192  dct[connectionName] = frozenset(dct[connectionName])
193  # our custom dict type must be turned into an actual dict to be used in
194  # type.__new__
195  return super().__new__(cls, name, bases, dict(dct))
196 
197  def __init__(cls, name, bases, dct, **kwargs):
198  # This overrides the default init to drop the kwargs argument. Python
199  # metaclasses will have this argument set if any kwargs are passes at
200  # class construction time, but should be consumed before calling
201  # __init__ on the type metaclass. This is in accordance with python
202  # documentation on metaclasses
203  super().__init__(name, bases, dct)
204 
205 
206 class QuantizedConnection(SimpleNamespace):
207  """A Namespace to map defined variable names of connections to their
208  `lsst.daf.buter.DatasetRef`s
209 
210  This class maps the names used to define a connection on a
211  PipelineTaskConnectionsClass to the corresponding
212  `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
213  instance. This will be a quantum of execution based on the graph created
214  by examining all the connections defined on the
215  `PipelineTaskConnectionsClass`.
216  """
217  def __init__(self, **kwargs):
218  # Create a variable to track what attributes are added. This is used
219  # later when iterating over this QuantizedConnection instance
220  object.__setattr__(self, "_attributes", set())
221 
222  def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
223  # Capture the attribute name as it is added to this object
224  self._attributes.add(name)
225  super().__setattr__(name, value)
226 
227  def __delattr__(self, name):
228  object.__delattr__(self, name)
229  self._attributes.remove(name)
230 
231  def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
232  typing.List[DatasetRef]]], None, None]:
233  """Make an Iterator for this QuantizedConnection
234 
235  Iterating over a QuantizedConnection will yield a tuple with the name
236  of an attribute and the value associated with that name. This is
237  similar to dict.items() but is on the namespace attributes rather than
238  dict keys.
239  """
240  yield from ((name, getattr(self, name)) for name in self._attributes)
241 
242  def keys(self) -> typing.Generator[str, None, None]:
243  """Returns an iterator over all the attributes added to a
244  QuantizedConnection class
245  """
246  yield from self._attributes
247 
248 
250  pass
251 
252 
253 class OutputQuantizedConnection(QuantizedConnection):
254  pass
255 
256 
257 class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")):
258  """Class which denotes that a datasetRef should be treated as deferred when
259  interacting with the butler
260 
261  Parameters
262  ----------
263  datasetRef : `lsst.daf.butler.DatasetRef`
264  The `lsst.daf.butler.DatasetRef` that will be eventually used to
265  resolve a dataset
266  """
267  __slots__ = ()
268 
269 
270 class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
271  """PipelineTaskConnections is a class used to declare desired IO when a
272  PipelineTask is run by an activator
273 
274  Parameters
275  ----------
276  config : `PipelineTaskConfig`
277  A `PipelineTaskConfig` class instance whose class has been configured
278  to use this `PipelineTaskConnectionsClass`
279 
280  Notes
281  -----
282  ``PipelineTaskConnection`` classes are created by declaring class
283  attributes of types defined in `lsst.pipe.base.connectionTypes` and are
284  listed as follows:
285 
286  * ``InitInput`` - Defines connections in a quantum graph which are used as
287  inputs to the ``__init__`` function of the `PipelineTask` corresponding
288  to this class
289  * ``InitOuput`` - Defines connections in a quantum graph which are to be
290  persisted using a butler at the end of the ``__init__`` function of the
291  `PipelineTask` corresponding to this class. The variable name used to
292  define this connection should be the same as an attribute name on the
293  `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
294  the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
295  a `PipelineTask` instance should have an attribute
296  ``self.outputSchema`` defined. Its value is what will be saved by the
297  activator framework.
298  * ``PrerequisiteInput`` - An input connection type that defines a
299  `lsst.daf.butler.DatasetType` that must be present at execution time,
300  but that will not be used during the course of creating the quantum
301  graph to be executed. These most often are things produced outside the
302  processing pipeline, such as reference catalogs.
303  * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
304  in the ``run`` method of a `PipelineTask`. The name used to declare
305  class attribute must match a function argument name in the ``run``
306  method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
307  defines an ``Input`` with the name ``calexp``, then the corresponding
308  signature should be ``PipelineTask.run(calexp, ...)``
309  * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
310  execution of a `PipelineTask`. The name used to declare the connection
311  must correspond to an attribute of a `Struct` that is returned by a
312  `PipelineTask` ``run`` method. E.g. if an output connection is
313  defined with the name ``measCat``, then the corresponding
314  ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
315  X matches the ``storageClass`` type defined on the output connection.
316 
317  The process of declaring a ``PipelineTaskConnection`` class involves
318  parameters passed in the declaration statement.
319 
320  The first parameter is ``dimensions`` which is an iterable of strings which
321  defines the unit of processing the run method of a corresponding
322  `PipelineTask` will operate on. These dimensions must match dimensions that
323  exist in the butler registry which will be used in executing the
324  corresponding `PipelineTask`.
325 
326  The second parameter is labeled ``defaultTemplates`` and is conditionally
327  optional. The name attributes of connections can be specified as python
328  format strings, with named format arguments. If any of the name parameters
329  on connections defined in a `PipelineTaskConnections` class contain a
330  template, then a default template value must be specified in the
331  ``defaultTemplates`` argument. This is done by passing a dictionary with
332  keys corresponding to a template identifier, and values corresponding to
333  the value to use as a default when formatting the string. For example if
334  ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
335  ``defaultTemplates`` = {'input': 'deep'}.
336 
337  Once a `PipelineTaskConnections` class is created, it is used in the
338  creation of a `PipelineTaskConfig`. This is further documented in the
339  documentation of `PipelineTaskConfig`. For the purposes of this
340  documentation, the relevant information is that the config class allows
341  configuration of connection names by users when running a pipeline.
342 
343  Instances of a `PipelineTaskConnections` class are used by the pipeline
344  task execution framework to introspect what a corresponding `PipelineTask`
345  will require, and what it will produce.
346 
347  Examples
348  --------
349  >>> from lsst.pipe.base import connectionTypes as cT
350  >>> from lsst.pipe.base import PipelineTaskConnections
351  >>> from lsst.pipe.base import PipelineTaskConfig
352  >>> class ExampleConnections(PipelineTaskConnections,
353  ... dimensions=("A", "B"),
354  ... defaultTemplates={"foo": "Example"}):
355  ... inputConnection = cT.Input(doc="Example input",
356  ... dimensions=("A", "B"),
357  ... storageClass=Exposure,
358  ... name="{foo}Dataset")
359  ... outputConnection = cT.Output(doc="Example output",
360  ... dimensions=("A", "B"),
361  ... storageClass=Exposure,
362  ... name="{foo}output")
363  >>> class ExampleConfig(PipelineTaskConfig,
364  ... pipelineConnections=ExampleConnections):
365  ... pass
366  >>> config = ExampleConfig()
367  >>> config.connections.foo = Modified
368  >>> config.connections.outputConnection = "TotallyDifferent"
369  >>> connections = ExampleConnections(config=config)
370  >>> assert(connections.inputConnection.name == "ModifiedDataset")
371  >>> assert(connections.outputConnection.name == "TotallyDifferent")
372  """
373 
374  def __init__(self, *, config: 'PipelineTaskConfig' = None):
375  self.inputs = set(self.inputs)
377  self.outputs = set(self.outputs)
378  self.initInputs = set(self.initInputs)
379  self.initOutputs = set(self.initOutputs)
380 
381  if config is None or not isinstance(config, configMod.PipelineTaskConfig):
382  raise ValueError("PipelineTaskConnections must be instantiated with"
383  " a PipelineTaskConfig instance")
384  self.config = config
385  # Extract the template names that were defined in the config instance
386  # by looping over the keys of the defaultTemplates dict specified at
387  # class declaration time
388  templateValues = {name: getattr(config.connections, name) for name in getattr(self,
389  'defaultTemplates').keys()}
390  # Extract the configured value corresponding to each connection
391  # variable. I.e. for each connection identifier, populate a override
392  # for the connection.name attribute
393  self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues)
394  for name in self.allConnections.keys()}
395 
396  # connections.name corresponds to a dataset type name, create a reverse
397  # mapping that goes from dataset type name to attribute identifier name
398  # (variable name) on the connection class
399  self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
400 
401  def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection,
402  OutputQuantizedConnection]:
403  """Builds QuantizedConnections corresponding to input Quantum
404 
405  Parameters
406  ----------
407  quantum : `lsst.daf.butler.Quantum`
408  Quantum object which defines the inputs and outputs for a given
409  unit of processing
410 
411  Returns
412  -------
413  retVal : `tuple` of (`InputQuantizedConnection`,
414  `OutputQuantizedConnection`) Namespaces mapping attribute names
415  (identifiers of connections) to butler references defined in the
416  input `lsst.daf.butler.Quantum`
417  """
418  inputDatasetRefs = InputQuantizedConnection()
419  outputDatasetRefs = OutputQuantizedConnection()
420  # operate on a reference object and an interable of names of class
421  # connection attributes
422  for refs, names in zip((inputDatasetRefs, outputDatasetRefs),
423  (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)):
424  # get a name of a class connection attribute
425  for attributeName in names:
426  # get the attribute identified by name
427  attribute = getattr(self, attributeName)
428  # Branch if the attribute dataset type is an input
429  if attribute.name in quantum.predictedInputs:
430  # Get the DatasetRefs
431  quantumInputRefs = quantum.predictedInputs[attribute.name]
432  # if the dataset is marked to load deferred, wrap it in a
433  # DeferredDatasetRef
434  if attribute.deferLoad:
435  quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs]
436  # Unpack arguments that are not marked multiples (list of
437  # length one)
438  if not attribute.multiple:
439  if len(quantumInputRefs) > 1:
440  raise ScalarError(attributeName, len(quantumInputRefs))
441  if len(quantumInputRefs) == 0:
442  continue
443  quantumInputRefs = quantumInputRefs[0]
444  # Add to the QuantizedConnection identifier
445  setattr(refs, attributeName, quantumInputRefs)
446  # Branch if the attribute dataset type is an output
447  elif attribute.name in quantum.outputs:
448  value = quantum.outputs[attribute.name]
449  # Unpack arguments that are not marked multiples (list of
450  # length one)
451  if not attribute.multiple:
452  value = value[0]
453  # Add to the QuantizedConnection identifier
454  setattr(refs, attributeName, value)
455  # Specified attribute is not in inputs or outputs dont know how
456  # to handle, throw
457  else:
458  raise ValueError(f"Attribute with name {attributeName} has no counterpoint "
459  "in input quantum")
460  return inputDatasetRefs, outputDatasetRefs
461 
462  def adjustQuantum(self, datasetRefMap: InputQuantizedConnection):
463  """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
464  in the `lsst.daf.butler.core.Quantum` during the graph generation stage
465  of the activator.
466 
467  Parameters
468  ----------
469  datasetRefMap : `dict`
470  Mapping with keys of dataset type name to `list` of
471  `lsst.daf.butler.DatasetRef` objects
472 
473  Returns
474  -------
475  datasetRefMap : `dict`
476  Modified mapping of input with possible adjusted
477  `lsst.daf.butler.DatasetRef` objects
478 
479  Raises
480  ------
481  Exception
482  Overrides of this function have the option of raising an Exception
483  if a field in the input does not satisfy a need for a corresponding
484  pipelineTask, i.e. no reference catalogs are found.
485  """
486  return datasetRefMap
487 
488 
489 def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator:
490  """Creates an iterator over the selected connections type which yields
491  all the defined connections of that type.
492 
493  Parameters
494  ----------
495  connections: `PipelineTaskConnections`
496  An instance of a `PipelineTaskConnections` object that will be iterated
497  over.
498  connectionType: `str`
499  The type of connections to iterate over, valid values are inputs,
500  outputs, prerequisiteInputs, initInputs, initOutputs.
501 
502  Yields
503  -------
504  connection: `BaseConnection`
505  A connection defined on the input connections object of the type
506  supplied. The yielded value Will be an derived type of
507  `BaseConnection`.
508  """
509  for name in getattr(connections, connectionType):
510  yield getattr(connections, name)
def __init__(self, key, numDataIds=None)
Definition: connections.py:57