Coverage for python/lsst/pipe/base/connections.py: 46%

225 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-15 02:30 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import itertools 

39import string 

40import typing 

41from collections import UserDict 

42from dataclasses import dataclass 

43from types import SimpleNamespace 

44from typing import Any, ClassVar, Dict, Iterable, List, Set, Union 

45 

46from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

47 

48from ._status import NoWorkFound 

49from .connectionTypes import ( 

50 BaseConnection, 

51 BaseInput, 

52 InitInput, 

53 InitOutput, 

54 Input, 

55 Output, 

56 PrerequisiteInput, 

57) 

58 

59if typing.TYPE_CHECKING: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true

60 from .config import PipelineTaskConfig 

61 

62 

63class ScalarError(TypeError): 

64 """Exception raised when dataset type is configured as scalar 

65 but there are multiple data IDs in a Quantum for that dataset. 

66 """ 

67 

68 

69class PipelineTaskConnectionDict(UserDict): 

70 """This is a special dict class used by PipelineTaskConnectionMetaclass 

71 

72 This dict is used in PipelineTaskConnection class creation, as the 

73 dictionary that is initially used as __dict__. It exists to 

74 intercept connection fields declared in a PipelineTaskConnection, and 

75 what name is used to identify them. The names are then added to class 

76 level list according to the connection type of the class attribute. The 

77 names are also used as keys in a class level dictionary associated with 

78 the corresponding class attribute. This information is a duplicate of 

79 what exists in __dict__, but provides a simple place to lookup and 

80 iterate on only these variables. 

81 """ 

82 

83 def __init__(self, *args: Any, **kwargs: Any): 

84 super().__init__(*args, **kwargs) 

85 # Initialize class level variables used to track any declared 

86 # class level variables that are instances of 

87 # connectionTypes.BaseConnection 

88 self.data["inputs"] = [] 

89 self.data["prerequisiteInputs"] = [] 

90 self.data["outputs"] = [] 

91 self.data["initInputs"] = [] 

92 self.data["initOutputs"] = [] 

93 self.data["allConnections"] = {} 

94 

95 def __setitem__(self, name: str, value: Any) -> None: 

96 if isinstance(value, Input): 

97 self.data["inputs"].append(name) 

98 elif isinstance(value, PrerequisiteInput): 

99 self.data["prerequisiteInputs"].append(name) 

100 elif isinstance(value, Output): 

101 self.data["outputs"].append(name) 

102 elif isinstance(value, InitInput): 

103 self.data["initInputs"].append(name) 

104 elif isinstance(value, InitOutput): 

105 self.data["initOutputs"].append(name) 

106 # This should not be an elif, as it needs tested for 

107 # everything that inherits from BaseConnection 

108 if isinstance(value, BaseConnection): 

109 object.__setattr__(value, "varName", name) 

110 self.data["allConnections"][name] = value 

111 # defer to the default behavior 

112 super().__setitem__(name, value) 

113 

114 

115class PipelineTaskConnectionsMetaclass(type): 

116 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

117 

118 def __prepare__(name, bases, **kwargs): # noqa: 805 

119 # Create an instance of our special dict to catch and track all 

120 # variables that are instances of connectionTypes.BaseConnection 

121 # Copy any existing connections from a parent class 

122 dct = PipelineTaskConnectionDict() 

123 for base in bases: 

124 if isinstance(base, PipelineTaskConnectionsMetaclass): 124 ↛ 123line 124 didn't jump to line 123, because the condition on line 124 was never false

125 for name, value in base.allConnections.items(): 125 ↛ 126line 125 didn't jump to line 126, because the loop on line 125 never started

126 dct[name] = value 

127 return dct 

128 

129 def __new__(cls, name, bases, dct, **kwargs): 

130 dimensionsValueError = TypeError( 

131 "PipelineTaskConnections class must be created with a dimensions " 

132 "attribute which is an iterable of dimension names" 

133 ) 

134 

135 if name != "PipelineTaskConnections": 

136 # Verify that dimensions are passed as a keyword in class 

137 # declaration 

138 if "dimensions" not in kwargs: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true

139 for base in bases: 

140 if hasattr(base, "dimensions"): 

141 kwargs["dimensions"] = base.dimensions 

142 break 

143 if "dimensions" not in kwargs: 

144 raise dimensionsValueError 

145 try: 

146 if isinstance(kwargs["dimensions"], str): 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true

147 raise TypeError( 

148 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

149 ) 

150 if not isinstance(kwargs["dimensions"], typing.Iterable): 150 ↛ 151line 150 didn't jump to line 151, because the condition on line 150 was never true

151 raise TypeError("Dimensions must be iterable of dimensions") 

152 dct["dimensions"] = set(kwargs["dimensions"]) 

153 except TypeError as exc: 

154 raise dimensionsValueError from exc 

155 # Lookup any python string templates that may have been used in the 

156 # declaration of the name field of a class connection attribute 

157 allTemplates = set() 

158 stringFormatter = string.Formatter() 

159 # Loop over all connections 

160 for obj in dct["allConnections"].values(): 

161 nameValue = obj.name 

162 # add all the parameters to the set of templates 

163 for param in stringFormatter.parse(nameValue): 

164 if param[1] is not None: 

165 allTemplates.add(param[1]) 

166 

167 # look up any template from base classes and merge them all 

168 # together 

169 mergeDict = {} 

170 for base in bases[::-1]: 

171 if hasattr(base, "defaultTemplates"): 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true

172 mergeDict.update(base.defaultTemplates) 

173 if "defaultTemplates" in kwargs: 

174 mergeDict.update(kwargs["defaultTemplates"]) 

175 

176 if len(mergeDict) > 0: 

177 kwargs["defaultTemplates"] = mergeDict 

178 

179 # Verify that if templated strings were used, defaults were 

180 # supplied as an argument in the declaration of the connection 

181 # class 

182 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 182 ↛ 183line 182 didn't jump to line 183, because the condition on line 182 was never true

183 raise TypeError( 

184 "PipelineTaskConnection class contains templated attribute names, but no " 

185 "defaut templates were provided, add a dictionary attribute named " 

186 "defaultTemplates which contains the mapping between template key and value" 

187 ) 

188 if len(allTemplates) > 0: 

189 # Verify all templates have a default, and throw if they do not 

190 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

191 templateDifference = allTemplates.difference(defaultTemplateKeys) 

192 if templateDifference: 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

194 # Verify that templates do not share names with variable names 

195 # used for a connection, this is needed because of how 

196 # templates are specified in an associated config class. 

197 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

198 if len(nameTemplateIntersection) > 0: 198 ↛ 199line 198 didn't jump to line 199, because the condition on line 198 was never true

199 raise TypeError( 

200 "Template parameters cannot share names with Class attributes" 

201 f" (conflicts are {nameTemplateIntersection})." 

202 ) 

203 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

204 

205 # Convert all the connection containers into frozensets so they cannot 

206 # be modified at the class scope 

207 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

208 dct[connectionName] = frozenset(dct[connectionName]) 

209 # our custom dict type must be turned into an actual dict to be used in 

210 # type.__new__ 

211 return super().__new__(cls, name, bases, dict(dct)) 

212 

213 def __init__(cls, name, bases, dct, **kwargs): 

214 # This overrides the default init to drop the kwargs argument. Python 

215 # metaclasses will have this argument set if any kwargs are passes at 

216 # class construction time, but should be consumed before calling 

217 # __init__ on the type metaclass. This is in accordance with python 

218 # documentation on metaclasses 

219 super().__init__(name, bases, dct) 

220 

221 

222class QuantizedConnection(SimpleNamespace): 

223 """A Namespace to map defined variable names of connections to the 

224 associated `lsst.daf.butler.DatasetRef` objects. 

225 

226 This class maps the names used to define a connection on a 

227 PipelineTaskConnectionsClass to the corresponding 

228 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

229 instance. This will be a quantum of execution based on the graph created 

230 by examining all the connections defined on the 

231 `PipelineTaskConnectionsClass`. 

232 """ 

233 

234 def __init__(self, **kwargs): 

235 # Create a variable to track what attributes are added. This is used 

236 # later when iterating over this QuantizedConnection instance 

237 object.__setattr__(self, "_attributes", set()) 

238 

239 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]) -> None: 

240 # Capture the attribute name as it is added to this object 

241 self._attributes.add(name) 

242 super().__setattr__(name, value) 

243 

244 def __delattr__(self, name): 

245 object.__delattr__(self, name) 

246 self._attributes.remove(name) 

247 

248 def __len__(self) -> int: 

249 return len(self._attributes) 

250 

251 def __iter__( 

252 self, 

253 ) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, typing.List[DatasetRef]]], None, None]: 

254 """Make an Iterator for this QuantizedConnection 

255 

256 Iterating over a QuantizedConnection will yield a tuple with the name 

257 of an attribute and the value associated with that name. This is 

258 similar to dict.items() but is on the namespace attributes rather than 

259 dict keys. 

260 """ 

261 yield from ((name, getattr(self, name)) for name in self._attributes) 

262 

263 def keys(self) -> typing.Generator[str, None, None]: 

264 """Returns an iterator over all the attributes added to a 

265 QuantizedConnection class 

266 """ 

267 yield from self._attributes 

268 

269 

270class InputQuantizedConnection(QuantizedConnection): 

271 pass 

272 

273 

274class OutputQuantizedConnection(QuantizedConnection): 

275 pass 

276 

277 

278@dataclass(frozen=True) 

279class DeferredDatasetRef: 

280 """A wrapper class for `DatasetRef` that indicates that a `PipelineTask` 

281 should receive a `DeferredDatasetHandle` instead of an in-memory dataset. 

282 

283 Parameters 

284 ---------- 

285 datasetRef : `lsst.daf.butler.DatasetRef` 

286 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

287 resolve a dataset 

288 """ 

289 

290 datasetRef: DatasetRef 

291 

292 @property 

293 def datasetType(self) -> DatasetType: 

294 """The dataset type for this dataset.""" 

295 return self.datasetRef.datasetType 

296 

297 @property 

298 def dataId(self) -> DataCoordinate: 

299 """The data ID for this dataset.""" 

300 return self.datasetRef.dataId 

301 

302 

303class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

304 """PipelineTaskConnections is a class used to declare desired IO when a 

305 PipelineTask is run by an activator 

306 

307 Parameters 

308 ---------- 

309 config : `PipelineTaskConfig` 

310 A `PipelineTaskConfig` class instance whose class has been configured 

311 to use this `PipelineTaskConnectionsClass` 

312 

313 See also 

314 -------- 

315 iterConnections 

316 

317 Notes 

318 ----- 

319 ``PipelineTaskConnection`` classes are created by declaring class 

320 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

321 listed as follows: 

322 

323 * ``InitInput`` - Defines connections in a quantum graph which are used as 

324 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

325 to this class 

326 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

327 persisted using a butler at the end of the ``__init__`` function of the 

328 `PipelineTask` corresponding to this class. The variable name used to 

329 define this connection should be the same as an attribute name on the 

330 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

331 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

332 a `PipelineTask` instance should have an attribute 

333 ``self.outputSchema`` defined. Its value is what will be saved by the 

334 activator framework. 

335 * ``PrerequisiteInput`` - An input connection type that defines a 

336 `lsst.daf.butler.DatasetType` that must be present at execution time, 

337 but that will not be used during the course of creating the quantum 

338 graph to be executed. These most often are things produced outside the 

339 processing pipeline, such as reference catalogs. 

340 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

341 in the ``run`` method of a `PipelineTask`. The name used to declare 

342 class attribute must match a function argument name in the ``run`` 

343 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

344 defines an ``Input`` with the name ``calexp``, then the corresponding 

345 signature should be ``PipelineTask.run(calexp, ...)`` 

346 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

347 execution of a `PipelineTask`. The name used to declare the connection 

348 must correspond to an attribute of a `Struct` that is returned by a 

349 `PipelineTask` ``run`` method. E.g. if an output connection is 

350 defined with the name ``measCat``, then the corresponding 

351 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

352 X matches the ``storageClass`` type defined on the output connection. 

353 

354 The process of declaring a ``PipelineTaskConnection`` class involves 

355 parameters passed in the declaration statement. 

356 

357 The first parameter is ``dimensions`` which is an iterable of strings which 

358 defines the unit of processing the run method of a corresponding 

359 `PipelineTask` will operate on. These dimensions must match dimensions that 

360 exist in the butler registry which will be used in executing the 

361 corresponding `PipelineTask`. 

362 

363 The second parameter is labeled ``defaultTemplates`` and is conditionally 

364 optional. The name attributes of connections can be specified as python 

365 format strings, with named format arguments. If any of the name parameters 

366 on connections defined in a `PipelineTaskConnections` class contain a 

367 template, then a default template value must be specified in the 

368 ``defaultTemplates`` argument. This is done by passing a dictionary with 

369 keys corresponding to a template identifier, and values corresponding to 

370 the value to use as a default when formatting the string. For example if 

371 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

372 ``defaultTemplates`` = {'input': 'deep'}. 

373 

374 Once a `PipelineTaskConnections` class is created, it is used in the 

375 creation of a `PipelineTaskConfig`. This is further documented in the 

376 documentation of `PipelineTaskConfig`. For the purposes of this 

377 documentation, the relevant information is that the config class allows 

378 configuration of connection names by users when running a pipeline. 

379 

380 Instances of a `PipelineTaskConnections` class are used by the pipeline 

381 task execution framework to introspect what a corresponding `PipelineTask` 

382 will require, and what it will produce. 

383 

384 Examples 

385 -------- 

386 >>> from lsst.pipe.base import connectionTypes as cT 

387 >>> from lsst.pipe.base import PipelineTaskConnections 

388 >>> from lsst.pipe.base import PipelineTaskConfig 

389 >>> class ExampleConnections(PipelineTaskConnections, 

390 ... dimensions=("A", "B"), 

391 ... defaultTemplates={"foo": "Example"}): 

392 ... inputConnection = cT.Input(doc="Example input", 

393 ... dimensions=("A", "B"), 

394 ... storageClass=Exposure, 

395 ... name="{foo}Dataset") 

396 ... outputConnection = cT.Output(doc="Example output", 

397 ... dimensions=("A", "B"), 

398 ... storageClass=Exposure, 

399 ... name="{foo}output") 

400 >>> class ExampleConfig(PipelineTaskConfig, 

401 ... pipelineConnections=ExampleConnections): 

402 ... pass 

403 >>> config = ExampleConfig() 

404 >>> config.connections.foo = Modified 

405 >>> config.connections.outputConnection = "TotallyDifferent" 

406 >>> connections = ExampleConnections(config=config) 

407 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

408 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

409 """ 

410 

411 dimensions: ClassVar[Set[str]] 

412 

413 def __init__(self, *, config: "PipelineTaskConfig" | None = None): 

414 self.inputs: Set[str] = set(self.inputs) 

415 self.prerequisiteInputs: Set[str] = set(self.prerequisiteInputs) 

416 self.outputs: Set[str] = set(self.outputs) 

417 self.initInputs: Set[str] = set(self.initInputs) 

418 self.initOutputs: Set[str] = set(self.initOutputs) 

419 self.allConnections: Dict[str, BaseConnection] = dict(self.allConnections) 

420 

421 from .config import PipelineTaskConfig # local import to avoid cycle 

422 

423 if config is None or not isinstance(config, PipelineTaskConfig): 

424 raise ValueError( 

425 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

426 ) 

427 self.config = config 

428 # Extract the template names that were defined in the config instance 

429 # by looping over the keys of the defaultTemplates dict specified at 

430 # class declaration time 

431 templateValues = { 

432 name: getattr(config.connections, name) for name in getattr(self, "defaultTemplates").keys() 

433 } 

434 # Extract the configured value corresponding to each connection 

435 # variable. I.e. for each connection identifier, populate a override 

436 # for the connection.name attribute 

437 self._nameOverrides = { 

438 name: getattr(config.connections, name).format(**templateValues) 

439 for name in self.allConnections.keys() 

440 } 

441 

442 # connections.name corresponds to a dataset type name, create a reverse 

443 # mapping that goes from dataset type name to attribute identifier name 

444 # (variable name) on the connection class 

445 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

446 

447 def buildDatasetRefs( 

448 self, quantum: Quantum 

449 ) -> typing.Tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

450 """Builds QuantizedConnections corresponding to input Quantum 

451 

452 Parameters 

453 ---------- 

454 quantum : `lsst.daf.butler.Quantum` 

455 Quantum object which defines the inputs and outputs for a given 

456 unit of processing 

457 

458 Returns 

459 ------- 

460 retVal : `tuple` of (`InputQuantizedConnection`, 

461 `OutputQuantizedConnection`) Namespaces mapping attribute names 

462 (identifiers of connections) to butler references defined in the 

463 input `lsst.daf.butler.Quantum` 

464 """ 

465 inputDatasetRefs = InputQuantizedConnection() 

466 outputDatasetRefs = OutputQuantizedConnection() 

467 # operate on a reference object and an interable of names of class 

468 # connection attributes 

469 for refs, names in zip( 

470 (inputDatasetRefs, outputDatasetRefs), 

471 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

472 ): 

473 # get a name of a class connection attribute 

474 for attributeName in names: 

475 # get the attribute identified by name 

476 attribute = getattr(self, attributeName) 

477 # Branch if the attribute dataset type is an input 

478 if attribute.name in quantum.inputs: 

479 # if the dataset is marked to load deferred, wrap it in a 

480 # DeferredDatasetRef 

481 quantumInputRefs: Union[List[DatasetRef], List[DeferredDatasetRef]] 

482 if attribute.deferLoad: 

483 quantumInputRefs = [ 

484 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

485 ] 

486 else: 

487 quantumInputRefs = list(quantum.inputs[attribute.name]) 

488 # Unpack arguments that are not marked multiples (list of 

489 # length one) 

490 if not attribute.multiple: 

491 if len(quantumInputRefs) > 1: 

492 raise ScalarError( 

493 "Received multiple datasets " 

494 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

495 f"for scalar connection {attributeName} " 

496 f"({quantumInputRefs[0].datasetType.name}) " 

497 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

498 ) 

499 if len(quantumInputRefs) == 0: 

500 continue 

501 setattr(refs, attributeName, quantumInputRefs[0]) 

502 else: 

503 # Add to the QuantizedConnection identifier 

504 setattr(refs, attributeName, quantumInputRefs) 

505 # Branch if the attribute dataset type is an output 

506 elif attribute.name in quantum.outputs: 

507 value = quantum.outputs[attribute.name] 

508 # Unpack arguments that are not marked multiples (list of 

509 # length one) 

510 if not attribute.multiple: 

511 setattr(refs, attributeName, value[0]) 

512 else: 

513 setattr(refs, attributeName, value) 

514 # Specified attribute is not in inputs or outputs dont know how 

515 # to handle, throw 

516 else: 

517 raise ValueError( 

518 f"Attribute with name {attributeName} has no counterpoint in input quantum" 

519 ) 

520 return inputDatasetRefs, outputDatasetRefs 

521 

522 def adjustQuantum( 

523 self, 

524 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

525 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

526 label: str, 

527 data_id: DataCoordinate, 

528 ) -> typing.Tuple[ 

529 typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

530 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

531 ]: 

532 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

533 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

534 of the activator. 

535 

536 Parameters 

537 ---------- 

538 inputs : `dict` 

539 Dictionary whose keys are an input (regular or prerequisite) 

540 connection name and whose values are a tuple of the connection 

541 instance and a collection of associated `DatasetRef` objects. 

542 The exact type of the nested collections is unspecified; it can be 

543 assumed to be multi-pass iterable and support `len` and ``in``, but 

544 it should not be mutated in place. In contrast, the outer 

545 dictionaries are guaranteed to be temporary copies that are true 

546 `dict` instances, and hence may be modified and even returned; this 

547 is especially useful for delegating to `super` (see notes below). 

548 outputs : `Mapping` 

549 Mapping of output datasets, with the same structure as ``inputs``. 

550 label : `str` 

551 Label for this task in the pipeline (should be used in all 

552 diagnostic messages). 

553 data_id : `lsst.daf.butler.DataCoordinate` 

554 Data ID for this quantum in the pipeline (should be used in all 

555 diagnostic messages). 

556 

557 Returns 

558 ------- 

559 adjusted_inputs : `Mapping` 

560 Mapping of the same form as ``inputs`` with updated containers of 

561 input `DatasetRef` objects. Connections that are not changed 

562 should not be returned at all. Datasets may only be removed, not 

563 added. Nested collections may be of any multi-pass iterable type, 

564 and the order of iteration will set the order of iteration within 

565 `PipelineTask.runQuantum`. 

566 adjusted_outputs : `Mapping` 

567 Mapping of updated output datasets, with the same structure and 

568 interpretation as ``adjusted_inputs``. 

569 

570 Raises 

571 ------ 

572 ScalarError 

573 Raised if any `Input` or `PrerequisiteInput` connection has 

574 ``multiple`` set to `False`, but multiple datasets. 

575 NoWorkFound 

576 Raised to indicate that this quantum should not be run; not enough 

577 datasets were found for a regular `Input` connection, and the 

578 quantum should be pruned or skipped. 

579 FileNotFoundError 

580 Raised to cause QuantumGraph generation to fail (with the message 

581 included in this exception); not enough datasets were found for a 

582 `PrerequisiteInput` connection. 

583 

584 Notes 

585 ----- 

586 The base class implementation performs important checks. It always 

587 returns an empty mapping (i.e. makes no adjustments). It should 

588 always called be via `super` by custom implementations, ideally at the 

589 end of the custom implementation with already-adjusted mappings when 

590 any datasets are actually dropped, e.g.:: 

591 

592 def adjustQuantum(self, inputs, outputs, label, data_id): 

593 # Filter out some dataset refs for one connection. 

594 connection, old_refs = inputs["my_input"] 

595 new_refs = [ref for ref in old_refs if ...] 

596 adjusted_inputs = {"my_input", (connection, new_refs)} 

597 # Update the original inputs so we can pass them to super. 

598 inputs.update(adjusted_inputs) 

599 # Can ignore outputs from super because they are guaranteed 

600 # to be empty. 

601 super().adjustQuantum(inputs, outputs, label_data_id) 

602 # Return only the connections we modified. 

603 return adjusted_inputs, {} 

604 

605 Removing outputs here is guaranteed to affect what is actually 

606 passed to `PipelineTask.runQuantum`, but its effect on the larger 

607 graph may be deferred to execution, depending on the context in 

608 which `adjustQuantum` is being run: if one quantum removes an output 

609 that is needed by a second quantum as input, the second quantum may not 

610 be adjusted (and hence pruned or skipped) until that output is actually 

611 found to be missing at execution time. 

612 

613 Tasks that desire zip-iteration consistency between any combinations of 

614 connections that have the same data ID should generally implement 

615 `adjustQuantum` to achieve this, even if they could also run that 

616 logic during execution; this allows the system to see outputs that will 

617 not be produced because the corresponding input is missing as early as 

618 possible. 

619 """ 

620 for name, (input_connection, refs) in inputs.items(): 

621 dataset_type_name = input_connection.name 

622 if not input_connection.multiple and len(refs) > 1: 

623 raise ScalarError( 

624 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

625 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

626 f"for quantum data ID {data_id}." 

627 ) 

628 if len(refs) < input_connection.minimum: 

629 if isinstance(input_connection, PrerequisiteInput): 

630 # This branch should only be possible during QG generation, 

631 # or if someone deleted the dataset between making the QG 

632 # and trying to run it. Either one should be a hard error. 

633 raise FileNotFoundError( 

634 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

635 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

636 f"{data_id}." 

637 ) 

638 else: 

639 # This branch should be impossible during QG generation, 

640 # because that algorithm can only make quanta whose inputs 

641 # are either already present or should be created during 

642 # execution. It can trigger during execution if the input 

643 # wasn't actually created by an upstream task in the same 

644 # graph. 

645 raise NoWorkFound(label, name, input_connection) 

646 for name, (output_connection, refs) in outputs.items(): 

647 dataset_type_name = output_connection.name 

648 if not output_connection.multiple and len(refs) > 1: 

649 raise ScalarError( 

650 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

651 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

652 f"for quantum data ID {data_id}." 

653 ) 

654 return {}, {} 

655 

656 

657def iterConnections( 

658 connections: PipelineTaskConnections, connectionType: Union[str, Iterable[str]] 

659) -> typing.Generator[BaseConnection, None, None]: 

660 """Creates an iterator over the selected connections type which yields 

661 all the defined connections of that type. 

662 

663 Parameters 

664 ---------- 

665 connections: `PipelineTaskConnections` 

666 An instance of a `PipelineTaskConnections` object that will be iterated 

667 over. 

668 connectionType: `str` 

669 The type of connections to iterate over, valid values are inputs, 

670 outputs, prerequisiteInputs, initInputs, initOutputs. 

671 

672 Yields 

673 ------ 

674 connection: `BaseConnection` 

675 A connection defined on the input connections object of the type 

676 supplied. The yielded value Will be an derived type of 

677 `BaseConnection`. 

678 """ 

679 if isinstance(connectionType, str): 

680 connectionType = (connectionType,) 

681 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

682 yield getattr(connections, name) 

683 

684 

685@dataclass 

686class AdjustQuantumHelper: 

687 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

688 

689 This class holds `input` and `output` mappings in the form used by 

690 `Quantum` and execution harness code, i.e. with `DatasetType` keys, 

691 translating them to and from the connection-oriented mappings used inside 

692 `PipelineTaskConnections`. 

693 """ 

694 

695 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

696 """Mapping of regular input and prerequisite input datasets, grouped by 

697 `DatasetType`. 

698 """ 

699 

700 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

701 """Mapping of output datasets, grouped by `DatasetType`. 

702 """ 

703 

704 inputs_adjusted: bool = False 

705 """Whether any inputs were removed in the last call to `adjust_in_place`. 

706 """ 

707 

708 outputs_adjusted: bool = False 

709 """Whether any outputs were removed in the last call to `adjust_in_place`. 

710 """ 

711 

712 def adjust_in_place( 

713 self, 

714 connections: PipelineTaskConnections, 

715 label: str, 

716 data_id: DataCoordinate, 

717 ) -> None: 

718 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

719 with its results. 

720 

721 Parameters 

722 ---------- 

723 connections : `PipelineTaskConnections` 

724 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

725 label : `str` 

726 Label for this task in the pipeline (should be used in all 

727 diagnostic messages). 

728 data_id : `lsst.daf.butler.DataCoordinate` 

729 Data ID for this quantum in the pipeline (should be used in all 

730 diagnostic messages). 

731 """ 

732 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

733 # connection-keyed, PipelineTask-oriented mappings. 

734 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {} 

735 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {} 

736 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

737 connection = getattr(connections, name) 

738 dataset_type_name = connection.name 

739 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

740 for name in itertools.chain(connections.outputs): 

741 connection = getattr(connections, name) 

742 dataset_type_name = connection.name 

743 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

744 # Actually call adjustQuantum. 

745 # MyPy correctly complains that this call is not quite legal, but the 

746 # method docs explain exactly what's expected and it's the behavior we 

747 # want. It'd be nice to avoid this if we ever have to change the 

748 # interface anyway, but not an immediate problem. 

749 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

750 inputs_by_connection, # type: ignore 

751 outputs_by_connection, # type: ignore 

752 label, 

753 data_id, 

754 ) 

755 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

756 # installing new mappings in self if necessary. 

757 if adjusted_inputs_by_connection: 

758 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs) 

759 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

760 dataset_type_name = connection.name 

761 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

762 raise RuntimeError( 

763 f"adjustQuantum implementation for task with label {label} returned {name} " 

764 f"({dataset_type_name}) input datasets that are not a subset of those " 

765 f"it was given for data ID {data_id}." 

766 ) 

767 adjusted_inputs[dataset_type_name] = list(updated_refs) 

768 self.inputs = adjusted_inputs.freeze() 

769 self.inputs_adjusted = True 

770 else: 

771 self.inputs_adjusted = False 

772 if adjusted_outputs_by_connection: 

773 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs) 

774 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

775 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

776 raise RuntimeError( 

777 f"adjustQuantum implementation for task with label {label} returned {name} " 

778 f"({dataset_type_name}) output datasets that are not a subset of those " 

779 f"it was given for data ID {data_id}." 

780 ) 

781 adjusted_outputs[dataset_type_name] = list(updated_refs) 

782 self.outputs = adjusted_outputs.freeze() 

783 self.outputs_adjusted = True 

784 else: 

785 self.outputs_adjusted = False