Coverage for python/lsst/pipe/base/pipelineIR.py: 21%
387 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:52 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:52 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = (
30 "ConfigIR",
31 "ContractError",
32 "ContractIR",
33 "ImportIR",
34 "LabeledSubset",
35 "ParametersIR",
36 "PipelineIR",
37 "TaskIR",
38)
40import copy
41import enum
42import os
43import re
44import warnings
45from collections import Counter
46from collections.abc import Generator, Hashable, Iterable, MutableMapping
47from dataclasses import dataclass, field
48from typing import Any, Literal
50import yaml
51from lsst.resources import ResourcePath, ResourcePathExpression
52from lsst.utils.introspection import find_outside_stacklevel
55class _Tags(enum.Enum):
56 KeepInstrument = enum.auto()
59class PipelineYamlLoader(yaml.SafeLoader):
60 """Specialized version of yaml's SafeLoader.
62 It checks and raises an exception if it finds that there are multiple
63 instances of the same key found inside a pipeline file at a given scope.
64 """
66 def construct_mapping(self, node: yaml.MappingNode, deep: bool = False) -> dict[Hashable, Any]:
67 # do the call to super first so that it can do all the other forms of
68 # checking on this node. If you check the uniqueness of keys first
69 # it would save the work that super does in the case of a failure, but
70 # it might fail in the case that the node was the incorrect node due
71 # to a parsing error, and the resulting exception would be difficult to
72 # understand.
73 mapping = super().construct_mapping(node, deep)
74 # Check if there are any duplicate keys
75 all_keys = Counter(key_node.value for key_node, _ in node.value)
76 duplicates = {k for k, i in all_keys.items() if i != 1}
77 if duplicates:
78 raise KeyError(
79 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times"
80 )
81 return mapping
84class MultilineStringDumper(yaml.Dumper):
85 """Custom YAML dumper that makes multi-line strings use the '|'
86 continuation style instead of unreadable newlines and tons of quotes.
88 Basic approach is taken from
89 https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data,
90 but is written as a Dumper subclass to make its effects non-global (vs
91 `yaml.add_representer`).
92 """
94 def represent_scalar(self, tag: str, value: Any, style: str | None = None) -> yaml.ScalarNode:
95 if style is None and tag == "tag:yaml.org,2002:str" and len(value.splitlines()) > 1:
96 style = "|"
97 return super().represent_scalar(tag, value, style)
100class ContractError(Exception):
101 """An exception that is raised when a pipeline contract is not satisfied"""
103 pass
106@dataclass
107class ContractIR:
108 """Intermediate representation of configuration contracts read from a
109 pipeline yaml file.
110 """
112 contract: str
113 """A string of python code representing one or more conditions on configs
114 in a pipeline. This code-as-string should, once evaluated, should be True
115 if the configs are fine, and False otherwise.
116 """
117 msg: str | None = None
118 """An optional message to be shown to the user if a contract fails
119 """
121 def to_primitives(self) -> dict[str, str]:
122 """Convert to a representation used in yaml serialization"""
123 accumulate = {"contract": self.contract}
124 if self.msg is not None:
125 accumulate["msg"] = self.msg
126 return accumulate
128 def __eq__(self, other: object) -> bool:
129 if not isinstance(other, ContractIR):
130 return False
131 return self.contract == other.contract and self.msg == other.msg
134@dataclass
135class LabeledSubset:
136 """Intermediate representation of named subset of task labels read from
137 a pipeline yaml file.
138 """
140 label: str
141 """The label used to identify the subset of task labels.
142 """
143 subset: set[str]
144 """A set of task labels contained in this subset.
145 """
146 description: str | None
147 """A description of what this subset of tasks is intended to do
148 """
150 @staticmethod
151 def from_primitives(label: str, value: list[str] | dict) -> LabeledSubset:
152 """Generate `LabeledSubset` objects given a properly formatted object
153 that as been created by a yaml loader.
155 Parameters
156 ----------
157 label : `str`
158 The label that will be used to identify this labeled subset.
159 value : `list` of `str` or `dict`
160 Object returned from loading a labeled subset section from a yaml
161 document.
163 Returns
164 -------
165 labeledSubset : `LabeledSubset`
166 A `LabeledSubset` object build from the inputs.
168 Raises
169 ------
170 ValueError
171 Raised if the value input is not properly formatted for parsing
172 """
173 if isinstance(value, MutableMapping):
174 subset = value.pop("subset", None)
175 if subset is None:
176 raise ValueError(
177 "If a labeled subset is specified as a mapping, it must contain the key 'subset'"
178 )
179 description = value.pop("description", None)
180 elif isinstance(value, Iterable):
181 subset = value
182 description = None
183 else:
184 raise ValueError(
185 f"There was a problem parsing the labeled subset {label}, make sure the "
186 "definition is either a valid yaml list, or a mapping with keys "
187 "(subset, description) where subset points to a yaml list, and description is "
188 "associated with a string"
189 )
190 return LabeledSubset(label, set(subset), description)
192 def to_primitives(self) -> dict[str, list[str] | str]:
193 """Convert to a representation used in yaml serialization."""
194 accumulate: dict[str, list[str] | str] = {"subset": list(self.subset)}
195 if self.description is not None:
196 accumulate["description"] = self.description
197 return accumulate
200@dataclass
201class ParametersIR:
202 """Intermediate representation of parameters that are global to a pipeline.
204 Parameters
205 ----------
206 mapping : `dict` [`str`, `str`]
207 A mutable mapping of identifiers as keys, and shared configuration
208 as values.
210 Notes
211 -----
212 These parameters are specified under a top level key named ``parameters``
213 and are declared as a yaml mapping. These entries can then be used inside
214 task configuration blocks to specify configuration values. They may not be
215 used in the special ``file`` or ``python`` blocks.
217 Examples
218 --------
219 .. code-block:: yaml
221 \u200bparameters:
222 shared_value: 14
223 tasks:
224 taskA:
225 class: modA
226 config:
227 field1: parameters.shared_value
228 taskB:
229 class: modB
230 config:
231 field2: parameters.shared_value
232 """
234 mapping: MutableMapping[str, Any]
235 """A mutable mapping of identifiers as keys, and shared configuration
236 as values.
237 """
239 def update(self, other: ParametersIR | None) -> None:
240 if other is not None:
241 self.mapping.update(other.mapping)
243 def to_primitives(self) -> MutableMapping[str, str]:
244 """Convert to a representation used in yaml serialization"""
245 return self.mapping
247 def __contains__(self, value: str) -> bool:
248 return value in self.mapping
250 def __getitem__(self, item: str) -> Any:
251 return self.mapping[item]
253 def __bool__(self) -> bool:
254 return bool(self.mapping)
257@dataclass
258class ConfigIR:
259 """Intermediate representation of configurations read from a pipeline yaml
260 file.
261 """
263 python: str | None = None
264 """A string of python code that is used to modify a configuration. This can
265 also be None if there are no modifications to do.
266 """
267 dataId: dict | None = None
268 """A dataId that is used to constrain these config overrides to only quanta
269 with matching dataIds. This field can be None if there is no constraint.
270 This is currently an unimplemented feature, and is placed here for future
271 use.
272 """
273 file: list[str] = field(default_factory=list)
274 """A list of paths which points to a file containing config overrides to be
275 applied. This value may be an empty list if there are no overrides to
276 apply.
277 """
278 rest: dict = field(default_factory=dict)
279 """This is a dictionary of key value pairs, where the keys are strings
280 corresponding to qualified fields on a config to override, and the values
281 are strings representing the values to apply.
282 """
284 def to_primitives(self) -> dict[str, str | dict | list[str]]:
285 """Convert to a representation used in yaml serialization"""
286 accumulate = {}
287 for name in ("python", "dataId", "file"):
288 # if this attribute is thruthy add it to the accumulation
289 # dictionary
290 if getattr(self, name):
291 accumulate[name] = getattr(self, name)
292 # Add the dictionary containing the rest of the config keys to the
293 # # accumulated dictionary
294 accumulate.update(self.rest)
295 return accumulate
297 def formatted(self, parameters: ParametersIR) -> ConfigIR:
298 """Return a new ConfigIR object that is formatted according to the
299 specified parameters
301 Parameters
302 ----------
303 parameters : `ParametersIR`
304 Object that contains variable mappings used in substitution.
306 Returns
307 -------
308 config : `ConfigIR`
309 A new ConfigIR object formatted with the input parameters
310 """
311 new_config = copy.deepcopy(self)
312 for key, value in new_config.rest.items():
313 if not isinstance(value, str):
314 continue
315 match = re.match("parameters[.](.*)", value)
316 if match and match.group(1) in parameters:
317 new_config.rest[key] = parameters[match.group(1)]
318 if match and match.group(1) not in parameters:
319 warnings.warn(
320 f"config {key} contains value {match.group(0)} which is formatted like a "
321 "Pipeline parameter but was not found within the Pipeline, if this was not "
322 "intentional, check for a typo",
323 stacklevel=find_outside_stacklevel("lsst.pipe.base"),
324 )
325 return new_config
327 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
328 """Merge another instance of a `ConfigIR` into this instance if
329 possible. This function returns a generator that is either self
330 if the configs were merged, or self, and other_config if that could
331 not be merged.
333 Parameters
334 ----------
335 other_config : `ConfigIR`
336 An instance of `ConfigIR` to merge into this instance.
338 Returns
339 -------
340 Generator : `ConfigIR`
341 A generator containing either self, or self and other_config if
342 the configs could be merged or not respectively.
343 """
344 # Verify that the config blocks can be merged
345 if (
346 self.dataId != other_config.dataId
347 or self.python
348 or other_config.python
349 or self.file
350 or other_config.file
351 ):
352 yield from (self, other_config)
353 return
355 # create a set of all keys, and verify two keys do not have different
356 # values
357 key_union = self.rest.keys() & other_config.rest.keys()
358 for key in key_union:
359 if self.rest[key] != other_config.rest[key]:
360 yield from (self, other_config)
361 return
362 self.rest.update(other_config.rest)
364 # Combine the lists of override files to load
365 self_file_set = set(self.file)
366 other_file_set = set(other_config.file)
367 self.file = list(self_file_set.union(other_file_set))
369 yield self
371 def __eq__(self, other: object) -> bool:
372 if not isinstance(other, ConfigIR):
373 return False
374 return all(
375 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest")
376 )
379@dataclass
380class TaskIR:
381 """Intermediate representation of tasks read from a pipeline yaml file."""
383 label: str
384 """An identifier used to refer to a task.
385 """
386 klass: str
387 """A string containing a fully qualified python class to be run in a
388 pipeline.
389 """
390 config: list[ConfigIR] | None = None
391 """list of all configs overrides associated with this task, and may be
392 `None` if there are no config overrides.
393 """
395 def to_primitives(self) -> dict[str, str | list[dict]]:
396 """Convert to a representation used in yaml serialization"""
397 accumulate: dict[str, str | list[dict]] = {"class": self.klass}
398 if self.config:
399 accumulate["config"] = [c.to_primitives() for c in self.config]
400 return accumulate
402 def add_or_update_config(self, other_config: ConfigIR) -> None:
403 """Add a `ConfigIR` to this task if one is not present. Merges configs
404 if there is a `ConfigIR` present and the dataId keys of both configs
405 match, otherwise adds a new entry to the config list. The exception to
406 the above is that if either the last config or other_config has a
407 python block, then other_config is always added, as python blocks can
408 modify configs in ways that cannot be predicted.
410 Parameters
411 ----------
412 other_config : `ConfigIR`
413 A `ConfigIR` instance to add or merge into the config attribute of
414 this task.
415 """
416 if not self.config:
417 self.config = [other_config]
418 return
419 self.config.extend(self.config.pop().maybe_merge(other_config))
421 def __eq__(self, other: object) -> bool:
422 if not isinstance(other, TaskIR):
423 return False
424 return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config"))
427@dataclass
428class ImportIR:
429 """An intermediate representation of imported pipelines"""
431 location: str
432 """This is the location of the pipeline to inherit. The path should be
433 specified as an absolute path. Environment variables may be used in the
434 path and should be specified as a python string template, with the name of
435 the environment variable inside braces.
436 """
437 include: list[str] | None = None
438 """list of tasks that should be included when inheriting this pipeline.
439 Either the include or exclude attributes may be specified, but not both.
440 """
441 exclude: list[str] | None = None
442 """list of tasks that should be excluded when inheriting this pipeline.
443 Either the include or exclude attributes may be specified, but not both.
444 """
445 importContracts: bool = True
446 """Boolean attribute to dictate if contracts should be inherited with the
447 pipeline or not.
448 """
449 instrument: Literal[_Tags.KeepInstrument] | str | None = _Tags.KeepInstrument
450 """Instrument to assign to the Pipeline at import. The default value of
451 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is
452 declared with will not be modified. setting this value to None will drop
453 any declared instrument prior to import.
454 """
456 def toPipelineIR(self) -> "PipelineIR":
457 """Load in the Pipeline specified by this object, and turn it into a
458 PipelineIR instance.
460 Returns
461 -------
462 pipeline : `PipelineIR`
463 A pipeline generated from the imported pipeline file
464 """
465 if self.include and self.exclude:
466 raise ValueError(
467 "An include list and an exclude list cannot both be specified"
468 " when declaring a pipeline import."
469 )
470 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location))
471 if self.instrument is not _Tags.KeepInstrument:
472 tmp_pipeline.instrument = self.instrument
474 included_labels = set()
475 for label in tmp_pipeline.tasks:
476 if (
477 (self.include and label in self.include)
478 or (self.exclude and label not in self.exclude)
479 or (self.include is None and self.exclude is None)
480 ):
481 included_labels.add(label)
483 # Handle labeled subsets being specified in the include or exclude
484 # list, adding or removing labels.
485 if self.include is not None:
486 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include
487 for label in subsets_in_include:
488 included_labels.update(tmp_pipeline.labeled_subsets[label].subset)
490 elif self.exclude is not None:
491 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude
492 for label in subsets_in_exclude:
493 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset)
495 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels)
497 if not self.importContracts:
498 tmp_pipeline.contracts = []
500 return tmp_pipeline
502 def __eq__(self, other: object) -> bool:
503 if not isinstance(other, ImportIR):
504 return False
505 return all(
506 getattr(self, attr) == getattr(other, attr)
507 for attr in ("location", "include", "exclude", "importContracts")
508 )
511class PipelineIR:
512 """Intermediate representation of a pipeline definition
514 Parameters
515 ----------
516 loaded_yaml : `dict`
517 A dictionary which matches the structure that would be produced by a
518 yaml reader which parses a pipeline definition document
520 Raises
521 ------
522 ValueError
523 Raised if:
525 - a pipeline is declared without a description;
526 - no tasks are declared in a pipeline, and no pipelines are to be
527 inherited;
528 - more than one instrument is specified;
529 - more than one inherited pipeline share a label.
530 """
532 def __init__(self, loaded_yaml: dict[str, Any]):
533 # Check required fields are present
534 if "description" not in loaded_yaml:
535 raise ValueError("A pipeline must be declared with a description")
536 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2:
537 raise ValueError("A pipeline must be declared with one or more tasks")
539 # These steps below must happen in this call order
541 # Process pipeline description
542 self.description = loaded_yaml.pop("description")
544 # Process tasks
545 self._read_tasks(loaded_yaml)
547 # Process instrument keys
548 inst = loaded_yaml.pop("instrument", None)
549 if isinstance(inst, list):
550 raise ValueError("Only one top level instrument can be defined in a pipeline")
551 self.instrument: str | None = inst
553 # Process any contracts
554 self._read_contracts(loaded_yaml)
556 # Process any defined parameters
557 self._read_parameters(loaded_yaml)
559 # Process any named label subsets
560 self._read_labeled_subsets(loaded_yaml)
562 # Process any inherited pipelines
563 self._read_imports(loaded_yaml)
565 # verify named subsets, must be done after inheriting
566 self._verify_labeled_subsets()
568 def _read_contracts(self, loaded_yaml: dict[str, Any]) -> None:
569 """Process the contracts portion of the loaded yaml document
571 Parameters
572 ----------
573 loaded_yaml : `dict`
574 A dictionary which matches the structure that would be produced by
575 a yaml reader which parses a pipeline definition document
576 """
577 loaded_contracts = loaded_yaml.pop("contracts", [])
578 if isinstance(loaded_contracts, str):
579 loaded_contracts = [loaded_contracts]
580 self.contracts: list[ContractIR] = []
581 for contract in loaded_contracts:
582 if isinstance(contract, dict):
583 self.contracts.append(ContractIR(**contract))
584 if isinstance(contract, str):
585 self.contracts.append(ContractIR(contract=contract))
587 def _read_parameters(self, loaded_yaml: dict[str, Any]) -> None:
588 """Process the parameters portion of the loaded yaml document
590 Parameters
591 ----------
592 loaded_yaml : `dict`
593 A dictionary which matches the structure that would be produced by
594 a yaml reader which parses a pipeline definition document
595 """
596 loaded_parameters = loaded_yaml.pop("parameters", {})
597 if not isinstance(loaded_parameters, dict):
598 raise ValueError("The parameters section must be a yaml mapping")
599 self.parameters = ParametersIR(loaded_parameters)
601 def _read_labeled_subsets(self, loaded_yaml: dict[str, Any]) -> None:
602 """Process the subsets portion of the loaded yaml document
604 Parameters
605 ----------
606 loaded_yaml: `MutableMapping`
607 A dictionary which matches the structure that would be produced
608 by a yaml reader which parses a pipeline definition document
609 """
610 loaded_subsets = loaded_yaml.pop("subsets", {})
611 self.labeled_subsets: dict[str, LabeledSubset] = {}
612 if not loaded_subsets and "subset" in loaded_yaml:
613 raise ValueError("Top level key should be subsets and not subset, add an s")
614 for key, value in loaded_subsets.items():
615 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value)
617 def _verify_labeled_subsets(self) -> None:
618 """Verify that all the labels in each named subset exist within the
619 pipeline.
620 """
621 # Verify that all labels defined in a labeled subset are in the
622 # Pipeline
623 for labeled_subset in self.labeled_subsets.values():
624 if not labeled_subset.subset.issubset(self.tasks.keys()):
625 raise ValueError(
626 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the "
627 "declared pipeline"
628 )
629 # Verify subset labels are not already task labels
630 label_intersection = self.labeled_subsets.keys() & self.tasks.keys()
631 if label_intersection:
632 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}")
634 def _read_imports(self, loaded_yaml: dict[str, Any]) -> None:
635 """Process the inherits portion of the loaded yaml document
637 Parameters
638 ----------
639 loaded_yaml : `dict`
640 A dictionary which matches the structure that would be produced by
641 a yaml reader which parses a pipeline definition document
642 """
644 def process_args(argument: str | dict) -> dict:
645 if isinstance(argument, str):
646 return {"location": argument}
647 elif isinstance(argument, dict):
648 if "exclude" in argument and isinstance(argument["exclude"], str):
649 argument["exclude"] = [argument["exclude"]]
650 if "include" in argument and isinstance(argument["include"], str):
651 argument["include"] = [argument["include"]]
652 if "instrument" in argument and argument["instrument"] == "None":
653 argument["instrument"] = None
654 return argument
656 if not {"inherits", "imports"} - loaded_yaml.keys():
657 raise ValueError("Cannot define both inherits and imports sections, use imports")
658 tmp_import = loaded_yaml.pop("inherits", None)
659 if tmp_import is None:
660 tmp_import = loaded_yaml.pop("imports", None)
661 else:
662 raise ValueError("The 'inherits' key is not supported. Please use the key 'imports' instead")
663 if tmp_import is None:
664 self.imports: list[ImportIR] = []
665 elif isinstance(tmp_import, list):
666 self.imports = [ImportIR(**process_args(args)) for args in tmp_import]
667 else:
668 self.imports = [ImportIR(**process_args(tmp_import))]
670 self.merge_pipelines([fragment.toPipelineIR() for fragment in self.imports])
672 def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
673 """Merge one or more other `PipelineIR` objects into this object.
675 Parameters
676 ----------
677 pipelines : `~collections.abc.Iterable` of `PipelineIR` objects
678 An `~collections.abc.Iterable` that contains one or more
679 `PipelineIR` objects to merge into this object.
681 Raises
682 ------
683 ValueError
684 Raised if there is a conflict in instrument specifications.
685 Raised if a task label appears in more than one of the input
686 `PipelineIR` objects which are to be merged.
687 Raised if a labeled subset appears in more than one of the input
688 `PipelineIR` objects which are to be merged, and with any subset
689 existing in this object.
690 """
691 # integrate any imported pipelines
692 accumulate_tasks: dict[str, TaskIR] = {}
693 accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
694 accumulated_parameters = ParametersIR({})
696 for tmp_IR in pipelines:
697 if self.instrument is None:
698 self.instrument = tmp_IR.instrument
699 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None:
700 msg = (
701 "Only one instrument can be declared in a pipeline or its imports. "
702 f"Top level pipeline defines {self.instrument} but pipeline to merge "
703 f"defines {tmp_IR.instrument}."
704 )
705 raise ValueError(msg)
706 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys():
707 msg = (
708 "Task labels in the imported pipelines must be unique. "
709 f"These labels appear multiple times: {duplicate_labels}"
710 )
711 raise ValueError(msg)
712 accumulate_tasks.update(tmp_IR.tasks)
713 self.contracts.extend(tmp_IR.contracts)
714 # verify that tmp_IR has unique labels for named subset among
715 # existing labeled subsets, and with existing task labels.
716 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys()
717 task_subset_overlap = (
718 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()
719 ) & accumulate_tasks.keys()
720 if overlapping_subsets or task_subset_overlap:
721 raise ValueError(
722 "Labeled subset names must be unique amongst imports in both labels and "
723 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}"
724 )
725 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets)
726 accumulated_parameters.update(tmp_IR.parameters)
728 # verify that any accumulated labeled subsets dont clash with a label
729 # from this pipeline
730 if accumulate_labeled_subsets.keys() & self.tasks.keys():
731 raise ValueError(
732 "Labeled subset names must be unique amongst imports in both labels and named Subsets"
733 )
734 # merge in the named subsets for self so this document can override any
735 # that have been delcared
736 accumulate_labeled_subsets.update(self.labeled_subsets)
737 self.labeled_subsets = accumulate_labeled_subsets
739 # merge the dict of label:TaskIR objects, preserving any configs in the
740 # imported pipeline if the labels point to the same class
741 for label, task in self.tasks.items():
742 if label not in accumulate_tasks:
743 accumulate_tasks[label] = task
744 elif accumulate_tasks[label].klass == task.klass:
745 if task.config is not None:
746 for config in task.config:
747 accumulate_tasks[label].add_or_update_config(config)
748 else:
749 accumulate_tasks[label] = task
750 self.tasks: dict[str, TaskIR] = accumulate_tasks
751 accumulated_parameters.update(self.parameters)
752 self.parameters = accumulated_parameters
754 def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
755 """Process the tasks portion of the loaded yaml document
757 Parameters
758 ----------
759 loaded_yaml : `dict`
760 A dictionary which matches the structure that would be produced by
761 a yaml reader which parses a pipeline definition document
762 """
763 self.tasks = {}
764 tmp_tasks = loaded_yaml.pop("tasks", None)
765 if tmp_tasks is None:
766 tmp_tasks = {}
768 if "parameters" in tmp_tasks:
769 raise ValueError("parameters is a reserved word and cannot be used as a task label")
771 for label, definition in tmp_tasks.items():
772 if isinstance(definition, str):
773 definition = {"class": definition}
774 config = definition.get("config", None)
775 if config is None:
776 task_config_ir = None
777 else:
778 if isinstance(config, dict):
779 config = [config]
780 task_config_ir = []
781 for c in config:
782 file = c.pop("file", None)
783 if file is None:
784 file = []
785 elif not isinstance(file, list):
786 file = [file]
787 task_config_ir.append(
788 ConfigIR(
789 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c
790 )
791 )
792 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
794 def _remove_contracts(self, label: str) -> None:
795 """Remove any contracts that contain the given label
797 String comparison used in this way is not the most elegant and may
798 have issues, but it is the only feasible way when users can specify
799 contracts with generic strings.
800 """
801 new_contracts = []
802 for contract in self.contracts:
803 # match a label that is not preceded by an ASCII identifier, or
804 # is the start of a line and is followed by a dot
805 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
806 continue
807 new_contracts.append(contract)
808 self.contracts = new_contracts
810 def subset_from_labels(self, labelSpecifier: set[str]) -> PipelineIR:
811 """Subset a pipelineIR to contain only labels specified in
812 labelSpecifier.
814 Parameters
815 ----------
816 labelSpecifier : `set` of `str`
817 set containing labels that describes how to subset a pipeline.
819 Returns
820 -------
821 pipeline : `PipelineIR`
822 A new pipelineIR object that is a subset of the old pipelineIR
824 Raises
825 ------
826 ValueError
827 Raised if there is an issue with specified labels
829 Notes
830 -----
831 This method attempts to prune any contracts that contain labels which
832 are not in the declared subset of labels. This pruning is done using a
833 string based matching due to the nature of contracts and may prune more
834 than it should. Any labeled subsets defined that no longer have all
835 members of the subset present in the pipeline will be removed from the
836 resulting pipeline.
837 """
838 pipeline = copy.deepcopy(self)
840 # update the label specifier to expand any named subsets
841 toRemove = set()
842 toAdd = set()
843 for label in labelSpecifier:
844 if label in pipeline.labeled_subsets:
845 toRemove.add(label)
846 toAdd.update(pipeline.labeled_subsets[label].subset)
847 labelSpecifier.difference_update(toRemove)
848 labelSpecifier.update(toAdd)
849 # verify all the labels are in the pipeline
850 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets):
851 difference = labelSpecifier.difference(pipeline.tasks.keys())
852 raise ValueError(
853 "Not all supplied labels (specified or named subsets) are in the pipeline "
854 f"definition, extra labels: {difference}"
855 )
856 # copy needed so as to not modify while iterating
857 pipeline_labels = set(pipeline.tasks.keys())
858 # Remove the labels from the pipelineIR, and any contracts that contain
859 # those labels (see docstring on _remove_contracts for why this may
860 # cause issues)
861 for label in pipeline_labels:
862 if label not in labelSpecifier:
863 pipeline.tasks.pop(label)
864 pipeline._remove_contracts(label)
866 # create a copy of the object to iterate over
867 labeled_subsets = copy.copy(pipeline.labeled_subsets)
868 # remove any labeled subsets that no longer have a complete set
869 for label, labeled_subset in labeled_subsets.items():
870 if labeled_subset.subset - pipeline.tasks.keys():
871 pipeline.labeled_subsets.pop(label)
873 return pipeline
875 @classmethod
876 def from_string(cls, pipeline_string: str) -> PipelineIR:
877 """Create a `PipelineIR` object from a string formatted like a pipeline
878 document
880 Parameters
881 ----------
882 pipeline_string : `str`
883 A string that is formatted according like a pipeline document
884 """
885 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
886 return cls(loaded_yaml)
888 @classmethod
889 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR:
890 """Create a `PipelineIR` object from the document specified by the
891 input uri.
893 Parameters
894 ----------
895 uri: convertible to `~lsst.resources.ResourcePath`
896 Location of document to use in creating a `PipelineIR` object.
898 Returns
899 -------
900 pipelineIR : `PipelineIR`
901 The loaded pipeline
902 """
903 loaded_uri = ResourcePath(uri)
904 with loaded_uri.open("r") as buffer:
905 loaded_yaml = yaml.load(buffer, Loader=PipelineYamlLoader)
906 return cls(loaded_yaml)
908 def write_to_uri(self, uri: ResourcePathExpression) -> None:
909 """Serialize this `PipelineIR` object into a yaml formatted string and
910 write the output to a file at the specified uri.
912 Parameters
913 ----------
914 uri: convertible to `~lsst.resources.ResourcePath`
915 Location of document to write a `PipelineIR` object.
916 """
917 with ResourcePath(uri).open("w") as buffer:
918 yaml.dump(self.to_primitives(), buffer, sort_keys=False, Dumper=MultilineStringDumper)
920 def to_primitives(self) -> dict[str, Any]:
921 """Convert to a representation used in yaml serialization
923 Returns
924 -------
925 primitives : `dict`
926 dictionary that maps directly to the serialized YAML form.
927 """
928 accumulate = {"description": self.description}
929 if self.instrument is not None:
930 accumulate["instrument"] = self.instrument
931 if self.parameters:
932 accumulate["parameters"] = self.parameters.to_primitives()
933 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()}
934 if len(self.contracts) > 0:
935 # sort contracts lexicographical order by the contract string in
936 # absence of any other ordering principle
937 contracts_list = [c.to_primitives() for c in self.contracts]
938 contracts_list.sort(key=lambda x: x["contract"])
939 accumulate["contracts"] = contracts_list
940 if self.labeled_subsets:
941 accumulate["subsets"] = {k: v.to_primitives() for k, v in self.labeled_subsets.items()}
942 return accumulate
944 def __str__(self) -> str:
945 """Instance formatting as how it would look in yaml representation"""
946 return yaml.dump(self.to_primitives(), sort_keys=False, Dumper=MultilineStringDumper)
948 def __repr__(self) -> str:
949 """Instance formatting as how it would look in yaml representation"""
950 return str(self)
952 def __eq__(self, other: object) -> bool:
953 if not isinstance(other, PipelineIR):
954 return False
955 # special case contracts because it is a list, but order is not
956 # important
957 return (
958 all(
959 getattr(self, attr) == getattr(other, attr)
960 for attr in ("tasks", "instrument", "labeled_subsets", "parameters")
961 )
962 and len(self.contracts) == len(other.contracts)
963 and all(c in self.contracts for c in other.contracts)
964 )