Coverage for python/lsst/pipe/base/connections.py: 42%
302 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 11:14 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 11:14 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError",
36]
38import dataclasses
39import itertools
40import string
41import warnings
42from collections import UserDict
43from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
44from dataclasses import dataclass
45from types import MappingProxyType, SimpleNamespace
46from typing import TYPE_CHECKING, Any
48from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
49from lsst.utils.introspection import find_outside_stacklevel
51from ._status import NoWorkFound
52from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
54if TYPE_CHECKING:
55 from .config import PipelineTaskConfig
58class ScalarError(TypeError):
59 """Exception raised when dataset type is configured as scalar
60 but there are multiple data IDs in a Quantum for that dataset.
61 """
64class PipelineTaskConnectionDict(UserDict):
65 """A special dict class used by `PipelineTaskConnectionMetaclass`.
67 This dict is used in `PipelineTaskConnection` class creation, as the
68 dictionary that is initially used as ``__dict__``. It exists to
69 intercept connection fields declared in a `PipelineTaskConnection`, and
70 what name is used to identify them. The names are then added to class
71 level list according to the connection type of the class attribute. The
72 names are also used as keys in a class level dictionary associated with
73 the corresponding class attribute. This information is a duplicate of
74 what exists in ``__dict__``, but provides a simple place to lookup and
75 iterate on only these variables.
76 """
78 def __init__(self, *args: Any, **kwargs: Any):
79 super().__init__(*args, **kwargs)
80 # Initialize class level variables used to track any declared
81 # class level variables that are instances of
82 # connectionTypes.BaseConnection
83 self.data["inputs"] = set()
84 self.data["prerequisiteInputs"] = set()
85 self.data["outputs"] = set()
86 self.data["initInputs"] = set()
87 self.data["initOutputs"] = set()
88 self.data["allConnections"] = {}
90 def __setitem__(self, name: str, value: Any) -> None:
91 if isinstance(value, BaseConnection):
92 if name in { 92 ↛ 102line 92 didn't jump to line 102, because the condition on line 92 was never true
93 "dimensions",
94 "inputs",
95 "prerequisiteInputs",
96 "outputs",
97 "initInputs",
98 "initOutputs",
99 "allConnections",
100 }:
101 # Guard against connections whose names are reserved.
102 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
103 if (previous := self.data.get(name)) is not None: 103 ↛ 106line 103 didn't jump to line 106, because the condition on line 103 was never true
104 # Guard against changing the type of an in inherited connection
105 # by first removing it from the set it's current in.
106 self.data[previous._connection_type_set].discard(name)
107 object.__setattr__(value, "varName", name)
108 self.data["allConnections"][name] = value
109 self.data[value._connection_type_set].add(name)
110 # defer to the default behavior
111 super().__setitem__(name, value)
114class PipelineTaskConnectionsMetaclass(type):
115 """Metaclass used in the declaration of PipelineTaskConnections classes"""
117 # We can annotate these attributes as `collections.abc.Set` to discourage
118 # undesirable modifications in type-checked code, since the internal code
119 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
120 # these annotations anyway.
122 dimensions: Set[str]
123 """Set of dimension names that define the unit of work for this task.
125 Required and implied dependencies will automatically be expanded later and
126 need not be provided.
128 This is shadowed by an instance-level attribute on
129 `PipelineTaskConnections` instances.
130 """
132 inputs: Set[str]
133 """Set with the names of all `~connectionTypes.Input` connection
134 attributes.
136 This is updated automatically as class attributes are added. Note that
137 this attribute is shadowed by an instance-level attribute on
138 `PipelineTaskConnections` instances.
139 """
141 prerequisiteInputs: Set[str]
142 """Set with the names of all `~connectionTypes.PrerequisiteInput`
143 connection attributes.
145 See `inputs` for additional information.
146 """
148 outputs: Set[str]
149 """Set with the names of all `~connectionTypes.Output` connection
150 attributes.
152 See `inputs` for additional information.
153 """
155 initInputs: Set[str]
156 """Set with the names of all `~connectionTypes.InitInput` connection
157 attributes.
159 See `inputs` for additional information.
160 """
162 initOutputs: Set[str]
163 """Set with the names of all `~connectionTypes.InitOutput` connection
164 attributes.
166 See `inputs` for additional information.
167 """
169 allConnections: Mapping[str, BaseConnection]
170 """Mapping containing all connection attributes.
172 See `inputs` for additional information.
173 """
175 def __prepare__(name, bases, **kwargs): # noqa: 805
176 # Create an instance of our special dict to catch and track all
177 # variables that are instances of connectionTypes.BaseConnection
178 # Copy any existing connections from a parent class
179 dct = PipelineTaskConnectionDict()
180 for base in bases:
181 if isinstance(base, PipelineTaskConnectionsMetaclass): 181 ↛ 180line 181 didn't jump to line 180, because the condition on line 181 was never false
182 for name, value in base.allConnections.items(): 182 ↛ 183line 182 didn't jump to line 183, because the loop on line 182 never started
183 dct[name] = value
184 return dct
186 def __new__(cls, name, bases, dct, **kwargs):
187 dimensionsValueError = TypeError(
188 "PipelineTaskConnections class must be created with a dimensions "
189 "attribute which is an iterable of dimension names"
190 )
192 if name != "PipelineTaskConnections":
193 # Verify that dimensions are passed as a keyword in class
194 # declaration
195 if "dimensions" not in kwargs: 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true
196 for base in bases:
197 if hasattr(base, "dimensions"):
198 kwargs["dimensions"] = base.dimensions
199 break
200 if "dimensions" not in kwargs:
201 raise dimensionsValueError
202 try:
203 if isinstance(kwargs["dimensions"], str): 203 ↛ 204line 203 didn't jump to line 204, because the condition on line 203 was never true
204 raise TypeError(
205 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
206 )
207 if not isinstance(kwargs["dimensions"], Iterable): 207 ↛ 208line 207 didn't jump to line 208, because the condition on line 207 was never true
208 raise TypeError("Dimensions must be iterable of dimensions")
209 dct["dimensions"] = set(kwargs["dimensions"])
210 except TypeError as exc:
211 raise dimensionsValueError from exc
212 # Lookup any python string templates that may have been used in the
213 # declaration of the name field of a class connection attribute
214 allTemplates = set()
215 stringFormatter = string.Formatter()
216 # Loop over all connections
217 for obj in dct["allConnections"].values():
218 nameValue = obj.name
219 # add all the parameters to the set of templates
220 for param in stringFormatter.parse(nameValue):
221 if param[1] is not None:
222 allTemplates.add(param[1])
224 # look up any template from base classes and merge them all
225 # together
226 mergeDict = {}
227 mergeDeprecationsDict = {}
228 for base in bases[::-1]:
229 if hasattr(base, "defaultTemplates"): 229 ↛ 230line 229 didn't jump to line 230, because the condition on line 229 was never true
230 mergeDict.update(base.defaultTemplates)
231 if hasattr(base, "deprecatedTemplates"): 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true
232 mergeDeprecationsDict.update(base.deprecatedTemplates)
233 if "defaultTemplates" in kwargs:
234 mergeDict.update(kwargs["defaultTemplates"])
235 if "deprecatedTemplates" in kwargs: 235 ↛ 236line 235 didn't jump to line 236, because the condition on line 235 was never true
236 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
237 if len(mergeDict) > 0:
238 kwargs["defaultTemplates"] = mergeDict
239 if len(mergeDeprecationsDict) > 0: 239 ↛ 240line 239 didn't jump to line 240, because the condition on line 239 was never true
240 kwargs["deprecatedTemplates"] = mergeDeprecationsDict
242 # Verify that if templated strings were used, defaults were
243 # supplied as an argument in the declaration of the connection
244 # class
245 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 245 ↛ 246line 245 didn't jump to line 246, because the condition on line 245 was never true
246 raise TypeError(
247 "PipelineTaskConnection class contains templated attribute names, but no "
248 "defaut templates were provided, add a dictionary attribute named "
249 "defaultTemplates which contains the mapping between template key and value"
250 )
251 if len(allTemplates) > 0:
252 # Verify all templates have a default, and throw if they do not
253 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
254 templateDifference = allTemplates.difference(defaultTemplateKeys)
255 if templateDifference: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true
256 raise TypeError(f"Default template keys were not provided for {templateDifference}")
257 # Verify that templates do not share names with variable names
258 # used for a connection, this is needed because of how
259 # templates are specified in an associated config class.
260 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
261 if len(nameTemplateIntersection) > 0: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true
262 raise TypeError(
263 "Template parameters cannot share names with Class attributes"
264 f" (conflicts are {nameTemplateIntersection})."
265 )
266 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
267 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})
269 # Convert all the connection containers into frozensets so they cannot
270 # be modified at the class scope
271 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
272 dct[connectionName] = frozenset(dct[connectionName])
273 # our custom dict type must be turned into an actual dict to be used in
274 # type.__new__
275 return super().__new__(cls, name, bases, dict(dct))
277 def __init__(cls, name, bases, dct, **kwargs):
278 # This overrides the default init to drop the kwargs argument. Python
279 # metaclasses will have this argument set if any kwargs are passes at
280 # class construction time, but should be consumed before calling
281 # __init__ on the type metaclass. This is in accordance with python
282 # documentation on metaclasses
283 super().__init__(name, bases, dct)
285 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
286 # MyPy appears not to really understand metaclass.__call__ at all, so
287 # we need to tell it to ignore __new__ and __init__ calls here.
288 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
290 # Make mutable copies of all set-like class attributes so derived
291 # __init__ implementations can modify them in place.
292 instance.dimensions = set(cls.dimensions)
293 instance.inputs = set(cls.inputs)
294 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
295 instance.outputs = set(cls.outputs)
296 instance.initInputs = set(cls.initInputs)
297 instance.initOutputs = set(cls.initOutputs)
299 # Set self.config. It's a bit strange that we claim to accept None but
300 # really just raise here, but it's not worth changing now.
301 from .config import PipelineTaskConfig # local import to avoid cycle
303 if config is None or not isinstance(config, PipelineTaskConfig):
304 raise ValueError(
305 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
306 )
307 instance.config = config
309 # Extract the template names that were defined in the config instance
310 # by looping over the keys of the defaultTemplates dict specified at
311 # class declaration time.
312 templateValues = {
313 name: getattr(config.connections, name) for name in getattr(cls, "defaultTemplates").keys()
314 }
316 # We now assemble a mapping of all connection instances keyed by
317 # internal name, applying the configuration and templates to make new
318 # configurations from the class-attribute defaults. This will be
319 # private, but with a public read-only view. This mapping is what the
320 # descriptor interface of the class-level attributes will return when
321 # they are accessed on an instance. This is better than just assigning
322 # regular instance attributes as it makes it so removed connections
323 # cannot be accessed on instances, instead of having access to them
324 # silent fall through to the not-removed class connection instance.
325 instance._allConnections = {}
326 instance.allConnections = MappingProxyType(instance._allConnections)
327 for internal_name, connection in cls.allConnections.items():
328 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
329 instance_connection = dataclasses.replace(
330 connection,
331 name=dataset_type_name,
332 doc=(
333 connection.doc
334 if connection.deprecated is None
335 else f"{connection.doc}\n{connection.deprecated}"
336 ),
337 )
338 instance._allConnections[internal_name] = instance_connection
340 # Finally call __init__. The base class implementation does nothing;
341 # we could have left some of the above implementation there (where it
342 # originated), but putting it here instead makes it hard for derived
343 # class implementors to get things into a weird state by delegating to
344 # super().__init__ in the wrong place, or by forgetting to do that
345 # entirely.
346 instance.__init__(config=config) # type: ignore
348 # Derived-class implementations may have changed the contents of the
349 # various kinds-of-connection sets; update allConnections to have keys
350 # that are a union of all those. We get values for the new
351 # allConnections from the attributes, since any dynamically added new
352 # ones will not be present in the old allConnections. Typically those
353 # getattrs will invoke the descriptors and get things from the old
354 # allConnections anyway. After processing each set we replace it with
355 # a frozenset.
356 updated_all_connections = {}
357 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
358 updated_connection_names = getattr(instance, attrName)
359 updated_all_connections.update(
360 {name: getattr(instance, name) for name in updated_connection_names}
361 )
362 # Setting these to frozenset is at odds with the type annotation,
363 # but MyPy can't tell because we're using setattr, and we want to
364 # lie to it anyway to get runtime guards against post-__init__
365 # mutation.
366 setattr(instance, attrName, frozenset(updated_connection_names))
367 # Update the existing dict in place, since we already have a view of
368 # that.
369 instance._allConnections.clear()
370 instance._allConnections.update(updated_all_connections)
372 for obj in instance._allConnections.values():
373 if obj.deprecated is not None:
374 warnings.warn(
375 obj.deprecated, FutureWarning, stacklevel=find_outside_stacklevel("lsst.pipe.base")
376 )
378 # Freeze the connection instance dimensions now. This at odds with the
379 # type annotation, which says [mutable] `set`, just like the connection
380 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
381 # tell with those since we're using setattr for them.
382 instance.dimensions = frozenset(instance.dimensions) # type: ignore
384 return instance
387class QuantizedConnection(SimpleNamespace):
388 r"""A Namespace to map defined variable names of connections to the
389 associated `lsst.daf.butler.DatasetRef` objects.
391 This class maps the names used to define a connection on a
392 `PipelineTaskConnections` class to the corresponding
393 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum`
394 instance. This will be a quantum of execution based on the graph created
395 by examining all the connections defined on the
396 `PipelineTaskConnections` class.
397 """
399 def __init__(self, **kwargs):
400 # Create a variable to track what attributes are added. This is used
401 # later when iterating over this QuantizedConnection instance
402 object.__setattr__(self, "_attributes", set())
404 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
405 # Capture the attribute name as it is added to this object
406 self._attributes.add(name)
407 super().__setattr__(name, value)
409 def __delattr__(self, name):
410 object.__delattr__(self, name)
411 self._attributes.remove(name)
413 def __len__(self) -> int:
414 return len(self._attributes)
416 def __iter__(
417 self,
418 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]:
419 """Make an iterator for this `QuantizedConnection`.
421 Iterating over a `QuantizedConnection` will yield a tuple with the name
422 of an attribute and the value associated with that name. This is
423 similar to dict.items() but is on the namespace attributes rather than
424 dict keys.
425 """
426 yield from ((name, getattr(self, name)) for name in self._attributes)
428 def keys(self) -> Generator[str, None, None]:
429 """Return an iterator over all the attributes added to a
430 `QuantizedConnection` class
431 """
432 yield from self._attributes
435class InputQuantizedConnection(QuantizedConnection):
436 """Input variant of a `QuantizedConnection`."""
438 pass
441class OutputQuantizedConnection(QuantizedConnection):
442 """Output variant of a `QuantizedConnection`."""
444 pass
447@dataclass(frozen=True)
448class DeferredDatasetRef:
449 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a
450 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle`
451 instead of an in-memory dataset.
453 Parameters
454 ----------
455 datasetRef : `lsst.daf.butler.DatasetRef`
456 The `lsst.daf.butler.DatasetRef` that will be eventually used to
457 resolve a dataset
458 """
460 datasetRef: DatasetRef
462 @property
463 def datasetType(self) -> DatasetType:
464 """The dataset type for this dataset."""
465 return self.datasetRef.datasetType
467 @property
468 def dataId(self) -> DataCoordinate:
469 """The data ID for this dataset."""
470 return self.datasetRef.dataId
473class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
474 """PipelineTaskConnections is a class used to declare desired IO when a
475 PipelineTask is run by an activator
477 Parameters
478 ----------
479 config : `PipelineTaskConfig`
480 A `PipelineTaskConfig` class instance whose class has been configured
481 to use this `PipelineTaskConnections` class.
483 See Also
484 --------
485 iterConnections
487 Notes
488 -----
489 ``PipelineTaskConnection`` classes are created by declaring class
490 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
491 listed as follows:
493 * ``InitInput`` - Defines connections in a quantum graph which are used as
494 inputs to the ``__init__`` function of the `PipelineTask` corresponding
495 to this class
496 * ``InitOuput`` - Defines connections in a quantum graph which are to be
497 persisted using a butler at the end of the ``__init__`` function of the
498 `PipelineTask` corresponding to this class. The variable name used to
499 define this connection should be the same as an attribute name on the
500 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
501 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
502 a `PipelineTask` instance should have an attribute
503 ``self.outputSchema`` defined. Its value is what will be saved by the
504 activator framework.
505 * ``PrerequisiteInput`` - An input connection type that defines a
506 `lsst.daf.butler.DatasetType` that must be present at execution time,
507 but that will not be used during the course of creating the quantum
508 graph to be executed. These most often are things produced outside the
509 processing pipeline, such as reference catalogs.
510 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
511 in the ``run`` method of a `PipelineTask`. The name used to declare
512 class attribute must match a function argument name in the ``run``
513 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
514 defines an ``Input`` with the name ``calexp``, then the corresponding
515 signature should be ``PipelineTask.run(calexp, ...)``
516 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
517 execution of a `PipelineTask`. The name used to declare the connection
518 must correspond to an attribute of a `Struct` that is returned by a
519 `PipelineTask` ``run`` method. E.g. if an output connection is
520 defined with the name ``measCat``, then the corresponding
521 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
522 X matches the ``storageClass`` type defined on the output connection.
524 Attributes of these types can also be created, replaced, or deleted on the
525 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
526 just the name depends on the configuration. It is preferred to define them
527 in the class when possible (even if configuration may cause the connection
528 to be removed from the instance).
530 The process of declaring a ``PipelineTaskConnection`` class involves
531 parameters passed in the declaration statement.
533 The first parameter is ``dimensions`` which is an iterable of strings which
534 defines the unit of processing the run method of a corresponding
535 `PipelineTask` will operate on. These dimensions must match dimensions that
536 exist in the butler registry which will be used in executing the
537 corresponding `PipelineTask`. The dimensions may be also modified in
538 subclass ``__init__`` methods if they need to depend on configuration.
540 The second parameter is labeled ``defaultTemplates`` and is conditionally
541 optional. The name attributes of connections can be specified as python
542 format strings, with named format arguments. If any of the name parameters
543 on connections defined in a `PipelineTaskConnections` class contain a
544 template, then a default template value must be specified in the
545 ``defaultTemplates`` argument. This is done by passing a dictionary with
546 keys corresponding to a template identifier, and values corresponding to
547 the value to use as a default when formatting the string. For example if
548 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
549 ``defaultTemplates`` = {'input': 'deep'}.
551 Once a `PipelineTaskConnections` class is created, it is used in the
552 creation of a `PipelineTaskConfig`. This is further documented in the
553 documentation of `PipelineTaskConfig`. For the purposes of this
554 documentation, the relevant information is that the config class allows
555 configuration of connection names by users when running a pipeline.
557 Instances of a `PipelineTaskConnections` class are used by the pipeline
558 task execution framework to introspect what a corresponding `PipelineTask`
559 will require, and what it will produce.
561 Examples
562 --------
563 >>> from lsst.pipe.base import connectionTypes as cT
564 >>> from lsst.pipe.base import PipelineTaskConnections
565 >>> from lsst.pipe.base import PipelineTaskConfig
566 >>> class ExampleConnections(PipelineTaskConnections,
567 ... dimensions=("A", "B"),
568 ... defaultTemplates={"foo": "Example"}):
569 ... inputConnection = cT.Input(doc="Example input",
570 ... dimensions=("A", "B"),
571 ... storageClass=Exposure,
572 ... name="{foo}Dataset")
573 ... outputConnection = cT.Output(doc="Example output",
574 ... dimensions=("A", "B"),
575 ... storageClass=Exposure,
576 ... name="{foo}output")
577 >>> class ExampleConfig(PipelineTaskConfig,
578 ... pipelineConnections=ExampleConnections):
579 ... pass
580 >>> config = ExampleConfig()
581 >>> config.connections.foo = Modified
582 >>> config.connections.outputConnection = "TotallyDifferent"
583 >>> connections = ExampleConnections(config=config)
584 >>> assert(connections.inputConnection.name == "ModifiedDataset")
585 >>> assert(connections.outputConnection.name == "TotallyDifferent")
586 """
588 # We annotate these attributes as mutable sets because that's what they are
589 # inside derived ``__init__`` implementations and that's what matters most
590 # After that's done, the metaclass __call__ makes them into frozensets, but
591 # relatively little code interacts with them then, and that code knows not
592 # to try to modify them without having to be told that by mypy.
594 dimensions: set[str]
595 """Set of dimension names that define the unit of work for this task.
597 Required and implied dependencies will automatically be expanded later and
598 need not be provided.
600 This may be replaced or modified in ``__init__`` to change the dimensions
601 of the task. After ``__init__`` it will be a `frozenset` and may not be
602 replaced.
603 """
605 inputs: set[str]
606 """Set with the names of all `connectionTypes.Input` connection attributes.
608 This is updated automatically as class attributes are added, removed, or
609 replaced in ``__init__``. Removing entries from this set will cause those
610 connections to be removed after ``__init__`` completes, but this is
611 supported only for backwards compatibility; new code should instead just
612 delete the collection attributed directly. After ``__init__`` this will be
613 a `frozenset` and may not be replaced.
614 """
616 prerequisiteInputs: set[str]
617 """Set with the names of all `~connectionTypes.PrerequisiteInput`
618 connection attributes.
620 See `inputs` for additional information.
621 """
623 outputs: set[str]
624 """Set with the names of all `~connectionTypes.Output` connection
625 attributes.
627 See `inputs` for additional information.
628 """
630 initInputs: set[str]
631 """Set with the names of all `~connectionTypes.InitInput` connection
632 attributes.
634 See `inputs` for additional information.
635 """
637 initOutputs: set[str]
638 """Set with the names of all `~connectionTypes.InitOutput` connection
639 attributes.
641 See `inputs` for additional information.
642 """
644 allConnections: Mapping[str, BaseConnection]
645 """Mapping holding all connection attributes.
647 This is a read-only view that is automatically updated when connection
648 attributes are added, removed, or replaced in ``__init__``. It is also
649 updated after ``__init__`` completes to reflect changes in `inputs`,
650 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
651 """
653 _allConnections: dict[str, BaseConnection]
655 def __init__(self, *, config: PipelineTaskConfig | None = None):
656 pass
658 def __setattr__(self, name: str, value: Any) -> None:
659 if isinstance(value, BaseConnection):
660 previous = self._allConnections.get(name)
661 try:
662 getattr(self, value._connection_type_set).add(name)
663 except AttributeError:
664 # Attempt to call add on a frozenset, which is what these sets
665 # are after __init__ is done.
666 raise TypeError("Connections objects are frozen after construction.") from None
667 if previous is not None and value._connection_type_set != previous._connection_type_set:
668 # Connection has changed type, e.g. Input to PrerequisiteInput;
669 # update the sets accordingly. To be extra defensive about
670 # multiple assignments we use the type of the previous instance
671 # instead of assuming that's the same as the type of the self,
672 # which is just the default. Use discard instead of remove so
673 # manually removing from these sets first is never an error.
674 getattr(self, previous._connection_type_set).discard(name)
675 self._allConnections[name] = value
676 if hasattr(self.__class__, name):
677 # Don't actually set the attribute if this was a connection
678 # declared in the class; in that case we let the descriptor
679 # return the value we just added to allConnections.
680 return
681 # Actually add the attribute.
682 super().__setattr__(name, value)
684 def __delattr__(self, name):
685 """Descriptor delete method."""
686 previous = self._allConnections.get(name)
687 if previous is not None:
688 # Delete this connection's name from the appropriate set, which we
689 # have to get from the previous instance instead of assuming it's
690 # the same set that was appropriate for the class-level default.
691 # Use discard instead of remove so manually removing from these
692 # sets first is never an error.
693 try:
694 getattr(self, previous._connection_type_set).discard(name)
695 except AttributeError:
696 # Attempt to call discard on a frozenset, which is what these
697 # sets are after __init__ is done.
698 raise TypeError("Connections objects are frozen after construction.") from None
699 del self._allConnections[name]
700 if hasattr(self.__class__, name):
701 # Don't actually delete the attribute if this was a connection
702 # declared in the class; in that case we let the descriptor
703 # see that it's no longer present in allConnections.
704 return
705 # Actually delete the attribute.
706 super().__delattr__(name)
708 def buildDatasetRefs(
709 self, quantum: Quantum
710 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
711 """Build `QuantizedConnection` corresponding to input
712 `~lsst.daf.butler.Quantum`.
714 Parameters
715 ----------
716 quantum : `lsst.daf.butler.Quantum`
717 Quantum object which defines the inputs and outputs for a given
718 unit of processing.
720 Returns
721 -------
722 retVal : `tuple` of (`InputQuantizedConnection`,
723 `OutputQuantizedConnection`) Namespaces mapping attribute names
724 (identifiers of connections) to butler references defined in the
725 input `lsst.daf.butler.Quantum`.
726 """
727 inputDatasetRefs = InputQuantizedConnection()
728 outputDatasetRefs = OutputQuantizedConnection()
729 # operate on a reference object and an interable of names of class
730 # connection attributes
731 for refs, names in zip(
732 (inputDatasetRefs, outputDatasetRefs),
733 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
734 ):
735 # get a name of a class connection attribute
736 for attributeName in names:
737 # get the attribute identified by name
738 attribute = getattr(self, attributeName)
739 # Branch if the attribute dataset type is an input
740 if attribute.name in quantum.inputs:
741 # if the dataset is marked to load deferred, wrap it in a
742 # DeferredDatasetRef
743 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
744 if attribute.deferLoad:
745 quantumInputRefs = [
746 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
747 ]
748 else:
749 quantumInputRefs = list(quantum.inputs[attribute.name])
750 # Unpack arguments that are not marked multiples (list of
751 # length one)
752 if not attribute.multiple:
753 if len(quantumInputRefs) > 1:
754 raise ScalarError(
755 "Received multiple datasets "
756 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
757 f"for scalar connection {attributeName} "
758 f"({quantumInputRefs[0].datasetType.name}) "
759 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
760 )
761 if len(quantumInputRefs) == 0:
762 continue
763 setattr(refs, attributeName, quantumInputRefs[0])
764 else:
765 # Add to the QuantizedConnection identifier
766 setattr(refs, attributeName, quantumInputRefs)
767 # Branch if the attribute dataset type is an output
768 elif attribute.name in quantum.outputs:
769 value = quantum.outputs[attribute.name]
770 # Unpack arguments that are not marked multiples (list of
771 # length one)
772 if not attribute.multiple:
773 setattr(refs, attributeName, value[0])
774 else:
775 setattr(refs, attributeName, value)
776 # Specified attribute is not in inputs or outputs dont know how
777 # to handle, throw
778 else:
779 raise ValueError(
780 f"Attribute with name {attributeName} has no counterpart in input quantum"
781 )
782 return inputDatasetRefs, outputDatasetRefs
784 def adjustQuantum(
785 self,
786 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
787 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
788 label: str,
789 data_id: DataCoordinate,
790 ) -> tuple[
791 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
792 Mapping[str, tuple[Output, Collection[DatasetRef]]],
793 ]:
794 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
795 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
796 of the activator.
798 Parameters
799 ----------
800 inputs : `dict`
801 Dictionary whose keys are an input (regular or prerequisite)
802 connection name and whose values are a tuple of the connection
803 instance and a collection of associated
804 `~lsst.daf.butler.DatasetRef` objects.
805 The exact type of the nested collections is unspecified; it can be
806 assumed to be multi-pass iterable and support `len` and ``in``, but
807 it should not be mutated in place. In contrast, the outer
808 dictionaries are guaranteed to be temporary copies that are true
809 `dict` instances, and hence may be modified and even returned; this
810 is especially useful for delegating to `super` (see notes below).
811 outputs : `~collections.abc.Mapping`
812 Mapping of output datasets, with the same structure as ``inputs``.
813 label : `str`
814 Label for this task in the pipeline (should be used in all
815 diagnostic messages).
816 data_id : `lsst.daf.butler.DataCoordinate`
817 Data ID for this quantum in the pipeline (should be used in all
818 diagnostic messages).
820 Returns
821 -------
822 adjusted_inputs : `~collections.abc.Mapping`
823 Mapping of the same form as ``inputs`` with updated containers of
824 input `~lsst.daf.butler.DatasetRef` objects. Connections that are
825 not changed should not be returned at all. Datasets may only be
826 removed, not added. Nested collections may be of any multi-pass
827 iterable type, and the order of iteration will set the order of
828 iteration within `PipelineTask.runQuantum`.
829 adjusted_outputs : `~collections.abc.Mapping`
830 Mapping of updated output datasets, with the same structure and
831 interpretation as ``adjusted_inputs``.
833 Raises
834 ------
835 ScalarError
836 Raised if any `Input` or `PrerequisiteInput` connection has
837 ``multiple`` set to `False`, but multiple datasets.
838 NoWorkFound
839 Raised to indicate that this quantum should not be run; not enough
840 datasets were found for a regular `Input` connection, and the
841 quantum should be pruned or skipped.
842 FileNotFoundError
843 Raised to cause QuantumGraph generation to fail (with the message
844 included in this exception); not enough datasets were found for a
845 `PrerequisiteInput` connection.
847 Notes
848 -----
849 The base class implementation performs important checks. It always
850 returns an empty mapping (i.e. makes no adjustments). It should
851 always called be via `super` by custom implementations, ideally at the
852 end of the custom implementation with already-adjusted mappings when
853 any datasets are actually dropped, e.g.:
855 .. code-block:: python
857 def adjustQuantum(self, inputs, outputs, label, data_id):
858 # Filter out some dataset refs for one connection.
859 connection, old_refs = inputs["my_input"]
860 new_refs = [ref for ref in old_refs if ...]
861 adjusted_inputs = {"my_input", (connection, new_refs)}
862 # Update the original inputs so we can pass them to super.
863 inputs.update(adjusted_inputs)
864 # Can ignore outputs from super because they are guaranteed
865 # to be empty.
866 super().adjustQuantum(inputs, outputs, label_data_id)
867 # Return only the connections we modified.
868 return adjusted_inputs, {}
870 Removing outputs here is guaranteed to affect what is actually
871 passed to `PipelineTask.runQuantum`, but its effect on the larger
872 graph may be deferred to execution, depending on the context in
873 which `adjustQuantum` is being run: if one quantum removes an output
874 that is needed by a second quantum as input, the second quantum may not
875 be adjusted (and hence pruned or skipped) until that output is actually
876 found to be missing at execution time.
878 Tasks that desire zip-iteration consistency between any combinations of
879 connections that have the same data ID should generally implement
880 `adjustQuantum` to achieve this, even if they could also run that
881 logic during execution; this allows the system to see outputs that will
882 not be produced because the corresponding input is missing as early as
883 possible.
884 """
885 for name, (input_connection, refs) in inputs.items():
886 dataset_type_name = input_connection.name
887 if not input_connection.multiple and len(refs) > 1:
888 raise ScalarError(
889 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
890 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
891 f"for quantum data ID {data_id}."
892 )
893 if len(refs) < input_connection.minimum:
894 if isinstance(input_connection, PrerequisiteInput):
895 # This branch should only be possible during QG generation,
896 # or if someone deleted the dataset between making the QG
897 # and trying to run it. Either one should be a hard error.
898 raise FileNotFoundError(
899 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
900 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
901 f"{data_id}."
902 )
903 else:
904 # This branch should be impossible during QG generation,
905 # because that algorithm can only make quanta whose inputs
906 # are either already present or should be created during
907 # execution. It can trigger during execution if the input
908 # wasn't actually created by an upstream task in the same
909 # graph.
910 raise NoWorkFound(label, name, input_connection)
911 for name, (output_connection, refs) in outputs.items():
912 dataset_type_name = output_connection.name
913 if not output_connection.multiple and len(refs) > 1:
914 raise ScalarError(
915 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
916 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
917 f"for quantum data ID {data_id}."
918 )
919 return {}, {}
922def iterConnections(
923 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
924) -> Generator[BaseConnection, None, None]:
925 """Create an iterator over the selected connections type which yields
926 all the defined connections of that type.
928 Parameters
929 ----------
930 connections : `PipelineTaskConnections`
931 An instance of a `PipelineTaskConnections` object that will be iterated
932 over.
933 connectionType : `str`
934 The type of connections to iterate over, valid values are inputs,
935 outputs, prerequisiteInputs, initInputs, initOutputs.
937 Yields
938 ------
939 connection: `~.connectionTypes.BaseConnection`
940 A connection defined on the input connections object of the type
941 supplied. The yielded value Will be an derived type of
942 `~.connectionTypes.BaseConnection`.
943 """
944 if isinstance(connectionType, str):
945 connectionType = (connectionType,)
946 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
947 yield getattr(connections, name)
950@dataclass
951class AdjustQuantumHelper:
952 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
954 This class holds `input` and `output` mappings in the form used by
955 `Quantum` and execution harness code, i.e. with
956 `~lsst.daf.butler.DatasetType` keys, translating them to and from the
957 connection-oriented mappings used inside `PipelineTaskConnections`.
958 """
960 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
961 """Mapping of regular input and prerequisite input datasets, grouped by
962 `~lsst.daf.butler.DatasetType`.
963 """
965 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
966 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
967 """
969 inputs_adjusted: bool = False
970 """Whether any inputs were removed in the last call to `adjust_in_place`.
971 """
973 outputs_adjusted: bool = False
974 """Whether any outputs were removed in the last call to `adjust_in_place`.
975 """
977 def adjust_in_place(
978 self,
979 connections: PipelineTaskConnections,
980 label: str,
981 data_id: DataCoordinate,
982 ) -> None:
983 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
984 with its results.
986 Parameters
987 ----------
988 connections : `PipelineTaskConnections`
989 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
990 label : `str`
991 Label for this task in the pipeline (should be used in all
992 diagnostic messages).
993 data_id : `lsst.daf.butler.DataCoordinate`
994 Data ID for this quantum in the pipeline (should be used in all
995 diagnostic messages).
996 """
997 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
998 # connection-keyed, PipelineTask-oriented mappings.
999 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
1000 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
1001 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
1002 connection = getattr(connections, name)
1003 dataset_type_name = connection.name
1004 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
1005 for name in itertools.chain(connections.outputs):
1006 connection = getattr(connections, name)
1007 dataset_type_name = connection.name
1008 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
1009 # Actually call adjustQuantum.
1010 # MyPy correctly complains that this call is not quite legal, but the
1011 # method docs explain exactly what's expected and it's the behavior we
1012 # want. It'd be nice to avoid this if we ever have to change the
1013 # interface anyway, but not an immediate problem.
1014 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
1015 inputs_by_connection, # type: ignore
1016 outputs_by_connection, # type: ignore
1017 label,
1018 data_id,
1019 )
1020 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
1021 # installing new mappings in self if necessary.
1022 if adjusted_inputs_by_connection:
1023 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
1024 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
1025 dataset_type_name = connection.name
1026 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
1027 raise RuntimeError(
1028 f"adjustQuantum implementation for task with label {label} returned {name} "
1029 f"({dataset_type_name}) input datasets that are not a subset of those "
1030 f"it was given for data ID {data_id}."
1031 )
1032 adjusted_inputs[dataset_type_name] = tuple(updated_refs)
1033 self.inputs = adjusted_inputs.freeze()
1034 self.inputs_adjusted = True
1035 else:
1036 self.inputs_adjusted = False
1037 if adjusted_outputs_by_connection:
1038 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
1039 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1040 dataset_type_name = connection.name
1041 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1042 raise RuntimeError(
1043 f"adjustQuantum implementation for task with label {label} returned {name} "
1044 f"({dataset_type_name}) output datasets that are not a subset of those "
1045 f"it was given for data ID {data_id}."
1046 )
1047 adjusted_outputs[dataset_type_name] = tuple(updated_refs)
1048 self.outputs = adjusted_outputs.freeze()
1049 self.outputs_adjusted = True
1050 else:
1051 self.outputs_adjusted = False