Coverage for python / lsst / pipe / base / connections.py: 39%
389 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining connection classes for PipelineTask."""
30from __future__ import annotations
32__all__ = [
33 "AdjustQuantumHelper",
34 "DeferredDatasetRef",
35 "InputQuantizedConnection",
36 "OutputQuantizedConnection",
37 "PipelineTaskConnections",
38 "QuantaAdjuster",
39 "QuantizedConnection",
40 "ScalarError",
41 "ScalarError",
42 "iterConnections",
43]
45import dataclasses
46import itertools
47import string
48import uuid
49import warnings
50from collections import UserDict, defaultdict
51from collections.abc import Collection, Generator, Iterable, Iterator, Mapping, Sequence, Set
52from dataclasses import dataclass
53from types import MappingProxyType, SimpleNamespace
54from typing import TYPE_CHECKING, Any
56from lsst.daf.butler import (
57 Butler,
58 DataCoordinate,
59 DatasetRef,
60 DatasetType,
61 NamedKeyDict,
62 NamedKeyMapping,
63 Quantum,
64)
66from ._status import NoWorkFound
67from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
69if TYPE_CHECKING:
70 from .config import PipelineTaskConfig
71 from .pipeline_graph import PipelineGraph, TaskNode
72 from .quantum_graph_skeleton import QuantumGraphSkeleton
75class ScalarError(TypeError):
76 """Exception raised when dataset type is configured as scalar
77 but there are multiple data IDs in a Quantum for that dataset.
78 """
81class PipelineTaskConnectionDict(UserDict):
82 """A special dict class used by `PipelineTaskConnectionMetaclass`.
84 This dict is used in `PipelineTaskConnection` class creation, as the
85 dictionary that is initially used as ``__dict__``. It exists to
86 intercept connection fields declared in a `PipelineTaskConnection`, and
87 what name is used to identify them. The names are then added to class
88 level list according to the connection type of the class attribute. The
89 names are also used as keys in a class level dictionary associated with
90 the corresponding class attribute. This information is a duplicate of
91 what exists in ``__dict__``, but provides a simple place to lookup and
92 iterate on only these variables.
94 Parameters
95 ----------
96 *args : `~typing.Any`
97 Passed to `dict` constructor.
98 **kwargs : `~typing.Any`
99 Passed to `dict` constructor.
100 """
102 def __init__(self, *args: Any, **kwargs: Any):
103 super().__init__(*args, **kwargs)
104 # Initialize class level variables used to track any declared
105 # class level variables that are instances of
106 # connectionTypes.BaseConnection
107 self.data["inputs"] = set()
108 self.data["prerequisiteInputs"] = set()
109 self.data["outputs"] = set()
110 self.data["initInputs"] = set()
111 self.data["initOutputs"] = set()
112 self.data["allConnections"] = {}
114 def __setitem__(self, name: str, value: Any) -> None:
115 if isinstance(value, BaseConnection):
116 if name in { 116 ↛ 126line 116 didn't jump to line 126 because the condition on line 116 was never true
117 "dimensions",
118 "inputs",
119 "prerequisiteInputs",
120 "outputs",
121 "initInputs",
122 "initOutputs",
123 "allConnections",
124 }:
125 # Guard against connections whose names are reserved.
126 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
127 if (previous := self.data.get(name)) is not None: 127 ↛ 130line 127 didn't jump to line 130 because the condition on line 127 was never true
128 # Guard against changing the type of an in inherited connection
129 # by first removing it from the set it's current in.
130 self.data[previous._connection_type_set].discard(name)
131 object.__setattr__(value, "varName", name)
132 self.data["allConnections"][name] = value
133 self.data[value._connection_type_set].add(name)
134 # defer to the default behavior
135 super().__setitem__(name, value)
138class PipelineTaskConnectionsMetaclass(type):
139 """Metaclass used in the declaration of PipelineTaskConnections classes.
141 Parameters
142 ----------
143 name : `str`
144 Name of connection.
145 bases : `~collections.abc.Collection`
146 Base classes.
147 dct : `~collections.abc.Mapping`
148 Connections dict.
149 **kwargs : `~typing.Any`
150 Additional parameters.
151 """
153 # We can annotate these attributes as `collections.abc.Set` to discourage
154 # undesirable modifications in type-checked code, since the internal code
155 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
156 # these annotations anyway.
158 dimensions: Set[str]
159 """Set of dimension names that define the unit of work for this task.
161 Required and implied dependencies will automatically be expanded later and
162 need not be provided.
164 This is shadowed by an instance-level attribute on
165 `PipelineTaskConnections` instances.
166 """
168 inputs: Set[str]
169 """Set with the names of all `~connectionTypes.Input` connection
170 attributes.
172 This is updated automatically as class attributes are added. Note that
173 this attribute is shadowed by an instance-level attribute on
174 `PipelineTaskConnections` instances.
175 """
177 prerequisiteInputs: Set[str]
178 """Set with the names of all `~connectionTypes.PrerequisiteInput`
179 connection attributes.
181 See `inputs` for additional information.
182 """
184 outputs: Set[str]
185 """Set with the names of all `~connectionTypes.Output` connection
186 attributes.
188 See `inputs` for additional information.
189 """
191 initInputs: Set[str]
192 """Set with the names of all `~connectionTypes.InitInput` connection
193 attributes.
195 See `inputs` for additional information.
196 """
198 initOutputs: Set[str]
199 """Set with the names of all `~connectionTypes.InitOutput` connection
200 attributes.
202 See `inputs` for additional information.
203 """
205 allConnections: Mapping[str, BaseConnection]
206 """Mapping containing all connection attributes.
208 See `inputs` for additional information.
209 """
211 def __prepare__(name, bases, **kwargs): # noqa: N804
212 # Create an instance of our special dict to catch and track all
213 # variables that are instances of connectionTypes.BaseConnection
214 # Copy any existing connections from a parent class
215 dct = PipelineTaskConnectionDict()
216 for base in bases:
217 if isinstance(base, PipelineTaskConnectionsMetaclass): 217 ↛ 216line 217 didn't jump to line 216 because the condition on line 217 was always true
218 for name, value in base.allConnections.items(): 218 ↛ 219line 218 didn't jump to line 219 because the loop on line 218 never started
219 dct[name] = value
220 return dct
222 def __new__(cls, name, bases, dct, **kwargs):
223 dimensionsValueError = TypeError(
224 "PipelineTaskConnections class must be created with a dimensions "
225 "attribute which is an iterable of dimension names"
226 )
228 if name != "PipelineTaskConnections":
229 # Verify that dimensions are passed as a keyword in class
230 # declaration
231 if "dimensions" not in kwargs: 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true
232 for base in bases:
233 if hasattr(base, "dimensions"):
234 kwargs["dimensions"] = base.dimensions
235 break
236 if "dimensions" not in kwargs:
237 raise dimensionsValueError
238 try:
239 if isinstance(kwargs["dimensions"], str): 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true
240 raise TypeError(
241 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
242 )
243 if not isinstance(kwargs["dimensions"], Iterable): 243 ↛ 244line 243 didn't jump to line 244 because the condition on line 243 was never true
244 raise TypeError("Dimensions must be iterable of dimensions")
245 dct["dimensions"] = set(kwargs["dimensions"])
246 except TypeError as exc:
247 raise dimensionsValueError from exc
248 # Lookup any python string templates that may have been used in the
249 # declaration of the name field of a class connection attribute
250 allTemplates = set()
251 stringFormatter = string.Formatter()
252 # Loop over all connections
253 for obj in dct["allConnections"].values():
254 nameValue = obj.name
255 # add all the parameters to the set of templates
256 for param in stringFormatter.parse(nameValue):
257 if param[1] is not None:
258 allTemplates.add(param[1])
260 # look up any template from base classes and merge them all
261 # together
262 mergeDict = {}
263 mergeDeprecationsDict = {}
264 for base in bases[::-1]:
265 if hasattr(base, "defaultTemplates"):
266 mergeDict.update(base.defaultTemplates)
267 if hasattr(base, "deprecatedTemplates"):
268 mergeDeprecationsDict.update(base.deprecatedTemplates)
269 if "defaultTemplates" in kwargs:
270 mergeDict.update(kwargs["defaultTemplates"])
271 if "deprecatedTemplates" in kwargs: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
273 if len(mergeDict) > 0:
274 kwargs["defaultTemplates"] = mergeDict
275 if len(mergeDeprecationsDict) > 0: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 kwargs["deprecatedTemplates"] = mergeDeprecationsDict
278 # Verify that if templated strings were used, defaults were
279 # supplied as an argument in the declaration of the connection
280 # class
281 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true
282 raise TypeError(
283 "PipelineTaskConnection class contains templated attribute names, but no "
284 "defaut templates were provided, add a dictionary attribute named "
285 "defaultTemplates which contains the mapping between template key and value"
286 )
287 if len(allTemplates) > 0:
288 # Verify all templates have a default, and throw if they do not
289 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
290 templateDifference = allTemplates.difference(defaultTemplateKeys)
291 if templateDifference: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 raise TypeError(f"Default template keys were not provided for {templateDifference}")
293 # Verify that templates do not share names with variable names
294 # used for a connection, this is needed because of how
295 # templates are specified in an associated config class.
296 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
297 if len(nameTemplateIntersection) > 0: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 raise TypeError(
299 "Template parameters cannot share names with Class attributes"
300 f" (conflicts are {nameTemplateIntersection})."
301 )
302 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
303 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})
305 # Convert all the connection containers into frozensets so they cannot
306 # be modified at the class scope
307 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
308 dct[connectionName] = frozenset(dct[connectionName])
309 # our custom dict type must be turned into an actual dict to be used in
310 # type.__new__
311 return super().__new__(cls, name, bases, dict(dct))
313 def __init__(cls, name, bases, dct, **kwargs):
314 # This overrides the default init to drop the kwargs argument. Python
315 # metaclasses will have this argument set if any kwargs are passes at
316 # class construction time, but should be consumed before calling
317 # __init__ on the type metaclass. This is in accordance with python
318 # documentation on metaclasses
319 super().__init__(name, bases, dct)
321 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
322 # MyPy appears not to really understand metaclass.__call__ at all, so
323 # we need to tell it to ignore __new__ and __init__ calls here.
324 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
326 # Make mutable copies of all set-like class attributes so derived
327 # __init__ implementations can modify them in place.
328 instance.dimensions = set(cls.dimensions)
329 instance.inputs = set(cls.inputs)
330 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
331 instance.outputs = set(cls.outputs)
332 instance.initInputs = set(cls.initInputs)
333 instance.initOutputs = set(cls.initOutputs)
335 # Set self.config. It's a bit strange that we claim to accept None but
336 # really just raise here, but it's not worth changing now.
337 from .config import PipelineTaskConfig # local import to avoid cycle
339 if config is None or not isinstance(config, PipelineTaskConfig):
340 raise ValueError(
341 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
342 )
343 instance.config = config
345 # Extract the template names that were defined in the config instance
346 # by looping over the keys of the defaultTemplates dict specified at
347 # class declaration time.
348 templateValues = {
349 name: getattr(config.connections, name)
350 for name in cls.defaultTemplates # type: ignore
351 }
353 # We now assemble a mapping of all connection instances keyed by
354 # internal name, applying the configuration and templates to make new
355 # configurations from the class-attribute defaults. This will be
356 # private, but with a public read-only view. This mapping is what the
357 # descriptor interface of the class-level attributes will return when
358 # they are accessed on an instance. This is better than just assigning
359 # regular instance attributes as it makes it so removed connections
360 # cannot be accessed on instances, instead of having access to them
361 # silent fall through to the not-removed class connection instance.
362 instance._allConnections = {}
363 instance.allConnections = MappingProxyType(instance._allConnections)
364 for internal_name, connection in cls.allConnections.items():
365 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
366 instance_connection = dataclasses.replace(
367 connection,
368 name=dataset_type_name,
369 doc=(
370 connection.doc
371 if connection.deprecated is None
372 else f"{connection.doc}\n{connection.deprecated}"
373 ),
374 _deprecation_context=connection._deprecation_context,
375 )
376 instance._allConnections[internal_name] = instance_connection
378 # Finally call __init__. The base class implementation does nothing;
379 # we could have left some of the above implementation there (where it
380 # originated), but putting it here instead makes it hard for derived
381 # class implementors to get things into a weird state by delegating to
382 # super().__init__ in the wrong place, or by forgetting to do that
383 # entirely.
384 instance.__init__(config=config) # type: ignore
386 # Derived-class implementations may have changed the contents of the
387 # various kinds-of-connection sets; update allConnections to have keys
388 # that are a union of all those. We get values for the new
389 # allConnections from the attributes, since any dynamically added new
390 # ones will not be present in the old allConnections. Typically those
391 # getattrs will invoke the descriptors and get things from the old
392 # allConnections anyway. After processing each set we replace it with
393 # a frozenset.
394 updated_all_connections = {}
395 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
396 updated_connection_names = getattr(instance, attrName)
397 updated_all_connections.update(
398 {name: getattr(instance, name) for name in updated_connection_names}
399 )
400 # Setting these to frozenset is at odds with the type annotation,
401 # but MyPy can't tell because we're using setattr, and we want to
402 # lie to it anyway to get runtime guards against post-__init__
403 # mutation.
404 setattr(instance, attrName, frozenset(updated_connection_names))
405 # Update the existing dict in place, since we already have a view of
406 # that.
407 instance._allConnections.clear()
408 instance._allConnections.update(updated_all_connections)
410 for connection_name, obj in instance._allConnections.items():
411 if obj.deprecated is not None:
412 warnings.warn(
413 f"Connection {connection_name} with datasetType {obj.name} "
414 f"(from {obj._deprecation_context}): {obj.deprecated}",
415 FutureWarning,
416 stacklevel=1, # Report from this location.
417 )
419 # Freeze the connection instance dimensions now. This at odds with the
420 # type annotation, which says [mutable] `set`, just like the connection
421 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
422 # tell with those since we're using setattr for them.
423 instance.dimensions = frozenset(instance.dimensions) # type: ignore
425 return instance
428class QuantizedConnection(SimpleNamespace):
429 r"""A Namespace to map defined variable names of connections to the
430 associated `lsst.daf.butler.DatasetRef` objects.
432 This class maps the names used to define a connection on a
433 `PipelineTaskConnections` class to the corresponding
434 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum`
435 instance. This will be a quantum of execution based on the graph created
436 by examining all the connections defined on the
437 `PipelineTaskConnections` class.
439 Parameters
440 ----------
441 **kwargs : `~typing.Any`
442 Not used.
443 """
445 def __init__(self, **kwargs):
446 # Create a variable to track what attributes are added. This is used
447 # later when iterating over this QuantizedConnection instance
448 object.__setattr__(self, "_attributes", set())
450 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
451 # Capture the attribute name as it is added to this object
452 self._attributes.add(name)
453 super().__setattr__(name, value)
455 def __delattr__(self, name):
456 object.__delattr__(self, name)
457 self._attributes.remove(name)
459 def __len__(self) -> int:
460 return len(self._attributes)
462 def __iter__(
463 self,
464 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]]]:
465 """Make an iterator for this `QuantizedConnection`.
467 Iterating over a `QuantizedConnection` will yield a tuple with the name
468 of an attribute and the value associated with that name. This is
469 similar to dict.items() but is on the namespace attributes rather than
470 dict keys.
471 """
472 yield from ((name, getattr(self, name)) for name in self._attributes)
474 def keys(self) -> Generator[str]:
475 """Return an iterator over all the attributes added to a
476 `QuantizedConnection` class.
477 """
478 yield from self._attributes
481class InputQuantizedConnection(QuantizedConnection):
482 """Input variant of a `QuantizedConnection`."""
484 pass
487class OutputQuantizedConnection(QuantizedConnection):
488 """Output variant of a `QuantizedConnection`."""
490 pass
493@dataclass(frozen=True)
494class DeferredDatasetRef:
495 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a
496 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle`
497 instead of an in-memory dataset.
498 """
500 datasetRef: DatasetRef
501 """The `lsst.daf.butler.DatasetRef` that will be eventually used to
502 resolve a dataset.
503 """
505 def __getattr__(self, name: str) -> Any:
506 # make sure reduce is called on DeferredDatasetRef and not on
507 # the DatasetRef
508 if name in ("__reduce__", "datasetRef", "__deepcopy__"):
509 object.__getattribute__(self, name)
510 return getattr(self.datasetRef, name)
512 def __deepcopy__(self, memo: dict) -> DeferredDatasetRef:
513 # dataset refs should be immutable deferred version should be too
514 return self
516 def __reduce__(self) -> tuple:
517 return (self.__class__, (self.datasetRef,))
520class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
521 """PipelineTaskConnections is a class used to declare desired IO when a
522 PipelineTask is run by an activator.
524 Parameters
525 ----------
526 config : `PipelineTaskConfig`
527 A `PipelineTaskConfig` class instance whose class has been configured
528 to use this `PipelineTaskConnections` class.
530 See Also
531 --------
532 iterConnections : Iterator over selected connections.
534 Notes
535 -----
536 ``PipelineTaskConnection`` classes are created by declaring class
537 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
538 listed as follows:
540 * ``InitInput`` - Defines connections in a quantum graph which are used as
541 inputs to the ``__init__`` function of the `PipelineTask` corresponding
542 to this class
543 * ``InitOuput`` - Defines connections in a quantum graph which are to be
544 persisted using a butler at the end of the ``__init__`` function of the
545 `PipelineTask` corresponding to this class. The variable name used to
546 define this connection should be the same as an attribute name on the
547 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
548 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
549 a `PipelineTask` instance should have an attribute
550 ``self.outputSchema`` defined. Its value is what will be saved by the
551 activator framework.
552 * ``PrerequisiteInput`` - An input connection type that defines a
553 `lsst.daf.butler.DatasetType` that must be present at execution time,
554 but that will not be used during the course of creating the quantum
555 graph to be executed. These most often are things produced outside the
556 processing pipeline, such as reference catalogs.
557 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
558 in the ``run`` method of a `PipelineTask`. The name used to declare
559 class attribute must match a function argument name in the ``run``
560 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
561 defines an ``Input`` with the name ``calexp``, then the corresponding
562 signature should be ``PipelineTask.run(calexp, ...)``
563 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
564 execution of a `PipelineTask`. The name used to declare the connection
565 must correspond to an attribute of a `Struct` that is returned by a
566 `PipelineTask` ``run`` method. E.g. if an output connection is
567 defined with the name ``measCat``, then the corresponding
568 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
569 X matches the ``storageClass`` type defined on the output connection.
571 Attributes of these types can also be created, replaced, or deleted on the
572 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
573 just the name depends on the configuration. It is preferred to define them
574 in the class when possible (even if configuration may cause the connection
575 to be removed from the instance).
577 The process of declaring a ``PipelineTaskConnection`` class involves
578 parameters passed in the declaration statement.
580 The first parameter is ``dimensions`` which is an iterable of strings which
581 defines the unit of processing the run method of a corresponding
582 `PipelineTask` will operate on. These dimensions must match dimensions that
583 exist in the butler registry which will be used in executing the
584 corresponding `PipelineTask`. The dimensions may be also modified in
585 subclass ``__init__`` methods if they need to depend on configuration.
587 The second parameter is labeled ``defaultTemplates`` and is conditionally
588 optional. The name attributes of connections can be specified as python
589 format strings, with named format arguments. If any of the name parameters
590 on connections defined in a `PipelineTaskConnections` class contain a
591 template, then a default template value must be specified in the
592 ``defaultTemplates`` argument. This is done by passing a dictionary with
593 keys corresponding to a template identifier, and values corresponding to
594 the value to use as a default when formatting the string. For example if
595 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
596 ``defaultTemplates`` = {'input': 'deep'}.
598 Once a `PipelineTaskConnections` class is created, it is used in the
599 creation of a `PipelineTaskConfig`. This is further documented in the
600 documentation of `PipelineTaskConfig`. For the purposes of this
601 documentation, the relevant information is that the config class allows
602 configuration of connection names by users when running a pipeline.
604 Instances of a `PipelineTaskConnections` class are used by the pipeline
605 task execution framework to introspect what a corresponding `PipelineTask`
606 will require, and what it will produce.
608 Examples
609 --------
610 >>> from lsst.pipe.base import connectionTypes as cT
611 >>> from lsst.pipe.base import PipelineTaskConnections
612 >>> from lsst.pipe.base import PipelineTaskConfig
613 >>> class ExampleConnections(
614 ... PipelineTaskConnections,
615 ... dimensions=("A", "B"),
616 ... defaultTemplates={"foo": "Example"},
617 ... ):
618 ... inputConnection = cT.Input(
619 ... doc="Example input",
620 ... dimensions=("A", "B"),
621 ... storageClass=Exposure,
622 ... name="{foo}Dataset",
623 ... )
624 ... outputConnection = cT.Output(
625 ... doc="Example output",
626 ... dimensions=("A", "B"),
627 ... storageClass=Exposure,
628 ... name="{foo}output",
629 ... )
630 >>> class ExampleConfig(
631 ... PipelineTaskConfig, pipelineConnections=ExampleConnections
632 ... ):
633 ... pass
634 >>> config = ExampleConfig()
635 >>> config.connections.foo = Modified
636 >>> config.connections.outputConnection = "TotallyDifferent"
637 >>> connections = ExampleConnections(config=config)
638 >>> assert connections.inputConnection.name == "ModifiedDataset"
639 >>> assert connections.outputConnection.name == "TotallyDifferent"
640 """
642 # We annotate these attributes as mutable sets because that's what they are
643 # inside derived ``__init__`` implementations and that's what matters most
644 # After that's done, the metaclass __call__ makes them into frozensets, but
645 # relatively little code interacts with them then, and that code knows not
646 # to try to modify them without having to be told that by mypy.
648 dimensions: set[str]
649 """Set of dimension names that define the unit of work for this task.
651 Required and implied dependencies will automatically be expanded later and
652 need not be provided.
654 This may be replaced or modified in ``__init__`` to change the dimensions
655 of the task. After ``__init__`` it will be a `frozenset` and may not be
656 replaced.
657 """
659 inputs: set[str]
660 """Set with the names of all `connectionTypes.Input` connection attributes.
662 This is updated automatically as class attributes are added, removed, or
663 replaced in ``__init__``. Removing entries from this set will cause those
664 connections to be removed after ``__init__`` completes, but this is
665 supported only for backwards compatibility; new code should instead just
666 delete the collection attributed directly. After ``__init__`` this will be
667 a `frozenset` and may not be replaced.
668 """
670 prerequisiteInputs: set[str]
671 """Set with the names of all `~connectionTypes.PrerequisiteInput`
672 connection attributes.
674 See `inputs` for additional information.
675 """
677 outputs: set[str]
678 """Set with the names of all `~connectionTypes.Output` connection
679 attributes.
681 See `inputs` for additional information.
682 """
684 initInputs: set[str]
685 """Set with the names of all `~connectionTypes.InitInput` connection
686 attributes.
688 See `inputs` for additional information.
689 """
691 initOutputs: set[str]
692 """Set with the names of all `~connectionTypes.InitOutput` connection
693 attributes.
695 See `inputs` for additional information.
696 """
698 allConnections: Mapping[str, BaseConnection]
699 """Mapping holding all connection attributes.
701 This is a read-only view that is automatically updated when connection
702 attributes are added, removed, or replaced in ``__init__``. It is also
703 updated after ``__init__`` completes to reflect changes in `inputs`,
704 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
705 """
707 _allConnections: dict[str, BaseConnection]
709 def __init__(self, *, config: PipelineTaskConfig | None = None):
710 pass
712 def __setattr__(self, name: str, value: Any) -> None:
713 if isinstance(value, BaseConnection):
714 previous = self._allConnections.get(name)
715 try:
716 getattr(self, value._connection_type_set).add(name)
717 except AttributeError:
718 # Attempt to call add on a frozenset, which is what these sets
719 # are after __init__ is done.
720 raise TypeError("Connections objects are frozen after construction.") from None
721 if previous is not None and value._connection_type_set != previous._connection_type_set:
722 # Connection has changed type, e.g. Input to PrerequisiteInput;
723 # update the sets accordingly. To be extra defensive about
724 # multiple assignments we use the type of the previous instance
725 # instead of assuming that's the same as the type of the self,
726 # which is just the default. Use discard instead of remove so
727 # manually removing from these sets first is never an error.
728 getattr(self, previous._connection_type_set).discard(name)
729 self._allConnections[name] = value
730 if hasattr(self.__class__, name):
731 # Don't actually set the attribute if this was a connection
732 # declared in the class; in that case we let the descriptor
733 # return the value we just added to allConnections.
734 return
735 # Actually add the attribute.
736 super().__setattr__(name, value)
738 def __delattr__(self, name):
739 """Descriptor delete method."""
740 previous = self._allConnections.get(name)
741 if previous is not None:
742 # Delete this connection's name from the appropriate set, which we
743 # have to get from the previous instance instead of assuming it's
744 # the same set that was appropriate for the class-level default.
745 # Use discard instead of remove so manually removing from these
746 # sets first is never an error.
747 try:
748 getattr(self, previous._connection_type_set).discard(name)
749 except AttributeError:
750 # Attempt to call discard on a frozenset, which is what these
751 # sets are after __init__ is done.
752 raise TypeError("Connections objects are frozen after construction.") from None
753 del self._allConnections[name]
754 if hasattr(self.__class__, name):
755 # Don't actually delete the attribute if this was a connection
756 # declared in the class; in that case we let the descriptor
757 # see that it's no longer present in allConnections.
758 return
759 # Actually delete the attribute.
760 super().__delattr__(name)
762 def buildDatasetRefs(
763 self, quantum: Quantum
764 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
765 """Build `QuantizedConnection` corresponding to input
766 `~lsst.daf.butler.Quantum`.
768 Parameters
769 ----------
770 quantum : `lsst.daf.butler.Quantum`
771 Quantum object which defines the inputs and outputs for a given
772 unit of processing.
774 Returns
775 -------
776 retVal : `tuple` of (`InputQuantizedConnection`, \
777 `OutputQuantizedConnection`)
778 Namespaces mapping attribute names
779 (identifiers of connections) to butler references defined in the
780 input `~lsst.daf.butler.Quantum`.
781 """
782 inputDatasetRefs = InputQuantizedConnection()
783 outputDatasetRefs = OutputQuantizedConnection()
785 # populate inputDatasetRefs from quantum inputs
786 for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs):
787 # get the attribute identified by name
788 attribute = getattr(self, attributeName)
789 # if the dataset is marked to load deferred, wrap it in a
790 # DeferredDatasetRef
791 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
792 if attribute.deferLoad:
793 quantumInputRefs = [
794 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
795 ]
796 else:
797 quantumInputRefs = list(quantum.inputs[attribute.name])
798 # Unpack arguments that are not marked multiples (list of
799 # length one)
800 if not attribute.multiple:
801 if len(quantumInputRefs) > 1:
802 raise ScalarError(
803 "Received multiple datasets "
804 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
805 f"for scalar connection {attributeName} "
806 f"({quantumInputRefs[0].datasetType.name}) "
807 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
808 )
809 if len(quantumInputRefs) == 0:
810 continue
811 setattr(inputDatasetRefs, attributeName, quantumInputRefs[0])
812 else:
813 # Add to the QuantizedConnection identifier
814 setattr(inputDatasetRefs, attributeName, quantumInputRefs)
816 # populate outputDatasetRefs from quantum outputs
817 for attributeName in self.outputs:
818 # get the attribute identified by name
819 attribute = getattr(self, attributeName)
820 value = quantum.outputs[attribute.name]
821 # Unpack arguments that are not marked multiples (list of
822 # length one)
823 if not attribute.multiple:
824 setattr(outputDatasetRefs, attributeName, value[0])
825 else:
826 setattr(outputDatasetRefs, attributeName, value)
828 return inputDatasetRefs, outputDatasetRefs
830 def adjustQuantum(
831 self,
832 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
833 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
834 label: str,
835 data_id: DataCoordinate,
836 ) -> tuple[
837 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
838 Mapping[str, tuple[Output, Collection[DatasetRef]]],
839 ]:
840 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
841 in the `lsst.daf.butler.Quantum` during the graph generation stage
842 of the activator.
844 Parameters
845 ----------
846 inputs : `dict`
847 Dictionary whose keys are an input (regular or prerequisite)
848 connection name and whose values are a tuple of the connection
849 instance and a collection of associated
850 `~lsst.daf.butler.DatasetRef` objects.
851 The exact type of the nested collections is unspecified; it can be
852 assumed to be multi-pass iterable and support `len` and ``in``, but
853 it should not be mutated in place. In contrast, the outer
854 dictionaries are guaranteed to be temporary copies that are true
855 `dict` instances, and hence may be modified and even returned; this
856 is especially useful for delegating to `super` (see notes below).
857 outputs : `~collections.abc.Mapping`
858 Mapping of output datasets, with the same structure as ``inputs``.
859 label : `str`
860 Label for this task in the pipeline (should be used in all
861 diagnostic messages).
862 data_id : `lsst.daf.butler.DataCoordinate`
863 Data ID for this quantum in the pipeline (should be used in all
864 diagnostic messages).
866 Returns
867 -------
868 adjusted_inputs : `~collections.abc.Mapping`
869 Mapping of the same form as ``inputs`` with updated containers of
870 input `~lsst.daf.butler.DatasetRef` objects. Connections that are
871 not changed should not be returned at all. Datasets may only be
872 removed, not added. Nested collections may be of any multi-pass
873 iterable type, and the order of iteration will set the order of
874 iteration within `PipelineTask.runQuantum`.
875 adjusted_outputs : `~collections.abc.Mapping`
876 Mapping of updated output datasets, with the same structure and
877 interpretation as ``adjusted_inputs``.
879 Raises
880 ------
881 ScalarError
882 Raised if any `Input` or `PrerequisiteInput` connection has
883 ``multiple`` set to `False`, but multiple datasets.
884 NoWorkFound
885 Raised to indicate that this quantum should not be run; not enough
886 datasets were found for a regular `Input` connection, and the
887 quantum should be pruned or skipped.
888 FileNotFoundError
889 Raised to cause QuantumGraph generation to fail (with the message
890 included in this exception); not enough datasets were found for a
891 `PrerequisiteInput` connection.
893 Notes
894 -----
895 The base class implementation performs important checks. It always
896 returns an empty mapping (i.e. makes no adjustments). It should
897 always called be via `super` by custom implementations, ideally at the
898 end of the custom implementation with already-adjusted mappings when
899 any datasets are actually dropped, e.g.:
901 .. code-block:: python
903 def adjustQuantum(self, inputs, outputs, label, data_id):
904 # Filter out some dataset refs for one connection.
905 connection, old_refs = inputs["my_input"]
906 new_refs = [ref for ref in old_refs if ...]
907 adjusted_inputs = {"my_input", (connection, new_refs)}
908 # Update the original inputs so we can pass them to super.
909 inputs.update(adjusted_inputs)
910 # Can ignore outputs from super because they are guaranteed
911 # to be empty.
912 super().adjustQuantum(inputs, outputs, label_data_id)
913 # Return only the connections we modified.
914 return adjusted_inputs, {}
916 Removing outputs here is guaranteed to affect what is actually
917 passed to `PipelineTask.runQuantum`, but its effect on the larger
918 graph may be deferred to execution, depending on the context in
919 which `adjustQuantum` is being run: if one quantum removes an output
920 that is needed by a second quantum as input, the second quantum may not
921 be adjusted (and hence pruned or skipped) until that output is actually
922 found to be missing at execution time.
924 Tasks that desire zip-iteration consistency between any combinations of
925 connections that have the same data ID should generally implement
926 `adjustQuantum` to achieve this, even if they could also run that
927 logic during execution; this allows the system to see outputs that will
928 not be produced because the corresponding input is missing as early as
929 possible.
930 """
931 for name, (input_connection, refs) in inputs.items():
932 dataset_type_name = input_connection.name
933 if not input_connection.multiple and len(refs) > 1:
934 raise ScalarError(
935 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
936 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
937 f"for quantum data ID {data_id}."
938 )
939 if len(refs) < input_connection.minimum:
940 if isinstance(input_connection, PrerequisiteInput):
941 # This branch should only be possible during QG generation,
942 # or if someone deleted the dataset between making the QG
943 # and trying to run it. Either one should be a hard error.
944 raise FileNotFoundError(
945 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
946 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
947 f"{data_id}."
948 )
949 else:
950 raise NoWorkFound(label, name, input_connection)
951 for name, (output_connection, refs) in outputs.items():
952 dataset_type_name = output_connection.name
953 if not output_connection.multiple and len(refs) > 1:
954 raise ScalarError(
955 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
956 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
957 f"for quantum data ID {data_id}."
958 )
959 return {}, {}
961 def getSpatialBoundsConnections(self) -> Iterable[str]:
962 """Return the names of regular input and output connections whose data
963 IDs should be used to compute the spatial bounds of this task's quanta.
965 The spatial bound for a quantum is defined as the union of the regions
966 of all data IDs of all connections returned here, along with the region
967 of the quantum data ID (if the task has spatial dimensions).
969 Returns
970 -------
971 connection_names : `collections.abc.Iterable` [ `str` ]
972 Names of collections with spatial dimensions. These are the
973 task-internal connection names, not butler dataset type names.
975 Notes
976 -----
977 The spatial bound is used to search for prerequisite inputs that have
978 skypix dimensions. The default implementation returns an empty
979 iterable, which is usually sufficient for tasks with spatial
980 dimensions, but if a task's inputs or outputs are associated with
981 spatial regions that extend beyond the quantum data ID's region, this
982 method may need to be overridden to expand the set of prerequisite
983 inputs found.
985 Tasks that do not have spatial dimensions that have skypix prerequisite
986 inputs should always override this method, as the default spatial
987 bounds otherwise cover the full sky.
988 """
989 return ()
991 def getTemporalBoundsConnections(self) -> Iterable[str]:
992 """Return the names of regular input and output connections whose data
993 IDs should be used to compute the temporal bounds of this task's
994 quanta.
996 The temporal bound for a quantum is defined as the union of the
997 timespans of all data IDs of all connections returned here, along with
998 the timespan of the quantum data ID (if the task has temporal
999 dimensions).
1001 Returns
1002 -------
1003 connection_names : `collections.abc.Iterable` [ `str` ]
1004 Names of collections with temporal dimensions. These are the
1005 task-internal connection names, not butler dataset type names.
1007 Notes
1008 -----
1009 The temporal bound is used to search for prerequisite inputs that are
1010 calibration datasets. The default implementation returns an empty
1011 iterable, which is usually sufficient for tasks with temporal
1012 dimensions, but if a task's inputs or outputs are associated with
1013 timespans that extend beyond the quantum data ID's timespan, this
1014 method may need to be overridden to expand the set of prerequisite
1015 inputs found.
1017 Tasks that do not have temporal dimensions that do not implement this
1018 method will use an infinite timespan for any calibration lookups.
1019 """
1020 return ()
1022 def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None:
1023 """Customize the set of quanta predicted for this task during quantum
1024 graph generation.
1026 Parameters
1027 ----------
1028 adjuster : `QuantaAdjuster`
1029 A helper object that implementations can use to modify the
1030 under-construction quantum graph.
1032 Notes
1033 -----
1034 This hook is called before `adjustQuantum`, which is where built-in
1035 checks for `NoWorkFound` cases and missing prerequisites are handled.
1036 This means that the set of preliminary quanta seen by this method could
1037 include some that would normally be dropped later.
1038 """
1039 pass
1042def iterConnections(
1043 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
1044) -> Generator[BaseConnection]:
1045 """Create an iterator over the selected connections type which yields
1046 all the defined connections of that type.
1048 Parameters
1049 ----------
1050 connections : `PipelineTaskConnections`
1051 An instance of a `PipelineTaskConnections` object that will be iterated
1052 over.
1053 connectionType : `str`
1054 The type of connections to iterate over, valid values are inputs,
1055 outputs, prerequisiteInputs, initInputs, initOutputs.
1057 Yields
1058 ------
1059 connection: `~.connectionTypes.BaseConnection`
1060 A connection defined on the input connections object of the type
1061 supplied. The yielded value Will be an derived type of
1062 `~.connectionTypes.BaseConnection`.
1063 """
1064 if isinstance(connectionType, str):
1065 connectionType = (connectionType,)
1066 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
1067 yield getattr(connections, name)
1070@dataclass
1071class AdjustQuantumHelper:
1072 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
1074 This class holds `inputs` and `outputs` mappings in the form used by
1075 `lsst.daf.butler.Quantum` and execution harness code, i.e. with
1076 `~lsst.daf.butler.DatasetType` keys, translating them to and from the
1077 connection-oriented mappings used inside `PipelineTaskConnections`.
1078 """
1080 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1081 """Mapping of regular input and prerequisite input datasets, grouped by
1082 `~lsst.daf.butler.DatasetType`.
1083 """
1085 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1086 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
1087 """
1089 inputs_adjusted: bool = False
1090 """Whether any inputs were removed in the last call to `adjust_in_place`.
1091 """
1093 outputs_adjusted: bool = False
1094 """Whether any outputs were removed in the last call to `adjust_in_place`.
1095 """
1097 def adjust_in_place(
1098 self,
1099 connections: PipelineTaskConnections,
1100 label: str,
1101 data_id: DataCoordinate,
1102 ) -> None:
1103 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
1104 with its results.
1106 Parameters
1107 ----------
1108 connections : `PipelineTaskConnections`
1109 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
1110 label : `str`
1111 Label for this task in the pipeline (should be used in all
1112 diagnostic messages).
1113 data_id : `lsst.daf.butler.DataCoordinate`
1114 Data ID for this quantum in the pipeline (should be used in all
1115 diagnostic messages).
1116 """
1117 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
1118 # connection-keyed, PipelineTask-oriented mappings.
1119 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
1120 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
1121 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
1122 connection = getattr(connections, name)
1123 dataset_type_name = connection.name
1124 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
1125 for name in itertools.chain(connections.outputs):
1126 connection = getattr(connections, name)
1127 dataset_type_name = connection.name
1128 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
1129 # Actually call adjustQuantum.
1130 # MyPy correctly complains that this call is not quite legal, but the
1131 # method docs explain exactly what's expected and it's the behavior we
1132 # want. It'd be nice to avoid this if we ever have to change the
1133 # interface anyway, but not an immediate problem.
1134 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
1135 inputs_by_connection, # type: ignore
1136 outputs_by_connection, # type: ignore
1137 label,
1138 data_id,
1139 )
1140 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
1141 # installing new mappings in self if necessary.
1142 if adjusted_inputs_by_connection:
1143 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
1144 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
1145 dataset_type_name = connection.name
1146 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
1147 raise RuntimeError(
1148 f"adjustQuantum implementation for task with label {label} returned {name} "
1149 f"({dataset_type_name}) input datasets that are not a subset of those "
1150 f"it was given for data ID {data_id}."
1151 )
1152 adjusted_inputs[dataset_type_name] = tuple(updated_refs)
1153 self.inputs = adjusted_inputs.freeze()
1154 self.inputs_adjusted = True
1155 else:
1156 self.inputs_adjusted = False
1157 if adjusted_outputs_by_connection:
1158 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
1159 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1160 dataset_type_name = connection.name
1161 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1162 raise RuntimeError(
1163 f"adjustQuantum implementation for task with label {label} returned {name} "
1164 f"({dataset_type_name}) output datasets that are not a subset of those "
1165 f"it was given for data ID {data_id}."
1166 )
1167 adjusted_outputs[dataset_type_name] = tuple(updated_refs)
1168 self.outputs = adjusted_outputs.freeze()
1169 self.outputs_adjusted = True
1170 else:
1171 self.outputs_adjusted = False
1174class QuantaAdjuster:
1175 """A helper class for the `PipelineTaskConnections.adjust_all_quanta` hook.
1177 Parameters
1178 ----------
1179 task_label : `str`
1180 Label of the task whose quanta are being adjusted.
1181 pipeline_graph : `pipeline_graph.PipelineGraph`
1182 Pipeline graph the quantum graph is being built from.
1183 skeleton : `quantum_graph_skeleton.QuantumGraphSkeleton`
1184 Under-construction quantum graph that will be modified in place.
1185 butler : `lsst.daf.butler.Butler`
1186 Read-only instance with its default collection search path set to the
1187 input collections passed to the quantum-graph builder.
1188 """
1190 def __init__(
1191 self, task_label: str, pipeline_graph: PipelineGraph, skeleton: QuantumGraphSkeleton, butler: Butler
1192 ):
1193 self._task_node = pipeline_graph.tasks[task_label]
1194 self._pipeline_graph = pipeline_graph
1195 self._skeleton = skeleton
1196 self._n_removed = 0
1197 self._butler = butler
1199 @property
1200 def task_label(self) -> str:
1201 """The label this task has been configured with."""
1202 return self._task_node.label
1204 @property
1205 def task_node(self) -> TaskNode:
1206 """The node for this task in the pipeline graph."""
1207 return self._task_node
1209 @property
1210 def butler(self) -> Butler:
1211 """Read-only instance with its default collection search path set to
1212 the input collections passed to the quantum-graph builder.
1213 """
1214 return self._butler
1216 def iter_data_ids(self) -> Iterator[DataCoordinate]:
1217 """Iterate over the data IDs of all quanta for this task.
1219 Returns
1220 -------
1221 data_ids : `~collections.abc.Iterator` [ \
1222 `~lsst.daf.butler.DataCoordinate` ]
1223 Data IDs. These are minimal data IDs without dimension records or
1224 implied values; use `expand_quantum_data_id` to get a full data ID
1225 when needed.
1226 """
1227 for key in self._skeleton.get_quanta(self._task_node.label):
1228 yield DataCoordinate.from_required_values(self._task_node.dimensions, key.data_id_values)
1230 def remove_quantum(self, data_id: DataCoordinate) -> None:
1231 """Remove a quantum from the graph.
1233 Parameters
1234 ----------
1235 data_id : `~lsst.daf.butler.DataCoordinate`
1236 Data ID of the quantum to remove. All outputs will be removed as
1237 well.
1238 """
1239 from .quantum_graph_skeleton import QuantumKey
1241 self._skeleton.remove_quantum_node(
1242 QuantumKey(self._task_node.label, data_id.required_values), remove_outputs=True
1243 )
1244 self._n_removed += 1
1246 def get_inputs(self, quantum_data_id: DataCoordinate) -> dict[str, list[DataCoordinate]]:
1247 """Return the data IDs of all regular inputs to a quantum.
1249 Parameters
1250 ----------
1251 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1252 Data ID of the quantum to get the inputs of.
1254 Returns
1255 -------
1256 inputs : `dict` [ `str`, `list` [ `~lsst.daf.butler.DataCoordinate` ] ]
1257 Data IDs of inputs, keyed by the connection name (the internal task
1258 name, not the dataset type name). This only contains regular
1259 inputs, not init-inputs or prerequisite inputs.
1261 Notes
1262 -----
1263 If two connections have the same dataset type, the current
1264 implementation assumes the set of datasets is the same for the two
1265 connections. This limitation may be removed in the future.
1266 """
1267 from .quantum_graph_skeleton import DatasetKey, QuantumKey
1269 by_dataset_type_name: defaultdict[str, list[DataCoordinate]] = defaultdict(list)
1270 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1271 for dataset_key in self._skeleton.iter_inputs_of(quantum_key):
1272 if not isinstance(dataset_key, DatasetKey):
1273 continue
1274 dataset_type_node = self._pipeline_graph.dataset_types[dataset_key.parent_dataset_type_name]
1275 by_dataset_type_name[dataset_key.parent_dataset_type_name].append(
1276 DataCoordinate.from_required_values(dataset_type_node.dimensions, dataset_key.data_id_values)
1277 )
1278 return {
1279 edge.connection_name: by_dataset_type_name[edge.parent_dataset_type_name]
1280 for edge in self._task_node.iter_all_inputs()
1281 }
1283 def get_prerequisite_inputs(
1284 self,
1285 quantum_data_id: DataCoordinate,
1286 ) -> dict[str, dict[uuid.UUID, DataCoordinate]]:
1287 """Return the data IDs of all prerequisite inputs to a quantum.
1289 Parameters
1290 ----------
1291 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1292 Data ID of the quantum to get the inputs of.
1294 Returns
1295 -------
1296 inputs : `dict` [ `str`, \
1297 `dict` [ `uuid.UUID`, `~lsst.daf.butler.DataCoordinate` ] ]
1298 Dataset IDs and and data IDs of prerequisite inputs, keyed by the
1299 connection name (the internal task name, not the dataset type
1300 name). This only contains prerequisite inputs, not init-inputs or
1301 regular inputs.
1303 Notes
1304 -----
1305 If two connections have the same dataset type, the current
1306 implementation assumes the set of datasets is the same for the two
1307 connections. This limitation may be removed in the future.
1309 Unlike regular inputs, prerequisite inputs are not looked up from input
1310 collections or indexed by data ID. Instead, they are uniquely
1311 identified by dataset UUID and reused directly between quanta.
1312 """
1313 from .quantum_graph_skeleton import PrerequisiteDatasetKey, QuantumKey
1315 by_dataset_type_name: defaultdict[str, dict[uuid.UUID, DataCoordinate]] = defaultdict(dict)
1316 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1318 for dataset_key in self._skeleton.iter_inputs_of(quantum_key):
1319 if not isinstance(dataset_key, PrerequisiteDatasetKey):
1320 continue
1321 by_dataset_type_name[dataset_key.parent_dataset_type_name][
1322 uuid.UUID(bytes=dataset_key.dataset_id_bytes)
1323 ] = self._skeleton.get_data_id(dataset_key)
1324 return {
1325 edge.connection_name: by_dataset_type_name[edge.parent_dataset_type_name]
1326 for edge in self._task_node.iter_all_inputs()
1327 }
1329 def get_outputs(self, quantum_data_id: DataCoordinate) -> dict[str, list[DataCoordinate]]:
1330 """Return the data IDs of all regular outputs to a quantum.
1332 Parameters
1333 ----------
1334 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1335 Data ID of the quantum to get the outputs of.
1337 Returns
1338 -------
1339 inputs : `dict` [ `str`, `list` [ `~lsst.daf.butler.DataCoordinate` ] ]
1340 Data IDs of inputs, keyed by the connection name (the internal task
1341 name, not the dataset type name). This only contains regular
1342 outputs, not init-outputs or log or metadata outputs.
1344 Notes
1345 -----
1346 If two connections have the same dataset type, the current
1347 implementation assumes the set of datasets is the same for the two
1348 connections. This limitation may be removed in the future.
1349 """
1350 from .quantum_graph_skeleton import QuantumKey
1352 by_dataset_type_name: defaultdict[str, list[DataCoordinate]] = defaultdict(list)
1353 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1354 for dataset_key in self._skeleton.iter_outputs_of(quantum_key):
1355 dataset_type_node = self._pipeline_graph.dataset_types[dataset_key.parent_dataset_type_name]
1356 by_dataset_type_name[dataset_key.parent_dataset_type_name].append(
1357 DataCoordinate.from_required_values(dataset_type_node.dimensions, dataset_key.data_id_values)
1358 )
1359 return {
1360 edge.connection_name: by_dataset_type_name[edge.parent_dataset_type_name]
1361 for edge in self._task_node.outputs.values()
1362 }
1364 def add_input(
1365 self, quantum_data_id: DataCoordinate, connection_name: str, dataset_data_id: DataCoordinate
1366 ) -> None:
1367 """Add a new input to a quantum.
1369 Parameters
1370 ----------
1371 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1372 Data ID of the quantum to add an input to.
1373 connection_name : `str`
1374 Name of the connection (the task-internal name, not the butler
1375 dataset type name).
1376 dataset_data_id : `~lsst.daf.butler.DataCoordinate`
1377 Data ID of the input dataset. Must already exist in the graph
1378 as an input to a different quantum of this task, and must be a
1379 regular input, not a prerequisite input or init-input.
1381 Notes
1382 -----
1383 If two connections have the same dataset type, the current
1384 implementation assumes the set of datasets is the same for the two
1385 connections. This limitation may be removed in the future.
1386 """
1387 from .quantum_graph_skeleton import DatasetKey, QuantumKey
1389 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1390 read_edge = self._task_node.inputs[connection_name]
1391 dataset_key = DatasetKey(read_edge.parent_dataset_type_name, dataset_data_id.required_values)
1392 if dataset_key not in self._skeleton:
1393 raise LookupError(
1394 f"Dataset {read_edge.parent_dataset_type_name}@{dataset_data_id} is not already in the graph."
1395 )
1396 self._skeleton.add_input_edge(quantum_key, dataset_key)
1398 def add_prerequisite_input(
1399 self, quantum_data_id: DataCoordinate, connection_name: str, dataset_uuid: uuid.UUID
1400 ) -> None:
1401 """Add a new prerequisite input to a quantum.
1403 Parameters
1404 ----------
1405 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1406 Data ID of the quantum to add an input to.
1407 connection_name : `str`
1408 Name of the connection (the task-internal name, not the butler
1409 dataset type name).
1410 dataset_uuid : `uuid.UUID`
1411 UUID of the prerequisite input dataset. Must already exist in the
1412 graph as an input to a different quantum of this task, and must be
1413 a prerequisite input, not a regular input or init-input.
1415 Notes
1416 -----
1417 If two connections have the same dataset type, the current
1418 implementation assumes the set of datasets is the same for the two
1419 connections. This limitation may be removed in the future.
1421 Unlike regular inputs, prerequisite inputs are not looked up from input
1422 collections or indexed by data ID. Instead, they are uniquely
1423 identified by dataset UUID and reused directly between quanta.
1424 """
1425 from .quantum_graph_skeleton import PrerequisiteDatasetKey, QuantumKey
1427 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1428 read_edge = self._task_node.prerequisite_inputs[connection_name]
1429 dataset_key = PrerequisiteDatasetKey(read_edge.parent_dataset_type_name, dataset_uuid.bytes)
1430 if dataset_key not in self._skeleton:
1431 raise LookupError(
1432 f"Prerequisite Dataset {read_edge.parent_dataset_type_name}@{dataset_uuid} "
1433 "is not already in the graph."
1434 )
1435 self._skeleton.add_input_edge(quantum_key, dataset_key)
1437 def move_output(
1438 self, quantum_data_id: DataCoordinate, connection_name: str, dataset_data_id: DataCoordinate
1439 ) -> None:
1440 """Remove an output of one quantum and make it a new output of another.
1442 Parameters
1443 ----------
1444 quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1445 Data ID of the quantum to move the output to.
1446 connection_name : `str`
1447 Name of the connection (the task-internal name, not the butler
1448 dataset type name).
1449 dataset_data_id : `~lsst.daf.butler.DataCoordinate`
1450 Data ID of the output dataset. Must already exist in the graph
1451 as an output of a different quantum of this task.
1452 """
1453 from .quantum_graph_skeleton import DatasetKey, QuantumKey
1455 quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1456 write_edge = self._task_node.outputs[connection_name]
1457 dataset_key = DatasetKey(write_edge.parent_dataset_type_name, dataset_data_id.required_values)
1458 if dataset_key not in self._skeleton:
1459 raise LookupError(
1460 f"Dataset {write_edge.parent_dataset_type_name}@{dataset_data_id} is "
1461 "not already in the graph."
1462 )
1463 self._skeleton.remove_output_edge(dataset_key)
1464 self._skeleton.add_output_edge(quantum_key, dataset_key)
1466 def expand_quantum_data_id(self, data_id: DataCoordinate) -> DataCoordinate:
1467 """Expand a quantum data ID to include implied values and dimension
1468 records.
1470 Parameters
1471 ----------
1472 data_id : `~lsst.daf.butler.DataCoordinate`
1473 A data ID of a quantum already in the graph.
1475 Returns
1476 -------
1477 expanded_data_id : `~lsst.daf.butler.DataCoordinate`
1478 The same data ID, with implied values included and dimension
1479 records attached.
1480 """
1481 from .quantum_graph_skeleton import QuantumKey
1483 return self._skeleton.get_data_id(QuantumKey(self._task_node.label, data_id.required_values))
1485 @property
1486 def n_removed(self) -> int:
1487 """The number of quanta that have been removed by this helper."""
1488 return self._n_removed