Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23"""Module defining Pipeline class and related methods. 

24""" 

25 

26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31from dataclasses import dataclass 

32from types import MappingProxyType 

33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional 

34 

35import copy 

36import os 

37 

38# ----------------------------- 

39# Imports for other modules -- 

40from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension 

41from lsst.utils import doImport 

42from .configOverrides import ConfigOverrides 

43from .connections import iterConnections 

44from .pipelineTask import PipelineTask 

45 

46from . import pipelineIR 

47from . import pipeTools 

48 

49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 from lsst.obs.base.instrument import Instrument 

51 

52# ---------------------------------- 

53# Local non-exported definitions -- 

54# ---------------------------------- 

55 

56# ------------------------ 

57# Exported definitions -- 

58# ------------------------ 

59 

60 

61@dataclass 

62class LabelSpecifier: 

63 """A structure to specify a subset of labels to load 

64 

65 This structure may contain a set of labels to be used in subsetting a 

66 pipeline, or a beginning and end point. Beginning or end may be empty, 

67 in which case the range will be a half open interval. Unlike python 

68 iteration bounds, end bounds are *INCLUDED*. Note that range based 

69 selection is not well defined for pipelines that are not linear in nature, 

70 and correct behavior is not guaranteed, or may vary from run to run. 

71 """ 

72 labels: Optional[Set[str]] = None 

73 begin: Optional[str] = None 

74 end: Optional[str] = None 

75 

76 def __post_init__(self): 

77 if self.labels is not None and (self.begin or self.end): 

78 raise ValueError("This struct can only be initialized with a labels set or " 

79 "a begin (and/or) end specifier") 

80 

81 

82class TaskDef: 

83 """TaskDef is a collection of information about task needed by Pipeline. 

84 

85 The information includes task name, configuration object and optional 

86 task class. This class is just a collection of attributes and it exposes 

87 all of them so that attributes could potentially be modified in place 

88 (e.g. if configuration needs extra overrides). 

89 

90 Attributes 

91 ---------- 

92 taskName : `str` 

93 `PipelineTask` class name, currently it is not specified whether this 

94 is a fully-qualified name or partial name (e.g. ``module.TaskClass``). 

95 Framework should be prepared to handle all cases. 

96 config : `lsst.pex.config.Config` 

97 Instance of the configuration class corresponding to this task class, 

98 usually with all overrides applied. This config will be frozen. 

99 taskClass : `type` or ``None`` 

100 `PipelineTask` class object, can be ``None``. If ``None`` then 

101 framework will have to locate and load class. 

102 label : `str`, optional 

103 Task label, usually a short string unique in a pipeline. 

104 """ 

105 def __init__(self, taskName, config, taskClass=None, label=""): 

106 self.taskName = taskName 

107 config.freeze() 

108 self.config = config 

109 self.taskClass = taskClass 

110 self.label = label 

111 self.connections = config.connections.ConnectionsClass(config=config) 

112 

113 @property 

114 def configDatasetName(self): 

115 """Name of a dataset type for configuration of this task (`str`) 

116 """ 

117 return self.label + "_config" 

118 

119 @property 

120 def metadataDatasetName(self): 

121 """Name of a dataset type for metadata of this task, `None` if 

122 metadata is not to be saved (`str`) 

123 """ 

124 if self.config.saveMetadata: 

125 return self.label + "_metadata" 

126 else: 

127 return None 

128 

129 def __str__(self): 

130 rep = "TaskDef(" + self.taskName 

131 if self.label: 

132 rep += ", label=" + self.label 

133 rep += ")" 

134 return rep 

135 

136 def __eq__(self, other: object) -> bool: 

137 if not isinstance(other, TaskDef): 

138 return False 

139 return self.config == other.config and\ 

140 self.taskClass == other.taskClass and\ 

141 self.label == other.label 

142 

143 def __hash__(self): 

144 return hash((self.taskClass, self.label)) 

145 

146 

147class Pipeline: 

148 """A `Pipeline` is a representation of a series of tasks to run, and the 

149 configuration for those tasks. 

150 

151 Parameters 

152 ---------- 

153 description : `str` 

154 A description of that this pipeline does. 

155 """ 

156 def __init__(self, description: str): 

157 pipeline_dict = {"description": description, "tasks": {}} 

158 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

159 

160 @classmethod 

161 def fromFile(cls, filename: str) -> Pipeline: 

162 """Load a pipeline defined in a pipeline yaml file. 

163 

164 Parameters 

165 ---------- 

166 filename: `str` 

167 A path that points to a pipeline defined in yaml format. This 

168 filename may also supply additional labels to be used in 

169 subsetting the loaded Pipeline. These labels are separated from 

170 the path by a colon, and may be specified as a comma separated 

171 list, or a range denoted as beginning..end. Beginning or end may 

172 be empty, in which case the range will be a half open interval. 

173 Unlike python iteration bounds, end bounds are *INCLUDED*. Note 

174 that range based selection is not well defined for pipelines that 

175 are not linear in nature, and correct behavior is not guaranteed, 

176 or may vary from run to run. 

177 

178 Returns 

179 ------- 

180 pipeline: `Pipeline` 

181 The pipeline loaded from specified location with appropriate (if 

182 any) subsetting 

183 

184 Notes 

185 ----- 

186 This method attempts to prune any contracts that contain labels which 

187 are not in the declared subset of labels. This pruning is done using a 

188 string based matching due to the nature of contracts and may prune more 

189 than it should. 

190 """ 

191 # Split up the filename and any labels that were supplied 

192 filename, labelSpecifier = cls._parseFileSpecifier(filename) 

193 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename)) 

194 

195 # If there are labels supplied, only keep those 

196 if labelSpecifier is not None: 

197 pipeline = pipeline.subsetFromLabels(labelSpecifier) 

198 return pipeline 

199 

200 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline: 

201 """Subset a pipeline to contain only labels specified in labelSpecifier 

202 

203 Parameters 

204 ---------- 

205 labelSpecifier : `labelSpecifier` 

206 Object containing labels that describes how to subset a pipeline. 

207 

208 Returns 

209 ------- 

210 pipeline : `Pipeline` 

211 A new pipeline object that is a subset of the old pipeline 

212 

213 Raises 

214 ------ 

215 ValueError 

216 Raised if there is an issue with specified labels 

217 

218 Notes 

219 ----- 

220 This method attempts to prune any contracts that contain labels which 

221 are not in the declared subset of labels. This pruning is done using a 

222 string based matching due to the nature of contracts and may prune more 

223 than it should. 

224 """ 

225 # Labels supplied as a set 

226 if labelSpecifier.labels: 

227 labelSet = labelSpecifier.labels 

228 # Labels supplied as a range, first create a list of all the labels 

229 # in the pipeline sorted according to task dependency. Then only 

230 # keep labels that lie between the supplied bounds 

231 else: 

232 # Create a copy of the pipeline to use when assessing the label 

233 # ordering. Use a dict for fast searching while preserving order. 

234 # Remove contracts so they do not fail in the expansion step. This 

235 # is needed because a user may only configure the tasks they intend 

236 # to run, which may cause some contracts to fail if they will later 

237 # be dropped 

238 pipeline = copy.deepcopy(self) 

239 pipeline._pipelineIR.contracts = [] 

240 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()} 

241 

242 # Verify the bounds are in the labels 

243 if labelSpecifier.begin is not None: 

244 if labelSpecifier.begin not in labels: 

245 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in " 

246 "pipeline definition") 

247 if labelSpecifier.end is not None: 

248 if labelSpecifier.end not in labels: 

249 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline " 

250 "definition") 

251 

252 labelSet = set() 

253 for label in labels: 

254 if labelSpecifier.begin is not None: 

255 if label != labelSpecifier.begin: 

256 continue 

257 else: 

258 labelSpecifier.begin = None 

259 labelSet.add(label) 

260 if labelSpecifier.end is not None and label == labelSpecifier.end: 

261 break 

262 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet)) 

263 

264 @staticmethod 

265 def _parseFileSpecifier(fileSpecifer): 

266 """Split appart a filename path from label subsets 

267 """ 

268 split = fileSpecifer.split(':') 

269 # There is only a filename, return just that 

270 if len(split) == 1: 

271 return fileSpecifer, None 

272 # More than one specifier provided, bail out 

273 if len(split) > 2: 

274 raise ValueError("Only one : is allowed when specifying a pipeline to load") 

275 else: 

276 labelSubset: str 

277 filename: str 

278 filename, labelSubset = split[0], split[1] 

279 # labels supplied as a list 

280 if ',' in labelSubset: 

281 if '..' in labelSubset: 

282 raise ValueError("Can only specify a list of labels or a range" 

283 "when loading a Pipline not both") 

284 labels = set(labelSubset.split(",")) 

285 specifier = LabelSpecifier(labels=labels) 

286 # labels supplied as a range 

287 elif '..' in labelSubset: 

288 # Try to destructure the labelSubset, this will fail if more 

289 # than one range is specified 

290 try: 

291 begin, end = labelSubset.split("..") 

292 except ValueError: 

293 raise ValueError("Only one range can be specified when loading a pipeline") 

294 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None) 

295 # Assume anything else is a single label 

296 else: 

297 labels = {labelSubset} 

298 specifier = LabelSpecifier(labels=labels) 

299 

300 return filename, specifier 

301 

302 @classmethod 

303 def fromString(cls, pipeline_string: str) -> Pipeline: 

304 """Create a pipeline from string formatted as a pipeline document. 

305 

306 Parameters 

307 ---------- 

308 pipeline_string : `str` 

309 A string that is formatted according like a pipeline document 

310 

311 Returns 

312 ------- 

313 pipeline: `Pipeline` 

314 """ 

315 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

316 return pipeline 

317 

318 @classmethod 

319 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

320 """Create a pipeline from an already created `PipelineIR` object. 

321 

322 Parameters 

323 ---------- 

324 deserialized_pipeline: `PipelineIR` 

325 An already created pipeline intermediate representation object 

326 

327 Returns 

328 ------- 

329 pipeline: `Pipeline` 

330 """ 

331 pipeline = cls.__new__(cls) 

332 pipeline._pipelineIR = deserialized_pipeline 

333 return pipeline 

334 

335 @classmethod 

336 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline: 

337 """Create a new pipeline by copying an already existing `Pipeline`. 

338 

339 Parameters 

340 ---------- 

341 pipeline: `Pipeline` 

342 An already created pipeline intermediate representation object 

343 

344 Returns 

345 ------- 

346 pipeline: `Pipeline` 

347 """ 

348 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR)) 

349 

350 def __str__(self) -> str: 

351 return str(self._pipelineIR) 

352 

353 def addInstrument(self, instrument: Union[Instrument, str]): 

354 """Add an instrument to the pipeline, or replace an instrument that is 

355 already defined. 

356 

357 Parameters 

358 ---------- 

359 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

360 Either a derived class object of a `lsst.daf.butler.instrument` or 

361 a string corresponding to a fully qualified 

362 `lsst.daf.butler.instrument` name. 

363 """ 

364 if isinstance(instrument, str): 

365 pass 

366 else: 

367 # TODO: assume that this is a subclass of Instrument, no type 

368 # checking 

369 instrument = f"{instrument.__module__}.{instrument.__qualname__}" 

370 self._pipelineIR.instrument = instrument 

371 

372 def getInstrument(self): 

373 """Get the instrument from the pipeline. 

374 

375 Returns 

376 ------- 

377 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None 

378 A derived class object of a `lsst.daf.butler.instrument`, a string 

379 corresponding to a fully qualified `lsst.daf.butler.instrument` 

380 name, or None if the pipeline does not have an instrument. 

381 """ 

382 return self._pipelineIR.instrument 

383 

384 def addTask(self, task: Union[PipelineTask, str], label: str): 

385 """Add a new task to the pipeline, or replace a task that is already 

386 associated with the supplied label. 

387 

388 Parameters 

389 ---------- 

390 task: `PipelineTask` or `str` 

391 Either a derived class object of a `PipelineTask` or a string 

392 corresponding to a fully qualified `PipelineTask` name. 

393 label: `str` 

394 A label that is used to identify the `PipelineTask` being added 

395 """ 

396 if isinstance(task, str): 

397 taskName = task 

398 elif issubclass(task, PipelineTask): 

399 taskName = f"{task.__module__}.{task.__qualname__}" 

400 else: 

401 raise ValueError("task must be either a child class of PipelineTask or a string containing" 

402 " a fully qualified name to one") 

403 if not label: 

404 # in some cases (with command line-generated pipeline) tasks can 

405 # be defined without label which is not acceptable, use task 

406 # _DefaultName in that case 

407 if isinstance(task, str): 

408 task = doImport(task) 

409 label = task._DefaultName 

410 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

411 

412 def removeTask(self, label: str): 

413 """Remove a task from the pipeline. 

414 

415 Parameters 

416 ---------- 

417 label : `str` 

418 The label used to identify the task that is to be removed 

419 

420 Raises 

421 ------ 

422 KeyError 

423 If no task with that label exists in the pipeline 

424 

425 """ 

426 self._pipelineIR.tasks.pop(label) 

427 

428 def addConfigOverride(self, label: str, key: str, value: object): 

429 """Apply single config override. 

430 

431 Parameters 

432 ---------- 

433 label : `str` 

434 Label of the task. 

435 key: `str` 

436 Fully-qualified field name. 

437 value : object 

438 Value to be given to a field. 

439 """ 

440 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

441 

442 def addConfigFile(self, label: str, filename: str): 

443 """Add overrides from a specified file. 

444 

445 Parameters 

446 ---------- 

447 label : `str` 

448 The label used to identify the task associated with config to 

449 modify 

450 filename : `str` 

451 Path to the override file. 

452 """ 

453 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

454 

455 def addConfigPython(self, label: str, pythonString: str): 

456 """Add Overrides by running a snippet of python code against a config. 

457 

458 Parameters 

459 ---------- 

460 label : `str` 

461 The label used to identity the task associated with config to 

462 modify. 

463 pythonString: `str` 

464 A string which is valid python code to be executed. This is done 

465 with config as the only local accessible value. 

466 """ 

467 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

468 

469 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR): 

470 if label == "parameters": 

471 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys(): 

472 raise ValueError("Cannot override parameters that are not defined in pipeline") 

473 self._pipelineIR.parameters.mapping.update(newConfig.rest) 

474 if newConfig.file: 

475 raise ValueError("Setting parameters section with config file is not supported") 

476 if newConfig.python: 

477 raise ValueError("Setting parameters section using python block in unsupported") 

478 return 

479 if label not in self._pipelineIR.tasks: 

480 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

481 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

482 

483 def toFile(self, filename: str): 

484 self._pipelineIR.to_file(filename) 

485 

486 def toExpandedPipeline(self) -> Generator[TaskDef]: 

487 """Returns a generator of TaskDefs which can be used to create quantum 

488 graphs. 

489 

490 Returns 

491 ------- 

492 generator : generator of `TaskDef` 

493 The generator returned will be the sorted iterator of tasks which 

494 are to be used in constructing a quantum graph. 

495 

496 Raises 

497 ------ 

498 NotImplementedError 

499 If a dataId is supplied in a config block. This is in place for 

500 future use 

501 """ 

502 taskDefs = [] 

503 for label, taskIR in self._pipelineIR.tasks.items(): 

504 taskClass = doImport(taskIR.klass) 

505 taskName = taskClass.__qualname__ 

506 config = taskClass.ConfigClass() 

507 overrides = ConfigOverrides() 

508 if self._pipelineIR.instrument is not None: 

509 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName) 

510 if taskIR.config is not None: 

511 for configIR in (configIr.formatted(self._pipelineIR.parameters) 

512 for configIr in taskIR.config): 

513 if configIR.dataId is not None: 

514 raise NotImplementedError("Specializing a config on a partial data id is not yet " 

515 "supported in Pipeline definition") 

516 # only apply override if it applies to everything 

517 if configIR.dataId is None: 

518 if configIR.file: 

519 for configFile in configIR.file: 

520 overrides.addFileOverride(os.path.expandvars(configFile)) 

521 if configIR.python is not None: 

522 overrides.addPythonOverride(configIR.python) 

523 for key, value in configIR.rest.items(): 

524 overrides.addValueOverride(key, value) 

525 overrides.applyTo(config) 

526 # This may need to be revisited 

527 config.validate() 

528 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)) 

529 

530 # lets evaluate the contracts 

531 if self._pipelineIR.contracts is not None: 

532 label_to_config = {x.label: x.config for x in taskDefs} 

533 for contract in self._pipelineIR.contracts: 

534 # execute this in its own line so it can raise a good error 

535 # message if there was problems with the eval 

536 success = eval(contract.contract, None, label_to_config) 

537 if not success: 

538 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

539 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not " 

540 f"satisfied{extra_info}") 

541 

542 yield from pipeTools.orderPipeline(taskDefs) 

543 

544 def __len__(self): 

545 return len(self._pipelineIR.tasks) 

546 

547 def __eq__(self, other: "Pipeline"): 

548 if not isinstance(other, Pipeline): 

549 return False 

550 return self._pipelineIR == other._pipelineIR 

551 

552 

553@dataclass(frozen=True) 

554class TaskDatasetTypes: 

555 """An immutable struct that extracts and classifies the dataset types used 

556 by a `PipelineTask` 

557 """ 

558 

559 initInputs: NamedValueSet[DatasetType] 

560 """Dataset types that are needed as inputs in order to construct this Task. 

561 

562 Task-level `initInputs` may be classified as either 

563 `~PipelineDatasetTypes.initInputs` or 

564 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

565 """ 

566 

567 initOutputs: NamedValueSet[DatasetType] 

568 """Dataset types that may be written after constructing this Task. 

569 

570 Task-level `initOutputs` may be classified as either 

571 `~PipelineDatasetTypes.initOutputs` or 

572 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

573 """ 

574 

575 inputs: NamedValueSet[DatasetType] 

576 """Dataset types that are regular inputs to this Task. 

577 

578 If an input dataset needed for a Quantum cannot be found in the input 

579 collection(s) or produced by another Task in the Pipeline, that Quantum 

580 (and all dependent Quanta) will not be produced. 

581 

582 Task-level `inputs` may be classified as either 

583 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

584 at the Pipeline level. 

585 """ 

586 

587 prerequisites: NamedValueSet[DatasetType] 

588 """Dataset types that are prerequisite inputs to this Task. 

589 

590 Prerequisite inputs must exist in the input collection(s) before the 

591 pipeline is run, but do not constrain the graph - if a prerequisite is 

592 missing for a Quantum, `PrerequisiteMissingError` is raised. 

593 

594 Prerequisite inputs are not resolved until the second stage of 

595 QuantumGraph generation. 

596 """ 

597 

598 outputs: NamedValueSet[DatasetType] 

599 """Dataset types that are produced by this Task. 

600 

601 Task-level `outputs` may be classified as either 

602 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

603 at the Pipeline level. 

604 """ 

605 

606 @classmethod 

607 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes: 

608 """Extract and classify the dataset types from a single `PipelineTask`. 

609 

610 Parameters 

611 ---------- 

612 taskDef: `TaskDef` 

613 An instance of a `TaskDef` class for a particular `PipelineTask`. 

614 registry: `Registry` 

615 Registry used to construct normalized `DatasetType` objects and 

616 retrieve those that are incomplete. 

617 

618 Returns 

619 ------- 

620 types: `TaskDatasetTypes` 

621 The dataset types used by this task. 

622 """ 

623 def makeDatasetTypesSet(connectionType, freeze=True): 

624 """Constructs a set of true `DatasetType` objects 

625 

626 Parameters 

627 ---------- 

628 connectionType : `str` 

629 Name of the connection type to produce a set for, corresponds 

630 to an attribute of type `list` on the connection class instance 

631 freeze : `bool`, optional 

632 If `True`, call `NamedValueSet.freeze` on the object returned. 

633 

634 Returns 

635 ------- 

636 datasetTypes : `NamedValueSet` 

637 A set of all datasetTypes which correspond to the input 

638 connection type specified in the connection class of this 

639 `PipelineTask` 

640 

641 Notes 

642 ----- 

643 This function is a closure over the variables ``registry`` and 

644 ``taskDef``. 

645 """ 

646 datasetTypes = NamedValueSet() 

647 for c in iterConnections(taskDef.connections, connectionType): 

648 dimensions = set(getattr(c, 'dimensions', set())) 

649 if "skypix" in dimensions: 

650 try: 

651 datasetType = registry.getDatasetType(c.name) 

652 except LookupError as err: 

653 raise LookupError( 

654 f"DatasetType '{c.name}' referenced by " 

655 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

656 f"placeholder, but does not already exist in the registry. " 

657 f"Note that reference catalog names are now used as the dataset " 

658 f"type name instead of 'ref_cat'." 

659 ) from err 

660 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names) 

661 rest2 = set(dim.name for dim in datasetType.dimensions 

662 if not isinstance(dim, SkyPixDimension)) 

663 if rest1 != rest2: 

664 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in " 

665 f"connections ({rest1}) are inconsistent with those in " 

666 f"registry's version of this dataset ({rest2}).") 

667 else: 

668 # Component dataset types are not explicitly in the 

669 # registry. This complicates consistency checks with 

670 # registry and requires we work out the composite storage 

671 # class. 

672 registryDatasetType = None 

673 try: 

674 registryDatasetType = registry.getDatasetType(c.name) 

675 except KeyError: 

676 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) 

677 parentStorageClass = DatasetType.PlaceholderParentStorageClass \ 

678 if componentName else None 

679 datasetType = c.makeDatasetType( 

680 registry.dimensions, 

681 parentStorageClass=parentStorageClass 

682 ) 

683 registryDatasetType = datasetType 

684 else: 

685 datasetType = c.makeDatasetType( 

686 registry.dimensions, 

687 parentStorageClass=registryDatasetType.parentStorageClass 

688 ) 

689 

690 if registryDatasetType and datasetType != registryDatasetType: 

691 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with " 

692 f"registry definition ({registryDatasetType}) " 

693 f"for {taskDef.label}.") 

694 datasetTypes.add(datasetType) 

695 if freeze: 

696 datasetTypes.freeze() 

697 return datasetTypes 

698 

699 # optionally add output dataset for metadata 

700 outputs = makeDatasetTypesSet("outputs", freeze=False) 

701 if taskDef.metadataDatasetName is not None: 

702 # Metadata is supposed to be of the PropertySet type, its 

703 # dimensions correspond to a task quantum 

704 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

705 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")} 

706 outputs.freeze() 

707 

708 return cls( 

709 initInputs=makeDatasetTypesSet("initInputs"), 

710 initOutputs=makeDatasetTypesSet("initOutputs"), 

711 inputs=makeDatasetTypesSet("inputs"), 

712 prerequisites=makeDatasetTypesSet("prerequisiteInputs"), 

713 outputs=outputs, 

714 ) 

715 

716 

717@dataclass(frozen=True) 

718class PipelineDatasetTypes: 

719 """An immutable struct that classifies the dataset types used in a 

720 `Pipeline`. 

721 """ 

722 

723 initInputs: NamedValueSet[DatasetType] 

724 """Dataset types that are needed as inputs in order to construct the Tasks 

725 in this Pipeline. 

726 

727 This does not include dataset types that are produced when constructing 

728 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

729 """ 

730 

731 initOutputs: NamedValueSet[DatasetType] 

732 """Dataset types that may be written after constructing the Tasks in this 

733 Pipeline. 

734 

735 This does not include dataset types that are also used as inputs when 

736 constructing other Tasks in the Pipeline (these are classified as 

737 `initIntermediates`). 

738 """ 

739 

740 initIntermediates: NamedValueSet[DatasetType] 

741 """Dataset types that are both used when constructing one or more Tasks 

742 in the Pipeline and produced as a side-effect of constructing another 

743 Task in the Pipeline. 

744 """ 

745 

746 inputs: NamedValueSet[DatasetType] 

747 """Dataset types that are regular inputs for the full pipeline. 

748 

749 If an input dataset needed for a Quantum cannot be found in the input 

750 collection(s), that Quantum (and all dependent Quanta) will not be 

751 produced. 

752 """ 

753 

754 prerequisites: NamedValueSet[DatasetType] 

755 """Dataset types that are prerequisite inputs for the full Pipeline. 

756 

757 Prerequisite inputs must exist in the input collection(s) before the 

758 pipeline is run, but do not constrain the graph - if a prerequisite is 

759 missing for a Quantum, `PrerequisiteMissingError` is raised. 

760 

761 Prerequisite inputs are not resolved until the second stage of 

762 QuantumGraph generation. 

763 """ 

764 

765 intermediates: NamedValueSet[DatasetType] 

766 """Dataset types that are output by one Task in the Pipeline and consumed 

767 as inputs by one or more other Tasks in the Pipeline. 

768 """ 

769 

770 outputs: NamedValueSet[DatasetType] 

771 """Dataset types that are output by a Task in the Pipeline and not consumed 

772 by any other Task in the Pipeline. 

773 """ 

774 

775 byTask: Mapping[str, TaskDatasetTypes] 

776 """Per-Task dataset types, keyed by label in the `Pipeline`. 

777 

778 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

779 neither has been modified since the dataset types were extracted, of 

780 course). 

781 """ 

782 

783 @classmethod 

784 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes: 

785 """Extract and classify the dataset types from all tasks in a 

786 `Pipeline`. 

787 

788 Parameters 

789 ---------- 

790 pipeline: `Pipeline` 

791 An ordered collection of tasks that can be run together. 

792 registry: `Registry` 

793 Registry used to construct normalized `DatasetType` objects and 

794 retrieve those that are incomplete. 

795 

796 Returns 

797 ------- 

798 types: `PipelineDatasetTypes` 

799 The dataset types used by this `Pipeline`. 

800 

801 Raises 

802 ------ 

803 ValueError 

804 Raised if Tasks are inconsistent about which datasets are marked 

805 prerequisite. This indicates that the Tasks cannot be run as part 

806 of the same `Pipeline`. 

807 """ 

808 allInputs = NamedValueSet() 

809 allOutputs = NamedValueSet() 

810 allInitInputs = NamedValueSet() 

811 allInitOutputs = NamedValueSet() 

812 prerequisites = NamedValueSet() 

813 byTask = dict() 

814 if isinstance(pipeline, Pipeline): 

815 pipeline = pipeline.toExpandedPipeline() 

816 for taskDef in pipeline: 

817 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

818 allInitInputs |= thisTask.initInputs 

819 allInitOutputs |= thisTask.initOutputs 

820 allInputs |= thisTask.inputs 

821 prerequisites |= thisTask.prerequisites 

822 allOutputs |= thisTask.outputs 

823 byTask[taskDef.label] = thisTask 

824 if not prerequisites.isdisjoint(allInputs): 

825 raise ValueError("{} marked as both prerequisites and regular inputs".format( 

826 {dt.name for dt in allInputs & prerequisites} 

827 )) 

828 if not prerequisites.isdisjoint(allOutputs): 

829 raise ValueError("{} marked as both prerequisites and outputs".format( 

830 {dt.name for dt in allOutputs & prerequisites} 

831 )) 

832 # Make sure that components which are marked as inputs get treated as 

833 # intermediates if there is an output which produces the composite 

834 # containing the component 

835 intermediateComponents = NamedValueSet() 

836 intermediateComposites = NamedValueSet() 

837 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

838 for dsType in allInputs: 

839 # get the name of a possible component 

840 name, component = dsType.nameAndComponent() 

841 # if there is a component name, that means this is a component 

842 # DatasetType, if there is an output which produces the parent of 

843 # this component, treat this input as an intermediate 

844 if component is not None: 

845 if name in outputNameMapping: 

846 if outputNameMapping[name].dimensions != dsType.dimensions: 

847 raise ValueError(f"Component dataset type {dsType.name} has different " 

848 f"dimensions ({dsType.dimensions}) than its parent " 

849 f"({outputNameMapping[name].dimensions}).") 

850 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass, 

851 universe=registry.dimensions) 

852 intermediateComponents.add(dsType) 

853 intermediateComposites.add(composite) 

854 

855 def checkConsistency(a: NamedValueSet, b: NamedValueSet): 

856 common = a.names & b.names 

857 for name in common: 

858 if a[name] != b[name]: 

859 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") 

860 

861 checkConsistency(allInitInputs, allInitOutputs) 

862 checkConsistency(allInputs, allOutputs) 

863 checkConsistency(allInputs, intermediateComposites) 

864 checkConsistency(allOutputs, intermediateComposites) 

865 

866 def frozen(s: NamedValueSet) -> NamedValueSet: 

867 s.freeze() 

868 return s 

869 

870 return cls( 

871 initInputs=frozen(allInitInputs - allInitOutputs), 

872 initIntermediates=frozen(allInitInputs & allInitOutputs), 

873 initOutputs=frozen(allInitOutputs - allInitInputs), 

874 inputs=frozen(allInputs - allOutputs - intermediateComponents), 

875 intermediates=frozen(allInputs & allOutputs | intermediateComponents), 

876 outputs=frozen(allOutputs - allInputs - intermediateComposites), 

877 prerequisites=frozen(prerequisites), 

878 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

879 )