Coverage for python/lsst/pipe/base/pipelineIR.py: 20%
415 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:47 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:47 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset")
25import copy
26import enum
27import os
28import re
29import warnings
30from collections import Counter
31from collections.abc import Iterable as abcIterable
32from dataclasses import dataclass, field
33from typing import (
34 Any,
35 Dict,
36 Generator,
37 Hashable,
38 List,
39 Literal,
40 Mapping,
41 MutableMapping,
42 Optional,
43 Set,
44 Union,
45)
47import yaml
48from deprecated.sphinx import deprecated
49from lsst.resources import ResourcePath, ResourcePathExpression
52class _Tags(enum.Enum):
53 KeepInstrument = enum.auto()
56class PipelineYamlLoader(yaml.SafeLoader):
57 """This is a specialized version of yaml's SafeLoader. It checks and raises
58 an exception if it finds that there are multiple instances of the same key
59 found inside a pipeline file at a given scope.
60 """
62 def construct_mapping(self, node: yaml.MappingNode, deep: bool = False) -> dict[Hashable, Any]:
63 # do the call to super first so that it can do all the other forms of
64 # checking on this node. If you check the uniqueness of keys first
65 # it would save the work that super does in the case of a failure, but
66 # it might fail in the case that the node was the incorrect node due
67 # to a parsing error, and the resulting exception would be difficult to
68 # understand.
69 mapping = super().construct_mapping(node, deep)
70 # Check if there are any duplicate keys
71 all_keys = Counter(key_node.value for key_node, _ in node.value)
72 duplicates = {k for k, i in all_keys.items() if i != 1}
73 if duplicates:
74 raise KeyError(
75 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times"
76 )
77 return mapping
80class MultilineStringDumper(yaml.Dumper):
81 """Custom YAML dumper that makes multi-line strings use the '|'
82 continuation style instead of unreadable newlines and tons of quotes.
84 Basic approach is taken from
85 https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data,
86 but is written as a Dumper subclass to make its effects non-global (vs
87 `yaml.add_representer`).
88 """
90 def represent_scalar(self, tag: str, value: Any, style: Optional[str] = None) -> yaml.ScalarNode:
91 if style is None and tag == "tag:yaml.org,2002:str" and len(value.splitlines()) > 1:
92 style = "|"
93 return super().represent_scalar(tag, value, style)
96class ContractError(Exception):
97 """An exception that is raised when a pipeline contract is not satisfied"""
99 pass
102@dataclass
103class ContractIR:
104 """Intermediate representation of configuration contracts read from a
105 pipeline yaml file."""
107 contract: str
108 """A string of python code representing one or more conditions on configs
109 in a pipeline. This code-as-string should, once evaluated, should be True
110 if the configs are fine, and False otherwise.
111 """
112 msg: Union[str, None] = None
113 """An optional message to be shown to the user if a contract fails
114 """
116 def to_primitives(self) -> Dict[str, str]:
117 """Convert to a representation used in yaml serialization"""
118 accumulate = {"contract": self.contract}
119 if self.msg is not None:
120 accumulate["msg"] = self.msg
121 return accumulate
123 def __eq__(self, other: object) -> bool:
124 if not isinstance(other, ContractIR):
125 return False
126 elif self.contract == other.contract and self.msg == other.msg:
127 return True
128 else:
129 return False
132@dataclass
133class LabeledSubset:
134 """Intermediate representation of named subset of task labels read from
135 a pipeline yaml file.
136 """
138 label: str
139 """The label used to identify the subset of task labels.
140 """
141 subset: Set[str]
142 """A set of task labels contained in this subset.
143 """
144 description: Optional[str]
145 """A description of what this subset of tasks is intended to do
146 """
148 @staticmethod
149 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset:
150 """Generate `LabeledSubset` objects given a properly formatted object
151 that as been created by a yaml loader.
153 Parameters
154 ----------
155 label : `str`
156 The label that will be used to identify this labeled subset.
157 value : `list` of `str` or `dict`
158 Object returned from loading a labeled subset section from a yaml
159 document.
161 Returns
162 -------
163 labeledSubset : `LabeledSubset`
164 A `LabeledSubset` object build from the inputs.
166 Raises
167 ------
168 ValueError
169 Raised if the value input is not properly formatted for parsing
170 """
171 if isinstance(value, MutableMapping):
172 subset = value.pop("subset", None)
173 if subset is None:
174 raise ValueError(
175 "If a labeled subset is specified as a mapping, it must contain the key 'subset'"
176 )
177 description = value.pop("description", None)
178 elif isinstance(value, abcIterable):
179 subset = value
180 description = None
181 else:
182 raise ValueError(
183 f"There was a problem parsing the labeled subset {label}, make sure the "
184 "definition is either a valid yaml list, or a mapping with keys "
185 "(subset, description) where subset points to a yaml list, and description is "
186 "associated with a string"
187 )
188 return LabeledSubset(label, set(subset), description)
190 def to_primitives(self) -> Dict[str, Union[List[str], str]]:
191 """Convert to a representation used in yaml serialization"""
192 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)}
193 if self.description is not None:
194 accumulate["description"] = self.description
195 return accumulate
198@dataclass
199class ParametersIR:
200 """Intermediate representation of parameters that are global to a pipeline
202 These parameters are specified under a top level key named `parameters`
203 and are declared as a yaml mapping. These entries can then be used inside
204 task configuration blocks to specify configuration values. They may not be
205 used in the special ``file`` or ``python`` blocks.
207 Example:
208 paramters:
209 shared_value: 14
210 tasks:
211 taskA:
212 class: modA
213 config:
214 field1: parameters.shared_value
215 taskB:
216 class: modB
217 config:
218 field2: parameters.shared_value
219 """
221 mapping: MutableMapping[str, str]
222 """A mutable mapping of identifiers as keys, and shared configuration
223 as values.
224 """
226 def update(self, other: Optional[ParametersIR]) -> None:
227 if other is not None:
228 self.mapping.update(other.mapping)
230 def to_primitives(self) -> MutableMapping[str, str]:
231 """Convert to a representation used in yaml serialization"""
232 return self.mapping
234 def __contains__(self, value: str) -> bool:
235 return value in self.mapping
237 def __getitem__(self, item: str) -> Any:
238 return self.mapping[item]
240 def __bool__(self) -> bool:
241 return bool(self.mapping)
244@dataclass
245class ConfigIR:
246 """Intermediate representation of configurations read from a pipeline yaml
247 file.
248 """
250 python: Union[str, None] = None
251 """A string of python code that is used to modify a configuration. This can
252 also be None if there are no modifications to do.
253 """
254 dataId: Union[dict, None] = None
255 """A dataId that is used to constrain these config overrides to only quanta
256 with matching dataIds. This field can be None if there is no constraint.
257 This is currently an unimplemented feature, and is placed here for future
258 use.
259 """
260 file: List[str] = field(default_factory=list)
261 """A list of paths which points to a file containing config overrides to be
262 applied. This value may be an empty list if there are no overrides to
263 apply.
264 """
265 rest: dict = field(default_factory=dict)
266 """This is a dictionary of key value pairs, where the keys are strings
267 corresponding to qualified fields on a config to override, and the values
268 are strings representing the values to apply.
269 """
271 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]:
272 """Convert to a representation used in yaml serialization"""
273 accumulate = {}
274 for name in ("python", "dataId", "file"):
275 # if this attribute is thruthy add it to the accumulation
276 # dictionary
277 if getattr(self, name):
278 accumulate[name] = getattr(self, name)
279 # Add the dictionary containing the rest of the config keys to the
280 # # accumulated dictionary
281 accumulate.update(self.rest)
282 return accumulate
284 def formatted(self, parameters: ParametersIR) -> ConfigIR:
285 """Returns a new ConfigIR object that is formatted according to the
286 specified parameters
288 Parameters
289 ----------
290 parameters : ParametersIR
291 Object that contains variable mappings used in substitution.
293 Returns
294 -------
295 config : ConfigIR
296 A new ConfigIR object formatted with the input parameters
297 """
298 new_config = copy.deepcopy(self)
299 for key, value in new_config.rest.items():
300 if not isinstance(value, str):
301 continue
302 match = re.match("parameters[.](.*)", value)
303 if match and match.group(1) in parameters:
304 new_config.rest[key] = parameters[match.group(1)]
305 if match and match.group(1) not in parameters:
306 warnings.warn(
307 f"config {key} contains value {match.group(0)} which is formatted like a "
308 "Pipeline parameter but was not found within the Pipeline, if this was not "
309 "intentional, check for a typo"
310 )
311 return new_config
313 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
314 """Merges another instance of a `ConfigIR` into this instance if
315 possible. This function returns a generator that is either self
316 if the configs were merged, or self, and other_config if that could
317 not be merged.
319 Parameters
320 ----------
321 other_config : `ConfigIR`
322 An instance of `ConfigIR` to merge into this instance.
324 Returns
325 -------
326 Generator : `ConfigIR`
327 A generator containing either self, or self and other_config if
328 the configs could be merged or not respectively.
329 """
330 # Verify that the config blocks can be merged
331 if (
332 self.dataId != other_config.dataId
333 or self.python
334 or other_config.python
335 or self.file
336 or other_config.file
337 ):
338 yield from (self, other_config)
339 return
341 # create a set of all keys, and verify two keys do not have different
342 # values
343 key_union = self.rest.keys() & other_config.rest.keys()
344 for key in key_union:
345 if self.rest[key] != other_config.rest[key]:
346 yield from (self, other_config)
347 return
348 self.rest.update(other_config.rest)
350 # Combine the lists of override files to load
351 self_file_set = set(self.file)
352 other_file_set = set(other_config.file)
353 self.file = list(self_file_set.union(other_file_set))
355 yield self
357 def __eq__(self, other: object) -> bool:
358 if not isinstance(other, ConfigIR):
359 return False
360 elif all(
361 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest")
362 ):
363 return True
364 else:
365 return False
368@dataclass
369class TaskIR:
370 """Intermediate representation of tasks read from a pipeline yaml file."""
372 label: str
373 """An identifier used to refer to a task.
374 """
375 klass: str
376 """A string containing a fully qualified python class to be run in a
377 pipeline.
378 """
379 config: Union[List[ConfigIR], None] = None
380 """List of all configs overrides associated with this task, and may be
381 `None` if there are no config overrides.
382 """
384 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]:
385 """Convert to a representation used in yaml serialization"""
386 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass}
387 if self.config:
388 accumulate["config"] = [c.to_primitives() for c in self.config]
389 return accumulate
391 def add_or_update_config(self, other_config: ConfigIR) -> None:
392 """Adds a `ConfigIR` to this task if one is not present. Merges configs
393 if there is a `ConfigIR` present and the dataId keys of both configs
394 match, otherwise adds a new entry to the config list. The exception to
395 the above is that if either the last config or other_config has a
396 python block, then other_config is always added, as python blocks can
397 modify configs in ways that cannot be predicted.
399 Parameters
400 ----------
401 other_config : `ConfigIR`
402 A `ConfigIR` instance to add or merge into the config attribute of
403 this task.
404 """
405 if not self.config:
406 self.config = [other_config]
407 return
408 self.config.extend(self.config.pop().maybe_merge(other_config))
410 def __eq__(self, other: object) -> bool:
411 if not isinstance(other, TaskIR):
412 return False
413 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")):
414 return True
415 else:
416 return False
419@dataclass
420class ImportIR:
421 """An intermediate representation of imported pipelines"""
423 location: str
424 """This is the location of the pipeline to inherit. The path should be
425 specified as an absolute path. Environment variables may be used in the
426 path and should be specified as a python string template, with the name of
427 the environment variable inside braces.
428 """
429 include: Union[List[str], None] = None
430 """List of tasks that should be included when inheriting this pipeline.
431 Either the include or exclude attributes may be specified, but not both.
432 """
433 exclude: Union[List[str], None] = None
434 """List of tasks that should be excluded when inheriting this pipeline.
435 Either the include or exclude attributes may be specified, but not both.
436 """
437 importContracts: bool = True
438 """Boolean attribute to dictate if contracts should be inherited with the
439 pipeline or not.
440 """
441 instrument: Union[Literal[_Tags.KeepInstrument], str, None] = _Tags.KeepInstrument
442 """Instrument to assign to the Pipeline at import. The default value of
443 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is
444 declared with will not be modified. Setting this value to None will drop
445 any declared instrument prior to import.
446 """
448 def toPipelineIR(self) -> "PipelineIR":
449 """Load in the Pipeline specified by this object, and turn it into a
450 PipelineIR instance.
452 Returns
453 -------
454 pipeline : `PipelineIR`
455 A pipeline generated from the imported pipeline file
456 """
457 if self.include and self.exclude:
458 raise ValueError(
459 "Both an include and an exclude list cant be specified when declaring a pipeline import"
460 )
461 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location))
462 if self.instrument is not _Tags.KeepInstrument:
463 tmp_pipeline.instrument = self.instrument
465 included_labels = set()
466 for label in tmp_pipeline.tasks:
467 if (
468 (self.include and label in self.include)
469 or (self.exclude and label not in self.exclude)
470 or (self.include is None and self.exclude is None)
471 ):
472 included_labels.add(label)
474 # Handle labeled subsets being specified in the include or exclude
475 # list, adding or removing labels.
476 if self.include is not None:
477 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include
478 for label in subsets_in_include:
479 included_labels.update(tmp_pipeline.labeled_subsets[label].subset)
481 elif self.exclude is not None:
482 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude
483 for label in subsets_in_exclude:
484 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset)
486 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels)
488 if not self.importContracts:
489 tmp_pipeline.contracts = []
491 return tmp_pipeline
493 def __eq__(self, other: object) -> bool:
494 if not isinstance(other, ImportIR):
495 return False
496 elif all(
497 getattr(self, attr) == getattr(other, attr)
498 for attr in ("location", "include", "exclude", "importContracts")
499 ):
500 return True
501 else:
502 return False
505class PipelineIR:
506 """Intermediate representation of a pipeline definition
508 Parameters
509 ----------
510 loaded_yaml : `dict`
511 A dictionary which matches the structure that would be produced by a
512 yaml reader which parses a pipeline definition document
514 Raises
515 ------
516 ValueError
517 Raised if:
518 - a pipeline is declared without a description;
519 - no tasks are declared in a pipeline, and no pipelines are to be
520 inherited;
521 - more than one instrument is specified;
522 - more than one inherited pipeline share a label.
523 """
525 def __init__(self, loaded_yaml: Dict[str, Any]):
526 # Check required fields are present
527 if "description" not in loaded_yaml:
528 raise ValueError("A pipeline must be declared with a description")
529 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2:
530 raise ValueError("A pipeline must be declared with one or more tasks")
532 # These steps below must happen in this call order
534 # Process pipeline description
535 self.description = loaded_yaml.pop("description")
537 # Process tasks
538 self._read_tasks(loaded_yaml)
540 # Process instrument keys
541 inst = loaded_yaml.pop("instrument", None)
542 if isinstance(inst, list):
543 raise ValueError("Only one top level instrument can be defined in a pipeline")
544 self.instrument: Optional[str] = inst
546 # Process any contracts
547 self._read_contracts(loaded_yaml)
549 # Process any defined parameters
550 self._read_parameters(loaded_yaml)
552 # Process any named label subsets
553 self._read_labeled_subsets(loaded_yaml)
555 # Process any inherited pipelines
556 self._read_imports(loaded_yaml)
558 # verify named subsets, must be done after inheriting
559 self._verify_labeled_subsets()
561 def _read_contracts(self, loaded_yaml: Dict[str, Any]) -> None:
562 """Process the contracts portion of the loaded yaml document
564 Parameters
565 ---------
566 loaded_yaml : `dict`
567 A dictionary which matches the structure that would be produced by
568 a yaml reader which parses a pipeline definition document
569 """
570 loaded_contracts = loaded_yaml.pop("contracts", [])
571 if isinstance(loaded_contracts, str):
572 loaded_contracts = [loaded_contracts]
573 self.contracts: List[ContractIR] = []
574 for contract in loaded_contracts:
575 if isinstance(contract, dict):
576 self.contracts.append(ContractIR(**contract))
577 if isinstance(contract, str):
578 self.contracts.append(ContractIR(contract=contract))
580 def _read_parameters(self, loaded_yaml: Dict[str, Any]) -> None:
581 """Process the parameters portion of the loaded yaml document
583 Parameters
584 ---------
585 loaded_yaml : `dict`
586 A dictionary which matches the structure that would be produced by
587 a yaml reader which parses a pipeline definition document
588 """
589 loaded_parameters = loaded_yaml.pop("parameters", {})
590 if not isinstance(loaded_parameters, dict):
591 raise ValueError("The parameters section must be a yaml mapping")
592 self.parameters = ParametersIR(loaded_parameters)
594 def _read_labeled_subsets(self, loaded_yaml: Dict[str, Any]) -> None:
595 """Process the subsets portion of the loaded yaml document
597 Parameters
598 ----------
599 loaded_yaml: `MutableMapping`
600 A dictionary which matches the structure that would be produced
601 by a yaml reader which parses a pipeline definition document
602 """
603 loaded_subsets = loaded_yaml.pop("subsets", {})
604 self.labeled_subsets: Dict[str, LabeledSubset] = {}
605 if not loaded_subsets and "subset" in loaded_yaml:
606 raise ValueError("Top level key should be subsets and not subset, add an s")
607 for key, value in loaded_subsets.items():
608 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value)
610 def _verify_labeled_subsets(self) -> None:
611 """Verifies that all the labels in each named subset exist within the
612 pipeline.
613 """
614 # Verify that all labels defined in a labeled subset are in the
615 # Pipeline
616 for labeled_subset in self.labeled_subsets.values():
617 if not labeled_subset.subset.issubset(self.tasks.keys()):
618 raise ValueError(
619 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the "
620 "declared pipeline"
621 )
622 # Verify subset labels are not already task labels
623 label_intersection = self.labeled_subsets.keys() & self.tasks.keys()
624 if label_intersection:
625 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}")
627 def _read_imports(self, loaded_yaml: Dict[str, Any]) -> None:
628 """Process the inherits portion of the loaded yaml document
630 Parameters
631 ---------
632 loaded_yaml : `dict`
633 A dictionary which matches the structure that would be produced by
634 a yaml reader which parses a pipeline definition document
635 """
637 def process_args(argument: Union[str, dict]) -> dict:
638 if isinstance(argument, str):
639 return {"location": argument}
640 elif isinstance(argument, dict):
641 if "exclude" in argument and isinstance(argument["exclude"], str):
642 argument["exclude"] = [argument["exclude"]]
643 if "include" in argument and isinstance(argument["include"], str):
644 argument["include"] = [argument["include"]]
645 if "instrument" in argument and argument["instrument"] == "None":
646 argument["instrument"] = None
647 return argument
649 if not {"inherits", "imports"} - loaded_yaml.keys():
650 raise ValueError("Cannot define both inherits and imports sections, use imports")
651 tmp_import = loaded_yaml.pop("inherits", None)
652 if tmp_import is None:
653 tmp_import = loaded_yaml.pop("imports", None)
654 else:
655 warnings.warn(
656 "The 'inherits' key is deprecated, and will be "
657 "removed around June 2021. Please use the key "
658 "'imports' instead"
659 )
660 if tmp_import is None:
661 self.imports: List[ImportIR] = []
662 elif isinstance(tmp_import, list):
663 self.imports = [ImportIR(**process_args(args)) for args in tmp_import]
664 else:
665 self.imports = [ImportIR(**process_args(tmp_import))]
667 # integrate any imported pipelines
668 accumulate_tasks: Dict[str, TaskIR] = {}
669 accumulate_labeled_subsets: Dict[str, LabeledSubset] = {}
670 accumulated_parameters = ParametersIR({})
671 for other_pipeline in self.imports:
672 tmp_IR = other_pipeline.toPipelineIR()
673 if self.instrument is None:
674 self.instrument = tmp_IR.instrument
675 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None:
676 msg = (
677 "Only one instrument can be declared in a pipeline or its imports. "
678 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} "
679 f"defines {tmp_IR.instrument}."
680 )
681 raise ValueError(msg)
682 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys():
683 msg = (
684 "Task labels in the imported pipelines must be unique. "
685 f"These labels appear multiple times: {duplicate_labels}"
686 )
687 raise ValueError(msg)
688 accumulate_tasks.update(tmp_IR.tasks)
689 self.contracts.extend(tmp_IR.contracts)
690 # verify that tmp_IR has unique labels for named subset among
691 # existing labeled subsets, and with existing task labels.
692 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys()
693 task_subset_overlap = (
694 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()
695 ) & accumulate_tasks.keys()
696 if overlapping_subsets or task_subset_overlap:
697 raise ValueError(
698 "Labeled subset names must be unique amongst imports in both labels and "
699 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}"
700 )
701 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets)
702 accumulated_parameters.update(tmp_IR.parameters)
704 # verify that any accumulated labeled subsets dont clash with a label
705 # from this pipeline
706 if accumulate_labeled_subsets.keys() & self.tasks.keys():
707 raise ValueError(
708 "Labeled subset names must be unique amongst imports in both labels and named Subsets"
709 )
710 # merge in the named subsets for self so this document can override any
711 # that have been delcared
712 accumulate_labeled_subsets.update(self.labeled_subsets)
713 self.labeled_subsets = accumulate_labeled_subsets
715 # merge the dict of label:TaskIR objects, preserving any configs in the
716 # imported pipeline if the labels point to the same class
717 for label, task in self.tasks.items():
718 if label not in accumulate_tasks:
719 accumulate_tasks[label] = task
720 elif accumulate_tasks[label].klass == task.klass:
721 if task.config is not None:
722 for config in task.config:
723 accumulate_tasks[label].add_or_update_config(config)
724 else:
725 accumulate_tasks[label] = task
726 self.tasks: Dict[str, TaskIR] = accumulate_tasks
727 accumulated_parameters.update(self.parameters)
728 self.parameters = accumulated_parameters
730 def _read_tasks(self, loaded_yaml: Dict[str, Any]) -> None:
731 """Process the tasks portion of the loaded yaml document
733 Parameters
734 ---------
735 loaded_yaml : `dict`
736 A dictionary which matches the structure that would be produced by
737 a yaml reader which parses a pipeline definition document
738 """
739 self.tasks = {}
740 tmp_tasks = loaded_yaml.pop("tasks", None)
741 if tmp_tasks is None:
742 tmp_tasks = {}
744 if "parameters" in tmp_tasks:
745 raise ValueError("parameters is a reserved word and cannot be used as a task label")
747 for label, definition in tmp_tasks.items():
748 if isinstance(definition, str):
749 definition = {"class": definition}
750 config = definition.get("config", None)
751 if config is None:
752 task_config_ir = None
753 else:
754 if isinstance(config, dict):
755 config = [config]
756 task_config_ir = []
757 for c in config:
758 file = c.pop("file", None)
759 if file is None:
760 file = []
761 elif not isinstance(file, list):
762 file = [file]
763 task_config_ir.append(
764 ConfigIR(
765 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c
766 )
767 )
768 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
770 def _remove_contracts(self, label: str) -> None:
771 """Remove any contracts that contain the given label
773 String comparison used in this way is not the most elegant and may
774 have issues, but it is the only feasible way when users can specify
775 contracts with generic strings.
776 """
777 new_contracts = []
778 for contract in self.contracts:
779 # match a label that is not preceded by an ASCII identifier, or
780 # is the start of a line and is followed by a dot
781 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
782 continue
783 new_contracts.append(contract)
784 self.contracts = new_contracts
786 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR:
787 """Subset a pipelineIR to contain only labels specified in
788 labelSpecifier.
790 Parameters
791 ----------
792 labelSpecifier : `set` of `str`
793 Set containing labels that describes how to subset a pipeline.
795 Returns
796 -------
797 pipeline : `PipelineIR`
798 A new pipelineIR object that is a subset of the old pipelineIR
800 Raises
801 ------
802 ValueError
803 Raised if there is an issue with specified labels
805 Notes
806 -----
807 This method attempts to prune any contracts that contain labels which
808 are not in the declared subset of labels. This pruning is done using a
809 string based matching due to the nature of contracts and may prune more
810 than it should. Any labeled subsets defined that no longer have all
811 members of the subset present in the pipeline will be removed from the
812 resulting pipeline.
813 """
815 pipeline = copy.deepcopy(self)
817 # update the label specifier to expand any named subsets
818 toRemove = set()
819 toAdd = set()
820 for label in labelSpecifier:
821 if label in pipeline.labeled_subsets:
822 toRemove.add(label)
823 toAdd.update(pipeline.labeled_subsets[label].subset)
824 labelSpecifier.difference_update(toRemove)
825 labelSpecifier.update(toAdd)
826 # verify all the labels are in the pipeline
827 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets):
828 difference = labelSpecifier.difference(pipeline.tasks.keys())
829 raise ValueError(
830 "Not all supplied labels (specified or named subsets) are in the pipeline "
831 f"definition, extra labels: {difference}"
832 )
833 # copy needed so as to not modify while iterating
834 pipeline_labels = set(pipeline.tasks.keys())
835 # Remove the labels from the pipelineIR, and any contracts that contain
836 # those labels (see docstring on _remove_contracts for why this may
837 # cause issues)
838 for label in pipeline_labels:
839 if label not in labelSpecifier:
840 pipeline.tasks.pop(label)
841 pipeline._remove_contracts(label)
843 # create a copy of the object to iterate over
844 labeled_subsets = copy.copy(pipeline.labeled_subsets)
845 # remove any labeled subsets that no longer have a complete set
846 for label, labeled_subset in labeled_subsets.items():
847 if labeled_subset.subset - pipeline.tasks.keys():
848 pipeline.labeled_subsets.pop(label)
850 return pipeline
852 @classmethod
853 def from_string(cls, pipeline_string: str) -> PipelineIR:
854 """Create a `PipelineIR` object from a string formatted like a pipeline
855 document
857 Parameters
858 ----------
859 pipeline_string : `str`
860 A string that is formatted according like a pipeline document
861 """
862 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
863 return cls(loaded_yaml)
865 @classmethod
866 @deprecated(
867 reason="This has been replaced with `from_uri`. will be removed after v23",
868 version="v21.0,",
869 category=FutureWarning,
870 )
871 def from_file(cls, filename: str) -> PipelineIR:
872 """Create a `PipelineIR` object from the document specified by the
873 input path.
875 Parameters
876 ----------
877 filename : `str`
878 Location of document to use in creating a `PipelineIR` object.
880 Returns
881 -------
882 pipelineIR : `PipelineIR`
883 The loaded pipeline
885 Note
886 ----
887 This method is deprecated, please use from_uri
888 """
889 return cls.from_uri(filename)
891 @classmethod
892 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR:
893 """Create a `PipelineIR` object from the document specified by the
894 input uri.
896 Parameters
897 ----------
898 uri: convertible to `ResourcePath`
899 Location of document to use in creating a `PipelineIR` object.
901 Returns
902 -------
903 pipelineIR : `PipelineIR`
904 The loaded pipeline
905 """
906 loaded_uri = ResourcePath(uri)
907 with loaded_uri.open("r") as buffer:
908 loaded_yaml = yaml.load(buffer, Loader=PipelineYamlLoader)
909 return cls(loaded_yaml)
911 @deprecated(
912 reason="This has been replaced with `write_to_uri`. will be removed after v23",
913 version="v21.0,",
914 category=FutureWarning,
915 ) # type: ignore
916 def to_file(self, filename: str):
917 """Serialize this `PipelineIR` object into a yaml formatted string and
918 write the output to a file at the specified path.
920 Parameters
921 ----------
922 filename : `str`
923 Location of document to write a `PipelineIR` object.
924 """
925 self.write_to_uri(filename)
927 def write_to_uri(
928 self,
929 uri: ResourcePathExpression,
930 ) -> None:
931 """Serialize this `PipelineIR` object into a yaml formatted string and
932 write the output to a file at the specified uri.
934 Parameters
935 ----------
936 uri: convertible to `ResourcePath`
937 Location of document to write a `PipelineIR` object.
938 """
939 with ResourcePath(uri).open("w") as buffer:
940 yaml.dump(self.to_primitives(), buffer, sort_keys=False, Dumper=MultilineStringDumper)
942 def to_primitives(self) -> Dict[str, Any]:
943 """Convert to a representation used in yaml serialization"""
944 accumulate = {"description": self.description}
945 if self.instrument is not None:
946 accumulate["instrument"] = self.instrument
947 if self.parameters:
948 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives())
949 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()}
950 if len(self.contracts) > 0:
951 # sort contracts lexicographical order by the contract string in
952 # absence of any other ordering principle
953 contracts_list = [c.to_primitives() for c in self.contracts]
954 contracts_list.sort(key=lambda x: x["contract"])
955 accumulate["contracts"] = contracts_list
956 if self.labeled_subsets:
957 accumulate["subsets"] = self._sort_by_str(
958 {k: v.to_primitives() for k, v in self.labeled_subsets.items()}
959 )
960 return accumulate
962 def reorder_tasks(self, task_labels: List[str]) -> None:
963 """Changes the order tasks are stored internally. Useful for
964 determining the order things will appear in the serialized (or printed)
965 form.
967 Parameters
968 ----------
969 task_labels : `list` of `str`
970 A list corresponding to all the labels in the pipeline inserted in
971 the order the tasks are to be stored.
973 Raises
974 ------
975 KeyError
976 Raised if labels are supplied that are not in the pipeline, or if
977 not all labels in the pipeline were supplied in task_labels input.
978 """
979 # verify that all labels are in the input
980 _tmp_set = set(task_labels)
981 if remainder := (self.tasks.keys() - _tmp_set):
982 raise KeyError(f"Label(s) {remainder} are missing from the task label list")
983 if extra := (_tmp_set - self.tasks.keys()):
984 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline")
986 newTasks = {key: self.tasks[key] for key in task_labels}
987 self.tasks = newTasks
989 @staticmethod
990 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]:
991 keys = sorted(arg.keys())
992 return {key: arg[key] for key in keys}
994 def __str__(self) -> str:
995 """Instance formatting as how it would look in yaml representation"""
996 return yaml.dump(self.to_primitives(), sort_keys=False, Dumper=MultilineStringDumper)
998 def __repr__(self) -> str:
999 """Instance formatting as how it would look in yaml representation"""
1000 return str(self)
1002 def __eq__(self, other: object) -> bool:
1003 if not isinstance(other, PipelineIR):
1004 return False
1005 # special case contracts because it is a list, but order is not
1006 # important
1007 elif (
1008 all(
1009 getattr(self, attr) == getattr(other, attr)
1010 for attr in ("tasks", "instrument", "labeled_subsets", "parameters")
1011 )
1012 and len(self.contracts) == len(other.contracts)
1013 and all(c in self.contracts for c in other.contracts)
1014 ):
1015 return True
1016 else:
1017 return False