Coverage for python/lsst/pipe/base/pipelineIR.py: 18%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset")
25import copy
26import os
27import re
28import warnings
29from collections import Counter
30from collections.abc import Iterable as abcIterable
31from dataclasses import dataclass, field
32from typing import Any, Dict, Generator, List, Mapping, MutableMapping, Optional, Set, Type, Union
34import yaml
35from deprecated.sphinx import deprecated
36from lsst.resources import ResourcePath, ResourcePathExpression
39class KeepInstrument:
40 pass
43class PipelineYamlLoader(yaml.SafeLoader):
44 """This is a specialized version of yaml's SafeLoader. It checks and raises
45 an exception if it finds that there are multiple instances of the same key
46 found inside a pipeline file at a given scope.
47 """
49 def construct_mapping(self, node, deep=False):
50 # do the call to super first so that it can do all the other forms of
51 # checking on this node. If you check the uniqueness of keys first
52 # it would save the work that super does in the case of a failure, but
53 # it might fail in the case that the node was the incorrect node due
54 # to a parsing error, and the resulting exception would be difficult to
55 # understand.
56 mapping = super().construct_mapping(node, deep)
57 # Check if there are any duplicate keys
58 all_keys = Counter(key_node.value for key_node, _ in node.value)
59 duplicates = {k for k, i in all_keys.items() if i != 1}
60 if duplicates:
61 raise KeyError(
62 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times"
63 )
64 return mapping
67class ContractError(Exception):
68 """An exception that is raised when a pipeline contract is not satisfied"""
70 pass
73@dataclass
74class ContractIR:
75 """Intermediate representation of configuration contracts read from a
76 pipeline yaml file."""
78 contract: str
79 """A string of python code representing one or more conditions on configs
80 in a pipeline. This code-as-string should, once evaluated, should be True
81 if the configs are fine, and False otherwise.
82 """
83 msg: Union[str, None] = None
84 """An optional message to be shown to the user if a contract fails
85 """
87 def to_primitives(self) -> Dict[str, str]:
88 """Convert to a representation used in yaml serialization"""
89 accumulate = {"contract": self.contract}
90 if self.msg is not None:
91 accumulate["msg"] = self.msg
92 return accumulate
94 def __eq__(self, other: object):
95 if not isinstance(other, ContractIR):
96 return False
97 elif self.contract == other.contract and self.msg == other.msg:
98 return True
99 else:
100 return False
103@dataclass
104class LabeledSubset:
105 """Intermediate representation of named subset of task labels read from
106 a pipeline yaml file.
107 """
109 label: str
110 """The label used to identify the subset of task labels.
111 """
112 subset: Set[str]
113 """A set of task labels contained in this subset.
114 """
115 description: Optional[str]
116 """A description of what this subset of tasks is intended to do
117 """
119 @staticmethod
120 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset:
121 """Generate `LabeledSubset` objects given a properly formatted object
122 that as been created by a yaml loader.
124 Parameters
125 ----------
126 label : `str`
127 The label that will be used to identify this labeled subset.
128 value : `list` of `str` or `dict`
129 Object returned from loading a labeled subset section from a yaml
130 document.
132 Returns
133 -------
134 labeledSubset : `LabeledSubset`
135 A `LabeledSubset` object build from the inputs.
137 Raises
138 ------
139 ValueError
140 Raised if the value input is not properly formatted for parsing
141 """
142 if isinstance(value, MutableMapping):
143 subset = value.pop("subset", None)
144 if subset is None:
145 raise ValueError(
146 "If a labeled subset is specified as a mapping, it must contain the key 'subset'"
147 )
148 description = value.pop("description", None)
149 elif isinstance(value, abcIterable):
150 subset = value
151 description = None
152 else:
153 raise ValueError(
154 f"There was a problem parsing the labeled subset {label}, make sure the "
155 "definition is either a valid yaml list, or a mapping with keys "
156 "(subset, description) where subset points to a yaml list, and description is "
157 "associated with a string"
158 )
159 return LabeledSubset(label, set(subset), description)
161 def to_primitives(self) -> Dict[str, Union[List[str], str]]:
162 """Convert to a representation used in yaml serialization"""
163 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)}
164 if self.description is not None:
165 accumulate["description"] = self.description
166 return accumulate
169@dataclass
170class ParametersIR:
171 """Intermediate representation of parameters that are global to a pipeline
173 These parameters are specified under a top level key named `parameters`
174 and are declared as a yaml mapping. These entries can then be used inside
175 task configuration blocks to specify configuration values. They may not be
176 used in the special ``file`` or ``python`` blocks.
178 Example:
179 paramters:
180 shared_value: 14
181 tasks:
182 taskA:
183 class: modA
184 config:
185 field1: parameters.shared_value
186 taskB:
187 class: modB
188 config:
189 field2: parameters.shared_value
190 """
192 mapping: MutableMapping[str, str]
193 """A mutable mapping of identifiers as keys, and shared configuration
194 as values.
195 """
197 def update(self, other: Optional[ParametersIR]):
198 if other is not None:
199 self.mapping.update(other.mapping)
201 def to_primitives(self) -> MutableMapping[str, str]:
202 """Convert to a representation used in yaml serialization"""
203 return self.mapping
205 def __contains__(self, value: str) -> bool:
206 return value in self.mapping
208 def __getitem__(self, item: str) -> Any:
209 return self.mapping[item]
211 def __bool__(self) -> bool:
212 return bool(self.mapping)
215@dataclass
216class ConfigIR:
217 """Intermediate representation of configurations read from a pipeline yaml
218 file.
219 """
221 python: Union[str, None] = None
222 """A string of python code that is used to modify a configuration. This can
223 also be None if there are no modifications to do.
224 """
225 dataId: Union[dict, None] = None
226 """A dataId that is used to constrain these config overrides to only quanta
227 with matching dataIds. This field can be None if there is no constraint.
228 This is currently an unimplemented feature, and is placed here for future
229 use.
230 """
231 file: List[str] = field(default_factory=list)
232 """A list of paths which points to a file containing config overrides to be
233 applied. This value may be an empty list if there are no overrides to
234 apply.
235 """
236 rest: dict = field(default_factory=dict)
237 """This is a dictionary of key value pairs, where the keys are strings
238 corresponding to qualified fields on a config to override, and the values
239 are strings representing the values to apply.
240 """
242 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]:
243 """Convert to a representation used in yaml serialization"""
244 accumulate = {}
245 for name in ("python", "dataId", "file"):
246 # if this attribute is thruthy add it to the accumulation
247 # dictionary
248 if getattr(self, name):
249 accumulate[name] = getattr(self, name)
250 # Add the dictionary containing the rest of the config keys to the
251 # # accumulated dictionary
252 accumulate.update(self.rest)
253 return accumulate
255 def formatted(self, parameters: ParametersIR) -> ConfigIR:
256 """Returns a new ConfigIR object that is formatted according to the
257 specified parameters
259 Parameters
260 ----------
261 parameters : ParametersIR
262 Object that contains variable mappings used in substitution.
264 Returns
265 -------
266 config : ConfigIR
267 A new ConfigIR object formatted with the input parameters
268 """
269 new_config = copy.deepcopy(self)
270 for key, value in new_config.rest.items():
271 if not isinstance(value, str):
272 continue
273 match = re.match("parameters[.](.*)", value)
274 if match and match.group(1) in parameters:
275 new_config.rest[key] = parameters[match.group(1)]
276 if match and match.group(1) not in parameters:
277 warnings.warn(
278 f"config {key} contains value {match.group(0)} which is formatted like a "
279 "Pipeline parameter but was not found within the Pipeline, if this was not "
280 "intentional, check for a typo"
281 )
282 return new_config
284 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
285 """Merges another instance of a `ConfigIR` into this instance if
286 possible. This function returns a generator that is either self
287 if the configs were merged, or self, and other_config if that could
288 not be merged.
290 Parameters
291 ----------
292 other_config : `ConfigIR`
293 An instance of `ConfigIR` to merge into this instance.
295 Returns
296 -------
297 Generator : `ConfigIR`
298 A generator containing either self, or self and other_config if
299 the configs could be merged or not respectively.
300 """
301 # Verify that the config blocks can be merged
302 if (
303 self.dataId != other_config.dataId
304 or self.python
305 or other_config.python
306 or self.file
307 or other_config.file
308 ):
309 yield from (self, other_config)
310 return
312 # create a set of all keys, and verify two keys do not have different
313 # values
314 key_union = self.rest.keys() & other_config.rest.keys()
315 for key in key_union:
316 if self.rest[key] != other_config.rest[key]:
317 yield from (self, other_config)
318 return
319 self.rest.update(other_config.rest)
321 # Combine the lists of override files to load
322 self_file_set = set(self.file)
323 other_file_set = set(other_config.file)
324 self.file = list(self_file_set.union(other_file_set))
326 yield self
328 def __eq__(self, other: object):
329 if not isinstance(other, ConfigIR):
330 return False
331 elif all(
332 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest")
333 ):
334 return True
335 else:
336 return False
339@dataclass
340class TaskIR:
341 """Intermediate representation of tasks read from a pipeline yaml file."""
343 label: str
344 """An identifier used to refer to a task.
345 """
346 klass: str
347 """A string containing a fully qualified python class to be run in a
348 pipeline.
349 """
350 config: Union[List[ConfigIR], None] = None
351 """List of all configs overrides associated with this task, and may be
352 `None` if there are no config overrides.
353 """
355 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]:
356 """Convert to a representation used in yaml serialization"""
357 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass}
358 if self.config:
359 accumulate["config"] = [c.to_primitives() for c in self.config]
360 return accumulate
362 def add_or_update_config(self, other_config: ConfigIR):
363 """Adds a `ConfigIR` to this task if one is not present. Merges configs
364 if there is a `ConfigIR` present and the dataId keys of both configs
365 match, otherwise adds a new entry to the config list. The exception to
366 the above is that if either the last config or other_config has a
367 python block, then other_config is always added, as python blocks can
368 modify configs in ways that cannot be predicted.
370 Parameters
371 ----------
372 other_config : `ConfigIR`
373 A `ConfigIR` instance to add or merge into the config attribute of
374 this task.
375 """
376 if not self.config:
377 self.config = [other_config]
378 return
379 self.config.extend(self.config.pop().maybe_merge(other_config))
381 def __eq__(self, other: object):
382 if not isinstance(other, TaskIR):
383 return False
384 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")):
385 return True
386 else:
387 return False
390@dataclass
391class ImportIR:
392 """An intermediate representation of imported pipelines"""
394 location: str
395 """This is the location of the pipeline to inherit. The path should be
396 specified as an absolute path. Environment variables may be used in the
397 path and should be specified as a python string template, with the name of
398 the environment variable inside braces.
399 """
400 include: Union[List[str], None] = None
401 """List of tasks that should be included when inheriting this pipeline.
402 Either the include or exclude attributes may be specified, but not both.
403 """
404 exclude: Union[List[str], None] = None
405 """List of tasks that should be excluded when inheriting this pipeline.
406 Either the include or exclude attributes may be specified, but not both.
407 """
408 importContracts: bool = True
409 """Boolean attribute to dictate if contracts should be inherited with the
410 pipeline or not.
411 """
412 instrument: Union[Type[KeepInstrument], str, None] = KeepInstrument
413 """Instrument to assign to the Pipeline at import. The default value of
414 KEEP_INSTRUMENT indicates that whatever instrument the pipeline is declared
415 with will not be modified. Setting this value to None will drop any
416 declared instrument prior to import.
417 """
419 def toPipelineIR(self) -> "PipelineIR":
420 """Load in the Pipeline specified by this object, and turn it into a
421 PipelineIR instance.
423 Returns
424 -------
425 pipeline : `PipelineIR`
426 A pipeline generated from the imported pipeline file
427 """
428 if self.include and self.exclude:
429 raise ValueError(
430 "Both an include and an exclude list cant be specified when declaring a pipeline import"
431 )
432 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location))
433 if self.instrument is not KeepInstrument:
434 tmp_pipeline.instrument = self.instrument
436 included_labels = set()
437 for label in tmp_pipeline.tasks:
438 if (
439 (self.include and label in self.include)
440 or (self.exclude and label not in self.exclude)
441 or (self.include is None and self.exclude is None)
442 ):
443 included_labels.add(label)
445 # Handle labeled subsets being specified in the include or exclude
446 # list, adding or removing labels.
447 if self.include is not None:
448 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include
449 for label in subsets_in_include:
450 included_labels.update(tmp_pipeline.labeled_subsets[label].subset)
452 elif self.exclude is not None:
453 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude
454 for label in subsets_in_exclude:
455 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset)
457 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels)
459 if not self.importContracts:
460 tmp_pipeline.contracts = []
462 return tmp_pipeline
464 def __eq__(self, other: object):
465 if not isinstance(other, ImportIR):
466 return False
467 elif all(
468 getattr(self, attr) == getattr(other, attr)
469 for attr in ("location", "include", "exclude", "importContracts")
470 ):
471 return True
472 else:
473 return False
476class PipelineIR:
477 """Intermediate representation of a pipeline definition
479 Parameters
480 ----------
481 loaded_yaml : `dict`
482 A dictionary which matches the structure that would be produced by a
483 yaml reader which parses a pipeline definition document
485 Raises
486 ------
487 ValueError :
488 - If a pipeline is declared without a description
489 - If no tasks are declared in a pipeline, and no pipelines are to be
490 inherited
491 - If more than one instrument is specified
492 - If more than one inherited pipeline share a label
493 """
495 def __init__(self, loaded_yaml):
496 # Check required fields are present
497 if "description" not in loaded_yaml:
498 raise ValueError("A pipeline must be declared with a description")
499 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2:
500 raise ValueError("A pipeline must be declared with one or more tasks")
502 # These steps below must happen in this call order
504 # Process pipeline description
505 self.description = loaded_yaml.pop("description")
507 # Process tasks
508 self._read_tasks(loaded_yaml)
510 # Process instrument keys
511 inst = loaded_yaml.pop("instrument", None)
512 if isinstance(inst, list):
513 raise ValueError("Only one top level instrument can be defined in a pipeline")
514 self.instrument = inst
516 # Process any contracts
517 self._read_contracts(loaded_yaml)
519 # Process any defined parameters
520 self._read_parameters(loaded_yaml)
522 # Process any named label subsets
523 self._read_labeled_subsets(loaded_yaml)
525 # Process any inherited pipelines
526 self._read_imports(loaded_yaml)
528 # verify named subsets, must be done after inheriting
529 self._verify_labeled_subsets()
531 def _read_contracts(self, loaded_yaml):
532 """Process the contracts portion of the loaded yaml document
534 Parameters
535 ---------
536 loaded_yaml : `dict`
537 A dictionary which matches the structure that would be produced by
538 a yaml reader which parses a pipeline definition document
539 """
540 loaded_contracts = loaded_yaml.pop("contracts", [])
541 if isinstance(loaded_contracts, str):
542 loaded_contracts = [loaded_contracts]
543 self.contracts = []
544 for contract in loaded_contracts:
545 if isinstance(contract, dict):
546 self.contracts.append(ContractIR(**contract))
547 if isinstance(contract, str):
548 self.contracts.append(ContractIR(contract=contract))
550 def _read_parameters(self, loaded_yaml):
551 """Process the parameters portion of the loaded yaml document
553 Parameters
554 ---------
555 loaded_yaml : `dict`
556 A dictionary which matches the structure that would be produced by
557 a yaml reader which parses a pipeline definition document
558 """
559 loaded_parameters = loaded_yaml.pop("parameters", {})
560 if not isinstance(loaded_parameters, dict):
561 raise ValueError("The parameters section must be a yaml mapping")
562 self.parameters = ParametersIR(loaded_parameters)
564 def _read_labeled_subsets(self, loaded_yaml: dict):
565 """Process the subsets portion of the loaded yaml document
567 Parameters
568 ----------
569 loaded_yaml: `MutableMapping`
570 A dictionary which matches the structure that would be produced
571 by a yaml reader which parses a pipeline definition document
572 """
573 loaded_subsets = loaded_yaml.pop("subsets", {})
574 self.labeled_subsets = {}
575 if not loaded_subsets and "subset" in loaded_yaml:
576 raise ValueError("Top level key should be subsets and not subset, add an s")
577 for key, value in loaded_subsets.items():
578 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value)
580 def _verify_labeled_subsets(self):
581 """Verifies that all the labels in each named subset exist within the
582 pipeline.
583 """
584 # Verify that all labels defined in a labeled subset are in the
585 # Pipeline
586 for labeled_subset in self.labeled_subsets.values():
587 if not labeled_subset.subset.issubset(self.tasks.keys()):
588 raise ValueError(
589 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the "
590 "declared pipeline"
591 )
592 # Verify subset labels are not already task labels
593 label_intersection = self.labeled_subsets.keys() & self.tasks.keys()
594 if label_intersection:
595 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}")
597 def _read_imports(self, loaded_yaml):
598 """Process the inherits portion of the loaded yaml document
600 Parameters
601 ---------
602 loaded_yaml : `dict`
603 A dictionary which matches the structure that would be produced by
604 a yaml reader which parses a pipeline definition document
605 """
607 def process_args(argument: Union[str, dict]) -> dict:
608 if isinstance(argument, str):
609 return {"location": argument}
610 elif isinstance(argument, dict):
611 if "exclude" in argument and isinstance(argument["exclude"], str):
612 argument["exclude"] = [argument["exclude"]]
613 if "include" in argument and isinstance(argument["include"], str):
614 argument["include"] = [argument["include"]]
615 if "instrument" in argument and argument["instrument"] == "None":
616 argument["instrument"] = None
617 return argument
619 if not {"inherits", "imports"} - loaded_yaml.keys():
620 raise ValueError("Cannot define both inherits and imports sections, use imports")
621 tmp_import = loaded_yaml.pop("inherits", None)
622 if tmp_import is None:
623 tmp_import = loaded_yaml.pop("imports", None)
624 else:
625 warnings.warn(
626 "The 'inherits' key is deprecated, and will be "
627 "removed around June 2021. Please use the key "
628 "'imports' instead"
629 )
630 if tmp_import is None:
631 self.imports = []
632 elif isinstance(tmp_import, list):
633 self.imports = [ImportIR(**process_args(args)) for args in tmp_import]
634 else:
635 self.imports = [ImportIR(**process_args(tmp_import))]
637 # integrate any imported pipelines
638 accumulate_tasks = {}
639 accumulate_labeled_subsets = {}
640 accumulated_parameters = ParametersIR({})
641 for other_pipeline in self.imports:
642 tmp_IR = other_pipeline.toPipelineIR()
643 if self.instrument is None:
644 self.instrument = tmp_IR.instrument
645 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None:
646 msg = (
647 "Only one instrument can be declared in a pipeline or its imports. "
648 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} "
649 f"defines {tmp_IR.instrument}."
650 )
651 raise ValueError(msg)
652 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys():
653 msg = (
654 "Task labels in the imported pipelines must be unique. "
655 f"These labels appear multiple times: {duplicate_labels}"
656 )
657 raise ValueError(msg)
658 accumulate_tasks.update(tmp_IR.tasks)
659 self.contracts.extend(tmp_IR.contracts)
660 # verify that tmp_IR has unique labels for named subset among
661 # existing labeled subsets, and with existing task labels.
662 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys()
663 task_subset_overlap = (
664 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()
665 ) & accumulate_tasks.keys()
666 if overlapping_subsets or task_subset_overlap:
667 raise ValueError(
668 "Labeled subset names must be unique amongst imports in both labels and "
669 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}"
670 )
671 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets)
672 accumulated_parameters.update(tmp_IR.parameters)
674 # verify that any accumulated labeled subsets dont clash with a label
675 # from this pipeline
676 if accumulate_labeled_subsets.keys() & self.tasks.keys():
677 raise ValueError(
678 "Labeled subset names must be unique amongst imports in both labels and named Subsets"
679 )
680 # merge in the named subsets for self so this document can override any
681 # that have been delcared
682 accumulate_labeled_subsets.update(self.labeled_subsets)
683 self.labeled_subsets = accumulate_labeled_subsets
685 # merge the dict of label:TaskIR objects, preserving any configs in the
686 # imported pipeline if the labels point to the same class
687 for label, task in self.tasks.items():
688 if label not in accumulate_tasks:
689 accumulate_tasks[label] = task
690 elif accumulate_tasks[label].klass == task.klass:
691 if task.config is not None:
692 for config in task.config:
693 accumulate_tasks[label].add_or_update_config(config)
694 else:
695 accumulate_tasks[label] = task
696 self.tasks = accumulate_tasks
697 accumulated_parameters.update(self.parameters)
698 self.parameters = accumulated_parameters
700 def _read_tasks(self, loaded_yaml):
701 """Process the tasks portion of the loaded yaml document
703 Parameters
704 ---------
705 loaded_yaml : `dict`
706 A dictionary which matches the structure that would be produced by
707 a yaml reader which parses a pipeline definition document
708 """
709 self.tasks = {}
710 tmp_tasks = loaded_yaml.pop("tasks", None)
711 if tmp_tasks is None:
712 tmp_tasks = {}
714 if "parameters" in tmp_tasks:
715 raise ValueError("parameters is a reserved word and cannot be used as a task label")
717 for label, definition in tmp_tasks.items():
718 if isinstance(definition, str):
719 definition = {"class": definition}
720 config = definition.get("config", None)
721 if config is None:
722 task_config_ir = None
723 else:
724 if isinstance(config, dict):
725 config = [config]
726 task_config_ir = []
727 for c in config:
728 file = c.pop("file", None)
729 if file is None:
730 file = []
731 elif not isinstance(file, list):
732 file = [file]
733 task_config_ir.append(
734 ConfigIR(
735 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c
736 )
737 )
738 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
740 def _remove_contracts(self, label: str):
741 """Remove any contracts that contain the given label
743 String comparison used in this way is not the most elegant and may
744 have issues, but it is the only feasible way when users can specify
745 contracts with generic strings.
746 """
747 new_contracts = []
748 for contract in self.contracts:
749 # match a label that is not preceded by an ASCII identifier, or
750 # is the start of a line and is followed by a dot
751 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
752 continue
753 new_contracts.append(contract)
754 self.contracts = new_contracts
756 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR:
757 """Subset a pipelineIR to contain only labels specified in
758 labelSpecifier.
760 Parameters
761 ----------
762 labelSpecifier : `set` of `str`
763 Set containing labels that describes how to subset a pipeline.
765 Returns
766 -------
767 pipeline : `PipelineIR`
768 A new pipelineIR object that is a subset of the old pipelineIR
770 Raises
771 ------
772 ValueError
773 Raised if there is an issue with specified labels
775 Notes
776 -----
777 This method attempts to prune any contracts that contain labels which
778 are not in the declared subset of labels. This pruning is done using a
779 string based matching due to the nature of contracts and may prune more
780 than it should. Any labeled subsets defined that no longer have all
781 members of the subset present in the pipeline will be removed from the
782 resulting pipeline.
783 """
785 pipeline = copy.deepcopy(self)
787 # update the label specifier to expand any named subsets
788 toRemove = set()
789 toAdd = set()
790 for label in labelSpecifier:
791 if label in pipeline.labeled_subsets:
792 toRemove.add(label)
793 toAdd.update(pipeline.labeled_subsets[label].subset)
794 labelSpecifier.difference_update(toRemove)
795 labelSpecifier.update(toAdd)
796 # verify all the labels are in the pipeline
797 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets):
798 difference = labelSpecifier.difference(pipeline.tasks.keys())
799 raise ValueError(
800 "Not all supplied labels (specified or named subsets) are in the pipeline "
801 f"definition, extra labels: {difference}"
802 )
803 # copy needed so as to not modify while iterating
804 pipeline_labels = set(pipeline.tasks.keys())
805 # Remove the labels from the pipelineIR, and any contracts that contain
806 # those labels (see docstring on _remove_contracts for why this may
807 # cause issues)
808 for label in pipeline_labels:
809 if label not in labelSpecifier:
810 pipeline.tasks.pop(label)
811 pipeline._remove_contracts(label)
813 # create a copy of the object to iterate over
814 labeled_subsets = copy.copy(pipeline.labeled_subsets)
815 # remove any labeled subsets that no longer have a complete set
816 for label, labeled_subset in labeled_subsets.items():
817 if labeled_subset.subset - pipeline.tasks.keys():
818 pipeline.labeled_subsets.pop(label)
820 return pipeline
822 @classmethod
823 def from_string(cls, pipeline_string: str):
824 """Create a `PipelineIR` object from a string formatted like a pipeline
825 document
827 Parameters
828 ----------
829 pipeline_string : `str`
830 A string that is formatted according like a pipeline document
831 """
832 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
833 return cls(loaded_yaml)
835 @classmethod
836 @deprecated(
837 reason="This has been replaced with `from_uri`. will be removed after v23",
838 version="v21.0,",
839 category=FutureWarning,
840 ) # type: ignore
841 def from_file(cls, filename: str) -> PipelineIR:
842 """Create a `PipelineIR` object from the document specified by the
843 input path.
845 Parameters
846 ----------
847 filename : `str`
848 Location of document to use in creating a `PipelineIR` object.
850 Returns
851 -------
852 pipelineIR : `PipelineIR`
853 The loaded pipeline
855 Note
856 ----
857 This method is deprecated, please use from_uri
858 """
859 return cls.from_uri(filename)
861 @classmethod
862 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR:
863 """Create a `PipelineIR` object from the document specified by the
864 input uri.
866 Parameters
867 ----------
868 uri: convertible to `ResourcePath`
869 Location of document to use in creating a `PipelineIR` object.
871 Returns
872 -------
873 pipelineIR : `PipelineIR`
874 The loaded pipeline
875 """
876 loaded_uri = ResourcePath(uri)
877 with loaded_uri.open("r") as buffer:
878 # explicitly read here, there was some issue with yaml trying
879 # to read the ResourcePath itself (I think because it only
880 # pretends to be conformant to the io api)
881 loaded_yaml = yaml.load(buffer.read(), Loader=PipelineYamlLoader)
882 return cls(loaded_yaml)
884 @deprecated(
885 reason="This has been replaced with `write_to_uri`. will be removed after v23",
886 version="v21.0,",
887 category=FutureWarning,
888 ) # type: ignore
889 def to_file(self, filename: str):
890 """Serialize this `PipelineIR` object into a yaml formatted string and
891 write the output to a file at the specified path.
893 Parameters
894 ----------
895 filename : `str`
896 Location of document to write a `PipelineIR` object.
897 """
898 self.write_to_uri(filename)
900 def write_to_uri(
901 self,
902 uri: ResourcePathExpression,
903 ):
904 """Serialize this `PipelineIR` object into a yaml formatted string and
905 write the output to a file at the specified uri.
907 Parameters
908 ----------
909 uri: convertible to `ResourcePath`
910 Location of document to write a `PipelineIR` object.
911 """
912 with ResourcePath(uri).open("w") as buffer:
913 yaml.dump(self.to_primitives(), buffer, sort_keys=False)
915 def to_primitives(self) -> Dict[str, Any]:
916 """Convert to a representation used in yaml serialization"""
917 accumulate = {"description": self.description}
918 if self.instrument is not None:
919 accumulate["instrument"] = self.instrument
920 if self.parameters:
921 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives())
922 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()}
923 if len(self.contracts) > 0:
924 # sort contracts lexicographical order by the contract string in
925 # absence of any other ordering principle
926 contracts_list = [c.to_primitives() for c in self.contracts]
927 contracts_list.sort(key=lambda x: x["contract"])
928 accumulate["contracts"] = contracts_list
929 if self.labeled_subsets:
930 accumulate["subsets"] = self._sort_by_str(
931 {k: v.to_primitives() for k, v in self.labeled_subsets.items()}
932 )
933 return accumulate
935 def reorder_tasks(self, task_labels: List[str]) -> None:
936 """Changes the order tasks are stored internally. Useful for
937 determining the order things will appear in the serialized (or printed)
938 form.
940 Parameters
941 ----------
942 task_labels : `list` of `str`
943 A list corresponding to all the labels in the pipeline inserted in
944 the order the tasks are to be stored.
946 Raises
947 ------
948 KeyError
949 Raised if labels are supplied that are not in the pipeline, or if
950 not all labels in the pipeline were supplied in task_labels input.
951 """
952 # verify that all labels are in the input
953 _tmp_set = set(task_labels)
954 if remainder := (self.tasks.keys() - _tmp_set):
955 raise KeyError(f"Label(s) {remainder} are missing from the task label list")
956 if extra := (_tmp_set - self.tasks.keys()):
957 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline")
959 newTasks = {key: self.tasks[key] for key in task_labels}
960 self.tasks = newTasks
962 @staticmethod
963 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]:
964 keys = sorted(arg.keys())
965 return {key: arg[key] for key in keys}
967 def __str__(self) -> str:
968 """Instance formatting as how it would look in yaml representation"""
969 return yaml.dump(self.to_primitives(), sort_keys=False)
971 def __repr__(self) -> str:
972 """Instance formatting as how it would look in yaml representation"""
973 return str(self)
975 def __eq__(self, other: object):
976 if not isinstance(other, PipelineIR):
977 return False
978 # special case contracts because it is a list, but order is not
979 # important
980 elif (
981 all(
982 getattr(self, attr) == getattr(other, attr)
983 for attr in ("tasks", "instrument", "labeled_subsets", "parameters")
984 )
985 and len(self.contracts) == len(other.contracts)
986 and all(c in self.contracts for c in other.contracts)
987 ):
988 return True
989 else:
990 return False