Coverage for python/lsst/pipe/base/pipeline.py: 23%

454 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-19 10:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Module defining Pipeline class and related methods. 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"] 

34 

35import copy 

36import logging 

37import re 

38import urllib.parse 

39 

40# ------------------------------- 

41# Imports of standard modules -- 

42# ------------------------------- 

43from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Set 

44from dataclasses import dataclass 

45from types import MappingProxyType 

46from typing import TYPE_CHECKING, ClassVar, cast 

47 

48# ----------------------------- 

49# Imports for other modules -- 

50from lsst.daf.butler import ( 

51 DataCoordinate, 

52 DatasetType, 

53 DimensionUniverse, 

54 NamedValueSet, 

55 Registry, 

56 SkyPixDimension, 

57) 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from lsst.utils import doImportType 

60from lsst.utils.introspection import get_full_type_name 

61 

62from . import automatic_connection_constants as acc 

63from . import pipeline_graph, pipelineIR 

64from ._instrument import Instrument as PipeBaseInstrument 

65from .config import PipelineTaskConfig 

66from .connections import PipelineTaskConnections, iterConnections 

67from .connectionTypes import Input 

68from .pipelineTask import PipelineTask 

69 

70if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 

71 from lsst.obs.base import Instrument 

72 from lsst.pex.config import Config 

73 

74# ---------------------------------- 

75# Local non-exported definitions -- 

76# ---------------------------------- 

77 

78_LOG = logging.getLogger(__name__) 

79 

80# ------------------------ 

81# Exported definitions -- 

82# ------------------------ 

83 

84 

85@dataclass 

86class LabelSpecifier: 

87 """A structure to specify a subset of labels to load 

88 

89 This structure may contain a set of labels to be used in subsetting a 

90 pipeline, or a beginning and end point. Beginning or end may be empty, 

91 in which case the range will be a half open interval. Unlike python 

92 iteration bounds, end bounds are *INCLUDED*. Note that range based 

93 selection is not well defined for pipelines that are not linear in nature, 

94 and correct behavior is not guaranteed, or may vary from run to run. 

95 """ 

96 

97 labels: set[str] | None = None 

98 begin: str | None = None 

99 end: str | None = None 

100 

101 def __post_init__(self) -> None: 

102 if self.labels is not None and (self.begin or self.end): 

103 raise ValueError( 

104 "This struct can only be initialized with a labels set or a begin (and/or) end specifier" 

105 ) 

106 

107 

108class TaskDef: 

109 """TaskDef is a collection of information about task needed by Pipeline. 

110 

111 The information includes task name, configuration object and optional 

112 task class. This class is just a collection of attributes and it exposes 

113 all of them so that attributes could potentially be modified in place 

114 (e.g. if configuration needs extra overrides). 

115 

116 Attributes 

117 ---------- 

118 taskName : `str`, optional 

119 The fully-qualified `PipelineTask` class name. If not provided, 

120 ``taskClass`` must be. 

121 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional 

122 Instance of the configuration class corresponding to this task class, 

123 usually with all overrides applied. This config will be frozen. If 

124 not provided, ``taskClass`` must be provided and 

125 ``taskClass.ConfigClass()`` will be used. 

126 taskClass : `type`, optional 

127 `PipelineTask` class object; if provided and ``taskName`` is as well, 

128 the caller guarantees that they are consistent. If not provided, 

129 ``taskName`` is used to import the type. 

130 label : `str`, optional 

131 Task label, usually a short string unique in a pipeline. If not 

132 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will 

133 be used. 

134 connections : `PipelineTaskConnections`, optional 

135 Object that describes the dataset types used by the task. If not 

136 provided, one will be constructed from the given configuration. If 

137 provided, it is assumed that ``config`` has already been validated 

138 and frozen. 

139 """ 

140 

141 def __init__( 

142 self, 

143 taskName: str | None = None, 

144 config: PipelineTaskConfig | None = None, 

145 taskClass: type[PipelineTask] | None = None, 

146 label: str | None = None, 

147 connections: PipelineTaskConnections | None = None, 

148 ): 

149 if taskName is None: 

150 if taskClass is None: 

151 raise ValueError("At least one of `taskName` and `taskClass` must be provided.") 

152 taskName = get_full_type_name(taskClass) 

153 elif taskClass is None: 

154 taskClass = doImportType(taskName) 

155 if config is None: 

156 if taskClass is None: 

157 raise ValueError("`taskClass` must be provided if `config` is not.") 

158 config = taskClass.ConfigClass() 

159 if label is None: 

160 if taskClass is None: 

161 raise ValueError("`taskClass` must be provided if `label` is not.") 

162 label = taskClass._DefaultName 

163 self.taskName = taskName 

164 if connections is None: 

165 # If we don't have connections yet, assume the config hasn't been 

166 # validated yet. 

167 try: 

168 config.validate() 

169 except Exception: 

170 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) 

171 raise 

172 config.freeze() 

173 connections = config.connections.ConnectionsClass(config=config) 

174 self.config = config 

175 self.taskClass = taskClass 

176 self.label = label 

177 self.connections = connections 

178 

179 @property 

180 def configDatasetName(self) -> str: 

181 """Name of a dataset type for configuration of this task (`str`)""" 

182 return acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=self.label) 

183 

184 @property 

185 def metadataDatasetName(self) -> str: 

186 """Name of a dataset type for metadata of this task (`str`)""" 

187 return self.makeMetadataDatasetName(self.label) 

188 

189 @classmethod 

190 def makeMetadataDatasetName(cls, label: str) -> str: 

191 """Construct the name of the dataset type for metadata for a task. 

192 

193 Parameters 

194 ---------- 

195 label : `str` 

196 Label for the task within its pipeline. 

197 

198 Returns 

199 ------- 

200 name : `str` 

201 Name of the task's metadata dataset type. 

202 """ 

203 return acc.METADATA_OUTPUT_TEMPLATE.format(label=label) 

204 

205 @property 

206 def logOutputDatasetName(self) -> str | None: 

207 """Name of a dataset type for log output from this task, `None` if 

208 logs are not to be saved (`str`) 

209 """ 

210 if self.config.saveLogOutput: 

211 return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label) 

212 else: 

213 return None 

214 

215 def __str__(self) -> str: 

216 rep = "TaskDef(" + self.taskName 

217 if self.label: 

218 rep += ", label=" + self.label 

219 rep += ")" 

220 return rep 

221 

222 def __eq__(self, other: object) -> bool: 

223 if not isinstance(other, TaskDef): 

224 return False 

225 # This does not consider equality of configs when determining equality 

226 # as config equality is a difficult thing to define. Should be updated 

227 # after DM-27847 

228 return self.taskClass == other.taskClass and self.label == other.label 

229 

230 def __hash__(self) -> int: 

231 return hash((self.taskClass, self.label)) 

232 

233 @classmethod 

234 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef: 

235 """Unpickle pickle. Custom callable for unpickling. 

236 

237 All arguments are forwarded directly to the constructor; this 

238 trampoline is only needed because ``__reduce__`` callables can't be 

239 called with keyword arguments. 

240 """ 

241 return cls(taskName=taskName, config=config, label=label) 

242 

243 def __reduce__(self) -> tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], tuple[str, Config, str]]: 

244 return (self._unreduce, (self.taskName, self.config, self.label)) 

245 

246 

247class Pipeline: 

248 """A `Pipeline` is a representation of a series of tasks to run, and the 

249 configuration for those tasks. 

250 

251 Parameters 

252 ---------- 

253 description : `str` 

254 A description of that this pipeline does. 

255 """ 

256 

257 def __init__(self, description: str): 

258 pipeline_dict = {"description": description, "tasks": {}} 

259 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

260 

261 @classmethod 

262 def fromFile(cls, filename: str) -> Pipeline: 

263 """Load a pipeline defined in a pipeline yaml file. 

264 

265 Parameters 

266 ---------- 

267 filename: `str` 

268 A path that points to a pipeline defined in yaml format. This 

269 filename may also supply additional labels to be used in 

270 subsetting the loaded Pipeline. These labels are separated from 

271 the path by a ``#``, and may be specified as a comma separated 

272 list, or a range denoted as beginning..end. Beginning or end may 

273 be empty, in which case the range will be a half open interval. 

274 Unlike python iteration bounds, end bounds are *INCLUDED*. Note 

275 that range based selection is not well defined for pipelines that 

276 are not linear in nature, and correct behavior is not guaranteed, 

277 or may vary from run to run. 

278 

279 Returns 

280 ------- 

281 pipeline: `Pipeline` 

282 The pipeline loaded from specified location with appropriate (if 

283 any) subsetting. 

284 

285 Notes 

286 ----- 

287 This method attempts to prune any contracts that contain labels which 

288 are not in the declared subset of labels. This pruning is done using a 

289 string based matching due to the nature of contracts and may prune more 

290 than it should. 

291 """ 

292 return cls.from_uri(filename) 

293 

294 @classmethod 

295 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline: 

296 """Load a pipeline defined in a pipeline yaml file at a location 

297 specified by a URI. 

298 

299 Parameters 

300 ---------- 

301 uri : convertible to `~lsst.resources.ResourcePath` 

302 If a string is supplied this should be a URI path that points to a 

303 pipeline defined in yaml format, either as a direct path to the 

304 yaml file, or as a directory containing a ``pipeline.yaml`` file 

305 the form used by `write_to_uri` with ``expand=True``). This uri may 

306 also supply additional labels to be used in subsetting the loaded 

307 `Pipeline`. These labels are separated from the path by a ``#``, 

308 and may be specified as a comma separated list, or a range denoted 

309 as beginning..end. Beginning or end may be empty, in which case the 

310 range will be a half open interval. Unlike python iteration bounds, 

311 end bounds are *INCLUDED*. Note that range based selection is not 

312 well defined for pipelines that are not linear in nature, and 

313 correct behavior is not guaranteed, or may vary from run to run. 

314 The same specifiers can be used with a 

315 `~lsst.resources.ResourcePath` object, by being the sole contents 

316 in the fragments attribute. 

317 

318 Returns 

319 ------- 

320 pipeline : `Pipeline` 

321 The pipeline loaded from specified location with appropriate (if 

322 any) subsetting. 

323 

324 Notes 

325 ----- 

326 This method attempts to prune any contracts that contain labels which 

327 are not in the declared subset of labels. This pruning is done using a 

328 string based matching due to the nature of contracts and may prune more 

329 than it should. 

330 """ 

331 # Split up the uri and any labels that were supplied 

332 uri, label_specifier = cls._parse_file_specifier(uri) 

333 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri)) 

334 

335 # If there are labels supplied, only keep those 

336 if label_specifier is not None: 

337 pipeline = pipeline.subsetFromLabels(label_specifier) 

338 return pipeline 

339 

340 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline: 

341 """Subset a pipeline to contain only labels specified in labelSpecifier 

342 

343 Parameters 

344 ---------- 

345 labelSpecifier : `labelSpecifier` 

346 Object containing labels that describes how to subset a pipeline. 

347 

348 Returns 

349 ------- 

350 pipeline : `Pipeline` 

351 A new pipeline object that is a subset of the old pipeline 

352 

353 Raises 

354 ------ 

355 ValueError 

356 Raised if there is an issue with specified labels 

357 

358 Notes 

359 ----- 

360 This method attempts to prune any contracts that contain labels which 

361 are not in the declared subset of labels. This pruning is done using a 

362 string based matching due to the nature of contracts and may prune more 

363 than it should. 

364 """ 

365 # Labels supplied as a set 

366 if labelSpecifier.labels: 

367 labelSet = labelSpecifier.labels 

368 # Labels supplied as a range, first create a list of all the labels 

369 # in the pipeline sorted according to task dependency. Then only 

370 # keep labels that lie between the supplied bounds 

371 else: 

372 # Create a copy of the pipeline to use when assessing the label 

373 # ordering. Use a dict for fast searching while preserving order. 

374 # Remove contracts so they do not fail in the expansion step. This 

375 # is needed because a user may only configure the tasks they intend 

376 # to run, which may cause some contracts to fail if they will later 

377 # be dropped 

378 pipeline = copy.deepcopy(self) 

379 pipeline._pipelineIR.contracts = [] 

380 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()} 

381 

382 # Verify the bounds are in the labels 

383 if labelSpecifier.begin is not None: 

384 if labelSpecifier.begin not in labels: 

385 raise ValueError( 

386 f"Beginning of range subset, {labelSpecifier.begin}, not found in pipeline definition" 

387 ) 

388 if labelSpecifier.end is not None: 

389 if labelSpecifier.end not in labels: 

390 raise ValueError( 

391 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition" 

392 ) 

393 

394 labelSet = set() 

395 for label in labels: 

396 if labelSpecifier.begin is not None: 

397 if label != labelSpecifier.begin: 

398 continue 

399 else: 

400 labelSpecifier.begin = None 

401 labelSet.add(label) 

402 if labelSpecifier.end is not None and label == labelSpecifier.end: 

403 break 

404 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet)) 

405 

406 @staticmethod 

407 def _parse_file_specifier(uri: ResourcePathExpression) -> tuple[ResourcePath, LabelSpecifier | None]: 

408 """Split appart a uri and any possible label subsets""" 

409 if isinstance(uri, str): 

410 # This is to support legacy pipelines during transition 

411 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri) 

412 if num_replace: 

413 raise ValueError( 

414 f"The pipeline file {uri} seems to use the legacy :" 

415 " to separate labels, please use # instead." 

416 ) 

417 if uri.count("#") > 1: 

418 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load") 

419 # Everything else can be converted directly to ResourcePath. 

420 uri = ResourcePath(uri) 

421 label_subset = uri.fragment or None 

422 

423 specifier: LabelSpecifier | None 

424 if label_subset is not None: 

425 label_subset = urllib.parse.unquote(label_subset) 

426 args: dict[str, set[str] | str | None] 

427 # labels supplied as a list 

428 if "," in label_subset: 

429 if ".." in label_subset: 

430 raise ValueError( 

431 "Can only specify a list of labels or a rangewhen loading a Pipline not both" 

432 ) 

433 args = {"labels": set(label_subset.split(","))} 

434 # labels supplied as a range 

435 elif ".." in label_subset: 

436 # Try to de-structure the labelSubset, this will fail if more 

437 # than one range is specified 

438 begin, end, *rest = label_subset.split("..") 

439 if rest: 

440 raise ValueError("Only one range can be specified when loading a pipeline") 

441 args = {"begin": begin if begin else None, "end": end if end else None} 

442 # Assume anything else is a single label 

443 else: 

444 args = {"labels": {label_subset}} 

445 

446 # MyPy doesn't like how cavalier kwarg construction is with types. 

447 specifier = LabelSpecifier(**args) # type: ignore 

448 else: 

449 specifier = None 

450 

451 return uri, specifier 

452 

453 @classmethod 

454 def fromString(cls, pipeline_string: str) -> Pipeline: 

455 """Create a pipeline from string formatted as a pipeline document. 

456 

457 Parameters 

458 ---------- 

459 pipeline_string : `str` 

460 A string that is formatted according like a pipeline document 

461 

462 Returns 

463 ------- 

464 pipeline: `Pipeline` 

465 """ 

466 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

467 return pipeline 

468 

469 @classmethod 

470 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

471 """Create a pipeline from an already created `PipelineIR` object. 

472 

473 Parameters 

474 ---------- 

475 deserialized_pipeline: `PipelineIR` 

476 An already created pipeline intermediate representation object 

477 

478 Returns 

479 ------- 

480 pipeline: `Pipeline` 

481 """ 

482 pipeline = cls.__new__(cls) 

483 pipeline._pipelineIR = deserialized_pipeline 

484 return pipeline 

485 

486 @classmethod 

487 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline: 

488 """Create a new pipeline by copying an already existing `Pipeline`. 

489 

490 Parameters 

491 ---------- 

492 pipeline: `Pipeline` 

493 An already created pipeline intermediate representation object 

494 

495 Returns 

496 ------- 

497 pipeline: `Pipeline` 

498 """ 

499 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR)) 

500 

501 def __str__(self) -> str: 

502 return str(self._pipelineIR) 

503 

504 def mergePipeline(self, pipeline: Pipeline) -> None: 

505 """Merge another in-memory `Pipeline` object into this one. 

506 

507 This merges another pipeline into this object, as if it were declared 

508 in the import block of the yaml definition of this pipeline. This 

509 modifies this pipeline in place. 

510 

511 Parameters 

512 ---------- 

513 pipeline : `Pipeline` 

514 The `Pipeline` object that is to be merged into this object. 

515 """ 

516 self._pipelineIR.merge_pipelines((pipeline._pipelineIR,)) 

517 

518 def addLabelToSubset(self, subset: str, label: str) -> None: 

519 """Add a task label from the specified subset. 

520 

521 Parameters 

522 ---------- 

523 subset : `str` 

524 The labeled subset to modify 

525 label : `str` 

526 The task label to add to the specified subset. 

527 

528 Raises 

529 ------ 

530 ValueError 

531 Raised if the specified subset does not exist within the pipeline. 

532 Raised if the specified label does not exist within the pipeline. 

533 """ 

534 if label not in self._pipelineIR.tasks: 

535 raise ValueError(f"Label {label} does not appear within the pipeline") 

536 if subset not in self._pipelineIR.labeled_subsets: 

537 raise ValueError(f"Subset {subset} does not appear within the pipeline") 

538 self._pipelineIR.labeled_subsets[subset].subset.add(label) 

539 

540 def removeLabelFromSubset(self, subset: str, label: str) -> None: 

541 """Remove a task label from the specified subset. 

542 

543 Parameters 

544 ---------- 

545 subset : `str` 

546 The labeled subset to modify 

547 label : `str` 

548 The task label to remove from the specified subset. 

549 

550 Raises 

551 ------ 

552 ValueError 

553 Raised if the specified subset does not exist in the pipeline. 

554 Raised if the specified label does not exist within the specified 

555 subset. 

556 """ 

557 if subset not in self._pipelineIR.labeled_subsets: 

558 raise ValueError(f"Subset {subset} does not appear within the pipeline") 

559 if label not in self._pipelineIR.labeled_subsets[subset].subset: 

560 raise ValueError(f"Label {label} does not appear within the pipeline") 

561 self._pipelineIR.labeled_subsets[subset].subset.remove(label) 

562 

563 def findSubsetsWithLabel(self, label: str) -> set[str]: 

564 """Find any subsets which may contain the specified label. 

565 

566 This function returns the name of subsets which return the specified 

567 label. May return an empty set if there are no subsets, or no subsets 

568 containing the specified label. 

569 

570 Parameters 

571 ---------- 

572 label : `str` 

573 The task label to use in membership check 

574 

575 Returns 

576 ------- 

577 subsets : `set` of `str` 

578 Returns a set (possibly empty) of subsets names which contain the 

579 specified label. 

580 

581 Raises 

582 ------ 

583 ValueError 

584 Raised if the specified label does not exist within this pipeline. 

585 """ 

586 results = set() 

587 if label not in self._pipelineIR.tasks: 

588 raise ValueError(f"Label {label} does not appear within the pipeline") 

589 for subset in self._pipelineIR.labeled_subsets.values(): 

590 if label in subset.subset: 

591 results.add(subset.label) 

592 return results 

593 

594 def addInstrument(self, instrument: Instrument | str) -> None: 

595 """Add an instrument to the pipeline, or replace an instrument that is 

596 already defined. 

597 

598 Parameters 

599 ---------- 

600 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

601 Either a derived class object of a `lsst.daf.butler.instrument` or 

602 a string corresponding to a fully qualified 

603 `lsst.daf.butler.instrument` name. 

604 """ 

605 if isinstance(instrument, str): 

606 pass 

607 else: 

608 # TODO: assume that this is a subclass of Instrument, no type 

609 # checking 

610 instrument = get_full_type_name(instrument) 

611 self._pipelineIR.instrument = instrument 

612 

613 def getInstrument(self) -> str | None: 

614 """Get the instrument from the pipeline. 

615 

616 Returns 

617 ------- 

618 instrument : `str`, or None 

619 The fully qualified name of a `lsst.obs.base.Instrument` subclass, 

620 name, or None if the pipeline does not have an instrument. 

621 """ 

622 return self._pipelineIR.instrument 

623 

624 def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate: 

625 """Return a data ID with all dimension constraints embedded in the 

626 pipeline. 

627 

628 Parameters 

629 ---------- 

630 universe : `lsst.daf.butler.DimensionUniverse` 

631 Object that defines all dimensions. 

632 

633 Returns 

634 ------- 

635 data_id : `lsst.daf.butler.DataCoordinate` 

636 Data ID with all dimension constraints embedded in the 

637 pipeline. 

638 """ 

639 instrument_class_name = self._pipelineIR.instrument 

640 if instrument_class_name is not None: 

641 instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name)) 

642 if instrument_class is not None: 

643 return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe) 

644 return DataCoordinate.makeEmpty(universe) 

645 

646 def addTask(self, task: type[PipelineTask] | str, label: str) -> None: 

647 """Add a new task to the pipeline, or replace a task that is already 

648 associated with the supplied label. 

649 

650 Parameters 

651 ---------- 

652 task: `PipelineTask` or `str` 

653 Either a derived class object of a `PipelineTask` or a string 

654 corresponding to a fully qualified `PipelineTask` name. 

655 label: `str` 

656 A label that is used to identify the `PipelineTask` being added 

657 """ 

658 if isinstance(task, str): 

659 taskName = task 

660 elif issubclass(task, PipelineTask): 

661 taskName = get_full_type_name(task) 

662 else: 

663 raise ValueError( 

664 "task must be either a child class of PipelineTask or a string containing" 

665 " a fully qualified name to one" 

666 ) 

667 if not label: 

668 # in some cases (with command line-generated pipeline) tasks can 

669 # be defined without label which is not acceptable, use task 

670 # _DefaultName in that case 

671 if isinstance(task, str): 

672 task_class = cast(PipelineTask, doImportType(task)) 

673 label = task_class._DefaultName 

674 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

675 

676 def removeTask(self, label: str) -> None: 

677 """Remove a task from the pipeline. 

678 

679 Parameters 

680 ---------- 

681 label : `str` 

682 The label used to identify the task that is to be removed 

683 

684 Raises 

685 ------ 

686 KeyError 

687 If no task with that label exists in the pipeline 

688 

689 """ 

690 self._pipelineIR.tasks.pop(label) 

691 

692 def addConfigOverride(self, label: str, key: str, value: object) -> None: 

693 """Apply single config override. 

694 

695 Parameters 

696 ---------- 

697 label : `str` 

698 Label of the task. 

699 key: `str` 

700 Fully-qualified field name. 

701 value : object 

702 Value to be given to a field. 

703 """ 

704 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

705 

706 def addConfigFile(self, label: str, filename: str) -> None: 

707 """Add overrides from a specified file. 

708 

709 Parameters 

710 ---------- 

711 label : `str` 

712 The label used to identify the task associated with config to 

713 modify 

714 filename : `str` 

715 Path to the override file. 

716 """ 

717 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

718 

719 def addConfigPython(self, label: str, pythonString: str) -> None: 

720 """Add Overrides by running a snippet of python code against a config. 

721 

722 Parameters 

723 ---------- 

724 label : `str` 

725 The label used to identity the task associated with config to 

726 modify. 

727 pythonString: `str` 

728 A string which is valid python code to be executed. This is done 

729 with config as the only local accessible value. 

730 """ 

731 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

732 

733 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None: 

734 if label == "parameters": 

735 self._pipelineIR.parameters.mapping.update(newConfig.rest) 

736 if newConfig.file: 

737 raise ValueError("Setting parameters section with config file is not supported") 

738 if newConfig.python: 

739 raise ValueError("Setting parameters section using python block in unsupported") 

740 return 

741 if label not in self._pipelineIR.tasks: 

742 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

743 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

744 

745 def write_to_uri(self, uri: ResourcePathExpression) -> None: 

746 """Write the pipeline to a file or directory. 

747 

748 Parameters 

749 ---------- 

750 uri : convertible to `~lsst.resources.ResourcePath` 

751 URI to write to; may have any scheme with 

752 `~lsst.resources.ResourcePath` write support or no scheme for a 

753 local file/directory. Should have a ``.yaml`` extension. 

754 """ 

755 self._pipelineIR.write_to_uri(uri) 

756 

757 def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineGraph: 

758 """Construct a pipeline graph from this pipeline. 

759 

760 Constructing a graph applies all configuration overrides, freezes all 

761 configuration, checks all contracts, and checks for dataset type 

762 consistency between tasks (as much as possible without access to a data 

763 repository). It cannot be reversed. 

764 

765 Parameters 

766 ---------- 

767 registry : `lsst.daf.butler.Registry`, optional 

768 Data repository client. If provided, the graph's dataset types 

769 and dimensions will be resolved (see `PipelineGraph.resolve`). 

770 

771 Returns 

772 ------- 

773 graph : `pipeline_graph.PipelineGraph` 

774 Representation of the pipeline as a graph. 

775 """ 

776 instrument_class_name = self._pipelineIR.instrument 

777 data_id = {} 

778 if instrument_class_name is not None: 

779 instrument_class: type[Instrument] = doImportType(instrument_class_name) 

780 if instrument_class is not None: 

781 data_id["instrument"] = instrument_class.getName() 

782 graph = pipeline_graph.PipelineGraph(data_id=data_id) 

783 graph.description = self._pipelineIR.description 

784 for label in self._pipelineIR.tasks: 

785 self._add_task_to_graph(label, graph) 

786 if self._pipelineIR.contracts is not None: 

787 label_to_config = {x.label: x.config for x in graph.tasks.values()} 

788 for contract in self._pipelineIR.contracts: 

789 # execute this in its own line so it can raise a good error 

790 # message if there was problems with the eval 

791 success = eval(contract.contract, None, label_to_config) 

792 if not success: 

793 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

794 raise pipelineIR.ContractError( 

795 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" 

796 ) 

797 for label, subset in self._pipelineIR.labeled_subsets.items(): 

798 graph.add_task_subset( 

799 label, subset.subset, subset.description if subset.description is not None else "" 

800 ) 

801 graph.sort() 

802 if registry is not None: 

803 graph.resolve(registry) 

804 return graph 

805 

806 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: 

807 r"""Return a generator of `TaskDef`\s which can be used to create 

808 quantum graphs. 

809 

810 Returns 

811 ------- 

812 generator : generator of `TaskDef` 

813 The generator returned will be the sorted iterator of tasks which 

814 are to be used in constructing a quantum graph. 

815 

816 Raises 

817 ------ 

818 NotImplementedError 

819 If a dataId is supplied in a config block. This is in place for 

820 future use 

821 """ 

822 yield from self.to_graph()._iter_task_defs() 

823 

824 def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None: 

825 """Add a single task from this pipeline to a pipeline graph that is 

826 under construction. 

827 

828 Parameters 

829 ---------- 

830 label : `str` 

831 Label for the task to be added. 

832 graph : `pipeline_graph.PipelineGraph` 

833 Graph to add the task to. 

834 """ 

835 if (taskIR := self._pipelineIR.tasks.get(label)) is None: 

836 raise NameError(f"Label {label} does not appear in this pipeline") 

837 taskClass: type[PipelineTask] = doImportType(taskIR.klass) 

838 config = taskClass.ConfigClass() 

839 instrument: PipeBaseInstrument | None = None 

840 if (instrumentName := self._pipelineIR.instrument) is not None: 

841 instrument_cls: type = doImportType(instrumentName) 

842 instrument = instrument_cls() 

843 config.applyConfigOverrides( 

844 instrument, 

845 getattr(taskClass, "_DefaultName", ""), 

846 taskIR.config, 

847 self._pipelineIR.parameters, 

848 label, 

849 ) 

850 graph.add_task(label, taskClass, config) 

851 

852 def __iter__(self) -> Generator[TaskDef, None, None]: 

853 return self.toExpandedPipeline() 

854 

855 def __getitem__(self, item: str) -> TaskDef: 

856 # Making a whole graph and then making a TaskDef from that is pretty 

857 # backwards, but I'm hoping to deprecate this method shortly in favor 

858 # of making the graph explicitly and working with its node objects. 

859 graph = pipeline_graph.PipelineGraph() 

860 self._add_task_to_graph(item, graph) 

861 (result,) = graph._iter_task_defs() 

862 return result 

863 

864 def __len__(self) -> int: 

865 return len(self._pipelineIR.tasks) 

866 

867 def __eq__(self, other: object) -> bool: 

868 if not isinstance(other, Pipeline): 

869 return False 

870 elif self._pipelineIR == other._pipelineIR: 

871 # Shortcut: if the IR is the same, the expanded pipeline must be 

872 # the same as well. But the converse is not true. 

873 return True 

874 else: 

875 self_expanded = {td.label: (td.taskClass,) for td in self} 

876 other_expanded = {td.label: (td.taskClass,) for td in other} 

877 if self_expanded != other_expanded: 

878 return False 

879 # After DM-27847, we should compare configuration here, or better, 

880 # delegated to TaskDef.__eq__ after making that compare configurations. 

881 raise NotImplementedError( 

882 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847." 

883 ) 

884 

885 

886@dataclass(frozen=True) 

887class TaskDatasetTypes: 

888 """An immutable struct that extracts and classifies the dataset types used 

889 by a `PipelineTask` 

890 """ 

891 

892 initInputs: NamedValueSet[DatasetType] 

893 """Dataset types that are needed as inputs in order to construct this Task. 

894 

895 Task-level `initInputs` may be classified as either 

896 `~PipelineDatasetTypes.initInputs` or 

897 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

898 """ 

899 

900 initOutputs: NamedValueSet[DatasetType] 

901 """Dataset types that may be written after constructing this Task. 

902 

903 Task-level `initOutputs` may be classified as either 

904 `~PipelineDatasetTypes.initOutputs` or 

905 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

906 """ 

907 

908 inputs: NamedValueSet[DatasetType] 

909 """Dataset types that are regular inputs to this Task. 

910 

911 If an input dataset needed for a Quantum cannot be found in the input 

912 collection(s) or produced by another Task in the Pipeline, that Quantum 

913 (and all dependent Quanta) will not be produced. 

914 

915 Task-level `inputs` may be classified as either 

916 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

917 at the Pipeline level. 

918 """ 

919 

920 queryConstraints: NamedValueSet[DatasetType] 

921 """Regular inputs that should not be used as constraints on the initial 

922 QuantumGraph generation data ID query, according to their tasks 

923 (`NamedValueSet`). 

924 """ 

925 

926 prerequisites: NamedValueSet[DatasetType] 

927 """Dataset types that are prerequisite inputs to this Task. 

928 

929 Prerequisite inputs must exist in the input collection(s) before the 

930 pipeline is run, but do not constrain the graph - if a prerequisite is 

931 missing for a Quantum, `PrerequisiteMissingError` is raised. 

932 

933 Prerequisite inputs are not resolved until the second stage of 

934 QuantumGraph generation. 

935 """ 

936 

937 outputs: NamedValueSet[DatasetType] 

938 """Dataset types that are produced by this Task. 

939 

940 Task-level `outputs` may be classified as either 

941 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

942 at the Pipeline level. 

943 """ 

944 

945 @classmethod 

946 def fromTaskDef( 

947 cls, 

948 taskDef: TaskDef, 

949 *, 

950 registry: Registry, 

951 include_configs: bool = True, 

952 storage_class_mapping: Mapping[str, str] | None = None, 

953 ) -> TaskDatasetTypes: 

954 """Extract and classify the dataset types from a single `PipelineTask`. 

955 

956 Parameters 

957 ---------- 

958 taskDef: `TaskDef` 

959 An instance of a `TaskDef` class for a particular `PipelineTask`. 

960 registry: `Registry` 

961 Registry used to construct normalized 

962 `~lsst.daf.butler.DatasetType` objects and retrieve those that are 

963 incomplete. 

964 include_configs : `bool`, optional 

965 If `True` (default) include config dataset types as 

966 ``initOutputs``. 

967 storage_class_mapping : `~collections.abc.Mapping` of `str` to \ 

968 `~lsst.daf.butler.StorageClass`, optional 

969 If a taskdef contains a component dataset type that is unknown 

970 to the registry, its parent `~lsst.daf.butler.StorageClass` will 

971 be looked up in this mapping if it is supplied. If the mapping does 

972 not contain the composite dataset type, or the mapping is not 

973 supplied an exception will be raised. 

974 

975 Returns 

976 ------- 

977 types: `TaskDatasetTypes` 

978 The dataset types used by this task. 

979 

980 Raises 

981 ------ 

982 ValueError 

983 Raised if dataset type connection definition differs from 

984 registry definition. 

985 LookupError 

986 Raised if component parent StorageClass could not be determined 

987 and storage_class_mapping does not contain the composite type, or 

988 is set to None. 

989 """ 

990 

991 def makeDatasetTypesSet( 

992 connectionType: str, 

993 is_input: bool, 

994 freeze: bool = True, 

995 ) -> NamedValueSet[DatasetType]: 

996 """Construct a set of true `~lsst.daf.butler.DatasetType` objects. 

997 

998 Parameters 

999 ---------- 

1000 connectionType : `str` 

1001 Name of the connection type to produce a set for, corresponds 

1002 to an attribute of type `list` on the connection class instance 

1003 is_input : `bool` 

1004 These are input dataset types, else they are output dataset 

1005 types. 

1006 freeze : `bool`, optional 

1007 If `True`, call `NamedValueSet.freeze` on the object returned. 

1008 

1009 Returns 

1010 ------- 

1011 datasetTypes : `NamedValueSet` 

1012 A set of all datasetTypes which correspond to the input 

1013 connection type specified in the connection class of this 

1014 `PipelineTask` 

1015 

1016 Raises 

1017 ------ 

1018 ValueError 

1019 Raised if dataset type connection definition differs from 

1020 registry definition. 

1021 LookupError 

1022 Raised if component parent StorageClass could not be determined 

1023 and storage_class_mapping does not contain the composite type, 

1024 or is set to None. 

1025 

1026 Notes 

1027 ----- 

1028 This function is a closure over the variables ``registry`` and 

1029 ``taskDef``, and ``storage_class_mapping``. 

1030 """ 

1031 datasetTypes = NamedValueSet[DatasetType]() 

1032 for c in iterConnections(taskDef.connections, connectionType): 

1033 dimensions = set(getattr(c, "dimensions", set())) 

1034 if "skypix" in dimensions: 

1035 try: 

1036 datasetType = registry.getDatasetType(c.name) 

1037 except LookupError as err: 

1038 raise LookupError( 

1039 f"DatasetType '{c.name}' referenced by " 

1040 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

1041 "placeholder, but does not already exist in the registry. " 

1042 "Note that reference catalog names are now used as the dataset " 

1043 "type name instead of 'ref_cat'." 

1044 ) from err 

1045 rest1 = set(registry.dimensions.extract(dimensions - {"skypix"}).names) 

1046 rest2 = { 

1047 dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension) 

1048 } 

1049 if rest1 != rest2: 

1050 raise ValueError( 

1051 f"Non-skypix dimensions for dataset type {c.name} declared in " 

1052 f"connections ({rest1}) are inconsistent with those in " 

1053 f"registry's version of this dataset ({rest2})." 

1054 ) 

1055 else: 

1056 # Component dataset types are not explicitly in the 

1057 # registry. This complicates consistency checks with 

1058 # registry and requires we work out the composite storage 

1059 # class. 

1060 registryDatasetType = None 

1061 try: 

1062 registryDatasetType = registry.getDatasetType(c.name) 

1063 except KeyError: 

1064 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) 

1065 if componentName: 

1066 if storage_class_mapping is None or compositeName not in storage_class_mapping: 

1067 raise LookupError( 

1068 "Component parent class cannot be determined, and " 

1069 "composite name was not in storage class mapping, or no " 

1070 "storage_class_mapping was supplied" 

1071 ) from None 

1072 else: 

1073 parentStorageClass = storage_class_mapping[compositeName] 

1074 else: 

1075 parentStorageClass = None 

1076 datasetType = c.makeDatasetType( 

1077 registry.dimensions, parentStorageClass=parentStorageClass 

1078 ) 

1079 registryDatasetType = datasetType 

1080 else: 

1081 datasetType = c.makeDatasetType( 

1082 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass 

1083 ) 

1084 

1085 if registryDatasetType and datasetType != registryDatasetType: 

1086 # The dataset types differ but first check to see if 

1087 # they are compatible before raising. 

1088 if is_input: 

1089 # This DatasetType must be compatible on get. 

1090 is_compatible = datasetType.is_compatible_with(registryDatasetType) 

1091 else: 

1092 # Has to be able to be converted to expect type 

1093 # on put. 

1094 is_compatible = registryDatasetType.is_compatible_with(datasetType) 

1095 if is_compatible: 

1096 # For inputs we want the pipeline to use the 

1097 # pipeline definition, for outputs it should use 

1098 # the registry definition. 

1099 if not is_input: 

1100 datasetType = registryDatasetType 

1101 _LOG.debug( 

1102 "Dataset types differ (task %s != registry %s) but are compatible" 

1103 " for %s in %s.", 

1104 datasetType, 

1105 registryDatasetType, 

1106 "input" if is_input else "output", 

1107 taskDef.label, 

1108 ) 

1109 else: 

1110 try: 

1111 # Explicitly check for storage class just to 

1112 # make more specific message. 

1113 _ = datasetType.storageClass 

1114 except KeyError: 

1115 raise ValueError( 

1116 "Storage class does not exist for supplied dataset type " 

1117 f"{datasetType} for {taskDef.label}." 

1118 ) from None 

1119 raise ValueError( 

1120 f"Supplied dataset type ({datasetType}) inconsistent with " 

1121 f"registry definition ({registryDatasetType}) " 

1122 f"for {taskDef.label}." 

1123 ) 

1124 datasetTypes.add(datasetType) 

1125 if freeze: 

1126 datasetTypes.freeze() 

1127 return datasetTypes 

1128 

1129 # optionally add initOutput dataset for config 

1130 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False) 

1131 if include_configs: 

1132 initOutputs.add( 

1133 DatasetType( 

1134 taskDef.configDatasetName, 

1135 registry.dimensions.empty, 

1136 storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

1137 ) 

1138 ) 

1139 initOutputs.freeze() 

1140 

1141 # optionally add output dataset for metadata 

1142 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False) 

1143 

1144 # Metadata is supposed to be of the TaskMetadata type, its dimensions 

1145 # correspond to a task quantum. 

1146 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

1147 

1148 # Allow the storage class definition to be read from the existing 

1149 # dataset type definition if present. 

1150 try: 

1151 current = registry.getDatasetType(taskDef.metadataDatasetName) 

1152 except KeyError: 

1153 # No previous definition so use the default. 

1154 storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS 

1155 else: 

1156 storageClass = current.storageClass.name 

1157 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)}) 

1158 

1159 if taskDef.logOutputDatasetName is not None: 

1160 # Log output dimensions correspond to a task quantum. 

1161 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

1162 outputs.update( 

1163 { 

1164 DatasetType( 

1165 taskDef.logOutputDatasetName, 

1166 dimensions, 

1167 acc.LOG_OUTPUT_STORAGE_CLASS, 

1168 ) 

1169 } 

1170 ) 

1171 

1172 outputs.freeze() 

1173 

1174 inputs = makeDatasetTypesSet("inputs", is_input=True) 

1175 queryConstraints = NamedValueSet( 

1176 inputs[c.name] 

1177 for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs")) 

1178 if not c.deferGraphConstraint 

1179 ) 

1180 

1181 return cls( 

1182 initInputs=makeDatasetTypesSet("initInputs", is_input=True), 

1183 initOutputs=initOutputs, 

1184 inputs=inputs, 

1185 queryConstraints=queryConstraints, 

1186 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True), 

1187 outputs=outputs, 

1188 ) 

1189 

1190 

1191@dataclass(frozen=True) 

1192class PipelineDatasetTypes: 

1193 """An immutable struct that classifies the dataset types used in a 

1194 `Pipeline`. 

1195 """ 

1196 

1197 packagesDatasetName: ClassVar[str] = acc.PACKAGES_INIT_OUTPUT_NAME 

1198 """Name of a dataset type used to save package versions. 

1199 """ 

1200 

1201 initInputs: NamedValueSet[DatasetType] 

1202 """Dataset types that are needed as inputs in order to construct the Tasks 

1203 in this Pipeline. 

1204 

1205 This does not include dataset types that are produced when constructing 

1206 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

1207 """ 

1208 

1209 initOutputs: NamedValueSet[DatasetType] 

1210 """Dataset types that may be written after constructing the Tasks in this 

1211 Pipeline. 

1212 

1213 This does not include dataset types that are also used as inputs when 

1214 constructing other Tasks in the Pipeline (these are classified as 

1215 `initIntermediates`). 

1216 """ 

1217 

1218 initIntermediates: NamedValueSet[DatasetType] 

1219 """Dataset types that are both used when constructing one or more Tasks 

1220 in the Pipeline and produced as a side-effect of constructing another 

1221 Task in the Pipeline. 

1222 """ 

1223 

1224 inputs: NamedValueSet[DatasetType] 

1225 """Dataset types that are regular inputs for the full pipeline. 

1226 

1227 If an input dataset needed for a Quantum cannot be found in the input 

1228 collection(s), that Quantum (and all dependent Quanta) will not be 

1229 produced. 

1230 """ 

1231 

1232 queryConstraints: NamedValueSet[DatasetType] 

1233 """Regular inputs that should be used as constraints on the initial 

1234 QuantumGraph generation data ID query, according to their tasks 

1235 (`NamedValueSet`). 

1236 """ 

1237 

1238 prerequisites: NamedValueSet[DatasetType] 

1239 """Dataset types that are prerequisite inputs for the full Pipeline. 

1240 

1241 Prerequisite inputs must exist in the input collection(s) before the 

1242 pipeline is run, but do not constrain the graph - if a prerequisite is 

1243 missing for a Quantum, `PrerequisiteMissingError` is raised. 

1244 

1245 Prerequisite inputs are not resolved until the second stage of 

1246 QuantumGraph generation. 

1247 """ 

1248 

1249 intermediates: NamedValueSet[DatasetType] 

1250 """Dataset types that are output by one Task in the Pipeline and consumed 

1251 as inputs by one or more other Tasks in the Pipeline. 

1252 """ 

1253 

1254 outputs: NamedValueSet[DatasetType] 

1255 """Dataset types that are output by a Task in the Pipeline and not consumed 

1256 by any other Task in the Pipeline. 

1257 """ 

1258 

1259 byTask: Mapping[str, TaskDatasetTypes] 

1260 """Per-Task dataset types, keyed by label in the `Pipeline`. 

1261 

1262 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

1263 neither has been modified since the dataset types were extracted, of 

1264 course). 

1265 """ 

1266 

1267 @classmethod 

1268 def fromPipeline( 

1269 cls, 

1270 pipeline: Pipeline | Iterable[TaskDef], 

1271 *, 

1272 registry: Registry, 

1273 include_configs: bool = True, 

1274 include_packages: bool = True, 

1275 ) -> PipelineDatasetTypes: 

1276 """Extract and classify the dataset types from all tasks in a 

1277 `Pipeline`. 

1278 

1279 Parameters 

1280 ---------- 

1281 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ] 

1282 A collection of tasks that can be run together. 

1283 registry: `Registry` 

1284 Registry used to construct normalized 

1285 `~lsst.daf.butler.DatasetType` objects and retrieve those that are 

1286 incomplete. 

1287 include_configs : `bool`, optional 

1288 If `True` (default) include config dataset types as 

1289 ``initOutputs``. 

1290 include_packages : `bool`, optional 

1291 If `True` (default) include the dataset type for software package 

1292 versions in ``initOutputs``. 

1293 

1294 Returns 

1295 ------- 

1296 types: `PipelineDatasetTypes` 

1297 The dataset types used by this `Pipeline`. 

1298 

1299 Raises 

1300 ------ 

1301 ValueError 

1302 Raised if Tasks are inconsistent about which datasets are marked 

1303 prerequisite. This indicates that the Tasks cannot be run as part 

1304 of the same `Pipeline`. 

1305 """ 

1306 allInputs = NamedValueSet[DatasetType]() 

1307 allOutputs = NamedValueSet[DatasetType]() 

1308 allInitInputs = NamedValueSet[DatasetType]() 

1309 allInitOutputs = NamedValueSet[DatasetType]() 

1310 prerequisites = NamedValueSet[DatasetType]() 

1311 queryConstraints = NamedValueSet[DatasetType]() 

1312 byTask = dict() 

1313 if include_packages: 

1314 allInitOutputs.add( 

1315 DatasetType( 

1316 cls.packagesDatasetName, 

1317 registry.dimensions.empty, 

1318 storageClass=acc.PACKAGES_INIT_OUTPUT_STORAGE_CLASS, 

1319 ) 

1320 ) 

1321 # create a list of TaskDefs in case the input is a generator 

1322 pipeline = list(pipeline) 

1323 

1324 # collect all the output dataset types 

1325 typeStorageclassMap: dict[str, str] = {} 

1326 for taskDef in pipeline: 

1327 for outConnection in iterConnections(taskDef.connections, "outputs"): 

1328 typeStorageclassMap[outConnection.name] = outConnection.storageClass 

1329 

1330 for taskDef in pipeline: 

1331 thisTask = TaskDatasetTypes.fromTaskDef( 

1332 taskDef, 

1333 registry=registry, 

1334 include_configs=include_configs, 

1335 storage_class_mapping=typeStorageclassMap, 

1336 ) 

1337 allInitInputs.update(thisTask.initInputs) 

1338 allInitOutputs.update(thisTask.initOutputs) 

1339 allInputs.update(thisTask.inputs) 

1340 # Inputs are query constraints if any task considers them a query 

1341 # constraint. 

1342 queryConstraints.update(thisTask.queryConstraints) 

1343 prerequisites.update(thisTask.prerequisites) 

1344 allOutputs.update(thisTask.outputs) 

1345 byTask[taskDef.label] = thisTask 

1346 if not prerequisites.isdisjoint(allInputs): 

1347 raise ValueError( 

1348 "{} marked as both prerequisites and regular inputs".format( 

1349 {dt.name for dt in allInputs & prerequisites} 

1350 ) 

1351 ) 

1352 if not prerequisites.isdisjoint(allOutputs): 

1353 raise ValueError( 

1354 "{} marked as both prerequisites and outputs".format( 

1355 {dt.name for dt in allOutputs & prerequisites} 

1356 ) 

1357 ) 

1358 # Make sure that components which are marked as inputs get treated as 

1359 # intermediates if there is an output which produces the composite 

1360 # containing the component 

1361 intermediateComponents = NamedValueSet[DatasetType]() 

1362 intermediateComposites = NamedValueSet[DatasetType]() 

1363 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

1364 for dsType in allInputs: 

1365 # get the name of a possible component 

1366 name, component = dsType.nameAndComponent() 

1367 # if there is a component name, that means this is a component 

1368 # DatasetType, if there is an output which produces the parent of 

1369 # this component, treat this input as an intermediate 

1370 if component is not None: 

1371 # This needs to be in this if block, because someone might have 

1372 # a composite that is a pure input from existing data 

1373 if name in outputNameMapping: 

1374 intermediateComponents.add(dsType) 

1375 intermediateComposites.add(outputNameMapping[name]) 

1376 

1377 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None: 

1378 common = a.names & b.names 

1379 for name in common: 

1380 # Any compatibility is allowed. This function does not know 

1381 # if a dataset type is to be used for input or output. 

1382 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])): 

1383 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") 

1384 

1385 checkConsistency(allInitInputs, allInitOutputs) 

1386 checkConsistency(allInputs, allOutputs) 

1387 checkConsistency(allInputs, intermediateComposites) 

1388 checkConsistency(allOutputs, intermediateComposites) 

1389 

1390 def frozen(s: Set[DatasetType]) -> NamedValueSet[DatasetType]: 

1391 assert isinstance(s, NamedValueSet) 

1392 s.freeze() 

1393 return s 

1394 

1395 inputs = frozen(allInputs - allOutputs - intermediateComponents) 

1396 

1397 return cls( 

1398 initInputs=frozen(allInitInputs - allInitOutputs), 

1399 initIntermediates=frozen(allInitInputs & allInitOutputs), 

1400 initOutputs=frozen(allInitOutputs - allInitInputs), 

1401 inputs=inputs, 

1402 queryConstraints=frozen(queryConstraints & inputs), 

1403 # If there are storage class differences in inputs and outputs 

1404 # the intermediates have to choose priority. Here choose that 

1405 # inputs to tasks much match the requested storage class by 

1406 # applying the inputs over the top of the outputs. 

1407 intermediates=frozen(allOutputs & allInputs | intermediateComponents), 

1408 outputs=frozen(allOutputs - allInputs - intermediateComposites), 

1409 prerequisites=frozen(prerequisites), 

1410 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

1411 ) 

1412 

1413 @classmethod 

1414 def initOutputNames( 

1415 cls, 

1416 pipeline: Pipeline | Iterable[TaskDef], 

1417 *, 

1418 include_configs: bool = True, 

1419 include_packages: bool = True, 

1420 ) -> Iterator[str]: 

1421 """Return the names of dataset types ot task initOutputs, Configs, 

1422 and package versions for a pipeline. 

1423 

1424 Parameters 

1425 ---------- 

1426 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ] 

1427 A `Pipeline` instance or collection of `TaskDef` instances. 

1428 include_configs : `bool`, optional 

1429 If `True` (default) include config dataset types. 

1430 include_packages : `bool`, optional 

1431 If `True` (default) include the dataset type for package versions. 

1432 

1433 Yields 

1434 ------ 

1435 datasetTypeName : `str` 

1436 Name of the dataset type. 

1437 """ 

1438 if include_packages: 

1439 # Package versions dataset type 

1440 yield cls.packagesDatasetName 

1441 

1442 if isinstance(pipeline, Pipeline): 

1443 pipeline = pipeline.toExpandedPipeline() 

1444 

1445 for taskDef in pipeline: 

1446 # all task InitOutputs 

1447 for name in taskDef.connections.initOutputs: 

1448 attribute = getattr(taskDef.connections, name) 

1449 yield attribute.name 

1450 

1451 # config dataset name 

1452 if include_configs: 

1453 yield taskDef.configDatasetName