Coverage for python/lsst/pipe/base/connections.py: 44%
303 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:25 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:25 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining connection classes for PipelineTask.
29"""
31from __future__ import annotations
33__all__ = [
34 "AdjustQuantumHelper",
35 "DeferredDatasetRef",
36 "InputQuantizedConnection",
37 "OutputQuantizedConnection",
38 "PipelineTaskConnections",
39 "ScalarError",
40 "iterConnections",
41 "ScalarError",
42]
44import dataclasses
45import itertools
46import string
47import warnings
48from collections import UserDict
49from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
50from dataclasses import dataclass
51from types import MappingProxyType, SimpleNamespace
52from typing import TYPE_CHECKING, Any
54from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
56from ._status import NoWorkFound
57from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
59if TYPE_CHECKING:
60 from .config import PipelineTaskConfig
63class ScalarError(TypeError):
64 """Exception raised when dataset type is configured as scalar
65 but there are multiple data IDs in a Quantum for that dataset.
66 """
69class PipelineTaskConnectionDict(UserDict):
70 """A special dict class used by `PipelineTaskConnectionMetaclass`.
72 This dict is used in `PipelineTaskConnection` class creation, as the
73 dictionary that is initially used as ``__dict__``. It exists to
74 intercept connection fields declared in a `PipelineTaskConnection`, and
75 what name is used to identify them. The names are then added to class
76 level list according to the connection type of the class attribute. The
77 names are also used as keys in a class level dictionary associated with
78 the corresponding class attribute. This information is a duplicate of
79 what exists in ``__dict__``, but provides a simple place to lookup and
80 iterate on only these variables.
82 Parameters
83 ----------
84 *args : `~typing.Any`
85 Passed to `dict` constructor.
86 **kwargs : `~typing.Any`
87 Passed to `dict` constructor.
88 """
90 def __init__(self, *args: Any, **kwargs: Any):
91 super().__init__(*args, **kwargs)
92 # Initialize class level variables used to track any declared
93 # class level variables that are instances of
94 # connectionTypes.BaseConnection
95 self.data["inputs"] = set()
96 self.data["prerequisiteInputs"] = set()
97 self.data["outputs"] = set()
98 self.data["initInputs"] = set()
99 self.data["initOutputs"] = set()
100 self.data["allConnections"] = {}
102 def __setitem__(self, name: str, value: Any) -> None:
103 if isinstance(value, BaseConnection):
104 if name in { 104 ↛ 114line 104 didn't jump to line 114, because the condition on line 104 was never true
105 "dimensions",
106 "inputs",
107 "prerequisiteInputs",
108 "outputs",
109 "initInputs",
110 "initOutputs",
111 "allConnections",
112 }:
113 # Guard against connections whose names are reserved.
114 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
115 if (previous := self.data.get(name)) is not None: 115 ↛ 118line 115 didn't jump to line 118, because the condition on line 115 was never true
116 # Guard against changing the type of an in inherited connection
117 # by first removing it from the set it's current in.
118 self.data[previous._connection_type_set].discard(name)
119 object.__setattr__(value, "varName", name)
120 self.data["allConnections"][name] = value
121 self.data[value._connection_type_set].add(name)
122 # defer to the default behavior
123 super().__setitem__(name, value)
126class PipelineTaskConnectionsMetaclass(type):
127 """Metaclass used in the declaration of PipelineTaskConnections classes.
129 Parameters
130 ----------
131 name : `str`
132 Name of connection.
133 bases : `~collections.abc.Collection`
134 Base classes.
135 dct : `~collections.abc.Mapping`
136 Connections dict.
137 **kwargs : `~typing.Any`
138 Additional parameters.
139 """
141 # We can annotate these attributes as `collections.abc.Set` to discourage
142 # undesirable modifications in type-checked code, since the internal code
143 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
144 # these annotations anyway.
146 dimensions: Set[str]
147 """Set of dimension names that define the unit of work for this task.
149 Required and implied dependencies will automatically be expanded later and
150 need not be provided.
152 This is shadowed by an instance-level attribute on
153 `PipelineTaskConnections` instances.
154 """
156 inputs: Set[str]
157 """Set with the names of all `~connectionTypes.Input` connection
158 attributes.
160 This is updated automatically as class attributes are added. Note that
161 this attribute is shadowed by an instance-level attribute on
162 `PipelineTaskConnections` instances.
163 """
165 prerequisiteInputs: Set[str]
166 """Set with the names of all `~connectionTypes.PrerequisiteInput`
167 connection attributes.
169 See `inputs` for additional information.
170 """
172 outputs: Set[str]
173 """Set with the names of all `~connectionTypes.Output` connection
174 attributes.
176 See `inputs` for additional information.
177 """
179 initInputs: Set[str]
180 """Set with the names of all `~connectionTypes.InitInput` connection
181 attributes.
183 See `inputs` for additional information.
184 """
186 initOutputs: Set[str]
187 """Set with the names of all `~connectionTypes.InitOutput` connection
188 attributes.
190 See `inputs` for additional information.
191 """
193 allConnections: Mapping[str, BaseConnection]
194 """Mapping containing all connection attributes.
196 See `inputs` for additional information.
197 """
199 def __prepare__(name, bases, **kwargs): # noqa: N804
200 # Create an instance of our special dict to catch and track all
201 # variables that are instances of connectionTypes.BaseConnection
202 # Copy any existing connections from a parent class
203 dct = PipelineTaskConnectionDict()
204 for base in bases:
205 if isinstance(base, PipelineTaskConnectionsMetaclass): 205 ↛ 204line 205 didn't jump to line 204, because the condition on line 205 was never false
206 for name, value in base.allConnections.items(): 206 ↛ 207line 206 didn't jump to line 207, because the loop on line 206 never started
207 dct[name] = value
208 return dct
210 def __new__(cls, name, bases, dct, **kwargs):
211 dimensionsValueError = TypeError(
212 "PipelineTaskConnections class must be created with a dimensions "
213 "attribute which is an iterable of dimension names"
214 )
216 if name != "PipelineTaskConnections":
217 # Verify that dimensions are passed as a keyword in class
218 # declaration
219 if "dimensions" not in kwargs: 219 ↛ 220line 219 didn't jump to line 220, because the condition on line 219 was never true
220 for base in bases:
221 if hasattr(base, "dimensions"):
222 kwargs["dimensions"] = base.dimensions
223 break
224 if "dimensions" not in kwargs:
225 raise dimensionsValueError
226 try:
227 if isinstance(kwargs["dimensions"], str): 227 ↛ 228line 227 didn't jump to line 228, because the condition on line 227 was never true
228 raise TypeError(
229 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
230 )
231 if not isinstance(kwargs["dimensions"], Iterable): 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true
232 raise TypeError("Dimensions must be iterable of dimensions")
233 dct["dimensions"] = set(kwargs["dimensions"])
234 except TypeError as exc:
235 raise dimensionsValueError from exc
236 # Lookup any python string templates that may have been used in the
237 # declaration of the name field of a class connection attribute
238 allTemplates = set()
239 stringFormatter = string.Formatter()
240 # Loop over all connections
241 for obj in dct["allConnections"].values():
242 nameValue = obj.name
243 # add all the parameters to the set of templates
244 for param in stringFormatter.parse(nameValue):
245 if param[1] is not None:
246 allTemplates.add(param[1])
248 # look up any template from base classes and merge them all
249 # together
250 mergeDict = {}
251 mergeDeprecationsDict = {}
252 for base in bases[::-1]:
253 if hasattr(base, "defaultTemplates"):
254 mergeDict.update(base.defaultTemplates)
255 if hasattr(base, "deprecatedTemplates"):
256 mergeDeprecationsDict.update(base.deprecatedTemplates)
257 if "defaultTemplates" in kwargs:
258 mergeDict.update(kwargs["defaultTemplates"])
259 if "deprecatedTemplates" in kwargs: 259 ↛ 260line 259 didn't jump to line 260, because the condition on line 259 was never true
260 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
261 if len(mergeDict) > 0:
262 kwargs["defaultTemplates"] = mergeDict
263 if len(mergeDeprecationsDict) > 0: 263 ↛ 264line 263 didn't jump to line 264, because the condition on line 263 was never true
264 kwargs["deprecatedTemplates"] = mergeDeprecationsDict
266 # Verify that if templated strings were used, defaults were
267 # supplied as an argument in the declaration of the connection
268 # class
269 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 269 ↛ 270line 269 didn't jump to line 270, because the condition on line 269 was never true
270 raise TypeError(
271 "PipelineTaskConnection class contains templated attribute names, but no "
272 "defaut templates were provided, add a dictionary attribute named "
273 "defaultTemplates which contains the mapping between template key and value"
274 )
275 if len(allTemplates) > 0:
276 # Verify all templates have a default, and throw if they do not
277 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
278 templateDifference = allTemplates.difference(defaultTemplateKeys)
279 if templateDifference: 279 ↛ 280line 279 didn't jump to line 280, because the condition on line 279 was never true
280 raise TypeError(f"Default template keys were not provided for {templateDifference}")
281 # Verify that templates do not share names with variable names
282 # used for a connection, this is needed because of how
283 # templates are specified in an associated config class.
284 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
285 if len(nameTemplateIntersection) > 0: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true
286 raise TypeError(
287 "Template parameters cannot share names with Class attributes"
288 f" (conflicts are {nameTemplateIntersection})."
289 )
290 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
291 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})
293 # Convert all the connection containers into frozensets so they cannot
294 # be modified at the class scope
295 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
296 dct[connectionName] = frozenset(dct[connectionName])
297 # our custom dict type must be turned into an actual dict to be used in
298 # type.__new__
299 return super().__new__(cls, name, bases, dict(dct))
301 def __init__(cls, name, bases, dct, **kwargs):
302 # This overrides the default init to drop the kwargs argument. Python
303 # metaclasses will have this argument set if any kwargs are passes at
304 # class construction time, but should be consumed before calling
305 # __init__ on the type metaclass. This is in accordance with python
306 # documentation on metaclasses
307 super().__init__(name, bases, dct)
309 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
310 # MyPy appears not to really understand metaclass.__call__ at all, so
311 # we need to tell it to ignore __new__ and __init__ calls here.
312 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
314 # Make mutable copies of all set-like class attributes so derived
315 # __init__ implementations can modify them in place.
316 instance.dimensions = set(cls.dimensions)
317 instance.inputs = set(cls.inputs)
318 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
319 instance.outputs = set(cls.outputs)
320 instance.initInputs = set(cls.initInputs)
321 instance.initOutputs = set(cls.initOutputs)
323 # Set self.config. It's a bit strange that we claim to accept None but
324 # really just raise here, but it's not worth changing now.
325 from .config import PipelineTaskConfig # local import to avoid cycle
327 if config is None or not isinstance(config, PipelineTaskConfig):
328 raise ValueError(
329 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
330 )
331 instance.config = config
333 # Extract the template names that were defined in the config instance
334 # by looping over the keys of the defaultTemplates dict specified at
335 # class declaration time.
336 templateValues = {
337 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore
338 }
340 # We now assemble a mapping of all connection instances keyed by
341 # internal name, applying the configuration and templates to make new
342 # configurations from the class-attribute defaults. This will be
343 # private, but with a public read-only view. This mapping is what the
344 # descriptor interface of the class-level attributes will return when
345 # they are accessed on an instance. This is better than just assigning
346 # regular instance attributes as it makes it so removed connections
347 # cannot be accessed on instances, instead of having access to them
348 # silent fall through to the not-removed class connection instance.
349 instance._allConnections = {}
350 instance.allConnections = MappingProxyType(instance._allConnections)
351 for internal_name, connection in cls.allConnections.items():
352 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
353 instance_connection = dataclasses.replace(
354 connection,
355 name=dataset_type_name,
356 doc=(
357 connection.doc
358 if connection.deprecated is None
359 else f"{connection.doc}\n{connection.deprecated}"
360 ),
361 _deprecation_context=connection._deprecation_context,
362 )
363 instance._allConnections[internal_name] = instance_connection
365 # Finally call __init__. The base class implementation does nothing;
366 # we could have left some of the above implementation there (where it
367 # originated), but putting it here instead makes it hard for derived
368 # class implementors to get things into a weird state by delegating to
369 # super().__init__ in the wrong place, or by forgetting to do that
370 # entirely.
371 instance.__init__(config=config) # type: ignore
373 # Derived-class implementations may have changed the contents of the
374 # various kinds-of-connection sets; update allConnections to have keys
375 # that are a union of all those. We get values for the new
376 # allConnections from the attributes, since any dynamically added new
377 # ones will not be present in the old allConnections. Typically those
378 # getattrs will invoke the descriptors and get things from the old
379 # allConnections anyway. After processing each set we replace it with
380 # a frozenset.
381 updated_all_connections = {}
382 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
383 updated_connection_names = getattr(instance, attrName)
384 updated_all_connections.update(
385 {name: getattr(instance, name) for name in updated_connection_names}
386 )
387 # Setting these to frozenset is at odds with the type annotation,
388 # but MyPy can't tell because we're using setattr, and we want to
389 # lie to it anyway to get runtime guards against post-__init__
390 # mutation.
391 setattr(instance, attrName, frozenset(updated_connection_names))
392 # Update the existing dict in place, since we already have a view of
393 # that.
394 instance._allConnections.clear()
395 instance._allConnections.update(updated_all_connections)
397 for connection_name, obj in instance._allConnections.items():
398 if obj.deprecated is not None:
399 warnings.warn(
400 f"Connection {connection_name} with datasetType {obj.name} "
401 f"(from {obj._deprecation_context}): {obj.deprecated}",
402 FutureWarning,
403 stacklevel=1, # Report from this location.
404 )
406 # Freeze the connection instance dimensions now. This at odds with the
407 # type annotation, which says [mutable] `set`, just like the connection
408 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
409 # tell with those since we're using setattr for them.
410 instance.dimensions = frozenset(instance.dimensions) # type: ignore
412 return instance
415class QuantizedConnection(SimpleNamespace):
416 r"""A Namespace to map defined variable names of connections to the
417 associated `lsst.daf.butler.DatasetRef` objects.
419 This class maps the names used to define a connection on a
420 `PipelineTaskConnections` class to the corresponding
421 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum`
422 instance. This will be a quantum of execution based on the graph created
423 by examining all the connections defined on the
424 `PipelineTaskConnections` class.
426 Parameters
427 ----------
428 **kwargs : `~typing.Any`
429 Not used.
430 """
432 def __init__(self, **kwargs):
433 # Create a variable to track what attributes are added. This is used
434 # later when iterating over this QuantizedConnection instance
435 object.__setattr__(self, "_attributes", set())
437 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
438 # Capture the attribute name as it is added to this object
439 self._attributes.add(name)
440 super().__setattr__(name, value)
442 def __delattr__(self, name):
443 object.__delattr__(self, name)
444 self._attributes.remove(name)
446 def __len__(self) -> int:
447 return len(self._attributes)
449 def __iter__(
450 self,
451 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]:
452 """Make an iterator for this `QuantizedConnection`.
454 Iterating over a `QuantizedConnection` will yield a tuple with the name
455 of an attribute and the value associated with that name. This is
456 similar to dict.items() but is on the namespace attributes rather than
457 dict keys.
458 """
459 yield from ((name, getattr(self, name)) for name in self._attributes)
461 def keys(self) -> Generator[str, None, None]:
462 """Return an iterator over all the attributes added to a
463 `QuantizedConnection` class.
464 """
465 yield from self._attributes
468class InputQuantizedConnection(QuantizedConnection):
469 """Input variant of a `QuantizedConnection`."""
471 pass
474class OutputQuantizedConnection(QuantizedConnection):
475 """Output variant of a `QuantizedConnection`."""
477 pass
480@dataclass(frozen=True)
481class DeferredDatasetRef:
482 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a
483 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle`
484 instead of an in-memory dataset.
486 Attributes
487 ----------
488 datasetRef : `lsst.daf.butler.DatasetRef`
489 The `lsst.daf.butler.DatasetRef` that will be eventually used to
490 resolve a dataset.
491 """
493 datasetRef: DatasetRef
495 @property
496 def datasetType(self) -> DatasetType:
497 """The dataset type for this dataset."""
498 return self.datasetRef.datasetType
500 @property
501 def dataId(self) -> DataCoordinate:
502 """The data ID for this dataset."""
503 return self.datasetRef.dataId
506class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
507 """PipelineTaskConnections is a class used to declare desired IO when a
508 PipelineTask is run by an activator.
510 Parameters
511 ----------
512 config : `PipelineTaskConfig`
513 A `PipelineTaskConfig` class instance whose class has been configured
514 to use this `PipelineTaskConnections` class.
516 See Also
517 --------
518 iterConnections : Iterator over selected connections.
520 Notes
521 -----
522 ``PipelineTaskConnection`` classes are created by declaring class
523 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
524 listed as follows:
526 * ``InitInput`` - Defines connections in a quantum graph which are used as
527 inputs to the ``__init__`` function of the `PipelineTask` corresponding
528 to this class
529 * ``InitOuput`` - Defines connections in a quantum graph which are to be
530 persisted using a butler at the end of the ``__init__`` function of the
531 `PipelineTask` corresponding to this class. The variable name used to
532 define this connection should be the same as an attribute name on the
533 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
534 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
535 a `PipelineTask` instance should have an attribute
536 ``self.outputSchema`` defined. Its value is what will be saved by the
537 activator framework.
538 * ``PrerequisiteInput`` - An input connection type that defines a
539 `lsst.daf.butler.DatasetType` that must be present at execution time,
540 but that will not be used during the course of creating the quantum
541 graph to be executed. These most often are things produced outside the
542 processing pipeline, such as reference catalogs.
543 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
544 in the ``run`` method of a `PipelineTask`. The name used to declare
545 class attribute must match a function argument name in the ``run``
546 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
547 defines an ``Input`` with the name ``calexp``, then the corresponding
548 signature should be ``PipelineTask.run(calexp, ...)``
549 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
550 execution of a `PipelineTask`. The name used to declare the connection
551 must correspond to an attribute of a `Struct` that is returned by a
552 `PipelineTask` ``run`` method. E.g. if an output connection is
553 defined with the name ``measCat``, then the corresponding
554 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
555 X matches the ``storageClass`` type defined on the output connection.
557 Attributes of these types can also be created, replaced, or deleted on the
558 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
559 just the name depends on the configuration. It is preferred to define them
560 in the class when possible (even if configuration may cause the connection
561 to be removed from the instance).
563 The process of declaring a ``PipelineTaskConnection`` class involves
564 parameters passed in the declaration statement.
566 The first parameter is ``dimensions`` which is an iterable of strings which
567 defines the unit of processing the run method of a corresponding
568 `PipelineTask` will operate on. These dimensions must match dimensions that
569 exist in the butler registry which will be used in executing the
570 corresponding `PipelineTask`. The dimensions may be also modified in
571 subclass ``__init__`` methods if they need to depend on configuration.
573 The second parameter is labeled ``defaultTemplates`` and is conditionally
574 optional. The name attributes of connections can be specified as python
575 format strings, with named format arguments. If any of the name parameters
576 on connections defined in a `PipelineTaskConnections` class contain a
577 template, then a default template value must be specified in the
578 ``defaultTemplates`` argument. This is done by passing a dictionary with
579 keys corresponding to a template identifier, and values corresponding to
580 the value to use as a default when formatting the string. For example if
581 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
582 ``defaultTemplates`` = {'input': 'deep'}.
584 Once a `PipelineTaskConnections` class is created, it is used in the
585 creation of a `PipelineTaskConfig`. This is further documented in the
586 documentation of `PipelineTaskConfig`. For the purposes of this
587 documentation, the relevant information is that the config class allows
588 configuration of connection names by users when running a pipeline.
590 Instances of a `PipelineTaskConnections` class are used by the pipeline
591 task execution framework to introspect what a corresponding `PipelineTask`
592 will require, and what it will produce.
594 Examples
595 --------
596 >>> from lsst.pipe.base import connectionTypes as cT
597 >>> from lsst.pipe.base import PipelineTaskConnections
598 >>> from lsst.pipe.base import PipelineTaskConfig
599 >>> class ExampleConnections(PipelineTaskConnections,
600 ... dimensions=("A", "B"),
601 ... defaultTemplates={"foo": "Example"}):
602 ... inputConnection = cT.Input(doc="Example input",
603 ... dimensions=("A", "B"),
604 ... storageClass=Exposure,
605 ... name="{foo}Dataset")
606 ... outputConnection = cT.Output(doc="Example output",
607 ... dimensions=("A", "B"),
608 ... storageClass=Exposure,
609 ... name="{foo}output")
610 >>> class ExampleConfig(PipelineTaskConfig,
611 ... pipelineConnections=ExampleConnections):
612 ... pass
613 >>> config = ExampleConfig()
614 >>> config.connections.foo = Modified
615 >>> config.connections.outputConnection = "TotallyDifferent"
616 >>> connections = ExampleConnections(config=config)
617 >>> assert(connections.inputConnection.name == "ModifiedDataset")
618 >>> assert(connections.outputConnection.name == "TotallyDifferent")
619 """
621 # We annotate these attributes as mutable sets because that's what they are
622 # inside derived ``__init__`` implementations and that's what matters most
623 # After that's done, the metaclass __call__ makes them into frozensets, but
624 # relatively little code interacts with them then, and that code knows not
625 # to try to modify them without having to be told that by mypy.
627 dimensions: set[str]
628 """Set of dimension names that define the unit of work for this task.
630 Required and implied dependencies will automatically be expanded later and
631 need not be provided.
633 This may be replaced or modified in ``__init__`` to change the dimensions
634 of the task. After ``__init__`` it will be a `frozenset` and may not be
635 replaced.
636 """
638 inputs: set[str]
639 """Set with the names of all `connectionTypes.Input` connection attributes.
641 This is updated automatically as class attributes are added, removed, or
642 replaced in ``__init__``. Removing entries from this set will cause those
643 connections to be removed after ``__init__`` completes, but this is
644 supported only for backwards compatibility; new code should instead just
645 delete the collection attributed directly. After ``__init__`` this will be
646 a `frozenset` and may not be replaced.
647 """
649 prerequisiteInputs: set[str]
650 """Set with the names of all `~connectionTypes.PrerequisiteInput`
651 connection attributes.
653 See `inputs` for additional information.
654 """
656 outputs: set[str]
657 """Set with the names of all `~connectionTypes.Output` connection
658 attributes.
660 See `inputs` for additional information.
661 """
663 initInputs: set[str]
664 """Set with the names of all `~connectionTypes.InitInput` connection
665 attributes.
667 See `inputs` for additional information.
668 """
670 initOutputs: set[str]
671 """Set with the names of all `~connectionTypes.InitOutput` connection
672 attributes.
674 See `inputs` for additional information.
675 """
677 allConnections: Mapping[str, BaseConnection]
678 """Mapping holding all connection attributes.
680 This is a read-only view that is automatically updated when connection
681 attributes are added, removed, or replaced in ``__init__``. It is also
682 updated after ``__init__`` completes to reflect changes in `inputs`,
683 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
684 """
686 _allConnections: dict[str, BaseConnection]
688 def __init__(self, *, config: PipelineTaskConfig | None = None):
689 pass
691 def __setattr__(self, name: str, value: Any) -> None:
692 if isinstance(value, BaseConnection):
693 previous = self._allConnections.get(name)
694 try:
695 getattr(self, value._connection_type_set).add(name)
696 except AttributeError:
697 # Attempt to call add on a frozenset, which is what these sets
698 # are after __init__ is done.
699 raise TypeError("Connections objects are frozen after construction.") from None
700 if previous is not None and value._connection_type_set != previous._connection_type_set:
701 # Connection has changed type, e.g. Input to PrerequisiteInput;
702 # update the sets accordingly. To be extra defensive about
703 # multiple assignments we use the type of the previous instance
704 # instead of assuming that's the same as the type of the self,
705 # which is just the default. Use discard instead of remove so
706 # manually removing from these sets first is never an error.
707 getattr(self, previous._connection_type_set).discard(name)
708 self._allConnections[name] = value
709 if hasattr(self.__class__, name):
710 # Don't actually set the attribute if this was a connection
711 # declared in the class; in that case we let the descriptor
712 # return the value we just added to allConnections.
713 return
714 # Actually add the attribute.
715 super().__setattr__(name, value)
717 def __delattr__(self, name):
718 """Descriptor delete method."""
719 previous = self._allConnections.get(name)
720 if previous is not None:
721 # Delete this connection's name from the appropriate set, which we
722 # have to get from the previous instance instead of assuming it's
723 # the same set that was appropriate for the class-level default.
724 # Use discard instead of remove so manually removing from these
725 # sets first is never an error.
726 try:
727 getattr(self, previous._connection_type_set).discard(name)
728 except AttributeError:
729 # Attempt to call discard on a frozenset, which is what these
730 # sets are after __init__ is done.
731 raise TypeError("Connections objects are frozen after construction.") from None
732 del self._allConnections[name]
733 if hasattr(self.__class__, name):
734 # Don't actually delete the attribute if this was a connection
735 # declared in the class; in that case we let the descriptor
736 # see that it's no longer present in allConnections.
737 return
738 # Actually delete the attribute.
739 super().__delattr__(name)
741 def buildDatasetRefs(
742 self, quantum: Quantum
743 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
744 """Build `QuantizedConnection` corresponding to input
745 `~lsst.daf.butler.Quantum`.
747 Parameters
748 ----------
749 quantum : `lsst.daf.butler.Quantum`
750 Quantum object which defines the inputs and outputs for a given
751 unit of processing.
753 Returns
754 -------
755 retVal : `tuple` of (`InputQuantizedConnection`, \
756 `OutputQuantizedConnection`)
757 Namespaces mapping attribute names
758 (identifiers of connections) to butler references defined in the
759 input `~lsst.daf.butler.Quantum`.
760 """
761 inputDatasetRefs = InputQuantizedConnection()
762 outputDatasetRefs = OutputQuantizedConnection()
764 # populate inputDatasetRefs from quantum inputs
765 for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs):
766 # get the attribute identified by name
767 attribute = getattr(self, attributeName)
768 # if the dataset is marked to load deferred, wrap it in a
769 # DeferredDatasetRef
770 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
771 if attribute.deferLoad:
772 quantumInputRefs = [
773 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
774 ]
775 else:
776 quantumInputRefs = list(quantum.inputs[attribute.name])
777 # Unpack arguments that are not marked multiples (list of
778 # length one)
779 if not attribute.multiple:
780 if len(quantumInputRefs) > 1:
781 raise ScalarError(
782 "Received multiple datasets "
783 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
784 f"for scalar connection {attributeName} "
785 f"({quantumInputRefs[0].datasetType.name}) "
786 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
787 )
788 if len(quantumInputRefs) == 0:
789 continue
790 setattr(inputDatasetRefs, attributeName, quantumInputRefs[0])
791 else:
792 # Add to the QuantizedConnection identifier
793 setattr(inputDatasetRefs, attributeName, quantumInputRefs)
795 # populate outputDatasetRefs from quantum outputs
796 for attributeName in self.outputs:
797 # get the attribute identified by name
798 attribute = getattr(self, attributeName)
799 value = quantum.outputs[attribute.name]
800 # Unpack arguments that are not marked multiples (list of
801 # length one)
802 if not attribute.multiple:
803 setattr(outputDatasetRefs, attributeName, value[0])
804 else:
805 setattr(outputDatasetRefs, attributeName, value)
807 return inputDatasetRefs, outputDatasetRefs
809 def adjustQuantum(
810 self,
811 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
812 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
813 label: str,
814 data_id: DataCoordinate,
815 ) -> tuple[
816 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
817 Mapping[str, tuple[Output, Collection[DatasetRef]]],
818 ]:
819 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
820 in the `lsst.daf.butler.Quantum` during the graph generation stage
821 of the activator.
823 Parameters
824 ----------
825 inputs : `dict`
826 Dictionary whose keys are an input (regular or prerequisite)
827 connection name and whose values are a tuple of the connection
828 instance and a collection of associated
829 `~lsst.daf.butler.DatasetRef` objects.
830 The exact type of the nested collections is unspecified; it can be
831 assumed to be multi-pass iterable and support `len` and ``in``, but
832 it should not be mutated in place. In contrast, the outer
833 dictionaries are guaranteed to be temporary copies that are true
834 `dict` instances, and hence may be modified and even returned; this
835 is especially useful for delegating to `super` (see notes below).
836 outputs : `~collections.abc.Mapping`
837 Mapping of output datasets, with the same structure as ``inputs``.
838 label : `str`
839 Label for this task in the pipeline (should be used in all
840 diagnostic messages).
841 data_id : `lsst.daf.butler.DataCoordinate`
842 Data ID for this quantum in the pipeline (should be used in all
843 diagnostic messages).
845 Returns
846 -------
847 adjusted_inputs : `~collections.abc.Mapping`
848 Mapping of the same form as ``inputs`` with updated containers of
849 input `~lsst.daf.butler.DatasetRef` objects. Connections that are
850 not changed should not be returned at all. Datasets may only be
851 removed, not added. Nested collections may be of any multi-pass
852 iterable type, and the order of iteration will set the order of
853 iteration within `PipelineTask.runQuantum`.
854 adjusted_outputs : `~collections.abc.Mapping`
855 Mapping of updated output datasets, with the same structure and
856 interpretation as ``adjusted_inputs``.
858 Raises
859 ------
860 ScalarError
861 Raised if any `Input` or `PrerequisiteInput` connection has
862 ``multiple`` set to `False`, but multiple datasets.
863 NoWorkFound
864 Raised to indicate that this quantum should not be run; not enough
865 datasets were found for a regular `Input` connection, and the
866 quantum should be pruned or skipped.
867 FileNotFoundError
868 Raised to cause QuantumGraph generation to fail (with the message
869 included in this exception); not enough datasets were found for a
870 `PrerequisiteInput` connection.
872 Notes
873 -----
874 The base class implementation performs important checks. It always
875 returns an empty mapping (i.e. makes no adjustments). It should
876 always called be via `super` by custom implementations, ideally at the
877 end of the custom implementation with already-adjusted mappings when
878 any datasets are actually dropped, e.g.:
880 .. code-block:: python
882 def adjustQuantum(self, inputs, outputs, label, data_id):
883 # Filter out some dataset refs for one connection.
884 connection, old_refs = inputs["my_input"]
885 new_refs = [ref for ref in old_refs if ...]
886 adjusted_inputs = {"my_input", (connection, new_refs)}
887 # Update the original inputs so we can pass them to super.
888 inputs.update(adjusted_inputs)
889 # Can ignore outputs from super because they are guaranteed
890 # to be empty.
891 super().adjustQuantum(inputs, outputs, label_data_id)
892 # Return only the connections we modified.
893 return adjusted_inputs, {}
895 Removing outputs here is guaranteed to affect what is actually
896 passed to `PipelineTask.runQuantum`, but its effect on the larger
897 graph may be deferred to execution, depending on the context in
898 which `adjustQuantum` is being run: if one quantum removes an output
899 that is needed by a second quantum as input, the second quantum may not
900 be adjusted (and hence pruned or skipped) until that output is actually
901 found to be missing at execution time.
903 Tasks that desire zip-iteration consistency between any combinations of
904 connections that have the same data ID should generally implement
905 `adjustQuantum` to achieve this, even if they could also run that
906 logic during execution; this allows the system to see outputs that will
907 not be produced because the corresponding input is missing as early as
908 possible.
909 """
910 for name, (input_connection, refs) in inputs.items():
911 dataset_type_name = input_connection.name
912 if not input_connection.multiple and len(refs) > 1:
913 raise ScalarError(
914 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
915 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
916 f"for quantum data ID {data_id}."
917 )
918 if len(refs) < input_connection.minimum:
919 if isinstance(input_connection, PrerequisiteInput):
920 # This branch should only be possible during QG generation,
921 # or if someone deleted the dataset between making the QG
922 # and trying to run it. Either one should be a hard error.
923 raise FileNotFoundError(
924 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
925 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
926 f"{data_id}."
927 )
928 else:
929 raise NoWorkFound(label, name, input_connection)
930 for name, (output_connection, refs) in outputs.items():
931 dataset_type_name = output_connection.name
932 if not output_connection.multiple and len(refs) > 1:
933 raise ScalarError(
934 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
935 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
936 f"for quantum data ID {data_id}."
937 )
938 return {}, {}
940 def getSpatialBoundsConnections(self) -> Iterable[str]:
941 """Return the names of regular input and output connections whose data
942 IDs should be used to compute the spatial bounds of this task's quanta.
944 The spatial bound for a quantum is defined as the union of the regions
945 of all data IDs of all connections returned here, along with the region
946 of the quantum data ID (if the task has spatial dimensions).
948 Returns
949 -------
950 connection_names : `collections.abc.Iterable` [ `str` ]
951 Names of collections with spatial dimensions. These are the
952 task-internal connection names, not butler dataset type names.
954 Notes
955 -----
956 The spatial bound is used to search for prerequisite inputs that have
957 skypix dimensions. The default implementation returns an empty
958 iterable, which is usually sufficient for tasks with spatial
959 dimensions, but if a task's inputs or outputs are associated with
960 spatial regions that extend beyond the quantum data ID's region, this
961 method may need to be overridden to expand the set of prerequisite
962 inputs found.
964 Tasks that do not have spatial dimensions that have skypix prerequisite
965 inputs should always override this method, as the default spatial
966 bounds otherwise cover the full sky.
967 """
968 return ()
970 def getTemporalBoundsConnections(self) -> Iterable[str]:
971 """Return the names of regular input and output connections whose data
972 IDs should be used to compute the temporal bounds of this task's
973 quanta.
975 The temporal bound for a quantum is defined as the union of the
976 timespans of all data IDs of all connections returned here, along with
977 the timespan of the quantum data ID (if the task has temporal
978 dimensions).
980 Returns
981 -------
982 connection_names : `collections.abc.Iterable` [ `str` ]
983 Names of collections with temporal dimensions. These are the
984 task-internal connection names, not butler dataset type names.
986 Notes
987 -----
988 The temporal bound is used to search for prerequisite inputs that are
989 calibration datasets. The default implementation returns an empty
990 iterable, which is usually sufficient for tasks with temporal
991 dimensions, but if a task's inputs or outputs are associated with
992 timespans that extend beyond the quantum data ID's timespan, this
993 method may need to be overridden to expand the set of prerequisite
994 inputs found.
996 Tasks that do not have temporal dimensions that do not implement this
997 method will use an infinite timespan for any calibration lookups.
998 """
999 return ()
1002def iterConnections(
1003 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
1004) -> Generator[BaseConnection, None, None]:
1005 """Create an iterator over the selected connections type which yields
1006 all the defined connections of that type.
1008 Parameters
1009 ----------
1010 connections : `PipelineTaskConnections`
1011 An instance of a `PipelineTaskConnections` object that will be iterated
1012 over.
1013 connectionType : `str`
1014 The type of connections to iterate over, valid values are inputs,
1015 outputs, prerequisiteInputs, initInputs, initOutputs.
1017 Yields
1018 ------
1019 connection: `~.connectionTypes.BaseConnection`
1020 A connection defined on the input connections object of the type
1021 supplied. The yielded value Will be an derived type of
1022 `~.connectionTypes.BaseConnection`.
1023 """
1024 if isinstance(connectionType, str):
1025 connectionType = (connectionType,)
1026 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
1027 yield getattr(connections, name)
1030@dataclass
1031class AdjustQuantumHelper:
1032 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
1034 This class holds `input` and `output` mappings in the form used by
1035 `Quantum` and execution harness code, i.e. with
1036 `~lsst.daf.butler.DatasetType` keys, translating them to and from the
1037 connection-oriented mappings used inside `PipelineTaskConnections`.
1038 """
1040 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1041 """Mapping of regular input and prerequisite input datasets, grouped by
1042 `~lsst.daf.butler.DatasetType`.
1043 """
1045 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1046 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
1047 """
1049 inputs_adjusted: bool = False
1050 """Whether any inputs were removed in the last call to `adjust_in_place`.
1051 """
1053 outputs_adjusted: bool = False
1054 """Whether any outputs were removed in the last call to `adjust_in_place`.
1055 """
1057 def adjust_in_place(
1058 self,
1059 connections: PipelineTaskConnections,
1060 label: str,
1061 data_id: DataCoordinate,
1062 ) -> None:
1063 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
1064 with its results.
1066 Parameters
1067 ----------
1068 connections : `PipelineTaskConnections`
1069 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
1070 label : `str`
1071 Label for this task in the pipeline (should be used in all
1072 diagnostic messages).
1073 data_id : `lsst.daf.butler.DataCoordinate`
1074 Data ID for this quantum in the pipeline (should be used in all
1075 diagnostic messages).
1076 """
1077 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
1078 # connection-keyed, PipelineTask-oriented mappings.
1079 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
1080 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
1081 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
1082 connection = getattr(connections, name)
1083 dataset_type_name = connection.name
1084 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
1085 for name in itertools.chain(connections.outputs):
1086 connection = getattr(connections, name)
1087 dataset_type_name = connection.name
1088 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
1089 # Actually call adjustQuantum.
1090 # MyPy correctly complains that this call is not quite legal, but the
1091 # method docs explain exactly what's expected and it's the behavior we
1092 # want. It'd be nice to avoid this if we ever have to change the
1093 # interface anyway, but not an immediate problem.
1094 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
1095 inputs_by_connection, # type: ignore
1096 outputs_by_connection, # type: ignore
1097 label,
1098 data_id,
1099 )
1100 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
1101 # installing new mappings in self if necessary.
1102 if adjusted_inputs_by_connection:
1103 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
1104 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
1105 dataset_type_name = connection.name
1106 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
1107 raise RuntimeError(
1108 f"adjustQuantum implementation for task with label {label} returned {name} "
1109 f"({dataset_type_name}) input datasets that are not a subset of those "
1110 f"it was given for data ID {data_id}."
1111 )
1112 adjusted_inputs[dataset_type_name] = tuple(updated_refs)
1113 self.inputs = adjusted_inputs.freeze()
1114 self.inputs_adjusted = True
1115 else:
1116 self.inputs_adjusted = False
1117 if adjusted_outputs_by_connection:
1118 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
1119 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1120 dataset_type_name = connection.name
1121 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1122 raise RuntimeError(
1123 f"adjustQuantum implementation for task with label {label} returned {name} "
1124 f"({dataset_type_name}) output datasets that are not a subset of those "
1125 f"it was given for data ID {data_id}."
1126 )
1127 adjusted_outputs[dataset_type_name] = tuple(updated_refs)
1128 self.outputs = adjusted_outputs.freeze()
1129 self.outputs_adjusted = True
1130 else:
1131 self.outputs_adjusted = False