Coverage for python/lsst/pipe/base/pipelineIR.py: 19%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset")
25import copy
26import enum
27import os
28import re
29import warnings
30from collections import Counter
31from collections.abc import Iterable as abcIterable
32from dataclasses import dataclass, field
33from typing import Any, Dict, Generator, List, Literal, Mapping, MutableMapping, Optional, Set, Union
35import yaml
36from deprecated.sphinx import deprecated
37from lsst.resources import ResourcePath, ResourcePathExpression
40class _Tags(enum.Enum):
41 KeepInstrument = enum.auto()
44class PipelineYamlLoader(yaml.SafeLoader):
45 """This is a specialized version of yaml's SafeLoader. It checks and raises
46 an exception if it finds that there are multiple instances of the same key
47 found inside a pipeline file at a given scope.
48 """
50 def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Mapping[str, Any]:
51 # do the call to super first so that it can do all the other forms of
52 # checking on this node. If you check the uniqueness of keys first
53 # it would save the work that super does in the case of a failure, but
54 # it might fail in the case that the node was the incorrect node due
55 # to a parsing error, and the resulting exception would be difficult to
56 # understand.
57 mapping = super().construct_mapping(node, deep)
58 # Check if there are any duplicate keys
59 all_keys = Counter(key_node.value for key_node, _ in node.value)
60 duplicates = {k for k, i in all_keys.items() if i != 1}
61 if duplicates:
62 raise KeyError(
63 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times"
64 )
65 return mapping
68class ContractError(Exception):
69 """An exception that is raised when a pipeline contract is not satisfied"""
71 pass
74@dataclass
75class ContractIR:
76 """Intermediate representation of configuration contracts read from a
77 pipeline yaml file."""
79 contract: str
80 """A string of python code representing one or more conditions on configs
81 in a pipeline. This code-as-string should, once evaluated, should be True
82 if the configs are fine, and False otherwise.
83 """
84 msg: Union[str, None] = None
85 """An optional message to be shown to the user if a contract fails
86 """
88 def to_primitives(self) -> Dict[str, str]:
89 """Convert to a representation used in yaml serialization"""
90 accumulate = {"contract": self.contract}
91 if self.msg is not None:
92 accumulate["msg"] = self.msg
93 return accumulate
95 def __eq__(self, other: object) -> bool:
96 if not isinstance(other, ContractIR):
97 return False
98 elif self.contract == other.contract and self.msg == other.msg:
99 return True
100 else:
101 return False
104@dataclass
105class LabeledSubset:
106 """Intermediate representation of named subset of task labels read from
107 a pipeline yaml file.
108 """
110 label: str
111 """The label used to identify the subset of task labels.
112 """
113 subset: Set[str]
114 """A set of task labels contained in this subset.
115 """
116 description: Optional[str]
117 """A description of what this subset of tasks is intended to do
118 """
120 @staticmethod
121 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset:
122 """Generate `LabeledSubset` objects given a properly formatted object
123 that as been created by a yaml loader.
125 Parameters
126 ----------
127 label : `str`
128 The label that will be used to identify this labeled subset.
129 value : `list` of `str` or `dict`
130 Object returned from loading a labeled subset section from a yaml
131 document.
133 Returns
134 -------
135 labeledSubset : `LabeledSubset`
136 A `LabeledSubset` object build from the inputs.
138 Raises
139 ------
140 ValueError
141 Raised if the value input is not properly formatted for parsing
142 """
143 if isinstance(value, MutableMapping):
144 subset = value.pop("subset", None)
145 if subset is None:
146 raise ValueError(
147 "If a labeled subset is specified as a mapping, it must contain the key 'subset'"
148 )
149 description = value.pop("description", None)
150 elif isinstance(value, abcIterable):
151 subset = value
152 description = None
153 else:
154 raise ValueError(
155 f"There was a problem parsing the labeled subset {label}, make sure the "
156 "definition is either a valid yaml list, or a mapping with keys "
157 "(subset, description) where subset points to a yaml list, and description is "
158 "associated with a string"
159 )
160 return LabeledSubset(label, set(subset), description)
162 def to_primitives(self) -> Dict[str, Union[List[str], str]]:
163 """Convert to a representation used in yaml serialization"""
164 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)}
165 if self.description is not None:
166 accumulate["description"] = self.description
167 return accumulate
170@dataclass
171class ParametersIR:
172 """Intermediate representation of parameters that are global to a pipeline
174 These parameters are specified under a top level key named `parameters`
175 and are declared as a yaml mapping. These entries can then be used inside
176 task configuration blocks to specify configuration values. They may not be
177 used in the special ``file`` or ``python`` blocks.
179 Example:
180 paramters:
181 shared_value: 14
182 tasks:
183 taskA:
184 class: modA
185 config:
186 field1: parameters.shared_value
187 taskB:
188 class: modB
189 config:
190 field2: parameters.shared_value
191 """
193 mapping: MutableMapping[str, str]
194 """A mutable mapping of identifiers as keys, and shared configuration
195 as values.
196 """
198 def update(self, other: Optional[ParametersIR]) -> None:
199 if other is not None:
200 self.mapping.update(other.mapping)
202 def to_primitives(self) -> MutableMapping[str, str]:
203 """Convert to a representation used in yaml serialization"""
204 return self.mapping
206 def __contains__(self, value: str) -> bool:
207 return value in self.mapping
209 def __getitem__(self, item: str) -> Any:
210 return self.mapping[item]
212 def __bool__(self) -> bool:
213 return bool(self.mapping)
216@dataclass
217class ConfigIR:
218 """Intermediate representation of configurations read from a pipeline yaml
219 file.
220 """
222 python: Union[str, None] = None
223 """A string of python code that is used to modify a configuration. This can
224 also be None if there are no modifications to do.
225 """
226 dataId: Union[dict, None] = None
227 """A dataId that is used to constrain these config overrides to only quanta
228 with matching dataIds. This field can be None if there is no constraint.
229 This is currently an unimplemented feature, and is placed here for future
230 use.
231 """
232 file: List[str] = field(default_factory=list)
233 """A list of paths which points to a file containing config overrides to be
234 applied. This value may be an empty list if there are no overrides to
235 apply.
236 """
237 rest: dict = field(default_factory=dict)
238 """This is a dictionary of key value pairs, where the keys are strings
239 corresponding to qualified fields on a config to override, and the values
240 are strings representing the values to apply.
241 """
243 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]:
244 """Convert to a representation used in yaml serialization"""
245 accumulate = {}
246 for name in ("python", "dataId", "file"):
247 # if this attribute is thruthy add it to the accumulation
248 # dictionary
249 if getattr(self, name):
250 accumulate[name] = getattr(self, name)
251 # Add the dictionary containing the rest of the config keys to the
252 # # accumulated dictionary
253 accumulate.update(self.rest)
254 return accumulate
256 def formatted(self, parameters: ParametersIR) -> ConfigIR:
257 """Returns a new ConfigIR object that is formatted according to the
258 specified parameters
260 Parameters
261 ----------
262 parameters : ParametersIR
263 Object that contains variable mappings used in substitution.
265 Returns
266 -------
267 config : ConfigIR
268 A new ConfigIR object formatted with the input parameters
269 """
270 new_config = copy.deepcopy(self)
271 for key, value in new_config.rest.items():
272 if not isinstance(value, str):
273 continue
274 match = re.match("parameters[.](.*)", value)
275 if match and match.group(1) in parameters:
276 new_config.rest[key] = parameters[match.group(1)]
277 if match and match.group(1) not in parameters:
278 warnings.warn(
279 f"config {key} contains value {match.group(0)} which is formatted like a "
280 "Pipeline parameter but was not found within the Pipeline, if this was not "
281 "intentional, check for a typo"
282 )
283 return new_config
285 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
286 """Merges another instance of a `ConfigIR` into this instance if
287 possible. This function returns a generator that is either self
288 if the configs were merged, or self, and other_config if that could
289 not be merged.
291 Parameters
292 ----------
293 other_config : `ConfigIR`
294 An instance of `ConfigIR` to merge into this instance.
296 Returns
297 -------
298 Generator : `ConfigIR`
299 A generator containing either self, or self and other_config if
300 the configs could be merged or not respectively.
301 """
302 # Verify that the config blocks can be merged
303 if (
304 self.dataId != other_config.dataId
305 or self.python
306 or other_config.python
307 or self.file
308 or other_config.file
309 ):
310 yield from (self, other_config)
311 return
313 # create a set of all keys, and verify two keys do not have different
314 # values
315 key_union = self.rest.keys() & other_config.rest.keys()
316 for key in key_union:
317 if self.rest[key] != other_config.rest[key]:
318 yield from (self, other_config)
319 return
320 self.rest.update(other_config.rest)
322 # Combine the lists of override files to load
323 self_file_set = set(self.file)
324 other_file_set = set(other_config.file)
325 self.file = list(self_file_set.union(other_file_set))
327 yield self
329 def __eq__(self, other: object) -> bool:
330 if not isinstance(other, ConfigIR):
331 return False
332 elif all(
333 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest")
334 ):
335 return True
336 else:
337 return False
340@dataclass
341class TaskIR:
342 """Intermediate representation of tasks read from a pipeline yaml file."""
344 label: str
345 """An identifier used to refer to a task.
346 """
347 klass: str
348 """A string containing a fully qualified python class to be run in a
349 pipeline.
350 """
351 config: Union[List[ConfigIR], None] = None
352 """List of all configs overrides associated with this task, and may be
353 `None` if there are no config overrides.
354 """
356 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]:
357 """Convert to a representation used in yaml serialization"""
358 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass}
359 if self.config:
360 accumulate["config"] = [c.to_primitives() for c in self.config]
361 return accumulate
363 def add_or_update_config(self, other_config: ConfigIR) -> None:
364 """Adds a `ConfigIR` to this task if one is not present. Merges configs
365 if there is a `ConfigIR` present and the dataId keys of both configs
366 match, otherwise adds a new entry to the config list. The exception to
367 the above is that if either the last config or other_config has a
368 python block, then other_config is always added, as python blocks can
369 modify configs in ways that cannot be predicted.
371 Parameters
372 ----------
373 other_config : `ConfigIR`
374 A `ConfigIR` instance to add or merge into the config attribute of
375 this task.
376 """
377 if not self.config:
378 self.config = [other_config]
379 return
380 self.config.extend(self.config.pop().maybe_merge(other_config))
382 def __eq__(self, other: object) -> bool:
383 if not isinstance(other, TaskIR):
384 return False
385 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")):
386 return True
387 else:
388 return False
391@dataclass
392class ImportIR:
393 """An intermediate representation of imported pipelines"""
395 location: str
396 """This is the location of the pipeline to inherit. The path should be
397 specified as an absolute path. Environment variables may be used in the
398 path and should be specified as a python string template, with the name of
399 the environment variable inside braces.
400 """
401 include: Union[List[str], None] = None
402 """List of tasks that should be included when inheriting this pipeline.
403 Either the include or exclude attributes may be specified, but not both.
404 """
405 exclude: Union[List[str], None] = None
406 """List of tasks that should be excluded when inheriting this pipeline.
407 Either the include or exclude attributes may be specified, but not both.
408 """
409 importContracts: bool = True
410 """Boolean attribute to dictate if contracts should be inherited with the
411 pipeline or not.
412 """
413 instrument: Union[Literal[_Tags.KeepInstrument], str, None] = _Tags.KeepInstrument
414 """Instrument to assign to the Pipeline at import. The default value of
415 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is
416 declared with will not be modified. Setting this value to None will drop
417 any declared instrument prior to import.
418 """
420 def toPipelineIR(self) -> "PipelineIR":
421 """Load in the Pipeline specified by this object, and turn it into a
422 PipelineIR instance.
424 Returns
425 -------
426 pipeline : `PipelineIR`
427 A pipeline generated from the imported pipeline file
428 """
429 if self.include and self.exclude:
430 raise ValueError(
431 "Both an include and an exclude list cant be specified when declaring a pipeline import"
432 )
433 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location))
434 if self.instrument is not _Tags.KeepInstrument:
435 tmp_pipeline.instrument = self.instrument
437 included_labels = set()
438 for label in tmp_pipeline.tasks:
439 if (
440 (self.include and label in self.include)
441 or (self.exclude and label not in self.exclude)
442 or (self.include is None and self.exclude is None)
443 ):
444 included_labels.add(label)
446 # Handle labeled subsets being specified in the include or exclude
447 # list, adding or removing labels.
448 if self.include is not None:
449 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include
450 for label in subsets_in_include:
451 included_labels.update(tmp_pipeline.labeled_subsets[label].subset)
453 elif self.exclude is not None:
454 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude
455 for label in subsets_in_exclude:
456 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset)
458 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels)
460 if not self.importContracts:
461 tmp_pipeline.contracts = []
463 return tmp_pipeline
465 def __eq__(self, other: object) -> bool:
466 if not isinstance(other, ImportIR):
467 return False
468 elif all(
469 getattr(self, attr) == getattr(other, attr)
470 for attr in ("location", "include", "exclude", "importContracts")
471 ):
472 return True
473 else:
474 return False
477class PipelineIR:
478 """Intermediate representation of a pipeline definition
480 Parameters
481 ----------
482 loaded_yaml : `dict`
483 A dictionary which matches the structure that would be produced by a
484 yaml reader which parses a pipeline definition document
486 Raises
487 ------
488 ValueError :
489 - If a pipeline is declared without a description
490 - If no tasks are declared in a pipeline, and no pipelines are to be
491 inherited
492 - If more than one instrument is specified
493 - If more than one inherited pipeline share a label
494 """
496 def __init__(self, loaded_yaml: Dict[str, Any]):
497 # Check required fields are present
498 if "description" not in loaded_yaml:
499 raise ValueError("A pipeline must be declared with a description")
500 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2:
501 raise ValueError("A pipeline must be declared with one or more tasks")
503 # These steps below must happen in this call order
505 # Process pipeline description
506 self.description = loaded_yaml.pop("description")
508 # Process tasks
509 self._read_tasks(loaded_yaml)
511 # Process instrument keys
512 inst = loaded_yaml.pop("instrument", None)
513 if isinstance(inst, list):
514 raise ValueError("Only one top level instrument can be defined in a pipeline")
515 self.instrument: Optional[str] = inst
517 # Process any contracts
518 self._read_contracts(loaded_yaml)
520 # Process any defined parameters
521 self._read_parameters(loaded_yaml)
523 # Process any named label subsets
524 self._read_labeled_subsets(loaded_yaml)
526 # Process any inherited pipelines
527 self._read_imports(loaded_yaml)
529 # verify named subsets, must be done after inheriting
530 self._verify_labeled_subsets()
532 def _read_contracts(self, loaded_yaml: Dict[str, Any]) -> None:
533 """Process the contracts portion of the loaded yaml document
535 Parameters
536 ---------
537 loaded_yaml : `dict`
538 A dictionary which matches the structure that would be produced by
539 a yaml reader which parses a pipeline definition document
540 """
541 loaded_contracts = loaded_yaml.pop("contracts", [])
542 if isinstance(loaded_contracts, str):
543 loaded_contracts = [loaded_contracts]
544 self.contracts: List[ContractIR] = []
545 for contract in loaded_contracts:
546 if isinstance(contract, dict):
547 self.contracts.append(ContractIR(**contract))
548 if isinstance(contract, str):
549 self.contracts.append(ContractIR(contract=contract))
551 def _read_parameters(self, loaded_yaml: Dict[str, Any]) -> None:
552 """Process the parameters portion of the loaded yaml document
554 Parameters
555 ---------
556 loaded_yaml : `dict`
557 A dictionary which matches the structure that would be produced by
558 a yaml reader which parses a pipeline definition document
559 """
560 loaded_parameters = loaded_yaml.pop("parameters", {})
561 if not isinstance(loaded_parameters, dict):
562 raise ValueError("The parameters section must be a yaml mapping")
563 self.parameters = ParametersIR(loaded_parameters)
565 def _read_labeled_subsets(self, loaded_yaml: Dict[str, Any]) -> None:
566 """Process the subsets portion of the loaded yaml document
568 Parameters
569 ----------
570 loaded_yaml: `MutableMapping`
571 A dictionary which matches the structure that would be produced
572 by a yaml reader which parses a pipeline definition document
573 """
574 loaded_subsets = loaded_yaml.pop("subsets", {})
575 self.labeled_subsets: Dict[str, LabeledSubset] = {}
576 if not loaded_subsets and "subset" in loaded_yaml:
577 raise ValueError("Top level key should be subsets and not subset, add an s")
578 for key, value in loaded_subsets.items():
579 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value)
581 def _verify_labeled_subsets(self) -> None:
582 """Verifies that all the labels in each named subset exist within the
583 pipeline.
584 """
585 # Verify that all labels defined in a labeled subset are in the
586 # Pipeline
587 for labeled_subset in self.labeled_subsets.values():
588 if not labeled_subset.subset.issubset(self.tasks.keys()):
589 raise ValueError(
590 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the "
591 "declared pipeline"
592 )
593 # Verify subset labels are not already task labels
594 label_intersection = self.labeled_subsets.keys() & self.tasks.keys()
595 if label_intersection:
596 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}")
598 def _read_imports(self, loaded_yaml: Dict[str, Any]) -> None:
599 """Process the inherits portion of the loaded yaml document
601 Parameters
602 ---------
603 loaded_yaml : `dict`
604 A dictionary which matches the structure that would be produced by
605 a yaml reader which parses a pipeline definition document
606 """
608 def process_args(argument: Union[str, dict]) -> dict:
609 if isinstance(argument, str):
610 return {"location": argument}
611 elif isinstance(argument, dict):
612 if "exclude" in argument and isinstance(argument["exclude"], str):
613 argument["exclude"] = [argument["exclude"]]
614 if "include" in argument and isinstance(argument["include"], str):
615 argument["include"] = [argument["include"]]
616 if "instrument" in argument and argument["instrument"] == "None":
617 argument["instrument"] = None
618 return argument
620 if not {"inherits", "imports"} - loaded_yaml.keys():
621 raise ValueError("Cannot define both inherits and imports sections, use imports")
622 tmp_import = loaded_yaml.pop("inherits", None)
623 if tmp_import is None:
624 tmp_import = loaded_yaml.pop("imports", None)
625 else:
626 warnings.warn(
627 "The 'inherits' key is deprecated, and will be "
628 "removed around June 2021. Please use the key "
629 "'imports' instead"
630 )
631 if tmp_import is None:
632 self.imports: List[ImportIR] = []
633 elif isinstance(tmp_import, list):
634 self.imports = [ImportIR(**process_args(args)) for args in tmp_import]
635 else:
636 self.imports = [ImportIR(**process_args(tmp_import))]
638 # integrate any imported pipelines
639 accumulate_tasks: Dict[str, TaskIR] = {}
640 accumulate_labeled_subsets: Dict[str, LabeledSubset] = {}
641 accumulated_parameters = ParametersIR({})
642 for other_pipeline in self.imports:
643 tmp_IR = other_pipeline.toPipelineIR()
644 if self.instrument is None:
645 self.instrument = tmp_IR.instrument
646 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None:
647 msg = (
648 "Only one instrument can be declared in a pipeline or its imports. "
649 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} "
650 f"defines {tmp_IR.instrument}."
651 )
652 raise ValueError(msg)
653 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys():
654 msg = (
655 "Task labels in the imported pipelines must be unique. "
656 f"These labels appear multiple times: {duplicate_labels}"
657 )
658 raise ValueError(msg)
659 accumulate_tasks.update(tmp_IR.tasks)
660 self.contracts.extend(tmp_IR.contracts)
661 # verify that tmp_IR has unique labels for named subset among
662 # existing labeled subsets, and with existing task labels.
663 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys()
664 task_subset_overlap = (
665 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()
666 ) & accumulate_tasks.keys()
667 if overlapping_subsets or task_subset_overlap:
668 raise ValueError(
669 "Labeled subset names must be unique amongst imports in both labels and "
670 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}"
671 )
672 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets)
673 accumulated_parameters.update(tmp_IR.parameters)
675 # verify that any accumulated labeled subsets dont clash with a label
676 # from this pipeline
677 if accumulate_labeled_subsets.keys() & self.tasks.keys():
678 raise ValueError(
679 "Labeled subset names must be unique amongst imports in both labels and named Subsets"
680 )
681 # merge in the named subsets for self so this document can override any
682 # that have been delcared
683 accumulate_labeled_subsets.update(self.labeled_subsets)
684 self.labeled_subsets = accumulate_labeled_subsets
686 # merge the dict of label:TaskIR objects, preserving any configs in the
687 # imported pipeline if the labels point to the same class
688 for label, task in self.tasks.items():
689 if label not in accumulate_tasks:
690 accumulate_tasks[label] = task
691 elif accumulate_tasks[label].klass == task.klass:
692 if task.config is not None:
693 for config in task.config:
694 accumulate_tasks[label].add_or_update_config(config)
695 else:
696 accumulate_tasks[label] = task
697 self.tasks: Dict[str, TaskIR] = accumulate_tasks
698 accumulated_parameters.update(self.parameters)
699 self.parameters = accumulated_parameters
701 def _read_tasks(self, loaded_yaml: Dict[str, Any]) -> None:
702 """Process the tasks portion of the loaded yaml document
704 Parameters
705 ---------
706 loaded_yaml : `dict`
707 A dictionary which matches the structure that would be produced by
708 a yaml reader which parses a pipeline definition document
709 """
710 self.tasks = {}
711 tmp_tasks = loaded_yaml.pop("tasks", None)
712 if tmp_tasks is None:
713 tmp_tasks = {}
715 if "parameters" in tmp_tasks:
716 raise ValueError("parameters is a reserved word and cannot be used as a task label")
718 for label, definition in tmp_tasks.items():
719 if isinstance(definition, str):
720 definition = {"class": definition}
721 config = definition.get("config", None)
722 if config is None:
723 task_config_ir = None
724 else:
725 if isinstance(config, dict):
726 config = [config]
727 task_config_ir = []
728 for c in config:
729 file = c.pop("file", None)
730 if file is None:
731 file = []
732 elif not isinstance(file, list):
733 file = [file]
734 task_config_ir.append(
735 ConfigIR(
736 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c
737 )
738 )
739 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
741 def _remove_contracts(self, label: str) -> None:
742 """Remove any contracts that contain the given label
744 String comparison used in this way is not the most elegant and may
745 have issues, but it is the only feasible way when users can specify
746 contracts with generic strings.
747 """
748 new_contracts = []
749 for contract in self.contracts:
750 # match a label that is not preceded by an ASCII identifier, or
751 # is the start of a line and is followed by a dot
752 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
753 continue
754 new_contracts.append(contract)
755 self.contracts = new_contracts
757 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR:
758 """Subset a pipelineIR to contain only labels specified in
759 labelSpecifier.
761 Parameters
762 ----------
763 labelSpecifier : `set` of `str`
764 Set containing labels that describes how to subset a pipeline.
766 Returns
767 -------
768 pipeline : `PipelineIR`
769 A new pipelineIR object that is a subset of the old pipelineIR
771 Raises
772 ------
773 ValueError
774 Raised if there is an issue with specified labels
776 Notes
777 -----
778 This method attempts to prune any contracts that contain labels which
779 are not in the declared subset of labels. This pruning is done using a
780 string based matching due to the nature of contracts and may prune more
781 than it should. Any labeled subsets defined that no longer have all
782 members of the subset present in the pipeline will be removed from the
783 resulting pipeline.
784 """
786 pipeline = copy.deepcopy(self)
788 # update the label specifier to expand any named subsets
789 toRemove = set()
790 toAdd = set()
791 for label in labelSpecifier:
792 if label in pipeline.labeled_subsets:
793 toRemove.add(label)
794 toAdd.update(pipeline.labeled_subsets[label].subset)
795 labelSpecifier.difference_update(toRemove)
796 labelSpecifier.update(toAdd)
797 # verify all the labels are in the pipeline
798 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets):
799 difference = labelSpecifier.difference(pipeline.tasks.keys())
800 raise ValueError(
801 "Not all supplied labels (specified or named subsets) are in the pipeline "
802 f"definition, extra labels: {difference}"
803 )
804 # copy needed so as to not modify while iterating
805 pipeline_labels = set(pipeline.tasks.keys())
806 # Remove the labels from the pipelineIR, and any contracts that contain
807 # those labels (see docstring on _remove_contracts for why this may
808 # cause issues)
809 for label in pipeline_labels:
810 if label not in labelSpecifier:
811 pipeline.tasks.pop(label)
812 pipeline._remove_contracts(label)
814 # create a copy of the object to iterate over
815 labeled_subsets = copy.copy(pipeline.labeled_subsets)
816 # remove any labeled subsets that no longer have a complete set
817 for label, labeled_subset in labeled_subsets.items():
818 if labeled_subset.subset - pipeline.tasks.keys():
819 pipeline.labeled_subsets.pop(label)
821 return pipeline
823 @classmethod
824 def from_string(cls, pipeline_string: str) -> PipelineIR:
825 """Create a `PipelineIR` object from a string formatted like a pipeline
826 document
828 Parameters
829 ----------
830 pipeline_string : `str`
831 A string that is formatted according like a pipeline document
832 """
833 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
834 return cls(loaded_yaml)
836 @classmethod
837 @deprecated(
838 reason="This has been replaced with `from_uri`. will be removed after v23",
839 version="v21.0,",
840 category=FutureWarning,
841 )
842 def from_file(cls, filename: str) -> PipelineIR:
843 """Create a `PipelineIR` object from the document specified by the
844 input path.
846 Parameters
847 ----------
848 filename : `str`
849 Location of document to use in creating a `PipelineIR` object.
851 Returns
852 -------
853 pipelineIR : `PipelineIR`
854 The loaded pipeline
856 Note
857 ----
858 This method is deprecated, please use from_uri
859 """
860 return cls.from_uri(filename)
862 @classmethod
863 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR:
864 """Create a `PipelineIR` object from the document specified by the
865 input uri.
867 Parameters
868 ----------
869 uri: convertible to `ResourcePath`
870 Location of document to use in creating a `PipelineIR` object.
872 Returns
873 -------
874 pipelineIR : `PipelineIR`
875 The loaded pipeline
876 """
877 loaded_uri = ResourcePath(uri)
878 with loaded_uri.open("r") as buffer:
879 # explicitly read here, there was some issue with yaml trying
880 # to read the ResourcePath itself (I think because it only
881 # pretends to be conformant to the io api)
882 loaded_yaml = yaml.load(buffer.read(), Loader=PipelineYamlLoader)
883 return cls(loaded_yaml)
885 @deprecated(
886 reason="This has been replaced with `write_to_uri`. will be removed after v23",
887 version="v21.0,",
888 category=FutureWarning,
889 ) # type: ignore
890 def to_file(self, filename: str):
891 """Serialize this `PipelineIR` object into a yaml formatted string and
892 write the output to a file at the specified path.
894 Parameters
895 ----------
896 filename : `str`
897 Location of document to write a `PipelineIR` object.
898 """
899 self.write_to_uri(filename)
901 def write_to_uri(
902 self,
903 uri: ResourcePathExpression,
904 ) -> None:
905 """Serialize this `PipelineIR` object into a yaml formatted string and
906 write the output to a file at the specified uri.
908 Parameters
909 ----------
910 uri: convertible to `ResourcePath`
911 Location of document to write a `PipelineIR` object.
912 """
913 with ResourcePath(uri).open("w") as buffer:
914 yaml.dump(self.to_primitives(), buffer, sort_keys=False)
916 def to_primitives(self) -> Dict[str, Any]:
917 """Convert to a representation used in yaml serialization"""
918 accumulate = {"description": self.description}
919 if self.instrument is not None:
920 accumulate["instrument"] = self.instrument
921 if self.parameters:
922 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives())
923 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()}
924 if len(self.contracts) > 0:
925 # sort contracts lexicographical order by the contract string in
926 # absence of any other ordering principle
927 contracts_list = [c.to_primitives() for c in self.contracts]
928 contracts_list.sort(key=lambda x: x["contract"])
929 accumulate["contracts"] = contracts_list
930 if self.labeled_subsets:
931 accumulate["subsets"] = self._sort_by_str(
932 {k: v.to_primitives() for k, v in self.labeled_subsets.items()}
933 )
934 return accumulate
936 def reorder_tasks(self, task_labels: List[str]) -> None:
937 """Changes the order tasks are stored internally. Useful for
938 determining the order things will appear in the serialized (or printed)
939 form.
941 Parameters
942 ----------
943 task_labels : `list` of `str`
944 A list corresponding to all the labels in the pipeline inserted in
945 the order the tasks are to be stored.
947 Raises
948 ------
949 KeyError
950 Raised if labels are supplied that are not in the pipeline, or if
951 not all labels in the pipeline were supplied in task_labels input.
952 """
953 # verify that all labels are in the input
954 _tmp_set = set(task_labels)
955 if remainder := (self.tasks.keys() - _tmp_set):
956 raise KeyError(f"Label(s) {remainder} are missing from the task label list")
957 if extra := (_tmp_set - self.tasks.keys()):
958 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline")
960 newTasks = {key: self.tasks[key] for key in task_labels}
961 self.tasks = newTasks
963 @staticmethod
964 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]:
965 keys = sorted(arg.keys())
966 return {key: arg[key] for key in keys}
968 def __str__(self) -> str:
969 """Instance formatting as how it would look in yaml representation"""
970 return yaml.dump(self.to_primitives(), sort_keys=False)
972 def __repr__(self) -> str:
973 """Instance formatting as how it would look in yaml representation"""
974 return str(self)
976 def __eq__(self, other: object) -> bool:
977 if not isinstance(other, PipelineIR):
978 return False
979 # special case contracts because it is a list, but order is not
980 # important
981 elif (
982 all(
983 getattr(self, attr) == getattr(other, attr)
984 for attr in ("tasks", "instrument", "labeled_subsets", "parameters")
985 )
986 and len(self.contracts) == len(other.contracts)
987 and all(c in self.contracts for c in other.contracts)
988 ):
989 return True
990 else:
991 return False