Coverage for python/lsst/pipe/base/connections.py: 42%
290 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-17 02:45 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-17 02:45 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError",
36]
38import dataclasses
39import itertools
40import string
41from collections import UserDict
42from collections.abc import Collection, Generator, Iterable, Mapping, Set
43from dataclasses import dataclass
44from types import MappingProxyType, SimpleNamespace
45from typing import TYPE_CHECKING, Any
47from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
49from ._status import NoWorkFound
50from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from .config import PipelineTaskConfig
56class ScalarError(TypeError):
57 """Exception raised when dataset type is configured as scalar
58 but there are multiple data IDs in a Quantum for that dataset.
59 """
62class PipelineTaskConnectionDict(UserDict):
63 """This is a special dict class used by PipelineTaskConnectionMetaclass
65 This dict is used in PipelineTaskConnection class creation, as the
66 dictionary that is initially used as __dict__. It exists to
67 intercept connection fields declared in a PipelineTaskConnection, and
68 what name is used to identify them. The names are then added to class
69 level list according to the connection type of the class attribute. The
70 names are also used as keys in a class level dictionary associated with
71 the corresponding class attribute. This information is a duplicate of
72 what exists in __dict__, but provides a simple place to lookup and
73 iterate on only these variables.
74 """
76 def __init__(self, *args: Any, **kwargs: Any):
77 super().__init__(*args, **kwargs)
78 # Initialize class level variables used to track any declared
79 # class level variables that are instances of
80 # connectionTypes.BaseConnection
81 self.data["inputs"] = set()
82 self.data["prerequisiteInputs"] = set()
83 self.data["outputs"] = set()
84 self.data["initInputs"] = set()
85 self.data["initOutputs"] = set()
86 self.data["allConnections"] = {}
88 def __setitem__(self, name: str, value: Any) -> None:
89 if isinstance(value, BaseConnection):
90 if name in { 90 ↛ 100line 90 didn't jump to line 100, because the condition on line 90 was never true
91 "dimensions",
92 "inputs",
93 "prerequisiteInputs",
94 "outputs",
95 "initInputs",
96 "initOutputs",
97 "allConnections",
98 }:
99 # Guard against connections whose names are reserved.
100 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
101 if (previous := self.data.get(name)) is not None: 101 ↛ 104line 101 didn't jump to line 104, because the condition on line 101 was never true
102 # Guard against changing the type of an in inherited connection
103 # by first removing it from the set it's current in.
104 self.data[previous._connection_type_set].discard(name)
105 object.__setattr__(value, "varName", name)
106 self.data["allConnections"][name] = value
107 self.data[value._connection_type_set].add(name)
108 # defer to the default behavior
109 super().__setitem__(name, value)
112class PipelineTaskConnectionsMetaclass(type):
113 """Metaclass used in the declaration of PipelineTaskConnections classes"""
115 # We can annotate these attributes as `collections.abc.Set` to discourage
116 # undesirable modifications in type-checked code, since the internal code
117 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
118 # these annotations anyway.
120 dimensions: Set[str]
121 """Set of dimension names that define the unit of work for this task.
123 Required and implied dependencies will automatically be expanded later and
124 need not be provided.
126 This is shadowed by an instance-level attribute on
127 `PipelineTaskConnections` instances.
128 """
130 inputs: Set[str]
131 """Set with the names of all `~connectionTypes.Input` connection
132 attributes.
134 This is updated automatically as class attributes are added. Note that
135 this attribute is shadowed by an instance-level attribute on
136 `PipelineTaskConnections` instances.
137 """
139 prerequisiteInputs: Set[str]
140 """Set with the names of all `~connectionTypes.PrerequisiteInput`
141 connection attributes.
143 See `inputs` for additional information.
144 """
146 outputs: Set[str]
147 """Set with the names of all `~connectionTypes.Output` connection
148 attributes.
150 See `inputs` for additional information.
151 """
153 initInputs: Set[str]
154 """Set with the names of all `~connectionTypes.InitInput` connection
155 attributes.
157 See `inputs` for additional information.
158 """
160 initOutputs: Set[str]
161 """Set with the names of all `~connectionTypes.InitOutput` connection
162 attributes.
164 See `inputs` for additional information.
165 """
167 allConnections: Mapping[str, BaseConnection]
168 """Mapping containing all connection attributes.
170 See `inputs` for additional information.
171 """
173 def __prepare__(name, bases, **kwargs): # noqa: 805
174 # Create an instance of our special dict to catch and track all
175 # variables that are instances of connectionTypes.BaseConnection
176 # Copy any existing connections from a parent class
177 dct = PipelineTaskConnectionDict()
178 for base in bases:
179 if isinstance(base, PipelineTaskConnectionsMetaclass): 179 ↛ 178line 179 didn't jump to line 178, because the condition on line 179 was never false
180 for name, value in base.allConnections.items(): 180 ↛ 181line 180 didn't jump to line 181, because the loop on line 180 never started
181 dct[name] = value
182 return dct
184 def __new__(cls, name, bases, dct, **kwargs):
185 dimensionsValueError = TypeError(
186 "PipelineTaskConnections class must be created with a dimensions "
187 "attribute which is an iterable of dimension names"
188 )
190 if name != "PipelineTaskConnections":
191 # Verify that dimensions are passed as a keyword in class
192 # declaration
193 if "dimensions" not in kwargs: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true
194 for base in bases:
195 if hasattr(base, "dimensions"):
196 kwargs["dimensions"] = base.dimensions
197 break
198 if "dimensions" not in kwargs:
199 raise dimensionsValueError
200 try:
201 if isinstance(kwargs["dimensions"], str): 201 ↛ 202line 201 didn't jump to line 202, because the condition on line 201 was never true
202 raise TypeError(
203 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
204 )
205 if not isinstance(kwargs["dimensions"], Iterable): 205 ↛ 206line 205 didn't jump to line 206, because the condition on line 205 was never true
206 raise TypeError("Dimensions must be iterable of dimensions")
207 dct["dimensions"] = set(kwargs["dimensions"])
208 except TypeError as exc:
209 raise dimensionsValueError from exc
210 # Lookup any python string templates that may have been used in the
211 # declaration of the name field of a class connection attribute
212 allTemplates = set()
213 stringFormatter = string.Formatter()
214 # Loop over all connections
215 for obj in dct["allConnections"].values():
216 nameValue = obj.name
217 # add all the parameters to the set of templates
218 for param in stringFormatter.parse(nameValue):
219 if param[1] is not None:
220 allTemplates.add(param[1])
222 # look up any template from base classes and merge them all
223 # together
224 mergeDict = {}
225 for base in bases[::-1]:
226 if hasattr(base, "defaultTemplates"): 226 ↛ 227line 226 didn't jump to line 227, because the condition on line 226 was never true
227 mergeDict.update(base.defaultTemplates)
228 if "defaultTemplates" in kwargs:
229 mergeDict.update(kwargs["defaultTemplates"])
231 if len(mergeDict) > 0:
232 kwargs["defaultTemplates"] = mergeDict
234 # Verify that if templated strings were used, defaults were
235 # supplied as an argument in the declaration of the connection
236 # class
237 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true
238 raise TypeError(
239 "PipelineTaskConnection class contains templated attribute names, but no "
240 "defaut templates were provided, add a dictionary attribute named "
241 "defaultTemplates which contains the mapping between template key and value"
242 )
243 if len(allTemplates) > 0:
244 # Verify all templates have a default, and throw if they do not
245 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
246 templateDifference = allTemplates.difference(defaultTemplateKeys)
247 if templateDifference: 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true
248 raise TypeError(f"Default template keys were not provided for {templateDifference}")
249 # Verify that templates do not share names with variable names
250 # used for a connection, this is needed because of how
251 # templates are specified in an associated config class.
252 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
253 if len(nameTemplateIntersection) > 0: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true
254 raise TypeError(
255 "Template parameters cannot share names with Class attributes"
256 f" (conflicts are {nameTemplateIntersection})."
257 )
258 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
260 # Convert all the connection containers into frozensets so they cannot
261 # be modified at the class scope
262 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
263 dct[connectionName] = frozenset(dct[connectionName])
264 # our custom dict type must be turned into an actual dict to be used in
265 # type.__new__
266 return super().__new__(cls, name, bases, dict(dct))
268 def __init__(cls, name, bases, dct, **kwargs):
269 # This overrides the default init to drop the kwargs argument. Python
270 # metaclasses will have this argument set if any kwargs are passes at
271 # class construction time, but should be consumed before calling
272 # __init__ on the type metaclass. This is in accordance with python
273 # documentation on metaclasses
274 super().__init__(name, bases, dct)
276 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
277 # MyPy appears not to really understand metaclass.__call__ at all, so
278 # we need to tell it to ignore __new__ and __init__ calls here.
279 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
281 # Make mutable copies of all set-like class attributes so derived
282 # __init__ implementations can modify them in place.
283 instance.dimensions = set(cls.dimensions)
284 instance.inputs = set(cls.inputs)
285 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
286 instance.outputs = set(cls.outputs)
287 instance.initInputs = set(cls.initInputs)
288 instance.initOutputs = set(cls.initOutputs)
290 # Set self.config. It's a bit strange that we claim to accept None but
291 # really just raise here, but it's not worth changing now.
292 from .config import PipelineTaskConfig # local import to avoid cycle
294 if config is None or not isinstance(config, PipelineTaskConfig):
295 raise ValueError(
296 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
297 )
298 instance.config = config
300 # Extract the template names that were defined in the config instance
301 # by looping over the keys of the defaultTemplates dict specified at
302 # class declaration time.
303 templateValues = {
304 name: getattr(config.connections, name) for name in getattr(cls, "defaultTemplates").keys()
305 }
307 # We now assemble a mapping of all connection instances keyed by
308 # internal name, applying the configuration and templates to make new
309 # configurations from the class-attribute defaults. This will be
310 # private, but with a public read-only view. This mapping is what the
311 # descriptor interface of the class-level attributes will return when
312 # they are accessed on an instance. This is better than just assigning
313 # regular instance attributes as it makes it so removed connections
314 # cannot be accessed on instances, instead of having access to them
315 # silent fall through to the not-removed class connection instance.
316 instance._allConnections = {}
317 instance.allConnections = MappingProxyType(instance._allConnections)
318 for internal_name, connection in cls.allConnections.items():
319 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
320 instance_connection = dataclasses.replace(connection, name=dataset_type_name)
321 instance._allConnections[internal_name] = instance_connection
323 # Finally call __init__. The base class implementation does nothing;
324 # we could have left some of the above implementation there (where it
325 # originated), but putting it here instead makes it hard for derived
326 # class implementors to get things into a weird state by delegating to
327 # super().__init__ in the wrong place, or by forgetting to do that
328 # entirely.
329 instance.__init__(config=config) # type: ignore
331 # Derived-class implementations may have changed the contents of the
332 # various kinds-of-connection sets; update allConnections to have keys
333 # that are a union of all those. We get values for the new
334 # allConnections from the attributes, since any dynamically added new
335 # ones will not be present in the old allConnections. Typically those
336 # getattrs will invoke the descriptors and get things from the old
337 # allConnections anyway. After processing each set we replace it with
338 # a frozenset.
339 updated_all_connections = {}
340 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
341 updated_connection_names = getattr(instance, attrName)
342 updated_all_connections.update(
343 {name: getattr(instance, name) for name in updated_connection_names}
344 )
345 # Setting these to frozenset is at odds with the type annotation,
346 # but MyPy can't tell because we're using setattr, and we want to
347 # lie to it anyway to get runtime guards against post-__init__
348 # mutation.
349 setattr(instance, attrName, frozenset(updated_connection_names))
350 # Update the existing dict in place, since we already have a view of
351 # that.
352 instance._allConnections.clear()
353 instance._allConnections.update(updated_all_connections)
355 # Freeze the connection instance dimensions now. This at odds with the
356 # type annotation, which says [mutable] `set`, just like the connection
357 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
358 # tell with those since we're using setattr for them.
359 instance.dimensions = frozenset(instance.dimensions) # type: ignore
361 return instance
364class QuantizedConnection(SimpleNamespace):
365 """A Namespace to map defined variable names of connections to the
366 associated `lsst.daf.butler.DatasetRef` objects.
368 This class maps the names used to define a connection on a
369 `PipelineTaskConnections` class to the corresponding
370 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
371 instance. This will be a quantum of execution based on the graph created
372 by examining all the connections defined on the
373 `PipelineTaskConnections` class.
374 """
376 def __init__(self, **kwargs):
377 # Create a variable to track what attributes are added. This is used
378 # later when iterating over this QuantizedConnection instance
379 object.__setattr__(self, "_attributes", set())
381 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
382 # Capture the attribute name as it is added to this object
383 self._attributes.add(name)
384 super().__setattr__(name, value)
386 def __delattr__(self, name):
387 object.__delattr__(self, name)
388 self._attributes.remove(name)
390 def __len__(self) -> int:
391 return len(self._attributes)
393 def __iter__(
394 self,
395 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]:
396 """Make an Iterator for this QuantizedConnection
398 Iterating over a QuantizedConnection will yield a tuple with the name
399 of an attribute and the value associated with that name. This is
400 similar to dict.items() but is on the namespace attributes rather than
401 dict keys.
402 """
403 yield from ((name, getattr(self, name)) for name in self._attributes)
405 def keys(self) -> Generator[str, None, None]:
406 """Returns an iterator over all the attributes added to a
407 QuantizedConnection class
408 """
409 yield from self._attributes
412class InputQuantizedConnection(QuantizedConnection):
413 pass
416class OutputQuantizedConnection(QuantizedConnection):
417 pass
420@dataclass(frozen=True)
421class DeferredDatasetRef:
422 """A wrapper class for `DatasetRef` that indicates that a `PipelineTask`
423 should receive a `DeferredDatasetHandle` instead of an in-memory dataset.
425 Parameters
426 ----------
427 datasetRef : `lsst.daf.butler.DatasetRef`
428 The `lsst.daf.butler.DatasetRef` that will be eventually used to
429 resolve a dataset
430 """
432 datasetRef: DatasetRef
434 @property
435 def datasetType(self) -> DatasetType:
436 """The dataset type for this dataset."""
437 return self.datasetRef.datasetType
439 @property
440 def dataId(self) -> DataCoordinate:
441 """The data ID for this dataset."""
442 return self.datasetRef.dataId
445class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
446 """PipelineTaskConnections is a class used to declare desired IO when a
447 PipelineTask is run by an activator
449 Parameters
450 ----------
451 config : `PipelineTaskConfig`
452 A `PipelineTaskConfig` class instance whose class has been configured
453 to use this `PipelineTaskConnections` class.
455 See also
456 --------
457 iterConnections
459 Notes
460 -----
461 ``PipelineTaskConnection`` classes are created by declaring class
462 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
463 listed as follows:
465 * ``InitInput`` - Defines connections in a quantum graph which are used as
466 inputs to the ``__init__`` function of the `PipelineTask` corresponding
467 to this class
468 * ``InitOuput`` - Defines connections in a quantum graph which are to be
469 persisted using a butler at the end of the ``__init__`` function of the
470 `PipelineTask` corresponding to this class. The variable name used to
471 define this connection should be the same as an attribute name on the
472 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
473 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
474 a `PipelineTask` instance should have an attribute
475 ``self.outputSchema`` defined. Its value is what will be saved by the
476 activator framework.
477 * ``PrerequisiteInput`` - An input connection type that defines a
478 `lsst.daf.butler.DatasetType` that must be present at execution time,
479 but that will not be used during the course of creating the quantum
480 graph to be executed. These most often are things produced outside the
481 processing pipeline, such as reference catalogs.
482 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
483 in the ``run`` method of a `PipelineTask`. The name used to declare
484 class attribute must match a function argument name in the ``run``
485 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
486 defines an ``Input`` with the name ``calexp``, then the corresponding
487 signature should be ``PipelineTask.run(calexp, ...)``
488 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
489 execution of a `PipelineTask`. The name used to declare the connection
490 must correspond to an attribute of a `Struct` that is returned by a
491 `PipelineTask` ``run`` method. E.g. if an output connection is
492 defined with the name ``measCat``, then the corresponding
493 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
494 X matches the ``storageClass`` type defined on the output connection.
496 Attributes of these types can also be created, replaced, or deleted on the
497 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
498 just the name depends on the configuration. It is preferred to define them
499 in the class when possible (even if configuration may cause the connection
500 to be removed from the instance).
502 The process of declaring a ``PipelineTaskConnection`` class involves
503 parameters passed in the declaration statement.
505 The first parameter is ``dimensions`` which is an iterable of strings which
506 defines the unit of processing the run method of a corresponding
507 `PipelineTask` will operate on. These dimensions must match dimensions that
508 exist in the butler registry which will be used in executing the
509 corresponding `PipelineTask`. The dimensions may be also modified in
510 subclass ``__init__`` methods if they need to depend on configuration.
512 The second parameter is labeled ``defaultTemplates`` and is conditionally
513 optional. The name attributes of connections can be specified as python
514 format strings, with named format arguments. If any of the name parameters
515 on connections defined in a `PipelineTaskConnections` class contain a
516 template, then a default template value must be specified in the
517 ``defaultTemplates`` argument. This is done by passing a dictionary with
518 keys corresponding to a template identifier, and values corresponding to
519 the value to use as a default when formatting the string. For example if
520 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
521 ``defaultTemplates`` = {'input': 'deep'}.
523 Once a `PipelineTaskConnections` class is created, it is used in the
524 creation of a `PipelineTaskConfig`. This is further documented in the
525 documentation of `PipelineTaskConfig`. For the purposes of this
526 documentation, the relevant information is that the config class allows
527 configuration of connection names by users when running a pipeline.
529 Instances of a `PipelineTaskConnections` class are used by the pipeline
530 task execution framework to introspect what a corresponding `PipelineTask`
531 will require, and what it will produce.
533 Examples
534 --------
535 >>> from lsst.pipe.base import connectionTypes as cT
536 >>> from lsst.pipe.base import PipelineTaskConnections
537 >>> from lsst.pipe.base import PipelineTaskConfig
538 >>> class ExampleConnections(PipelineTaskConnections,
539 ... dimensions=("A", "B"),
540 ... defaultTemplates={"foo": "Example"}):
541 ... inputConnection = cT.Input(doc="Example input",
542 ... dimensions=("A", "B"),
543 ... storageClass=Exposure,
544 ... name="{foo}Dataset")
545 ... outputConnection = cT.Output(doc="Example output",
546 ... dimensions=("A", "B"),
547 ... storageClass=Exposure,
548 ... name="{foo}output")
549 >>> class ExampleConfig(PipelineTaskConfig,
550 ... pipelineConnections=ExampleConnections):
551 ... pass
552 >>> config = ExampleConfig()
553 >>> config.connections.foo = Modified
554 >>> config.connections.outputConnection = "TotallyDifferent"
555 >>> connections = ExampleConnections(config=config)
556 >>> assert(connections.inputConnection.name == "ModifiedDataset")
557 >>> assert(connections.outputConnection.name == "TotallyDifferent")
558 """
560 # We annotate these attributes as mutable sets because that's what they are
561 # inside derived ``__init__`` implementations and that's what matters most
562 # After that's done, the metaclass __call__ makes them into frozensets, but
563 # relatively little code interacts with them then, and that code knows not
564 # to try to modify them without having to be told that by mypy.
566 dimensions: set[str]
567 """Set of dimension names that define the unit of work for this task.
569 Required and implied dependencies will automatically be expanded later and
570 need not be provided.
572 This may be replaced or modified in ``__init__`` to change the dimensions
573 of the task. After ``__init__`` it will be a `frozenset` and may not be
574 replaced.
575 """
577 inputs: set[str]
578 """Set with the names of all `connectionTypes.Input` connection attributes.
580 This is updated automatically as class attributes are added, removed, or
581 replaced in ``__init__``. Removing entries from this set will cause those
582 connections to be removed after ``__init__`` completes, but this is
583 supported only for backwards compatibility; new code should instead just
584 delete the collection attributed directly. After ``__init__`` this will be
585 a `frozenset` and may not be replaced.
586 """
588 prerequisiteInputs: set[str]
589 """Set with the names of all `~connectionTypes.PrerequisiteInput`
590 connection attributes.
592 See `inputs` for additional information.
593 """
595 outputs: set[str]
596 """Set with the names of all `~connectionTypes.Output` connection
597 attributes.
599 See `inputs` for additional information.
600 """
602 initInputs: set[str]
603 """Set with the names of all `~connectionTypes.InitInput` connection
604 attributes.
606 See `inputs` for additional information.
607 """
609 initOutputs: set[str]
610 """Set with the names of all `~connectionTypes.InitOutput` connection
611 attributes.
613 See `inputs` for additional information.
614 """
616 allConnections: Mapping[str, BaseConnection]
617 """Mapping holding all connection attributes.
619 This is a read-only view that is automatically updated when connection
620 attributes are added, removed, or replaced in ``__init__``. It is also
621 updated after ``__init__`` completes to reflect changes in `inputs`,
622 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
623 """
625 _allConnections: dict[str, BaseConnection]
627 def __init__(self, *, config: PipelineTaskConfig | None = None):
628 pass
630 def __setattr__(self, name: str, value: Any) -> None:
631 if isinstance(value, BaseConnection):
632 previous = self._allConnections.get(name)
633 try:
634 getattr(self, value._connection_type_set).add(name)
635 except AttributeError:
636 # Attempt to call add on a frozenset, which is what these sets
637 # are after __init__ is done.
638 raise TypeError("Connections objects are frozen after construction.") from None
639 if previous is not None and value._connection_type_set != previous._connection_type_set:
640 # Connection has changed type, e.g. Input to PrerequisiteInput;
641 # update the sets accordingly. To be extra defensive about
642 # multiple assignments we use the type of the previous instance
643 # instead of assuming that's the same as the type of the self,
644 # which is just the default. Use discard instead of remove so
645 # manually removing from these sets first is never an error.
646 getattr(self, previous._connection_type_set).discard(name)
647 self._allConnections[name] = value
648 if hasattr(self.__class__, name):
649 # Don't actually set the attribute if this was a connection
650 # declared in the class; in that case we let the descriptor
651 # return the value we just added to allConnections.
652 return
653 # Actually add the attribute.
654 super().__setattr__(name, value)
656 def __delattr__(self, name):
657 """Descriptor delete method."""
658 previous = self._allConnections.get(name)
659 if previous is not None:
660 # Delete this connection's name from the appropriate set, which we
661 # have to get from the previous instance instead of assuming it's
662 # the same set that was appropriate for the class-level default.
663 # Use discard instead of remove so manually removing from these
664 # sets first is never an error.
665 try:
666 getattr(self, previous._connection_type_set).discard(name)
667 except AttributeError:
668 # Attempt to call discard on a frozenset, which is what these
669 # sets are after __init__ is done.
670 raise TypeError("Connections objects are frozen after construction.") from None
671 del self._allConnections[name]
672 if hasattr(self.__class__, name):
673 # Don't actually delete the attribute if this was a connection
674 # declared in the class; in that case we let the descriptor
675 # see that it's no longer present in allConnections.
676 return
677 # Actually delete the attribute.
678 super().__delattr__(name)
680 def buildDatasetRefs(
681 self, quantum: Quantum
682 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
683 """Builds QuantizedConnections corresponding to input Quantum
685 Parameters
686 ----------
687 quantum : `lsst.daf.butler.Quantum`
688 Quantum object which defines the inputs and outputs for a given
689 unit of processing
691 Returns
692 -------
693 retVal : `tuple` of (`InputQuantizedConnection`,
694 `OutputQuantizedConnection`) Namespaces mapping attribute names
695 (identifiers of connections) to butler references defined in the
696 input `lsst.daf.butler.Quantum`
697 """
698 inputDatasetRefs = InputQuantizedConnection()
699 outputDatasetRefs = OutputQuantizedConnection()
700 # operate on a reference object and an interable of names of class
701 # connection attributes
702 for refs, names in zip(
703 (inputDatasetRefs, outputDatasetRefs),
704 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
705 ):
706 # get a name of a class connection attribute
707 for attributeName in names:
708 # get the attribute identified by name
709 attribute = getattr(self, attributeName)
710 # Branch if the attribute dataset type is an input
711 if attribute.name in quantum.inputs:
712 # if the dataset is marked to load deferred, wrap it in a
713 # DeferredDatasetRef
714 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
715 if attribute.deferLoad:
716 quantumInputRefs = [
717 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
718 ]
719 else:
720 quantumInputRefs = list(quantum.inputs[attribute.name])
721 # Unpack arguments that are not marked multiples (list of
722 # length one)
723 if not attribute.multiple:
724 if len(quantumInputRefs) > 1:
725 raise ScalarError(
726 "Received multiple datasets "
727 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
728 f"for scalar connection {attributeName} "
729 f"({quantumInputRefs[0].datasetType.name}) "
730 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
731 )
732 if len(quantumInputRefs) == 0:
733 continue
734 setattr(refs, attributeName, quantumInputRefs[0])
735 else:
736 # Add to the QuantizedConnection identifier
737 setattr(refs, attributeName, quantumInputRefs)
738 # Branch if the attribute dataset type is an output
739 elif attribute.name in quantum.outputs:
740 value = quantum.outputs[attribute.name]
741 # Unpack arguments that are not marked multiples (list of
742 # length one)
743 if not attribute.multiple:
744 setattr(refs, attributeName, value[0])
745 else:
746 setattr(refs, attributeName, value)
747 # Specified attribute is not in inputs or outputs dont know how
748 # to handle, throw
749 else:
750 raise ValueError(
751 f"Attribute with name {attributeName} has no counterpart in input quantum"
752 )
753 return inputDatasetRefs, outputDatasetRefs
755 def adjustQuantum(
756 self,
757 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
758 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
759 label: str,
760 data_id: DataCoordinate,
761 ) -> tuple[
762 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
763 Mapping[str, tuple[Output, Collection[DatasetRef]]],
764 ]:
765 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
766 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
767 of the activator.
769 Parameters
770 ----------
771 inputs : `dict`
772 Dictionary whose keys are an input (regular or prerequisite)
773 connection name and whose values are a tuple of the connection
774 instance and a collection of associated `DatasetRef` objects.
775 The exact type of the nested collections is unspecified; it can be
776 assumed to be multi-pass iterable and support `len` and ``in``, but
777 it should not be mutated in place. In contrast, the outer
778 dictionaries are guaranteed to be temporary copies that are true
779 `dict` instances, and hence may be modified and even returned; this
780 is especially useful for delegating to `super` (see notes below).
781 outputs : `Mapping`
782 Mapping of output datasets, with the same structure as ``inputs``.
783 label : `str`
784 Label for this task in the pipeline (should be used in all
785 diagnostic messages).
786 data_id : `lsst.daf.butler.DataCoordinate`
787 Data ID for this quantum in the pipeline (should be used in all
788 diagnostic messages).
790 Returns
791 -------
792 adjusted_inputs : `Mapping`
793 Mapping of the same form as ``inputs`` with updated containers of
794 input `DatasetRef` objects. Connections that are not changed
795 should not be returned at all. Datasets may only be removed, not
796 added. Nested collections may be of any multi-pass iterable type,
797 and the order of iteration will set the order of iteration within
798 `PipelineTask.runQuantum`.
799 adjusted_outputs : `Mapping`
800 Mapping of updated output datasets, with the same structure and
801 interpretation as ``adjusted_inputs``.
803 Raises
804 ------
805 ScalarError
806 Raised if any `Input` or `PrerequisiteInput` connection has
807 ``multiple`` set to `False`, but multiple datasets.
808 NoWorkFound
809 Raised to indicate that this quantum should not be run; not enough
810 datasets were found for a regular `Input` connection, and the
811 quantum should be pruned or skipped.
812 FileNotFoundError
813 Raised to cause QuantumGraph generation to fail (with the message
814 included in this exception); not enough datasets were found for a
815 `PrerequisiteInput` connection.
817 Notes
818 -----
819 The base class implementation performs important checks. It always
820 returns an empty mapping (i.e. makes no adjustments). It should
821 always called be via `super` by custom implementations, ideally at the
822 end of the custom implementation with already-adjusted mappings when
823 any datasets are actually dropped, e.g.::
825 def adjustQuantum(self, inputs, outputs, label, data_id):
826 # Filter out some dataset refs for one connection.
827 connection, old_refs = inputs["my_input"]
828 new_refs = [ref for ref in old_refs if ...]
829 adjusted_inputs = {"my_input", (connection, new_refs)}
830 # Update the original inputs so we can pass them to super.
831 inputs.update(adjusted_inputs)
832 # Can ignore outputs from super because they are guaranteed
833 # to be empty.
834 super().adjustQuantum(inputs, outputs, label_data_id)
835 # Return only the connections we modified.
836 return adjusted_inputs, {}
838 Removing outputs here is guaranteed to affect what is actually
839 passed to `PipelineTask.runQuantum`, but its effect on the larger
840 graph may be deferred to execution, depending on the context in
841 which `adjustQuantum` is being run: if one quantum removes an output
842 that is needed by a second quantum as input, the second quantum may not
843 be adjusted (and hence pruned or skipped) until that output is actually
844 found to be missing at execution time.
846 Tasks that desire zip-iteration consistency between any combinations of
847 connections that have the same data ID should generally implement
848 `adjustQuantum` to achieve this, even if they could also run that
849 logic during execution; this allows the system to see outputs that will
850 not be produced because the corresponding input is missing as early as
851 possible.
852 """
853 for name, (input_connection, refs) in inputs.items():
854 dataset_type_name = input_connection.name
855 if not input_connection.multiple and len(refs) > 1:
856 raise ScalarError(
857 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
858 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
859 f"for quantum data ID {data_id}."
860 )
861 if len(refs) < input_connection.minimum:
862 if isinstance(input_connection, PrerequisiteInput):
863 # This branch should only be possible during QG generation,
864 # or if someone deleted the dataset between making the QG
865 # and trying to run it. Either one should be a hard error.
866 raise FileNotFoundError(
867 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
868 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
869 f"{data_id}."
870 )
871 else:
872 # This branch should be impossible during QG generation,
873 # because that algorithm can only make quanta whose inputs
874 # are either already present or should be created during
875 # execution. It can trigger during execution if the input
876 # wasn't actually created by an upstream task in the same
877 # graph.
878 raise NoWorkFound(label, name, input_connection)
879 for name, (output_connection, refs) in outputs.items():
880 dataset_type_name = output_connection.name
881 if not output_connection.multiple and len(refs) > 1:
882 raise ScalarError(
883 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
884 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
885 f"for quantum data ID {data_id}."
886 )
887 return {}, {}
890def iterConnections(
891 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
892) -> Generator[BaseConnection, None, None]:
893 """Creates an iterator over the selected connections type which yields
894 all the defined connections of that type.
896 Parameters
897 ----------
898 connections: `PipelineTaskConnections`
899 An instance of a `PipelineTaskConnections` object that will be iterated
900 over.
901 connectionType: `str`
902 The type of connections to iterate over, valid values are inputs,
903 outputs, prerequisiteInputs, initInputs, initOutputs.
905 Yields
906 ------
907 connection: `BaseConnection`
908 A connection defined on the input connections object of the type
909 supplied. The yielded value Will be an derived type of
910 `BaseConnection`.
911 """
912 if isinstance(connectionType, str):
913 connectionType = (connectionType,)
914 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
915 yield getattr(connections, name)
918@dataclass
919class AdjustQuantumHelper:
920 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
922 This class holds `input` and `output` mappings in the form used by
923 `Quantum` and execution harness code, i.e. with `DatasetType` keys,
924 translating them to and from the connection-oriented mappings used inside
925 `PipelineTaskConnections`.
926 """
928 inputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
929 """Mapping of regular input and prerequisite input datasets, grouped by
930 `DatasetType`.
931 """
933 outputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
934 """Mapping of output datasets, grouped by `DatasetType`.
935 """
937 inputs_adjusted: bool = False
938 """Whether any inputs were removed in the last call to `adjust_in_place`.
939 """
941 outputs_adjusted: bool = False
942 """Whether any outputs were removed in the last call to `adjust_in_place`.
943 """
945 def adjust_in_place(
946 self,
947 connections: PipelineTaskConnections,
948 label: str,
949 data_id: DataCoordinate,
950 ) -> None:
951 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
952 with its results.
954 Parameters
955 ----------
956 connections : `PipelineTaskConnections`
957 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
958 label : `str`
959 Label for this task in the pipeline (should be used in all
960 diagnostic messages).
961 data_id : `lsst.daf.butler.DataCoordinate`
962 Data ID for this quantum in the pipeline (should be used in all
963 diagnostic messages).
964 """
965 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
966 # connection-keyed, PipelineTask-oriented mappings.
967 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
968 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
969 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
970 connection = getattr(connections, name)
971 dataset_type_name = connection.name
972 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
973 for name in itertools.chain(connections.outputs):
974 connection = getattr(connections, name)
975 dataset_type_name = connection.name
976 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
977 # Actually call adjustQuantum.
978 # MyPy correctly complains that this call is not quite legal, but the
979 # method docs explain exactly what's expected and it's the behavior we
980 # want. It'd be nice to avoid this if we ever have to change the
981 # interface anyway, but not an immediate problem.
982 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
983 inputs_by_connection, # type: ignore
984 outputs_by_connection, # type: ignore
985 label,
986 data_id,
987 )
988 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
989 # installing new mappings in self if necessary.
990 if adjusted_inputs_by_connection:
991 adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs)
992 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
993 dataset_type_name = connection.name
994 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
995 raise RuntimeError(
996 f"adjustQuantum implementation for task with label {label} returned {name} "
997 f"({dataset_type_name}) input datasets that are not a subset of those "
998 f"it was given for data ID {data_id}."
999 )
1000 adjusted_inputs[dataset_type_name] = list(updated_refs)
1001 self.inputs = adjusted_inputs.freeze()
1002 self.inputs_adjusted = True
1003 else:
1004 self.inputs_adjusted = False
1005 if adjusted_outputs_by_connection:
1006 adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs)
1007 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1008 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1009 raise RuntimeError(
1010 f"adjustQuantum implementation for task with label {label} returned {name} "
1011 f"({dataset_type_name}) output datasets that are not a subset of those "
1012 f"it was given for data ID {data_id}."
1013 )
1014 adjusted_outputs[dataset_type_name] = list(updated_refs)
1015 self.outputs = adjusted_outputs.freeze()
1016 self.outputs_adjusted = True
1017 else:
1018 self.outputs_adjusted = False