Coverage for python/lsst/pipe/base/pipelineIR.py: 18%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

391 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ConfigIR", "ContractError", "ContractIR", "ImportIR", "PipelineIR", "TaskIR", "LabeledSubset") 

24 

25from collections import Counter 

26from collections.abc import Iterable as abcIterable 

27from dataclasses import dataclass, field 

28from deprecated.sphinx import deprecated 

29from typing import Any, List, Set, Union, Generator, MutableMapping, Optional, Dict, Type, Mapping 

30 

31import copy 

32import re 

33import os 

34import yaml 

35import warnings 

36 

37from lsst.daf.butler import ButlerURI 

38 

39 

40class KeepInstrument: 

41 pass 

42 

43 

44class PipelineYamlLoader(yaml.SafeLoader): 

45 """This is a specialized version of yaml's SafeLoader. It checks and raises 

46 an exception if it finds that there are multiple instances of the same key 

47 found inside a pipeline file at a given scope. 

48 """ 

49 def construct_mapping(self, node, deep=False): 

50 # do the call to super first so that it can do all the other forms of 

51 # checking on this node. If you check the uniqueness of keys first 

52 # it would save the work that super does in the case of a failure, but 

53 # it might fail in the case that the node was the incorrect node due 

54 # to a parsing error, and the resulting exception would be difficult to 

55 # understand. 

56 mapping = super().construct_mapping(node, deep) 

57 # Check if there are any duplicate keys 

58 all_keys = Counter(key_node.value for key_node, _ in node.value) 

59 duplicates = {k for k, i in all_keys.items() if i != 1} 

60 if duplicates: 

61 raise KeyError("Pipeline files must not have duplicated keys, " 

62 f"{duplicates} appeared multiple times") 

63 return mapping 

64 

65 

66class ContractError(Exception): 

67 """An exception that is raised when a pipeline contract is not satisfied 

68 """ 

69 pass 

70 

71 

72@dataclass 

73class ContractIR: 

74 """Intermediate representation of contracts read from a pipeline yaml file. 

75 """ 

76 contract: str 

77 """A string of python code representing one or more conditions on configs 

78 in a pipeline. This code-as-string should, once evaluated, should be True 

79 if the configs are fine, and False otherwise. 

80 """ 

81 msg: Union[str, None] = None 

82 """An optional message to be shown to the user if a contract fails 

83 """ 

84 

85 def to_primitives(self) -> Dict[str, str]: 

86 """Convert to a representation used in yaml serialization 

87 """ 

88 accumulate = {"contract": self.contract} 

89 if self.msg is not None: 

90 accumulate['msg'] = self.msg 

91 return accumulate 

92 

93 def __eq__(self, other: object): 

94 if not isinstance(other, ContractIR): 

95 return False 

96 elif self.contract == other.contract and self.msg == other.msg: 

97 return True 

98 else: 

99 return False 

100 

101 

102@dataclass 

103class LabeledSubset: 

104 """Intermediate representation of named subset of task labels read from 

105 a pipeline yaml file. 

106 """ 

107 label: str 

108 """The label used to identify the subset of task labels. 

109 """ 

110 subset: Set[str] 

111 """A set of task labels contained in this subset. 

112 """ 

113 description: Optional[str] 

114 """A description of what this subset of tasks is intended to do 

115 """ 

116 

117 @staticmethod 

118 def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset: 

119 """Generate `LabeledSubset` objects given a properly formatted object 

120 that as been created by a yaml loader. 

121 

122 Parameters 

123 ---------- 

124 label : `str` 

125 The label that will be used to identify this labeled subset. 

126 value : `list` of `str` or `dict` 

127 Object returned from loading a labeled subset section from a yaml 

128 document. 

129 

130 Returns 

131 ------- 

132 labeledSubset : `LabeledSubset` 

133 A `LabeledSubset` object build from the inputs. 

134 

135 Raises 

136 ------ 

137 ValueError 

138 Raised if the value input is not properly formatted for parsing 

139 """ 

140 if isinstance(value, MutableMapping): 

141 subset = value.pop("subset", None) 

142 if subset is None: 

143 raise ValueError("If a labeled subset is specified as a mapping, it must contain the key " 

144 "'subset'") 

145 description = value.pop("description", None) 

146 elif isinstance(value, abcIterable): 

147 subset = value 

148 description = None 

149 else: 

150 raise ValueError(f"There was a problem parsing the labeled subset {label}, make sure the " 

151 "definition is either a valid yaml list, or a mapping with keys " 

152 "(subset, description) where subset points to a yaml list, and description is " 

153 "associated with a string") 

154 return LabeledSubset(label, set(subset), description) 

155 

156 def to_primitives(self) -> Dict[str, Union[List[str], str]]: 

157 """Convert to a representation used in yaml serialization 

158 """ 

159 accumulate: Dict[str, Union[List[str], str]] = {"subset": list(self.subset)} 

160 if self.description is not None: 

161 accumulate["description"] = self.description 

162 return accumulate 

163 

164 

165@dataclass 

166class ParametersIR: 

167 """Intermediate representation of parameters that are global to a pipeline 

168 

169 These parameters are specified under a top level key named `parameters` 

170 and are declared as a yaml mapping. These entries can then be used inside 

171 task configuration blocks to specify configuration values. They may not be 

172 used in the special ``file`` or ``python`` blocks. 

173 

174 Example: 

175 paramters: 

176 shared_value: 14 

177 tasks: 

178 taskA: 

179 class: modA 

180 config: 

181 field1: parameters.shared_value 

182 taskB: 

183 class: modB 

184 config: 

185 field2: parameters.shared_value 

186 """ 

187 mapping: MutableMapping[str, str] 

188 """A mutable mapping of identifiers as keys, and shared configuration 

189 as values. 

190 """ 

191 def update(self, other: Optional[ParametersIR]): 

192 if other is not None: 

193 self.mapping.update(other.mapping) 

194 

195 def to_primitives(self) -> MutableMapping[str, str]: 

196 """Convert to a representation used in yaml serialization 

197 """ 

198 return self.mapping 

199 

200 def __contains__(self, value: str) -> bool: 

201 return value in self.mapping 

202 

203 def __getitem__(self, item: str) -> Any: 

204 return self.mapping[item] 

205 

206 def __bool__(self) -> bool: 

207 return bool(self.mapping) 

208 

209 

210@dataclass 

211class ConfigIR: 

212 """Intermediate representation of configurations read from a pipeline yaml 

213 file. 

214 """ 

215 python: Union[str, None] = None 

216 """A string of python code that is used to modify a configuration. This can 

217 also be None if there are no modifications to do. 

218 """ 

219 dataId: Union[dict, None] = None 

220 """A dataId that is used to constrain these config overrides to only quanta 

221 with matching dataIds. This field can be None if there is no constraint. 

222 This is currently an unimplemented feature, and is placed here for future 

223 use. 

224 """ 

225 file: List[str] = field(default_factory=list) 

226 """A list of paths which points to a file containing config overrides to be 

227 applied. This value may be an empty list if there are no overrides to 

228 apply. 

229 """ 

230 rest: dict = field(default_factory=dict) 

231 """This is a dictionary of key value pairs, where the keys are strings 

232 corresponding to qualified fields on a config to override, and the values 

233 are strings representing the values to apply. 

234 """ 

235 

236 def to_primitives(self) -> Dict[str, Union[str, dict, List[str]]]: 

237 """Convert to a representation used in yaml serialization 

238 """ 

239 accumulate = {} 

240 for name in ("python", "dataId", "file"): 

241 # if this attribute is thruthy add it to the accumulation 

242 # dictionary 

243 if getattr(self, name): 

244 accumulate[name] = getattr(self, name) 

245 # Add the dictionary containing the rest of the config keys to the 

246 # # accumulated dictionary 

247 accumulate.update(self.rest) 

248 return accumulate 

249 

250 def formatted(self, parameters: ParametersIR) -> ConfigIR: 

251 """Returns a new ConfigIR object that is formatted according to the 

252 specified parameters 

253 

254 Parameters 

255 ---------- 

256 parameters : ParametersIR 

257 Object that contains variable mappings used in substitution. 

258 

259 Returns 

260 ------- 

261 config : ConfigIR 

262 A new ConfigIR object formatted with the input parameters 

263 """ 

264 new_config = copy.deepcopy(self) 

265 for key, value in new_config.rest.items(): 

266 if not isinstance(value, str): 

267 continue 

268 match = re.match("parameters[.](.*)", value) 

269 if match and match.group(1) in parameters: 

270 new_config.rest[key] = parameters[match.group(1)] 

271 if match and match.group(1) not in parameters: 

272 warnings.warn(f"config {key} contains value {match.group(0)} which is formatted like a " 

273 "Pipeline parameter but was not found within the Pipeline, if this was not " 

274 "intentional, check for a typo") 

275 return new_config 

276 

277 def maybe_merge(self, other_config: "ConfigIR") -> Generator["ConfigIR", None, None]: 

278 """Merges another instance of a `ConfigIR` into this instance if 

279 possible. This function returns a generator that is either self 

280 if the configs were merged, or self, and other_config if that could 

281 not be merged. 

282 

283 Parameters 

284 ---------- 

285 other_config : `ConfigIR` 

286 An instance of `ConfigIR` to merge into this instance. 

287 

288 Returns 

289 ------- 

290 Generator : `ConfigIR` 

291 A generator containing either self, or self and other_config if 

292 the configs could be merged or not respectively. 

293 """ 

294 # Verify that the config blocks can be merged 

295 if self.dataId != other_config.dataId or self.python or other_config.python or\ 

296 self.file or other_config.file: 

297 yield from (self, other_config) 

298 return 

299 

300 # create a set of all keys, and verify two keys do not have different 

301 # values 

302 key_union = self.rest.keys() & other_config.rest.keys() 

303 for key in key_union: 

304 if self.rest[key] != other_config.rest[key]: 

305 yield from (self, other_config) 

306 return 

307 self.rest.update(other_config.rest) 

308 

309 # Combine the lists of override files to load 

310 self_file_set = set(self.file) 

311 other_file_set = set(other_config.file) 

312 self.file = list(self_file_set.union(other_file_set)) 

313 

314 yield self 

315 

316 def __eq__(self, other: object): 

317 if not isinstance(other, ConfigIR): 

318 return False 

319 elif all(getattr(self, attr) == getattr(other, attr) for attr in 

320 ("python", "dataId", "file", "rest")): 

321 return True 

322 else: 

323 return False 

324 

325 

326@dataclass 

327class TaskIR: 

328 """Intermediate representation of tasks read from a pipeline yaml file. 

329 """ 

330 label: str 

331 """An identifier used to refer to a task. 

332 """ 

333 klass: str 

334 """A string containing a fully qualified python class to be run in a 

335 pipeline. 

336 """ 

337 config: Union[List[ConfigIR], None] = None 

338 """List of all configs overrides associated with this task, and may be 

339 `None` if there are no config overrides. 

340 """ 

341 

342 def to_primitives(self) -> Dict[str, Union[str, List[dict]]]: 

343 """Convert to a representation used in yaml serialization 

344 """ 

345 accumulate: Dict[str, Union[str, List[dict]]] = {'class': self.klass} 

346 if self.config: 

347 accumulate['config'] = [c.to_primitives() for c in self.config] 

348 return accumulate 

349 

350 def add_or_update_config(self, other_config: ConfigIR): 

351 """Adds a `ConfigIR` to this task if one is not present. Merges configs 

352 if there is a `ConfigIR` present and the dataId keys of both configs 

353 match, otherwise adds a new entry to the config list. The exception to 

354 the above is that if either the last config or other_config has a 

355 python block, then other_config is always added, as python blocks can 

356 modify configs in ways that cannot be predicted. 

357 

358 Parameters 

359 ---------- 

360 other_config : `ConfigIR` 

361 A `ConfigIR` instance to add or merge into the config attribute of 

362 this task. 

363 """ 

364 if not self.config: 

365 self.config = [other_config] 

366 return 

367 self.config.extend(self.config.pop().maybe_merge(other_config)) 

368 

369 def __eq__(self, other: object): 

370 if not isinstance(other, TaskIR): 

371 return False 

372 elif all(getattr(self, attr) == getattr(other, attr) for attr in 

373 ("label", "klass", "config")): 

374 return True 

375 else: 

376 return False 

377 

378 

379@dataclass 

380class ImportIR: 

381 """An intermediate representation of imported pipelines 

382 """ 

383 location: str 

384 """This is the location of the pipeline to inherit. The path should be 

385 specified as an absolute path. Environment variables may be used in the 

386 path and should be specified as a python string template, with the name of 

387 the environment variable inside braces. 

388 """ 

389 include: Union[List[str], None] = None 

390 """List of tasks that should be included when inheriting this pipeline. 

391 Either the include or exclude attributes may be specified, but not both. 

392 """ 

393 exclude: Union[List[str], None] = None 

394 """List of tasks that should be excluded when inheriting this pipeline. 

395 Either the include or exclude attributes may be specified, but not both. 

396 """ 

397 importContracts: bool = True 

398 """Boolean attribute to dictate if contracts should be inherited with the 

399 pipeline or not. 

400 """ 

401 instrument: Union[Type[KeepInstrument], str, None] = KeepInstrument 

402 """Instrument to assign to the Pipeline at import. The default value of 

403 KEEP_INSTRUMENT indicates that whatever instrument the pipeline is declared 

404 with will not be modified. Setting this value to None will drop any 

405 declared instrument prior to import. 

406 """ 

407 

408 def toPipelineIR(self) -> "PipelineIR": 

409 """Load in the Pipeline specified by this object, and turn it into a 

410 PipelineIR instance. 

411 

412 Returns 

413 ------- 

414 pipeline : `PipelineIR` 

415 A pipeline generated from the imported pipeline file 

416 """ 

417 if self.include and self.exclude: 

418 raise ValueError("Both an include and an exclude list cant be specified" 

419 " when declaring a pipeline import") 

420 tmp_pipeline = PipelineIR.from_uri(os.path.expandvars(self.location)) 

421 if self.instrument is not KeepInstrument: 

422 tmp_pipeline.instrument = self.instrument 

423 

424 included_labels = set() 

425 for label in tmp_pipeline.tasks: 

426 if (self.include and label in self.include) or (self.exclude and label not in self.exclude)\ 

427 or (self.include is None and self.exclude is None): 

428 included_labels.add(label) 

429 

430 # Handle labeled subsets being specified in the include or exclude 

431 # list, adding or removing labels. 

432 if self.include is not None: 

433 subsets_in_include = tmp_pipeline.labeled_subsets.keys() & self.include 

434 for label in subsets_in_include: 

435 included_labels.update(tmp_pipeline.labeled_subsets[label].subset) 

436 

437 elif self.exclude is not None: 

438 subsets_in_exclude = tmp_pipeline.labeled_subsets.keys() & self.exclude 

439 for label in subsets_in_exclude: 

440 included_labels.difference_update(tmp_pipeline.labeled_subsets[label].subset) 

441 

442 tmp_pipeline = tmp_pipeline.subset_from_labels(included_labels) 

443 

444 if not self.importContracts: 

445 tmp_pipeline.contracts = [] 

446 

447 return tmp_pipeline 

448 

449 def __eq__(self, other: object): 

450 if not isinstance(other, ImportIR): 

451 return False 

452 elif all(getattr(self, attr) == getattr(other, attr) for attr in 

453 ("location", "include", "exclude", "importContracts")): 

454 return True 

455 else: 

456 return False 

457 

458 

459class PipelineIR: 

460 """Intermediate representation of a pipeline definition 

461 

462 Parameters 

463 ---------- 

464 loaded_yaml : `dict` 

465 A dictionary which matches the structure that would be produced by a 

466 yaml reader which parses a pipeline definition document 

467 

468 Raises 

469 ------ 

470 ValueError : 

471 - If a pipeline is declared without a description 

472 - If no tasks are declared in a pipeline, and no pipelines are to be 

473 inherited 

474 - If more than one instrument is specified 

475 - If more than one inherited pipeline share a label 

476 """ 

477 def __init__(self, loaded_yaml): 

478 # Check required fields are present 

479 if "description" not in loaded_yaml: 

480 raise ValueError("A pipeline must be declared with a description") 

481 if "tasks" not in loaded_yaml and len({"imports", "inherits"} - loaded_yaml.keys()) == 2: 

482 raise ValueError("A pipeline must be declared with one or more tasks") 

483 

484 # These steps below must happen in this call order 

485 

486 # Process pipeline description 

487 self.description = loaded_yaml.pop("description") 

488 

489 # Process tasks 

490 self._read_tasks(loaded_yaml) 

491 

492 # Process instrument keys 

493 inst = loaded_yaml.pop("instrument", None) 

494 if isinstance(inst, list): 

495 raise ValueError("Only one top level instrument can be defined in a pipeline") 

496 self.instrument = inst 

497 

498 # Process any contracts 

499 self._read_contracts(loaded_yaml) 

500 

501 # Process any defined parameters 

502 self._read_parameters(loaded_yaml) 

503 

504 # Process any named label subsets 

505 self._read_labeled_subsets(loaded_yaml) 

506 

507 # Process any inherited pipelines 

508 self._read_imports(loaded_yaml) 

509 

510 # verify named subsets, must be done after inheriting 

511 self._verify_labeled_subsets() 

512 

513 def _read_contracts(self, loaded_yaml): 

514 """Process the contracts portion of the loaded yaml document 

515 

516 Parameters 

517 --------- 

518 loaded_yaml : `dict` 

519 A dictionary which matches the structure that would be produced by 

520 a yaml reader which parses a pipeline definition document 

521 """ 

522 loaded_contracts = loaded_yaml.pop("contracts", []) 

523 if isinstance(loaded_contracts, str): 

524 loaded_contracts = [loaded_contracts] 

525 self.contracts = [] 

526 for contract in loaded_contracts: 

527 if isinstance(contract, dict): 

528 self.contracts.append(ContractIR(**contract)) 

529 if isinstance(contract, str): 

530 self.contracts.append(ContractIR(contract=contract)) 

531 

532 def _read_parameters(self, loaded_yaml): 

533 """Process the parameters portion of the loaded yaml document 

534 

535 Parameters 

536 --------- 

537 loaded_yaml : `dict` 

538 A dictionary which matches the structure that would be produced by 

539 a yaml reader which parses a pipeline definition document 

540 """ 

541 loaded_parameters = loaded_yaml.pop("parameters", {}) 

542 if not isinstance(loaded_parameters, dict): 

543 raise ValueError("The parameters section must be a yaml mapping") 

544 self.parameters = ParametersIR(loaded_parameters) 

545 

546 def _read_labeled_subsets(self, loaded_yaml: dict): 

547 """Process the subsets portion of the loaded yaml document 

548 

549 Parameters 

550 ---------- 

551 loaded_yaml: `MutableMapping` 

552 A dictionary which matches the structure that would be produced 

553 by a yaml reader which parses a pipeline definition document 

554 """ 

555 loaded_subsets = loaded_yaml.pop("subsets", {}) 

556 self.labeled_subsets = {} 

557 if not loaded_subsets and "subset" in loaded_yaml: 

558 raise ValueError("Top level key should be subsets and not subset, add an s") 

559 for key, value in loaded_subsets.items(): 

560 self.labeled_subsets[key] = LabeledSubset.from_primitives(key, value) 

561 

562 def _verify_labeled_subsets(self): 

563 """Verifies that all the labels in each named subset exist within the 

564 pipeline. 

565 """ 

566 # Verify that all labels defined in a labeled subset are in the 

567 # Pipeline 

568 for labeled_subset in self.labeled_subsets.values(): 

569 if not labeled_subset.subset.issubset(self.tasks.keys()): 

570 raise ValueError(f"Labels {labeled_subset.subset - self.tasks.keys()} were not found in the " 

571 "declared pipeline") 

572 # Verify subset labels are not already task labels 

573 label_intersection = self.labeled_subsets.keys() & self.tasks.keys() 

574 if label_intersection: 

575 raise ValueError(f"Labeled subsets can not use the same label as a task: {label_intersection}") 

576 

577 def _read_imports(self, loaded_yaml): 

578 """Process the inherits portion of the loaded yaml document 

579 

580 Parameters 

581 --------- 

582 loaded_yaml : `dict` 

583 A dictionary which matches the structure that would be produced by 

584 a yaml reader which parses a pipeline definition document 

585 """ 

586 def process_args(argument: Union[str, dict]) -> dict: 

587 if isinstance(argument, str): 

588 return {"location": argument} 

589 elif isinstance(argument, dict): 

590 if "exclude" in argument and isinstance(argument["exclude"], str): 

591 argument["exclude"] = [argument["exclude"]] 

592 if "include" in argument and isinstance(argument["include"], str): 

593 argument["include"] = [argument["include"]] 

594 if "instrument" in argument and argument["instrument"] == "None": 

595 argument["instrument"] = None 

596 return argument 

597 if not {"inherits", "imports"} - loaded_yaml.keys(): 

598 raise ValueError("Cannot define both inherits and imports sections, use imports") 

599 tmp_import = loaded_yaml.pop("inherits", None) 

600 if tmp_import is None: 

601 tmp_import = loaded_yaml.pop("imports", None) 

602 else: 

603 warnings.warn("The 'inherits' key is deprecated, and will be " 

604 "removed around June 2021. Please use the key " 

605 "'imports' instead") 

606 if tmp_import is None: 

607 self.imports = [] 

608 elif isinstance(tmp_import, list): 

609 self.imports = [ImportIR(**process_args(args)) for args in tmp_import] 

610 else: 

611 self.imports = [ImportIR(**process_args(tmp_import))] 

612 

613 # integrate any imported pipelines 

614 accumulate_tasks = {} 

615 accumulate_labeled_subsets = {} 

616 accumulated_parameters = ParametersIR({}) 

617 for other_pipeline in self.imports: 

618 tmp_IR = other_pipeline.toPipelineIR() 

619 if self.instrument is None: 

620 self.instrument = tmp_IR.instrument 

621 elif self.instrument != tmp_IR.instrument and tmp_IR.instrument is not None: 

622 msg = ("Only one instrument can be declared in a pipeline or its imports. " 

623 f"Top level pipeline defines {self.instrument} but {other_pipeline.location} " 

624 f"defines {tmp_IR.instrument}.") 

625 raise ValueError(msg) 

626 if duplicate_labels := accumulate_tasks.keys() & tmp_IR.tasks.keys(): 

627 msg = ("Task labels in the imported pipelines must be unique. " 

628 f"These labels appear multiple times: {duplicate_labels}") 

629 raise ValueError(msg) 

630 accumulate_tasks.update(tmp_IR.tasks) 

631 self.contracts.extend(tmp_IR.contracts) 

632 # verify that tmp_IR has unique labels for named subset among 

633 # existing labeled subsets, and with existing task labels. 

634 overlapping_subsets = accumulate_labeled_subsets.keys() & tmp_IR.labeled_subsets.keys() 

635 task_subset_overlap = ((accumulate_labeled_subsets.keys() | tmp_IR.labeled_subsets.keys()) 

636 & accumulate_tasks.keys()) 

637 if overlapping_subsets or task_subset_overlap: 

638 raise ValueError("Labeled subset names must be unique amongst imports in both labels and " 

639 f" named Subsets. Duplicate: {overlapping_subsets | task_subset_overlap}") 

640 accumulate_labeled_subsets.update(tmp_IR.labeled_subsets) 

641 accumulated_parameters.update(tmp_IR.parameters) 

642 

643 # verify that any accumulated labeled subsets dont clash with a label 

644 # from this pipeline 

645 if accumulate_labeled_subsets.keys() & self.tasks.keys(): 

646 raise ValueError("Labeled subset names must be unique amongst imports in both labels and " 

647 " named Subsets") 

648 # merge in the named subsets for self so this document can override any 

649 # that have been delcared 

650 accumulate_labeled_subsets.update(self.labeled_subsets) 

651 self.labeled_subsets = accumulate_labeled_subsets 

652 

653 # merge the dict of label:TaskIR objects, preserving any configs in the 

654 # imported pipeline if the labels point to the same class 

655 for label, task in self.tasks.items(): 

656 if label not in accumulate_tasks: 

657 accumulate_tasks[label] = task 

658 elif accumulate_tasks[label].klass == task.klass: 

659 if task.config is not None: 

660 for config in task.config: 

661 accumulate_tasks[label].add_or_update_config(config) 

662 else: 

663 accumulate_tasks[label] = task 

664 self.tasks = accumulate_tasks 

665 accumulated_parameters.update(self.parameters) 

666 self.parameters = accumulated_parameters 

667 

668 def _read_tasks(self, loaded_yaml): 

669 """Process the tasks portion of the loaded yaml document 

670 

671 Parameters 

672 --------- 

673 loaded_yaml : `dict` 

674 A dictionary which matches the structure that would be produced by 

675 a yaml reader which parses a pipeline definition document 

676 """ 

677 self.tasks = {} 

678 tmp_tasks = loaded_yaml.pop("tasks", None) 

679 if tmp_tasks is None: 

680 tmp_tasks = {} 

681 

682 if "parameters" in tmp_tasks: 

683 raise ValueError("parameters is a reserved word and cannot be used as a task label") 

684 

685 for label, definition in tmp_tasks.items(): 

686 if isinstance(definition, str): 

687 definition = {"class": definition} 

688 config = definition.get('config', None) 

689 if config is None: 

690 task_config_ir = None 

691 else: 

692 if isinstance(config, dict): 

693 config = [config] 

694 task_config_ir = [] 

695 for c in config: 

696 file = c.pop("file", None) 

697 if file is None: 

698 file = [] 

699 elif not isinstance(file, list): 

700 file = [file] 

701 task_config_ir.append(ConfigIR(python=c.pop("python", None), 

702 dataId=c.pop("dataId", None), 

703 file=file, 

704 rest=c)) 

705 self.tasks[label] = TaskIR(label, definition["class"], task_config_ir) 

706 

707 def _remove_contracts(self, label: str): 

708 """Remove any contracts that contain the given label 

709 

710 String comparison used in this way is not the most elegant and may 

711 have issues, but it is the only feasible way when users can specify 

712 contracts with generic strings. 

713 """ 

714 new_contracts = [] 

715 for contract in self.contracts: 

716 # match a label that is not preceded by an ASCII identifier, or 

717 # is the start of a line and is followed by a dot 

718 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

719 continue 

720 new_contracts.append(contract) 

721 self.contracts = new_contracts 

722 

723 def subset_from_labels(self, labelSpecifier: Set[str]) -> PipelineIR: 

724 """Subset a pipelineIR to contain only labels specified in 

725 labelSpecifier. 

726 

727 Parameters 

728 ---------- 

729 labelSpecifier : `set` of `str` 

730 Set containing labels that describes how to subset a pipeline. 

731 

732 Returns 

733 ------- 

734 pipeline : `PipelineIR` 

735 A new pipelineIR object that is a subset of the old pipelineIR 

736 

737 Raises 

738 ------ 

739 ValueError 

740 Raised if there is an issue with specified labels 

741 

742 Notes 

743 ----- 

744 This method attempts to prune any contracts that contain labels which 

745 are not in the declared subset of labels. This pruning is done using a 

746 string based matching due to the nature of contracts and may prune more 

747 than it should. Any labeled subsets defined that no longer have all 

748 members of the subset present in the pipeline will be removed from the 

749 resulting pipeline. 

750 """ 

751 

752 pipeline = copy.deepcopy(self) 

753 

754 # update the label specifier to expand any named subsets 

755 toRemove = set() 

756 toAdd = set() 

757 for label in labelSpecifier: 

758 if label in pipeline.labeled_subsets: 

759 toRemove.add(label) 

760 toAdd.update(pipeline.labeled_subsets[label].subset) 

761 labelSpecifier.difference_update(toRemove) 

762 labelSpecifier.update(toAdd) 

763 # verify all the labels are in the pipeline 

764 if not labelSpecifier.issubset(pipeline.tasks.keys() 

765 | pipeline.labeled_subsets): 

766 difference = labelSpecifier.difference(pipeline.tasks.keys()) 

767 raise ValueError("Not all supplied labels (specified or named subsets) are in the pipeline " 

768 f"definition, extra labels: {difference}") 

769 # copy needed so as to not modify while iterating 

770 pipeline_labels = set(pipeline.tasks.keys()) 

771 # Remove the labels from the pipelineIR, and any contracts that contain 

772 # those labels (see docstring on _remove_contracts for why this may 

773 # cause issues) 

774 for label in pipeline_labels: 

775 if label not in labelSpecifier: 

776 pipeline.tasks.pop(label) 

777 pipeline._remove_contracts(label) 

778 

779 # create a copy of the object to iterate over 

780 labeled_subsets = copy.copy(pipeline.labeled_subsets) 

781 # remove any labeled subsets that no longer have a complete set 

782 for label, labeled_subset in labeled_subsets.items(): 

783 if labeled_subset.subset - pipeline.tasks.keys(): 

784 pipeline.labeled_subsets.pop(label) 

785 

786 return pipeline 

787 

788 @classmethod 

789 def from_string(cls, pipeline_string: str): 

790 """Create a `PipelineIR` object from a string formatted like a pipeline 

791 document 

792 

793 Parameters 

794 ---------- 

795 pipeline_string : `str` 

796 A string that is formatted according like a pipeline document 

797 """ 

798 loaded_yaml = yaml.load(pipeline_string, Loader=PipelineYamlLoader) 

799 return cls(loaded_yaml) 

800 

801 @classmethod 

802 @deprecated(reason="This has been replaced with `from_uri`. will be removed after v23", 

803 version="v21.0,", category=FutureWarning) # type: ignore 

804 def from_file(cls, filename: str) -> PipelineIR: 

805 """Create a `PipelineIR` object from the document specified by the 

806 input path. 

807 

808 Parameters 

809 ---------- 

810 filename : `str` 

811 Location of document to use in creating a `PipelineIR` object. 

812 

813 Returns 

814 ------- 

815 pipelineIR : `PipelineIR` 

816 The loaded pipeline 

817 

818 Note 

819 ---- 

820 This method is deprecated, please use from_uri 

821 """ 

822 return cls.from_uri(filename) 

823 

824 @classmethod 

825 def from_uri(cls, uri: Union[str, ButlerURI]) -> PipelineIR: 

826 """Create a `PipelineIR` object from the document specified by the 

827 input uri. 

828 

829 Parameters 

830 ---------- 

831 uri: `str` or `ButlerURI` 

832 Location of document to use in creating a `PipelineIR` object. 

833 

834 Returns 

835 ------- 

836 pipelineIR : `PipelineIR` 

837 The loaded pipeline 

838 """ 

839 loaded_uri = ButlerURI(uri) 

840 # With ButlerURI we have the choice of always using a local file or 

841 # reading in the bytes directly. Reading in bytes can be more 

842 # efficient for reasonably-sized files when the resource is remote. 

843 # For now use the local file variant. For a local file as_local() does 

844 # nothing. 

845 with loaded_uri.as_local() as local: 

846 # explicitly read here, there was some issue with yaml trying 

847 # to read the ButlerURI itself (I think because it only 

848 # pretends to be conformant to the io api) 

849 loaded_yaml = yaml.load(local.read(), Loader=PipelineYamlLoader) 

850 return cls(loaded_yaml) 

851 

852 @deprecated(reason="This has been replaced with `write_to_uri`. will be removed after v23", 

853 version="v21.0,", category=FutureWarning) # type: ignore 

854 def to_file(self, filename: str): 

855 """Serialize this `PipelineIR` object into a yaml formatted string and 

856 write the output to a file at the specified path. 

857 

858 Parameters 

859 ---------- 

860 filename : `str` 

861 Location of document to write a `PipelineIR` object. 

862 """ 

863 self.write_to_uri(filename) 

864 

865 def write_to_uri(self, uri: Union[ButlerURI, str]): 

866 """Serialize this `PipelineIR` object into a yaml formatted string and 

867 write the output to a file at the specified uri. 

868 

869 Parameters 

870 ---------- 

871 uri: `str` or `ButlerURI` 

872 Location of document to write a `PipelineIR` object. 

873 """ 

874 butlerUri = ButlerURI(uri) 

875 butlerUri.write(yaml.dump(self.to_primitives(), sort_keys=False).encode()) 

876 

877 def to_primitives(self) -> Dict[str, Any]: 

878 """Convert to a representation used in yaml serialization 

879 """ 

880 accumulate = {"description": self.description} 

881 if self.instrument is not None: 

882 accumulate['instrument'] = self.instrument 

883 if self.parameters: 

884 accumulate['parameters'] = self._sort_by_str(self.parameters.to_primitives()) 

885 accumulate['tasks'] = {m: t.to_primitives() for m, t in self.tasks.items()} 

886 if len(self.contracts) > 0: 

887 # sort contracts lexicographical order by the contract string in 

888 # absence of any other ordering principle 

889 contracts_list = [c.to_primitives() for c in self.contracts] 

890 contracts_list.sort(key=lambda x: x["contract"]) 

891 accumulate['contracts'] = contracts_list 

892 if self.labeled_subsets: 

893 accumulate['subsets'] = self._sort_by_str({k: v.to_primitives() for 

894 k, v in self.labeled_subsets.items()}) 

895 return accumulate 

896 

897 def reorder_tasks(self, task_labels: List[str]) -> None: 

898 """Changes the order tasks are stored internally. Useful for 

899 determining the order things will appear in the serialized (or printed) 

900 form. 

901 

902 Parameters 

903 ---------- 

904 task_labels : `list` of `str` 

905 A list corresponding to all the labels in the pipeline inserted in 

906 the order the tasks are to be stored. 

907 

908 Raises 

909 ------ 

910 KeyError 

911 Raised if labels are supplied that are not in the pipeline, or if 

912 not all labels in the pipeline were supplied in task_labels input. 

913 """ 

914 # verify that all labels are in the input 

915 _tmp_set = set(task_labels) 

916 if remainder := (self.tasks.keys() - _tmp_set): 

917 raise KeyError(f"Label(s) {remainder} are missing from the task label list") 

918 if extra := (_tmp_set - self.tasks.keys()): 

919 raise KeyError(f"Extra label(s) {extra} were in the input and are not in the pipeline") 

920 

921 newTasks = {key: self.tasks[key] for key in task_labels} 

922 self.tasks = newTasks 

923 

924 @staticmethod 

925 def _sort_by_str(arg: Mapping[str, Any]) -> Mapping[str, Any]: 

926 keys = sorted(arg.keys()) 

927 return {key: arg[key] for key in keys} 

928 

929 def __str__(self) -> str: 

930 """Instance formatting as how it would look in yaml representation 

931 """ 

932 return yaml.dump(self.to_primitives(), sort_keys=False) 

933 

934 def __repr__(self) -> str: 

935 """Instance formatting as how it would look in yaml representation 

936 """ 

937 return str(self) 

938 

939 def __eq__(self, other: object): 

940 if not isinstance(other, PipelineIR): 

941 return False 

942 # special case contracts because it is a list, but order is not 

943 # important 

944 elif (all(getattr(self, attr) == getattr(other, attr) for attr in 

945 ("tasks", "instrument", "labeled_subsets", "parameters")) 

946 and len(self.contracts) == len(other.contracts) 

947 and all(c in self.contracts for c in other.contracts)): 

948 return True 

949 else: 

950 return False