Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23"""Module defining Pipeline class and related methods. 

24""" 

25 

26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31from dataclasses import dataclass 

32from types import MappingProxyType 

33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional 

34 

35import copy 

36import os 

37import re 

38 

39# ----------------------------- 

40# Imports for other modules -- 

41from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension 

42from lsst.utils import doImport 

43from .configOverrides import ConfigOverrides 

44from .connections import iterConnections 

45from .pipelineTask import PipelineTask 

46 

47from . import pipelineIR 

48from . import pipeTools 

49 

50if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from lsst.obs.base.instrument import Instrument 

52 

53# ---------------------------------- 

54# Local non-exported definitions -- 

55# ---------------------------------- 

56 

57# ------------------------ 

58# Exported definitions -- 

59# ------------------------ 

60 

61 

62@dataclass 

63class LabelSpecifier: 

64 """A structure to specify a subset of labels to load 

65 

66 This structure may contain a set of labels to be used in subsetting a 

67 pipeline, or a beginning and end point. Beginning or end may be empty, 

68 in which case the range will be a half open interval. Unlike python 

69 iteration bounds, end bounds are *INCLUDED*. Note that range based 

70 selection is not well defined for pipelines that are not linear in nature, 

71 and correct behavior is not guaranteed, or may vary from run to run. 

72 """ 

73 labels: Optional[Set[str]] = None 

74 begin: Optional[str] = None 

75 end: Optional[str] = None 

76 

77 def __post_init__(self): 

78 if self.labels is not None and (self.begin or self.end): 

79 raise ValueError("This struct can only be initialized with a labels set or " 

80 "a begin (and/or) end specifier") 

81 

82 

83class TaskDef: 

84 """TaskDef is a collection of information about task needed by Pipeline. 

85 

86 The information includes task name, configuration object and optional 

87 task class. This class is just a collection of attributes and it exposes 

88 all of them so that attributes could potentially be modified in place 

89 (e.g. if configuration needs extra overrides). 

90 

91 Attributes 

92 ---------- 

93 taskName : `str` 

94 `PipelineTask` class name, currently it is not specified whether this 

95 is a fully-qualified name or partial name (e.g. ``module.TaskClass``). 

96 Framework should be prepared to handle all cases. 

97 config : `lsst.pex.config.Config` 

98 Instance of the configuration class corresponding to this task class, 

99 usually with all overrides applied. This config will be frozen. 

100 taskClass : `type` or ``None`` 

101 `PipelineTask` class object, can be ``None``. If ``None`` then 

102 framework will have to locate and load class. 

103 label : `str`, optional 

104 Task label, usually a short string unique in a pipeline. 

105 """ 

106 def __init__(self, taskName, config, taskClass=None, label=""): 

107 self.taskName = taskName 

108 config.freeze() 

109 self.config = config 

110 self.taskClass = taskClass 

111 self.label = label 

112 self.connections = config.connections.ConnectionsClass(config=config) 

113 

114 @property 

115 def configDatasetName(self): 

116 """Name of a dataset type for configuration of this task (`str`) 

117 """ 

118 return self.label + "_config" 

119 

120 @property 

121 def metadataDatasetName(self): 

122 """Name of a dataset type for metadata of this task, `None` if 

123 metadata is not to be saved (`str`) 

124 """ 

125 if self.config.saveMetadata: 

126 return self.label + "_metadata" 

127 else: 

128 return None 

129 

130 def __str__(self): 

131 rep = "TaskDef(" + self.taskName 

132 if self.label: 

133 rep += ", label=" + self.label 

134 rep += ")" 

135 return rep 

136 

137 def __eq__(self, other: object) -> bool: 

138 if not isinstance(other, TaskDef): 

139 return False 

140 return self.config == other.config and\ 

141 self.taskClass == other.taskClass and\ 

142 self.label == other.label 

143 

144 def __hash__(self): 

145 return hash((self.taskClass, self.label)) 

146 

147 

148class Pipeline: 

149 """A `Pipeline` is a representation of a series of tasks to run, and the 

150 configuration for those tasks. 

151 

152 Parameters 

153 ---------- 

154 description : `str` 

155 A description of that this pipeline does. 

156 """ 

157 def __init__(self, description: str): 

158 pipeline_dict = {"description": description, "tasks": {}} 

159 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

160 

161 @classmethod 

162 def fromFile(cls, filename: str) -> Pipeline: 

163 """Load a pipeline defined in a pipeline yaml file. 

164 

165 Parameters 

166 ---------- 

167 filename: `str` 

168 A path that points to a pipeline defined in yaml format. This 

169 filename may also supply additional labels to be used in 

170 subsetting the loaded Pipeline. These labels are separated from 

171 the path by a colon, and may be specified as a comma separated 

172 list, or a range denoted as beginning..end. Beginning or end may 

173 be empty, in which case the range will be a half open interval. 

174 Unlike python iteration bounds, end bounds are *INCLUDED*. Note 

175 that range based selection is not well defined for pipelines that 

176 are not linear in nature, and correct behavior is not guaranteed, 

177 or may vary from run to run. 

178 

179 Returns 

180 ------- 

181 pipeline: `Pipeline` 

182 The pipeline loaded from specified location with appropriate (if 

183 any) subsetting 

184 

185 Notes 

186 ----- 

187 This method attempts to prune any contracts that contain labels which 

188 are not in the declared subset of labels. This pruning is done using a 

189 string based matching due to the nature of contracts and may prune more 

190 than it should. 

191 """ 

192 # Split up the filename and any labels that were supplied 

193 filename, labelSpecifier = cls._parseFileSpecifier(filename) 

194 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename)) 

195 

196 # If there are labels supplied, only keep those 

197 if labelSpecifier is not None: 

198 pipeline = pipeline.subsetFromLabels(labelSpecifier) 

199 return pipeline 

200 

201 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline: 

202 """Subset a pipeline to contain only labels specified in labelSpecifier 

203 

204 Parameters 

205 ---------- 

206 labelSpecifier : `labelSpecifier` 

207 Object containing labels that describes how to subset a pipeline. 

208 

209 Returns 

210 ------- 

211 pipeline : `Pipeline` 

212 A new pipeline object that is a subset of the old pipeline 

213 

214 Raises 

215 ------ 

216 ValueError 

217 Raised if there is an issue with specified labels 

218 

219 Notes 

220 ----- 

221 This method attempts to prune any contracts that contain labels which 

222 are not in the declared subset of labels. This pruning is done using a 

223 string based matching due to the nature of contracts and may prune more 

224 than it should. 

225 """ 

226 

227 pipeline = copy.deepcopy(self) 

228 

229 def remove_contracts(label: str): 

230 """Remove any contracts that contain the given label 

231 

232 String comparison used in this way is not the most elegant and may 

233 have issues, but it is the only feasible way when users can specify 

234 contracts with generic strings. 

235 """ 

236 new_contracts = [] 

237 for contract in pipeline._pipelineIR.contracts: 

238 # match a label that is not preceded by an ASCII identifier, or 

239 # is the start of a line and is followed by a dot 

240 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract): 

241 continue 

242 new_contracts.append(contract) 

243 pipeline._pipelineIR.contracts = new_contracts 

244 

245 # Labels supplied as a set, explicitly remove any that are not in 

246 # That list 

247 if labelSpecifier.labels: 

248 # verify all the labels are in the pipeline 

249 if not labelSpecifier.labels.issubset(pipeline._pipelineIR.tasks.keys()): 

250 difference = labelSpecifier.labels.difference(pipeline._pipelineIR.tasks.keys()) 

251 raise ValueError("Not all supplied labels are in the pipeline definition, extra labels:" 

252 f"{difference}") 

253 # copy needed so as to not modify while iterating 

254 pipeline_labels = list(pipeline._pipelineIR.tasks.keys()) 

255 for label in pipeline_labels: 

256 if label not in labelSpecifier.labels: 

257 pipeline.removeTask(label) 

258 remove_contracts(label) 

259 # Labels supplied as a range, first create a list of all the labels 

260 # in the pipeline sorted according to task dependency. Then only 

261 # keep labels that lie between the supplied bounds 

262 else: 

263 # Use a dict for fast searching while preserving order 

264 # save contracts and remove them so they do not fail in the 

265 # expansion step, will be restored after. This is needed because 

266 # a user may only configure the tasks they intend to run, which 

267 # may cause some contracts to fail if they will later be dropped 

268 contractSave = pipeline._pipelineIR.contracts 

269 pipeline._pipelineIR.contracts = [] 

270 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()} 

271 pipeline._pipelineIR.contracts = contractSave 

272 

273 # Verify the bounds are in the labels 

274 if labelSpecifier.begin is not None: 

275 if labelSpecifier.begin not in labels: 

276 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in " 

277 "pipeline definition") 

278 if labelSpecifier.end is not None: 

279 if labelSpecifier.end not in labels: 

280 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline " 

281 "definition") 

282 

283 closed = False 

284 for label in labels: 

285 # if there is a begin label delete all labels until it is 

286 # reached. 

287 if labelSpecifier.begin: 

288 if label != labelSpecifier.begin: 

289 pipeline.removeTask(label) 

290 remove_contracts(label) 

291 continue 

292 else: 

293 labelSpecifier.begin = None 

294 # if there is an end specifier, keep all tasks until the 

295 # end specifier is reached, afterwards delete the labels 

296 if labelSpecifier.end: 

297 if label != labelSpecifier.end: 

298 if closed: 

299 pipeline.removeTask(label) 

300 remove_contracts(label) 

301 continue 

302 else: 

303 closed = True 

304 return pipeline 

305 

306 @staticmethod 

307 def _parseFileSpecifier(fileSpecifer): 

308 """Split appart a filename path from label subsets 

309 """ 

310 split = fileSpecifer.split(':') 

311 # There is only a filename, return just that 

312 if len(split) == 1: 

313 return fileSpecifer, None 

314 # More than one specifier provided, bail out 

315 if len(split) > 2: 

316 raise ValueError("Only one : is allowed when specifying a pipeline to load") 

317 else: 

318 labelSubset: str 

319 filename: str 

320 filename, labelSubset = split[0], split[1] 

321 # labels supplied as a list 

322 if ',' in labelSubset: 

323 if '..' in labelSubset: 

324 raise ValueError("Can only specify a list of labels or a range" 

325 "when loading a Pipline not both") 

326 labels = set(labelSubset.split(",")) 

327 specifier = LabelSpecifier(labels=labels) 

328 # labels supplied as a range 

329 elif '..' in labelSubset: 

330 # Try to destructure the labelSubset, this will fail if more 

331 # than one range is specified 

332 try: 

333 begin, end = labelSubset.split("..") 

334 except ValueError: 

335 raise ValueError("Only one range can be specified when loading a pipeline") 

336 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None) 

337 # Assume anything else is a single label 

338 else: 

339 labels = {labelSubset} 

340 specifier = LabelSpecifier(labels=labels) 

341 

342 return filename, specifier 

343 

344 @classmethod 

345 def fromString(cls, pipeline_string: str) -> Pipeline: 

346 """Create a pipeline from string formatted as a pipeline document. 

347 

348 Parameters 

349 ---------- 

350 pipeline_string : `str` 

351 A string that is formatted according like a pipeline document 

352 

353 Returns 

354 ------- 

355 pipeline: `Pipeline` 

356 """ 

357 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

358 return pipeline 

359 

360 @classmethod 

361 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

362 """Create a pipeline from an already created `PipelineIR` object. 

363 

364 Parameters 

365 ---------- 

366 deserialized_pipeline: `PipelineIR` 

367 An already created pipeline intermediate representation object 

368 

369 Returns 

370 ------- 

371 pipeline: `Pipeline` 

372 """ 

373 pipeline = cls.__new__(cls) 

374 pipeline._pipelineIR = deserialized_pipeline 

375 return pipeline 

376 

377 @classmethod 

378 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline: 

379 """Create a new pipeline by copying an already existing `Pipeline`. 

380 

381 Parameters 

382 ---------- 

383 pipeline: `Pipeline` 

384 An already created pipeline intermediate representation object 

385 

386 Returns 

387 ------- 

388 pipeline: `Pipeline` 

389 """ 

390 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR)) 

391 

392 def __str__(self) -> str: 

393 return str(self._pipelineIR) 

394 

395 def addInstrument(self, instrument: Union[Instrument, str]): 

396 """Add an instrument to the pipeline, or replace an instrument that is 

397 already defined. 

398 

399 Parameters 

400 ---------- 

401 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

402 Either a derived class object of a `lsst.daf.butler.instrument` or 

403 a string corresponding to a fully qualified 

404 `lsst.daf.butler.instrument` name. 

405 """ 

406 if isinstance(instrument, str): 

407 pass 

408 else: 

409 # TODO: assume that this is a subclass of Instrument, no type 

410 # checking 

411 instrument = f"{instrument.__module__}.{instrument.__qualname__}" 

412 self._pipelineIR.instrument = instrument 

413 

414 def getInstrument(self): 

415 """Get the instrument from the pipeline. 

416 

417 Returns 

418 ------- 

419 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None 

420 A derived class object of a `lsst.daf.butler.instrument`, a string 

421 corresponding to a fully qualified `lsst.daf.butler.instrument` 

422 name, or None if the pipeline does not have an instrument. 

423 """ 

424 return self._pipelineIR.instrument 

425 

426 def addTask(self, task: Union[PipelineTask, str], label: str): 

427 """Add a new task to the pipeline, or replace a task that is already 

428 associated with the supplied label. 

429 

430 Parameters 

431 ---------- 

432 task: `PipelineTask` or `str` 

433 Either a derived class object of a `PipelineTask` or a string 

434 corresponding to a fully qualified `PipelineTask` name. 

435 label: `str` 

436 A label that is used to identify the `PipelineTask` being added 

437 """ 

438 if isinstance(task, str): 

439 taskName = task 

440 elif issubclass(task, PipelineTask): 

441 taskName = f"{task.__module__}.{task.__qualname__}" 

442 else: 

443 raise ValueError("task must be either a child class of PipelineTask or a string containing" 

444 " a fully qualified name to one") 

445 if not label: 

446 # in some cases (with command line-generated pipeline) tasks can 

447 # be defined without label which is not acceptable, use task 

448 # _DefaultName in that case 

449 if isinstance(task, str): 

450 task = doImport(task) 

451 label = task._DefaultName 

452 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

453 

454 def removeTask(self, label: str): 

455 """Remove a task from the pipeline. 

456 

457 Parameters 

458 ---------- 

459 label : `str` 

460 The label used to identify the task that is to be removed 

461 

462 Raises 

463 ------ 

464 KeyError 

465 If no task with that label exists in the pipeline 

466 

467 """ 

468 self._pipelineIR.tasks.pop(label) 

469 

470 def addConfigOverride(self, label: str, key: str, value: object): 

471 """Apply single config override. 

472 

473 Parameters 

474 ---------- 

475 label : `str` 

476 Label of the task. 

477 key: `str` 

478 Fully-qualified field name. 

479 value : object 

480 Value to be given to a field. 

481 """ 

482 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

483 

484 def addConfigFile(self, label: str, filename: str): 

485 """Add overrides from a specified file. 

486 

487 Parameters 

488 ---------- 

489 label : `str` 

490 The label used to identify the task associated with config to 

491 modify 

492 filename : `str` 

493 Path to the override file. 

494 """ 

495 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

496 

497 def addConfigPython(self, label: str, pythonString: str): 

498 """Add Overrides by running a snippet of python code against a config. 

499 

500 Parameters 

501 ---------- 

502 label : `str` 

503 The label used to identity the task associated with config to 

504 modify. 

505 pythonString: `str` 

506 A string which is valid python code to be executed. This is done 

507 with config as the only local accessible value. 

508 """ 

509 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

510 

511 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR): 

512 if label not in self._pipelineIR.tasks: 

513 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

514 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

515 

516 def toFile(self, filename: str): 

517 self._pipelineIR.to_file(filename) 

518 

519 def toExpandedPipeline(self) -> Generator[TaskDef]: 

520 """Returns a generator of TaskDefs which can be used to create quantum 

521 graphs. 

522 

523 Returns 

524 ------- 

525 generator : generator of `TaskDef` 

526 The generator returned will be the sorted iterator of tasks which 

527 are to be used in constructing a quantum graph. 

528 

529 Raises 

530 ------ 

531 NotImplementedError 

532 If a dataId is supplied in a config block. This is in place for 

533 future use 

534 """ 

535 taskDefs = [] 

536 for label, taskIR in self._pipelineIR.tasks.items(): 

537 taskClass = doImport(taskIR.klass) 

538 taskName = taskClass.__qualname__ 

539 config = taskClass.ConfigClass() 

540 overrides = ConfigOverrides() 

541 if self._pipelineIR.instrument is not None: 

542 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName) 

543 if taskIR.config is not None: 

544 for configIR in taskIR.config: 

545 if configIR.dataId is not None: 

546 raise NotImplementedError("Specializing a config on a partial data id is not yet " 

547 "supported in Pipeline definition") 

548 # only apply override if it applies to everything 

549 if configIR.dataId is None: 

550 if configIR.file: 

551 for configFile in configIR.file: 

552 overrides.addFileOverride(os.path.expandvars(configFile)) 

553 if configIR.python is not None: 

554 overrides.addPythonOverride(configIR.python) 

555 for key, value in configIR.rest.items(): 

556 overrides.addValueOverride(key, value) 

557 overrides.applyTo(config) 

558 # This may need to be revisited 

559 config.validate() 

560 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)) 

561 

562 # lets evaluate the contracts 

563 if self._pipelineIR.contracts is not None: 

564 label_to_config = {x.label: x.config for x in taskDefs} 

565 for contract in self._pipelineIR.contracts: 

566 # execute this in its own line so it can raise a good error 

567 # message if there was problems with the eval 

568 success = eval(contract.contract, None, label_to_config) 

569 if not success: 

570 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

571 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not " 

572 f"satisfied{extra_info}") 

573 

574 yield from pipeTools.orderPipeline(taskDefs) 

575 

576 def __len__(self): 

577 return len(self._pipelineIR.tasks) 

578 

579 def __eq__(self, other: "Pipeline"): 

580 if not isinstance(other, Pipeline): 

581 return False 

582 return self._pipelineIR == other._pipelineIR 

583 

584 

585@dataclass(frozen=True) 

586class TaskDatasetTypes: 

587 """An immutable struct that extracts and classifies the dataset types used 

588 by a `PipelineTask` 

589 """ 

590 

591 initInputs: NamedValueSet[DatasetType] 

592 """Dataset types that are needed as inputs in order to construct this Task. 

593 

594 Task-level `initInputs` may be classified as either 

595 `~PipelineDatasetTypes.initInputs` or 

596 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

597 """ 

598 

599 initOutputs: NamedValueSet[DatasetType] 

600 """Dataset types that may be written after constructing this Task. 

601 

602 Task-level `initOutputs` may be classified as either 

603 `~PipelineDatasetTypes.initOutputs` or 

604 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

605 """ 

606 

607 inputs: NamedValueSet[DatasetType] 

608 """Dataset types that are regular inputs to this Task. 

609 

610 If an input dataset needed for a Quantum cannot be found in the input 

611 collection(s) or produced by another Task in the Pipeline, that Quantum 

612 (and all dependent Quanta) will not be produced. 

613 

614 Task-level `inputs` may be classified as either 

615 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

616 at the Pipeline level. 

617 """ 

618 

619 prerequisites: NamedValueSet[DatasetType] 

620 """Dataset types that are prerequisite inputs to this Task. 

621 

622 Prerequisite inputs must exist in the input collection(s) before the 

623 pipeline is run, but do not constrain the graph - if a prerequisite is 

624 missing for a Quantum, `PrerequisiteMissingError` is raised. 

625 

626 Prerequisite inputs are not resolved until the second stage of 

627 QuantumGraph generation. 

628 """ 

629 

630 outputs: NamedValueSet[DatasetType] 

631 """Dataset types that are produced by this Task. 

632 

633 Task-level `outputs` may be classified as either 

634 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

635 at the Pipeline level. 

636 """ 

637 

638 @classmethod 

639 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes: 

640 """Extract and classify the dataset types from a single `PipelineTask`. 

641 

642 Parameters 

643 ---------- 

644 taskDef: `TaskDef` 

645 An instance of a `TaskDef` class for a particular `PipelineTask`. 

646 registry: `Registry` 

647 Registry used to construct normalized `DatasetType` objects and 

648 retrieve those that are incomplete. 

649 

650 Returns 

651 ------- 

652 types: `TaskDatasetTypes` 

653 The dataset types used by this task. 

654 """ 

655 def makeDatasetTypesSet(connectionType, freeze=True): 

656 """Constructs a set of true `DatasetType` objects 

657 

658 Parameters 

659 ---------- 

660 connectionType : `str` 

661 Name of the connection type to produce a set for, corresponds 

662 to an attribute of type `list` on the connection class instance 

663 freeze : `bool`, optional 

664 If `True`, call `NamedValueSet.freeze` on the object returned. 

665 

666 Returns 

667 ------- 

668 datasetTypes : `NamedValueSet` 

669 A set of all datasetTypes which correspond to the input 

670 connection type specified in the connection class of this 

671 `PipelineTask` 

672 

673 Notes 

674 ----- 

675 This function is a closure over the variables ``registry`` and 

676 ``taskDef``. 

677 """ 

678 datasetTypes = NamedValueSet() 

679 for c in iterConnections(taskDef.connections, connectionType): 

680 dimensions = set(getattr(c, 'dimensions', set())) 

681 if "skypix" in dimensions: 

682 try: 

683 datasetType = registry.getDatasetType(c.name) 

684 except LookupError as err: 

685 raise LookupError( 

686 f"DatasetType '{c.name}' referenced by " 

687 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

688 f"placeholder, but does not already exist in the registry. " 

689 f"Note that reference catalog names are now used as the dataset " 

690 f"type name instead of 'ref_cat'." 

691 ) from err 

692 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names) 

693 rest2 = set(dim.name for dim in datasetType.dimensions 

694 if not isinstance(dim, SkyPixDimension)) 

695 if rest1 != rest2: 

696 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in " 

697 f"connections ({rest1}) are inconsistent with those in " 

698 f"registry's version of this dataset ({rest2}).") 

699 else: 

700 # Component dataset types are not explicitly in the 

701 # registry. This complicates consistency checks with 

702 # registry and requires we work out the composite storage 

703 # class. 

704 registryDatasetType = None 

705 try: 

706 registryDatasetType = registry.getDatasetType(c.name) 

707 except KeyError: 

708 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) 

709 parentStorageClass = DatasetType.PlaceholderParentStorageClass \ 

710 if componentName else None 

711 datasetType = c.makeDatasetType( 

712 registry.dimensions, 

713 parentStorageClass=parentStorageClass 

714 ) 

715 registryDatasetType = datasetType 

716 else: 

717 datasetType = c.makeDatasetType( 

718 registry.dimensions, 

719 parentStorageClass=registryDatasetType.parentStorageClass 

720 ) 

721 

722 if registryDatasetType and datasetType != registryDatasetType: 

723 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with " 

724 f"registry definition ({registryDatasetType}) " 

725 f"for {taskDef.label}.") 

726 datasetTypes.add(datasetType) 

727 if freeze: 

728 datasetTypes.freeze() 

729 return datasetTypes 

730 

731 # optionally add output dataset for metadata 

732 outputs = makeDatasetTypesSet("outputs", freeze=False) 

733 if taskDef.metadataDatasetName is not None: 

734 # Metadata is supposed to be of the PropertySet type, its 

735 # dimensions correspond to a task quantum 

736 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

737 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")} 

738 outputs.freeze() 

739 

740 return cls( 

741 initInputs=makeDatasetTypesSet("initInputs"), 

742 initOutputs=makeDatasetTypesSet("initOutputs"), 

743 inputs=makeDatasetTypesSet("inputs"), 

744 prerequisites=makeDatasetTypesSet("prerequisiteInputs"), 

745 outputs=outputs, 

746 ) 

747 

748 

749@dataclass(frozen=True) 

750class PipelineDatasetTypes: 

751 """An immutable struct that classifies the dataset types used in a 

752 `Pipeline`. 

753 """ 

754 

755 initInputs: NamedValueSet[DatasetType] 

756 """Dataset types that are needed as inputs in order to construct the Tasks 

757 in this Pipeline. 

758 

759 This does not include dataset types that are produced when constructing 

760 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

761 """ 

762 

763 initOutputs: NamedValueSet[DatasetType] 

764 """Dataset types that may be written after constructing the Tasks in this 

765 Pipeline. 

766 

767 This does not include dataset types that are also used as inputs when 

768 constructing other Tasks in the Pipeline (these are classified as 

769 `initIntermediates`). 

770 """ 

771 

772 initIntermediates: NamedValueSet[DatasetType] 

773 """Dataset types that are both used when constructing one or more Tasks 

774 in the Pipeline and produced as a side-effect of constructing another 

775 Task in the Pipeline. 

776 """ 

777 

778 inputs: NamedValueSet[DatasetType] 

779 """Dataset types that are regular inputs for the full pipeline. 

780 

781 If an input dataset needed for a Quantum cannot be found in the input 

782 collection(s), that Quantum (and all dependent Quanta) will not be 

783 produced. 

784 """ 

785 

786 prerequisites: NamedValueSet[DatasetType] 

787 """Dataset types that are prerequisite inputs for the full Pipeline. 

788 

789 Prerequisite inputs must exist in the input collection(s) before the 

790 pipeline is run, but do not constrain the graph - if a prerequisite is 

791 missing for a Quantum, `PrerequisiteMissingError` is raised. 

792 

793 Prerequisite inputs are not resolved until the second stage of 

794 QuantumGraph generation. 

795 """ 

796 

797 intermediates: NamedValueSet[DatasetType] 

798 """Dataset types that are output by one Task in the Pipeline and consumed 

799 as inputs by one or more other Tasks in the Pipeline. 

800 """ 

801 

802 outputs: NamedValueSet[DatasetType] 

803 """Dataset types that are output by a Task in the Pipeline and not consumed 

804 by any other Task in the Pipeline. 

805 """ 

806 

807 byTask: Mapping[str, TaskDatasetTypes] 

808 """Per-Task dataset types, keyed by label in the `Pipeline`. 

809 

810 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

811 neither has been modified since the dataset types were extracted, of 

812 course). 

813 """ 

814 

815 @classmethod 

816 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes: 

817 """Extract and classify the dataset types from all tasks in a 

818 `Pipeline`. 

819 

820 Parameters 

821 ---------- 

822 pipeline: `Pipeline` 

823 An ordered collection of tasks that can be run together. 

824 registry: `Registry` 

825 Registry used to construct normalized `DatasetType` objects and 

826 retrieve those that are incomplete. 

827 

828 Returns 

829 ------- 

830 types: `PipelineDatasetTypes` 

831 The dataset types used by this `Pipeline`. 

832 

833 Raises 

834 ------ 

835 ValueError 

836 Raised if Tasks are inconsistent about which datasets are marked 

837 prerequisite. This indicates that the Tasks cannot be run as part 

838 of the same `Pipeline`. 

839 """ 

840 allInputs = NamedValueSet() 

841 allOutputs = NamedValueSet() 

842 allInitInputs = NamedValueSet() 

843 allInitOutputs = NamedValueSet() 

844 prerequisites = NamedValueSet() 

845 byTask = dict() 

846 if isinstance(pipeline, Pipeline): 

847 pipeline = pipeline.toExpandedPipeline() 

848 for taskDef in pipeline: 

849 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

850 allInitInputs |= thisTask.initInputs 

851 allInitOutputs |= thisTask.initOutputs 

852 allInputs |= thisTask.inputs 

853 prerequisites |= thisTask.prerequisites 

854 allOutputs |= thisTask.outputs 

855 byTask[taskDef.label] = thisTask 

856 if not prerequisites.isdisjoint(allInputs): 

857 raise ValueError("{} marked as both prerequisites and regular inputs".format( 

858 {dt.name for dt in allInputs & prerequisites} 

859 )) 

860 if not prerequisites.isdisjoint(allOutputs): 

861 raise ValueError("{} marked as both prerequisites and outputs".format( 

862 {dt.name for dt in allOutputs & prerequisites} 

863 )) 

864 # Make sure that components which are marked as inputs get treated as 

865 # intermediates if there is an output which produces the composite 

866 # containing the component 

867 intermediateComponents = NamedValueSet() 

868 intermediateComposites = NamedValueSet() 

869 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

870 for dsType in allInputs: 

871 # get the name of a possible component 

872 name, component = dsType.nameAndComponent() 

873 # if there is a component name, that means this is a component 

874 # DatasetType, if there is an output which produces the parent of 

875 # this component, treat this input as an intermediate 

876 if component is not None: 

877 if name in outputNameMapping: 

878 if outputNameMapping[name].dimensions != dsType.dimensions: 

879 raise ValueError(f"Component dataset type {dsType.name} has different " 

880 f"dimensions ({dsType.dimensions}) than its parent " 

881 f"({outputNameMapping[name].dimensions}).") 

882 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass, 

883 universe=registry.dimensions) 

884 intermediateComponents.add(dsType) 

885 intermediateComposites.add(composite) 

886 

887 def checkConsistency(a: NamedValueSet, b: NamedValueSet): 

888 common = a.names & b.names 

889 for name in common: 

890 if a[name] != b[name]: 

891 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") 

892 

893 checkConsistency(allInitInputs, allInitOutputs) 

894 checkConsistency(allInputs, allOutputs) 

895 checkConsistency(allInputs, intermediateComposites) 

896 checkConsistency(allOutputs, intermediateComposites) 

897 

898 def frozen(s: NamedValueSet) -> NamedValueSet: 

899 s.freeze() 

900 return s 

901 

902 return cls( 

903 initInputs=frozen(allInitInputs - allInitOutputs), 

904 initIntermediates=frozen(allInitInputs & allInitOutputs), 

905 initOutputs=frozen(allInitOutputs - allInitInputs), 

906 inputs=frozen(allInputs - allOutputs - intermediateComponents), 

907 intermediates=frozen(allInputs & allOutputs | intermediateComponents), 

908 outputs=frozen(allOutputs - allInputs - intermediateComposites), 

909 prerequisites=frozen(prerequisites), 

910 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

911 )