Coverage for python/lsst/pipe/base/connections.py : 46%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection",
26 "DeferredDatasetRef", "iterConnections"]
28from collections import UserDict, namedtuple
29from types import SimpleNamespace
30import typing
32import itertools
33import string
35from . import config as configMod
36from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput,
37 Output, BaseConnection)
38from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum
40if typing.TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true
41 from .config import PipelineTaskConfig
44class ScalarError(TypeError):
45 """Exception raised when dataset type is configured as scalar
46 but there are multiple data IDs in a Quantum for that dataset.
47 """
50class PipelineTaskConnectionDict(UserDict):
51 """This is a special dict class used by PipelineTaskConnectionMetaclass
53 This dict is used in PipelineTaskConnection class creation, as the
54 dictionary that is initially used as __dict__. It exists to
55 intercept connection fields declared in a PipelineTaskConnection, and
56 what name is used to identify them. The names are then added to class
57 level list according to the connection type of the class attribute. The
58 names are also used as keys in a class level dictionary associated with
59 the corresponding class attribute. This information is a duplicate of
60 what exists in __dict__, but provides a simple place to lookup and
61 iterate on only these variables.
62 """
63 def __init__(self, *args, **kwargs):
64 super().__init__(*args, **kwargs)
65 # Initialize class level variables used to track any declared
66 # class level variables that are instances of
67 # connectionTypes.BaseConnection
68 self.data['inputs'] = []
69 self.data['prerequisiteInputs'] = []
70 self.data['outputs'] = []
71 self.data['initInputs'] = []
72 self.data['initOutputs'] = []
73 self.data['allConnections'] = {}
75 def __setitem__(self, name, value):
76 if isinstance(value, Input):
77 self.data['inputs'].append(name)
78 elif isinstance(value, PrerequisiteInput):
79 self.data['prerequisiteInputs'].append(name)
80 elif isinstance(value, Output):
81 self.data['outputs'].append(name)
82 elif isinstance(value, InitInput): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true
83 self.data['initInputs'].append(name)
84 elif isinstance(value, InitOutput): 84 ↛ 85line 84 didn't jump to line 85, because the condition on line 84 was never true
85 self.data['initOutputs'].append(name)
86 # This should not be an elif, as it needs tested for
87 # everything that inherits from BaseConnection
88 if isinstance(value, BaseConnection):
89 object.__setattr__(value, 'varName', name)
90 self.data['allConnections'][name] = value
91 # defer to the default behavior
92 super().__setitem__(name, value)
95class PipelineTaskConnectionsMetaclass(type):
96 """Metaclass used in the declaration of PipelineTaskConnections classes
97 """
98 def __prepare__(name, bases, **kwargs): # noqa: 805
99 # Create an instance of our special dict to catch and track all
100 # variables that are instances of connectionTypes.BaseConnection
101 # Copy any existing connections from a parent class
102 dct = PipelineTaskConnectionDict()
103 for base in bases:
104 if isinstance(base, PipelineTaskConnectionsMetaclass): 104 ↛ 103line 104 didn't jump to line 103, because the condition on line 104 was never false
105 for name, value in base.allConnections.items(): 105 ↛ 106line 105 didn't jump to line 106, because the loop on line 105 never started
106 dct[name] = value
107 return dct
109 def __new__(cls, name, bases, dct, **kwargs):
110 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions "
111 "attribute which is an iterable of dimension names")
113 if name != 'PipelineTaskConnections':
114 # Verify that dimensions are passed as a keyword in class
115 # declaration
116 if 'dimensions' not in kwargs: 116 ↛ 117line 116 didn't jump to line 117, because the condition on line 116 was never true
117 for base in bases:
118 if hasattr(base, 'dimensions'):
119 kwargs['dimensions'] = base.dimensions
120 break
121 if 'dimensions' not in kwargs:
122 raise dimensionsValueError
123 try:
124 if isinstance(kwargs['dimensions'], str): 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true
125 raise TypeError("Dimensions must be iterable of dimensions, got str,"
126 "possibly omitted trailing comma")
127 if not isinstance(kwargs['dimensions'], typing.Iterable): 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true
128 raise TypeError("Dimensions must be iterable of dimensions")
129 dct['dimensions'] = set(kwargs['dimensions'])
130 except TypeError as exc:
131 raise dimensionsValueError from exc
132 # Lookup any python string templates that may have been used in the
133 # declaration of the name field of a class connection attribute
134 allTemplates = set()
135 stringFormatter = string.Formatter()
136 # Loop over all connections
137 for obj in dct['allConnections'].values():
138 nameValue = obj.name
139 # add all the parameters to the set of templates
140 for param in stringFormatter.parse(nameValue):
141 if param[1] is not None: 141 ↛ 142line 141 didn't jump to line 142, because the condition on line 141 was never true
142 allTemplates.add(param[1])
144 # look up any template from base classes and merge them all
145 # together
146 mergeDict = {}
147 for base in bases[::-1]:
148 if hasattr(base, 'defaultTemplates'): 148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true
149 mergeDict.update(base.defaultTemplates)
150 if 'defaultTemplates' in kwargs: 150 ↛ 151line 150 didn't jump to line 151, because the condition on line 150 was never true
151 mergeDict.update(kwargs['defaultTemplates'])
153 if len(mergeDict) > 0: 153 ↛ 154line 153 didn't jump to line 154, because the condition on line 153 was never true
154 kwargs['defaultTemplates'] = mergeDict
156 # Verify that if templated strings were used, defaults were
157 # supplied as an argument in the declaration of the connection
158 # class
159 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 159 ↛ 160line 159 didn't jump to line 160, because the condition on line 159 was never true
160 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no "
161 "defaut templates were provided, add a dictionary attribute named "
162 "defaultTemplates which contains the mapping between template key and value")
163 if len(allTemplates) > 0: 163 ↛ 165line 163 didn't jump to line 165, because the condition on line 163 was never true
164 # Verify all templates have a default, and throw if they do not
165 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys())
166 templateDifference = allTemplates.difference(defaultTemplateKeys)
167 if templateDifference:
168 raise TypeError(f"Default template keys were not provided for {templateDifference}")
169 # Verify that templates do not share names with variable names
170 # used for a connection, this is needed because of how
171 # templates are specified in an associated config class.
172 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys()))
173 if len(nameTemplateIntersection) > 0:
174 raise TypeError(f"Template parameters cannot share names with Class attributes"
175 f" (conflicts are {nameTemplateIntersection}).")
176 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {})
178 # Convert all the connection containers into frozensets so they cannot
179 # be modified at the class scope
180 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
181 dct[connectionName] = frozenset(dct[connectionName])
182 # our custom dict type must be turned into an actual dict to be used in
183 # type.__new__
184 return super().__new__(cls, name, bases, dict(dct))
186 def __init__(cls, name, bases, dct, **kwargs):
187 # This overrides the default init to drop the kwargs argument. Python
188 # metaclasses will have this argument set if any kwargs are passes at
189 # class construction time, but should be consumed before calling
190 # __init__ on the type metaclass. This is in accordance with python
191 # documentation on metaclasses
192 super().__init__(name, bases, dct)
195class QuantizedConnection(SimpleNamespace):
196 """A Namespace to map defined variable names of connections to their
197 `lsst.daf.buter.DatasetRef`s
199 This class maps the names used to define a connection on a
200 PipelineTaskConnectionsClass to the corresponding
201 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
202 instance. This will be a quantum of execution based on the graph created
203 by examining all the connections defined on the
204 `PipelineTaskConnectionsClass`.
205 """
206 def __init__(self, **kwargs):
207 # Create a variable to track what attributes are added. This is used
208 # later when iterating over this QuantizedConnection instance
209 object.__setattr__(self, "_attributes", set())
211 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
212 # Capture the attribute name as it is added to this object
213 self._attributes.add(name)
214 super().__setattr__(name, value)
216 def __delattr__(self, name):
217 object.__delattr__(self, name)
218 self._attributes.remove(name)
220 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
221 typing.List[DatasetRef]]], None, None]:
222 """Make an Iterator for this QuantizedConnection
224 Iterating over a QuantizedConnection will yield a tuple with the name
225 of an attribute and the value associated with that name. This is
226 similar to dict.items() but is on the namespace attributes rather than
227 dict keys.
228 """
229 yield from ((name, getattr(self, name)) for name in self._attributes)
231 def keys(self) -> typing.Generator[str, None, None]:
232 """Returns an iterator over all the attributes added to a
233 QuantizedConnection class
234 """
235 yield from self._attributes
238class InputQuantizedConnection(QuantizedConnection):
239 pass
242class OutputQuantizedConnection(QuantizedConnection):
243 pass
246class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")):
247 """Class which denotes that a datasetRef should be treated as deferred when
248 interacting with the butler
250 Parameters
251 ----------
252 datasetRef : `lsst.daf.butler.DatasetRef`
253 The `lsst.daf.butler.DatasetRef` that will be eventually used to
254 resolve a dataset
255 """
256 __slots__ = ()
259class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
260 """PipelineTaskConnections is a class used to declare desired IO when a
261 PipelineTask is run by an activator
263 Parameters
264 ----------
265 config : `PipelineTaskConfig`
266 A `PipelineTaskConfig` class instance whose class has been configured
267 to use this `PipelineTaskConnectionsClass`
269 Notes
270 -----
271 ``PipelineTaskConnection`` classes are created by declaring class
272 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
273 listed as follows:
275 * ``InitInput`` - Defines connections in a quantum graph which are used as
276 inputs to the ``__init__`` function of the `PipelineTask` corresponding
277 to this class
278 * ``InitOuput`` - Defines connections in a quantum graph which are to be
279 persisted using a butler at the end of the ``__init__`` function of the
280 `PipelineTask` corresponding to this class. The variable name used to
281 define this connection should be the same as an attribute name on the
282 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
283 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
284 a `PipelineTask` instance should have an attribute
285 ``self.outputSchema`` defined. Its value is what will be saved by the
286 activator framework.
287 * ``PrerequisiteInput`` - An input connection type that defines a
288 `lsst.daf.butler.DatasetType` that must be present at execution time,
289 but that will not be used during the course of creating the quantum
290 graph to be executed. These most often are things produced outside the
291 processing pipeline, such as reference catalogs.
292 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
293 in the ``run`` method of a `PipelineTask`. The name used to declare
294 class attribute must match a function argument name in the ``run``
295 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
296 defines an ``Input`` with the name ``calexp``, then the corresponding
297 signature should be ``PipelineTask.run(calexp, ...)``
298 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
299 execution of a `PipelineTask`. The name used to declare the connection
300 must correspond to an attribute of a `Struct` that is returned by a
301 `PipelineTask` ``run`` method. E.g. if an output connection is
302 defined with the name ``measCat``, then the corresponding
303 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
304 X matches the ``storageClass`` type defined on the output connection.
306 The process of declaring a ``PipelineTaskConnection`` class involves
307 parameters passed in the declaration statement.
309 The first parameter is ``dimensions`` which is an iterable of strings which
310 defines the unit of processing the run method of a corresponding
311 `PipelineTask` will operate on. These dimensions must match dimensions that
312 exist in the butler registry which will be used in executing the
313 corresponding `PipelineTask`.
315 The second parameter is labeled ``defaultTemplates`` and is conditionally
316 optional. The name attributes of connections can be specified as python
317 format strings, with named format arguments. If any of the name parameters
318 on connections defined in a `PipelineTaskConnections` class contain a
319 template, then a default template value must be specified in the
320 ``defaultTemplates`` argument. This is done by passing a dictionary with
321 keys corresponding to a template identifier, and values corresponding to
322 the value to use as a default when formatting the string. For example if
323 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
324 ``defaultTemplates`` = {'input': 'deep'}.
326 Once a `PipelineTaskConnections` class is created, it is used in the
327 creation of a `PipelineTaskConfig`. This is further documented in the
328 documentation of `PipelineTaskConfig`. For the purposes of this
329 documentation, the relevant information is that the config class allows
330 configuration of connection names by users when running a pipeline.
332 Instances of a `PipelineTaskConnections` class are used by the pipeline
333 task execution framework to introspect what a corresponding `PipelineTask`
334 will require, and what it will produce.
336 Examples
337 --------
338 >>> from lsst.pipe.base import connectionTypes as cT
339 >>> from lsst.pipe.base import PipelineTaskConnections
340 >>> from lsst.pipe.base import PipelineTaskConfig
341 >>> class ExampleConnections(PipelineTaskConnections,
342 ... dimensions=("A", "B"),
343 ... defaultTemplates={"foo": "Example"}):
344 ... inputConnection = cT.Input(doc="Example input",
345 ... dimensions=("A", "B"),
346 ... storageClass=Exposure,
347 ... name="{foo}Dataset")
348 ... outputConnection = cT.Output(doc="Example output",
349 ... dimensions=("A", "B"),
350 ... storageClass=Exposure,
351 ... name="{foo}output")
352 >>> class ExampleConfig(PipelineTaskConfig,
353 ... pipelineConnections=ExampleConnections):
354 ... pass
355 >>> config = ExampleConfig()
356 >>> config.connections.foo = Modified
357 >>> config.connections.outputConnection = "TotallyDifferent"
358 >>> connections = ExampleConnections(config=config)
359 >>> assert(connections.inputConnection.name == "ModifiedDataset")
360 >>> assert(connections.outputConnection.name == "TotallyDifferent")
361 """
363 def __init__(self, *, config: 'PipelineTaskConfig' = None):
364 self.inputs = set(self.inputs)
365 self.prerequisiteInputs = set(self.prerequisiteInputs)
366 self.outputs = set(self.outputs)
367 self.initInputs = set(self.initInputs)
368 self.initOutputs = set(self.initOutputs)
370 if config is None or not isinstance(config, configMod.PipelineTaskConfig):
371 raise ValueError("PipelineTaskConnections must be instantiated with"
372 " a PipelineTaskConfig instance")
373 self.config = config
374 # Extract the template names that were defined in the config instance
375 # by looping over the keys of the defaultTemplates dict specified at
376 # class declaration time
377 templateValues = {name: getattr(config.connections, name) for name in getattr(self,
378 'defaultTemplates').keys()}
379 # Extract the configured value corresponding to each connection
380 # variable. I.e. for each connection identifier, populate a override
381 # for the connection.name attribute
382 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues)
383 for name in self.allConnections.keys()}
385 # connections.name corresponds to a dataset type name, create a reverse
386 # mapping that goes from dataset type name to attribute identifier name
387 # (variable name) on the connection class
388 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()}
390 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection,
391 OutputQuantizedConnection]:
392 """Builds QuantizedConnections corresponding to input Quantum
394 Parameters
395 ----------
396 quantum : `lsst.daf.butler.Quantum`
397 Quantum object which defines the inputs and outputs for a given
398 unit of processing
400 Returns
401 -------
402 retVal : `tuple` of (`InputQuantizedConnection`,
403 `OutputQuantizedConnection`) Namespaces mapping attribute names
404 (identifiers of connections) to butler references defined in the
405 input `lsst.daf.butler.Quantum`
406 """
407 inputDatasetRefs = InputQuantizedConnection()
408 outputDatasetRefs = OutputQuantizedConnection()
409 # operate on a reference object and an interable of names of class
410 # connection attributes
411 for refs, names in zip((inputDatasetRefs, outputDatasetRefs),
412 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)):
413 # get a name of a class connection attribute
414 for attributeName in names:
415 # get the attribute identified by name
416 attribute = getattr(self, attributeName)
417 # Branch if the attribute dataset type is an input
418 if attribute.name in quantum.predictedInputs:
419 # Get the DatasetRefs
420 quantumInputRefs = quantum.predictedInputs[attribute.name]
421 # if the dataset is marked to load deferred, wrap it in a
422 # DeferredDatasetRef
423 if attribute.deferLoad:
424 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs]
425 # Unpack arguments that are not marked multiples (list of
426 # length one)
427 if not attribute.multiple:
428 if len(quantumInputRefs) > 1:
429 raise ScalarError(
430 f"Received multiple datasets "
431 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
432 f"for scalar connection {attributeName} "
433 f"({quantumInputRefs[0].datasetType.name}) "
434 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
435 )
436 if len(quantumInputRefs) == 0:
437 continue
438 quantumInputRefs = quantumInputRefs[0]
439 # Add to the QuantizedConnection identifier
440 setattr(refs, attributeName, quantumInputRefs)
441 # Branch if the attribute dataset type is an output
442 elif attribute.name in quantum.outputs:
443 value = quantum.outputs[attribute.name]
444 # Unpack arguments that are not marked multiples (list of
445 # length one)
446 if not attribute.multiple:
447 value = value[0]
448 # Add to the QuantizedConnection identifier
449 setattr(refs, attributeName, value)
450 # Specified attribute is not in inputs or outputs dont know how
451 # to handle, throw
452 else:
453 raise ValueError(f"Attribute with name {attributeName} has no counterpoint "
454 "in input quantum")
455 return inputDatasetRefs, outputDatasetRefs
457 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]]
458 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]:
459 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
460 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
461 of the activator.
463 The base class implementation simply checks that input connections with
464 ``multiple`` set to `False` have no more than one dataset.
466 Parameters
467 ----------
468 datasetRefMap : `NamedKeyDict`
469 Mapping from dataset type to a `set` of
470 `lsst.daf.butler.DatasetRef` objects
472 Returns
473 -------
474 datasetRefMap : `NamedKeyDict`
475 Modified mapping of input with possibly adjusted
476 `lsst.daf.butler.DatasetRef` objects.
478 Raises
479 ------
480 ScalarError
481 Raised if any `Input` or `PrerequisiteInput` connection has
482 ``multiple`` set to `False`, but multiple datasets.
483 Exception
484 Overrides of this function have the option of raising an Exception
485 if a field in the input does not satisfy a need for a corresponding
486 pipelineTask, i.e. no reference catalogs are found.
487 """
488 for connection in itertools.chain(iterConnections(self, "inputs"),
489 iterConnections(self, "prerequisiteInputs")):
490 refs = datasetRefMap[connection.name]
491 if not connection.multiple and len(refs) > 1:
492 raise ScalarError(
493 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
494 f"for scalar connection {connection.name} ({refs[0].datasetType.name})."
495 )
496 return datasetRefMap
499def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator:
500 """Creates an iterator over the selected connections type which yields
501 all the defined connections of that type.
503 Parameters
504 ----------
505 connections: `PipelineTaskConnections`
506 An instance of a `PipelineTaskConnections` object that will be iterated
507 over.
508 connectionType: `str`
509 The type of connections to iterate over, valid values are inputs,
510 outputs, prerequisiteInputs, initInputs, initOutputs.
512 Yields
513 -------
514 connection: `BaseConnection`
515 A connection defined on the input connections object of the type
516 supplied. The yielded value Will be an derived type of
517 `BaseConnection`.
518 """
519 for name in getattr(connections, connectionType):
520 yield getattr(connections, name)