Coverage for python/lsst/pipe/base/connections.py: 43%
305 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining connection classes for PipelineTask.
23"""
25from __future__ import annotations
27__all__ = [
28 "AdjustQuantumHelper",
29 "DeferredDatasetRef",
30 "InputQuantizedConnection",
31 "OutputQuantizedConnection",
32 "PipelineTaskConnections",
33 "ScalarError",
34 "iterConnections",
35 "ScalarError",
36]
38import dataclasses
39import itertools
40import string
41import warnings
42from collections import UserDict
43from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
44from dataclasses import dataclass
45from types import MappingProxyType, SimpleNamespace
46from typing import TYPE_CHECKING, Any
48from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
50from ._status import NoWorkFound
51from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
53if TYPE_CHECKING:
54 from .config import PipelineTaskConfig
57class ScalarError(TypeError):
58 """Exception raised when dataset type is configured as scalar
59 but there are multiple data IDs in a Quantum for that dataset.
60 """
63class PipelineTaskConnectionDict(UserDict):
64 """A special dict class used by `PipelineTaskConnectionMetaclass`.
66 This dict is used in `PipelineTaskConnection` class creation, as the
67 dictionary that is initially used as ``__dict__``. It exists to
68 intercept connection fields declared in a `PipelineTaskConnection`, and
69 what name is used to identify them. The names are then added to class
70 level list according to the connection type of the class attribute. The
71 names are also used as keys in a class level dictionary associated with
72 the corresponding class attribute. This information is a duplicate of
73 what exists in ``__dict__``, but provides a simple place to lookup and
74 iterate on only these variables.
75 """
77 def __init__(self, *args: Any, **kwargs: Any):
78 super().__init__(*args, **kwargs)
79 # Initialize class level variables used to track any declared
80 # class level variables that are instances of
81 # connectionTypes.BaseConnection
82 self.data["inputs"] = set()
83 self.data["prerequisiteInputs"] = set()
84 self.data["outputs"] = set()
85 self.data["initInputs"] = set()
86 self.data["initOutputs"] = set()
87 self.data["allConnections"] = {}
89 def __setitem__(self, name: str, value: Any) -> None:
90 if isinstance(value, BaseConnection):
91 if name in { 91 ↛ 101line 91 didn't jump to line 101, because the condition on line 91 was never true
92 "dimensions",
93 "inputs",
94 "prerequisiteInputs",
95 "outputs",
96 "initInputs",
97 "initOutputs",
98 "allConnections",
99 }:
100 # Guard against connections whose names are reserved.
101 raise AttributeError(f"Connection name {name!r} is reserved for internal use.")
102 if (previous := self.data.get(name)) is not None: 102 ↛ 105line 102 didn't jump to line 105, because the condition on line 102 was never true
103 # Guard against changing the type of an in inherited connection
104 # by first removing it from the set it's current in.
105 self.data[previous._connection_type_set].discard(name)
106 object.__setattr__(value, "varName", name)
107 self.data["allConnections"][name] = value
108 self.data[value._connection_type_set].add(name)
109 # defer to the default behavior
110 super().__setitem__(name, value)
113class PipelineTaskConnectionsMetaclass(type):
114 """Metaclass used in the declaration of PipelineTaskConnections classes"""
116 # We can annotate these attributes as `collections.abc.Set` to discourage
117 # undesirable modifications in type-checked code, since the internal code
118 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see
119 # these annotations anyway.
121 dimensions: Set[str]
122 """Set of dimension names that define the unit of work for this task.
124 Required and implied dependencies will automatically be expanded later and
125 need not be provided.
127 This is shadowed by an instance-level attribute on
128 `PipelineTaskConnections` instances.
129 """
131 inputs: Set[str]
132 """Set with the names of all `~connectionTypes.Input` connection
133 attributes.
135 This is updated automatically as class attributes are added. Note that
136 this attribute is shadowed by an instance-level attribute on
137 `PipelineTaskConnections` instances.
138 """
140 prerequisiteInputs: Set[str]
141 """Set with the names of all `~connectionTypes.PrerequisiteInput`
142 connection attributes.
144 See `inputs` for additional information.
145 """
147 outputs: Set[str]
148 """Set with the names of all `~connectionTypes.Output` connection
149 attributes.
151 See `inputs` for additional information.
152 """
154 initInputs: Set[str]
155 """Set with the names of all `~connectionTypes.InitInput` connection
156 attributes.
158 See `inputs` for additional information.
159 """
161 initOutputs: Set[str]
162 """Set with the names of all `~connectionTypes.InitOutput` connection
163 attributes.
165 See `inputs` for additional information.
166 """
168 allConnections: Mapping[str, BaseConnection]
169 """Mapping containing all connection attributes.
171 See `inputs` for additional information.
172 """
174 def __prepare__(name, bases, **kwargs): # noqa: N804
175 # Create an instance of our special dict to catch and track all
176 # variables that are instances of connectionTypes.BaseConnection
177 # Copy any existing connections from a parent class
178 dct = PipelineTaskConnectionDict()
179 for base in bases:
180 if isinstance(base, PipelineTaskConnectionsMetaclass): 180 ↛ 179line 180 didn't jump to line 179, because the condition on line 180 was never false
181 for name, value in base.allConnections.items(): 181 ↛ 182line 181 didn't jump to line 182, because the loop on line 181 never started
182 dct[name] = value
183 return dct
185 def __new__(cls, name, bases, dct, **kwargs):
186 dimensionsValueError = TypeError(
187 "PipelineTaskConnections class must be created with a dimensions "
188 "attribute which is an iterable of dimension names"
189 )
191 if name != "PipelineTaskConnections":
192 # Verify that dimensions are passed as a keyword in class
193 # declaration
194 if "dimensions" not in kwargs: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true
195 for base in bases:
196 if hasattr(base, "dimensions"):
197 kwargs["dimensions"] = base.dimensions
198 break
199 if "dimensions" not in kwargs:
200 raise dimensionsValueError
201 try:
202 if isinstance(kwargs["dimensions"], str): 202 ↛ 203line 202 didn't jump to line 203, because the condition on line 202 was never true
203 raise TypeError(
204 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma"
205 )
206 if not isinstance(kwargs["dimensions"], Iterable): 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true
207 raise TypeError("Dimensions must be iterable of dimensions")
208 dct["dimensions"] = set(kwargs["dimensions"])
209 except TypeError as exc:
210 raise dimensionsValueError from exc
211 # Lookup any python string templates that may have been used in the
212 # declaration of the name field of a class connection attribute
213 allTemplates = set()
214 stringFormatter = string.Formatter()
215 # Loop over all connections
216 for obj in dct["allConnections"].values():
217 nameValue = obj.name
218 # add all the parameters to the set of templates
219 for param in stringFormatter.parse(nameValue):
220 if param[1] is not None:
221 allTemplates.add(param[1])
223 # look up any template from base classes and merge them all
224 # together
225 mergeDict = {}
226 mergeDeprecationsDict = {}
227 for base in bases[::-1]:
228 if hasattr(base, "defaultTemplates"):
229 mergeDict.update(base.defaultTemplates)
230 if hasattr(base, "deprecatedTemplates"):
231 mergeDeprecationsDict.update(base.deprecatedTemplates)
232 if "defaultTemplates" in kwargs:
233 mergeDict.update(kwargs["defaultTemplates"])
234 if "deprecatedTemplates" in kwargs: 234 ↛ 235line 234 didn't jump to line 235, because the condition on line 234 was never true
235 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
236 if len(mergeDict) > 0:
237 kwargs["defaultTemplates"] = mergeDict
238 if len(mergeDeprecationsDict) > 0: 238 ↛ 239line 238 didn't jump to line 239, because the condition on line 238 was never true
239 kwargs["deprecatedTemplates"] = mergeDeprecationsDict
241 # Verify that if templated strings were used, defaults were
242 # supplied as an argument in the declaration of the connection
243 # class
244 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true
245 raise TypeError(
246 "PipelineTaskConnection class contains templated attribute names, but no "
247 "defaut templates were provided, add a dictionary attribute named "
248 "defaultTemplates which contains the mapping between template key and value"
249 )
250 if len(allTemplates) > 0:
251 # Verify all templates have a default, and throw if they do not
252 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys())
253 templateDifference = allTemplates.difference(defaultTemplateKeys)
254 if templateDifference: 254 ↛ 255line 254 didn't jump to line 255, because the condition on line 254 was never true
255 raise TypeError(f"Default template keys were not provided for {templateDifference}")
256 # Verify that templates do not share names with variable names
257 # used for a connection, this is needed because of how
258 # templates are specified in an associated config class.
259 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys()))
260 if len(nameTemplateIntersection) > 0: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true
261 raise TypeError(
262 "Template parameters cannot share names with Class attributes"
263 f" (conflicts are {nameTemplateIntersection})."
264 )
265 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
266 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})
268 # Convert all the connection containers into frozensets so they cannot
269 # be modified at the class scope
270 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"):
271 dct[connectionName] = frozenset(dct[connectionName])
272 # our custom dict type must be turned into an actual dict to be used in
273 # type.__new__
274 return super().__new__(cls, name, bases, dict(dct))
276 def __init__(cls, name, bases, dct, **kwargs):
277 # This overrides the default init to drop the kwargs argument. Python
278 # metaclasses will have this argument set if any kwargs are passes at
279 # class construction time, but should be consumed before calling
280 # __init__ on the type metaclass. This is in accordance with python
281 # documentation on metaclasses
282 super().__init__(name, bases, dct)
284 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections:
285 # MyPy appears not to really understand metaclass.__call__ at all, so
286 # we need to tell it to ignore __new__ and __init__ calls here.
287 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore
289 # Make mutable copies of all set-like class attributes so derived
290 # __init__ implementations can modify them in place.
291 instance.dimensions = set(cls.dimensions)
292 instance.inputs = set(cls.inputs)
293 instance.prerequisiteInputs = set(cls.prerequisiteInputs)
294 instance.outputs = set(cls.outputs)
295 instance.initInputs = set(cls.initInputs)
296 instance.initOutputs = set(cls.initOutputs)
298 # Set self.config. It's a bit strange that we claim to accept None but
299 # really just raise here, but it's not worth changing now.
300 from .config import PipelineTaskConfig # local import to avoid cycle
302 if config is None or not isinstance(config, PipelineTaskConfig):
303 raise ValueError(
304 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance"
305 )
306 instance.config = config
308 # Extract the template names that were defined in the config instance
309 # by looping over the keys of the defaultTemplates dict specified at
310 # class declaration time.
311 templateValues = {
312 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore
313 }
315 # We now assemble a mapping of all connection instances keyed by
316 # internal name, applying the configuration and templates to make new
317 # configurations from the class-attribute defaults. This will be
318 # private, but with a public read-only view. This mapping is what the
319 # descriptor interface of the class-level attributes will return when
320 # they are accessed on an instance. This is better than just assigning
321 # regular instance attributes as it makes it so removed connections
322 # cannot be accessed on instances, instead of having access to them
323 # silent fall through to the not-removed class connection instance.
324 instance._allConnections = {}
325 instance.allConnections = MappingProxyType(instance._allConnections)
326 for internal_name, connection in cls.allConnections.items():
327 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
328 instance_connection = dataclasses.replace(
329 connection,
330 name=dataset_type_name,
331 doc=(
332 connection.doc
333 if connection.deprecated is None
334 else f"{connection.doc}\n{connection.deprecated}"
335 ),
336 _deprecation_context=connection._deprecation_context,
337 )
338 instance._allConnections[internal_name] = instance_connection
340 # Finally call __init__. The base class implementation does nothing;
341 # we could have left some of the above implementation there (where it
342 # originated), but putting it here instead makes it hard for derived
343 # class implementors to get things into a weird state by delegating to
344 # super().__init__ in the wrong place, or by forgetting to do that
345 # entirely.
346 instance.__init__(config=config) # type: ignore
348 # Derived-class implementations may have changed the contents of the
349 # various kinds-of-connection sets; update allConnections to have keys
350 # that are a union of all those. We get values for the new
351 # allConnections from the attributes, since any dynamically added new
352 # ones will not be present in the old allConnections. Typically those
353 # getattrs will invoke the descriptors and get things from the old
354 # allConnections anyway. After processing each set we replace it with
355 # a frozenset.
356 updated_all_connections = {}
357 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"):
358 updated_connection_names = getattr(instance, attrName)
359 updated_all_connections.update(
360 {name: getattr(instance, name) for name in updated_connection_names}
361 )
362 # Setting these to frozenset is at odds with the type annotation,
363 # but MyPy can't tell because we're using setattr, and we want to
364 # lie to it anyway to get runtime guards against post-__init__
365 # mutation.
366 setattr(instance, attrName, frozenset(updated_connection_names))
367 # Update the existing dict in place, since we already have a view of
368 # that.
369 instance._allConnections.clear()
370 instance._allConnections.update(updated_all_connections)
372 for connection_name, obj in instance._allConnections.items():
373 if obj.deprecated is not None:
374 warnings.warn(
375 f"Connection {connection_name} with datasetType {obj.name} "
376 f"(from {obj._deprecation_context}): {obj.deprecated}",
377 FutureWarning,
378 stacklevel=1, # Report from this location.
379 )
381 # Freeze the connection instance dimensions now. This at odds with the
382 # type annotation, which says [mutable] `set`, just like the connection
383 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
384 # tell with those since we're using setattr for them.
385 instance.dimensions = frozenset(instance.dimensions) # type: ignore
387 return instance
390class QuantizedConnection(SimpleNamespace):
391 r"""A Namespace to map defined variable names of connections to the
392 associated `lsst.daf.butler.DatasetRef` objects.
394 This class maps the names used to define a connection on a
395 `PipelineTaskConnections` class to the corresponding
396 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum`
397 instance. This will be a quantum of execution based on the graph created
398 by examining all the connections defined on the
399 `PipelineTaskConnections` class.
400 """
402 def __init__(self, **kwargs):
403 # Create a variable to track what attributes are added. This is used
404 # later when iterating over this QuantizedConnection instance
405 object.__setattr__(self, "_attributes", set())
407 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None:
408 # Capture the attribute name as it is added to this object
409 self._attributes.add(name)
410 super().__setattr__(name, value)
412 def __delattr__(self, name):
413 object.__delattr__(self, name)
414 self._attributes.remove(name)
416 def __len__(self) -> int:
417 return len(self._attributes)
419 def __iter__(
420 self,
421 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]:
422 """Make an iterator for this `QuantizedConnection`.
424 Iterating over a `QuantizedConnection` will yield a tuple with the name
425 of an attribute and the value associated with that name. This is
426 similar to dict.items() but is on the namespace attributes rather than
427 dict keys.
428 """
429 yield from ((name, getattr(self, name)) for name in self._attributes)
431 def keys(self) -> Generator[str, None, None]:
432 """Return an iterator over all the attributes added to a
433 `QuantizedConnection` class
434 """
435 yield from self._attributes
438class InputQuantizedConnection(QuantizedConnection):
439 """Input variant of a `QuantizedConnection`."""
441 pass
444class OutputQuantizedConnection(QuantizedConnection):
445 """Output variant of a `QuantizedConnection`."""
447 pass
450@dataclass(frozen=True)
451class DeferredDatasetRef:
452 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a
453 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle`
454 instead of an in-memory dataset.
456 Parameters
457 ----------
458 datasetRef : `lsst.daf.butler.DatasetRef`
459 The `lsst.daf.butler.DatasetRef` that will be eventually used to
460 resolve a dataset
461 """
463 datasetRef: DatasetRef
465 @property
466 def datasetType(self) -> DatasetType:
467 """The dataset type for this dataset."""
468 return self.datasetRef.datasetType
470 @property
471 def dataId(self) -> DataCoordinate:
472 """The data ID for this dataset."""
473 return self.datasetRef.dataId
476class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
477 """PipelineTaskConnections is a class used to declare desired IO when a
478 PipelineTask is run by an activator
480 Parameters
481 ----------
482 config : `PipelineTaskConfig`
483 A `PipelineTaskConfig` class instance whose class has been configured
484 to use this `PipelineTaskConnections` class.
486 See Also
487 --------
488 iterConnections
490 Notes
491 -----
492 ``PipelineTaskConnection`` classes are created by declaring class
493 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
494 listed as follows:
496 * ``InitInput`` - Defines connections in a quantum graph which are used as
497 inputs to the ``__init__`` function of the `PipelineTask` corresponding
498 to this class
499 * ``InitOuput`` - Defines connections in a quantum graph which are to be
500 persisted using a butler at the end of the ``__init__`` function of the
501 `PipelineTask` corresponding to this class. The variable name used to
502 define this connection should be the same as an attribute name on the
503 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
504 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
505 a `PipelineTask` instance should have an attribute
506 ``self.outputSchema`` defined. Its value is what will be saved by the
507 activator framework.
508 * ``PrerequisiteInput`` - An input connection type that defines a
509 `lsst.daf.butler.DatasetType` that must be present at execution time,
510 but that will not be used during the course of creating the quantum
511 graph to be executed. These most often are things produced outside the
512 processing pipeline, such as reference catalogs.
513 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
514 in the ``run`` method of a `PipelineTask`. The name used to declare
515 class attribute must match a function argument name in the ``run``
516 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
517 defines an ``Input`` with the name ``calexp``, then the corresponding
518 signature should be ``PipelineTask.run(calexp, ...)``
519 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
520 execution of a `PipelineTask`. The name used to declare the connection
521 must correspond to an attribute of a `Struct` that is returned by a
522 `PipelineTask` ``run`` method. E.g. if an output connection is
523 defined with the name ``measCat``, then the corresponding
524 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
525 X matches the ``storageClass`` type defined on the output connection.
527 Attributes of these types can also be created, replaced, or deleted on the
528 `PipelineTaskConnections` instance in the ``__init__`` method, if more than
529 just the name depends on the configuration. It is preferred to define them
530 in the class when possible (even if configuration may cause the connection
531 to be removed from the instance).
533 The process of declaring a ``PipelineTaskConnection`` class involves
534 parameters passed in the declaration statement.
536 The first parameter is ``dimensions`` which is an iterable of strings which
537 defines the unit of processing the run method of a corresponding
538 `PipelineTask` will operate on. These dimensions must match dimensions that
539 exist in the butler registry which will be used in executing the
540 corresponding `PipelineTask`. The dimensions may be also modified in
541 subclass ``__init__`` methods if they need to depend on configuration.
543 The second parameter is labeled ``defaultTemplates`` and is conditionally
544 optional. The name attributes of connections can be specified as python
545 format strings, with named format arguments. If any of the name parameters
546 on connections defined in a `PipelineTaskConnections` class contain a
547 template, then a default template value must be specified in the
548 ``defaultTemplates`` argument. This is done by passing a dictionary with
549 keys corresponding to a template identifier, and values corresponding to
550 the value to use as a default when formatting the string. For example if
551 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then
552 ``defaultTemplates`` = {'input': 'deep'}.
554 Once a `PipelineTaskConnections` class is created, it is used in the
555 creation of a `PipelineTaskConfig`. This is further documented in the
556 documentation of `PipelineTaskConfig`. For the purposes of this
557 documentation, the relevant information is that the config class allows
558 configuration of connection names by users when running a pipeline.
560 Instances of a `PipelineTaskConnections` class are used by the pipeline
561 task execution framework to introspect what a corresponding `PipelineTask`
562 will require, and what it will produce.
564 Examples
565 --------
566 >>> from lsst.pipe.base import connectionTypes as cT
567 >>> from lsst.pipe.base import PipelineTaskConnections
568 >>> from lsst.pipe.base import PipelineTaskConfig
569 >>> class ExampleConnections(PipelineTaskConnections,
570 ... dimensions=("A", "B"),
571 ... defaultTemplates={"foo": "Example"}):
572 ... inputConnection = cT.Input(doc="Example input",
573 ... dimensions=("A", "B"),
574 ... storageClass=Exposure,
575 ... name="{foo}Dataset")
576 ... outputConnection = cT.Output(doc="Example output",
577 ... dimensions=("A", "B"),
578 ... storageClass=Exposure,
579 ... name="{foo}output")
580 >>> class ExampleConfig(PipelineTaskConfig,
581 ... pipelineConnections=ExampleConnections):
582 ... pass
583 >>> config = ExampleConfig()
584 >>> config.connections.foo = Modified
585 >>> config.connections.outputConnection = "TotallyDifferent"
586 >>> connections = ExampleConnections(config=config)
587 >>> assert(connections.inputConnection.name == "ModifiedDataset")
588 >>> assert(connections.outputConnection.name == "TotallyDifferent")
589 """
591 # We annotate these attributes as mutable sets because that's what they are
592 # inside derived ``__init__`` implementations and that's what matters most
593 # After that's done, the metaclass __call__ makes them into frozensets, but
594 # relatively little code interacts with them then, and that code knows not
595 # to try to modify them without having to be told that by mypy.
597 dimensions: set[str]
598 """Set of dimension names that define the unit of work for this task.
600 Required and implied dependencies will automatically be expanded later and
601 need not be provided.
603 This may be replaced or modified in ``__init__`` to change the dimensions
604 of the task. After ``__init__`` it will be a `frozenset` and may not be
605 replaced.
606 """
608 inputs: set[str]
609 """Set with the names of all `connectionTypes.Input` connection attributes.
611 This is updated automatically as class attributes are added, removed, or
612 replaced in ``__init__``. Removing entries from this set will cause those
613 connections to be removed after ``__init__`` completes, but this is
614 supported only for backwards compatibility; new code should instead just
615 delete the collection attributed directly. After ``__init__`` this will be
616 a `frozenset` and may not be replaced.
617 """
619 prerequisiteInputs: set[str]
620 """Set with the names of all `~connectionTypes.PrerequisiteInput`
621 connection attributes.
623 See `inputs` for additional information.
624 """
626 outputs: set[str]
627 """Set with the names of all `~connectionTypes.Output` connection
628 attributes.
630 See `inputs` for additional information.
631 """
633 initInputs: set[str]
634 """Set with the names of all `~connectionTypes.InitInput` connection
635 attributes.
637 See `inputs` for additional information.
638 """
640 initOutputs: set[str]
641 """Set with the names of all `~connectionTypes.InitOutput` connection
642 attributes.
644 See `inputs` for additional information.
645 """
647 allConnections: Mapping[str, BaseConnection]
648 """Mapping holding all connection attributes.
650 This is a read-only view that is automatically updated when connection
651 attributes are added, removed, or replaced in ``__init__``. It is also
652 updated after ``__init__`` completes to reflect changes in `inputs`,
653 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`.
654 """
656 _allConnections: dict[str, BaseConnection]
658 def __init__(self, *, config: PipelineTaskConfig | None = None):
659 pass
661 def __setattr__(self, name: str, value: Any) -> None:
662 if isinstance(value, BaseConnection):
663 previous = self._allConnections.get(name)
664 try:
665 getattr(self, value._connection_type_set).add(name)
666 except AttributeError:
667 # Attempt to call add on a frozenset, which is what these sets
668 # are after __init__ is done.
669 raise TypeError("Connections objects are frozen after construction.") from None
670 if previous is not None and value._connection_type_set != previous._connection_type_set:
671 # Connection has changed type, e.g. Input to PrerequisiteInput;
672 # update the sets accordingly. To be extra defensive about
673 # multiple assignments we use the type of the previous instance
674 # instead of assuming that's the same as the type of the self,
675 # which is just the default. Use discard instead of remove so
676 # manually removing from these sets first is never an error.
677 getattr(self, previous._connection_type_set).discard(name)
678 self._allConnections[name] = value
679 if hasattr(self.__class__, name):
680 # Don't actually set the attribute if this was a connection
681 # declared in the class; in that case we let the descriptor
682 # return the value we just added to allConnections.
683 return
684 # Actually add the attribute.
685 super().__setattr__(name, value)
687 def __delattr__(self, name):
688 """Descriptor delete method."""
689 previous = self._allConnections.get(name)
690 if previous is not None:
691 # Delete this connection's name from the appropriate set, which we
692 # have to get from the previous instance instead of assuming it's
693 # the same set that was appropriate for the class-level default.
694 # Use discard instead of remove so manually removing from these
695 # sets first is never an error.
696 try:
697 getattr(self, previous._connection_type_set).discard(name)
698 except AttributeError:
699 # Attempt to call discard on a frozenset, which is what these
700 # sets are after __init__ is done.
701 raise TypeError("Connections objects are frozen after construction.") from None
702 del self._allConnections[name]
703 if hasattr(self.__class__, name):
704 # Don't actually delete the attribute if this was a connection
705 # declared in the class; in that case we let the descriptor
706 # see that it's no longer present in allConnections.
707 return
708 # Actually delete the attribute.
709 super().__delattr__(name)
711 def buildDatasetRefs(
712 self, quantum: Quantum
713 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]:
714 """Build `QuantizedConnection` corresponding to input
715 `~lsst.daf.butler.Quantum`.
717 Parameters
718 ----------
719 quantum : `lsst.daf.butler.Quantum`
720 Quantum object which defines the inputs and outputs for a given
721 unit of processing.
723 Returns
724 -------
725 retVal : `tuple` of (`InputQuantizedConnection`,
726 `OutputQuantizedConnection`) Namespaces mapping attribute names
727 (identifiers of connections) to butler references defined in the
728 input `lsst.daf.butler.Quantum`.
729 """
730 inputDatasetRefs = InputQuantizedConnection()
731 outputDatasetRefs = OutputQuantizedConnection()
732 # operate on a reference object and an iterable of names of class
733 # connection attributes
734 for refs, names in zip(
735 (inputDatasetRefs, outputDatasetRefs),
736 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
737 strict=True,
738 ):
739 # get a name of a class connection attribute
740 for attributeName in names:
741 # get the attribute identified by name
742 attribute = getattr(self, attributeName)
743 # Branch if the attribute dataset type is an input
744 if attribute.name in quantum.inputs:
745 # if the dataset is marked to load deferred, wrap it in a
746 # DeferredDatasetRef
747 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
748 if attribute.deferLoad:
749 quantumInputRefs = [
750 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
751 ]
752 else:
753 quantumInputRefs = list(quantum.inputs[attribute.name])
754 # Unpack arguments that are not marked multiples (list of
755 # length one)
756 if not attribute.multiple:
757 if len(quantumInputRefs) > 1:
758 raise ScalarError(
759 "Received multiple datasets "
760 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
761 f"for scalar connection {attributeName} "
762 f"({quantumInputRefs[0].datasetType.name}) "
763 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
764 )
765 if len(quantumInputRefs) == 0:
766 continue
767 setattr(refs, attributeName, quantumInputRefs[0])
768 else:
769 # Add to the QuantizedConnection identifier
770 setattr(refs, attributeName, quantumInputRefs)
771 # Branch if the attribute dataset type is an output
772 elif attribute.name in quantum.outputs:
773 value = quantum.outputs[attribute.name]
774 # Unpack arguments that are not marked multiples (list of
775 # length one)
776 if not attribute.multiple:
777 setattr(refs, attributeName, value[0])
778 else:
779 setattr(refs, attributeName, value)
780 # Specified attribute is not in inputs or outputs dont know how
781 # to handle, throw
782 else:
783 raise ValueError(
784 f"Attribute with name {attributeName} has no counterpart in input quantum"
785 )
786 return inputDatasetRefs, outputDatasetRefs
788 def adjustQuantum(
789 self,
790 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
791 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
792 label: str,
793 data_id: DataCoordinate,
794 ) -> tuple[
795 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
796 Mapping[str, tuple[Output, Collection[DatasetRef]]],
797 ]:
798 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
799 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
800 of the activator.
802 Parameters
803 ----------
804 inputs : `dict`
805 Dictionary whose keys are an input (regular or prerequisite)
806 connection name and whose values are a tuple of the connection
807 instance and a collection of associated
808 `~lsst.daf.butler.DatasetRef` objects.
809 The exact type of the nested collections is unspecified; it can be
810 assumed to be multi-pass iterable and support `len` and ``in``, but
811 it should not be mutated in place. In contrast, the outer
812 dictionaries are guaranteed to be temporary copies that are true
813 `dict` instances, and hence may be modified and even returned; this
814 is especially useful for delegating to `super` (see notes below).
815 outputs : `~collections.abc.Mapping`
816 Mapping of output datasets, with the same structure as ``inputs``.
817 label : `str`
818 Label for this task in the pipeline (should be used in all
819 diagnostic messages).
820 data_id : `lsst.daf.butler.DataCoordinate`
821 Data ID for this quantum in the pipeline (should be used in all
822 diagnostic messages).
824 Returns
825 -------
826 adjusted_inputs : `~collections.abc.Mapping`
827 Mapping of the same form as ``inputs`` with updated containers of
828 input `~lsst.daf.butler.DatasetRef` objects. Connections that are
829 not changed should not be returned at all. Datasets may only be
830 removed, not added. Nested collections may be of any multi-pass
831 iterable type, and the order of iteration will set the order of
832 iteration within `PipelineTask.runQuantum`.
833 adjusted_outputs : `~collections.abc.Mapping`
834 Mapping of updated output datasets, with the same structure and
835 interpretation as ``adjusted_inputs``.
837 Raises
838 ------
839 ScalarError
840 Raised if any `Input` or `PrerequisiteInput` connection has
841 ``multiple`` set to `False`, but multiple datasets.
842 NoWorkFound
843 Raised to indicate that this quantum should not be run; not enough
844 datasets were found for a regular `Input` connection, and the
845 quantum should be pruned or skipped.
846 FileNotFoundError
847 Raised to cause QuantumGraph generation to fail (with the message
848 included in this exception); not enough datasets were found for a
849 `PrerequisiteInput` connection.
851 Notes
852 -----
853 The base class implementation performs important checks. It always
854 returns an empty mapping (i.e. makes no adjustments). It should
855 always called be via `super` by custom implementations, ideally at the
856 end of the custom implementation with already-adjusted mappings when
857 any datasets are actually dropped, e.g.:
859 .. code-block:: python
861 def adjustQuantum(self, inputs, outputs, label, data_id):
862 # Filter out some dataset refs for one connection.
863 connection, old_refs = inputs["my_input"]
864 new_refs = [ref for ref in old_refs if ...]
865 adjusted_inputs = {"my_input", (connection, new_refs)}
866 # Update the original inputs so we can pass them to super.
867 inputs.update(adjusted_inputs)
868 # Can ignore outputs from super because they are guaranteed
869 # to be empty.
870 super().adjustQuantum(inputs, outputs, label_data_id)
871 # Return only the connections we modified.
872 return adjusted_inputs, {}
874 Removing outputs here is guaranteed to affect what is actually
875 passed to `PipelineTask.runQuantum`, but its effect on the larger
876 graph may be deferred to execution, depending on the context in
877 which `adjustQuantum` is being run: if one quantum removes an output
878 that is needed by a second quantum as input, the second quantum may not
879 be adjusted (and hence pruned or skipped) until that output is actually
880 found to be missing at execution time.
882 Tasks that desire zip-iteration consistency between any combinations of
883 connections that have the same data ID should generally implement
884 `adjustQuantum` to achieve this, even if they could also run that
885 logic during execution; this allows the system to see outputs that will
886 not be produced because the corresponding input is missing as early as
887 possible.
888 """
889 for name, (input_connection, refs) in inputs.items():
890 dataset_type_name = input_connection.name
891 if not input_connection.multiple and len(refs) > 1:
892 raise ScalarError(
893 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
894 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) "
895 f"for quantum data ID {data_id}."
896 )
897 if len(refs) < input_connection.minimum:
898 if isinstance(input_connection, PrerequisiteInput):
899 # This branch should only be possible during QG generation,
900 # or if someone deleted the dataset between making the QG
901 # and trying to run it. Either one should be a hard error.
902 raise FileNotFoundError(
903 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} "
904 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID "
905 f"{data_id}."
906 )
907 else:
908 # This branch should be impossible during QG generation,
909 # because that algorithm can only make quanta whose inputs
910 # are either already present or should be created during
911 # execution. It can trigger during execution if the input
912 # wasn't actually created by an upstream task in the same
913 # graph.
914 raise NoWorkFound(label, name, input_connection)
915 for name, (output_connection, refs) in outputs.items():
916 dataset_type_name = output_connection.name
917 if not output_connection.multiple and len(refs) > 1:
918 raise ScalarError(
919 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
920 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) "
921 f"for quantum data ID {data_id}."
922 )
923 return {}, {}
925 def getSpatialBoundsConnections(self) -> Iterable[str]:
926 """Return the names of regular input and output connections whose data
927 IDs should be used to compute the spatial bounds of this task's quanta.
929 The spatial bound for a quantum is defined as the union of the regions
930 of all data IDs of all connections returned here, along with the region
931 of the quantum data ID (if the task has spatial dimensions).
933 Returns
934 -------
935 connection_names : `collections.abc.Iterable` [ `str` ]
936 Names of collections with spatial dimensions. These are the
937 task-internal connection names, not butler dataset type names.
939 Notes
940 -----
941 The spatial bound is used to search for prerequisite inputs that have
942 skypix dimensions. The default implementation returns an empty
943 iterable, which is usually sufficient for tasks with spatial
944 dimensions, but if a task's inputs or outputs are associated with
945 spatial regions that extend beyond the quantum data ID's region, this
946 method may need to be overridden to expand the set of prerequisite
947 inputs found.
949 Tasks that do not have spatial dimensions that have skypix prerequisite
950 inputs should always override this method, as the default spatial
951 bounds otherwise cover the full sky.
952 """
953 return ()
955 def getTemporalBoundsConnections(self) -> Iterable[str]:
956 """Return the names of regular input and output connections whose data
957 IDs should be used to compute the temporal bounds of this task's
958 quanta.
960 The temporal bound for a quantum is defined as the union of the
961 timespans of all data IDs of all connections returned here, along with
962 the timespan of the quantum data ID (if the task has temporal
963 dimensions).
965 Returns
966 -------
967 connection_names : `collections.abc.Iterable` [ `str` ]
968 Names of collections with temporal dimensions. These are the
969 task-internal connection names, not butler dataset type names.
971 Notes
972 -----
973 The temporal bound is used to search for prerequisite inputs that are
974 calibration datasets. The default implementation returns an empty
975 iterable, which is usually sufficient for tasks with temporal
976 dimensions, but if a task's inputs or outputs are associated with
977 timespans that extend beyond the quantum data ID's timespan, this
978 method may need to be overridden to expand the set of prerequisite
979 inputs found.
981 Tasks that do not have temporal dimensions that do not implement this
982 method will use an infinite timespan for any calibration lookups.
983 """
984 return ()
987def iterConnections(
988 connections: PipelineTaskConnections, connectionType: str | Iterable[str]
989) -> Generator[BaseConnection, None, None]:
990 """Create an iterator over the selected connections type which yields
991 all the defined connections of that type.
993 Parameters
994 ----------
995 connections : `PipelineTaskConnections`
996 An instance of a `PipelineTaskConnections` object that will be iterated
997 over.
998 connectionType : `str`
999 The type of connections to iterate over, valid values are inputs,
1000 outputs, prerequisiteInputs, initInputs, initOutputs.
1002 Yields
1003 ------
1004 connection: `~.connectionTypes.BaseConnection`
1005 A connection defined on the input connections object of the type
1006 supplied. The yielded value Will be an derived type of
1007 `~.connectionTypes.BaseConnection`.
1008 """
1009 if isinstance(connectionType, str):
1010 connectionType = (connectionType,)
1011 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType):
1012 yield getattr(connections, name)
1015@dataclass
1016class AdjustQuantumHelper:
1017 """Helper class for calling `PipelineTaskConnections.adjustQuantum`.
1019 This class holds `input` and `output` mappings in the form used by
1020 `Quantum` and execution harness code, i.e. with
1021 `~lsst.daf.butler.DatasetType` keys, translating them to and from the
1022 connection-oriented mappings used inside `PipelineTaskConnections`.
1023 """
1025 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1026 """Mapping of regular input and prerequisite input datasets, grouped by
1027 `~lsst.daf.butler.DatasetType`.
1028 """
1030 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
1031 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
1032 """
1034 inputs_adjusted: bool = False
1035 """Whether any inputs were removed in the last call to `adjust_in_place`.
1036 """
1038 outputs_adjusted: bool = False
1039 """Whether any outputs were removed in the last call to `adjust_in_place`.
1040 """
1042 def adjust_in_place(
1043 self,
1044 connections: PipelineTaskConnections,
1045 label: str,
1046 data_id: DataCoordinate,
1047 ) -> None:
1048 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self``
1049 with its results.
1051 Parameters
1052 ----------
1053 connections : `PipelineTaskConnections`
1054 Instance on which to call `~PipelineTaskConnections.adjustQuantum`.
1055 label : `str`
1056 Label for this task in the pipeline (should be used in all
1057 diagnostic messages).
1058 data_id : `lsst.daf.butler.DataCoordinate`
1059 Data ID for this quantum in the pipeline (should be used in all
1060 diagnostic messages).
1061 """
1062 # Translate self's DatasetType-keyed, Quantum-oriented mappings into
1063 # connection-keyed, PipelineTask-oriented mappings.
1064 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {}
1065 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {}
1066 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs):
1067 connection = getattr(connections, name)
1068 dataset_type_name = connection.name
1069 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ())))
1070 for name in itertools.chain(connections.outputs):
1071 connection = getattr(connections, name)
1072 dataset_type_name = connection.name
1073 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ())))
1074 # Actually call adjustQuantum.
1075 # MyPy correctly complains that this call is not quite legal, but the
1076 # method docs explain exactly what's expected and it's the behavior we
1077 # want. It'd be nice to avoid this if we ever have to change the
1078 # interface anyway, but not an immediate problem.
1079 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum(
1080 inputs_by_connection, # type: ignore
1081 outputs_by_connection, # type: ignore
1082 label,
1083 data_id,
1084 )
1085 # Translate adjustments to DatasetType-keyed, Quantum-oriented form,
1086 # installing new mappings in self if necessary.
1087 if adjusted_inputs_by_connection:
1088 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
1089 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
1090 dataset_type_name = connection.name
1091 if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
1092 raise RuntimeError(
1093 f"adjustQuantum implementation for task with label {label} returned {name} "
1094 f"({dataset_type_name}) input datasets that are not a subset of those "
1095 f"it was given for data ID {data_id}."
1096 )
1097 adjusted_inputs[dataset_type_name] = tuple(updated_refs)
1098 self.inputs = adjusted_inputs.freeze()
1099 self.inputs_adjusted = True
1100 else:
1101 self.inputs_adjusted = False
1102 if adjusted_outputs_by_connection:
1103 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
1104 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
1105 dataset_type_name = connection.name
1106 if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
1107 raise RuntimeError(
1108 f"adjustQuantum implementation for task with label {label} returned {name} "
1109 f"({dataset_type_name}) output datasets that are not a subset of those "
1110 f"it was given for data ID {data_id}."
1111 )
1112 adjusted_outputs[dataset_type_name] = tuple(updated_refs)
1113 self.outputs = adjusted_outputs.freeze()
1114 self.outputs_adjusted = True
1115 else:
1116 self.outputs_adjusted = False