Coverage for python/lsst/pipe/base/connections.py: 44%
303 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:55 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:55 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining connection classes for PipelineTask.
29"""
31from __future__ import annotations
33__all__ = [
34 "AdjustQuantumHelper",
35 "DeferredDatasetRef",
36 "InputQuantizedConnection",
37 "OutputQuantizedConnection",
38 "PipelineTaskConnections",
39 "ScalarError",
40 "iterConnections",
41 "ScalarError",
42 "QuantizedConnection",
43]
45import dataclasses
46import itertools
47import string
48import warnings
49from collections import UserDict
50from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
51from dataclasses import dataclass
52from types import MappingProxyType, SimpleNamespace
53from typing import TYPE_CHECKING, Any
55from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
57from ._status import NoWorkFound
58from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
60if TYPE_CHECKING:
61 from .config import PipelineTaskConfig
64class ScalarError(TypeError):
65 """Exception raised when dataset type is configured as scalar
66 but there are multiple data IDs in a Quantum for that dataset.
67 """
70class PipelineTaskConnectionDict(UserDict):
71 """A special dict class used by `PipelineTaskConnectionMetaclass`.
73 This dict is used in `PipelineTaskConnection` class creation, as the
74 dictionary that is initially used as ``__dict__``. It exists to
75 intercept connection fields declared in a `PipelineTaskConnection`, and
76 what name is used to identify them. The names are then added to class
77 level list according to the connection type of the class attribute. The
78 names are also used as keys in a class level dictionary associated with
79 the corresponding class attribute. This information is a duplicate of
80 what exists in ``__dict__``, but provides a simple place to lookup and
81 iterate on only these variables.
83 Parameters
84 ----------
85 *args : `~typing.Any`
86 Passed to `dict` constructor.
87 **kwargs : `~typing.Any`
88 Passed to `dict` constructor.
89 """
91 def __init__(self, *args: Any, **kwargs: Any):
92 super().__init__(*args, **kwargs)
93 # Initialize class level variables used to track any declared
94 # class level variables that are instances of
95 # connectionTypes.BaseConnection
96 self.data["inputs"] = set()
97 self.data["prerequisiteInputs"] = set()
98 self.data["outputs"] = set()
99 self.data["initInputs"] = set()
100 self.data["initOutputs"] = set()
101 self.data["allConnections"] = {}
103 def __setitem__(self, name: str, value: Any) -> None:
104 if isinstance(value, BaseConnection):
105 if name in { 105 ↛ 115line 105 didn't jump to line 115, because the condition on line 105 was never true
106 "dimensions",
107 "inputs",
108 "prerequisiteInputs",
109 "outputs",
110 "initInputs",
111 "initOutputs",
112 "allConnections",
113 }:
114 # Guard against connections whose names are reserved.
115 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
116 if (previous := self.data.get(name)) is not None: 116 ↛ 119line 116 didn't jump to line 119, because the condition on line 116 was never true
117 # Guard against changing the type of an in inherited connection
118 # by first removing it from the set it's current in.
119 self.data[previous._connection_type_set].discard(name)
120 object.__setattr__(value, "varName", name)
121 self.data["allConnections"][name] = value
122 self.data[value._connection_type_set].add(name)
123 # defer to the default behavior
124 super().__setitem__(name, value)
127class PipelineTaskConnectionsMetaclass(type):
128 """Metaclass used in the declaration of PipelineTaskConnections classes.
130 Parameters
131 ----------
132 name : `str`
133 Name of connection.
134 bases : `~collections.abc.Collection`
135 Base classes.
136 dct : `~collections.abc.Mapping`
137 Connections dict.
138 **kwargs : `~typing.Any`
139 Additional parameters.
140 """
142 # We can annotate these attributes as `collections.abc.Set` to discourage
143 # undesirable modifications in type-checked code, since the internal code
144 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
145 # these annotations anyway.
147 dimensions: Set[str]
148 """Set of dimension names that define the unit of work for this task.
150 Required and implied dependencies will automatically be expanded later and
151 need not be provided.
153 This is shadowed by an instance-level attribute on
154 `PipelineTaskConnections` instances.
155 """
157 inputs: Set[str]
158 """Set with the names of all `~connectionTypes.Input` connection
159 attributes.
161 This is updated automatically as class attributes are added. Note that
162 this attribute is shadowed by an instance-level attribute on
163 `PipelineTaskConnections` instances.
164 """
166 prerequisiteInputs: Set[str]
167 """Set with the names of all `~connectionTypes.PrerequisiteInput`
168 connection attributes.
170 See `inputs` for additional information.
171 """
173 outputs: Set[str]
174 """Set with the names of all `~connectionTypes.Output` connection
175 attributes.
177 See `inputs` for additional information.
178 """
180 initInputs: Set[str]
181 """Set with the names of all `~connectionTypes.InitInput` connection
182 attributes.
184 See `inputs` for additional information.
185 """
187 initOutputs: Set[str]
188 """Set with the names of all `~connectionTypes.InitOutput` connection
189 attributes.
191 See `inputs` for additional information.
192 """
194 allConnections: Mapping[str, BaseConnection]
195 """Mapping containing all connection attributes.
197 See `inputs` for additional information.
198 """
200 def __prepare__(name, bases, **kwargs): # noqa: N804
201 # Create an instance of our special dict to catch and track all
202 # variables that are instances of connectionTypes.BaseConnection
203 # Copy any existing connections from a parent class
204 dct = PipelineTaskConnectionDict()
205 for base in bases:
206 if isinstance(base, PipelineTaskConnectionsMetaclass): 206 ↛ 205line 206 didn't jump to line 205, because the condition on line 206 was never false
207 for name, value in base.allConnections.items(): 207 ↛ 208line 207 didn't jump to line 208, because the loop on line 207 never started
208 dct[name] = value
209 return dct
211 def __new__(cls, name, bases, dct, **kwargs):
212 dimensionsValueError = TypeError(
213 "PipelineTaskConnections class must be created with a dimensions "
214 "attribute which is an iterable of dimension names"
215 )
217 if name != "PipelineTaskConnections":
218 # Verify that dimensions are passed as a keyword in class
219 # declaration
220 if "dimensions" not in kwargs: 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true
221 for base in bases:
222 if hasattr(base, "dimensions"):
223 kwargs["dimensions"] = base.dimensions
224 break
225 if "dimensions" not in kwargs:
226 raise dimensionsValueError
227 try:
228 if isinstance(kwargs["dimensions"], str): 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true
229 raise TypeError(
230 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
231 )
232 if not isinstance(kwargs["dimensions"], Iterable): 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true
233 raise TypeError("Dimensions must be iterable of dimensions")
234 dct["dimensions"] = set(kwargs["dimensions"])
235 except TypeError as exc:
236 raise dimensionsValueError from exc
237 # Lookup any python string templates that may have been used in the
238 # declaration of the name field of a class connection attribute
239 allTemplates = set()
240 stringFormatter = string.Formatter()
241 # Loop over all connections
242 for obj in dct["allConnections"].values():
243 nameValue = obj.name
244 # add all the parameters to the set of templates
245 for param in stringFormatter.parse(nameValue):
246 if param[1] is not None:
247 allTemplates.add(param[1])
249 # look up any template from base classes and merge them all
250 # together
251 mergeDict = {}
252 mergeDeprecationsDict = {}
253 for base in bases[::-1]:
254 if hasattr(base, "defaultTemplates"):
255 mergeDict.update(base.defaultTemplates)
256 if hasattr(base, "deprecatedTemplates"):
257 mergeDeprecationsDict.update(base.deprecatedTemplates)
258 if "defaultTemplates" in kwargs:
259 mergeDict.update(kwargs["defaultTemplates"])
260 if "deprecatedTemplates" in kwargs: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true
261 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
262 if len(mergeDict) > 0:
263 kwargs["defaultTemplates"] = mergeDict
264 if len(mergeDeprecationsDict) > 0: 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true
265 kwargs["deprecatedTemplates"] = mergeDeprecationsDict
267 # Verify that if templated strings were used, defaults were
268 # supplied as an argument in the declaration of the connection
269 # class
270 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 270 ↛ 271line 270 didn't jump to line 271, because the condition on line 270 was never true
271 raise TypeError(
272 "PipelineTaskConnection class contains templated attribute names, but no "
273 "defaut templates were provided, add a dictionary attribute named "
274 "defaultTemplates which contains the mapping between template key and value"
275 )
276 if len(allTemplates) > 0:
277 # Verify all templates have a default, and throw if they do not
278 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
279 templateDifference = allTemplates.difference(defaultTemplateKeys)
280 if templateDifference: 280 ↛ 281line 280 didn't jump to line 281, because the condition on line 280 was never true
281 raise TypeError(f"Default template keys were not provided for {templateDifference}")
282 # Verify that templates do not share names with variable names
283 # used for a connection, this is needed because of how
284 # templates are specified in an associated config class.
285 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
286 if len(nameTemplateIntersection) > 0: 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true
287 raise TypeError(
288 "Template parameters cannot share names with Class attributes"
289 f" (conflicts are {nameTemplateIntersection})."
290 )
291 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
292 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})
294 # Convert all the connection containers into frozensets so they cannot
295 # be modified at the class scope
296 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
297 dct[connectionName] = frozenset(dct[connectionName])
298 # our custom dict type must be turned into an actual dict to be used in
299 # type.__new__
300 return super().__new__(cls, name, bases, dict(dct))
302 def __init__(cls, name, bases, dct, **kwargs):
303 # This overrides the default init to drop the kwargs argument. Python
304 # metaclasses will have this argument set if any kwargs are passes at
305 # class construction time, but should be consumed before calling
306 # __init__ on the type metaclass. This is in accordance with python
307 # documentation on metaclasses
308 super().__init__(name, bases, dct)
310 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
311 # MyPy appears not to really understand metaclass.__call__ at all, so
312 # we need to tell it to ignore __new__ and __init__ calls here.
313 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
315 # Make mutable copies of all set-like class attributes so derived
316 # __init__ implementations can modify them in place.
317 instance.dimensions = set(cls.dimensions)
318 instance.inputs = set(cls.inputs)
319 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
320 instance.outputs = set(cls.outputs)
321 instance.initInputs = set(cls.initInputs)
322 instance.initOutputs = set(cls.initOutputs)
324 # Set self.config. It's a bit strange that we claim to accept None but
325 # really just raise here, but it's not worth changing now.
326 from .config import PipelineTaskConfig # local import to avoid cycle
328 if config is None or not isinstance(config, PipelineTaskConfig):
329 raise ValueError(
330 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
331 )
332 instance.config = config
334 # Extract the template names that were defined in the config instance
335 # by looping over the keys of the defaultTemplates dict specified at
336 # class declaration time.
337 templateValues = {
338 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore
339 }
341 # We now assemble a mapping of all connection instances keyed by
342 # internal name, applying the configuration and templates to make new
343 # configurations from the class-attribute defaults. This will be
344 # private, but with a public read-only view. This mapping is what the
345 # descriptor interface of the class-level attributes will return when
346 # they are accessed on an instance. This is better than just assigning
347 # regular instance attributes as it makes it so removed connections
348 # cannot be accessed on instances, instead of having access to them
349 # silent fall through to the not-removed class connection instance.
350 instance._allConnections = {}
351 instance.allConnections = MappingProxyType(instance._allConnections)
352 for internal_name, connection in cls.allConnections.items():
353 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
354 instance_connection = dataclasses.replace(
355 connection,
356 name=dataset_type_name,
357 doc=(
358 connection.doc
359 if connection.deprecated is None
360 else f"{connection.doc}\n{connection.deprecated}"
361 ),
362 _deprecation_context=connection._deprecation_context,
363 )
364 instance._allConnections[internal_name] = instance_connection
366 # Finally call __init__. The base class implementation does nothing;
367 # we could have left some of the above implementation there (where it
368 # originated), but putting it here instead makes it hard for derived
369 # class implementors to get things into a weird state by delegating to
370 # super().__init__ in the wrong place, or by forgetting to do that
371 # entirely.
372 instance.__init__(config=config) # type: ignore
374 # Derived-class implementations may have changed the contents of the
375 # various kinds-of-connection sets; update allConnections to have keys
376 # that are a union of all those. We get values for the new
377 # allConnections from the attributes, since any dynamically added new
378 # ones will not be present in the old allConnections. Typically those
379 # getattrs will invoke the descriptors and get things from the old
380 # allConnections anyway. After processing each set we replace it with
381 # a frozenset.
382 updated_all_connections = {}
383 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
384 updated_connection_names = getattr(instance, attrName)
385 updated_all_connections.update(
386 {name: getattr(instance, name) for name in updated_connection_names}
387 )
388 # Setting these to frozenset is at odds with the type annotation,
389 # but MyPy can't tell because we're using setattr, and we want to
390 # lie to it anyway to get runtime guards against post-__init__
391 # mutation.
392 setattr(instance, attrName, frozenset(updated_connection_names))
393 # Update the existing dict in place, since we already have a view of
394 # that.
395 instance._allConnections.clear()
396 instance._allConnections.update(updated_all_connections)
398 for connection_name, obj in instance._allConnections.items():
399 if obj.deprecated is not None:
400 warnings.warn(
401 f"Connection {connection_name} with datasetType {obj.name} "
402 f"(from {obj._deprecation_context}): {obj.deprecated}",
403 FutureWarning,
404 stacklevel=1, # Report from this location.
405 )
407 # Freeze the connection instance dimensions now. This at odds with the
408 # type annotation, which says [mutable] `set`, just like the connection
409 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
410 # tell with those since we're using setattr for them.
411 instance.dimensions = frozenset(instance.dimensions) # type: ignore
413 return instance
416class QuantizedConnection(SimpleNamespace):
417 r"""A Namespace to map defined variable names of connections to the
418 associated `lsst.daf.butler.DatasetRef` objects.
420 This class maps the names used to define a connection on a
421 `PipelineTaskConnections` class to the corresponding
422 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum`
423 instance. This will be a quantum of execution based on the graph created
424 by examining all the connections defined on the
425 `PipelineTaskConnections` class.
427 Parameters
428 ----------
429 **kwargs : `~typing.Any`
430 Not used.
431 """
433 def __init__(self, **kwargs):
434 # Create a variable to track what attributes are added. This is used
435 # later when iterating over this QuantizedConnection instance
436 object.__setattr__(self, "_attributes", set())
438 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
439 # Capture the attribute name as it is added to this object
440 self._attributes.add(name)
441 super().__setattr__(name, value)
443 def __delattr__(self, name):
444 object.__delattr__(self, name)
445 self._attributes.remove(name)
447 def __len__(self) -> int:
448 return len(self._attributes)
450 def __iter__(
451 self,
452 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]:
453 """Make an iterator for this `QuantizedConnection`.
455 Iterating over a `QuantizedConnection` will yield a tuple with the name
456 of an attribute and the value associated with that name. This is
457 similar to dict.items() but is on the namespace attributes rather than
458 dict keys.
459 """
460 yield from ((name, getattr(self, name)) for name in self._attributes)
462 def keys(self) -> Generator[str, None, None]:
463 """Return an iterator over all the attributes added to a
464 `QuantizedConnection` class.
465 """
466 yield from self._attributes
469class InputQuantizedConnection(QuantizedConnection):
470 """Input variant of a `QuantizedConnection`."""
472 pass
475class OutputQuantizedConnection(QuantizedConnection):
476 """Output variant of a `QuantizedConnection`."""
478 pass
481@dataclass(frozen=True)
482class DeferredDatasetRef:
483 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a
484 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle`
485 instead of an in-memory dataset.
487 Attributes
488 ----------
489 datasetRef : `lsst.daf.butler.DatasetRef`
490 The `lsst.daf.butler.DatasetRef` that will be eventually used to
491 resolve a dataset.
492 """
494 datasetRef: DatasetRef
496 @property
497 def datasetType(self) -> DatasetType:
498 """The dataset type for this dataset."""
499 return self.datasetRef.datasetType
501 @property
502 def dataId(self) -> DataCoordinate:
503 """The data ID for this dataset."""
504 return self.datasetRef.dataId
507class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
508 """PipelineTaskConnections is a class used to declare desired IO when a
509 PipelineTask is run by an activator.
511 Parameters
512 ----------
513 config : `PipelineTaskConfig`
514 A `PipelineTaskConfig` class instance whose class has been configured
515 to use this `PipelineTaskConnections` class.
517 See Also
518 --------
519 iterConnections : Iterator over selected connections.
521 Notes
522 -----
523 ``PipelineTaskConnection`` classes are created by declaring class
524 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
525 listed as follows:
527 * ``InitInput`` - Defines connections in a quantum graph which are used as
528 inputs to the ``__init__`` function of the `PipelineTask` corresponding
529 to this class
530 * ``InitOuput`` - Defines connections in a quantum graph which are to be
531 persisted using a butler at the end of the ``__init__`` function of the
532 `PipelineTask` corresponding to this class. The variable name used to
533 define this connection should be the same as an attribute name on the
534 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
535 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
536 a `PipelineTask` instance should have an attribute
537 ``self.outputSchema`` defined. Its value is what will be saved by the
538 activator framework.
539 * ``PrerequisiteInput`` - An input connection type that defines a
540 `lsst.daf.butler.DatasetType` that must be present at execution time,
541 but that will not be used during the course of creating the quantum
542 graph to be executed. These most often are things produced outside the
543 processing pipeline, such as reference catalogs.
544 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
545 in the ``run`` method of a `PipelineTask`. The name used to declare
546 class attribute must match a function argument name in the ``run``
547 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
548 defines an ``Input`` with the name ``calexp``, then the corresponding
549 signature should be ``PipelineTask.run(calexp, ...)``
550 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
551 execution of a `PipelineTask`. The name used to declare the connection
552 must correspond to an attribute of a `Struct` that is returned by a
553 `PipelineTask` ``run`` method. E.g. if an output connection is
554 defined with the name ``measCat``, then the corresponding
555 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
556 X matches the ``storageClass`` type defined on the output connection.
558 Attributes of these types can also be created, replaced, or deleted on the
559 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
560 just the name depends on the configuration. It is preferred to define them
561 in the class when possible (even if configuration may cause the connection
562 to be removed from the instance).
564 The process of declaring a ``PipelineTaskConnection`` class involves
565 parameters passed in the declaration statement.
567 The first parameter is ``dimensions`` which is an iterable of strings which
568 defines the unit of processing the run method of a corresponding
569 `PipelineTask` will operate on. These dimensions must match dimensions that
570 exist in the butler registry which will be used in executing the
571 corresponding `PipelineTask`. The dimensions may be also modified in
572 subclass ``__init__`` methods if they need to depend on configuration.
574 The second parameter is labeled ``defaultTemplates`` and is conditionally
575 optional. The name attributes of connections can be specified as python
576 format strings, with named format arguments. If any of the name parameters
577 on connections defined in a `PipelineTaskConnections` class contain a
578 template, then a default template value must be specified in the
579 ``defaultTemplates`` argument. This is done by passing a dictionary with
580 keys corresponding to a template identifier, and values corresponding to
581 the value to use as a default when formatting the string. For example if
582 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
583 ``defaultTemplates`` = {'input': 'deep'}.
585 Once a `PipelineTaskConnections` class is created, it is used in the
586 creation of a `PipelineTaskConfig`. This is further documented in the
587 documentation of `PipelineTaskConfig`. For the purposes of this
588 documentation, the relevant information is that the config class allows
589 configuration of connection names by users when running a pipeline.
591 Instances of a `PipelineTaskConnections` class are used by the pipeline
592 task execution framework to introspect what a corresponding `PipelineTask`
593 will require, and what it will produce.
595 Examples
596 --------
597 >>> from lsst.pipe.base import connectionTypes as cT
598 >>> from lsst.pipe.base import PipelineTaskConnections
599 >>> from lsst.pipe.base import PipelineTaskConfig
600 >>> class ExampleConnections(PipelineTaskConnections,
601 ... dimensions=("A", "B"),
602 ... defaultTemplates={"foo": "Example"}):
603 ... inputConnection = cT.Input(doc="Example input",
604 ... dimensions=("A", "B"),
605 ... storageClass=Exposure,
606 ... name="{foo}Dataset")
607 ... outputConnection = cT.Output(doc="Example output",
608 ... dimensions=("A", "B"),
609 ... storageClass=Exposure,
610 ... name="{foo}output")
611 >>> class ExampleConfig(PipelineTaskConfig,
612 ... pipelineConnections=ExampleConnections):
613 ... pass
614 >>> config = ExampleConfig()
615 >>> config.connections.foo = Modified
616 >>> config.connections.outputConnection = "TotallyDifferent"
617 >>> connections = ExampleConnections(config=config)
618 >>> assert(connections.inputConnection.name == "ModifiedDataset")
619 >>> assert(connections.outputConnection.name == "TotallyDifferent")
620 """
622 # We annotate these attributes as mutable sets because that's what they are
623 # inside derived ``__init__`` implementations and that's what matters most
624 # After that's done, the metaclass __call__ makes them into frozensets, but
625 # relatively little code interacts with them then, and that code knows not
626 # to try to modify them without having to be told that by mypy.
628 dimensions: set[str]
629 """Set of dimension names that define the unit of work for this task.
631 Required and implied dependencies will automatically be expanded later and
632 need not be provided.
634 This may be replaced or modified in ``__init__`` to change the dimensions
635 of the task. After ``__init__`` it will be a `frozenset` and may not be
636 replaced.
637 """
639 inputs: set[str]
640 """Set with the names of all `connectionTypes.Input` connection attributes.
642 This is updated automatically as class attributes are added, removed, or
643 replaced in ``__init__``. Removing entries from this set will cause those
644 connections to be removed after ``__init__`` completes, but this is
645 supported only for backwards compatibility; new code should instead just
646 delete the collection attributed directly. After ``__init__`` this will be
647 a `frozenset` and may not be replaced.
648 """
650 prerequisiteInputs: set[str]
651 """Set with the names of all `~connectionTypes.PrerequisiteInput`
652 connection attributes.
654 See `inputs` for additional information.
655 """
657 outputs: set[str]
658 """Set with the names of all `~connectionTypes.Output` connection
659 attributes.
661 See `inputs` for additional information.
662 """
664 initInputs: set[str]
665 """Set with the names of all `~connectionTypes.InitInput` connection
666 attributes.
668 See `inputs` for additional information.
669 """
671 initOutputs: set[str]
672 """Set with the names of all `~connectionTypes.InitOutput` connection
673 attributes.
675 See `inputs` for additional information.
676 """
678 allConnections: Mapping[str, BaseConnection]
679 """Mapping holding all connection attributes.
681 This is a read-only view that is automatically updated when connection
682 attributes are added, removed, or replaced in ``__init__``. It is also
683 updated after ``__init__`` completes to reflect changes in `inputs`,
684 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
685 """
687 _allConnections: dict[str, BaseConnection]
689 def __init__(self, *, config: PipelineTaskConfig | None = None):
690 pass
692 def __setattr__(self, name: str, value: Any) -> None:
693 if isinstance(value, BaseConnection):
694 previous = self._allConnections.get(name)
695 try:
696 getattr(self, value._connection_type_set).add(name)
697 except AttributeError:
698 # Attempt to call add on a frozenset, which is what these sets
699 # are after __init__ is done.
700 raise TypeError("Connections objects are frozen after construction.") from None
701 if previous is not None and value._connection_type_set != previous._connection_type_set:
702 # Connection has changed type, e.g. Input to PrerequisiteInput;
703 # update the sets accordingly. To be extra defensive about
704 # multiple assignments we use the type of the previous instance
705 # instead of assuming that's the same as the type of the self,
706 # which is just the default. Use discard instead of remove so
707 # manually removing from these sets first is never an error.
708 getattr(self, previous._connection_type_set).discard(name)
709 self._allConnections[name] = value
710 if hasattr(self.__class__, name):
711 # Don't actually set the attribute if this was a connection
712 # declared in the class; in that case we let the descriptor
713 # return the value we just added to allConnections.
714 return
715 # Actually add the attribute.
716 super().__setattr__(name, value)
718 def __delattr__(self, name):
719 """Descriptor delete method."""
720 previous = self._allConnections.get(name)
721 if previous is not None:
722 # Delete this connection's name from the appropriate set, which we
723 # have to get from the previous instance instead of assuming it's
724 # the same set that was appropriate for the class-level default.
725 # Use discard instead of remove so manually removing from these
726 # sets first is never an error.
727 try:
728 getattr(self, previous._connection_type_set).discard(name)
729 except AttributeError:
730 # Attempt to call discard on a frozenset, which is what these
731 # sets are after __init__ is done.
732 raise TypeError("Connections objects are frozen after construction.") from None
733 del self._allConnections[name]
734 if hasattr(self.__class__, name):
735 # Don't actually delete the attribute if this was a connection
736 # declared in the class; in that case we let the descriptor
737 # see that it's no longer present in allConnections.
738 return
739 # Actually delete the attribute.
740 super().__delattr__(name)
742 def buildDatasetRefs(
743 self, quantum: Quantum
744 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
745 """Build `QuantizedConnection` corresponding to input
746 `~lsst.daf.butler.Quantum`.
748 Parameters
749 ----------
750 quantum : `lsst.daf.butler.Quantum`
751 Quantum object which defines the inputs and outputs for a given
752 unit of processing.
754 Returns
755 -------
756 retVal : `tuple` of (`InputQuantizedConnection`, \
757 `OutputQuantizedConnection`)
758 Namespaces mapping attribute names
759 (identifiers of connections) to butler references defined in the
760 input `~lsst.daf.butler.Quantum`.
761 """
762 inputDatasetRefs = InputQuantizedConnection()
763 outputDatasetRefs = OutputQuantizedConnection()
765 # populate inputDatasetRefs from quantum inputs
766 for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs):
767 # get the attribute identified by name
768 attribute = getattr(self, attributeName)
769 # if the dataset is marked to load deferred, wrap it in a
770 # DeferredDatasetRef
771 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
772 if attribute.deferLoad:
773 quantumInputRefs = [
774 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
775 ]
776 else:
777 quantumInputRefs = list(quantum.inputs[attribute.name])
778 # Unpack arguments that are not marked multiples (list of
779 # length one)
780 if not attribute.multiple:
781 if len(quantumInputRefs) > 1:
782 raise ScalarError(
783 "Received multiple datasets "
784 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
785 f"for scalar connection {attributeName} "
786 f"({quantumInputRefs[0].datasetType.name}) "
787 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
788 )
789 if len(quantumInputRefs) == 0:
790 continue
791 setattr(inputDatasetRefs, attributeName, quantumInputRefs[0])
792 else:
793 # Add to the QuantizedConnection identifier
794 setattr(inputDatasetRefs, attributeName, quantumInputRefs)
796 # populate outputDatasetRefs from quantum outputs
797 for attributeName in self.outputs:
798 # get the attribute identified by name
799 attribute = getattr(self, attributeName)
800 value = quantum.outputs[attribute.name]
801 # Unpack arguments that are not marked multiples (list of
802 # length one)
803 if not attribute.multiple:
804 setattr(outputDatasetRefs, attributeName, value[0])
805 else:
806 setattr(outputDatasetRefs, attributeName, value)
808 return inputDatasetRefs, outputDatasetRefs
810 def adjustQuantum(
811 self,
812 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
813 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
814 label: str,
815 data_id: DataCoordinate,
816 ) -> tuple[
817 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
818 Mapping[str, tuple[Output, Collection[DatasetRef]]],
819 ]:
820 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
821 in the `lsst.daf.butler.Quantum` during the graph generation stage
822 of the activator.
824 Parameters
825 ----------
826 inputs : `dict`
827 Dictionary whose keys are an input (regular or prerequisite)
828 connection name and whose values are a tuple of the connection
829 instance and a collection of associated
830 `~lsst.daf.butler.DatasetRef` objects.
831 The exact type of the nested collections is unspecified; it can be
832 assumed to be multi-pass iterable and support `len` and ``in``, but
833 it should not be mutated in place. In contrast, the outer
834 dictionaries are guaranteed to be temporary copies that are true
835 `dict` instances, and hence may be modified and even returned; this
836 is especially useful for delegating to `super` (see notes below).
837 outputs : `~collections.abc.Mapping`
838 Mapping of output datasets, with the same structure as ``inputs``.
839 label : `str`
840 Label for this task in the pipeline (should be used in all
841 diagnostic messages).
842 data_id : `lsst.daf.butler.DataCoordinate`
843 Data ID for this quantum in the pipeline (should be used in all
844 diagnostic messages).
846 Returns
847 -------
848 adjusted_inputs : `~collections.abc.Mapping`
849 Mapping of the same form as ``inputs`` with updated containers of
850 input `~lsst.daf.butler.DatasetRef` objects. Connections that are
851 not changed should not be returned at all. Datasets may only be
852 removed, not added. Nested collections may be of any multi-pass
853 iterable type, and the order of iteration will set the order of
854 iteration within `PipelineTask.runQuantum`.
855 adjusted_outputs : `~collections.abc.Mapping`
856 Mapping of updated output datasets, with the same structure and
857 interpretation as ``adjusted_inputs``.
859 Raises
860 ------
861 ScalarError
862 Raised if any `Input` or `PrerequisiteInput` connection has
863 ``multiple`` set to `False`, but multiple datasets.
864 NoWorkFound
865 Raised to indicate that this quantum should not be run; not enough
866 datasets were found for a regular `Input` connection, and the
867 quantum should be pruned or skipped.
868 FileNotFoundError
869 Raised to cause QuantumGraph generation to fail (with the message
870 included in this exception); not enough datasets were found for a
871 `PrerequisiteInput` connection.
873 Notes
874 -----
875 The base class implementation performs important checks. It always
876 returns an empty mapping (i.e. makes no adjustments). It should
877 always called be via `super` by custom implementations, ideally at the
878 end of the custom implementation with already-adjusted mappings when
879 any datasets are actually dropped, e.g.:
881 .. code-block:: python
883 def adjustQuantum(self, inputs, outputs, label, data_id):
884 # Filter out some dataset refs for one connection.
885 connection, old_refs = inputs["my_input"]
886 new_refs = [ref for ref in old_refs if ...]
887 adjusted_inputs = {"my_input", (connection, new_refs)}
888 # Update the original inputs so we can pass them to super.
889 inputs.update(adjusted_inputs)
890 # Can ignore outputs from super because they are guaranteed
891 # to be empty.
892 super().adjustQuantum(inputs, outputs, label_data_id)
893 # Return only the connections we modified.
894 return adjusted_inputs, {}
896 Removing outputs here is guaranteed to affect what is actually
897 passed to `PipelineTask.runQuantum`, but its effect on the larger
898 graph may be deferred to execution, depending on the context in
899 which `adjustQuantum` is being run: if one quantum removes an output
900 that is needed by a second quantum as input, the second quantum may not
901 be adjusted (and hence pruned or skipped) until that output is actually
902 found to be missing at execution time.
904 Tasks that desire zip-iteration consistency between any combinations of
905 connections that have the same data ID should generally implement
906 `adjustQuantum` to achieve this, even if they could also run that
907 logic during execution; this allows the system to see outputs that will
908 not be produced because the corresponding input is missing as early as
909 possible.
910 """
911 for name, (input_connection, refs) in inputs.items():
912 dataset_type_name = input_connection.name
913 if not input_connection.multiple and len(refs) > 1:
914 raise ScalarError(
915 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
916 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
917 f"for quantum data ID {data_id}."
918 )
919 if len(refs) < input_connection.minimum:
920 if isinstance(input_connection, PrerequisiteInput):
921 # This branch should only be possible during QG generation,
922 # or if someone deleted the dataset between making the QG
923 # and trying to run it. Either one should be a hard error.
924 raise FileNotFoundError(
925 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
926 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
927 f"{data_id}."
928 )
929 else:
930 raise NoWorkFound(label, name, input_connection)
931 for name, (output_connection, refs) in outputs.items():
932 dataset_type_name = output_connection.name
933 if not output_connection.multiple and len(refs) > 1:
934 raise ScalarError(
935 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
936 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
937 f"for quantum data ID {data_id}."
938 )
939 return {}, {}
941 def getSpatialBoundsConnections(self) -> Iterable[str]:
942 """Return the names of regular input and output connections whose data
943 IDs should be used to compute the spatial bounds of this task's quanta.
945 The spatial bound for a quantum is defined as the union of the regions
946 of all data IDs of all connections returned here, along with the region
947 of the quantum data ID (if the task has spatial dimensions).
949 Returns
950 -------
951 connection_names : `collections.abc.Iterable` [ `str` ]
952 Names of collections with spatial dimensions. These are the
953 task-internal connection names, not butler dataset type names.
955 Notes
956 -----
957 The spatial bound is used to search for prerequisite inputs that have
958 skypix dimensions. The default implementation returns an empty
959 iterable, which is usually sufficient for tasks with spatial
960 dimensions, but if a task's inputs or outputs are associated with
961 spatial regions that extend beyond the quantum data ID's region, this
962 method may need to be overridden to expand the set of prerequisite
963 inputs found.
965 Tasks that do not have spatial dimensions that have skypix prerequisite
966 inputs should always override this method, as the default spatial
967 bounds otherwise cover the full sky.
968 """
969 return ()
971 def getTemporalBoundsConnections(self) -> Iterable[str]:
972 """Return the names of regular input and output connections whose data
973 IDs should be used to compute the temporal bounds of this task's
974 quanta.
976 The temporal bound for a quantum is defined as the union of the
977 timespans of all data IDs of all connections returned here, along with
978 the timespan of the quantum data ID (if the task has temporal
979 dimensions).
981 Returns
982 -------
983 connection_names : `collections.abc.Iterable` [ `str` ]
984 Names of collections with temporal dimensions. These are the
985 task-internal connection names, not butler dataset type names.
987 Notes
988 -----
989 The temporal bound is used to search for prerequisite inputs that are
990 calibration datasets. The default implementation returns an empty
991 iterable, which is usually sufficient for tasks with temporal
992 dimensions, but if a task's inputs or outputs are associated with
993 timespans that extend beyond the quantum data ID's timespan, this
994 method may need to be overridden to expand the set of prerequisite
995 inputs found.
997 Tasks that do not have temporal dimensions that do not implement this
998 method will use an infinite timespan for any calibration lookups.
999 """
1000 return ()
1003def iterConnections(
1004 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
1005) -> Generator[BaseConnection, None, None]:
1006 """Create an iterator over the selected connections type which yields
1007 all the defined connections of that type.
1009 Parameters
1010 ----------
1011 connections : `PipelineTaskConnections`
1012 An instance of a `PipelineTaskConnections` object that will be iterated
1013 over.
1014 connectionType : `str`
1015 The type of connections to iterate over, valid values are inputs,
1016 outputs, prerequisiteInputs, initInputs, initOutputs.
1018 Yields
1019 ------
1020 connection: `~.connectionTypes.BaseConnection`
1021 A connection defined on the input connections object of the type
1022 supplied. The yielded value Will be an derived type of
1023 `~.connectionTypes.BaseConnection`.
1024 """
1025 if isinstance(connectionType, str):
1026 connectionType = (connectionType,)
1027 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
1028 yield getattr(connections, name)
1031@dataclass
1032class AdjustQuantumHelper:
1033 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
1035 This class holds `input` and `output` mappings in the form used by
1036 `Quantum` and execution harness code, i.e. with
1037 `~lsst.daf.butler.DatasetType` keys, translating them to and from the
1038 connection-oriented mappings used inside `PipelineTaskConnections`.
1039 """
1041 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1042 """Mapping of regular input and prerequisite input datasets, grouped by
1043 `~lsst.daf.butler.DatasetType`.
1044 """
1046 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1047 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
1048 """
1050 inputs_adjusted: bool = False
1051 """Whether any inputs were removed in the last call to `adjust_in_place`.
1052 """
1054 outputs_adjusted: bool = False
1055 """Whether any outputs were removed in the last call to `adjust_in_place`.
1056 """
1058 def adjust_in_place(
1059 self,
1060 connections: PipelineTaskConnections,
1061 label: str,
1062 data_id: DataCoordinate,
1063 ) -> None:
1064 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
1065 with its results.
1067 Parameters
1068 ----------
1069 connections : `PipelineTaskConnections`
1070 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
1071 label : `str`
1072 Label for this task in the pipeline (should be used in all
1073 diagnostic messages).
1074 data_id : `lsst.daf.butler.DataCoordinate`
1075 Data ID for this quantum in the pipeline (should be used in all
1076 diagnostic messages).
1077 """
1078 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
1079 # connection-keyed, PipelineTask-oriented mappings.
1080 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
1081 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
1082 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
1083 connection = getattr(connections, name)
1084 dataset_type_name = connection.name
1085 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
1086 for name in itertools.chain(connections.outputs):
1087 connection = getattr(connections, name)
1088 dataset_type_name = connection.name
1089 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
1090 # Actually call adjustQuantum.
1091 # MyPy correctly complains that this call is not quite legal, but the
1092 # method docs explain exactly what's expected and it's the behavior we
1093 # want. It'd be nice to avoid this if we ever have to change the
1094 # interface anyway, but not an immediate problem.
1095 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
1096 inputs_by_connection, # type: ignore
1097 outputs_by_connection, # type: ignore
1098 label,
1099 data_id,
1100 )
1101 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
1102 # installing new mappings in self if necessary.
1103 if adjusted_inputs_by_connection:
1104 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
1105 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
1106 dataset_type_name = connection.name
1107 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
1108 raise RuntimeError(
1109 f"adjustQuantum implementation for task with label {label} returned {name} "
1110 f"({dataset_type_name}) input datasets that are not a subset of those "
1111 f"it was given for data ID {data_id}."
1112 )
1113 adjusted_inputs[dataset_type_name] = tuple(updated_refs)
1114 self.inputs = adjusted_inputs.freeze()
1115 self.inputs_adjusted = True
1116 else:
1117 self.inputs_adjusted = False
1118 if adjusted_outputs_by_connection:
1119 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
1120 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1121 dataset_type_name = connection.name
1122 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1123 raise RuntimeError(
1124 f"adjustQuantum implementation for task with label {label} returned {name} "
1125 f"({dataset_type_name}) output datasets that are not a subset of those "
1126 f"it was given for data ID {data_id}."
1127 )
1128 adjusted_outputs[dataset_type_name] = tuple(updated_refs)
1129 self.outputs = adjusted_outputs.freeze()
1130 self.outputs_adjusted = True
1131 else:
1132 self.outputs_adjusted = False