Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23"""Module defining Pipeline class and related methods. 

24""" 

25 

26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31from dataclasses import dataclass 

32from types import MappingProxyType 

33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional 

34 

35import copy 

36import os 

37import re 

38 

39# ----------------------------- 

40# Imports for other modules -- 

41from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension 

42from lsst.utils import doImport 

43from .configOverrides import ConfigOverrides 

44from .connections import iterConnections 

45from .pipelineTask import PipelineTask 

46 

47from . import pipelineIR 

48from . import pipeTools 

49 

50if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from lsst.obs.base.instrument import Instrument 

52 

53# ---------------------------------- 

54# Local non-exported definitions -- 

55# ---------------------------------- 

56 

57# ------------------------ 

58# Exported definitions -- 

59# ------------------------ 

60 

61 

62@dataclass 

63class LabelSpecifier: 

64 """A structure to specify a subset of labels to load 

65 

66 This structure may contain a set of labels to be used in subsetting a 

67 pipeline, or a beginning and end point. Beginning or end may be empty, 

68 in which case the range will be a half open interval. Unlike python 

69 iteration bounds, end bounds are *INCLUDED*. Note that range based selection 

70 is not well defined for pipelines that are not linear in nature, and 

71 correct behavior is not guaranteed, or may vary from run to run. 

72 """ 

73 labels: Optional[Set[str]] = None 

74 begin: Optional[str] = None 

75 end: Optional[str] = None 

76 

77 def __post_init__(self): 

78 if self.labels is not None and (self.begin or self.end): 

79 raise ValueError("This struct can only be initialized with a labels set or " 

80 "a begin (and/or) end specifier") 

81 

82 

83class TaskDef: 

84 """TaskDef is a collection of information about task needed by Pipeline. 

85 

86 The information includes task name, configuration object and optional 

87 task class. This class is just a collection of attributes and it exposes 

88 all of them so that attributes could potentially be modified in place 

89 (e.g. if configuration needs extra overrides). 

90 

91 Attributes 

92 ---------- 

93 taskName : `str` 

94 `PipelineTask` class name, currently it is not specified whether this 

95 is a fully-qualified name or partial name (e.g. ``module.TaskClass``). 

96 Framework should be prepared to handle all cases. 

97 config : `lsst.pex.config.Config` 

98 Instance of the configuration class corresponding to this task class, 

99 usually with all overrides applied. This config will be frozen. 

100 taskClass : `type` or ``None`` 

101 `PipelineTask` class object, can be ``None``. If ``None`` then 

102 framework will have to locate and load class. 

103 label : `str`, optional 

104 Task label, usually a short string unique in a pipeline. 

105 """ 

106 def __init__(self, taskName, config, taskClass=None, label=""): 

107 self.taskName = taskName 

108 config.freeze() 

109 self.config = config 

110 self.taskClass = taskClass 

111 self.label = label 

112 self.connections = config.connections.ConnectionsClass(config=config) 

113 

114 @property 

115 def configDatasetName(self): 

116 """Name of a dataset type for configuration of this task (`str`) 

117 """ 

118 return self.label + "_config" 

119 

120 @property 

121 def metadataDatasetName(self): 

122 """Name of a dataset type for metadata of this task, `None` if 

123 metadata is not to be saved (`str`) 

124 """ 

125 if self.config.saveMetadata: 

126 return self.label + "_metadata" 

127 else: 

128 return None 

129 

130 def __str__(self): 

131 rep = "TaskDef(" + self.taskName 

132 if self.label: 

133 rep += ", label=" + self.label 

134 rep += ")" 

135 return rep 

136 

137 def __eq__(self, other: object) -> bool: 

138 if not isinstance(other, TaskDef): 

139 return False 

140 return self.config == other.config and\ 

141 self.taskClass == other.taskClass and\ 

142 self.label == other.label 

143 

144 def __hash__(self): 

145 return hash((self.taskClass, self.label)) 

146 

147 

148class Pipeline: 

149 """A `Pipeline` is a representation of a series of tasks to run, and the 

150 configuration for those tasks. 

151 

152 Parameters 

153 ---------- 

154 description : `str` 

155 A description of that this pipeline does. 

156 """ 

157 def __init__(self, description: str): 

158 pipeline_dict = {"description": description, "tasks": {}} 

159 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

160 

161 @classmethod 

162 def fromFile(cls, filename: str) -> Pipeline: 

163 """Load a pipeline defined in a pipeline yaml file. 

164 

165 Parameters 

166 ---------- 

167 filename: `str` 

168 A path that points to a pipeline defined in yaml format. This 

169 filename may also supply additional labels to be used in 

170 subsetting the loaded Pipeline. These labels are separated from 

171 the path by a colon, and may be specified as a comma separated 

172 list, or a range denoted as beginning..end. Beginning or end may 

173 be empty, in which case the range will be a half open interval. 

174 Unlike python iteration bounds, end bounds are *INCLUDED*. Note 

175 that range based selection is not well defined for pipelines that 

176 are not linear in nature, and correct behavior is not guaranteed, 

177 or may vary from run to run. 

178 

179 Returns 

180 ------- 

181 pipeline: `Pipeline` 

182 The pipeline loaded from specified location with appropriate (if 

183 any) subsetting 

184 

185 Notes 

186 ----- 

187 This method attempts to prune any contracts that contain labels which 

188 are not in the declared subset of labels. This pruning is done using a 

189 string based matching due to the nature of contracts and may prune more 

190 than it should. 

191 """ 

192 # Split up the filename and any labels that were supplied 

193 filename, labelSpecifier = cls._parseFileSpecifier(filename) 

194 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename)) 

195 

196 # If there are labels supplied, only keep those 

197 if labelSpecifier is not None: 

198 pipeline = pipeline.subsetFromLabels(labelSpecifier) 

199 return pipeline 

200 

201 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline: 

202 """Subset a pipeline to contain only labels specified in labelSpecifier 

203 

204 Parameters 

205 ---------- 

206 labelSpecifier : `labelSpecifier` 

207 Object containing labels that describes how to subset a pipeline. 

208 

209 Returns 

210 ------- 

211 pipeline : `Pipeline` 

212 A new pipeline object that is a subset of the old pipeline 

213 

214 Raises 

215 ------ 

216 ValueError 

217 Raised if there is an issue with specified labels 

218 

219 Notes 

220 ----- 

221 This method attempts to prune any contracts that contain labels which 

222 are not in the declared subset of labels. This pruning is done using a 

223 string based matching due to the nature of contracts and may prune more 

224 than it should. 

225 """ 

226 

227 pipeline = copy.deepcopy(self) 

228 

229 def remove_contracts(label: str): 

230 """Remove any contracts that contain the given label 

231 

232 String comparison used in this way is not the most elegant and may 

233 have issues, but it is the only feasible way when users can specify 

234 contracts with generic strings. 

235 """ 

236 new_contracts = [] 

237 for contract in pipeline._pipelineIR.contracts: 

238 # match a label that is not preceded by an ASCII identifier, or 

239 # is the start of a line and is followed by a dot 

240 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

241 continue 

242 new_contracts.append(contract) 

243 pipeline._pipelineIR.contracts = new_contracts 

244 

245 # Labels supplied as a set, explicitly remove any that are not in 

246 # That list 

247 if labelSpecifier.labels: 

248 # verify all the labels are in the pipeline 

249 if not labelSpecifier.labels.issubset(pipeline._pipelineIR.tasks.keys()): 

250 difference = labelSpecifier.labels.difference(pipeline._pipelineIR.tasks.keys()) 

251 raise ValueError("Not all supplied labels are in the pipeline definition, extra labels:" 

252 f"{difference}") 

253 # copy needed so as to not modify while iterating 

254 pipeline_labels = list(pipeline._pipelineIR.tasks.keys()) 

255 for label in pipeline_labels: 

256 if label not in labelSpecifier.labels: 

257 pipeline.removeTask(label) 

258 remove_contracts(label) 

259 # Labels supplied as a range, first create a list of all the labels 

260 # in the pipeline sorted according to task dependency. Then only 

261 # keep labels that lie between the supplied bounds 

262 else: 

263 # Use a dict for fast searching while preserving order 

264 # save contracts and remove them so they do not fail in the 

265 # expansion step, will be restored after. This is needed because 

266 # a user may only configure the tasks they intend to run, which 

267 # may cause some contracts to fail if they will later be dropped 

268 contractSave = pipeline._pipelineIR.contracts 

269 pipeline._pipelineIR.contracts = [] 

270 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()} 

271 pipeline._pipelineIR.contracts = contractSave 

272 

273 # Verify the bounds are in the labels 

274 if labelSpecifier.begin is not None: 

275 if labelSpecifier.begin not in labels: 

276 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in " 

277 "pipeline definition") 

278 if labelSpecifier.end is not None: 

279 if labelSpecifier.end not in labels: 

280 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline " 

281 "definition") 

282 

283 closed = False 

284 for label in labels: 

285 # if there is a begin label delete all labels until it is 

286 # reached. 

287 if labelSpecifier.begin: 

288 if label != labelSpecifier.begin: 

289 pipeline.removeTask(label) 

290 remove_contracts(label) 

291 continue 

292 else: 

293 labelSpecifier.begin = None 

294 # if there is an end specifier, keep all tasks until the 

295 # end specifier is reached, afterwards delete the labels 

296 if labelSpecifier.end: 

297 if label != labelSpecifier.end: 

298 if closed: 

299 pipeline.removeTask(label) 

300 remove_contracts(label) 

301 continue 

302 else: 

303 closed = True 

304 return pipeline 

305 

306 @staticmethod 

307 def _parseFileSpecifier(fileSpecifer): 

308 """Split appart a filename path from label subsets 

309 """ 

310 split = fileSpecifer.split(':') 

311 # There is only a filename, return just that 

312 if len(split) == 1: 

313 return fileSpecifer, None 

314 # More than one specifier provided, bail out 

315 if len(split) > 2: 

316 raise ValueError("Only one : is allowed when specifying a pipeline to load") 

317 else: 

318 labelSubset: str 

319 filename: str 

320 filename, labelSubset = split[0], split[1] 

321 # labels supplied as a list 

322 if ',' in labelSubset: 

323 if '..' in labelSubset: 

324 raise ValueError("Can only specify a list of labels or a range" 

325 "when loading a Pipline not both") 

326 labels = set(labelSubset.split(",")) 

327 specifier = LabelSpecifier(labels=labels) 

328 # labels supplied as a range 

329 elif '..' in labelSubset: 

330 # Try to destructure the labelSubset, this will fail if more 

331 # than one range is specified 

332 try: 

333 begin, end = labelSubset.split("..") 

334 except ValueError: 

335 raise ValueError("Only one range can be specified when loading a pipeline") 

336 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None) 

337 # Assume anything else is a single label 

338 else: 

339 labels = {labelSubset} 

340 specifier = LabelSpecifier(labels=labels) 

341 

342 return filename, specifier 

343 

344 @classmethod 

345 def fromString(cls, pipeline_string: str) -> Pipeline: 

346 """Create a pipeline from string formatted as a pipeline document. 

347 

348 Parameters 

349 ---------- 

350 pipeline_string : `str` 

351 A string that is formatted according like a pipeline document 

352 

353 Returns 

354 ------- 

355 pipeline: `Pipeline` 

356 """ 

357 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

358 return pipeline 

359 

360 @classmethod 

361 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

362 """Create a pipeline from an already created `PipelineIR` object. 

363 

364 Parameters 

365 ---------- 

366 deserialized_pipeline: `PipelineIR` 

367 An already created pipeline intermediate representation object 

368 

369 Returns 

370 ------- 

371 pipeline: `Pipeline` 

372 """ 

373 pipeline = cls.__new__(cls) 

374 pipeline._pipelineIR = deserialized_pipeline 

375 return pipeline 

376 

377 @classmethod 

378 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline: 

379 """Create a new pipeline by copying an already existing `Pipeline`. 

380 

381 Parameters 

382 ---------- 

383 pipeline: `Pipeline` 

384 An already created pipeline intermediate representation object 

385 

386 Returns 

387 ------- 

388 pipeline: `Pipeline` 

389 """ 

390 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR)) 

391 

392 def __str__(self) -> str: 

393 return str(self._pipelineIR) 

394 

395 def addInstrument(self, instrument: Union[Instrument, str]): 

396 """Add an instrument to the pipeline, or replace an instrument that is 

397 already defined. 

398 

399 Parameters 

400 ---------- 

401 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

402 Either a derived class object of a `lsst.daf.butler.instrument` or a 

403 string corresponding to a fully qualified 

404 `lsst.daf.butler.instrument` name. 

405 """ 

406 if isinstance(instrument, str): 

407 pass 

408 else: 

409 # TODO: assume that this is a subclass of Instrument, no type checking 

410 instrument = f"{instrument.__module__}.{instrument.__qualname__}" 

411 self._pipelineIR.instrument = instrument 

412 

413 def getInstrument(self): 

414 """Get the instrument from the pipeline. 

415 

416 Returns 

417 ------- 

418 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None 

419 A derived class object of a `lsst.daf.butler.instrument`, a string 

420 corresponding to a fully qualified `lsst.daf.butler.instrument` 

421 name, or None if the pipeline does not have an instrument. 

422 """ 

423 return self._pipelineIR.instrument 

424 

425 def addTask(self, task: Union[PipelineTask, str], label: str): 

426 """Add a new task to the pipeline, or replace a task that is already 

427 associated with the supplied label. 

428 

429 Parameters 

430 ---------- 

431 task: `PipelineTask` or `str` 

432 Either a derived class object of a `PipelineTask` or a string 

433 corresponding to a fully qualified `PipelineTask` name. 

434 label: `str` 

435 A label that is used to identify the `PipelineTask` being added 

436 """ 

437 if isinstance(task, str): 

438 taskName = task 

439 elif issubclass(task, PipelineTask): 

440 taskName = f"{task.__module__}.{task.__qualname__}" 

441 else: 

442 raise ValueError("task must be either a child class of PipelineTask or a string containing" 

443 " a fully qualified name to one") 

444 if not label: 

445 # in some cases (with command line-generated pipeline) tasks can 

446 # be defined without label which is not acceptable, use task 

447 # _DefaultName in that case 

448 if isinstance(task, str): 

449 task = doImport(task) 

450 label = task._DefaultName 

451 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

452 

453 def removeTask(self, label: str): 

454 """Remove a task from the pipeline. 

455 

456 Parameters 

457 ---------- 

458 label : `str` 

459 The label used to identify the task that is to be removed 

460 

461 Raises 

462 ------ 

463 KeyError 

464 If no task with that label exists in the pipeline 

465 

466 """ 

467 self._pipelineIR.tasks.pop(label) 

468 

469 def addConfigOverride(self, label: str, key: str, value: object): 

470 """Apply single config override. 

471 

472 Parameters 

473 ---------- 

474 label : `str` 

475 Label of the task. 

476 key: `str` 

477 Fully-qualified field name. 

478 value : object 

479 Value to be given to a field. 

480 """ 

481 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

482 

483 def addConfigFile(self, label: str, filename: str): 

484 """Add overrides from a specified file. 

485 

486 Parameters 

487 ---------- 

488 label : `str` 

489 The label used to identify the task associated with config to 

490 modify 

491 filename : `str` 

492 Path to the override file. 

493 """ 

494 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

495 

496 def addConfigPython(self, label: str, pythonString: str): 

497 """Add Overrides by running a snippet of python code against a config. 

498 

499 Parameters 

500 ---------- 

501 label : `str` 

502 The label used to identity the task associated with config to 

503 modify. 

504 pythonString: `str` 

505 A string which is valid python code to be executed. This is done 

506 with config as the only local accessible value. 

507 """ 

508 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

509 

510 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR): 

511 if label not in self._pipelineIR.tasks: 

512 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

513 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

514 

515 def toFile(self, filename: str): 

516 self._pipelineIR.to_file(filename) 

517 

518 def toExpandedPipeline(self) -> Generator[TaskDef]: 

519 """Returns a generator of TaskDefs which can be used to create quantum 

520 graphs. 

521 

522 Returns 

523 ------- 

524 generator : generator of `TaskDef` 

525 The generator returned will be the sorted iterator of tasks which 

526 are to be used in constructing a quantum graph. 

527 

528 Raises 

529 ------ 

530 NotImplementedError 

531 If a dataId is supplied in a config block. This is in place for 

532 future use 

533 """ 

534 taskDefs = [] 

535 for label, taskIR in self._pipelineIR.tasks.items(): 

536 taskClass = doImport(taskIR.klass) 

537 taskName = taskClass.__qualname__ 

538 config = taskClass.ConfigClass() 

539 overrides = ConfigOverrides() 

540 if self._pipelineIR.instrument is not None: 

541 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName) 

542 if taskIR.config is not None: 

543 for configIR in taskIR.config: 

544 if configIR.dataId is not None: 

545 raise NotImplementedError("Specializing a config on a partial data id is not yet " 

546 "supported in Pipeline definition") 

547 # only apply override if it applies to everything 

548 if configIR.dataId is None: 

549 if configIR.file: 

550 for configFile in configIR.file: 

551 overrides.addFileOverride(os.path.expandvars(configFile)) 

552 if configIR.python is not None: 

553 overrides.addPythonOverride(configIR.python) 

554 for key, value in configIR.rest.items(): 

555 overrides.addValueOverride(key, value) 

556 overrides.applyTo(config) 

557 # This may need to be revisited 

558 config.validate() 

559 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)) 

560 

561 # lets evaluate the contracts 

562 if self._pipelineIR.contracts is not None: 

563 label_to_config = {x.label: x.config for x in taskDefs} 

564 for contract in self._pipelineIR.contracts: 

565 # execute this in its own line so it can raise a good error message if there was problems 

566 # with the eval 

567 success = eval(contract.contract, None, label_to_config) 

568 if not success: 

569 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

570 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not " 

571 f"satisfied{extra_info}") 

572 

573 yield from pipeTools.orderPipeline(taskDefs) 

574 

575 def __len__(self): 

576 return len(self._pipelineIR.tasks) 

577 

578 def __eq__(self, other: "Pipeline"): 

579 if not isinstance(other, Pipeline): 

580 return False 

581 return self._pipelineIR == other._pipelineIR 

582 

583 

584@dataclass(frozen=True) 

585class TaskDatasetTypes: 

586 """An immutable struct that extracts and classifies the dataset types used 

587 by a `PipelineTask` 

588 """ 

589 

590 initInputs: NamedValueSet[DatasetType] 

591 """Dataset types that are needed as inputs in order to construct this Task. 

592 

593 Task-level `initInputs` may be classified as either 

594 `~PipelineDatasetTypes.initInputs` or 

595 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

596 """ 

597 

598 initOutputs: NamedValueSet[DatasetType] 

599 """Dataset types that may be written after constructing this Task. 

600 

601 Task-level `initOutputs` may be classified as either 

602 `~PipelineDatasetTypes.initOutputs` or 

603 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

604 """ 

605 

606 inputs: NamedValueSet[DatasetType] 

607 """Dataset types that are regular inputs to this Task. 

608 

609 If an input dataset needed for a Quantum cannot be found in the input 

610 collection(s) or produced by another Task in the Pipeline, that Quantum 

611 (and all dependent Quanta) will not be produced. 

612 

613 Task-level `inputs` may be classified as either 

614 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

615 at the Pipeline level. 

616 """ 

617 

618 prerequisites: NamedValueSet[DatasetType] 

619 """Dataset types that are prerequisite inputs to this Task. 

620 

621 Prerequisite inputs must exist in the input collection(s) before the 

622 pipeline is run, but do not constrain the graph - if a prerequisite is 

623 missing for a Quantum, `PrerequisiteMissingError` is raised. 

624 

625 Prerequisite inputs are not resolved until the second stage of 

626 QuantumGraph generation. 

627 """ 

628 

629 outputs: NamedValueSet[DatasetType] 

630 """Dataset types that are produced by this Task. 

631 

632 Task-level `outputs` may be classified as either 

633 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

634 at the Pipeline level. 

635 """ 

636 

637 @classmethod 

638 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes: 

639 """Extract and classify the dataset types from a single `PipelineTask`. 

640 

641 Parameters 

642 ---------- 

643 taskDef: `TaskDef` 

644 An instance of a `TaskDef` class for a particular `PipelineTask`. 

645 registry: `Registry` 

646 Registry used to construct normalized `DatasetType` objects and 

647 retrieve those that are incomplete. 

648 

649 Returns 

650 ------- 

651 types: `TaskDatasetTypes` 

652 The dataset types used by this task. 

653 """ 

654 def makeDatasetTypesSet(connectionType, freeze=True): 

655 """Constructs a set of true `DatasetType` objects 

656 

657 Parameters 

658 ---------- 

659 connectionType : `str` 

660 Name of the connection type to produce a set for, corresponds 

661 to an attribute of type `list` on the connection class instance 

662 freeze : `bool`, optional 

663 If `True`, call `NamedValueSet.freeze` on the object returned. 

664 

665 Returns 

666 ------- 

667 datasetTypes : `NamedValueSet` 

668 A set of all datasetTypes which correspond to the input 

669 connection type specified in the connection class of this 

670 `PipelineTask` 

671 

672 Notes 

673 ----- 

674 This function is a closure over the variables ``registry`` and 

675 ``taskDef``. 

676 """ 

677 datasetTypes = NamedValueSet() 

678 for c in iterConnections(taskDef.connections, connectionType): 

679 dimensions = set(getattr(c, 'dimensions', set())) 

680 if "skypix" in dimensions: 

681 try: 

682 datasetType = registry.getDatasetType(c.name) 

683 except LookupError as err: 

684 raise LookupError( 

685 f"DatasetType '{c.name}' referenced by " 

686 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

687 f"placeholder, but does not already exist in the registry. " 

688 f"Note that reference catalog names are now used as the dataset " 

689 f"type name instead of 'ref_cat'." 

690 ) from err 

691 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names) 

692 rest2 = set(dim.name for dim in datasetType.dimensions 

693 if not isinstance(dim, SkyPixDimension)) 

694 if rest1 != rest2: 

695 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in " 

696 f"connections ({rest1}) are inconsistent with those in " 

697 f"registry's version of this dataset ({rest2}).") 

698 else: 

699 # Component dataset types are not explicitly in the 

700 # registry. This complicates consistency checks with 

701 # registry and requires we work out the composite storage 

702 # class. 

703 registryDatasetType = None 

704 try: 

705 registryDatasetType = registry.getDatasetType(c.name) 

706 except KeyError: 

707 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) 

708 parentStorageClass = DatasetType.PlaceholderParentStorageClass \ 

709 if componentName else None 

710 datasetType = c.makeDatasetType( 

711 registry.dimensions, 

712 parentStorageClass=parentStorageClass 

713 ) 

714 registryDatasetType = datasetType 

715 else: 

716 datasetType = c.makeDatasetType( 

717 registry.dimensions, 

718 parentStorageClass=registryDatasetType.parentStorageClass 

719 ) 

720 

721 if registryDatasetType and datasetType != registryDatasetType: 

722 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with " 

723 f"registry definition ({registryDatasetType}) " 

724 f"for {taskDef.label}.") 

725 datasetTypes.add(datasetType) 

726 if freeze: 

727 datasetTypes.freeze() 

728 return datasetTypes 

729 

730 # optionally add output dataset for metadata 

731 outputs = makeDatasetTypesSet("outputs", freeze=False) 

732 if taskDef.metadataDatasetName is not None: 

733 # Metadata is supposed to be of the PropertySet type, its dimensions 

734 # correspond to a task quantum 

735 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

736 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")} 

737 outputs.freeze() 

738 

739 return cls( 

740 initInputs=makeDatasetTypesSet("initInputs"), 

741 initOutputs=makeDatasetTypesSet("initOutputs"), 

742 inputs=makeDatasetTypesSet("inputs"), 

743 prerequisites=makeDatasetTypesSet("prerequisiteInputs"), 

744 outputs=outputs, 

745 ) 

746 

747 

748@dataclass(frozen=True) 

749class PipelineDatasetTypes: 

750 """An immutable struct that classifies the dataset types used in a 

751 `Pipeline`. 

752 """ 

753 

754 initInputs: NamedValueSet[DatasetType] 

755 """Dataset types that are needed as inputs in order to construct the Tasks 

756 in this Pipeline. 

757 

758 This does not include dataset types that are produced when constructing 

759 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

760 """ 

761 

762 initOutputs: NamedValueSet[DatasetType] 

763 """Dataset types that may be written after constructing the Tasks in this 

764 Pipeline. 

765 

766 This does not include dataset types that are also used as inputs when 

767 constructing other Tasks in the Pipeline (these are classified as 

768 `initIntermediates`). 

769 """ 

770 

771 initIntermediates: NamedValueSet[DatasetType] 

772 """Dataset types that are both used when constructing one or more Tasks 

773 in the Pipeline and produced as a side-effect of constructing another 

774 Task in the Pipeline. 

775 """ 

776 

777 inputs: NamedValueSet[DatasetType] 

778 """Dataset types that are regular inputs for the full pipeline. 

779 

780 If an input dataset needed for a Quantum cannot be found in the input 

781 collection(s), that Quantum (and all dependent Quanta) will not be 

782 produced. 

783 """ 

784 

785 prerequisites: NamedValueSet[DatasetType] 

786 """Dataset types that are prerequisite inputs for the full Pipeline. 

787 

788 Prerequisite inputs must exist in the input collection(s) before the 

789 pipeline is run, but do not constrain the graph - if a prerequisite is 

790 missing for a Quantum, `PrerequisiteMissingError` is raised. 

791 

792 Prerequisite inputs are not resolved until the second stage of 

793 QuantumGraph generation. 

794 """ 

795 

796 intermediates: NamedValueSet[DatasetType] 

797 """Dataset types that are output by one Task in the Pipeline and consumed 

798 as inputs by one or more other Tasks in the Pipeline. 

799 """ 

800 

801 outputs: NamedValueSet[DatasetType] 

802 """Dataset types that are output by a Task in the Pipeline and not consumed 

803 by any other Task in the Pipeline. 

804 """ 

805 

806 byTask: Mapping[str, TaskDatasetTypes] 

807 """Per-Task dataset types, keyed by label in the `Pipeline`. 

808 

809 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

810 neither has been modified since the dataset types were extracted, of 

811 course). 

812 """ 

813 

814 @classmethod 

815 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes: 

816 """Extract and classify the dataset types from all tasks in a 

817 `Pipeline`. 

818 

819 Parameters 

820 ---------- 

821 pipeline: `Pipeline` 

822 An ordered collection of tasks that can be run together. 

823 registry: `Registry` 

824 Registry used to construct normalized `DatasetType` objects and 

825 retrieve those that are incomplete. 

826 

827 Returns 

828 ------- 

829 types: `PipelineDatasetTypes` 

830 The dataset types used by this `Pipeline`. 

831 

832 Raises 

833 ------ 

834 ValueError 

835 Raised if Tasks are inconsistent about which datasets are marked 

836 prerequisite. This indicates that the Tasks cannot be run as part 

837 of the same `Pipeline`. 

838 """ 

839 allInputs = NamedValueSet() 

840 allOutputs = NamedValueSet() 

841 allInitInputs = NamedValueSet() 

842 allInitOutputs = NamedValueSet() 

843 prerequisites = NamedValueSet() 

844 byTask = dict() 

845 if isinstance(pipeline, Pipeline): 

846 pipeline = pipeline.toExpandedPipeline() 

847 for taskDef in pipeline: 

848 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

849 allInitInputs |= thisTask.initInputs 

850 allInitOutputs |= thisTask.initOutputs 

851 allInputs |= thisTask.inputs 

852 prerequisites |= thisTask.prerequisites 

853 allOutputs |= thisTask.outputs 

854 byTask[taskDef.label] = thisTask 

855 if not prerequisites.isdisjoint(allInputs): 

856 raise ValueError("{} marked as both prerequisites and regular inputs".format( 

857 {dt.name for dt in allInputs & prerequisites} 

858 )) 

859 if not prerequisites.isdisjoint(allOutputs): 

860 raise ValueError("{} marked as both prerequisites and outputs".format( 

861 {dt.name for dt in allOutputs & prerequisites} 

862 )) 

863 # Make sure that components which are marked as inputs get treated as 

864 # intermediates if there is an output which produces the composite 

865 # containing the component 

866 intermediateComponents = NamedValueSet() 

867 intermediateComposites = NamedValueSet() 

868 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

869 for dsType in allInputs: 

870 # get the name of a possible component 

871 name, component = dsType.nameAndComponent() 

872 # if there is a component name, that means this is a component 

873 # DatasetType, if there is an output which produces the parent of 

874 # this component, treat this input as an intermediate 

875 if component is not None: 

876 if name in outputNameMapping: 

877 if outputNameMapping[name].dimensions != dsType.dimensions: 

878 raise ValueError(f"Component dataset type {dsType.name} has different " 

879 f"dimensions ({dsType.dimensions}) than its parent " 

880 f"({outputNameMapping[name].dimensions}).") 

881 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass, 

882 universe=registry.dimensions) 

883 intermediateComponents.add(dsType) 

884 intermediateComposites.add(composite) 

885 

886 def checkConsistency(a: NamedValueSet, b: NamedValueSet): 

887 common = a.names & b.names 

888 for name in common: 

889 if a[name] != b[name]: 

890 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") 

891 

892 checkConsistency(allInitInputs, allInitOutputs) 

893 checkConsistency(allInputs, allOutputs) 

894 checkConsistency(allInputs, intermediateComposites) 

895 checkConsistency(allOutputs, intermediateComposites) 

896 

897 def frozen(s: NamedValueSet) -> NamedValueSet: 

898 s.freeze() 

899 return s 

900 

901 return cls( 

902 initInputs=frozen(allInitInputs - allInitOutputs), 

903 initIntermediates=frozen(allInitInputs & allInitOutputs), 

904 initOutputs=frozen(allInitOutputs - allInitInputs), 

905 inputs=frozen(allInputs - allOutputs - intermediateComponents), 

906 intermediates=frozen(allInputs & allOutputs | intermediateComponents), 

907 outputs=frozen(allOutputs - allInputs - intermediateComposites), 

908 prerequisites=frozen(prerequisites), 

909 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

910 )