Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23"""Module defining Pipeline class and related methods. 

24""" 

25 

26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31from dataclasses import dataclass 

32from types import MappingProxyType 

33from typing import FrozenSet, Mapping, Union, Generator, TYPE_CHECKING 

34 

35import copy 

36 

37# ----------------------------- 

38# Imports for other modules -- 

39from lsst.daf.butler import DatasetType, Registry, SkyPixDimension 

40from lsst.utils import doImport 

41from .configOverrides import ConfigOverrides 

42from .connections import iterConnections 

43from .pipelineTask import PipelineTask 

44 

45from . import pipelineIR 

46from . import pipeTools 

47 

48if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 from lsst.obs.base.instrument import Instrument 

50 

51# ---------------------------------- 

52# Local non-exported definitions -- 

53# ---------------------------------- 

54 

55# ------------------------ 

56# Exported definitions -- 

57# ------------------------ 

58 

59 

60class TaskDef: 

61 """TaskDef is a collection of information about task needed by Pipeline. 

62 

63 The information includes task name, configuration object and optional 

64 task class. This class is just a collection of attributes and it exposes 

65 all of them so that attributes could potentially be modified in place 

66 (e.g. if configuration needs extra overrides). 

67 

68 Attributes 

69 ---------- 

70 taskName : `str` 

71 `PipelineTask` class name, currently it is not specified whether this 

72 is a fully-qualified name or partial name (e.g. ``module.TaskClass``). 

73 Framework should be prepared to handle all cases. 

74 config : `lsst.pex.config.Config` 

75 Instance of the configuration class corresponding to this task class, 

76 usually with all overrides applied. 

77 taskClass : `type` or ``None`` 

78 `PipelineTask` class object, can be ``None``. If ``None`` then 

79 framework will have to locate and load class. 

80 label : `str`, optional 

81 Task label, usually a short string unique in a pipeline. 

82 """ 

83 def __init__(self, taskName, config, taskClass=None, label=""): 

84 self.taskName = taskName 

85 self.config = config 

86 self.taskClass = taskClass 

87 self.label = label 

88 self.connections = config.connections.ConnectionsClass(config=config) 

89 

90 @property 

91 def metadataDatasetName(self): 

92 """Name of a dataset type for metadata of this task, `None` if 

93 metadata is not to be saved (`str`) 

94 """ 

95 if self.config.saveMetadata: 

96 return self.label + "_metadata" 

97 else: 

98 return None 

99 

100 def __str__(self): 

101 rep = "TaskDef(" + self.taskName 

102 if self.label: 

103 rep += ", label=" + self.label 

104 rep += ")" 

105 return rep 

106 

107 

108class Pipeline: 

109 """A `Pipeline` is a representation of a series of tasks to run, and the 

110 configuration for those tasks. 

111 

112 Parameters 

113 ---------- 

114 description : `str` 

115 A description of that this pipeline does. 

116 """ 

117 def __init__(self, description: str) -> Pipeline: 

118 pipeline_dict = {"description": description, "tasks": {}} 

119 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict) 

120 

121 @classmethod 

122 def fromFile(cls, filename: str) -> Pipeline: 

123 """Load a pipeline defined in a pipeline yaml file. 

124 

125 Parameters 

126 ---------- 

127 filename: `str` 

128 A path that points to a pipeline defined in yaml format 

129 

130 Returns 

131 ------- 

132 pipeline: `Pipeline` 

133 """ 

134 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename)) 

135 return pipeline 

136 

137 @classmethod 

138 def fromString(cls, pipeline_string: str) -> Pipeline: 

139 """Create a pipeline from string formatted as a pipeline document. 

140 

141 Parameters 

142 ---------- 

143 pipeline_string : `str` 

144 A string that is formatted according like a pipeline document 

145 

146 Returns 

147 ------- 

148 pipeline: `Pipeline` 

149 """ 

150 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string)) 

151 return pipeline 

152 

153 @classmethod 

154 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline: 

155 """Create a pipeline from an already created `PipelineIR` object. 

156 

157 Parameters 

158 ---------- 

159 deserialized_pipeline: `PipelineIR` 

160 An already created pipeline intermediate representation object 

161 

162 Returns 

163 ------- 

164 pipeline: `Pipeline` 

165 """ 

166 pipeline = cls.__new__(cls) 

167 pipeline._pipelineIR = deserialized_pipeline 

168 return pipeline 

169 

170 @classmethod 

171 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline: 

172 """Create a new pipeline by copying an already existing `Pipeline`. 

173 

174 Parameters 

175 ---------- 

176 pipeline: `Pipeline` 

177 An already created pipeline intermediate representation object 

178 

179 Returns 

180 ------- 

181 pipeline: `Pipeline` 

182 """ 

183 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR)) 

184 

185 def __str__(self) -> str: 

186 return str(self._pipelineIR) 

187 

188 def addInstrument(self, instrument: Union[Instrument, str]): 

189 """Add an instrument to the pipeline, or replace an instrument that is 

190 already defined. 

191 

192 Parameters 

193 ---------- 

194 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` 

195 Either a derived class object of a `lsst.daf.butler.instrument` or a 

196 string corresponding to a fully qualified 

197 `lsst.daf.butler.instrument` name. 

198 """ 

199 if isinstance(instrument, str): 

200 pass 

201 else: 

202 # TODO: assume that this is a subclass of Instrument, no type checking 

203 instrument = f"{instrument.__module__}.{instrument.__qualname__}" 

204 self._pipelineIR.instrument = instrument 

205 

206 def addTask(self, task: Union[PipelineTask, str], label: str): 

207 """Add a new task to the pipeline, or replace a task that is already 

208 associated with the supplied label. 

209 

210 Parameters 

211 ---------- 

212 task: `PipelineTask` or `str` 

213 Either a derived class object of a `PipelineTask` or a string 

214 corresponding to a fully qualified `PipelineTask` name. 

215 label: `str` 

216 A label that is used to identify the `PipelineTask` being added 

217 """ 

218 if isinstance(task, str): 

219 taskName = task 

220 elif issubclass(task, PipelineTask): 

221 taskName = f"{task.__module__}.{task.__qualname__}" 

222 else: 

223 raise ValueError("task must be either a child class of PipelineTask or a string containing" 

224 " a fully qualified name to one") 

225 if not label: 

226 # in some cases (with command line-generated pipeline) tasks can 

227 # be defined without label which is not acceptable, use task 

228 # _DefaultName in that case 

229 if isinstance(task, str): 

230 task = doImport(task) 

231 label = task._DefaultName 

232 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) 

233 

234 def removeTask(self, label: str): 

235 """Remove a task from the pipeline. 

236 

237 Parameters 

238 ---------- 

239 label : `str` 

240 The label used to identify the task that is to be removed 

241 

242 Raises 

243 ------ 

244 KeyError 

245 If no task with that label exists in the pipeline 

246 

247 """ 

248 self._pipelineIR.tasks.pop(label) 

249 

250 def addConfigOverride(self, label: str, key: str, value: object): 

251 """Apply single config override. 

252 

253 Parameters 

254 ---------- 

255 label : `str` 

256 Label of the task. 

257 key: `str` 

258 Fully-qualified field name. 

259 value : object 

260 Value to be given to a field. 

261 """ 

262 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value})) 

263 

264 def addConfigFile(self, label: str, filename: str): 

265 """Add overrides from a specified file. 

266 

267 Parameters 

268 ---------- 

269 label : `str` 

270 The label used to identify the task associated with config to 

271 modify 

272 filename : `str` 

273 Path to the override file. 

274 """ 

275 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename])) 

276 

277 def addConfigPython(self, label: str, pythonString: str): 

278 """Add Overrides by running a snippet of python code against a config. 

279 

280 Parameters 

281 ---------- 

282 label : `str` 

283 The label used to identity the task associated with config to 

284 modify. 

285 pythonString: `str` 

286 A string which is valid python code to be executed. This is done 

287 with config as the only local accessible value. 

288 """ 

289 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString)) 

290 

291 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR): 

292 if label not in self._pipelineIR.tasks: 

293 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") 

294 self._pipelineIR.tasks[label].add_or_update_config(newConfig) 

295 

296 def toFile(self, filename: str): 

297 self._pipelineIR.to_file(filename) 

298 

299 def toExpandedPipeline(self) -> Generator[TaskDef]: 

300 """Returns a generator of TaskDefs which can be used to create quantum 

301 graphs. 

302 

303 Returns 

304 ------- 

305 generator : generator of `TaskDef` 

306 The generator returned will be the sorted iterator of tasks which 

307 are to be used in constructing a quantum graph. 

308 

309 Raises 

310 ------ 

311 NotImplementedError 

312 If a dataId is supplied in a config block. This is in place for 

313 future use 

314 """ 

315 taskDefs = [] 

316 for label, taskIR in self._pipelineIR.tasks.items(): 

317 taskClass = doImport(taskIR.klass) 

318 taskName = taskClass.__qualname__ 

319 config = taskClass.ConfigClass() 

320 overrides = ConfigOverrides() 

321 if self._pipelineIR.instrument is not None: 

322 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName) 

323 if taskIR.config is not None: 

324 for configIR in taskIR.config: 

325 if configIR.dataId is not None: 

326 raise NotImplementedError("Specializing a config on a partial data id is not yet " 

327 "supported in Pipeline definition") 

328 # only apply override if it applies to everything 

329 if configIR.dataId is None: 

330 if configIR.file: 

331 for configFile in configIR.file: 

332 overrides.addFileOverride(configFile) 

333 if configIR.python is not None: 

334 overrides.addPythonOverride(configIR.python) 

335 for key, value in configIR.rest.items(): 

336 overrides.addValueOverride(key, value) 

337 overrides.applyTo(config) 

338 # This may need to be revisited 

339 config.validate() 

340 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)) 

341 

342 # lets evaluate the contracts 

343 if self._pipelineIR.contracts is not None: 

344 label_to_config = {x.label: x.config for x in taskDefs} 

345 for contract in self._pipelineIR.contracts: 

346 # execute this in its own line so it can raise a good error message if there was problems 

347 # with the eval 

348 success = eval(contract.contract, None, label_to_config) 

349 if not success: 

350 extra_info = f": {contract.msg}" if contract.msg is not None else "" 

351 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not " 

352 f"satisfied{extra_info}") 

353 

354 yield from pipeTools.orderPipeline(taskDefs) 

355 

356 def __len__(self): 

357 return len(self._pipelineIR.tasks) 

358 

359 def __eq__(self, other: "Pipeline"): 

360 if not isinstance(other, Pipeline): 

361 return False 

362 return self._pipelineIR == other._pipelineIR 

363 

364 

365@dataclass(frozen=True) 

366class TaskDatasetTypes: 

367 """An immutable struct that extracts and classifies the dataset types used 

368 by a `PipelineTask` 

369 """ 

370 

371 initInputs: FrozenSet[DatasetType] 

372 """Dataset types that are needed as inputs in order to construct this Task. 

373 

374 Task-level `initInputs` may be classified as either 

375 `~PipelineDatasetTypes.initInputs` or 

376 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

377 """ 

378 

379 initOutputs: FrozenSet[DatasetType] 

380 """Dataset types that may be written after constructing this Task. 

381 

382 Task-level `initOutputs` may be classified as either 

383 `~PipelineDatasetTypes.initOutputs` or 

384 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level. 

385 """ 

386 

387 inputs: FrozenSet[DatasetType] 

388 """Dataset types that are regular inputs to this Task. 

389 

390 If an input dataset needed for a Quantum cannot be found in the input 

391 collection(s) or produced by another Task in the Pipeline, that Quantum 

392 (and all dependent Quanta) will not be produced. 

393 

394 Task-level `inputs` may be classified as either 

395 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates` 

396 at the Pipeline level. 

397 """ 

398 

399 prerequisites: FrozenSet[DatasetType] 

400 """Dataset types that are prerequisite inputs to this Task. 

401 

402 Prerequisite inputs must exist in the input collection(s) before the 

403 pipeline is run, but do not constrain the graph - if a prerequisite is 

404 missing for a Quantum, `PrerequisiteMissingError` is raised. 

405 

406 Prerequisite inputs are not resolved until the second stage of 

407 QuantumGraph generation. 

408 """ 

409 

410 outputs: FrozenSet[DatasetType] 

411 """Dataset types that are produced by this Task. 

412 

413 Task-level `outputs` may be classified as either 

414 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates` 

415 at the Pipeline level. 

416 """ 

417 

418 @classmethod 

419 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes: 

420 """Extract and classify the dataset types from a single `PipelineTask`. 

421 

422 Parameters 

423 ---------- 

424 taskDef: `TaskDef` 

425 An instance of a `TaskDef` class for a particular `PipelineTask`. 

426 registry: `Registry` 

427 Registry used to construct normalized `DatasetType` objects and 

428 retrieve those that are incomplete. 

429 

430 Returns 

431 ------- 

432 types: `TaskDatasetTypes` 

433 The dataset types used by this task. 

434 """ 

435 def makeDatasetTypesSet(connectionType): 

436 """Constructs a set of true `DatasetType` objects 

437 

438 Parameters 

439 ---------- 

440 connectionType : `str` 

441 Name of the connection type to produce a set for, corresponds 

442 to an attribute of type `list` on the connection class instance 

443 

444 Returns 

445 ------- 

446 datasetTypes : `frozenset` 

447 A set of all datasetTypes which correspond to the input 

448 connection type specified in the connection class of this 

449 `PipelineTask` 

450 

451 Notes 

452 ----- 

453 This function is a closure over the variables ``registry`` and 

454 ``taskDef``. 

455 """ 

456 datasetTypes = [] 

457 for c in iterConnections(taskDef.connections, connectionType): 

458 dimensions = set(getattr(c, 'dimensions', set())) 

459 if "skypix" in dimensions: 

460 try: 

461 datasetType = registry.getDatasetType(c.name) 

462 except LookupError as err: 

463 raise LookupError( 

464 f"DatasetType '{c.name}' referenced by " 

465 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " 

466 f"placeholder, but does not already exist in the registry. " 

467 f"Note that reference catalog names are now used as the dataset " 

468 f"type name instead of 'ref_cat'." 

469 ) from err 

470 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names) 

471 rest2 = set(dim.name for dim in datasetType.dimensions 

472 if not isinstance(dim, SkyPixDimension)) 

473 if rest1 != rest2: 

474 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in " 

475 f"connections ({rest1}) are inconsistent with those in " 

476 f"registry's version of this dataset ({rest2}).") 

477 else: 

478 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions), 

479 c.storageClass) 

480 datasetTypes.append(datasetType) 

481 return frozenset(datasetTypes) 

482 

483 # optionally add output dataset for metadata 

484 outputs = makeDatasetTypesSet("outputs") 

485 if taskDef.metadataDatasetName is not None: 

486 # Metadata is supposed to be of the PropertyList type, its dimensions 

487 # correspond to a task quantum 

488 dimensions = registry.dimensions.extract(taskDef.connections.dimensions) 

489 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")} 

490 

491 return cls( 

492 initInputs=makeDatasetTypesSet("initInputs"), 

493 initOutputs=makeDatasetTypesSet("initOutputs"), 

494 inputs=makeDatasetTypesSet("inputs"), 

495 prerequisites=makeDatasetTypesSet("prerequisiteInputs"), 

496 outputs=outputs, 

497 ) 

498 

499 

500@dataclass(frozen=True) 

501class PipelineDatasetTypes: 

502 """An immutable struct that classifies the dataset types used in a 

503 `Pipeline`. 

504 """ 

505 

506 initInputs: FrozenSet[DatasetType] 

507 """Dataset types that are needed as inputs in order to construct the Tasks 

508 in this Pipeline. 

509 

510 This does not include dataset types that are produced when constructing 

511 other Tasks in the Pipeline (these are classified as `initIntermediates`). 

512 """ 

513 

514 initOutputs: FrozenSet[DatasetType] 

515 """Dataset types that may be written after constructing the Tasks in this 

516 Pipeline. 

517 

518 This does not include dataset types that are also used as inputs when 

519 constructing other Tasks in the Pipeline (these are classified as 

520 `initIntermediates`). 

521 """ 

522 

523 initIntermediates: FrozenSet[DatasetType] 

524 """Dataset types that are both used when constructing one or more Tasks 

525 in the Pipeline and produced as a side-effect of constructing another 

526 Task in the Pipeline. 

527 """ 

528 

529 inputs: FrozenSet[DatasetType] 

530 """Dataset types that are regular inputs for the full pipeline. 

531 

532 If an input dataset needed for a Quantum cannot be found in the input 

533 collection(s), that Quantum (and all dependent Quanta) will not be 

534 produced. 

535 """ 

536 

537 prerequisites: FrozenSet[DatasetType] 

538 """Dataset types that are prerequisite inputs for the full Pipeline. 

539 

540 Prerequisite inputs must exist in the input collection(s) before the 

541 pipeline is run, but do not constrain the graph - if a prerequisite is 

542 missing for a Quantum, `PrerequisiteMissingError` is raised. 

543 

544 Prerequisite inputs are not resolved until the second stage of 

545 QuantumGraph generation. 

546 """ 

547 

548 intermediates: FrozenSet[DatasetType] 

549 """Dataset types that are output by one Task in the Pipeline and consumed 

550 as inputs by one or more other Tasks in the Pipeline. 

551 """ 

552 

553 outputs: FrozenSet[DatasetType] 

554 """Dataset types that are output by a Task in the Pipeline and not consumed 

555 by any other Task in the Pipeline. 

556 """ 

557 

558 byTask: Mapping[str, TaskDatasetTypes] 

559 """Per-Task dataset types, keyed by label in the `Pipeline`. 

560 

561 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming 

562 neither has been modified since the dataset types were extracted, of 

563 course). 

564 """ 

565 

566 @classmethod 

567 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes: 

568 """Extract and classify the dataset types from all tasks in a 

569 `Pipeline`. 

570 

571 Parameters 

572 ---------- 

573 pipeline: `Pipeline` 

574 An ordered collection of tasks that can be run together. 

575 registry: `Registry` 

576 Registry used to construct normalized `DatasetType` objects and 

577 retrieve those that are incomplete. 

578 

579 Returns 

580 ------- 

581 types: `PipelineDatasetTypes` 

582 The dataset types used by this `Pipeline`. 

583 

584 Raises 

585 ------ 

586 ValueError 

587 Raised if Tasks are inconsistent about which datasets are marked 

588 prerequisite. This indicates that the Tasks cannot be run as part 

589 of the same `Pipeline`. 

590 """ 

591 allInputs = set() 

592 allOutputs = set() 

593 allInitInputs = set() 

594 allInitOutputs = set() 

595 prerequisites = set() 

596 byTask = dict() 

597 if isinstance(pipeline, Pipeline): 

598 pipeline = pipeline.toExpandedPipeline() 

599 for taskDef in pipeline: 

600 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

601 allInitInputs.update(thisTask.initInputs) 

602 allInitOutputs.update(thisTask.initOutputs) 

603 allInputs.update(thisTask.inputs) 

604 prerequisites.update(thisTask.prerequisites) 

605 allOutputs.update(thisTask.outputs) 

606 byTask[taskDef.label] = thisTask 

607 if not prerequisites.isdisjoint(allInputs): 

608 raise ValueError("{} marked as both prerequisites and regular inputs".format( 

609 {dt.name for dt in allInputs & prerequisites} 

610 )) 

611 if not prerequisites.isdisjoint(allOutputs): 

612 raise ValueError("{} marked as both prerequisites and outputs".format( 

613 {dt.name for dt in allOutputs & prerequisites} 

614 )) 

615 # Make sure that components which are marked as inputs get treated as 

616 # intermediates if there is an output which produces the composite 

617 # containing the component 

618 intermediateComponents = set() 

619 intermediateComposites = set() 

620 outputNameMapping = {dsType.name: dsType for dsType in allOutputs} 

621 for dsType in allInputs: 

622 # get the name of a possible component 

623 name, component = dsType.nameAndComponent() 

624 # if there is a component name, that means this is a component 

625 # DatasetType, if there is an output which produces the parent of 

626 # this component, treat this input as an intermediate 

627 if component is not None: 

628 if name in outputNameMapping and outputNameMapping[name].dimensions == dsType.dimensions: 

629 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass, 

630 universe=registry.dimensions) 

631 intermediateComponents.add(dsType) 

632 intermediateComposites.add(composite) 

633 return cls( 

634 initInputs=frozenset(allInitInputs - allInitOutputs), 

635 initIntermediates=frozenset(allInitInputs & allInitOutputs), 

636 initOutputs=frozenset(allInitOutputs - allInitInputs), 

637 inputs=frozenset(allInputs - allOutputs - intermediateComponents), 

638 intermediates=frozenset(allInputs & allOutputs | intermediateComponents), 

639 outputs=frozenset(allOutputs - allInputs - intermediateComposites), 

640 prerequisites=frozenset(prerequisites), 

641 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability 

642 )