Coverage for python/lsst/pipe/base/pipelineIR.py: 19%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

392 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset") 

24 

25import copy 

26import enum 

27import os 

28import re 

29import warnings 

30from collections import Counter 

31from collections.abc import Iterable as abcIterable 

32from dataclasses import dataclass, field 

33from typing import Any, Dict, Generator, List, Literal, Mapping, MutableMapping, Optional, Set, Union 

34 

35import yaml 

36from deprecated.sphinx import deprecated 

37from lsst.resources import ResourcePath, ResourcePathExpression 

38 

39 

40class _Tags(enum.Enum): 

41 KeepInstrument = enum.auto() 

42 

43 

44class PipelineYamlLoader(yaml.SafeLoader): 

45 """This is a specialized version of yaml's SafeLoader. It checks and raises 

46 an exception if it finds that there are multiple instances of the same key 

47 found inside a pipeline file at a given scope. 

48 """ 

49 

50 def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Mapping[str, Any]: 

51 # do the call to super first so that it can do all the other forms of 

52 # checking on this node. If you check the uniqueness of keys first 

53 # it would save the work that super does in the case of a failure, but 

54 # it might fail in the case that the node was the incorrect node due 

55 # to a parsing error, and the resulting exception would be difficult to 

56 # understand. 

57 mapping = super().construct_mapping(node, deep) 

58 # Check if there are any duplicate keys 

59 all_keys = Counter(key_node.value for key_node, _ in node.value) 

60 duplicates = {k for k, i in all_keys.items() if i != 1} 

61 if duplicates: 

62 raise KeyError( 

63 f"Pipeline files must not have duplicated keys, {duplicates} appeared multiple times" 

64 ) 

65 return mapping 

66 

67 

68class ContractError(Exception): 

69 """An exception that is raised when a pipeline contract is not satisfied""" 

70 

71 pass 

72 

73 

74@dataclass 

75class ContractIR: 

76 """Intermediate representation of configuration contracts read from a 

77 pipeline yaml file.""" 

78 

79 contract: str 

80 """A string of python code representing one or more conditions on configs 

81 in a pipeline. This code-as-string should, once evaluated, should be True 

82 if the configs are fine, and False otherwise. 

83 """ 

84 msg: Union[str, None] = None 

85 """An optional message to be shown to the user if a contract fails 

86 """ 

87 

88 def to_primitives(self) -> Dict[str, str]: 

89 """Convert to a representation used in yaml serialization""" 

90 accumulate = {"contract": self.contract} 

91 if self.msg is not None: 

92 accumulate["msg"] = self.msg 

93 return accumulate 

94 

95 def __eq__(self, other: object) -> bool: 

96 if not isinstance(other, ContractIR): 

97 return False 

98 elif self.contract == other.contract and self.msg == other.msg: 

99 return True 

100 else: 

101 return False 

102 

103 

104@dataclass 

105class LabeledSubset: 

106 """Intermediate representation of named subset of task labels read from 

107 a pipeline yaml file. 

108 """ 

109 

110 label: str 

111 """The label used to identify the subset of task labels. 

112 """ 

113 subset: Set[str] 

114 """A set of task labels contained in this subset. 

115 """ 

116 description: Optional[str] 

117 """A description of what this subset of tasks is intended to do 

118 """ 

119 

120 @staticmethod 

121 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset: 

122 """Generate `LabeledSubset` objects given a properly formatted object 

123 that as been created by a yaml loader. 

124 

125 Parameters 

126 ---------- 

127 label : `str` 

128 The label that will be used to identify this labeled subset. 

129 value : `list` of `str` or `dict` 

130 Object returned from loading a labeled subset section from a yaml 

131 document. 

132 

133 Returns 

134 ------- 

135 labeledSubset : `LabeledSubset` 

136 A `LabeledSubset` object build from the inputs. 

137 

138 Raises 

139 ------ 

140 ValueError 

141 Raised if the value input is not properly formatted for parsing 

142 """ 

143 if isinstance(value, MutableMapping): 

144 subset = value.pop("subset", None) 

145 if subset is None: 

146 raise ValueError( 

147 "If a labeled subset is specified as a mapping, it must contain the key 'subset'" 

148 ) 

149 description = value.pop("description", None) 

150 elif isinstance(value, abcIterable): 

151 subset = value 

152 description = None 

153 else: 

154 raise ValueError( 

155 f"There was a problem parsing the labeled subset {label}, make sure the " 

156 "definition is either a valid yaml list, or a mapping with keys " 

157 "(subset, description) where subset points to a yaml list, and description is " 

158 "associated with a string" 

159 ) 

160 return LabeledSubset(label, set(subset), description) 

161 

162 def to_primitives(self) -> Dict[str, Union[List[str], str]]: 

163 """Convert to a representation used in yaml serialization""" 

164 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)} 

165 if self.description is not None: 

166 accumulate["description"] = self.description 

167 return accumulate 

168 

169 

170@dataclass 

171class ParametersIR: 

172 """Intermediate representation of parameters that are global to a pipeline 

173 

174 These parameters are specified under a top level key named `parameters` 

175 and are declared as a yaml mapping. These entries can then be used inside 

176 task configuration blocks to specify configuration values. They may not be 

177 used in the special ``file`` or ``python`` blocks. 

178 

179 Example: 

180 paramters: 

181 shared_value: 14 

182 tasks: 

183 taskA: 

184 class: modA 

185 config: 

186 field1: parameters.shared_value 

187 taskB: 

188 class: modB 

189 config: 

190 field2: parameters.shared_value 

191 """ 

192 

193 mapping: MutableMapping[str, str] 

194 """A mutable mapping of identifiers as keys, and shared configuration 

195 as values. 

196 """ 

197 

198 def update(self, other: Optional[ParametersIR]) -> None: 

199 if other is not None: 

200 self.mapping.update(other.mapping) 

201 

202 def to_primitives(self) -> MutableMapping[str, str]: 

203 """Convert to a representation used in yaml serialization""" 

204 return self.mapping 

205 

206 def __contains__(self, value: str) -> bool: 

207 return value in self.mapping 

208 

209 def __getitem__(self, item: str) -> Any: 

210 return self.mapping[item] 

211 

212 def __bool__(self) -> bool: 

213 return bool(self.mapping) 

214 

215 

216@dataclass 

217class ConfigIR: 

218 """Intermediate representation of configurations read from a pipeline yaml 

219 file. 

220 """ 

221 

222 python: Union[str, None] = None 

223 """A string of python code that is used to modify a configuration. This can 

224 also be None if there are no modifications to do. 

225 """ 

226 dataId: Union[dict, None] = None 

227 """A dataId that is used to constrain these config overrides to only quanta 

228 with matching dataIds. This field can be None if there is no constraint. 

229 This is currently an unimplemented feature, and is placed here for future 

230 use. 

231 """ 

232 file: List[str] = field(default_factory=list) 

233 """A list of paths which points to a file containing config overrides to be 

234 applied. This value may be an empty list if there are no overrides to 

235 apply. 

236 """ 

237 rest: dict = field(default_factory=dict) 

238 """This is a dictionary of key value pairs, where the keys are strings 

239 corresponding to qualified fields on a config to override, and the values 

240 are strings representing the values to apply. 

241 """ 

242 

243 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]: 

244 """Convert to a representation used in yaml serialization""" 

245 accumulate = {} 

246 for name in ("python", "dataId", "file"): 

247 # if this attribute is thruthy add it to the accumulation 

248 # dictionary 

249 if getattr(self, name): 

250 accumulate[name] = getattr(self, name) 

251 # Add the dictionary containing the rest of the config keys to the 

252 # # accumulated dictionary 

253 accumulate.update(self.rest) 

254 return accumulate 

255 

256 def formatted(self, parameters: ParametersIR) -> ConfigIR: 

257 """Returns a new ConfigIR object that is formatted according to the 

258 specified parameters 

259 

260 Parameters 

261 ---------- 

262 parameters : ParametersIR 

263 Object that contains variable mappings used in substitution. 

264 

265 Returns 

266 ------- 

267 config : ConfigIR 

268 A new ConfigIR object formatted with the input parameters 

269 """ 

270 new_config = copy.deepcopy(self) 

271 for key, value in new_config.rest.items(): 

272 if not isinstance(value, str): 

273 continue 

274 match = re.match("parameters[.](.*)", value) 

275 if match and match.group(1) in parameters: 

276 new_config.rest[key] = parameters[match.group(1)] 

277 if match and match.group(1) not in parameters: 

278 warnings.warn( 

279 f"config {key} contains value {match.group(0)} which is formatted like a " 

280 "Pipeline parameter but was not found within the Pipeline, if this was not " 

281 "intentional, check for a typo" 

282 ) 

283 return new_config 

284 

285 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]: 

286 """Merges another instance of a `ConfigIR` into this instance if 

287 possible. This function returns a generator that is either self 

288 if the configs were merged, or self, and other_config if that could 

289 not be merged. 

290 

291 Parameters 

292 ---------- 

293 other_config : `ConfigIR` 

294 An instance of `ConfigIR` to merge into this instance. 

295 

296 Returns 

297 ------- 

298 Generator : `ConfigIR` 

299 A generator containing either self, or self and other_config if 

300 the configs could be merged or not respectively. 

301 """ 

302 # Verify that the config blocks can be merged 

303 if ( 

304 self.dataId != other_config.dataId 

305 or self.python 

306 or other_config.python 

307 or self.file 

308 or other_config.file 

309 ): 

310 yield from (self, other_config) 

311 return 

312 

313 # create a set of all keys, and verify two keys do not have different 

314 # values 

315 key_union = self.rest.keys() & other_config.rest.keys() 

316 for key in key_union: 

317 if self.rest[key] != other_config.rest[key]: 

318 yield from (self, other_config) 

319 return 

320 self.rest.update(other_config.rest) 

321 

322 # Combine the lists of override files to load 

323 self_file_set = set(self.file) 

324 other_file_set = set(other_config.file) 

325 self.file = list(self_file_set.union(other_file_set)) 

326 

327 yield self 

328 

329 def __eq__(self, other: object) -> bool: 

330 if not isinstance(other, ConfigIR): 

331 return False 

332 elif all( 

333 getattr(self, attr) == getattr(other, attr) for attr in ("python", "dataId", "file", "rest") 

334 ): 

335 return True 

336 else: 

337 return False 

338 

339 

340@dataclass 

341class TaskIR: 

342 """Intermediate representation of tasks read from a pipeline yaml file.""" 

343 

344 label: str 

345 """An identifier used to refer to a task. 

346 """ 

347 klass: str 

348 """A string containing a fully qualified python class to be run in a 

349 pipeline. 

350 """ 

351 config: Union[List[ConfigIR], None] = None 

352 """List of all configs overrides associated with this task, and may be 

353 `None` if there are no config overrides. 

354 """ 

355 

356 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]: 

357 """Convert to a representation used in yaml serialization""" 

358 accumulate: Dict[str, Union[str, List[dict]]] = {"class": self.klass} 

359 if self.config: 

360 accumulate["config"] = [c.to_primitives() for c in self.config] 

361 return accumulate 

362 

363 def add_or_update_config(self, other_config: ConfigIR) -> None: 

364 """Adds a `ConfigIR` to this task if one is not present. Merges configs 

365 if there is a `ConfigIR` present and the dataId keys of both configs 

366 match, otherwise adds a new entry to the config list. The exception to 

367 the above is that if either the last config or other_config has a 

368 python block, then other_config is always added, as python blocks can 

369 modify configs in ways that cannot be predicted. 

370 

371 Parameters 

372 ---------- 

373 other_config : `ConfigIR` 

374 A `ConfigIR` instance to add or merge into the config attribute of 

375 this task. 

376 """ 

377 if not self.config: 

378 self.config = [other_config] 

379 return 

380 self.config.extend(self.config.pop().maybe_merge(other_config)) 

381 

382 def __eq__(self, other: object) -> bool: 

383 if not isinstance(other, TaskIR): 

384 return False 

385 elif all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")): 

386 return True 

387 else: 

388 return False 

389 

390 

391@dataclass 

392class ImportIR: 

393 """An intermediate representation of imported pipelines""" 

394 

395 location: str 

396 """This is the location of the pipeline to inherit. The path should be 

397 specified as an absolute path. Environment variables may be used in the 

398 path and should be specified as a python string template, with the name of 

399 the environment variable inside braces. 

400 """ 

401 include: Union[List[str], None] = None 

402 """List of tasks that should be included when inheriting this pipeline. 

403 Either the include or exclude attributes may be specified, but not both. 

404 """ 

405 exclude: Union[List[str], None] = None 

406 """List of tasks that should be excluded when inheriting this pipeline. 

407 Either the include or exclude attributes may be specified, but not both. 

408 """ 

409 importContracts: bool = True 

410 """Boolean attribute to dictate if contracts should be inherited with the 

411 pipeline or not. 

412 """ 

413 instrument: Union[Literal[_Tags.KeepInstrument], str, None] = _Tags.KeepInstrument 

414 """Instrument to assign to the Pipeline at import. The default value of 

415 `_Tags.KeepInstrument`` indicates that whatever instrument the pipeline is 

416 declared with will not be modified. Setting this value to None will drop 

417 any declared instrument prior to import. 

418 """ 

419 

420 def toPipelineIR(self) -> "PipelineIR": 

421 """Load in the Pipeline specified by this object, and turn it into a 

422 PipelineIR instance. 

423 

424 Returns 

425 ------- 

426 pipeline : `PipelineIR` 

427 A pipeline generated from the imported pipeline file 

428 """ 

429 if self.include and self.exclude: 

430 raise ValueError( 

431 "Both an include and an exclude list cant be specified when declaring a pipeline import" 

432 ) 

433 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location)) 

434 if self.instrument is not _Tags.KeepInstrument: 

435 tmp_pipeline.instrument = self.instrument 

436 

437 included_labels = set() 

438 for label in tmp_pipeline.tasks: 

439 if ( 

440 (self.include and label in self.include) 

441 or (self.exclude and label not in self.exclude) 

442 or (self.include is None and self.exclude is None) 

443 ): 

444 included_labels.add(label) 

445 

446 # Handle labeled subsets being specified in the include or exclude 

447 # list, adding or removing labels. 

448 if self.include is not None: 

449 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include 

450 for label in subsets_in_include: 

451 included_labels.update(tmp_pipeline.labeled_subsets[label].subset) 

452 

453 elif self.exclude is not None: 

454 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude 

455 for label in subsets_in_exclude: 

456 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset) 

457 

458 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels) 

459 

460 if not self.importContracts: 

461 tmp_pipeline.contracts = [] 

462 

463 return tmp_pipeline 

464 

465 def __eq__(self, other: object) -> bool: 

466 if not isinstance(other, ImportIR): 

467 return False 

468 elif all( 

469 getattr(self, attr) == getattr(other, attr) 

470 for attr in ("location", "include", "exclude", "importContracts") 

471 ): 

472 return True 

473 else: 

474 return False 

475 

476 

477class PipelineIR: 

478 """Intermediate representation of a pipeline definition 

479 

480 Parameters 

481 ---------- 

482 loaded_yaml : `dict` 

483 A dictionary which matches the structure that would be produced by a 

484 yaml reader which parses a pipeline definition document 

485 

486 Raises 

487 ------ 

488 ValueError : 

489 - If a pipeline is declared without a description 

490 - If no tasks are declared in a pipeline, and no pipelines are to be 

491 inherited 

492 - If more than one instrument is specified 

493 - If more than one inherited pipeline share a label 

494 """ 

495 

496 def __init__(self, loaded_yaml: Dict[str, Any]): 

497 # Check required fields are present 

498 if "description" not in loaded_yaml: 

499 raise ValueError("A pipeline must be declared with a description") 

500 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2: 

501 raise ValueError("A pipeline must be declared with one or more tasks") 

502 

503 # These steps below must happen in this call order 

504 

505 # Process pipeline description 

506 self.description = loaded_yaml.pop("description") 

507 

508 # Process tasks 

509 self._read_tasks(loaded_yaml) 

510 

511 # Process instrument keys 

512 inst = loaded_yaml.pop("instrument", None) 

513 if isinstance(inst, list): 

514 raise ValueError("Only one top level instrument can be defined in a pipeline") 

515 self.instrument: Optional[str] = inst 

516 

517 # Process any contracts 

518 self._read_contracts(loaded_yaml) 

519 

520 # Process any defined parameters 

521 self._read_parameters(loaded_yaml) 

522 

523 # Process any named label subsets 

524 self._read_labeled_subsets(loaded_yaml) 

525 

526 # Process any inherited pipelines 

527 self._read_imports(loaded_yaml) 

528 

529 # verify named subsets, must be done after inheriting 

530 self._verify_labeled_subsets() 

531 

532 def _read_contracts(self, loaded_yaml: Dict[str, Any]) -> None: 

533 """Process the contracts portion of the loaded yaml document 

534 

535 Parameters 

536 --------- 

537 loaded_yaml : `dict` 

538 A dictionary which matches the structure that would be produced by 

539 a yaml reader which parses a pipeline definition document 

540 """ 

541 loaded_contracts = loaded_yaml.pop("contracts", []) 

542 if isinstance(loaded_contracts, str): 

543 loaded_contracts = [loaded_contracts] 

544 self.contracts: List[ContractIR] = [] 

545 for contract in loaded_contracts: 

546 if isinstance(contract, dict): 

547 self.contracts.append(ContractIR(**contract)) 

548 if isinstance(contract, str): 

549 self.contracts.append(ContractIR(contract=contract)) 

550 

551 def _read_parameters(self, loaded_yaml: Dict[str, Any]) -> None: 

552 """Process the parameters portion of the loaded yaml document 

553 

554 Parameters 

555 --------- 

556 loaded_yaml : `dict` 

557 A dictionary which matches the structure that would be produced by 

558 a yaml reader which parses a pipeline definition document 

559 """ 

560 loaded_parameters = loaded_yaml.pop("parameters", {}) 

561 if not isinstance(loaded_parameters, dict): 

562 raise ValueError("The parameters section must be a yaml mapping") 

563 self.parameters = ParametersIR(loaded_parameters) 

564 

565 def _read_labeled_subsets(self, loaded_yaml: Dict[str, Any]) -> None: 

566 """Process the subsets portion of the loaded yaml document 

567 

568 Parameters 

569 ---------- 

570 loaded_yaml: `MutableMapping` 

571 A dictionary which matches the structure that would be produced 

572 by a yaml reader which parses a pipeline definition document 

573 """ 

574 loaded_subsets = loaded_yaml.pop("subsets", {}) 

575 self.labeled_subsets: Dict[str, LabeledSubset] = {} 

576 if not loaded_subsets and "subset" in loaded_yaml: 

577 raise ValueError("Top level key should be subsets and not subset, add an s") 

578 for key, value in loaded_subsets.items(): 

579 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value) 

580 

581 def _verify_labeled_subsets(self) -> None: 

582 """Verifies that all the labels in each named subset exist within the 

583 pipeline. 

584 """ 

585 # Verify that all labels defined in a labeled subset are in the 

586 # Pipeline 

587 for labeled_subset in self.labeled_subsets.values(): 

588 if not labeled_subset.subset.issubset(self.tasks.keys()): 

589 raise ValueError( 

590 f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the " 

591 "declared pipeline" 

592 ) 

593 # Verify subset labels are not already task labels 

594 label_intersection = self.labeled_subsets.keys() & self.tasks.keys() 

595 if label_intersection: 

596 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}") 

597 

598 def _read_imports(self, loaded_yaml: Dict[str, Any]) -> None: 

599 """Process the inherits portion of the loaded yaml document 

600 

601 Parameters 

602 --------- 

603 loaded_yaml : `dict` 

604 A dictionary which matches the structure that would be produced by 

605 a yaml reader which parses a pipeline definition document 

606 """ 

607 

608 def process_args(argument: Union[str, dict]) -> dict: 

609 if isinstance(argument, str): 

610 return {"location": argument} 

611 elif isinstance(argument, dict): 

612 if "exclude" in argument and isinstance(argument["exclude"], str): 

613 argument["exclude"] = [argument["exclude"]] 

614 if "include" in argument and isinstance(argument["include"], str): 

615 argument["include"] = [argument["include"]] 

616 if "instrument" in argument and argument["instrument"] == "None": 

617 argument["instrument"] = None 

618 return argument 

619 

620 if not {"inherits", "imports"} - loaded_yaml.keys(): 

621 raise ValueError("Cannot define both inherits and imports sections, use imports") 

622 tmp_import = loaded_yaml.pop("inherits", None) 

623 if tmp_import is None: 

624 tmp_import = loaded_yaml.pop("imports", None) 

625 else: 

626 warnings.warn( 

627 "The 'inherits' key is deprecated, and will be " 

628 "removed around June 2021. Please use the key " 

629 "'imports' instead" 

630 ) 

631 if tmp_import is None: 

632 self.imports: List[ImportIR] = [] 

633 elif isinstance(tmp_import, list): 

634 self.imports = [ImportIR(**process_args(args)) for args in tmp_import] 

635 else: 

636 self.imports = [ImportIR(**process_args(tmp_import))] 

637 

638 # integrate any imported pipelines 

639 accumulate_tasks: Dict[str, TaskIR] = {} 

640 accumulate_labeled_subsets: Dict[str, LabeledSubset] = {} 

641 accumulated_parameters = ParametersIR({}) 

642 for other_pipeline in self.imports: 

643 tmp_IR = other_pipeline.toPipelineIR() 

644 if self.instrument is None: 

645 self.instrument = tmp_IR.instrument 

646 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None: 

647 msg = ( 

648 "Only one instrument can be declared in a pipeline or its imports. " 

649 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} " 

650 f"defines {tmp_IR.instrument}." 

651 ) 

652 raise ValueError(msg) 

653 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys(): 

654 msg = ( 

655 "Task labels in the imported pipelines must be unique. " 

656 f"These labels appear multiple times: {duplicate_labels}" 

657 ) 

658 raise ValueError(msg) 

659 accumulate_tasks.update(tmp_IR.tasks) 

660 self.contracts.extend(tmp_IR.contracts) 

661 # verify that tmp_IR has unique labels for named subset among 

662 # existing labeled subsets, and with existing task labels. 

663 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys() 

664 task_subset_overlap = ( 

665 accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys() 

666 ) & accumulate_tasks.keys() 

667 if overlapping_subsets or task_subset_overlap: 

668 raise ValueError( 

669 "Labeled subset names must be unique amongst imports in both labels and " 

670 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}" 

671 ) 

672 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets) 

673 accumulated_parameters.update(tmp_IR.parameters) 

674 

675 # verify that any accumulated labeled subsets dont clash with a label 

676 # from this pipeline 

677 if accumulate_labeled_subsets.keys() & self.tasks.keys(): 

678 raise ValueError( 

679 "Labeled subset names must be unique amongst imports in both labels and named Subsets" 

680 ) 

681 # merge in the named subsets for self so this document can override any 

682 # that have been delcared 

683 accumulate_labeled_subsets.update(self.labeled_subsets) 

684 self.labeled_subsets = accumulate_labeled_subsets 

685 

686 # merge the dict of label:TaskIR objects, preserving any configs in the 

687 # imported pipeline if the labels point to the same class 

688 for label, task in self.tasks.items(): 

689 if label not in accumulate_tasks: 

690 accumulate_tasks[label] = task 

691 elif accumulate_tasks[label].klass == task.klass: 

692 if task.config is not None: 

693 for config in task.config: 

694 accumulate_tasks[label].add_or_update_config(config) 

695 else: 

696 accumulate_tasks[label] = task 

697 self.tasks: Dict[str, TaskIR] = accumulate_tasks 

698 accumulated_parameters.update(self.parameters) 

699 self.parameters = accumulated_parameters 

700 

701 def _read_tasks(self, loaded_yaml: Dict[str, Any]) -> None: 

702 """Process the tasks portion of the loaded yaml document 

703 

704 Parameters 

705 --------- 

706 loaded_yaml : `dict` 

707 A dictionary which matches the structure that would be produced by 

708 a yaml reader which parses a pipeline definition document 

709 """ 

710 self.tasks = {} 

711 tmp_tasks = loaded_yaml.pop("tasks", None) 

712 if tmp_tasks is None: 

713 tmp_tasks = {} 

714 

715 if "parameters" in tmp_tasks: 

716 raise ValueError("parameters is a reserved word and cannot be used as a task label") 

717 

718 for label, definition in tmp_tasks.items(): 

719 if isinstance(definition, str): 

720 definition = {"class": definition} 

721 config = definition.get("config", None) 

722 if config is None: 

723 task_config_ir = None 

724 else: 

725 if isinstance(config, dict): 

726 config = [config] 

727 task_config_ir = [] 

728 for c in config: 

729 file = c.pop("file", None) 

730 if file is None: 

731 file = [] 

732 elif not isinstance(file, list): 

733 file = [file] 

734 task_config_ir.append( 

735 ConfigIR( 

736 python=c.pop("python", None), dataId=c.pop("dataId", None), file=file, rest=c 

737 ) 

738 ) 

739 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir) 

740 

741 def _remove_contracts(self, label: str) -> None: 

742 """Remove any contracts that contain the given label 

743 

744 String comparison used in this way is not the most elegant and may 

745 have issues, but it is the only feasible way when users can specify 

746 contracts with generic strings. 

747 """ 

748 new_contracts = [] 

749 for contract in self.contracts: 

750 # match a label that is not preceded by an ASCII identifier, or 

751 # is the start of a line and is followed by a dot 

752 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

753 continue 

754 new_contracts.append(contract) 

755 self.contracts = new_contracts 

756 

757 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR: 

758 """Subset a pipelineIR to contain only labels specified in 

759 labelSpecifier. 

760 

761 Parameters 

762 ---------- 

763 labelSpecifier : `set` of `str` 

764 Set containing labels that describes how to subset a pipeline. 

765 

766 Returns 

767 ------- 

768 pipeline : `PipelineIR` 

769 A new pipelineIR object that is a subset of the old pipelineIR 

770 

771 Raises 

772 ------ 

773 ValueError 

774 Raised if there is an issue with specified labels 

775 

776 Notes 

777 ----- 

778 This method attempts to prune any contracts that contain labels which 

779 are not in the declared subset of labels. This pruning is done using a 

780 string based matching due to the nature of contracts and may prune more 

781 than it should. Any labeled subsets defined that no longer have all 

782 members of the subset present in the pipeline will be removed from the 

783 resulting pipeline. 

784 """ 

785 

786 pipeline = copy.deepcopy(self) 

787 

788 # update the label specifier to expand any named subsets 

789 toRemove = set() 

790 toAdd = set() 

791 for label in labelSpecifier: 

792 if label in pipeline.labeled_subsets: 

793 toRemove.add(label) 

794 toAdd.update(pipeline.labeled_subsets[label].subset) 

795 labelSpecifier.difference_update(toRemove) 

796 labelSpecifier.update(toAdd) 

797 # verify all the labels are in the pipeline 

798 if not labelSpecifier.issubset(pipeline.tasks.keys() | pipeline.labeled_subsets): 

799 difference = labelSpecifier.difference(pipeline.tasks.keys()) 

800 raise ValueError( 

801 "Not all supplied labels (specified or named subsets) are in the pipeline " 

802 f"definition, extra labels: {difference}" 

803 ) 

804 # copy needed so as to not modify while iterating 

805 pipeline_labels = set(pipeline.tasks.keys()) 

806 # Remove the labels from the pipelineIR, and any contracts that contain 

807 # those labels (see docstring on _remove_contracts for why this may 

808 # cause issues) 

809 for label in pipeline_labels: 

810 if label not in labelSpecifier: 

811 pipeline.tasks.pop(label) 

812 pipeline._remove_contracts(label) 

813 

814 # create a copy of the object to iterate over 

815 labeled_subsets = copy.copy(pipeline.labeled_subsets) 

816 # remove any labeled subsets that no longer have a complete set 

817 for label, labeled_subset in labeled_subsets.items(): 

818 if labeled_subset.subset - pipeline.tasks.keys(): 

819 pipeline.labeled_subsets.pop(label) 

820 

821 return pipeline 

822 

823 @classmethod 

824 def from_string(cls, pipeline_string: str) -> PipelineIR: 

825 """Create a `PipelineIR` object from a string formatted like a pipeline 

826 document 

827 

828 Parameters 

829 ---------- 

830 pipeline_string : `str` 

831 A string that is formatted according like a pipeline document 

832 """ 

833 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader) 

834 return cls(loaded_yaml) 

835 

836 @classmethod 

837 @deprecated( 

838 reason="This has been replaced with `from_uri`. will be removed after v23", 

839 version="v21.0,", 

840 category=FutureWarning, 

841 ) 

842 def from_file(cls, filename: str) -> PipelineIR: 

843 """Create a `PipelineIR` object from the document specified by the 

844 input path. 

845 

846 Parameters 

847 ---------- 

848 filename : `str` 

849 Location of document to use in creating a `PipelineIR` object. 

850 

851 Returns 

852 ------- 

853 pipelineIR : `PipelineIR` 

854 The loaded pipeline 

855 

856 Note 

857 ---- 

858 This method is deprecated, please use from_uri 

859 """ 

860 return cls.from_uri(filename) 

861 

862 @classmethod 

863 def from_uri(cls, uri: ResourcePathExpression) -> PipelineIR: 

864 """Create a `PipelineIR` object from the document specified by the 

865 input uri. 

866 

867 Parameters 

868 ---------- 

869 uri: convertible to `ResourcePath` 

870 Location of document to use in creating a `PipelineIR` object. 

871 

872 Returns 

873 ------- 

874 pipelineIR : `PipelineIR` 

875 The loaded pipeline 

876 """ 

877 loaded_uri = ResourcePath(uri) 

878 with loaded_uri.open("r") as buffer: 

879 # explicitly read here, there was some issue with yaml trying 

880 # to read the ResourcePath itself (I think because it only 

881 # pretends to be conformant to the io api) 

882 loaded_yaml = yaml.load(buffer.read(), Loader=PipelineYamlLoader) 

883 return cls(loaded_yaml) 

884 

885 @deprecated( 

886 reason="This has been replaced with `write_to_uri`. will be removed after v23", 

887 version="v21.0,", 

888 category=FutureWarning, 

889 ) # type: ignore 

890 def to_file(self, filename: str): 

891 """Serialize this `PipelineIR` object into a yaml formatted string and 

892 write the output to a file at the specified path. 

893 

894 Parameters 

895 ---------- 

896 filename : `str` 

897 Location of document to write a `PipelineIR` object. 

898 """ 

899 self.write_to_uri(filename) 

900 

901 def write_to_uri( 

902 self, 

903 uri: ResourcePathExpression, 

904 ) -> None: 

905 """Serialize this `PipelineIR` object into a yaml formatted string and 

906 write the output to a file at the specified uri. 

907 

908 Parameters 

909 ---------- 

910 uri: convertible to `ResourcePath` 

911 Location of document to write a `PipelineIR` object. 

912 """ 

913 with ResourcePath(uri).open("w") as buffer: 

914 yaml.dump(self.to_primitives(), buffer, sort_keys=False) 

915 

916 def to_primitives(self) -> Dict[str, Any]: 

917 """Convert to a representation used in yaml serialization""" 

918 accumulate = {"description": self.description} 

919 if self.instrument is not None: 

920 accumulate["instrument"] = self.instrument 

921 if self.parameters: 

922 accumulate["parameters"] = self._sort_by_str(self.parameters.to_primitives()) 

923 accumulate["tasks"] = {m: t.to_primitives() for m, t in self.tasks.items()} 

924 if len(self.contracts) > 0: 

925 # sort contracts lexicographical order by the contract string in 

926 # absence of any other ordering principle 

927 contracts_list = [c.to_primitives() for c in self.contracts] 

928 contracts_list.sort(key=lambda x: x["contract"]) 

929 accumulate["contracts"] = contracts_list 

930 if self.labeled_subsets: 

931 accumulate["subsets"] = self._sort_by_str( 

932 {k: v.to_primitives() for k, v in self.labeled_subsets.items()} 

933 ) 

934 return accumulate 

935 

936 def reorder_tasks(self, task_labels: List[str]) -> None: 

937 """Changes the order tasks are stored internally. Useful for 

938 determining the order things will appear in the serialized (or printed) 

939 form. 

940 

941 Parameters 

942 ---------- 

943 task_labels : `list` of `str` 

944 A list corresponding to all the labels in the pipeline inserted in 

945 the order the tasks are to be stored. 

946 

947 Raises 

948 ------ 

949 KeyError 

950 Raised if labels are supplied that are not in the pipeline, or if 

951 not all labels in the pipeline were supplied in task_labels input. 

952 """ 

953 # verify that all labels are in the input 

954 _tmp_set = set(task_labels) 

955 if remainder := (self.tasks.keys() - _tmp_set): 

956 raise KeyError(f"Label(s) {remainder} are missing from the task label list") 

957 if extra := (_tmp_set - self.tasks.keys()): 

958 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline") 

959 

960 newTasks = {key: self.tasks[key] for key in task_labels} 

961 self.tasks = newTasks 

962 

963 @staticmethod 

964 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]: 

965 keys = sorted(arg.keys()) 

966 return {key: arg[key] for key in keys} 

967 

968 def __str__(self) -> str: 

969 """Instance formatting as how it would look in yaml representation""" 

970 return yaml.dump(self.to_primitives(), sort_keys=False) 

971 

972 def __repr__(self) -> str: 

973 """Instance formatting as how it would look in yaml representation""" 

974 return str(self) 

975 

976 def __eq__(self, other: object) -> bool: 

977 if not isinstance(other, PipelineIR): 

978 return False 

979 # special case contracts because it is a list, but order is not 

980 # important 

981 elif ( 

982 all( 

983 getattr(self, attr) == getattr(other, attr) 

984 for attr in ("tasks", "instrument", "labeled_subsets", "parameters") 

985 ) 

986 and len(self.contracts) == len(other.contracts) 

987 and all(c in self.contracts for c in other.contracts) 

988 ): 

989 return True 

990 else: 

991 return False