Coverage for python/lsst/pipe/base/pipelineIR.py: 19%
396 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-08 09:15 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-08 09:15 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset")
25import copy
26import enum
27import os
28import re
29import warnings
30from collections import Counter
31from collections.abc import Generator, Hashable, Iterable, MutableMapping
32from dataclasses import dataclass, field
33from typing import Any, Literal
35import yaml
36from lsst.resources import ResourcePath, ResourcePathExpression
39class _Tags(enum.Enum):
40 KeepInstrument = enum.auto()
43class PipelineYamlLoader(yaml.SafeLoader):
44 """This is a specialized version of yaml's SafeLoader. It checks and raises
45 an exception if it finds that there are multiple instances of the same key
46 found inside a pipeline file at a given scope.
47 """
49 def construct_mapping(self, node: yaml.MappingNode, deep: bool = False) -> dict[Hashable, Any]:
50 # do the call to super first so that it can do all the other forms of
51 # checking on this node. If you check the uniqueness of keys first
52 # it would save the work that super does in the case of a failure, but
53 # it might fail in the case that the node was the incorrect node due
54 # to a parsing error, and the resulting exception would be difficult to
55 # understand.
56 mapping = super().construct_mapping(node, deep)
57 # Check if there are any duplicate keys
58 all_keys = Counter(key_node.value for key_node, _ in node.value)
59 duplicates = {k for k, i in all_keys.items() if i != 1}
60 if duplicates:
61 raise KeyError(
62 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times"
63 )
64 return mapping
67class MultilineStringDumper(yaml.Dumper):
68 """Custom YAML dumper that makes multi-line strings use the '|'
69 continuation style instead of unreadable newlines and tons of quotes.
71 Basic approach is taken from
72 https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data,
73 but is written as a Dumper subclass to make its effects non-global (vs
74 `yaml.add_representer`).
75 """
77 def represent_scalar(self, tag: str, value: Any, style: str | None = None) -> yaml.ScalarNode:
78 if style is None and tag == "tag:yaml.org,2002:str" and len(value.splitlines()) > 1:
79 style = "|"
80 return super().represent_scalar(tag, value, style)
83class ContractError(Exception):
84 """An exception that is raised when a pipeline contract is not satisfied"""
86 pass
89@dataclass
90class ContractIR:
91 """Intermediate representation of configuration contracts read from a
92 pipeline yaml file."""
94 contract: str
95 """A string of python code representing one or more conditions on configs
96 in a pipeline. This code-as-string should, once evaluated, should be True
97 if the configs are fine, and False otherwise.
98 """
99 msg: str | None = None
100 """An optional message to be shown to the user if a contract fails
101 """
103 def to_primitives(self) -> dict[str, str]:
104 """Convert to a representation used in yaml serialization"""
105 accumulate = {"contract": self.contract}
106 if self.msg is not None:
107 accumulate["msg"] = self.msg
108 return accumulate
110 def __eq__(self, other: object) -> bool:
111 if not isinstance(other, ContractIR):
112 return False
113 elif self.contract == other.contract and self.msg == other.msg:
114 return True
115 else:
116 return False
119@dataclass
120class LabeledSubset:
121 """Intermediate representation of named subset of task labels read from
122 a pipeline yaml file.
123 """
125 label: str
126 """The label used to identify the subset of task labels.
127 """
128 subset: set[str]
129 """A set of task labels contained in this subset.
130 """
131 description: str | None
132 """A description of what this subset of tasks is intended to do
133 """
135 @staticmethod
136 def from_primitives(label: str, value: list[str] | dict) -> LabeledSubset:
137 """Generate `LabeledSubset` objects given a properly formatted object
138 that as been created by a yaml loader.
140 Parameters
141 ----------
142 label : `str`
143 The label that will be used to identify this labeled subset.
144 value : `list` of `str` or `dict`
145 Object returned from loading a labeled subset section from a yaml
146 document.
148 Returns
149 -------
150 labeledSubset : `LabeledSubset`
151 A `LabeledSubset` object build from the inputs.
153 Raises
154 ------
155 ValueError
156 Raised if the value input is not properly formatted for parsing
157 """
158 if isinstance(value, MutableMapping):
159 subset = value.pop("subset", None)
160 if subset is None:
161 raise ValueError(
162 "If a labeled subset is specified as a mapping, it must contain the key 'subset'"
163 )
164 description = value.pop("description", None)
165 elif isinstance(value, Iterable):
166 subset = value
167 description = None
168 else:
169 raise ValueError(
170 f"There was a problem parsing the labeled subset {label}, make sure the "
171 "definition is either a valid yaml list, or a mapping with keys "
172 "(subset, description) where subset points to a yaml list, and description is "
173 "associated with a string"
174 )
175 return LabeledSubset(label, set(subset), description)
177 def to_primitives(self) -> dict[str, list[str] | str]:
178 """Convert to a representation used in yaml serialization."""
179 accumulate: dict[str, list[str] | str] = {"subset": list(self.subset)}
180 if self.description is not None:
181 accumulate["description"] = self.description
182 return accumulate
185@dataclass
186class ParametersIR:
187 """Intermediate representation of parameters that are global to a pipeline
189 These parameters are specified under a top level key named `parameters`
190 and are declared as a yaml mapping. These entries can then be used inside
191 task configuration blocks to specify configuration values. They may not be
192 used in the special ``file`` or ``python`` blocks.
194 Example:
195 paramters:
196 shared_value: 14
197 tasks:
198 taskA:
199 class: modA
200 config:
201 field1: parameters.shared_value
202 taskB:
203 class: modB
204 config:
205 field2: parameters.shared_value
206 """
208 mapping: MutableMapping[str, str]
209 """A mutable mapping of identifiers as keys, and shared configuration
210 as values.
211 """
213 def update(self, other: ParametersIR | None) -> None:
214 if other is not None:
215 self.mapping.update(other.mapping)
217 def to_primitives(self) -> MutableMapping[str, str]:
218 """Convert to a representation used in yaml serialization"""
219 return self.mapping
221 def __contains__(self, value: str) -> bool:
222 return value in self.mapping
224 def __getitem__(self, item: str) -> Any:
225 return self.mapping[item]
227 def __bool__(self) -> bool:
228 return bool(self.mapping)
231@dataclass
232class ConfigIR:
233 """Intermediate representation of configurations read from a pipeline yaml
234 file.
235 """
237 python: str | None = None
238 """A string of python code that is used to modify a configuration. This can
239 also be None if there are no modifications to do.
240 """
241 dataId: dict | None = None
242 """A dataId that is used to constrain these config overrides to only quanta
243 with matching dataIds. This field can be None if there is no constraint.
244 This is currently an unimplemented feature, and is placed here for future
245 use.
246 """
247 file: list[str] = field(default_factory=list)
248 """A list of paths which points to a file containing config overrides to be
249 applied. This value may be an empty list if there are no overrides to
250 apply.
251 """
252 rest: dict = field(default_factory=dict)
253 """This is a dictionary of key value pairs, where the keys are strings
254 corresponding to qualified fields on a config to override, and the values
255 are strings representing the values to apply.
256 """
258 def to_primitives(self) -> dict[str, str | dict | list[str]]:
259 """Convert to a representation used in yaml serialization"""
260 accumulate = {}
261 for name in ("python", "dataId", "file"):
262 # if this attribute is thruthy add it to the accumulation
263 # dictionary
264 if getattr(self, name):
265 accumulate[name] = getattr(self, name)
266 # Add the dictionary containing the rest of the config keys to the
267 # # accumulated dictionary
268 accumulate.update(self.rest)
269 return accumulate
271 def formatted(self, parameters: ParametersIR) -> ConfigIR:
272 """Returns a new ConfigIR object that is formatted according to the
273 specified parameters
275 Parameters
276 ----------
277 parameters : ParametersIR
278 Object that contains variable mappings used in substitution.
280 Returns
281 -------
282 config : ConfigIR
283 A new ConfigIR object formatted with the input parameters
284 """
285 new_config = copy.deepcopy(self)
286 for key, value in new_config.rest.items():
287 if not isinstance(value, str):
288 continue
289 match = re.match("parameters[.](.*)", value)
290 if match and match.group(1) in parameters:
291 new_config.rest[key] = parameters[match.group(1)]
292 if match and match.group(1) not in parameters:
293 warnings.warn(
294 f"config {key} contains value {match.group(0)} which is formatted like a "
295 "Pipeline parameter but was not found within the Pipeline, if this was not "
296 "intentional, check for a typo"
297 )
298 return new_config
300 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
301 """Merges another instance of a `ConfigIR` into this instance if
302 possible. This function returns a generator that is either self
303 if the configs were merged, or self, and other_config if that could
304 not be merged.
306 Parameters
307 ----------
308 other_config : `ConfigIR`
309 An instance of `ConfigIR` to merge into this instance.
311 Returns
312 -------
313 Generator : `ConfigIR`
314 A generator containing either self, or self and other_config if
315 the configs could be merged or not respectively.
316 """
317 # Verify that the config blocks can be merged
318 if (
319 self.dataId != other_config.dataId
320 or self.python
321 or other_config.python
322 or self.file
323 or other_config.file
324 ):
325 yield from (self, other_config)
326 return
328 # create a set of all keys, and verify two keys do not have different
329 # values
330 key_union = self.rest.keys() & other_config.rest.keys()
331 for key in key_union:
332 if self.rest[key] != other_config.rest[key]:
333 yield from (self, other_config)
334 return
335 self.rest.update(other_config.rest)
337 # Combine the lists of override files to load
338 self_file_set = set(self.file)
339 other_file_set = set(other_config.file)
340 self.file = list(self_file_set.union(other_file_set))
342 yield self
344 def __eq__(self, other: object) -> bool:
345 if not isinstance(other, ConfigIR):
346 return False
347 elif all(
348 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest")
349 ):
350 return True
351 else:
352 return False
355@dataclass
356class TaskIR:
357 """Intermediate representation of tasks read from a pipeline yaml file."""
359 label: str
360 """An identifier used to refer to a task.
361 """
362 klass: str
363 """A string containing a fully qualified python class to be run in a
364 pipeline.
365 """
366 config: list[ConfigIR] | None = None
367 """list of all configs overrides associated with this task, and may be
368 `None` if there are no config overrides.
369 """
371 def to_primitives(self) -> dict[str, str | list[dict]]:
372 """Convert to a representation used in yaml serialization"""
373 accumulate: dict[str, str | list[dict]] = {"class": self.klass}
374 if self.config:
375 accumulate["config"] = [c.to_primitives() for c in self.config]
376 return accumulate
378 def add_or_update_config(self, other_config: ConfigIR) -> None:
379 """Adds a `ConfigIR` to this task if one is not present. Merges configs
380 if there is a `ConfigIR` present and the dataId keys of both configs
381 match, otherwise adds a new entry to the config list. The exception to
382 the above is that if either the last config or other_config has a
383 python block, then other_config is always added, as python blocks can
384 modify configs in ways that cannot be predicted.
386 Parameters
387 ----------
388 other_config : `ConfigIR`
389 A `ConfigIR` instance to add or merge into the config attribute of
390 this task.
391 """
392 if not self.config:
393 self.config = [other_config]
394 return
395 self.config.extend(self.config.pop().maybe_merge(other_config))
397 def __eq__(self, other: object) -> bool:
398 if not isinstance(other, TaskIR):
399 return False
400 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")):
401 return True
402 else:
403 return False
406@dataclass
407class ImportIR:
408 """An intermediate representation of imported pipelines"""
410 location: str
411 """This is the location of the pipeline to inherit. The path should be
412 specified as an absolute path. Environment variables may be used in the
413 path and should be specified as a python string template, with the name of
414 the environment variable inside braces.
415 """
416 include: list[str] | None = None
417 """list of tasks that should be included when inheriting this pipeline.
418 Either the include or exclude attributes may be specified, but not both.
419 """
420 exclude: list[str] | None = None
421 """list of tasks that should be excluded when inheriting this pipeline.
422 Either the include or exclude attributes may be specified, but not both.
423 """
424 importContracts: bool = True
425 """Boolean attribute to dictate if contracts should be inherited with the
426 pipeline or not.
427 """
428 instrument: Literal[_Tags.KeepInstrument] | str | None = _Tags.KeepInstrument
429 """Instrument to assign to the Pipeline at import. The default value of
430 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is
431 declared with will not be modified. setting this value to None will drop
432 any declared instrument prior to import.
433 """
435 def toPipelineIR(self) -> "PipelineIR":
436 """Load in the Pipeline specified by this object, and turn it into a
437 PipelineIR instance.
439 Returns
440 -------
441 pipeline : `PipelineIR`
442 A pipeline generated from the imported pipeline file
443 """
444 if self.include and self.exclude:
445 raise ValueError(
446 "An include list and an exclude list cannot both be specified"
447 " when declaring a pipeline import."
448 )
449 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location))
450 if self.instrument is not _Tags.KeepInstrument:
451 tmp_pipeline.instrument = self.instrument
453 included_labels = set()
454 for label in tmp_pipeline.tasks:
455 if (
456 (self.include and label in self.include)
457 or (self.exclude and label not in self.exclude)
458 or (self.include is None and self.exclude is None)
459 ):
460 included_labels.add(label)
462 # Handle labeled subsets being specified in the include or exclude
463 # list, adding or removing labels.
464 if self.include is not None:
465 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include
466 for label in subsets_in_include:
467 included_labels.update(tmp_pipeline.labeled_subsets[label].subset)
469 elif self.exclude is not None:
470 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude
471 for label in subsets_in_exclude:
472 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset)
474 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels)
476 if not self.importContracts:
477 tmp_pipeline.contracts = []
479 return tmp_pipeline
481 def __eq__(self, other: object) -> bool:
482 if not isinstance(other, ImportIR):
483 return False
484 elif all(
485 getattr(self, attr) == getattr(other, attr)
486 for attr in ("location", "include", "exclude", "importContracts")
487 ):
488 return True
489 else:
490 return False
493class PipelineIR:
494 """Intermediate representation of a pipeline definition
496 Parameters
497 ----------
498 loaded_yaml : `dict`
499 A dictionary which matches the structure that would be produced by a
500 yaml reader which parses a pipeline definition document
502 Raises
503 ------
504 ValueError
505 Raised if:
507 - a pipeline is declared without a description;
508 - no tasks are declared in a pipeline, and no pipelines are to be
509 inherited;
510 - more than one instrument is specified;
511 - more than one inherited pipeline share a label.
512 """
514 def __init__(self, loaded_yaml: dict[str, Any]):
515 # Check required fields are present
516 if "description" not in loaded_yaml:
517 raise ValueError("A pipeline must be declared with a description")
518 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2:
519 raise ValueError("A pipeline must be declared with one or more tasks")
521 # These steps below must happen in this call order
523 # Process pipeline description
524 self.description = loaded_yaml.pop("description")
526 # Process tasks
527 self._read_tasks(loaded_yaml)
529 # Process instrument keys
530 inst = loaded_yaml.pop("instrument", None)
531 if isinstance(inst, list):
532 raise ValueError("Only one top level instrument can be defined in a pipeline")
533 self.instrument: str | None = inst
535 # Process any contracts
536 self._read_contracts(loaded_yaml)
538 # Process any defined parameters
539 self._read_parameters(loaded_yaml)
541 # Process any named label subsets
542 self._read_labeled_subsets(loaded_yaml)
544 # Process any inherited pipelines
545 self._read_imports(loaded_yaml)
547 # verify named subsets, must be done after inheriting
548 self._verify_labeled_subsets()
550 def _read_contracts(self, loaded_yaml: dict[str, Any]) -> None:
551 """Process the contracts portion of the loaded yaml document
553 Parameters
554 ---------
555 loaded_yaml : `dict`
556 A dictionary which matches the structure that would be produced by
557 a yaml reader which parses a pipeline definition document
558 """
559 loaded_contracts = loaded_yaml.pop("contracts", [])
560 if isinstance(loaded_contracts, str):
561 loaded_contracts = [loaded_contracts]
562 self.contracts: list[ContractIR] = []
563 for contract in loaded_contracts:
564 if isinstance(contract, dict):
565 self.contracts.append(ContractIR(**contract))
566 if isinstance(contract, str):
567 self.contracts.append(ContractIR(contract=contract))
569 def _read_parameters(self, loaded_yaml: dict[str, Any]) -> None:
570 """Process the parameters portion of the loaded yaml document
572 Parameters
573 ---------
574 loaded_yaml : `dict`
575 A dictionary which matches the structure that would be produced by
576 a yaml reader which parses a pipeline definition document
577 """
578 loaded_parameters = loaded_yaml.pop("parameters", {})
579 if not isinstance(loaded_parameters, dict):
580 raise ValueError("The parameters section must be a yaml mapping")
581 self.parameters = ParametersIR(loaded_parameters)
583 def _read_labeled_subsets(self, loaded_yaml: dict[str, Any]) -> None:
584 """Process the subsets portion of the loaded yaml document
586 Parameters
587 ----------
588 loaded_yaml: `MutableMapping`
589 A dictionary which matches the structure that would be produced
590 by a yaml reader which parses a pipeline definition document
591 """
592 loaded_subsets = loaded_yaml.pop("subsets", {})
593 self.labeled_subsets: dict[str, LabeledSubset] = {}
594 if not loaded_subsets and "subset" in loaded_yaml:
595 raise ValueError("Top level key should be subsets and not subset, add an s")
596 for key, value in loaded_subsets.items():
597 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value)
599 def _verify_labeled_subsets(self) -> None:
600 """Verifies that all the labels in each named subset exist within the
601 pipeline.
602 """
603 # Verify that all labels defined in a labeled subset are in the
604 # Pipeline
605 for labeled_subset in self.labeled_subsets.values():
606 if not labeled_subset.subset.issubset(self.tasks.keys()):
607 raise ValueError(
608 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the "
609 "declared pipeline"
610 )
611 # Verify subset labels are not already task labels
612 label_intersection = self.labeled_subsets.keys() & self.tasks.keys()
613 if label_intersection:
614 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}")
616 def _read_imports(self, loaded_yaml: dict[str, Any]) -> None:
617 """Process the inherits portion of the loaded yaml document
619 Parameters
620 ---------
621 loaded_yaml : `dict`
622 A dictionary which matches the structure that would be produced by
623 a yaml reader which parses a pipeline definition document
624 """
626 def process_args(argument: str | dict) -> dict:
627 if isinstance(argument, str):
628 return {"location": argument}
629 elif isinstance(argument, dict):
630 if "exclude" in argument and isinstance(argument["exclude"], str):
631 argument["exclude"] = [argument["exclude"]]
632 if "include" in argument and isinstance(argument["include"], str):
633 argument["include"] = [argument["include"]]
634 if "instrument" in argument and argument["instrument"] == "None":
635 argument["instrument"] = None
636 return argument
638 if not {"inherits", "imports"} - loaded_yaml.keys():
639 raise ValueError("Cannot define both inherits and imports sections, use imports")
640 tmp_import = loaded_yaml.pop("inherits", None)
641 if tmp_import is None:
642 tmp_import = loaded_yaml.pop("imports", None)
643 else:
644 raise ValueError("The 'inherits' key is not supported. Please use the key 'imports' instead")
645 if tmp_import is None:
646 self.imports: list[ImportIR] = []
647 elif isinstance(tmp_import, list):
648 self.imports = [ImportIR(**process_args(args)) for args in tmp_import]
649 else:
650 self.imports = [ImportIR(**process_args(tmp_import))]
652 self.merge_pipelines([fragment.toPipelineIR() for fragment in self.imports])
654 def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
655 """Merge one or more other `PipelineIR` objects into this object.
657 Parameters
658 ----------
659 pipelines : `Iterable` of `PipelineIR` objects
660 An `Iterable` that contains one or more `PipelineIR` objects to
661 merge into this object.
663 Raises
664 ------
665 ValueError
666 Raised if there is a conflict in instrument specifications.
667 Raised if a task label appears in more than one of the input
668 `PipelineIR` objects which are to be merged.
669 Raised if a labeled subset appears in more than one of the input
670 `PipelineIR` objects which are to be merged, and with any subset
671 existing in this object.
672 """
673 # integrate any imported pipelines
674 accumulate_tasks: dict[str, TaskIR] = {}
675 accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
676 accumulated_parameters = ParametersIR({})
678 for tmp_IR in pipelines:
679 if self.instrument is None:
680 self.instrument = tmp_IR.instrument
681 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None:
682 msg = (
683 "Only one instrument can be declared in a pipeline or its imports. "
684 f"Top level pipeline defines {self.instrument} but pipeline to merge "
685 f"defines {tmp_IR.instrument}."
686 )
687 raise ValueError(msg)
688 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys():
689 msg = (
690 "Task labels in the imported pipelines must be unique. "
691 f"These labels appear multiple times: {duplicate_labels}"
692 )
693 raise ValueError(msg)
694 accumulate_tasks.update(tmp_IR.tasks)
695 self.contracts.extend(tmp_IR.contracts)
696 # verify that tmp_IR has unique labels for named subset among
697 # existing labeled subsets, and with existing task labels.
698 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys()
699 task_subset_overlap = (
700 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()
701 ) & accumulate_tasks.keys()
702 if overlapping_subsets or task_subset_overlap:
703 raise ValueError(
704 "Labeled subset names must be unique amongst imports in both labels and "
705 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}"
706 )
707 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets)
708 accumulated_parameters.update(tmp_IR.parameters)
710 # verify that any accumulated labeled subsets dont clash with a label
711 # from this pipeline
712 if accumulate_labeled_subsets.keys() & self.tasks.keys():
713 raise ValueError(
714 "Labeled subset names must be unique amongst imports in both labels and named Subsets"
715 )
716 # merge in the named subsets for self so this document can override any
717 # that have been delcared
718 accumulate_labeled_subsets.update(self.labeled_subsets)
719 self.labeled_subsets = accumulate_labeled_subsets
721 # merge the dict of label:TaskIR objects, preserving any configs in the
722 # imported pipeline if the labels point to the same class
723 for label, task in self.tasks.items():
724 if label not in accumulate_tasks:
725 accumulate_tasks[label] = task
726 elif accumulate_tasks[label].klass == task.klass:
727 if task.config is not None:
728 for config in task.config:
729 accumulate_tasks[label].add_or_update_config(config)
730 else:
731 accumulate_tasks[label] = task
732 self.tasks: dict[str, TaskIR] = accumulate_tasks
733 accumulated_parameters.update(self.parameters)
734 self.parameters = accumulated_parameters
736 def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
737 """Process the tasks portion of the loaded yaml document
739 Parameters
740 ---------
741 loaded_yaml : `dict`
742 A dictionary which matches the structure that would be produced by
743 a yaml reader which parses a pipeline definition document
744 """
745 self.tasks = {}
746 tmp_tasks = loaded_yaml.pop("tasks", None)
747 if tmp_tasks is None:
748 tmp_tasks = {}
750 if "parameters" in tmp_tasks:
751 raise ValueError("parameters is a reserved word and cannot be used as a task label")
753 for label, definition in tmp_tasks.items():
754 if isinstance(definition, str):
755 definition = {"class": definition}
756 config = definition.get("config", None)
757 if config is None:
758 task_config_ir = None
759 else:
760 if isinstance(config, dict):
761 config = [config]
762 task_config_ir = []
763 for c in config:
764 file = c.pop("file", None)
765 if file is None:
766 file = []
767 elif not isinstance(file, list):
768 file = [file]
769 task_config_ir.append(
770 ConfigIR(
771 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c
772 )
773 )
774 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
776 def _remove_contracts(self, label: str) -> None:
777 """Remove any contracts that contain the given label
779 String comparison used in this way is not the most elegant and may
780 have issues, but it is the only feasible way when users can specify
781 contracts with generic strings.
782 """
783 new_contracts = []
784 for contract in self.contracts:
785 # match a label that is not preceded by an ASCII identifier, or
786 # is the start of a line and is followed by a dot
787 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
788 continue
789 new_contracts.append(contract)
790 self.contracts = new_contracts
792 def subset_from_labels(self, labelSpecifier: set[str]) -> PipelineIR:
793 """Subset a pipelineIR to contain only labels specified in
794 labelSpecifier.
796 Parameters
797 ----------
798 labelSpecifier : `set` of `str`
799 set containing labels that describes how to subset a pipeline.
801 Returns
802 -------
803 pipeline : `PipelineIR`
804 A new pipelineIR object that is a subset of the old pipelineIR
806 Raises
807 ------
808 ValueError
809 Raised if there is an issue with specified labels
811 Notes
812 -----
813 This method attempts to prune any contracts that contain labels which
814 are not in the declared subset of labels. This pruning is done using a
815 string based matching due to the nature of contracts and may prune more
816 than it should. Any labeled subsets defined that no longer have all
817 members of the subset present in the pipeline will be removed from the
818 resulting pipeline.
819 """
821 pipeline = copy.deepcopy(self)
823 # update the label specifier to expand any named subsets
824 toRemove = set()
825 toAdd = set()
826 for label in labelSpecifier:
827 if label in pipeline.labeled_subsets:
828 toRemove.add(label)
829 toAdd.update(pipeline.labeled_subsets[label].subset)
830 labelSpecifier.difference_update(toRemove)
831 labelSpecifier.update(toAdd)
832 # verify all the labels are in the pipeline
833 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets):
834 difference = labelSpecifier.difference(pipeline.tasks.keys())
835 raise ValueError(
836 "Not all supplied labels (specified or named subsets) are in the pipeline "
837 f"definition, extra labels: {difference}"
838 )
839 # copy needed so as to not modify while iterating
840 pipeline_labels = set(pipeline.tasks.keys())
841 # Remove the labels from the pipelineIR, and any contracts that contain
842 # those labels (see docstring on _remove_contracts for why this may
843 # cause issues)
844 for label in pipeline_labels:
845 if label not in labelSpecifier:
846 pipeline.tasks.pop(label)
847 pipeline._remove_contracts(label)
849 # create a copy of the object to iterate over
850 labeled_subsets = copy.copy(pipeline.labeled_subsets)
851 # remove any labeled subsets that no longer have a complete set
852 for label, labeled_subset in labeled_subsets.items():
853 if labeled_subset.subset - pipeline.tasks.keys():
854 pipeline.labeled_subsets.pop(label)
856 return pipeline
858 @classmethod
859 def from_string(cls, pipeline_string: str) -> PipelineIR:
860 """Create a `PipelineIR` object from a string formatted like a pipeline
861 document
863 Parameters
864 ----------
865 pipeline_string : `str`
866 A string that is formatted according like a pipeline document
867 """
868 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
869 return cls(loaded_yaml)
871 @classmethod
872 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR:
873 """Create a `PipelineIR` object from the document specified by the
874 input uri.
876 Parameters
877 ----------
878 uri: convertible to `ResourcePath`
879 Location of document to use in creating a `PipelineIR` object.
881 Returns
882 -------
883 pipelineIR : `PipelineIR`
884 The loaded pipeline
885 """
886 loaded_uri = ResourcePath(uri)
887 with loaded_uri.open("r") as buffer:
888 loaded_yaml = yaml.load(buffer, Loader=PipelineYamlLoader)
889 return cls(loaded_yaml)
891 def write_to_uri(self, uri: ResourcePathExpression) -> None:
892 """Serialize this `PipelineIR` object into a yaml formatted string and
893 write the output to a file at the specified uri.
895 Parameters
896 ----------
897 uri: convertible to `ResourcePath`
898 Location of document to write a `PipelineIR` object.
899 """
900 with ResourcePath(uri).open("w") as buffer:
901 yaml.dump(self.to_primitives(), buffer, sort_keys=False, Dumper=MultilineStringDumper)
903 def to_primitives(self) -> dict[str, Any]:
904 """Convert to a representation used in yaml serialization
906 Returns
907 -------
908 primitives : `dict`
909 dictionary that maps directly to the serialized YAML form.
910 """
911 accumulate = {"description": self.description}
912 if self.instrument is not None:
913 accumulate["instrument"] = self.instrument
914 if self.parameters:
915 accumulate["parameters"] = self.parameters.to_primitives()
916 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()}
917 if len(self.contracts) > 0:
918 # sort contracts lexicographical order by the contract string in
919 # absence of any other ordering principle
920 contracts_list = [c.to_primitives() for c in self.contracts]
921 contracts_list.sort(key=lambda x: x["contract"])
922 accumulate["contracts"] = contracts_list
923 if self.labeled_subsets:
924 accumulate["subsets"] = {k: v.to_primitives() for k, v in self.labeled_subsets.items()}
925 return accumulate
927 def __str__(self) -> str:
928 """Instance formatting as how it would look in yaml representation"""
929 return yaml.dump(self.to_primitives(), sort_keys=False, Dumper=MultilineStringDumper)
931 def __repr__(self) -> str:
932 """Instance formatting as how it would look in yaml representation"""
933 return str(self)
935 def __eq__(self, other: object) -> bool:
936 if not isinstance(other, PipelineIR):
937 return False
938 # special case contracts because it is a list, but order is not
939 # important
940 elif (
941 all(
942 getattr(self, attr) == getattr(other, attr)
943 for attr in ("tasks", "instrument", "labeled_subsets", "parameters")
944 )
945 and len(self.contracts) == len(other.contracts)
946 and all(c in self.contracts for c in other.contracts)
947 ):
948 return True
949 else:
950 return False