Coverage for python/lsst/pipe/base/pipelineIR.py: 19%

407 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-10 10:55 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset") 

24 

25import copy 

26import enum 

27import os 

28import re 

29import warnings 

30from collections import Counter 

31from collections.abc import Iterable as abcIterable 

32from dataclasses import dataclass, field 

33from typing import ( 

34 Any, 

35 Dict, 

36 Generator, 

37 Hashable, 

38 List, 

39 Literal, 

40 Mapping, 

41 MutableMapping, 

42 Optional, 

43 Set, 

44 Union, 

45) 

46 

47import yaml 

48from lsst.resources import ResourcePath, ResourcePathExpression 

49 

50 

51class _Tags(enum.Enum): 

52 KeepInstrument = enum.auto() 

53 

54 

55class PipelineYamlLoader(yaml.SafeLoader): 

56 """This is a specialized version of yaml's SafeLoader. It checks and raises 

57 an exception if it finds that there are multiple instances of the same key 

58 found inside a pipeline file at a given scope. 

59 """ 

60 

61 def construct_mapping(self, node: yaml.MappingNode, deep: bool = False) -> dict[Hashable, Any]: 

62 # do the call to super first so that it can do all the other forms of 

63 # checking on this node. If you check the uniqueness of keys first 

64 # it would save the work that super does in the case of a failure, but 

65 # it might fail in the case that the node was the incorrect node due 

66 # to a parsing error, and the resulting exception would be difficult to 

67 # understand. 

68 mapping = super().construct_mapping(node, deep) 

69 # Check if there are any duplicate keys 

70 all_keys = Counter(key_node.value for key_node, _ in node.value) 

71 duplicates = {k for k, i in all_keys.items() if i != 1} 

72 if duplicates: 

73 raise KeyError( 

74 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times" 

75 ) 

76 return mapping 

77 

78 

79class MultilineStringDumper(yaml.Dumper): 

80 """Custom YAML dumper that makes multi-line strings use the '|' 

81 continuation style instead of unreadable newlines and tons of quotes. 

82 

83 Basic approach is taken from 

84 https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data, 

85 but is written as a Dumper subclass to make its effects non-global (vs 

86 `yaml.add_representer`). 

87 """ 

88 

89 def represent_scalar(self, tag: str, value: Any, style: Optional[str] = None) -> yaml.ScalarNode: 

90 if style is None and tag == "tag:yaml.org,2002:str" and len(value.splitlines()) > 1: 

91 style = "|" 

92 return super().represent_scalar(tag, value, style) 

93 

94 

95class ContractError(Exception): 

96 """An exception that is raised when a pipeline contract is not satisfied""" 

97 

98 pass 

99 

100 

101@dataclass 

102class ContractIR: 

103 """Intermediate representation of configuration contracts read from a 

104 pipeline yaml file.""" 

105 

106 contract: str 

107 """A string of python code representing one or more conditions on configs 

108 in a pipeline. This code-as-string should, once evaluated, should be True 

109 if the configs are fine, and False otherwise. 

110 """ 

111 msg: Union[str, None] = None 

112 """An optional message to be shown to the user if a contract fails 

113 """ 

114 

115 def to_primitives(self) -> Dict[str, str]: 

116 """Convert to a representation used in yaml serialization""" 

117 accumulate = {"contract": self.contract} 

118 if self.msg is not None: 

119 accumulate["msg"] = self.msg 

120 return accumulate 

121 

122 def __eq__(self, other: object) -> bool: 

123 if not isinstance(other, ContractIR): 

124 return False 

125 elif self.contract == other.contract and self.msg == other.msg: 

126 return True 

127 else: 

128 return False 

129 

130 

131@dataclass 

132class LabeledSubset: 

133 """Intermediate representation of named subset of task labels read from 

134 a pipeline yaml file. 

135 """ 

136 

137 label: str 

138 """The label used to identify the subset of task labels. 

139 """ 

140 subset: Set[str] 

141 """A set of task labels contained in this subset. 

142 """ 

143 description: Optional[str] 

144 """A description of what this subset of tasks is intended to do 

145 """ 

146 

147 @staticmethod 

148 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset: 

149 """Generate `LabeledSubset` objects given a properly formatted object 

150 that as been created by a yaml loader. 

151 

152 Parameters 

153 ---------- 

154 label : `str` 

155 The label that will be used to identify this labeled subset. 

156 value : `list` of `str` or `dict` 

157 Object returned from loading a labeled subset section from a yaml 

158 document. 

159 

160 Returns 

161 ------- 

162 labeledSubset : `LabeledSubset` 

163 A `LabeledSubset` object build from the inputs. 

164 

165 Raises 

166 ------ 

167 ValueError 

168 Raised if the value input is not properly formatted for parsing 

169 """ 

170 if isinstance(value, MutableMapping): 

171 subset = value.pop("subset", None) 

172 if subset is None: 

173 raise ValueError( 

174 "If a labeled subset is specified as a mapping, it must contain the key 'subset'" 

175 ) 

176 description = value.pop("description", None) 

177 elif isinstance(value, abcIterable): 

178 subset = value 

179 description = None 

180 else: 

181 raise ValueError( 

182 f"There was a problem parsing the labeled subset {label}, make sure the " 

183 "definition is either a valid yaml list, or a mapping with keys " 

184 "(subset, description) where subset points to a yaml list, and description is " 

185 "associated with a string" 

186 ) 

187 return LabeledSubset(label, set(subset), description) 

188 

189 def to_primitives(self) -> Dict[str, Union[List[str], str]]: 

190 """Convert to a representation used in yaml serialization""" 

191 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)} 

192 if self.description is not None: 

193 accumulate["description"] = self.description 

194 return accumulate 

195 

196 

197@dataclass 

198class ParametersIR: 

199 """Intermediate representation of parameters that are global to a pipeline 

200 

201 These parameters are specified under a top level key named `parameters` 

202 and are declared as a yaml mapping. These entries can then be used inside 

203 task configuration blocks to specify configuration values. They may not be 

204 used in the special ``file`` or ``python`` blocks. 

205 

206 Example: 

207 paramters: 

208 shared_value: 14 

209 tasks: 

210 taskA: 

211 class: modA 

212 config: 

213 field1: parameters.shared_value 

214 taskB: 

215 class: modB 

216 config: 

217 field2: parameters.shared_value 

218 """ 

219 

220 mapping: MutableMapping[str, str] 

221 """A mutable mapping of identifiers as keys, and shared configuration 

222 as values. 

223 """ 

224 

225 def update(self, other: Optional[ParametersIR]) -> None: 

226 if other is not None: 

227 self.mapping.update(other.mapping) 

228 

229 def to_primitives(self) -> MutableMapping[str, str]: 

230 """Convert to a representation used in yaml serialization""" 

231 return self.mapping 

232 

233 def __contains__(self, value: str) -> bool: 

234 return value in self.mapping 

235 

236 def __getitem__(self, item: str) -> Any: 

237 return self.mapping[item] 

238 

239 def __bool__(self) -> bool: 

240 return bool(self.mapping) 

241 

242 

243@dataclass 

244class ConfigIR: 

245 """Intermediate representation of configurations read from a pipeline yaml 

246 file. 

247 """ 

248 

249 python: Union[str, None] = None 

250 """A string of python code that is used to modify a configuration. This can 

251 also be None if there are no modifications to do. 

252 """ 

253 dataId: Union[dict, None] = None 

254 """A dataId that is used to constrain these config overrides to only quanta 

255 with matching dataIds. This field can be None if there is no constraint. 

256 This is currently an unimplemented feature, and is placed here for future 

257 use. 

258 """ 

259 file: List[str] = field(default_factory=list) 

260 """A list of paths which points to a file containing config overrides to be 

261 applied. This value may be an empty list if there are no overrides to 

262 apply. 

263 """ 

264 rest: dict = field(default_factory=dict) 

265 """This is a dictionary of key value pairs, where the keys are strings 

266 corresponding to qualified fields on a config to override, and the values 

267 are strings representing the values to apply. 

268 """ 

269 

270 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]: 

271 """Convert to a representation used in yaml serialization""" 

272 accumulate = {} 

273 for name in ("python", "dataId", "file"): 

274 # if this attribute is thruthy add it to the accumulation 

275 # dictionary 

276 if getattr(self, name): 

277 accumulate[name] = getattr(self, name) 

278 # Add the dictionary containing the rest of the config keys to the 

279 # # accumulated dictionary 

280 accumulate.update(self.rest) 

281 return accumulate 

282 

283 def formatted(self, parameters: ParametersIR) -> ConfigIR: 

284 """Returns a new ConfigIR object that is formatted according to the 

285 specified parameters 

286 

287 Parameters 

288 ---------- 

289 parameters : ParametersIR 

290 Object that contains variable mappings used in substitution. 

291 

292 Returns 

293 ------- 

294 config : ConfigIR 

295 A new ConfigIR object formatted with the input parameters 

296 """ 

297 new_config = copy.deepcopy(self) 

298 for key, value in new_config.rest.items(): 

299 if not isinstance(value, str): 

300 continue 

301 match = re.match("parameters[.](.*)", value) 

302 if match and match.group(1) in parameters: 

303 new_config.rest[key] = parameters[match.group(1)] 

304 if match and match.group(1) not in parameters: 

305 warnings.warn( 

306 f"config {key} contains value {match.group(0)} which is formatted like a " 

307 "Pipeline parameter but was not found within the Pipeline, if this was not " 

308 "intentional, check for a typo" 

309 ) 

310 return new_config 

311 

312 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]: 

313 """Merges another instance of a `ConfigIR` into this instance if 

314 possible. This function returns a generator that is either self 

315 if the configs were merged, or self, and other_config if that could 

316 not be merged. 

317 

318 Parameters 

319 ---------- 

320 other_config : `ConfigIR` 

321 An instance of `ConfigIR` to merge into this instance. 

322 

323 Returns 

324 ------- 

325 Generator : `ConfigIR` 

326 A generator containing either self, or self and other_config if 

327 the configs could be merged or not respectively. 

328 """ 

329 # Verify that the config blocks can be merged 

330 if ( 

331 self.dataId != other_config.dataId 

332 or self.python 

333 or other_config.python 

334 or self.file 

335 or other_config.file 

336 ): 

337 yield from (self, other_config) 

338 return 

339 

340 # create a set of all keys, and verify two keys do not have different 

341 # values 

342 key_union = self.rest.keys() & other_config.rest.keys() 

343 for key in key_union: 

344 if self.rest[key] != other_config.rest[key]: 

345 yield from (self, other_config) 

346 return 

347 self.rest.update(other_config.rest) 

348 

349 # Combine the lists of override files to load 

350 self_file_set = set(self.file) 

351 other_file_set = set(other_config.file) 

352 self.file = list(self_file_set.union(other_file_set)) 

353 

354 yield self 

355 

356 def __eq__(self, other: object) -> bool: 

357 if not isinstance(other, ConfigIR): 

358 return False 

359 elif all( 

360 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest") 

361 ): 

362 return True 

363 else: 

364 return False 

365 

366 

367@dataclass 

368class TaskIR: 

369 """Intermediate representation of tasks read from a pipeline yaml file.""" 

370 

371 label: str 

372 """An identifier used to refer to a task. 

373 """ 

374 klass: str 

375 """A string containing a fully qualified python class to be run in a 

376 pipeline. 

377 """ 

378 config: Union[List[ConfigIR], None] = None 

379 """List of all configs overrides associated with this task, and may be 

380 `None` if there are no config overrides. 

381 """ 

382 

383 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]: 

384 """Convert to a representation used in yaml serialization""" 

385 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass} 

386 if self.config: 

387 accumulate["config"] = [c.to_primitives() for c in self.config] 

388 return accumulate 

389 

390 def add_or_update_config(self, other_config: ConfigIR) -> None: 

391 """Adds a `ConfigIR` to this task if one is not present. Merges configs 

392 if there is a `ConfigIR` present and the dataId keys of both configs 

393 match, otherwise adds a new entry to the config list. The exception to 

394 the above is that if either the last config or other_config has a 

395 python block, then other_config is always added, as python blocks can 

396 modify configs in ways that cannot be predicted. 

397 

398 Parameters 

399 ---------- 

400 other_config : `ConfigIR` 

401 A `ConfigIR` instance to add or merge into the config attribute of 

402 this task. 

403 """ 

404 if not self.config: 

405 self.config = [other_config] 

406 return 

407 self.config.extend(self.config.pop().maybe_merge(other_config)) 

408 

409 def __eq__(self, other: object) -> bool: 

410 if not isinstance(other, TaskIR): 

411 return False 

412 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")): 

413 return True 

414 else: 

415 return False 

416 

417 

418@dataclass 

419class ImportIR: 

420 """An intermediate representation of imported pipelines""" 

421 

422 location: str 

423 """This is the location of the pipeline to inherit. The path should be 

424 specified as an absolute path. Environment variables may be used in the 

425 path and should be specified as a python string template, with the name of 

426 the environment variable inside braces. 

427 """ 

428 include: Union[List[str], None] = None 

429 """List of tasks that should be included when inheriting this pipeline. 

430 Either the include or exclude attributes may be specified, but not both. 

431 """ 

432 exclude: Union[List[str], None] = None 

433 """List of tasks that should be excluded when inheriting this pipeline. 

434 Either the include or exclude attributes may be specified, but not both. 

435 """ 

436 importContracts: bool = True 

437 """Boolean attribute to dictate if contracts should be inherited with the 

438 pipeline or not. 

439 """ 

440 instrument: Union[Literal[_Tags.KeepInstrument], str, None] = _Tags.KeepInstrument 

441 """Instrument to assign to the Pipeline at import. The default value of 

442 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is 

443 declared with will not be modified. Setting this value to None will drop 

444 any declared instrument prior to import. 

445 """ 

446 

447 def toPipelineIR(self) -> "PipelineIR": 

448 """Load in the Pipeline specified by this object, and turn it into a 

449 PipelineIR instance. 

450 

451 Returns 

452 ------- 

453 pipeline : `PipelineIR` 

454 A pipeline generated from the imported pipeline file 

455 """ 

456 if self.include and self.exclude: 

457 raise ValueError( 

458 "Both an include and an exclude list cant be specified when declaring a pipeline import" 

459 ) 

460 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location)) 

461 if self.instrument is not _Tags.KeepInstrument: 

462 tmp_pipeline.instrument = self.instrument 

463 

464 included_labels = set() 

465 for label in tmp_pipeline.tasks: 

466 if ( 

467 (self.include and label in self.include) 

468 or (self.exclude and label not in self.exclude) 

469 or (self.include is None and self.exclude is None) 

470 ): 

471 included_labels.add(label) 

472 

473 # Handle labeled subsets being specified in the include or exclude 

474 # list, adding or removing labels. 

475 if self.include is not None: 

476 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include 

477 for label in subsets_in_include: 

478 included_labels.update(tmp_pipeline.labeled_subsets[label].subset) 

479 

480 elif self.exclude is not None: 

481 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude 

482 for label in subsets_in_exclude: 

483 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset) 

484 

485 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels) 

486 

487 if not self.importContracts: 

488 tmp_pipeline.contracts = [] 

489 

490 return tmp_pipeline 

491 

492 def __eq__(self, other: object) -> bool: 

493 if not isinstance(other, ImportIR): 

494 return False 

495 elif all( 

496 getattr(self, attr) == getattr(other, attr) 

497 for attr in ("location", "include", "exclude", "importContracts") 

498 ): 

499 return True 

500 else: 

501 return False 

502 

503 

504class PipelineIR: 

505 """Intermediate representation of a pipeline definition 

506 

507 Parameters 

508 ---------- 

509 loaded_yaml : `dict` 

510 A dictionary which matches the structure that would be produced by a 

511 yaml reader which parses a pipeline definition document 

512 

513 Raises 

514 ------ 

515 ValueError 

516 Raised if: 

517 

518 - a pipeline is declared without a description; 

519 - no tasks are declared in a pipeline, and no pipelines are to be 

520 inherited; 

521 - more than one instrument is specified; 

522 - more than one inherited pipeline share a label. 

523 """ 

524 

525 def __init__(self, loaded_yaml: Dict[str, Any]): 

526 # Check required fields are present 

527 if "description" not in loaded_yaml: 

528 raise ValueError("A pipeline must be declared with a description") 

529 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2: 

530 raise ValueError("A pipeline must be declared with one or more tasks") 

531 

532 # These steps below must happen in this call order 

533 

534 # Process pipeline description 

535 self.description = loaded_yaml.pop("description") 

536 

537 # Process tasks 

538 self._read_tasks(loaded_yaml) 

539 

540 # Process instrument keys 

541 inst = loaded_yaml.pop("instrument", None) 

542 if isinstance(inst, list): 

543 raise ValueError("Only one top level instrument can be defined in a pipeline") 

544 self.instrument: Optional[str] = inst 

545 

546 # Process any contracts 

547 self._read_contracts(loaded_yaml) 

548 

549 # Process any defined parameters 

550 self._read_parameters(loaded_yaml) 

551 

552 # Process any named label subsets 

553 self._read_labeled_subsets(loaded_yaml) 

554 

555 # Process any inherited pipelines 

556 self._read_imports(loaded_yaml) 

557 

558 # verify named subsets, must be done after inheriting 

559 self._verify_labeled_subsets() 

560 

561 def _read_contracts(self, loaded_yaml: Dict[str, Any]) -> None: 

562 """Process the contracts portion of the loaded yaml document 

563 

564 Parameters 

565 --------- 

566 loaded_yaml : `dict` 

567 A dictionary which matches the structure that would be produced by 

568 a yaml reader which parses a pipeline definition document 

569 """ 

570 loaded_contracts = loaded_yaml.pop("contracts", []) 

571 if isinstance(loaded_contracts, str): 

572 loaded_contracts = [loaded_contracts] 

573 self.contracts: List[ContractIR] = [] 

574 for contract in loaded_contracts: 

575 if isinstance(contract, dict): 

576 self.contracts.append(ContractIR(**contract)) 

577 if isinstance(contract, str): 

578 self.contracts.append(ContractIR(contract=contract)) 

579 

580 def _read_parameters(self, loaded_yaml: Dict[str, Any]) -> None: 

581 """Process the parameters portion of the loaded yaml document 

582 

583 Parameters 

584 --------- 

585 loaded_yaml : `dict` 

586 A dictionary which matches the structure that would be produced by 

587 a yaml reader which parses a pipeline definition document 

588 """ 

589 loaded_parameters = loaded_yaml.pop("parameters", {}) 

590 if not isinstance(loaded_parameters, dict): 

591 raise ValueError("The parameters section must be a yaml mapping") 

592 self.parameters = ParametersIR(loaded_parameters) 

593 

594 def _read_labeled_subsets(self, loaded_yaml: Dict[str, Any]) -> None: 

595 """Process the subsets portion of the loaded yaml document 

596 

597 Parameters 

598 ---------- 

599 loaded_yaml: `MutableMapping` 

600 A dictionary which matches the structure that would be produced 

601 by a yaml reader which parses a pipeline definition document 

602 """ 

603 loaded_subsets = loaded_yaml.pop("subsets", {}) 

604 self.labeled_subsets: Dict[str, LabeledSubset] = {} 

605 if not loaded_subsets and "subset" in loaded_yaml: 

606 raise ValueError("Top level key should be subsets and not subset, add an s") 

607 for key, value in loaded_subsets.items(): 

608 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value) 

609 

610 def _verify_labeled_subsets(self) -> None: 

611 """Verifies that all the labels in each named subset exist within the 

612 pipeline. 

613 """ 

614 # Verify that all labels defined in a labeled subset are in the 

615 # Pipeline 

616 for labeled_subset in self.labeled_subsets.values(): 

617 if not labeled_subset.subset.issubset(self.tasks.keys()): 

618 raise ValueError( 

619 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the " 

620 "declared pipeline" 

621 ) 

622 # Verify subset labels are not already task labels 

623 label_intersection = self.labeled_subsets.keys() & self.tasks.keys() 

624 if label_intersection: 

625 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}") 

626 

627 def _read_imports(self, loaded_yaml: Dict[str, Any]) -> None: 

628 """Process the inherits portion of the loaded yaml document 

629 

630 Parameters 

631 --------- 

632 loaded_yaml : `dict` 

633 A dictionary which matches the structure that would be produced by 

634 a yaml reader which parses a pipeline definition document 

635 """ 

636 

637 def process_args(argument: Union[str, dict]) -> dict: 

638 if isinstance(argument, str): 

639 return {"location": argument} 

640 elif isinstance(argument, dict): 

641 if "exclude" in argument and isinstance(argument["exclude"], str): 

642 argument["exclude"] = [argument["exclude"]] 

643 if "include" in argument and isinstance(argument["include"], str): 

644 argument["include"] = [argument["include"]] 

645 if "instrument" in argument and argument["instrument"] == "None": 

646 argument["instrument"] = None 

647 return argument 

648 

649 if not {"inherits", "imports"} - loaded_yaml.keys(): 

650 raise ValueError("Cannot define both inherits and imports sections, use imports") 

651 tmp_import = loaded_yaml.pop("inherits", None) 

652 if tmp_import is None: 

653 tmp_import = loaded_yaml.pop("imports", None) 

654 else: 

655 warnings.warn( 

656 "The 'inherits' key is deprecated, and will be " 

657 "removed around June 2021. Please use the key " 

658 "'imports' instead" 

659 ) 

660 if tmp_import is None: 

661 self.imports: List[ImportIR] = [] 

662 elif isinstance(tmp_import, list): 

663 self.imports = [ImportIR(**process_args(args)) for args in tmp_import] 

664 else: 

665 self.imports = [ImportIR(**process_args(tmp_import))] 

666 

667 # integrate any imported pipelines 

668 accumulate_tasks: Dict[str, TaskIR] = {} 

669 accumulate_labeled_subsets: Dict[str, LabeledSubset] = {} 

670 accumulated_parameters = ParametersIR({}) 

671 for other_pipeline in self.imports: 

672 tmp_IR = other_pipeline.toPipelineIR() 

673 if self.instrument is None: 

674 self.instrument = tmp_IR.instrument 

675 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None: 

676 msg = ( 

677 "Only one instrument can be declared in a pipeline or its imports. " 

678 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} " 

679 f"defines {tmp_IR.instrument}." 

680 ) 

681 raise ValueError(msg) 

682 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys(): 

683 msg = ( 

684 "Task labels in the imported pipelines must be unique. " 

685 f"These labels appear multiple times: {duplicate_labels}" 

686 ) 

687 raise ValueError(msg) 

688 accumulate_tasks.update(tmp_IR.tasks) 

689 self.contracts.extend(tmp_IR.contracts) 

690 # verify that tmp_IR has unique labels for named subset among 

691 # existing labeled subsets, and with existing task labels. 

692 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys() 

693 task_subset_overlap = ( 

694 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys() 

695 ) & accumulate_tasks.keys() 

696 if overlapping_subsets or task_subset_overlap: 

697 raise ValueError( 

698 "Labeled subset names must be unique amongst imports in both labels and " 

699 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}" 

700 ) 

701 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets) 

702 accumulated_parameters.update(tmp_IR.parameters) 

703 

704 # verify that any accumulated labeled subsets dont clash with a label 

705 # from this pipeline 

706 if accumulate_labeled_subsets.keys() & self.tasks.keys(): 

707 raise ValueError( 

708 "Labeled subset names must be unique amongst imports in both labels and named Subsets" 

709 ) 

710 # merge in the named subsets for self so this document can override any 

711 # that have been delcared 

712 accumulate_labeled_subsets.update(self.labeled_subsets) 

713 self.labeled_subsets = accumulate_labeled_subsets 

714 

715 # merge the dict of label:TaskIR objects, preserving any configs in the 

716 # imported pipeline if the labels point to the same class 

717 for label, task in self.tasks.items(): 

718 if label not in accumulate_tasks: 

719 accumulate_tasks[label] = task 

720 elif accumulate_tasks[label].klass == task.klass: 

721 if task.config is not None: 

722 for config in task.config: 

723 accumulate_tasks[label].add_or_update_config(config) 

724 else: 

725 accumulate_tasks[label] = task 

726 self.tasks: Dict[str, TaskIR] = accumulate_tasks 

727 accumulated_parameters.update(self.parameters) 

728 self.parameters = accumulated_parameters 

729 

730 def _read_tasks(self, loaded_yaml: Dict[str, Any]) -> None: 

731 """Process the tasks portion of the loaded yaml document 

732 

733 Parameters 

734 --------- 

735 loaded_yaml : `dict` 

736 A dictionary which matches the structure that would be produced by 

737 a yaml reader which parses a pipeline definition document 

738 """ 

739 self.tasks = {} 

740 tmp_tasks = loaded_yaml.pop("tasks", None) 

741 if tmp_tasks is None: 

742 tmp_tasks = {} 

743 

744 if "parameters" in tmp_tasks: 

745 raise ValueError("parameters is a reserved word and cannot be used as a task label") 

746 

747 for label, definition in tmp_tasks.items(): 

748 if isinstance(definition, str): 

749 definition = {"class": definition} 

750 config = definition.get("config", None) 

751 if config is None: 

752 task_config_ir = None 

753 else: 

754 if isinstance(config, dict): 

755 config = [config] 

756 task_config_ir = [] 

757 for c in config: 

758 file = c.pop("file", None) 

759 if file is None: 

760 file = [] 

761 elif not isinstance(file, list): 

762 file = [file] 

763 task_config_ir.append( 

764 ConfigIR( 

765 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c 

766 ) 

767 ) 

768 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir) 

769 

770 def _remove_contracts(self, label: str) -> None: 

771 """Remove any contracts that contain the given label 

772 

773 String comparison used in this way is not the most elegant and may 

774 have issues, but it is the only feasible way when users can specify 

775 contracts with generic strings. 

776 """ 

777 new_contracts = [] 

778 for contract in self.contracts: 

779 # match a label that is not preceded by an ASCII identifier, or 

780 # is the start of a line and is followed by a dot 

781 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

782 continue 

783 new_contracts.append(contract) 

784 self.contracts = new_contracts 

785 

786 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR: 

787 """Subset a pipelineIR to contain only labels specified in 

788 labelSpecifier. 

789 

790 Parameters 

791 ---------- 

792 labelSpecifier : `set` of `str` 

793 Set containing labels that describes how to subset a pipeline. 

794 

795 Returns 

796 ------- 

797 pipeline : `PipelineIR` 

798 A new pipelineIR object that is a subset of the old pipelineIR 

799 

800 Raises 

801 ------ 

802 ValueError 

803 Raised if there is an issue with specified labels 

804 

805 Notes 

806 ----- 

807 This method attempts to prune any contracts that contain labels which 

808 are not in the declared subset of labels. This pruning is done using a 

809 string based matching due to the nature of contracts and may prune more 

810 than it should. Any labeled subsets defined that no longer have all 

811 members of the subset present in the pipeline will be removed from the 

812 resulting pipeline. 

813 """ 

814 

815 pipeline = copy.deepcopy(self) 

816 

817 # update the label specifier to expand any named subsets 

818 toRemove = set() 

819 toAdd = set() 

820 for label in labelSpecifier: 

821 if label in pipeline.labeled_subsets: 

822 toRemove.add(label) 

823 toAdd.update(pipeline.labeled_subsets[label].subset) 

824 labelSpecifier.difference_update(toRemove) 

825 labelSpecifier.update(toAdd) 

826 # verify all the labels are in the pipeline 

827 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets): 

828 difference = labelSpecifier.difference(pipeline.tasks.keys()) 

829 raise ValueError( 

830 "Not all supplied labels (specified or named subsets) are in the pipeline " 

831 f"definition, extra labels: {difference}" 

832 ) 

833 # copy needed so as to not modify while iterating 

834 pipeline_labels = set(pipeline.tasks.keys()) 

835 # Remove the labels from the pipelineIR, and any contracts that contain 

836 # those labels (see docstring on _remove_contracts for why this may 

837 # cause issues) 

838 for label in pipeline_labels: 

839 if label not in labelSpecifier: 

840 pipeline.tasks.pop(label) 

841 pipeline._remove_contracts(label) 

842 

843 # create a copy of the object to iterate over 

844 labeled_subsets = copy.copy(pipeline.labeled_subsets) 

845 # remove any labeled subsets that no longer have a complete set 

846 for label, labeled_subset in labeled_subsets.items(): 

847 if labeled_subset.subset - pipeline.tasks.keys(): 

848 pipeline.labeled_subsets.pop(label) 

849 

850 return pipeline 

851 

852 @classmethod 

853 def from_string(cls, pipeline_string: str) -> PipelineIR: 

854 """Create a `PipelineIR` object from a string formatted like a pipeline 

855 document 

856 

857 Parameters 

858 ---------- 

859 pipeline_string : `str` 

860 A string that is formatted according like a pipeline document 

861 """ 

862 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader) 

863 return cls(loaded_yaml) 

864 

865 @classmethod 

866 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR: 

867 """Create a `PipelineIR` object from the document specified by the 

868 input uri. 

869 

870 Parameters 

871 ---------- 

872 uri: convertible to `ResourcePath` 

873 Location of document to use in creating a `PipelineIR` object. 

874 

875 Returns 

876 ------- 

877 pipelineIR : `PipelineIR` 

878 The loaded pipeline 

879 """ 

880 loaded_uri = ResourcePath(uri) 

881 with loaded_uri.open("r") as buffer: 

882 loaded_yaml = yaml.load(buffer, Loader=PipelineYamlLoader) 

883 return cls(loaded_yaml) 

884 

885 def write_to_uri( 

886 self, 

887 uri: ResourcePathExpression, 

888 ) -> None: 

889 """Serialize this `PipelineIR` object into a yaml formatted string and 

890 write the output to a file at the specified uri. 

891 

892 Parameters 

893 ---------- 

894 uri: convertible to `ResourcePath` 

895 Location of document to write a `PipelineIR` object. 

896 """ 

897 with ResourcePath(uri).open("w") as buffer: 

898 yaml.dump(self.to_primitives(), buffer, sort_keys=False, Dumper=MultilineStringDumper) 

899 

900 def to_primitives(self) -> Dict[str, Any]: 

901 """Convert to a representation used in yaml serialization""" 

902 accumulate = {"description": self.description} 

903 if self.instrument is not None: 

904 accumulate["instrument"] = self.instrument 

905 if self.parameters: 

906 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives()) 

907 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()} 

908 if len(self.contracts) > 0: 

909 # sort contracts lexicographical order by the contract string in 

910 # absence of any other ordering principle 

911 contracts_list = [c.to_primitives() for c in self.contracts] 

912 contracts_list.sort(key=lambda x: x["contract"]) 

913 accumulate["contracts"] = contracts_list 

914 if self.labeled_subsets: 

915 accumulate["subsets"] = self._sort_by_str( 

916 {k: v.to_primitives() for k, v in self.labeled_subsets.items()} 

917 ) 

918 return accumulate 

919 

920 def reorder_tasks(self, task_labels: List[str]) -> None: 

921 """Changes the order tasks are stored internally. Useful for 

922 determining the order things will appear in the serialized (or printed) 

923 form. 

924 

925 Parameters 

926 ---------- 

927 task_labels : `list` of `str` 

928 A list corresponding to all the labels in the pipeline inserted in 

929 the order the tasks are to be stored. 

930 

931 Raises 

932 ------ 

933 KeyError 

934 Raised if labels are supplied that are not in the pipeline, or if 

935 not all labels in the pipeline were supplied in task_labels input. 

936 """ 

937 # verify that all labels are in the input 

938 _tmp_set = set(task_labels) 

939 if remainder := (self.tasks.keys() - _tmp_set): 

940 raise KeyError(f"Label(s) {remainder} are missing from the task label list") 

941 if extra := (_tmp_set - self.tasks.keys()): 

942 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline") 

943 

944 newTasks = {key: self.tasks[key] for key in task_labels} 

945 self.tasks = newTasks 

946 

947 @staticmethod 

948 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]: 

949 keys = sorted(arg.keys()) 

950 return {key: arg[key] for key in keys} 

951 

952 def __str__(self) -> str: 

953 """Instance formatting as how it would look in yaml representation""" 

954 return yaml.dump(self.to_primitives(), sort_keys=False, Dumper=MultilineStringDumper) 

955 

956 def __repr__(self) -> str: 

957 """Instance formatting as how it would look in yaml representation""" 

958 return str(self) 

959 

960 def __eq__(self, other: object) -> bool: 

961 if not isinstance(other, PipelineIR): 

962 return False 

963 # special case contracts because it is a list, but order is not 

964 # important 

965 elif ( 

966 all( 

967 getattr(self, attr) == getattr(other, attr) 

968 for attr in ("tasks", "instrument", "labeled_subsets", "parameters") 

969 ) 

970 and len(self.contracts) == len(other.contracts) 

971 and all(c in self.contracts for c in other.contracts) 

972 ): 

973 return True 

974 else: 

975 return False