Coverage for python/lsst/pipe/base/pipelineIR.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1__all__ = ("ConfigIR", "ContractError", "ContractIR", "InheritIR", "PipelineIR", "TaskIR")
2# This file is part of pipe_base.
3#
4# Developed for the LSST Data Management System.
5# This product includes software developed by the LSST Project
6# (http://www.lsst.org).
7# See the COPYRIGHT file at the top-level directory of this distribution
8# for details of code ownership.
9#
10# This program is free software: you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# This program is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License
21# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from collections import Counter
24from dataclasses import dataclass, field
25from typing import List, Union, Generator
27import os
28import yaml
29import warnings
32class PipelineYamlLoader(yaml.SafeLoader):
33 """This is a specialized version of yaml's SafeLoader. It checks and raises
34 an exception if it finds that there are multiple instances of the same key
35 found inside a pipeline file at a given scope.
36 """
37 def construct_mapping(self, node, deep=False):
38 # do the call to super first so that it can do all the other forms of
39 # checking on this node. If you check the uniqueness of keys first
40 # it would save the work that super does in the case of a failure, but
41 # it might fail in the case that the node was the incorrect node due
42 # to a parsing error, and the resulting exception would be difficult to
43 # understand.
44 mapping = super().construct_mapping(node, deep)
45 # Check if there are any duplicate keys
46 all_keys = Counter(key_node.value for key_node, _ in node.value)
47 duplicates = {k for k, i in all_keys.items() if i != 1}
48 if duplicates:
49 raise KeyError("Pipeline files must not have duplicated keys, "
50 f"{duplicates} appeared multiple times")
51 return mapping
54class ContractError(Exception):
55 """An exception that is raised when a pipeline contract is not satisfied
56 """
57 pass
60@dataclass
61class ContractIR:
62 """Intermediate representation of contracts read from a pipeline yaml file.
63 """
64 contract: str
65 """A string of python code representing one or more conditions on configs
66 in a pipeline. This code-as-string should, once evaluated, should be True
67 if the configs are fine, and False otherwise.
68 """
69 msg: Union[str, None] = None
70 """An optional message to be shown to the user if a contract fails
71 """
73 def to_primitives(self) -> dict:
74 """Convert to a representation used in yaml serialization
75 """
76 accumulate = {"contract": self.contract}
77 if self.msg is not None:
78 accumulate['msg'] = self.msg
79 return accumulate
81 def __eq__(self, other: "ContractIR"):
82 if not isinstance(other, ContractIR):
83 return False
84 elif self.contract == other.contract and self.msg == other.msg:
85 return True
86 else:
87 return False
90@dataclass
91class ConfigIR:
92 """Intermediate representation of configurations read from a pipeline yaml
93 file.
94 """
95 python: Union[str, None] = None
96 """A string of python code that is used to modify a configuration. This can
97 also be None if there are no modifications to do.
98 """
99 dataId: Union[dict, None] = None
100 """A dataId that is used to constrain these config overrides to only quanta
101 with matching dataIds. This field can be None if there is no constraint.
102 This is currently an unimplemented feature, and is placed here for future
103 use.
104 """
105 file: List[str] = field(default_factory=list)
106 """A list of paths which points to a file containing config overrides to be
107 applied. This value may be an empty list if there are no overrides to apply.
108 """
109 rest: dict = field(default_factory=dict)
110 """This is a dictionary of key value pairs, where the keys are strings
111 corresponding to qualified fields on a config to override, and the values
112 are strings representing the values to apply.
113 """
115 def to_primitives(self) -> dict:
116 """Convert to a representation used in yaml serialization
117 """
118 accumulate = {}
119 for name in ("python", "dataId", "file"):
120 # if this attribute is thruthy add it to the accumulation dictionary
121 if getattr(self, name):
122 accumulate[name] = getattr(self, name)
123 # Add the dictionary containing the rest of the config keys to the
124 # # accumulated dictionary
125 accumulate.update(self.rest)
126 return accumulate
128 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]:
129 """Merges another instance of a `ConfigIR` into this instance if
130 possible. This function returns a generator that is either self
131 if the configs were merged, or self, and other_config if that could
132 not be merged.
134 Parameters
135 ----------
136 other_config : `ConfigIR`
137 An instance of `ConfigIR` to merge into this instance.
139 Returns
140 -------
141 Generator : `ConfigIR`
142 A generator containing either self, or self and other_config if
143 the configs could be merged or not respectively.
144 """
145 # Verify that the config blocks can be merged
146 if self.dataId != other_config.dataId or self.python or other_config.python or\
147 self.file or other_config.file:
148 yield from (self, other_config)
149 return
151 # create a set of all keys, and verify two keys do not have different
152 # values
153 key_union = self.rest.keys() & other_config.rest.keys()
154 for key in key_union:
155 if self.rest[key] != other_config.rest[key]:
156 yield from (self, other_config)
157 return
158 self.rest.update(other_config.rest)
160 # Combine the lists of override files to load
161 self_file_set = set(self.file)
162 other_file_set = set(other_config.file)
163 self.file = list(self_file_set.union(other_file_set))
165 yield self
167 def __eq__(self, other: "ConfigIR"):
168 if not isinstance(other, ConfigIR):
169 return False
170 elif all(getattr(self, attr) == getattr(other, attr) for attr in
171 ("python", "dataId", "file", "rest")):
172 return True
173 else:
174 return False
177@dataclass
178class TaskIR:
179 """Intermediate representation of tasks read from a pipeline yaml file.
180 """
181 label: str
182 """An identifier used to refer to a task.
183 """
184 klass: str
185 """A string containing a fully qualified python class to be run in a
186 pipeline.
187 """
188 config: Union[List[ConfigIR], None] = None
189 """List of all configs overrides associated with this task, and may be
190 `None` if there are no config overrides.
191 """
193 def to_primitives(self) -> dict:
194 """Convert to a representation used in yaml serialization
195 """
196 accumulate = {'class': self.klass}
197 if self.config:
198 accumulate['config'] = [c.to_primitives() for c in self.config]
199 return accumulate
201 def add_or_update_config(self, other_config: ConfigIR):
202 """Adds a `ConfigIR` to this task if one is not present. Merges configs
203 if there is a `ConfigIR` present and the dataId keys of both configs
204 match, otherwise adds a new entry to the config list. The exception to
205 the above is that if either the last config or other_config has a python
206 block, then other_config is always added, as python blocks can modify
207 configs in ways that cannot be predicted.
209 Parameters
210 ----------
211 other_config : `ConfigIR`
212 A `ConfigIR` instance to add or merge into the config attribute of
213 this task.
214 """
215 if not self.config:
216 self.config = [other_config]
217 return
218 self.config.extend(self.config.pop().maybe_merge(other_config))
220 def __eq__(self, other: "TaskIR"):
221 if not isinstance(other, TaskIR):
222 return False
223 elif all(getattr(self, attr) == getattr(other, attr) for attr in
224 ("label", "klass", "config")):
225 return True
226 else:
227 return False
230@dataclass
231class InheritIR:
232 """An intermediate representation of inherited pipelines
233 """
234 location: str
235 """This is the location of the pipeline to inherit. The path should be
236 specified as an absolute path. Environment variables may be used in the path
237 and should be specified as a python string template, with the name of the
238 environment variable inside braces.
239 """
240 include: Union[List[str], None] = None
241 """List of tasks that should be included when inheriting this pipeline.
242 Either the include or exclude attributes may be specified, but not both.
243 """
244 exclude: Union[List[str], None] = None
245 """List of tasks that should be excluded when inheriting this pipeline.
246 Either the include or exclude attributes may be specified, but not both.
247 """
248 importContracts: bool = True
249 """Boolean attribute to dictate if contracts should be inherited with the
250 pipeline or not.
251 """
253 def toPipelineIR(self) -> "PipelineIR":
254 """Convert to a representation used in yaml serialization
255 """
256 if self.include and self.exclude:
257 raise ValueError("Both an include and an exclude list cant be specified"
258 " when declaring a pipeline import")
259 tmp_pipeline = PipelineIR.from_file(os.path.expandvars(self.location))
260 if tmp_pipeline.instrument is not None:
261 warnings.warn("Any instrument definitions in imported pipelines are ignored. "
262 "if an instrument is desired please define it in the top most pipeline")
264 new_tasks = {}
265 for label, task in tmp_pipeline.tasks.items():
266 if (self.include and label in self.include) or (self.exclude and label not in self.exclude)\
267 or (self.include is None and self.exclude is None):
268 new_tasks[label] = task
269 tmp_pipeline.tasks = new_tasks
271 if not self.importContracts:
272 tmp_pipeline.contracts = []
274 return tmp_pipeline
276 def __eq__(self, other: "InheritIR"):
277 if not isinstance(other, InheritIR):
278 return False
279 elif all(getattr(self, attr) == getattr(other, attr) for attr in
280 ("location", "include", "exclude", "importContracts")):
281 return True
282 else:
283 return False
286class PipelineIR:
287 """Intermediate representation of a pipeline definition
289 Parameters
290 ----------
291 loaded_yaml : `dict`
292 A dictionary which matches the structure that would be produced by a
293 yaml reader which parses a pipeline definition document
295 Raises
296 ------
297 ValueError :
298 - If a pipeline is declared without a description
299 - If no tasks are declared in a pipeline, and no pipelines are to be
300 inherited
301 - If more than one instrument is specified
302 - If more than one inherited pipeline share a label
303 """
304 def __init__(self, loaded_yaml):
305 # Check required fields are present
306 if "description" not in loaded_yaml:
307 raise ValueError("A pipeline must be declared with a description")
308 if "tasks" not in loaded_yaml and "inherits" not in loaded_yaml:
309 raise ValueError("A pipeline must be declared with one or more tasks")
311 # Process pipeline description
312 self.description = loaded_yaml.pop("description")
314 # Process tasks
315 self._read_tasks(loaded_yaml)
317 # Process instrument keys
318 inst = loaded_yaml.pop("instrument", None)
319 if isinstance(inst, list):
320 raise ValueError("Only one top level instrument can be defined in a pipeline")
321 self.instrument = inst
323 # Process any contracts
324 self._read_contracts(loaded_yaml)
326 # Process any inherited pipelines
327 self._read_inherits(loaded_yaml)
329 def _read_contracts(self, loaded_yaml):
330 """Process the contracts portion of the loaded yaml document
332 Parameters
333 ---------
334 loaded_yaml : `dict`
335 A dictionary which matches the structure that would be produced by a
336 yaml reader which parses a pipeline definition document
337 """
338 loaded_contracts = loaded_yaml.pop("contracts", [])
339 if isinstance(loaded_contracts, str):
340 loaded_contracts = [loaded_contracts]
341 self.contracts = []
342 for contract in loaded_contracts:
343 if isinstance(contract, dict):
344 self.contracts.append(ContractIR(**contract))
345 if isinstance(contract, str):
346 self.contracts.append(ContractIR(contract=contract))
348 def _read_inherits(self, loaded_yaml):
349 """Process the inherits portion of the loaded yaml document
351 Parameters
352 ---------
353 loaded_yaml : `dict`
354 A dictionary which matches the structure that would be produced by a
355 yaml reader which parses a pipeline definition document
356 """
357 def process_args(argument: Union[str, dict]) -> dict:
358 if isinstance(argument, str):
359 return {"location": argument}
360 elif isinstance(argument, dict):
361 if "exclude" in argument and isinstance(argument["exclude"], str):
362 argument["exclude"] = [argument["exclude"]]
363 if "include" in argument and isinstance(argument["include"], str):
364 argument["include"] = [argument["include"]]
365 return argument
366 tmp_inherit = loaded_yaml.pop("inherits", None)
367 if tmp_inherit is None:
368 self.inherits = []
369 elif isinstance(tmp_inherit, list):
370 self.inherits = [InheritIR(**process_args(args)) for args in tmp_inherit]
371 else:
372 self.inherits = [InheritIR(**process_args(tmp_inherit))]
374 # integrate any imported pipelines
375 accumulate_tasks = {}
376 for other_pipeline in self.inherits:
377 tmp_IR = other_pipeline.toPipelineIR()
378 if accumulate_tasks.keys() & tmp_IR.tasks.keys():
379 raise ValueError("Task labels in the imported pipelines must "
380 "be unique")
381 accumulate_tasks.update(tmp_IR.tasks)
382 self.contracts.extend(tmp_IR.contracts)
383 accumulate_tasks.update(self.tasks)
384 self.tasks = accumulate_tasks
386 def _read_tasks(self, loaded_yaml):
387 """Process the tasks portion of the loaded yaml document
389 Parameters
390 ---------
391 loaded_yaml : `dict`
392 A dictionary which matches the structure that would be produced by a
393 yaml reader which parses a pipeline definition document
394 """
395 self.tasks = {}
396 tmp_tasks = loaded_yaml.pop("tasks", None)
397 if tmp_tasks is None:
398 tmp_tasks = {}
400 for label, definition in tmp_tasks.items():
401 if isinstance(definition, str):
402 definition = {"class": definition}
403 config = definition.get('config', None)
404 if config is None:
405 task_config_ir = None
406 else:
407 if isinstance(config, dict):
408 config = [config]
409 task_config_ir = []
410 for c in config:
411 file = c.pop("file", None)
412 if file is None:
413 file = []
414 elif not isinstance(file, list):
415 file = [file]
416 task_config_ir.append(ConfigIR(python=c.pop("python", None),
417 dataId=c.pop("dataId", None),
418 file=file,
419 rest=c))
420 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir)
422 @classmethod
423 def from_string(cls, pipeline_string: str):
424 """Create a `PipelineIR` object from a string formatted like a pipeline
425 document
427 Parameters
428 ----------
429 pipeline_string : `str`
430 A string that is formatted according like a pipeline document
431 """
432 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader)
433 return cls(loaded_yaml)
435 @classmethod
436 def from_file(cls, filename: str):
437 """Create a `PipelineIR` object from the document specified by the
438 input path.
440 Parameters
441 ----------
442 filename : `str`
443 Location of document to use in creating a `PipelineIR` object.
444 """
445 with open(filename, 'r') as f:
446 loaded_yaml = yaml.load(f, Loader=PipelineYamlLoader)
447 return cls(loaded_yaml)
449 def to_file(self, filename: str):
450 """Serialize this `PipelineIR` object into a yaml formatted string and
451 write the output to a file at the specified path.
453 Parameters
454 ----------
455 filename : `str`
456 Location of document to write a `PipelineIR` object.
457 """
458 with open(filename, 'w') as f:
459 yaml.dump(self.to_primitives(), f, sort_keys=False)
461 def to_primitives(self):
462 """Convert to a representation used in yaml serialization
463 """
464 accumulate = {"description": self.description}
465 if self.instrument is not None:
466 accumulate['instrument'] = self.instrument
467 accumulate['tasks'] = {m: t.to_primitives() for m, t in self.tasks.items()}
468 if len(self.contracts) > 0:
469 accumulate['contracts'] = [c.to_primitives() for c in self.contracts]
470 return accumulate
472 def __str__(self) -> str:
473 """Instance formatting as how it would look in yaml representation
474 """
475 return yaml.dump(self.to_primitives(), sort_keys=False)
477 def __repr__(self) -> str:
478 """Instance formatting as how it would look in yaml representation
479 """
480 return str(self)
482 def __eq__(self, other: "PipelineIR"):
483 if not isinstance(other, PipelineIR):
484 return False
485 elif all(getattr(self, attr) == getattr(other, attr) for attr in
486 ("contracts", "tasks", "instrument")):
487 return True
488 else:
489 return False