Coverage for python/lsst/pipe/base/pipelineIR.py: 19%

408 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-02 14:36 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset") 

24 

25import copy 

26import enum 

27import os 

28import re 

29import warnings 

30from collections import Counter 

31from collections.abc import Iterable as abcIterable 

32from dataclasses import dataclass, field 

33from typing import ( 

34 Any, 

35 Dict, 

36 Generator, 

37 Hashable, 

38 Iterable, 

39 List, 

40 Literal, 

41 Mapping, 

42 MutableMapping, 

43 Optional, 

44 Set, 

45 Union, 

46) 

47 

48import yaml 

49from lsst.resources import ResourcePath, ResourcePathExpression 

50 

51 

52class _Tags(enum.Enum): 

53 KeepInstrument = enum.auto() 

54 

55 

56class PipelineYamlLoader(yaml.SafeLoader): 

57 """This is a specialized version of yaml's SafeLoader. It checks and raises 

58 an exception if it finds that there are multiple instances of the same key 

59 found inside a pipeline file at a given scope. 

60 """ 

61 

62 def construct_mapping(self, node: yaml.MappingNode, deep: bool = False) -> dict[Hashable, Any]: 

63 # do the call to super first so that it can do all the other forms of 

64 # checking on this node. If you check the uniqueness of keys first 

65 # it would save the work that super does in the case of a failure, but 

66 # it might fail in the case that the node was the incorrect node due 

67 # to a parsing error, and the resulting exception would be difficult to 

68 # understand. 

69 mapping = super().construct_mapping(node, deep) 

70 # Check if there are any duplicate keys 

71 all_keys = Counter(key_node.value for key_node, _ in node.value) 

72 duplicates = {k for k, i in all_keys.items() if i != 1} 

73 if duplicates: 

74 raise KeyError( 

75 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times" 

76 ) 

77 return mapping 

78 

79 

80class MultilineStringDumper(yaml.Dumper): 

81 """Custom YAML dumper that makes multi-line strings use the '|' 

82 continuation style instead of unreadable newlines and tons of quotes. 

83 

84 Basic approach is taken from 

85 https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data, 

86 but is written as a Dumper subclass to make its effects non-global (vs 

87 `yaml.add_representer`). 

88 """ 

89 

90 def represent_scalar(self, tag: str, value: Any, style: Optional[str] = None) -> yaml.ScalarNode: 

91 if style is None and tag == "tag:yaml.org,2002:str" and len(value.splitlines()) > 1: 

92 style = "|" 

93 return super().represent_scalar(tag, value, style) 

94 

95 

96class ContractError(Exception): 

97 """An exception that is raised when a pipeline contract is not satisfied""" 

98 

99 pass 

100 

101 

102@dataclass 

103class ContractIR: 

104 """Intermediate representation of configuration contracts read from a 

105 pipeline yaml file.""" 

106 

107 contract: str 

108 """A string of python code representing one or more conditions on configs 

109 in a pipeline. This code-as-string should, once evaluated, should be True 

110 if the configs are fine, and False otherwise. 

111 """ 

112 msg: Union[str, None] = None 

113 """An optional message to be shown to the user if a contract fails 

114 """ 

115 

116 def to_primitives(self) -> Dict[str, str]: 

117 """Convert to a representation used in yaml serialization""" 

118 accumulate = {"contract": self.contract} 

119 if self.msg is not None: 

120 accumulate["msg"] = self.msg 

121 return accumulate 

122 

123 def __eq__(self, other: object) -> bool: 

124 if not isinstance(other, ContractIR): 

125 return False 

126 elif self.contract == other.contract and self.msg == other.msg: 

127 return True 

128 else: 

129 return False 

130 

131 

132@dataclass 

133class LabeledSubset: 

134 """Intermediate representation of named subset of task labels read from 

135 a pipeline yaml file. 

136 """ 

137 

138 label: str 

139 """The label used to identify the subset of task labels. 

140 """ 

141 subset: Set[str] 

142 """A set of task labels contained in this subset. 

143 """ 

144 description: Optional[str] 

145 """A description of what this subset of tasks is intended to do 

146 """ 

147 

148 @staticmethod 

149 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset: 

150 """Generate `LabeledSubset` objects given a properly formatted object 

151 that as been created by a yaml loader. 

152 

153 Parameters 

154 ---------- 

155 label : `str` 

156 The label that will be used to identify this labeled subset. 

157 value : `list` of `str` or `dict` 

158 Object returned from loading a labeled subset section from a yaml 

159 document. 

160 

161 Returns 

162 ------- 

163 labeledSubset : `LabeledSubset` 

164 A `LabeledSubset` object build from the inputs. 

165 

166 Raises 

167 ------ 

168 ValueError 

169 Raised if the value input is not properly formatted for parsing 

170 """ 

171 if isinstance(value, MutableMapping): 

172 subset = value.pop("subset", None) 

173 if subset is None: 

174 raise ValueError( 

175 "If a labeled subset is specified as a mapping, it must contain the key 'subset'" 

176 ) 

177 description = value.pop("description", None) 

178 elif isinstance(value, abcIterable): 

179 subset = value 

180 description = None 

181 else: 

182 raise ValueError( 

183 f"There was a problem parsing the labeled subset {label}, make sure the " 

184 "definition is either a valid yaml list, or a mapping with keys " 

185 "(subset, description) where subset points to a yaml list, and description is " 

186 "associated with a string" 

187 ) 

188 return LabeledSubset(label, set(subset), description) 

189 

190 def to_primitives(self) -> Dict[str, Union[List[str], str]]: 

191 """Convert to a representation used in yaml serialization""" 

192 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)} 

193 if self.description is not None: 

194 accumulate["description"] = self.description 

195 return accumulate 

196 

197 

198@dataclass 

199class ParametersIR: 

200 """Intermediate representation of parameters that are global to a pipeline 

201 

202 These parameters are specified under a top level key named `parameters` 

203 and are declared as a yaml mapping. These entries can then be used inside 

204 task configuration blocks to specify configuration values. They may not be 

205 used in the special ``file`` or ``python`` blocks. 

206 

207 Example: 

208 paramters: 

209 shared_value: 14 

210 tasks: 

211 taskA: 

212 class: modA 

213 config: 

214 field1: parameters.shared_value 

215 taskB: 

216 class: modB 

217 config: 

218 field2: parameters.shared_value 

219 """ 

220 

221 mapping: MutableMapping[str, str] 

222 """A mutable mapping of identifiers as keys, and shared configuration 

223 as values. 

224 """ 

225 

226 def update(self, other: Optional[ParametersIR]) -> None: 

227 if other is not None: 

228 self.mapping.update(other.mapping) 

229 

230 def to_primitives(self) -> MutableMapping[str, str]: 

231 """Convert to a representation used in yaml serialization""" 

232 return self.mapping 

233 

234 def __contains__(self, value: str) -> bool: 

235 return value in self.mapping 

236 

237 def __getitem__(self, item: str) -> Any: 

238 return self.mapping[item] 

239 

240 def __bool__(self) -> bool: 

241 return bool(self.mapping) 

242 

243 

244@dataclass 

245class ConfigIR: 

246 """Intermediate representation of configurations read from a pipeline yaml 

247 file. 

248 """ 

249 

250 python: Union[str, None] = None 

251 """A string of python code that is used to modify a configuration. This can 

252 also be None if there are no modifications to do. 

253 """ 

254 dataId: Union[dict, None] = None 

255 """A dataId that is used to constrain these config overrides to only quanta 

256 with matching dataIds. This field can be None if there is no constraint. 

257 This is currently an unimplemented feature, and is placed here for future 

258 use. 

259 """ 

260 file: List[str] = field(default_factory=list) 

261 """A list of paths which points to a file containing config overrides to be 

262 applied. This value may be an empty list if there are no overrides to 

263 apply. 

264 """ 

265 rest: dict = field(default_factory=dict) 

266 """This is a dictionary of key value pairs, where the keys are strings 

267 corresponding to qualified fields on a config to override, and the values 

268 are strings representing the values to apply. 

269 """ 

270 

271 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]: 

272 """Convert to a representation used in yaml serialization""" 

273 accumulate = {} 

274 for name in ("python", "dataId", "file"): 

275 # if this attribute is thruthy add it to the accumulation 

276 # dictionary 

277 if getattr(self, name): 

278 accumulate[name] = getattr(self, name) 

279 # Add the dictionary containing the rest of the config keys to the 

280 # # accumulated dictionary 

281 accumulate.update(self.rest) 

282 return accumulate 

283 

284 def formatted(self, parameters: ParametersIR) -> ConfigIR: 

285 """Returns a new ConfigIR object that is formatted according to the 

286 specified parameters 

287 

288 Parameters 

289 ---------- 

290 parameters : ParametersIR 

291 Object that contains variable mappings used in substitution. 

292 

293 Returns 

294 ------- 

295 config : ConfigIR 

296 A new ConfigIR object formatted with the input parameters 

297 """ 

298 new_config = copy.deepcopy(self) 

299 for key, value in new_config.rest.items(): 

300 if not isinstance(value, str): 

301 continue 

302 match = re.match("parameters[.](.*)", value) 

303 if match and match.group(1) in parameters: 

304 new_config.rest[key] = parameters[match.group(1)] 

305 if match and match.group(1) not in parameters: 

306 warnings.warn( 

307 f"config {key} contains value {match.group(0)} which is formatted like a " 

308 "Pipeline parameter but was not found within the Pipeline, if this was not " 

309 "intentional, check for a typo" 

310 ) 

311 return new_config 

312 

313 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]: 

314 """Merges another instance of a `ConfigIR` into this instance if 

315 possible. This function returns a generator that is either self 

316 if the configs were merged, or self, and other_config if that could 

317 not be merged. 

318 

319 Parameters 

320 ---------- 

321 other_config : `ConfigIR` 

322 An instance of `ConfigIR` to merge into this instance. 

323 

324 Returns 

325 ------- 

326 Generator : `ConfigIR` 

327 A generator containing either self, or self and other_config if 

328 the configs could be merged or not respectively. 

329 """ 

330 # Verify that the config blocks can be merged 

331 if ( 

332 self.dataId != other_config.dataId 

333 or self.python 

334 or other_config.python 

335 or self.file 

336 or other_config.file 

337 ): 

338 yield from (self, other_config) 

339 return 

340 

341 # create a set of all keys, and verify two keys do not have different 

342 # values 

343 key_union = self.rest.keys() & other_config.rest.keys() 

344 for key in key_union: 

345 if self.rest[key] != other_config.rest[key]: 

346 yield from (self, other_config) 

347 return 

348 self.rest.update(other_config.rest) 

349 

350 # Combine the lists of override files to load 

351 self_file_set = set(self.file) 

352 other_file_set = set(other_config.file) 

353 self.file = list(self_file_set.union(other_file_set)) 

354 

355 yield self 

356 

357 def __eq__(self, other: object) -> bool: 

358 if not isinstance(other, ConfigIR): 

359 return False 

360 elif all( 

361 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest") 

362 ): 

363 return True 

364 else: 

365 return False 

366 

367 

368@dataclass 

369class TaskIR: 

370 """Intermediate representation of tasks read from a pipeline yaml file.""" 

371 

372 label: str 

373 """An identifier used to refer to a task. 

374 """ 

375 klass: str 

376 """A string containing a fully qualified python class to be run in a 

377 pipeline. 

378 """ 

379 config: Union[List[ConfigIR], None] = None 

380 """List of all configs overrides associated with this task, and may be 

381 `None` if there are no config overrides. 

382 """ 

383 

384 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]: 

385 """Convert to a representation used in yaml serialization""" 

386 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass} 

387 if self.config: 

388 accumulate["config"] = [c.to_primitives() for c in self.config] 

389 return accumulate 

390 

391 def add_or_update_config(self, other_config: ConfigIR) -> None: 

392 """Adds a `ConfigIR` to this task if one is not present. Merges configs 

393 if there is a `ConfigIR` present and the dataId keys of both configs 

394 match, otherwise adds a new entry to the config list. The exception to 

395 the above is that if either the last config or other_config has a 

396 python block, then other_config is always added, as python blocks can 

397 modify configs in ways that cannot be predicted. 

398 

399 Parameters 

400 ---------- 

401 other_config : `ConfigIR` 

402 A `ConfigIR` instance to add or merge into the config attribute of 

403 this task. 

404 """ 

405 if not self.config: 

406 self.config = [other_config] 

407 return 

408 self.config.extend(self.config.pop().maybe_merge(other_config)) 

409 

410 def __eq__(self, other: object) -> bool: 

411 if not isinstance(other, TaskIR): 

412 return False 

413 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")): 

414 return True 

415 else: 

416 return False 

417 

418 

419@dataclass 

420class ImportIR: 

421 """An intermediate representation of imported pipelines""" 

422 

423 location: str 

424 """This is the location of the pipeline to inherit. The path should be 

425 specified as an absolute path. Environment variables may be used in the 

426 path and should be specified as a python string template, with the name of 

427 the environment variable inside braces. 

428 """ 

429 include: Union[List[str], None] = None 

430 """List of tasks that should be included when inheriting this pipeline. 

431 Either the include or exclude attributes may be specified, but not both. 

432 """ 

433 exclude: Union[List[str], None] = None 

434 """List of tasks that should be excluded when inheriting this pipeline. 

435 Either the include or exclude attributes may be specified, but not both. 

436 """ 

437 importContracts: bool = True 

438 """Boolean attribute to dictate if contracts should be inherited with the 

439 pipeline or not. 

440 """ 

441 instrument: Union[Literal[_Tags.KeepInstrument], str, None] = _Tags.KeepInstrument 

442 """Instrument to assign to the Pipeline at import. The default value of 

443 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is 

444 declared with will not be modified. Setting this value to None will drop 

445 any declared instrument prior to import. 

446 """ 

447 

448 def toPipelineIR(self) -> "PipelineIR": 

449 """Load in the Pipeline specified by this object, and turn it into a 

450 PipelineIR instance. 

451 

452 Returns 

453 ------- 

454 pipeline : `PipelineIR` 

455 A pipeline generated from the imported pipeline file 

456 """ 

457 if self.include and self.exclude: 

458 raise ValueError( 

459 "Both an include and an exclude list cant be specified when declaring a pipeline import" 

460 ) 

461 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location)) 

462 if self.instrument is not _Tags.KeepInstrument: 

463 tmp_pipeline.instrument = self.instrument 

464 

465 included_labels = set() 

466 for label in tmp_pipeline.tasks: 

467 if ( 

468 (self.include and label in self.include) 

469 or (self.exclude and label not in self.exclude) 

470 or (self.include is None and self.exclude is None) 

471 ): 

472 included_labels.add(label) 

473 

474 # Handle labeled subsets being specified in the include or exclude 

475 # list, adding or removing labels. 

476 if self.include is not None: 

477 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include 

478 for label in subsets_in_include: 

479 included_labels.update(tmp_pipeline.labeled_subsets[label].subset) 

480 

481 elif self.exclude is not None: 

482 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude 

483 for label in subsets_in_exclude: 

484 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset) 

485 

486 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels) 

487 

488 if not self.importContracts: 

489 tmp_pipeline.contracts = [] 

490 

491 return tmp_pipeline 

492 

493 def __eq__(self, other: object) -> bool: 

494 if not isinstance(other, ImportIR): 

495 return False 

496 elif all( 

497 getattr(self, attr) == getattr(other, attr) 

498 for attr in ("location", "include", "exclude", "importContracts") 

499 ): 

500 return True 

501 else: 

502 return False 

503 

504 

505class PipelineIR: 

506 """Intermediate representation of a pipeline definition 

507 

508 Parameters 

509 ---------- 

510 loaded_yaml : `dict` 

511 A dictionary which matches the structure that would be produced by a 

512 yaml reader which parses a pipeline definition document 

513 

514 Raises 

515 ------ 

516 ValueError 

517 Raised if: 

518 

519 - a pipeline is declared without a description; 

520 - no tasks are declared in a pipeline, and no pipelines are to be 

521 inherited; 

522 - more than one instrument is specified; 

523 - more than one inherited pipeline share a label. 

524 """ 

525 

526 def __init__(self, loaded_yaml: Dict[str, Any]): 

527 # Check required fields are present 

528 if "description" not in loaded_yaml: 

529 raise ValueError("A pipeline must be declared with a description") 

530 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2: 

531 raise ValueError("A pipeline must be declared with one or more tasks") 

532 

533 # These steps below must happen in this call order 

534 

535 # Process pipeline description 

536 self.description = loaded_yaml.pop("description") 

537 

538 # Process tasks 

539 self._read_tasks(loaded_yaml) 

540 

541 # Process instrument keys 

542 inst = loaded_yaml.pop("instrument", None) 

543 if isinstance(inst, list): 

544 raise ValueError("Only one top level instrument can be defined in a pipeline") 

545 self.instrument: Optional[str] = inst 

546 

547 # Process any contracts 

548 self._read_contracts(loaded_yaml) 

549 

550 # Process any defined parameters 

551 self._read_parameters(loaded_yaml) 

552 

553 # Process any named label subsets 

554 self._read_labeled_subsets(loaded_yaml) 

555 

556 # Process any inherited pipelines 

557 self._read_imports(loaded_yaml) 

558 

559 # verify named subsets, must be done after inheriting 

560 self._verify_labeled_subsets() 

561 

562 def _read_contracts(self, loaded_yaml: Dict[str, Any]) -> None: 

563 """Process the contracts portion of the loaded yaml document 

564 

565 Parameters 

566 --------- 

567 loaded_yaml : `dict` 

568 A dictionary which matches the structure that would be produced by 

569 a yaml reader which parses a pipeline definition document 

570 """ 

571 loaded_contracts = loaded_yaml.pop("contracts", []) 

572 if isinstance(loaded_contracts, str): 

573 loaded_contracts = [loaded_contracts] 

574 self.contracts: List[ContractIR] = [] 

575 for contract in loaded_contracts: 

576 if isinstance(contract, dict): 

577 self.contracts.append(ContractIR(**contract)) 

578 if isinstance(contract, str): 

579 self.contracts.append(ContractIR(contract=contract)) 

580 

581 def _read_parameters(self, loaded_yaml: Dict[str, Any]) -> None: 

582 """Process the parameters portion of the loaded yaml document 

583 

584 Parameters 

585 --------- 

586 loaded_yaml : `dict` 

587 A dictionary which matches the structure that would be produced by 

588 a yaml reader which parses a pipeline definition document 

589 """ 

590 loaded_parameters = loaded_yaml.pop("parameters", {}) 

591 if not isinstance(loaded_parameters, dict): 

592 raise ValueError("The parameters section must be a yaml mapping") 

593 self.parameters = ParametersIR(loaded_parameters) 

594 

595 def _read_labeled_subsets(self, loaded_yaml: Dict[str, Any]) -> None: 

596 """Process the subsets portion of the loaded yaml document 

597 

598 Parameters 

599 ---------- 

600 loaded_yaml: `MutableMapping` 

601 A dictionary which matches the structure that would be produced 

602 by a yaml reader which parses a pipeline definition document 

603 """ 

604 loaded_subsets = loaded_yaml.pop("subsets", {}) 

605 self.labeled_subsets: Dict[str, LabeledSubset] = {} 

606 if not loaded_subsets and "subset" in loaded_yaml: 

607 raise ValueError("Top level key should be subsets and not subset, add an s") 

608 for key, value in loaded_subsets.items(): 

609 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value) 

610 

611 def _verify_labeled_subsets(self) -> None: 

612 """Verifies that all the labels in each named subset exist within the 

613 pipeline. 

614 """ 

615 # Verify that all labels defined in a labeled subset are in the 

616 # Pipeline 

617 for labeled_subset in self.labeled_subsets.values(): 

618 if not labeled_subset.subset.issubset(self.tasks.keys()): 

619 raise ValueError( 

620 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the " 

621 "declared pipeline" 

622 ) 

623 # Verify subset labels are not already task labels 

624 label_intersection = self.labeled_subsets.keys() & self.tasks.keys() 

625 if label_intersection: 

626 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}") 

627 

628 def _read_imports(self, loaded_yaml: Dict[str, Any]) -> None: 

629 """Process the inherits portion of the loaded yaml document 

630 

631 Parameters 

632 --------- 

633 loaded_yaml : `dict` 

634 A dictionary which matches the structure that would be produced by 

635 a yaml reader which parses a pipeline definition document 

636 """ 

637 

638 def process_args(argument: Union[str, dict]) -> dict: 

639 if isinstance(argument, str): 

640 return {"location": argument} 

641 elif isinstance(argument, dict): 

642 if "exclude" in argument and isinstance(argument["exclude"], str): 

643 argument["exclude"] = [argument["exclude"]] 

644 if "include" in argument and isinstance(argument["include"], str): 

645 argument["include"] = [argument["include"]] 

646 if "instrument" in argument and argument["instrument"] == "None": 

647 argument["instrument"] = None 

648 return argument 

649 

650 if not {"inherits", "imports"} - loaded_yaml.keys(): 

651 raise ValueError("Cannot define both inherits and imports sections, use imports") 

652 tmp_import = loaded_yaml.pop("inherits", None) 

653 if tmp_import is None: 

654 tmp_import = loaded_yaml.pop("imports", None) 

655 else: 

656 warnings.warn( 

657 "The 'inherits' key is deprecated, and will be " 

658 "removed around June 2021. Please use the key " 

659 "'imports' instead" 

660 ) 

661 if tmp_import is None: 

662 self.imports: List[ImportIR] = [] 

663 elif isinstance(tmp_import, list): 

664 self.imports = [ImportIR(**process_args(args)) for args in tmp_import] 

665 else: 

666 self.imports = [ImportIR(**process_args(tmp_import))] 

667 

668 self.merge_pipelines([fragment.toPipelineIR() for fragment in self.imports]) 

669 

670 def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None: 

671 """Merge one or more other `PipelineIR` objects into this object. 

672 

673 Parameters 

674 ---------- 

675 pipelines : `Iterable` of `PipelineIR` objects 

676 An `Iterable` that contains one or more `PipelineIR` objects to 

677 merge into this object. 

678 

679 Raises 

680 ------ 

681 ValueError 

682 Raised if there is a conflict in instrument specifications. 

683 Raised if a task label appears in more than one of the input 

684 `PipelineIR` objects which are to be merged. 

685 Raised if a labeled subset appears in more than one of the input 

686 `PipelineIR` objects which are to be merged, and with any subset 

687 existing in this object. 

688 """ 

689 # integrate any imported pipelines 

690 accumulate_tasks: Dict[str, TaskIR] = {} 

691 accumulate_labeled_subsets: Dict[str, LabeledSubset] = {} 

692 accumulated_parameters = ParametersIR({}) 

693 

694 for tmp_IR in pipelines: 

695 if self.instrument is None: 

696 self.instrument = tmp_IR.instrument 

697 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None: 

698 msg = ( 

699 "Only one instrument can be declared in a pipeline or its imports. " 

700 f"Top level pipeline defines {self.instrument} but pipeline to merge " 

701 f"defines {tmp_IR.instrument}." 

702 ) 

703 raise ValueError(msg) 

704 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys(): 

705 msg = ( 

706 "Task labels in the imported pipelines must be unique. " 

707 f"These labels appear multiple times: {duplicate_labels}" 

708 ) 

709 raise ValueError(msg) 

710 accumulate_tasks.update(tmp_IR.tasks) 

711 self.contracts.extend(tmp_IR.contracts) 

712 # verify that tmp_IR has unique labels for named subset among 

713 # existing labeled subsets, and with existing task labels. 

714 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys() 

715 task_subset_overlap = ( 

716 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys() 

717 ) & accumulate_tasks.keys() 

718 if overlapping_subsets or task_subset_overlap: 

719 raise ValueError( 

720 "Labeled subset names must be unique amongst imports in both labels and " 

721 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}" 

722 ) 

723 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets) 

724 accumulated_parameters.update(tmp_IR.parameters) 

725 

726 # verify that any accumulated labeled subsets dont clash with a label 

727 # from this pipeline 

728 if accumulate_labeled_subsets.keys() & self.tasks.keys(): 

729 raise ValueError( 

730 "Labeled subset names must be unique amongst imports in both labels and named Subsets" 

731 ) 

732 # merge in the named subsets for self so this document can override any 

733 # that have been delcared 

734 accumulate_labeled_subsets.update(self.labeled_subsets) 

735 self.labeled_subsets = accumulate_labeled_subsets 

736 

737 # merge the dict of label:TaskIR objects, preserving any configs in the 

738 # imported pipeline if the labels point to the same class 

739 for label, task in self.tasks.items(): 

740 if label not in accumulate_tasks: 

741 accumulate_tasks[label] = task 

742 elif accumulate_tasks[label].klass == task.klass: 

743 if task.config is not None: 

744 for config in task.config: 

745 accumulate_tasks[label].add_or_update_config(config) 

746 else: 

747 accumulate_tasks[label] = task 

748 self.tasks: Dict[str, TaskIR] = accumulate_tasks 

749 accumulated_parameters.update(self.parameters) 

750 self.parameters = accumulated_parameters 

751 

752 def _read_tasks(self, loaded_yaml: Dict[str, Any]) -> None: 

753 """Process the tasks portion of the loaded yaml document 

754 

755 Parameters 

756 --------- 

757 loaded_yaml : `dict` 

758 A dictionary which matches the structure that would be produced by 

759 a yaml reader which parses a pipeline definition document 

760 """ 

761 self.tasks = {} 

762 tmp_tasks = loaded_yaml.pop("tasks", None) 

763 if tmp_tasks is None: 

764 tmp_tasks = {} 

765 

766 if "parameters" in tmp_tasks: 

767 raise ValueError("parameters is a reserved word and cannot be used as a task label") 

768 

769 for label, definition in tmp_tasks.items(): 

770 if isinstance(definition, str): 

771 definition = {"class": definition} 

772 config = definition.get("config", None) 

773 if config is None: 

774 task_config_ir = None 

775 else: 

776 if isinstance(config, dict): 

777 config = [config] 

778 task_config_ir = [] 

779 for c in config: 

780 file = c.pop("file", None) 

781 if file is None: 

782 file = [] 

783 elif not isinstance(file, list): 

784 file = [file] 

785 task_config_ir.append( 

786 ConfigIR( 

787 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c 

788 ) 

789 ) 

790 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir) 

791 

792 def _remove_contracts(self, label: str) -> None: 

793 """Remove any contracts that contain the given label 

794 

795 String comparison used in this way is not the most elegant and may 

796 have issues, but it is the only feasible way when users can specify 

797 contracts with generic strings. 

798 """ 

799 new_contracts = [] 

800 for contract in self.contracts: 

801 # match a label that is not preceded by an ASCII identifier, or 

802 # is the start of a line and is followed by a dot 

803 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

804 continue 

805 new_contracts.append(contract) 

806 self.contracts = new_contracts 

807 

808 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR: 

809 """Subset a pipelineIR to contain only labels specified in 

810 labelSpecifier. 

811 

812 Parameters 

813 ---------- 

814 labelSpecifier : `set` of `str` 

815 Set containing labels that describes how to subset a pipeline. 

816 

817 Returns 

818 ------- 

819 pipeline : `PipelineIR` 

820 A new pipelineIR object that is a subset of the old pipelineIR 

821 

822 Raises 

823 ------ 

824 ValueError 

825 Raised if there is an issue with specified labels 

826 

827 Notes 

828 ----- 

829 This method attempts to prune any contracts that contain labels which 

830 are not in the declared subset of labels. This pruning is done using a 

831 string based matching due to the nature of contracts and may prune more 

832 than it should. Any labeled subsets defined that no longer have all 

833 members of the subset present in the pipeline will be removed from the 

834 resulting pipeline. 

835 """ 

836 

837 pipeline = copy.deepcopy(self) 

838 

839 # update the label specifier to expand any named subsets 

840 toRemove = set() 

841 toAdd = set() 

842 for label in labelSpecifier: 

843 if label in pipeline.labeled_subsets: 

844 toRemove.add(label) 

845 toAdd.update(pipeline.labeled_subsets[label].subset) 

846 labelSpecifier.difference_update(toRemove) 

847 labelSpecifier.update(toAdd) 

848 # verify all the labels are in the pipeline 

849 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets): 

850 difference = labelSpecifier.difference(pipeline.tasks.keys()) 

851 raise ValueError( 

852 "Not all supplied labels (specified or named subsets) are in the pipeline " 

853 f"definition, extra labels: {difference}" 

854 ) 

855 # copy needed so as to not modify while iterating 

856 pipeline_labels = set(pipeline.tasks.keys()) 

857 # Remove the labels from the pipelineIR, and any contracts that contain 

858 # those labels (see docstring on _remove_contracts for why this may 

859 # cause issues) 

860 for label in pipeline_labels: 

861 if label not in labelSpecifier: 

862 pipeline.tasks.pop(label) 

863 pipeline._remove_contracts(label) 

864 

865 # create a copy of the object to iterate over 

866 labeled_subsets = copy.copy(pipeline.labeled_subsets) 

867 # remove any labeled subsets that no longer have a complete set 

868 for label, labeled_subset in labeled_subsets.items(): 

869 if labeled_subset.subset - pipeline.tasks.keys(): 

870 pipeline.labeled_subsets.pop(label) 

871 

872 return pipeline 

873 

874 @classmethod 

875 def from_string(cls, pipeline_string: str) -> PipelineIR: 

876 """Create a `PipelineIR` object from a string formatted like a pipeline 

877 document 

878 

879 Parameters 

880 ---------- 

881 pipeline_string : `str` 

882 A string that is formatted according like a pipeline document 

883 """ 

884 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader) 

885 return cls(loaded_yaml) 

886 

887 @classmethod 

888 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR: 

889 """Create a `PipelineIR` object from the document specified by the 

890 input uri. 

891 

892 Parameters 

893 ---------- 

894 uri: convertible to `ResourcePath` 

895 Location of document to use in creating a `PipelineIR` object. 

896 

897 Returns 

898 ------- 

899 pipelineIR : `PipelineIR` 

900 The loaded pipeline 

901 """ 

902 loaded_uri = ResourcePath(uri) 

903 with loaded_uri.open("r") as buffer: 

904 loaded_yaml = yaml.load(buffer, Loader=PipelineYamlLoader) 

905 return cls(loaded_yaml) 

906 

907 def write_to_uri( 

908 self, 

909 uri: ResourcePathExpression, 

910 ) -> None: 

911 """Serialize this `PipelineIR` object into a yaml formatted string and 

912 write the output to a file at the specified uri. 

913 

914 Parameters 

915 ---------- 

916 uri: convertible to `ResourcePath` 

917 Location of document to write a `PipelineIR` object. 

918 """ 

919 with ResourcePath(uri).open("w") as buffer: 

920 yaml.dump(self.to_primitives(), buffer, sort_keys=False, Dumper=MultilineStringDumper) 

921 

922 def to_primitives(self) -> Dict[str, Any]: 

923 """Convert to a representation used in yaml serialization""" 

924 accumulate = {"description": self.description} 

925 if self.instrument is not None: 

926 accumulate["instrument"] = self.instrument 

927 if self.parameters: 

928 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives()) 

929 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()} 

930 if len(self.contracts) > 0: 

931 # sort contracts lexicographical order by the contract string in 

932 # absence of any other ordering principle 

933 contracts_list = [c.to_primitives() for c in self.contracts] 

934 contracts_list.sort(key=lambda x: x["contract"]) 

935 accumulate["contracts"] = contracts_list 

936 if self.labeled_subsets: 

937 accumulate["subsets"] = self._sort_by_str( 

938 {k: v.to_primitives() for k, v in self.labeled_subsets.items()} 

939 ) 

940 return accumulate 

941 

942 def reorder_tasks(self, task_labels: List[str]) -> None: 

943 """Changes the order tasks are stored internally. Useful for 

944 determining the order things will appear in the serialized (or printed) 

945 form. 

946 

947 Parameters 

948 ---------- 

949 task_labels : `list` of `str` 

950 A list corresponding to all the labels in the pipeline inserted in 

951 the order the tasks are to be stored. 

952 

953 Raises 

954 ------ 

955 KeyError 

956 Raised if labels are supplied that are not in the pipeline, or if 

957 not all labels in the pipeline were supplied in task_labels input. 

958 """ 

959 # verify that all labels are in the input 

960 _tmp_set = set(task_labels) 

961 if remainder := (self.tasks.keys() - _tmp_set): 

962 raise KeyError(f"Label(s) {remainder} are missing from the task label list") 

963 if extra := (_tmp_set - self.tasks.keys()): 

964 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline") 

965 

966 newTasks = {key: self.tasks[key] for key in task_labels} 

967 self.tasks = newTasks 

968 

969 @staticmethod 

970 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]: 

971 keys = sorted(arg.keys()) 

972 return {key: arg[key] for key in keys} 

973 

974 def __str__(self) -> str: 

975 """Instance formatting as how it would look in yaml representation""" 

976 return yaml.dump(self.to_primitives(), sort_keys=False, Dumper=MultilineStringDumper) 

977 

978 def __repr__(self) -> str: 

979 """Instance formatting as how it would look in yaml representation""" 

980 return str(self) 

981 

982 def __eq__(self, other: object) -> bool: 

983 if not isinstance(other, PipelineIR): 

984 return False 

985 # special case contracts because it is a list, but order is not 

986 # important 

987 elif ( 

988 all( 

989 getattr(self, attr) == getattr(other, attr) 

990 for attr in ("tasks", "instrument", "labeled_subsets", "parameters") 

991 ) 

992 and len(self.contracts) == len(other.contracts) 

993 and all(c in self.contracts for c in other.contracts) 

994 ): 

995 return True 

996 else: 

997 return False