Coverage for python/lsst/pipe/base/pipelineIR.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1__all__ = ("ConfigIR", "ContractError", "ContractIR", "InheritIR", "PipelineIR", "TaskIR")
2# This file is part of pipe_base.
3#
4# Developed for the LSST Data Management System.
5# This product includes software developed by the LSST Project
6# (http://www.lsst.org).
7# See the COPYRIGHT file at the top-level directory of this distribution
8# for details of code ownership.
9#
10# This program is free software: you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# This program is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License
21# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from collections import Counter
24from dataclasses import dataclass, field
25from typing import List, Union, Generator
27import os
28import yaml
29import warnings
32class PipelineYamlLoader(yaml.SafeLoader):
33 """This is a specialized version of yaml's SafeLoader. It checks and raises
34 an exception if it finds that there are multiple instances of the same key
35 found inside a pipeline file at a given scope.
36 """
37 def construct_mapping(self, node, deep=False):
38 # do the call to super first so that it can do all the other forms of
39 # checking on this node. If you check the uniqueness of keys first
40 # it would save the work that super does in the case of a failure, but
41 # it might fail in the case that the node was the incorrect node due
42 # to a parsing error, and the resulting exception would be difficult to
43 # understand.
44 mapping = super().construct_mapping(node, deep)
45 # Check if there are any duplicate keys
46 all_keys = Counter(key_node.value for key_node, _ in node.value)
47 duplicates = {k for k, i in all_keys.items() if i != 1}
48 if duplicates:
49 raise KeyError("Pipeline files must not have duplicated keys, "
50 f"{duplicates} appeared multiple times")
51 return mapping
54class ContractError(Exception):
55 """An exception that is raised when a pipeline contract is not satisfied
56 """
57 pass
60@dataclass
61class ContractIR:
62 """Intermediate representation of contracts read from a pipeline yaml file.
63 """
64 contract: str
65 """A string of python code representing one or more conditions on configs
66 in a pipeline. This code-as-string should, once evaluated, should be True
67 if the configs are fine, and False otherwise.
68 """
69 msg: Union[str, None] = None
70 """An optional message to be shown to the user if a contract fails
71 """
73 def to_primitives(self) -> dict:
74 """Convert to a representation used in yaml serialization
75 """
76 accumulate = {"contract": self.contract}
77 if self.msg is not None:
78 accumulate['msg'] = self.msg
79 return accumulate
81 def __eq__(self, other: "ContractIR"):
82 if not isinstance(other, ContractIR):
83 return False
84 elif self.contract == other.contract and self.msg == other.msg:
85 return True
86 else:
87 return False
90@dataclass
91class ConfigIR:
92 """Intermediate representation of configurations read from a pipeline yaml
93 file.
94 """
95 python: Union[str, None] = None
96 """A string of python code that is used to modify a configuration. This can
97 also be None if there are no modifications to do.
98 """
99 dataId: Union[dict, None] = None
100 """A dataId that is used to constrain these config overrides to only quanta
101 with matching dataIds. This field can be None if there is no constraint.
102 This is currently an unimplemented feature, and is placed here for future
103 use.
104 """
105 file: List[str] = field(default_factory=list)
106 """A list of paths which points to a file containing config overrides to be
107 applied. This value may be an empty list if there are no overrides to apply.
108 """
109 rest: dict = field(default_factory=dict)
110 """This is a dictionary of key value pairs, where the keys are strings
111 corresponding to qualified fields on a config to override, and the values
112 are strings representing the values to apply.
113 """
115 def to_primitives(self) -> dict:
116 """Convert to a representation used in yaml serialization
117 """
118 accumulate = {}
119 for name in ("python", "dataId", "file"):
120 # if this attribute is thruthy add it to the accumulation dictionary
121 if getattr(self, name):
122 accumulate[name] = getattr(self, name)
123 # Add the dictionary containing the rest of the config keys to the
124 # # accumulated dictionary
125 accumulate.update(self.rest)
126 return accumulate
128 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
129 """Merges another instance of a `ConfigIR` into this instance if
130 possible. This function returns a generator that is either self
131 if the configs were merged, or self, and other_config if that could
132 not be merged.
134 Parameters
135 ----------
136 other_config : `ConfigIR`
137 An instance of `ConfigIR` to merge into this instance.
139 Returns
140 -------
141 Generator : `ConfigIR`
142 A generator containing either self, or self and other_config if
143 the configs could be merged or not respectively.
144 """
145 # Verify that the config blocks can be merged
146 if self.dataId != other_config.dataId or self.python or other_config.python or\
147 self.file or other_config.file:
148 yield from (self, other_config)
149 return
151 # create a set of all keys, and verify two keys do not have different
152 # values
153 key_union = self.rest.keys() & other_config.rest.keys()
154 for key in key_union:
155 if self.rest[key] != other_config.rest[key]:
156 yield from (self, other_config)
157 return
158 self.rest.update(other_config.rest)
160 # Combine the lists of override files to load
161 self_file_set = set(self.file)
162 other_file_set = set(other_config.file)
163 self.file = list(self_file_set.union(other_file_set))
165 yield self
167 def __eq__(self, other: "ConfigIR"):
168 if not isinstance(other, ConfigIR):
169 return False
170 elif all(getattr(self, attr) == getattr(other, attr) for attr in
171 ("python", "dataId", "file", "rest")):
172 return True
173 else:
174 return False
177@dataclass
178class TaskIR:
179 """Intermediate representation of tasks read from a pipeline yaml file.
180 """
181 label: str
182 """An identifier used to refer to a task.
183 """
184 klass: str
185 """A string containing a fully qualified python class to be run in a
186 pipeline.
187 """
188 config: Union[List[ConfigIR], None] = None
189 """List of all configs overrides associated with this task, and may be
190 `None` if there are no config overrides.
191 """
193 def to_primitives(self) -> dict:
194 """Convert to a representation used in yaml serialization
195 """
196 accumulate = {'class': self.klass}
197 if self.config:
198 accumulate['config'] = [c.to_primitives() for c in self.config]
199 return accumulate
201 def add_or_update_config(self, other_config: ConfigIR):
202 """Adds a `ConfigIR` to this task if one is not present. Merges configs
203 if there is a `ConfigIR` present and the dataId keys of both configs
204 match, otherwise adds a new entry to the config list. The exception to
205 the above is that if either the last config or other_config has a python
206 block, then other_config is always added, as python blocks can modify
207 configs in ways that cannot be predicted.
209 Parameters
210 ----------
211 other_config : `ConfigIR`
212 A `ConfigIR` instance to add or merge into the config attribute of
213 this task.
214 """
215 if not self.config:
216 self.config = [other_config]
217 return
218 self.config.extend(self.config.pop().maybe_merge(other_config))
220 def __eq__(self, other: "TaskIR"):
221 if not isinstance(other, TaskIR):
222 return False
223 elif all(getattr(self, attr) == getattr(other, attr) for attr in
224 ("label", "klass", "config")):
225 return True
226 else:
227 return False
230@dataclass
231class InheritIR:
232 """An intermediate representation of inherited pipelines
233 """
234 location: str
235 """This is the location of the pipeline to inherit. The path should be
236 specified as an absolute path. Environment variables may be used in the path
237 and should be specified as a python string template, with the name of the
238 environment variable inside braces.
239 """
240 include: Union[List[str], None] = None
241 """List of tasks that should be included when inheriting this pipeline.
242 Either the include or exclude attributes may be specified, but not both.
243 """
244 exclude: Union[List[str], None] = None
245 """List of tasks that should be excluded when inheriting this pipeline.
246 Either the include or exclude attributes may be specified, but not both.
247 """
248 importContracts: bool = True
249 """Boolean attribute to dictate if contracts should be inherited with the
250 pipeline or not.
251 """
253 def toPipelineIR(self) -> "PipelineIR":
254 """Convert to a representation used in yaml serialization
255 """
256 if self.include and self.exclude:
257 raise ValueError("Both an include and an exclude list cant be specified"
258 " when declaring a pipeline import")
259 tmp_pipeline = PipelineIR.from_file(os.path.expandvars(self.location))
260 if tmp_pipeline.instrument is not None:
261 warnings.warn("Any instrument definitions in imported pipelines are ignored. "
262 "if an instrument is desired please define it in the top most pipeline")
264 new_tasks = {}
265 for label, task in tmp_pipeline.tasks.items():
266 if (self.include and label in self.include) or (self.exclude and label not in self.exclude)\
267 or (self.include is None and self.exclude is None):
268 new_tasks[label] = task
269 tmp_pipeline.tasks = new_tasks
271 if not self.importContracts:
272 tmp_pipeline.contracts = []
274 return tmp_pipeline
276 def __eq__(self, other: "InheritIR"):
277 if not isinstance(other, InheritIR):
278 return False
279 elif all(getattr(self, attr) == getattr(other, attr) for attr in
280 ("location", "include", "exclude", "importContracts")):
281 return True
282 else:
283 return False
286class PipelineIR:
287 """Intermediate representation of a pipeline definition
289 Parameters
290 ----------
291 loaded_yaml : `dict`
292 A dictionary which matches the structure that would be produced by a
293 yaml reader which parses a pipeline definition document
295 Raises
296 ------
297 ValueError :
298 - If a pipeline is declared without a description
299 - If no tasks are declared in a pipeline, and no pipelines are to be
300 inherited
301 - If more than one instrument is specified
302 - If more than one inherited pipeline share a label
303 """
304 def __init__(self, loaded_yaml):
305 # Check required fields are present
306 if "description" not in loaded_yaml:
307 raise ValueError("A pipeline must be declared with a description")
308 if "tasks" not in loaded_yaml and "inherits" not in loaded_yaml:
309 raise ValueError("A pipeline must be declared with one or more tasks")
311 # Process pipeline description
312 self.description = loaded_yaml.pop("description")
314 # Process tasks
315 self._read_tasks(loaded_yaml)
317 # Process instrument keys
318 inst = loaded_yaml.pop("instrument", None)
319 if isinstance(inst, list):
320 raise ValueError("Only one top level instrument can be defined in a pipeline")
321 self.instrument = inst
323 # Process any contracts
324 self._read_contracts(loaded_yaml)
326 # Process any inherited pipelines
327 self._read_inherits(loaded_yaml)
329 def _read_contracts(self, loaded_yaml):
330 """Process the contracts portion of the loaded yaml document
332 Parameters
333 ---------
334 loaded_yaml : `dict`
335 A dictionary which matches the structure that would be produced by a
336 yaml reader which parses a pipeline definition document
337 """
338 loaded_contracts = loaded_yaml.pop("contracts", [])
339 if isinstance(loaded_contracts, str):
340 loaded_contracts = [loaded_contracts]
341 self.contracts = []
342 for contract in loaded_contracts:
343 if isinstance(contract, dict):
344 self.contracts.append(ContractIR(**contract))
345 if isinstance(contract, str):
346 self.contracts.append(ContractIR(contract=contract))
348 def _read_inherits(self, loaded_yaml):
349 """Process the inherits portion of the loaded yaml document
351 Parameters
352 ---------
353 loaded_yaml : `dict`
354 A dictionary which matches the structure that would be produced by a
355 yaml reader which parses a pipeline definition document
356 """
357 def process_args(argument: Union[str, dict]) -> dict:
358 if isinstance(argument, str):
359 return {"location": argument}
360 elif isinstance(argument, dict):
361 if "exclude" in argument and isinstance(argument["exclude"], str):
362 argument["exclude"] = [argument["exclude"]]
363 if "include" in argument and isinstance(argument["include"], str):
364 argument["include"] = [argument["include"]]
365 return argument
366 tmp_inherit = loaded_yaml.pop("inherits", None)
367 if tmp_inherit is None:
368 self.inherits = []
369 elif isinstance(tmp_inherit, list):
370 self.inherits = [InheritIR(**process_args(args)) for args in tmp_inherit]
371 else:
372 self.inherits = [InheritIR(**process_args(tmp_inherit))]
374 # integrate any imported pipelines
375 accumulate_tasks = {}
376 for other_pipeline in self.inherits:
377 tmp_IR = other_pipeline.toPipelineIR()
378 if accumulate_tasks.keys() & tmp_IR.tasks.keys():
379 raise ValueError("Task labels in the imported pipelines must "
380 "be unique")
381 accumulate_tasks.update(tmp_IR.tasks)
382 self.contracts.extend(tmp_IR.contracts)
384 # merge the dict of label:TaskIR objects, preserving any configs in the
385 # imported pipeline if the labels point to the same class
386 for label, task in self.tasks.items():
387 if label not in accumulate_tasks:
388 accumulate_tasks[label] = task
389 elif accumulate_tasks[label].klass == task.klass:
390 if task.config is not None:
391 for config in task.config:
392 accumulate_tasks[label].add_or_update_config(config)
393 else:
394 accumulate_tasks[label] = task
395 self.tasks = accumulate_tasks
397 def _read_tasks(self, loaded_yaml):
398 """Process the tasks portion of the loaded yaml document
400 Parameters
401 ---------
402 loaded_yaml : `dict`
403 A dictionary which matches the structure that would be produced by a
404 yaml reader which parses a pipeline definition document
405 """
406 self.tasks = {}
407 tmp_tasks = loaded_yaml.pop("tasks", None)
408 if tmp_tasks is None:
409 tmp_tasks = {}
411 for label, definition in tmp_tasks.items():
412 if isinstance(definition, str):
413 definition = {"class": definition}
414 config = definition.get('config', None)
415 if config is None:
416 task_config_ir = None
417 else:
418 if isinstance(config, dict):
419 config = [config]
420 task_config_ir = []
421 for c in config:
422 file = c.pop("file", None)
423 if file is None:
424 file = []
425 elif not isinstance(file, list):
426 file = [file]
427 task_config_ir.append(ConfigIR(python=c.pop("python", None),
428 dataId=c.pop("dataId", None),
429 file=file,
430 rest=c))
431 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
433 @classmethod
434 def from_string(cls, pipeline_string: str):
435 """Create a `PipelineIR` object from a string formatted like a pipeline
436 document
438 Parameters
439 ----------
440 pipeline_string : `str`
441 A string that is formatted according like a pipeline document
442 """
443 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
444 return cls(loaded_yaml)
446 @classmethod
447 def from_file(cls, filename: str):
448 """Create a `PipelineIR` object from the document specified by the
449 input path.
451 Parameters
452 ----------
453 filename : `str`
454 Location of document to use in creating a `PipelineIR` object.
455 """
456 with open(filename, 'r') as f:
457 loaded_yaml = yaml.load(f, Loader=PipelineYamlLoader)
458 return cls(loaded_yaml)
460 def to_file(self, filename: str):
461 """Serialize this `PipelineIR` object into a yaml formatted string and
462 write the output to a file at the specified path.
464 Parameters
465 ----------
466 filename : `str`
467 Location of document to write a `PipelineIR` object.
468 """
469 with open(filename, 'w') as f:
470 yaml.dump(self.to_primitives(), f, sort_keys=False)
472 def to_primitives(self):
473 """Convert to a representation used in yaml serialization
474 """
475 accumulate = {"description": self.description}
476 if self.instrument is not None:
477 accumulate['instrument'] = self.instrument
478 accumulate['tasks'] = {m: t.to_primitives() for m, t in self.tasks.items()}
479 if len(self.contracts) > 0:
480 accumulate['contracts'] = [c.to_primitives() for c in self.contracts]
481 return accumulate
483 def __str__(self) -> str:
484 """Instance formatting as how it would look in yaml representation
485 """
486 return yaml.dump(self.to_primitives(), sort_keys=False)
488 def __repr__(self) -> str:
489 """Instance formatting as how it would look in yaml representation
490 """
491 return str(self)
493 def __eq__(self, other: "PipelineIR"):
494 if not isinstance(other, PipelineIR):
495 return False
496 elif all(getattr(self, attr) == getattr(other, attr) for attr in
497 ("contracts", "tasks", "instrument")):
498 return True
499 else:
500 return False