Coverage for python/lsst/pipe/base/pipeline.py: 23%

454 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Pipeline class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"] 

28 

29import copy 

30import logging 

31import re 

32import urllib.parse 

33 

34# ------------------------------- 

35# Imports of standard modules -- 

36# ------------------------------- 

37from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Set 

38from dataclasses import dataclass 

39from types import MappingProxyType 

40from typing import TYPE_CHECKING, ClassVar, cast 

41 

42# ----------------------------- 

43# Imports for other modules -- 

44from lsst.daf.butler import ( 

45 DataCoordinate, 

46 DatasetType, 

47 DimensionUniverse, 

48 NamedValueSet, 

49 Registry, 

50 SkyPixDimension, 

51) 

52from lsst.resources import ResourcePath, ResourcePathExpression 

53from lsst.utils import doImportType 

54from lsst.utils.introspection import get_full_type_name 

55 

56from . import automatic_connection_constants as acc 

57from . import pipeline_graph, pipelineIR 

58from ._instrument import Instrument as PipeBaseInstrument 

59from .config import PipelineTaskConfig 

60from .connections import PipelineTaskConnections, iterConnections 

61from .connectionTypes import Input 

62from .pipelineTask import PipelineTask 

63 

64if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 

65 from lsst.obs.base import Instrument 

66 from lsst.pex.config import Config 

67 

68# ---------------------------------- 

69# Local non-exported definitions -- 

70# ---------------------------------- 

71 

72_LOG = logging.getLogger(__name__) 

73 

74# ------------------------ 

75# Exported definitions -- 

76# ------------------------ 

77 

78 

79@dataclass 

80class LabelSpecifier: 

81 """A structure to specify a subset of labels to load 

82 

83 This structure may contain a set of labels to be used in subsetting a 

84 pipeline, or a beginning and end point. Beginning or end may be empty, 

85 in which case the range will be a half open interval. Unlike python 

86 iteration bounds, end bounds are *INCLUDED*. Note that range based 

87 selection is not well defined for pipelines that are not linear in nature, 

88 and correct behavior is not guaranteed, or may vary from run to run. 

89 """ 

90 

91 labels: set[str] | None = None 

92 begin: str | None = None 

93 end: str | None = None 

94 

95 def __post_init__(self) -> None: 

96 if self.labels is not None and (self.begin or self.end): 

97 raise ValueError( 

98 "This struct can only be initialized with a labels set or a begin (and/or) end specifier" 

99 ) 

100 

101 

102class TaskDef: 

103 """TaskDef is a collection of information about task needed by Pipeline. 

104 

105 The information includes task name, configuration object and optional 

106 task class. This class is just a collection of attributes and it exposes 

107 all of them so that attributes could potentially be modified in place 

108 (e.g. if configuration needs extra overrides). 

109 

110 Attributes 

111 ---------- 

112 taskName : `str`, optional 

113 The fully-qualified `PipelineTask` class name. If not provided, 

114 ``taskClass`` must be. 

115 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional 

116 Instance of the configuration class corresponding to this task class, 

117 usually with all overrides applied. This config will be frozen. If 

118 not provided, ``taskClass`` must be provided and 

119 ``taskClass.ConfigClass()`` will be used. 

120 taskClass : `type`, optional 

121 `PipelineTask` class object; if provided and ``taskName`` is as well, 

122 the caller guarantees that they are consistent. If not provided, 

123 ``taskName`` is used to import the type. 

124 label : `str`, optional 

125 Task label, usually a short string unique in a pipeline. If not 

126 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will 

127 be used. 

128 connections : `PipelineTaskConnections`, optional 

129 Object that describes the dataset types used by the task. If not 

130 provided, one will be constructed from the given configuration. If 

131 provided, it is assumed that ``config`` has already been validated 

132 and frozen. 

133 """ 

134 

135 def __init__( 

136 self, 

137 taskName: str | None = None, 

138 config: PipelineTaskConfig | None = None, 

139 taskClass: type[PipelineTask] | None = None, 

140 label: str | None = None, 

141 connections: PipelineTaskConnections | None = None, 

142 ): 

143 if taskName is None: 

144 if taskClass is None: 

145 raise ValueError("At least one of `taskName` and `taskClass` must be provided.") 

146 taskName = get_full_type_name(taskClass) 

147 elif taskClass is None: 

148 taskClass = doImportType(taskName) 

149 if config is None: 

150 if taskClass is None: 

151 raise ValueError("`taskClass` must be provided if `config` is not.") 

152 config = taskClass.ConfigClass() 

153 if label is None: 

154 if taskClass is None: 

155 raise ValueError("`taskClass` must be provided if `label` is not.") 

156 label = taskClass._DefaultName 

157 self.taskName = taskName 

158 if connections is None: 

159 # If we don't have connections yet, assume the config hasn't been 

160 # validated yet. 

161 try: 

162 config.validate() 

163 except Exception: 

164 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) 

165 raise 

166 config.freeze() 

167 connections = config.connections.ConnectionsClass(config=config) 

168 self.config = config 

169 self.taskClass = taskClass 

170 self.label = label 

171 self.connections = connections 

172 

173 @property 

174 def configDatasetName(self) -> str: 

175 """Name of a dataset type for configuration of this task (`str`)""" 

176 return acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=self.label) 

177 

178 @property 

179 def metadataDatasetName(self) -> str: 

180 """Name of a dataset type for metadata of this task (`str`)""" 

181 return self.makeMetadataDatasetName(self.label) 

182 

183 @classmethod 

184 def makeMetadataDatasetName(cls, label: str) -> str: 

185 """Construct the name of the dataset type for metadata for a task. 

186 

187 Parameters 

188 ---------- 

189 label : `str` 

190 Label for the task within its pipeline. 

191 

192 Returns 

193 ------- 

194 name : `str` 

195 Name of the task's metadata dataset type. 

196 """ 

197 return acc.METADATA_OUTPUT_TEMPLATE.format(label=label) 

198 

199 @property 

200 def logOutputDatasetName(self) -> str | None: 

201 """Name of a dataset type for log output from this task, `None` if 

202 logs are not to be saved (`str`) 

203 """ 

204 if self.config.saveLogOutput: 

205 return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label) 

206 else: 

207 return None 

208 

209 def __str__(self) -> str: 

210 rep = "TaskDef(" + self.taskName 

211 if self.label: 

212 rep += ", label=" + self.label 

213 rep += ")" 

214 return rep 

215 

216 def __eq__(self, other: object) -> bool: 

217 if not isinstance(other, TaskDef): 

218 return False 

219 # This does not consider equality of configs when determining equality 

220 # as config equality is a difficult thing to define. Should be updated 

221 # after DM-27847 

222 return self.taskClass == other.taskClass and self.label == other.label 

223 

224 def __hash__(self) -> int: 

225 return hash((self.taskClass, self.label)) 

226 

227 @classmethod 

228 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef: 

229 """Unpickle pickle. Custom callable for unpickling. 

230 

231 All arguments are forwarded directly to the constructor; this 

232 trampoline is only needed because ``__reduce__`` callables can't be 

233 called with keyword arguments. 

234 """ 

235 return cls(taskName=taskName, config=config, label=label) 

236 

237 def __reduce__(self) -> tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], tuple[str, Config, str]]: 

238 return (self._unreduce, (self.taskName, self.config, self.label)) 

239 

240 

241class Pipeline: 

242 """A `Pipeline` is a representation of a series of tasks to run, and the 

243 configuration for those tasks. 

244 

245 Parameters 

246 ---------- 

247 description : `str` 

248 A description of that this pipeline does. 

249 """ 

250 

251 def __init__(self, description: str): 

252 pipeline_dict = {"description": description, "tasks": {}} 

253 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

254 

255 @classmethod 

256 def fromFile(cls, filename: str) -> Pipeline: 

257 """Load a pipeline defined in a pipeline yaml file. 

258 

259 Parameters 

260 ---------- 

261 filename: `str` 

262 A path that points to a pipeline defined in yaml format. This 

263 filename may also supply additional labels to be used in 

264 subsetting the loaded Pipeline. These labels are separated from 

265 the path by a ``#``, and may be specified as a comma separated 

266 list, or a range denoted as beginning..end. Beginning or end may 

267 be empty, in which case the range will be a half open interval. 

268 Unlike python iteration bounds, end bounds are *INCLUDED*. Note 

269 that range based selection is not well defined for pipelines that 

270 are not linear in nature, and correct behavior is not guaranteed, 

271 or may vary from run to run. 

272 

273 Returns 

274 ------- 

275 pipeline: `Pipeline` 

276 The pipeline loaded from specified location with appropriate (if 

277 any) subsetting. 

278 

279 Notes 

280 ----- 

281 This method attempts to prune any contracts that contain labels which 

282 are not in the declared subset of labels. This pruning is done using a 

283 string based matching due to the nature of contracts and may prune more 

284 than it should. 

285 """ 

286 return cls.from_uri(filename) 

287 

288 @classmethod 

289 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline: 

290 """Load a pipeline defined in a pipeline yaml file at a location 

291 specified by a URI. 

292 

293 Parameters 

294 ---------- 

295 uri : convertible to `~lsst.resources.ResourcePath` 

296 If a string is supplied this should be a URI path that points to a 

297 pipeline defined in yaml format, either as a direct path to the 

298 yaml file, or as a directory containing a ``pipeline.yaml`` file 

299 the form used by `write_to_uri` with ``expand=True``). This uri may 

300 also supply additional labels to be used in subsetting the loaded 

301 `Pipeline`. These labels are separated from the path by a ``#``, 

302 and may be specified as a comma separated list, or a range denoted 

303 as beginning..end. Beginning or end may be empty, in which case the 

304 range will be a half open interval. Unlike python iteration bounds, 

305 end bounds are *INCLUDED*. Note that range based selection is not 

306 well defined for pipelines that are not linear in nature, and 

307 correct behavior is not guaranteed, or may vary from run to run. 

308 The same specifiers can be used with a 

309 `~lsst.resources.ResourcePath` object, by being the sole contents 

310 in the fragments attribute. 

311 

312 Returns 

313 ------- 

314 pipeline : `Pipeline` 

315 The pipeline loaded from specified location with appropriate (if 

316 any) subsetting. 

317 

318 Notes 

319 ----- 

320 This method attempts to prune any contracts that contain labels which 

321 are not in the declared subset of labels. This pruning is done using a 

322 string based matching due to the nature of contracts and may prune more 

323 than it should. 

324 """ 

325 # Split up the uri and any labels that were supplied 

326 uri, label_specifier = cls._parse_file_specifier(uri) 

327 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri)) 

328 

329 # If there are labels supplied, only keep those 

330 if label_specifier is not None: 

331 pipeline = pipeline.subsetFromLabels(label_specifier) 

332 return pipeline 

333 

334 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline: 

335 """Subset a pipeline to contain only labels specified in labelSpecifier 

336 

337 Parameters 

338 ---------- 

339 labelSpecifier : `labelSpecifier` 

340 Object containing labels that describes how to subset a pipeline. 

341 

342 Returns 

343 ------- 

344 pipeline : `Pipeline` 

345 A new pipeline object that is a subset of the old pipeline 

346 

347 Raises 

348 ------ 

349 ValueError 

350 Raised if there is an issue with specified labels 

351 

352 Notes 

353 ----- 

354 This method attempts to prune any contracts that contain labels which 

355 are not in the declared subset of labels. This pruning is done using a 

356 string based matching due to the nature of contracts and may prune more 

357 than it should. 

358 """ 

359 # Labels supplied as a set 

360 if labelSpecifier.labels: 

361 labelSet = labelSpecifier.labels 

362 # Labels supplied as a range, first create a list of all the labels 

363 # in the pipeline sorted according to task dependency. Then only 

364 # keep labels that lie between the supplied bounds 

365 else: 

366 # Create a copy of the pipeline to use when assessing the label 

367 # ordering. Use a dict for fast searching while preserving order. 

368 # Remove contracts so they do not fail in the expansion step. This 

369 # is needed because a user may only configure the tasks they intend 

370 # to run, which may cause some contracts to fail if they will later 

371 # be dropped 

372 pipeline = copy.deepcopy(self) 

373 pipeline._pipelineIR.contracts = [] 

374 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()} 

375 

376 # Verify the bounds are in the labels 

377 if labelSpecifier.begin is not None: 

378 if labelSpecifier.begin not in labels: 

379 raise ValueError( 

380 f"Beginning of range subset, {labelSpecifier.begin}, not found in pipeline definition" 

381 ) 

382 if labelSpecifier.end is not None: 

383 if labelSpecifier.end not in labels: 

384 raise ValueError( 

385 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition" 

386 ) 

387 

388 labelSet = set() 

389 for label in labels: 

390 if labelSpecifier.begin is not None: 

391 if label != labelSpecifier.begin: 

392 continue 

393 else: 

394 labelSpecifier.begin = None 

395 labelSet.add(label) 

396 if labelSpecifier.end is not None and label == labelSpecifier.end: 

397 break 

398 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet)) 

399 

400 @staticmethod 

401 def _parse_file_specifier(uri: ResourcePathExpression) -> tuple[ResourcePath, LabelSpecifier | None]: 

402 """Split appart a uri and any possible label subsets""" 

403 if isinstance(uri, str): 

404 # This is to support legacy pipelines during transition 

405 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri) 

406 if num_replace: 

407 raise ValueError( 

408 f"The pipeline file {uri} seems to use the legacy :" 

409 " to separate labels, please use # instead." 

410 ) 

411 if uri.count("#") > 1: 

412 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load") 

413 # Everything else can be converted directly to ResourcePath. 

414 uri = ResourcePath(uri) 

415 label_subset = uri.fragment or None 

416 

417 specifier: LabelSpecifier | None 

418 if label_subset is not None: 

419 label_subset = urllib.parse.unquote(label_subset) 

420 args: dict[str, set[str] | str | None] 

421 # labels supplied as a list 

422 if "," in label_subset: 

423 if ".." in label_subset: 

424 raise ValueError( 

425 "Can only specify a list of labels or a rangewhen loading a Pipline not both" 

426 ) 

427 args = {"labels": set(label_subset.split(","))} 

428 # labels supplied as a range 

429 elif ".." in label_subset: 

430 # Try to de-structure the labelSubset, this will fail if more 

431 # than one range is specified 

432 begin, end, *rest = label_subset.split("..") 

433 if rest: 

434 raise ValueError("Only one range can be specified when loading a pipeline") 

435 args = {"begin": begin if begin else None, "end": end if end else None} 

436 # Assume anything else is a single label 

437 else: 

438 args = {"labels": {label_subset}} 

439 

440 # MyPy doesn't like how cavalier kwarg construction is with types. 

441 specifier = LabelSpecifier(**args) # type: ignore 

442 else: 

443 specifier = None 

444 

445 return uri, specifier 

446 

447 @classmethod 

448 def fromString(cls, pipeline_string: str) -> Pipeline: 

449 """Create a pipeline from string formatted as a pipeline document. 

450 

451 Parameters 

452 ---------- 

453 pipeline_string : `str` 

454 A string that is formatted according like a pipeline document 

455 

456 Returns 

457 ------- 

458 pipeline: `Pipeline` 

459 """ 

460 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

461 return pipeline 

462 

463 @classmethod 

464 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

465 """Create a pipeline from an already created `PipelineIR` object. 

466 

467 Parameters 

468 ---------- 

469 deserialized_pipeline: `PipelineIR` 

470 An already created pipeline intermediate representation object 

471 

472 Returns 

473 ------- 

474 pipeline: `Pipeline` 

475 """ 

476 pipeline = cls.__new__(cls) 

477 pipeline._pipelineIR = deserialized_pipeline 

478 return pipeline 

479 

480 @classmethod 

481 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline: 

482 """Create a new pipeline by copying an already existing `Pipeline`. 

483 

484 Parameters 

485 ---------- 

486 pipeline: `Pipeline` 

487 An already created pipeline intermediate representation object 

488 

489 Returns 

490 ------- 

491 pipeline: `Pipeline` 

492 """ 

493 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR)) 

494 

495 def __str__(self) -> str: 

496 return str(self._pipelineIR) 

497 

498 def mergePipeline(self, pipeline: Pipeline) -> None: 

499 """Merge another in-memory `Pipeline` object into this one. 

500 

501 This merges another pipeline into this object, as if it were declared 

502 in the import block of the yaml definition of this pipeline. This 

503 modifies this pipeline in place. 

504 

505 Parameters 

506 ---------- 

507 pipeline : `Pipeline` 

508 The `Pipeline` object that is to be merged into this object. 

509 """ 

510 self._pipelineIR.merge_pipelines((pipeline._pipelineIR,)) 

511 

512 def addLabelToSubset(self, subset: str, label: str) -> None: 

513 """Add a task label from the specified subset. 

514 

515 Parameters 

516 ---------- 

517 subset : `str` 

518 The labeled subset to modify 

519 label : `str` 

520 The task label to add to the specified subset. 

521 

522 Raises 

523 ------ 

524 ValueError 

525 Raised if the specified subset does not exist within the pipeline. 

526 Raised if the specified label does not exist within the pipeline. 

527 """ 

528 if label not in self._pipelineIR.tasks: 

529 raise ValueError(f"Label {label} does not appear within the pipeline") 

530 if subset not in self._pipelineIR.labeled_subsets: 

531 raise ValueError(f"Subset {subset} does not appear within the pipeline") 

532 self._pipelineIR.labeled_subsets[subset].subset.add(label) 

533 

534 def removeLabelFromSubset(self, subset: str, label: str) -> None: 

535 """Remove a task label from the specified subset. 

536 

537 Parameters 

538 ---------- 

539 subset : `str` 

540 The labeled subset to modify 

541 label : `str` 

542 The task label to remove from the specified subset. 

543 

544 Raises 

545 ------ 

546 ValueError 

547 Raised if the specified subset does not exist in the pipeline. 

548 Raised if the specified label does not exist within the specified 

549 subset. 

550 """ 

551 if subset not in self._pipelineIR.labeled_subsets: 

552 raise ValueError(f"Subset {subset} does not appear within the pipeline") 

553 if label not in self._pipelineIR.labeled_subsets[subset].subset: 

554 raise ValueError(f"Label {label} does not appear within the pipeline") 

555 self._pipelineIR.labeled_subsets[subset].subset.remove(label) 

556 

557 def findSubsetsWithLabel(self, label: str) -> set[str]: 

558 """Find any subsets which may contain the specified label. 

559 

560 This function returns the name of subsets which return the specified 

561 label. May return an empty set if there are no subsets, or no subsets 

562 containing the specified label. 

563 

564 Parameters 

565 ---------- 

566 label : `str` 

567 The task label to use in membership check 

568 

569 Returns 

570 ------- 

571 subsets : `set` of `str` 

572 Returns a set (possibly empty) of subsets names which contain the 

573 specified label. 

574 

575 Raises 

576 ------ 

577 ValueError 

578 Raised if the specified label does not exist within this pipeline. 

579 """ 

580 results = set() 

581 if label not in self._pipelineIR.tasks: 

582 raise ValueError(f"Label {label} does not appear within the pipeline") 

583 for subset in self._pipelineIR.labeled_subsets.values(): 

584 if label in subset.subset: 

585 results.add(subset.label) 

586 return results 

587 

588 def addInstrument(self, instrument: Instrument | str) -> None: 

589 """Add an instrument to the pipeline, or replace an instrument that is 

590 already defined. 

591 

592 Parameters 

593 ---------- 

594 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

595 Either a derived class object of a `lsst.daf.butler.instrument` or 

596 a string corresponding to a fully qualified 

597 `lsst.daf.butler.instrument` name. 

598 """ 

599 if isinstance(instrument, str): 

600 pass 

601 else: 

602 # TODO: assume that this is a subclass of Instrument, no type 

603 # checking 

604 instrument = get_full_type_name(instrument) 

605 self._pipelineIR.instrument = instrument 

606 

607 def getInstrument(self) -> str | None: 

608 """Get the instrument from the pipeline. 

609 

610 Returns 

611 ------- 

612 instrument : `str`, or None 

613 The fully qualified name of a `lsst.obs.base.Instrument` subclass, 

614 name, or None if the pipeline does not have an instrument. 

615 """ 

616 return self._pipelineIR.instrument 

617 

618 def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate: 

619 """Return a data ID with all dimension constraints embedded in the 

620 pipeline. 

621 

622 Parameters 

623 ---------- 

624 universe : `lsst.daf.butler.DimensionUniverse` 

625 Object that defines all dimensions. 

626 

627 Returns 

628 ------- 

629 data_id : `lsst.daf.butler.DataCoordinate` 

630 Data ID with all dimension constraints embedded in the 

631 pipeline. 

632 """ 

633 instrument_class_name = self._pipelineIR.instrument 

634 if instrument_class_name is not None: 

635 instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name)) 

636 if instrument_class is not None: 

637 return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe) 

638 return DataCoordinate.makeEmpty(universe) 

639 

640 def addTask(self, task: type[PipelineTask] | str, label: str) -> None: 

641 """Add a new task to the pipeline, or replace a task that is already 

642 associated with the supplied label. 

643 

644 Parameters 

645 ---------- 

646 task: `PipelineTask` or `str` 

647 Either a derived class object of a `PipelineTask` or a string 

648 corresponding to a fully qualified `PipelineTask` name. 

649 label: `str` 

650 A label that is used to identify the `PipelineTask` being added 

651 """ 

652 if isinstance(task, str): 

653 taskName = task 

654 elif issubclass(task, PipelineTask): 

655 taskName = get_full_type_name(task) 

656 else: 

657 raise ValueError( 

658 "task must be either a child class of PipelineTask or a string containing" 

659 " a fully qualified name to one" 

660 ) 

661 if not label: 

662 # in some cases (with command line-generated pipeline) tasks can 

663 # be defined without label which is not acceptable, use task 

664 # _DefaultName in that case 

665 if isinstance(task, str): 

666 task_class = cast(PipelineTask, doImportType(task)) 

667 label = task_class._DefaultName 

668 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

669 

670 def removeTask(self, label: str) -> None: 

671 """Remove a task from the pipeline. 

672 

673 Parameters 

674 ---------- 

675 label : `str` 

676 The label used to identify the task that is to be removed 

677 

678 Raises 

679 ------ 

680 KeyError 

681 If no task with that label exists in the pipeline 

682 

683 """ 

684 self._pipelineIR.tasks.pop(label) 

685 

686 def addConfigOverride(self, label: str, key: str, value: object) -> None: 

687 """Apply single config override. 

688 

689 Parameters 

690 ---------- 

691 label : `str` 

692 Label of the task. 

693 key: `str` 

694 Fully-qualified field name. 

695 value : object 

696 Value to be given to a field. 

697 """ 

698 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

699 

700 def addConfigFile(self, label: str, filename: str) -> None: 

701 """Add overrides from a specified file. 

702 

703 Parameters 

704 ---------- 

705 label : `str` 

706 The label used to identify the task associated with config to 

707 modify 

708 filename : `str` 

709 Path to the override file. 

710 """ 

711 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

712 

713 def addConfigPython(self, label: str, pythonString: str) -> None: 

714 """Add Overrides by running a snippet of python code against a config. 

715 

716 Parameters 

717 ---------- 

718 label : `str` 

719 The label used to identity the task associated with config to 

720 modify. 

721 pythonString: `str` 

722 A string which is valid python code to be executed. This is done 

723 with config as the only local accessible value. 

724 """ 

725 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

726 

727 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None: 

728 if label == "parameters": 

729 self._pipelineIR.parameters.mapping.update(newConfig.rest) 

730 if newConfig.file: 

731 raise ValueError("Setting parameters section with config file is not supported") 

732 if newConfig.python: 

733 raise ValueError("Setting parameters section using python block in unsupported") 

734 return 

735 if label not in self._pipelineIR.tasks: 

736 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

737 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

738 

739 def write_to_uri(self, uri: ResourcePathExpression) -> None: 

740 """Write the pipeline to a file or directory. 

741 

742 Parameters 

743 ---------- 

744 uri : convertible to `~lsst.resources.ResourcePath` 

745 URI to write to; may have any scheme with 

746 `~lsst.resources.ResourcePath` write support or no scheme for a 

747 local file/directory. Should have a ``.yaml`` extension. 

748 """ 

749 self._pipelineIR.write_to_uri(uri) 

750 

751 def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineGraph: 

752 """Construct a pipeline graph from this pipeline. 

753 

754 Constructing a graph applies all configuration overrides, freezes all 

755 configuration, checks all contracts, and checks for dataset type 

756 consistency between tasks (as much as possible without access to a data 

757 repository). It cannot be reversed. 

758 

759 Parameters 

760 ---------- 

761 registry : `lsst.daf.butler.Registry`, optional 

762 Data repository client. If provided, the graph's dataset types 

763 and dimensions will be resolved (see `PipelineGraph.resolve`). 

764 

765 Returns 

766 ------- 

767 graph : `pipeline_graph.PipelineGraph` 

768 Representation of the pipeline as a graph. 

769 """ 

770 instrument_class_name = self._pipelineIR.instrument 

771 data_id = {} 

772 if instrument_class_name is not None: 

773 instrument_class: type[Instrument] = doImportType(instrument_class_name) 

774 if instrument_class is not None: 

775 data_id["instrument"] = instrument_class.getName() 

776 graph = pipeline_graph.PipelineGraph(data_id=data_id) 

777 graph.description = self._pipelineIR.description 

778 for label in self._pipelineIR.tasks: 

779 self._add_task_to_graph(label, graph) 

780 if self._pipelineIR.contracts is not None: 

781 label_to_config = {x.label: x.config for x in graph.tasks.values()} 

782 for contract in self._pipelineIR.contracts: 

783 # execute this in its own line so it can raise a good error 

784 # message if there was problems with the eval 

785 success = eval(contract.contract, None, label_to_config) 

786 if not success: 

787 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

788 raise pipelineIR.ContractError( 

789 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" 

790 ) 

791 for label, subset in self._pipelineIR.labeled_subsets.items(): 

792 graph.add_task_subset( 

793 label, subset.subset, subset.description if subset.description is not None else "" 

794 ) 

795 graph.sort() 

796 if registry is not None: 

797 graph.resolve(registry) 

798 return graph 

799 

800 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: 

801 r"""Return a generator of `TaskDef`\s which can be used to create 

802 quantum graphs. 

803 

804 Returns 

805 ------- 

806 generator : generator of `TaskDef` 

807 The generator returned will be the sorted iterator of tasks which 

808 are to be used in constructing a quantum graph. 

809 

810 Raises 

811 ------ 

812 NotImplementedError 

813 If a dataId is supplied in a config block. This is in place for 

814 future use 

815 """ 

816 yield from self.to_graph()._iter_task_defs() 

817 

818 def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None: 

819 """Add a single task from this pipeline to a pipeline graph that is 

820 under construction. 

821 

822 Parameters 

823 ---------- 

824 label : `str` 

825 Label for the task to be added. 

826 graph : `pipeline_graph.PipelineGraph` 

827 Graph to add the task to. 

828 """ 

829 if (taskIR := self._pipelineIR.tasks.get(label)) is None: 

830 raise NameError(f"Label {label} does not appear in this pipeline") 

831 taskClass: type[PipelineTask] = doImportType(taskIR.klass) 

832 config = taskClass.ConfigClass() 

833 instrument: PipeBaseInstrument | None = None 

834 if (instrumentName := self._pipelineIR.instrument) is not None: 

835 instrument_cls: type = doImportType(instrumentName) 

836 instrument = instrument_cls() 

837 config.applyConfigOverrides( 

838 instrument, 

839 getattr(taskClass, "_DefaultName", ""), 

840 taskIR.config, 

841 self._pipelineIR.parameters, 

842 label, 

843 ) 

844 graph.add_task(label, taskClass, config) 

845 

846 def __iter__(self) -> Generator[TaskDef, None, None]: 

847 return self.toExpandedPipeline() 

848 

849 def __getitem__(self, item: str) -> TaskDef: 

850 # Making a whole graph and then making a TaskDef from that is pretty 

851 # backwards, but I'm hoping to deprecate this method shortly in favor 

852 # of making the graph explicitly and working with its node objects. 

853 graph = pipeline_graph.PipelineGraph() 

854 self._add_task_to_graph(item, graph) 

855 (result,) = graph._iter_task_defs() 

856 return result 

857 

858 def __len__(self) -> int: 

859 return len(self._pipelineIR.tasks) 

860 

861 def __eq__(self, other: object) -> bool: 

862 if not isinstance(other, Pipeline): 

863 return False 

864 elif self._pipelineIR == other._pipelineIR: 

865 # Shortcut: if the IR is the same, the expanded pipeline must be 

866 # the same as well. But the converse is not true. 

867 return True 

868 else: 

869 self_expanded = {td.label: (td.taskClass,) for td in self} 

870 other_expanded = {td.label: (td.taskClass,) for td in other} 

871 if self_expanded != other_expanded: 

872 return False 

873 # After DM-27847, we should compare configuration here, or better, 

874 # delegated to TaskDef.__eq__ after making that compare configurations. 

875 raise NotImplementedError( 

876 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847." 

877 ) 

878 

879 

880@dataclass(frozen=True) 

881class TaskDatasetTypes: 

882 """An immutable struct that extracts and classifies the dataset types used 

883 by a `PipelineTask` 

884 """ 

885 

886 initInputs: NamedValueSet[DatasetType] 

887 """Dataset types that are needed as inputs in order to construct this Task. 

888 

889 Task-level `initInputs` may be classified as either 

890 `~PipelineDatasetTypes.initInputs` or 

891 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

892 """ 

893 

894 initOutputs: NamedValueSet[DatasetType] 

895 """Dataset types that may be written after constructing this Task. 

896 

897 Task-level `initOutputs` may be classified as either 

898 `~PipelineDatasetTypes.initOutputs` or 

899 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

900 """ 

901 

902 inputs: NamedValueSet[DatasetType] 

903 """Dataset types that are regular inputs to this Task. 

904 

905 If an input dataset needed for a Quantum cannot be found in the input 

906 collection(s) or produced by another Task in the Pipeline, that Quantum 

907 (and all dependent Quanta) will not be produced. 

908 

909 Task-level `inputs` may be classified as either 

910 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

911 at the Pipeline level. 

912 """ 

913 

914 queryConstraints: NamedValueSet[DatasetType] 

915 """Regular inputs that should not be used as constraints on the initial 

916 QuantumGraph generation data ID query, according to their tasks 

917 (`NamedValueSet`). 

918 """ 

919 

920 prerequisites: NamedValueSet[DatasetType] 

921 """Dataset types that are prerequisite inputs to this Task. 

922 

923 Prerequisite inputs must exist in the input collection(s) before the 

924 pipeline is run, but do not constrain the graph - if a prerequisite is 

925 missing for a Quantum, `PrerequisiteMissingError` is raised. 

926 

927 Prerequisite inputs are not resolved until the second stage of 

928 QuantumGraph generation. 

929 """ 

930 

931 outputs: NamedValueSet[DatasetType] 

932 """Dataset types that are produced by this Task. 

933 

934 Task-level `outputs` may be classified as either 

935 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

936 at the Pipeline level. 

937 """ 

938 

939 @classmethod 

940 def fromTaskDef( 

941 cls, 

942 taskDef: TaskDef, 

943 *, 

944 registry: Registry, 

945 include_configs: bool = True, 

946 storage_class_mapping: Mapping[str, str] | None = None, 

947 ) -> TaskDatasetTypes: 

948 """Extract and classify the dataset types from a single `PipelineTask`. 

949 

950 Parameters 

951 ---------- 

952 taskDef: `TaskDef` 

953 An instance of a `TaskDef` class for a particular `PipelineTask`. 

954 registry: `Registry` 

955 Registry used to construct normalized 

956 `~lsst.daf.butler.DatasetType` objects and retrieve those that are 

957 incomplete. 

958 include_configs : `bool`, optional 

959 If `True` (default) include config dataset types as 

960 ``initOutputs``. 

961 storage_class_mapping : `~collections.abc.Mapping` of `str` to \ 

962 `~lsst.daf.butler.StorageClass`, optional 

963 If a taskdef contains a component dataset type that is unknown 

964 to the registry, its parent `~lsst.daf.butler.StorageClass` will 

965 be looked up in this mapping if it is supplied. If the mapping does 

966 not contain the composite dataset type, or the mapping is not 

967 supplied an exception will be raised. 

968 

969 Returns 

970 ------- 

971 types: `TaskDatasetTypes` 

972 The dataset types used by this task. 

973 

974 Raises 

975 ------ 

976 ValueError 

977 Raised if dataset type connection definition differs from 

978 registry definition. 

979 LookupError 

980 Raised if component parent StorageClass could not be determined 

981 and storage_class_mapping does not contain the composite type, or 

982 is set to None. 

983 """ 

984 

985 def makeDatasetTypesSet( 

986 connectionType: str, 

987 is_input: bool, 

988 freeze: bool = True, 

989 ) -> NamedValueSet[DatasetType]: 

990 """Construct a set of true `~lsst.daf.butler.DatasetType` objects. 

991 

992 Parameters 

993 ---------- 

994 connectionType : `str` 

995 Name of the connection type to produce a set for, corresponds 

996 to an attribute of type `list` on the connection class instance 

997 is_input : `bool` 

998 These are input dataset types, else they are output dataset 

999 types. 

1000 freeze : `bool`, optional 

1001 If `True`, call `NamedValueSet.freeze` on the object returned. 

1002 

1003 Returns 

1004 ------- 

1005 datasetTypes : `NamedValueSet` 

1006 A set of all datasetTypes which correspond to the input 

1007 connection type specified in the connection class of this 

1008 `PipelineTask` 

1009 

1010 Raises 

1011 ------ 

1012 ValueError 

1013 Raised if dataset type connection definition differs from 

1014 registry definition. 

1015 LookupError 

1016 Raised if component parent StorageClass could not be determined 

1017 and storage_class_mapping does not contain the composite type, 

1018 or is set to None. 

1019 

1020 Notes 

1021 ----- 

1022 This function is a closure over the variables ``registry`` and 

1023 ``taskDef``, and ``storage_class_mapping``. 

1024 """ 

1025 datasetTypes = NamedValueSet[DatasetType]() 

1026 for c in iterConnections(taskDef.connections, connectionType): 

1027 dimensions = set(getattr(c, "dimensions", set())) 

1028 if "skypix" in dimensions: 

1029 try: 

1030 datasetType = registry.getDatasetType(c.name) 

1031 except LookupError as err: 

1032 raise LookupError( 

1033 f"DatasetType '{c.name}' referenced by " 

1034 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

1035 "placeholder, but does not already exist in the registry. " 

1036 "Note that reference catalog names are now used as the dataset " 

1037 "type name instead of 'ref_cat'." 

1038 ) from err 

1039 rest1 = set(registry.dimensions.extract(dimensions - {"skypix"}).names) 

1040 rest2 = { 

1041 dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension) 

1042 } 

1043 if rest1 != rest2: 

1044 raise ValueError( 

1045 f"Non-skypix dimensions for dataset type {c.name} declared in " 

1046 f"connections ({rest1}) are inconsistent with those in " 

1047 f"registry's version of this dataset ({rest2})." 

1048 ) 

1049 else: 

1050 # Component dataset types are not explicitly in the 

1051 # registry. This complicates consistency checks with 

1052 # registry and requires we work out the composite storage 

1053 # class. 

1054 registryDatasetType = None 

1055 try: 

1056 registryDatasetType = registry.getDatasetType(c.name) 

1057 except KeyError: 

1058 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) 

1059 if componentName: 

1060 if storage_class_mapping is None or compositeName not in storage_class_mapping: 

1061 raise LookupError( 

1062 "Component parent class cannot be determined, and " 

1063 "composite name was not in storage class mapping, or no " 

1064 "storage_class_mapping was supplied" 

1065 ) from None 

1066 else: 

1067 parentStorageClass = storage_class_mapping[compositeName] 

1068 else: 

1069 parentStorageClass = None 

1070 datasetType = c.makeDatasetType( 

1071 registry.dimensions, parentStorageClass=parentStorageClass 

1072 ) 

1073 registryDatasetType = datasetType 

1074 else: 

1075 datasetType = c.makeDatasetType( 

1076 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass 

1077 ) 

1078 

1079 if registryDatasetType and datasetType != registryDatasetType: 

1080 # The dataset types differ but first check to see if 

1081 # they are compatible before raising. 

1082 if is_input: 

1083 # This DatasetType must be compatible on get. 

1084 is_compatible = datasetType.is_compatible_with(registryDatasetType) 

1085 else: 

1086 # Has to be able to be converted to expect type 

1087 # on put. 

1088 is_compatible = registryDatasetType.is_compatible_with(datasetType) 

1089 if is_compatible: 

1090 # For inputs we want the pipeline to use the 

1091 # pipeline definition, for outputs it should use 

1092 # the registry definition. 

1093 if not is_input: 

1094 datasetType = registryDatasetType 

1095 _LOG.debug( 

1096 "Dataset types differ (task %s != registry %s) but are compatible" 

1097 " for %s in %s.", 

1098 datasetType, 

1099 registryDatasetType, 

1100 "input" if is_input else "output", 

1101 taskDef.label, 

1102 ) 

1103 else: 

1104 try: 

1105 # Explicitly check for storage class just to 

1106 # make more specific message. 

1107 _ = datasetType.storageClass 

1108 except KeyError: 

1109 raise ValueError( 

1110 "Storage class does not exist for supplied dataset type " 

1111 f"{datasetType} for {taskDef.label}." 

1112 ) from None 

1113 raise ValueError( 

1114 f"Supplied dataset type ({datasetType}) inconsistent with " 

1115 f"registry definition ({registryDatasetType}) " 

1116 f"for {taskDef.label}." 

1117 ) 

1118 datasetTypes.add(datasetType) 

1119 if freeze: 

1120 datasetTypes.freeze() 

1121 return datasetTypes 

1122 

1123 # optionally add initOutput dataset for config 

1124 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False) 

1125 if include_configs: 

1126 initOutputs.add( 

1127 DatasetType( 

1128 taskDef.configDatasetName, 

1129 registry.dimensions.empty, 

1130 storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

1131 ) 

1132 ) 

1133 initOutputs.freeze() 

1134 

1135 # optionally add output dataset for metadata 

1136 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False) 

1137 

1138 # Metadata is supposed to be of the TaskMetadata type, its dimensions 

1139 # correspond to a task quantum. 

1140 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

1141 

1142 # Allow the storage class definition to be read from the existing 

1143 # dataset type definition if present. 

1144 try: 

1145 current = registry.getDatasetType(taskDef.metadataDatasetName) 

1146 except KeyError: 

1147 # No previous definition so use the default. 

1148 storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS 

1149 else: 

1150 storageClass = current.storageClass.name 

1151 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)}) 

1152 

1153 if taskDef.logOutputDatasetName is not None: 

1154 # Log output dimensions correspond to a task quantum. 

1155 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

1156 outputs.update( 

1157 { 

1158 DatasetType( 

1159 taskDef.logOutputDatasetName, 

1160 dimensions, 

1161 acc.LOG_OUTPUT_STORAGE_CLASS, 

1162 ) 

1163 } 

1164 ) 

1165 

1166 outputs.freeze() 

1167 

1168 inputs = makeDatasetTypesSet("inputs", is_input=True) 

1169 queryConstraints = NamedValueSet( 

1170 inputs[c.name] 

1171 for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs")) 

1172 if not c.deferGraphConstraint 

1173 ) 

1174 

1175 return cls( 

1176 initInputs=makeDatasetTypesSet("initInputs", is_input=True), 

1177 initOutputs=initOutputs, 

1178 inputs=inputs, 

1179 queryConstraints=queryConstraints, 

1180 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True), 

1181 outputs=outputs, 

1182 ) 

1183 

1184 

1185@dataclass(frozen=True) 

1186class PipelineDatasetTypes: 

1187 """An immutable struct that classifies the dataset types used in a 

1188 `Pipeline`. 

1189 """ 

1190 

1191 packagesDatasetName: ClassVar[str] = acc.PACKAGES_INIT_OUTPUT_NAME 

1192 """Name of a dataset type used to save package versions. 

1193 """ 

1194 

1195 initInputs: NamedValueSet[DatasetType] 

1196 """Dataset types that are needed as inputs in order to construct the Tasks 

1197 in this Pipeline. 

1198 

1199 This does not include dataset types that are produced when constructing 

1200 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

1201 """ 

1202 

1203 initOutputs: NamedValueSet[DatasetType] 

1204 """Dataset types that may be written after constructing the Tasks in this 

1205 Pipeline. 

1206 

1207 This does not include dataset types that are also used as inputs when 

1208 constructing other Tasks in the Pipeline (these are classified as 

1209 `initIntermediates`). 

1210 """ 

1211 

1212 initIntermediates: NamedValueSet[DatasetType] 

1213 """Dataset types that are both used when constructing one or more Tasks 

1214 in the Pipeline and produced as a side-effect of constructing another 

1215 Task in the Pipeline. 

1216 """ 

1217 

1218 inputs: NamedValueSet[DatasetType] 

1219 """Dataset types that are regular inputs for the full pipeline. 

1220 

1221 If an input dataset needed for a Quantum cannot be found in the input 

1222 collection(s), that Quantum (and all dependent Quanta) will not be 

1223 produced. 

1224 """ 

1225 

1226 queryConstraints: NamedValueSet[DatasetType] 

1227 """Regular inputs that should be used as constraints on the initial 

1228 QuantumGraph generation data ID query, according to their tasks 

1229 (`NamedValueSet`). 

1230 """ 

1231 

1232 prerequisites: NamedValueSet[DatasetType] 

1233 """Dataset types that are prerequisite inputs for the full Pipeline. 

1234 

1235 Prerequisite inputs must exist in the input collection(s) before the 

1236 pipeline is run, but do not constrain the graph - if a prerequisite is 

1237 missing for a Quantum, `PrerequisiteMissingError` is raised. 

1238 

1239 Prerequisite inputs are not resolved until the second stage of 

1240 QuantumGraph generation. 

1241 """ 

1242 

1243 intermediates: NamedValueSet[DatasetType] 

1244 """Dataset types that are output by one Task in the Pipeline and consumed 

1245 as inputs by one or more other Tasks in the Pipeline. 

1246 """ 

1247 

1248 outputs: NamedValueSet[DatasetType] 

1249 """Dataset types that are output by a Task in the Pipeline and not consumed 

1250 by any other Task in the Pipeline. 

1251 """ 

1252 

1253 byTask: Mapping[str, TaskDatasetTypes] 

1254 """Per-Task dataset types, keyed by label in the `Pipeline`. 

1255 

1256 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

1257 neither has been modified since the dataset types were extracted, of 

1258 course). 

1259 """ 

1260 

1261 @classmethod 

1262 def fromPipeline( 

1263 cls, 

1264 pipeline: Pipeline | Iterable[TaskDef], 

1265 *, 

1266 registry: Registry, 

1267 include_configs: bool = True, 

1268 include_packages: bool = True, 

1269 ) -> PipelineDatasetTypes: 

1270 """Extract and classify the dataset types from all tasks in a 

1271 `Pipeline`. 

1272 

1273 Parameters 

1274 ---------- 

1275 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ] 

1276 A collection of tasks that can be run together. 

1277 registry: `Registry` 

1278 Registry used to construct normalized 

1279 `~lsst.daf.butler.DatasetType` objects and retrieve those that are 

1280 incomplete. 

1281 include_configs : `bool`, optional 

1282 If `True` (default) include config dataset types as 

1283 ``initOutputs``. 

1284 include_packages : `bool`, optional 

1285 If `True` (default) include the dataset type for software package 

1286 versions in ``initOutputs``. 

1287 

1288 Returns 

1289 ------- 

1290 types: `PipelineDatasetTypes` 

1291 The dataset types used by this `Pipeline`. 

1292 

1293 Raises 

1294 ------ 

1295 ValueError 

1296 Raised if Tasks are inconsistent about which datasets are marked 

1297 prerequisite. This indicates that the Tasks cannot be run as part 

1298 of the same `Pipeline`. 

1299 """ 

1300 allInputs = NamedValueSet[DatasetType]() 

1301 allOutputs = NamedValueSet[DatasetType]() 

1302 allInitInputs = NamedValueSet[DatasetType]() 

1303 allInitOutputs = NamedValueSet[DatasetType]() 

1304 prerequisites = NamedValueSet[DatasetType]() 

1305 queryConstraints = NamedValueSet[DatasetType]() 

1306 byTask = dict() 

1307 if include_packages: 

1308 allInitOutputs.add( 

1309 DatasetType( 

1310 cls.packagesDatasetName, 

1311 registry.dimensions.empty, 

1312 storageClass=acc.PACKAGES_INIT_OUTPUT_STORAGE_CLASS, 

1313 ) 

1314 ) 

1315 # create a list of TaskDefs in case the input is a generator 

1316 pipeline = list(pipeline) 

1317 

1318 # collect all the output dataset types 

1319 typeStorageclassMap: dict[str, str] = {} 

1320 for taskDef in pipeline: 

1321 for outConnection in iterConnections(taskDef.connections, "outputs"): 

1322 typeStorageclassMap[outConnection.name] = outConnection.storageClass 

1323 

1324 for taskDef in pipeline: 

1325 thisTask = TaskDatasetTypes.fromTaskDef( 

1326 taskDef, 

1327 registry=registry, 

1328 include_configs=include_configs, 

1329 storage_class_mapping=typeStorageclassMap, 

1330 ) 

1331 allInitInputs.update(thisTask.initInputs) 

1332 allInitOutputs.update(thisTask.initOutputs) 

1333 allInputs.update(thisTask.inputs) 

1334 # Inputs are query constraints if any task considers them a query 

1335 # constraint. 

1336 queryConstraints.update(thisTask.queryConstraints) 

1337 prerequisites.update(thisTask.prerequisites) 

1338 allOutputs.update(thisTask.outputs) 

1339 byTask[taskDef.label] = thisTask 

1340 if not prerequisites.isdisjoint(allInputs): 

1341 raise ValueError( 

1342 "{} marked as both prerequisites and regular inputs".format( 

1343 {dt.name for dt in allInputs & prerequisites} 

1344 ) 

1345 ) 

1346 if not prerequisites.isdisjoint(allOutputs): 

1347 raise ValueError( 

1348 "{} marked as both prerequisites and outputs".format( 

1349 {dt.name for dt in allOutputs & prerequisites} 

1350 ) 

1351 ) 

1352 # Make sure that components which are marked as inputs get treated as 

1353 # intermediates if there is an output which produces the composite 

1354 # containing the component 

1355 intermediateComponents = NamedValueSet[DatasetType]() 

1356 intermediateComposites = NamedValueSet[DatasetType]() 

1357 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

1358 for dsType in allInputs: 

1359 # get the name of a possible component 

1360 name, component = dsType.nameAndComponent() 

1361 # if there is a component name, that means this is a component 

1362 # DatasetType, if there is an output which produces the parent of 

1363 # this component, treat this input as an intermediate 

1364 if component is not None: 

1365 # This needs to be in this if block, because someone might have 

1366 # a composite that is a pure input from existing data 

1367 if name in outputNameMapping: 

1368 intermediateComponents.add(dsType) 

1369 intermediateComposites.add(outputNameMapping[name]) 

1370 

1371 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None: 

1372 common = a.names & b.names 

1373 for name in common: 

1374 # Any compatibility is allowed. This function does not know 

1375 # if a dataset type is to be used for input or output. 

1376 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])): 

1377 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") 

1378 

1379 checkConsistency(allInitInputs, allInitOutputs) 

1380 checkConsistency(allInputs, allOutputs) 

1381 checkConsistency(allInputs, intermediateComposites) 

1382 checkConsistency(allOutputs, intermediateComposites) 

1383 

1384 def frozen(s: Set[DatasetType]) -> NamedValueSet[DatasetType]: 

1385 assert isinstance(s, NamedValueSet) 

1386 s.freeze() 

1387 return s 

1388 

1389 inputs = frozen(allInputs - allOutputs - intermediateComponents) 

1390 

1391 return cls( 

1392 initInputs=frozen(allInitInputs - allInitOutputs), 

1393 initIntermediates=frozen(allInitInputs & allInitOutputs), 

1394 initOutputs=frozen(allInitOutputs - allInitInputs), 

1395 inputs=inputs, 

1396 queryConstraints=frozen(queryConstraints & inputs), 

1397 # If there are storage class differences in inputs and outputs 

1398 # the intermediates have to choose priority. Here choose that 

1399 # inputs to tasks much match the requested storage class by 

1400 # applying the inputs over the top of the outputs. 

1401 intermediates=frozen(allOutputs & allInputs | intermediateComponents), 

1402 outputs=frozen(allOutputs - allInputs - intermediateComposites), 

1403 prerequisites=frozen(prerequisites), 

1404 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

1405 ) 

1406 

1407 @classmethod 

1408 def initOutputNames( 

1409 cls, 

1410 pipeline: Pipeline | Iterable[TaskDef], 

1411 *, 

1412 include_configs: bool = True, 

1413 include_packages: bool = True, 

1414 ) -> Iterator[str]: 

1415 """Return the names of dataset types ot task initOutputs, Configs, 

1416 and package versions for a pipeline. 

1417 

1418 Parameters 

1419 ---------- 

1420 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ] 

1421 A `Pipeline` instance or collection of `TaskDef` instances. 

1422 include_configs : `bool`, optional 

1423 If `True` (default) include config dataset types. 

1424 include_packages : `bool`, optional 

1425 If `True` (default) include the dataset type for package versions. 

1426 

1427 Yields 

1428 ------ 

1429 datasetTypeName : `str` 

1430 Name of the dataset type. 

1431 """ 

1432 if include_packages: 

1433 # Package versions dataset type 

1434 yield cls.packagesDatasetName 

1435 

1436 if isinstance(pipeline, Pipeline): 

1437 pipeline = pipeline.toExpandedPipeline() 

1438 

1439 for taskDef in pipeline: 

1440 # all task InitOutputs 

1441 for name in taskDef.connections.initOutputs: 

1442 attribute = getattr(taskDef.connections, name) 

1443 yield attribute.name 

1444 

1445 # config dataset name 

1446 if include_configs: 

1447 yield taskDef.configDatasetName