Coverage for python/lsst/pipe/base/connections.py: 46%

225 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-12 02:06 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import itertools 

39import string 

40import typing 

41from collections import UserDict 

42from dataclasses import dataclass 

43from types import SimpleNamespace 

44from typing import Any, ClassVar, Dict, Iterable, List, Set, Union 

45 

46from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

47 

48from ._status import NoWorkFound 

49from .connectionTypes import ( 

50 BaseConnection, 

51 BaseInput, 

52 InitInput, 

53 InitOutput, 

54 Input, 

55 Output, 

56 PrerequisiteInput, 

57) 

58 

59if typing.TYPE_CHECKING: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true

60 from .config import PipelineTaskConfig 

61 

62 

63class ScalarError(TypeError): 

64 """Exception raised when dataset type is configured as scalar 

65 but there are multiple data IDs in a Quantum for that dataset. 

66 """ 

67 

68 

69class PipelineTaskConnectionDict(UserDict): 

70 """This is a special dict class used by PipelineTaskConnectionMetaclass 

71 

72 This dict is used in PipelineTaskConnection class creation, as the 

73 dictionary that is initially used as __dict__. It exists to 

74 intercept connection fields declared in a PipelineTaskConnection, and 

75 what name is used to identify them. The names are then added to class 

76 level list according to the connection type of the class attribute. The 

77 names are also used as keys in a class level dictionary associated with 

78 the corresponding class attribute. This information is a duplicate of 

79 what exists in __dict__, but provides a simple place to lookup and 

80 iterate on only these variables. 

81 """ 

82 

83 def __init__(self, *args: Any, **kwargs: Any): 

84 super().__init__(*args, **kwargs) 

85 # Initialize class level variables used to track any declared 

86 # class level variables that are instances of 

87 # connectionTypes.BaseConnection 

88 self.data["inputs"] = [] 

89 self.data["prerequisiteInputs"] = [] 

90 self.data["outputs"] = [] 

91 self.data["initInputs"] = [] 

92 self.data["initOutputs"] = [] 

93 self.data["allConnections"] = {} 

94 

95 def __setitem__(self, name: str, value: Any) -> None: 

96 if isinstance(value, Input): 

97 self.data["inputs"].append(name) 

98 elif isinstance(value, PrerequisiteInput): 

99 self.data["prerequisiteInputs"].append(name) 

100 elif isinstance(value, Output): 

101 self.data["outputs"].append(name) 

102 elif isinstance(value, InitInput): 

103 self.data["initInputs"].append(name) 

104 elif isinstance(value, InitOutput): 

105 self.data["initOutputs"].append(name) 

106 # This should not be an elif, as it needs tested for 

107 # everything that inherits from BaseConnection 

108 if isinstance(value, BaseConnection): 

109 object.__setattr__(value, "varName", name) 

110 self.data["allConnections"][name] = value 

111 # defer to the default behavior 

112 super().__setitem__(name, value) 

113 

114 

115class PipelineTaskConnectionsMetaclass(type): 

116 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

117 

118 def __prepare__(name, bases, **kwargs): # noqa: 805 

119 # Create an instance of our special dict to catch and track all 

120 # variables that are instances of connectionTypes.BaseConnection 

121 # Copy any existing connections from a parent class 

122 dct = PipelineTaskConnectionDict() 

123 for base in bases: 

124 if isinstance(base, PipelineTaskConnectionsMetaclass): 124 ↛ 123line 124 didn't jump to line 123, because the condition on line 124 was never false

125 for name, value in base.allConnections.items(): 125 ↛ 126line 125 didn't jump to line 126, because the loop on line 125 never started

126 dct[name] = value 

127 return dct 

128 

129 def __new__(cls, name, bases, dct, **kwargs): 

130 dimensionsValueError = TypeError( 

131 "PipelineTaskConnections class must be created with a dimensions " 

132 "attribute which is an iterable of dimension names" 

133 ) 

134 

135 if name != "PipelineTaskConnections": 

136 # Verify that dimensions are passed as a keyword in class 

137 # declaration 

138 if "dimensions" not in kwargs: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true

139 for base in bases: 

140 if hasattr(base, "dimensions"): 

141 kwargs["dimensions"] = base.dimensions 

142 break 

143 if "dimensions" not in kwargs: 

144 raise dimensionsValueError 

145 try: 

146 if isinstance(kwargs["dimensions"], str): 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true

147 raise TypeError( 

148 "Dimensions must be iterable of dimensions, got str," 

149 "possibly omitted trailing comma" 

150 ) 

151 if not isinstance(kwargs["dimensions"], typing.Iterable): 151 ↛ 152line 151 didn't jump to line 152, because the condition on line 151 was never true

152 raise TypeError("Dimensions must be iterable of dimensions") 

153 dct["dimensions"] = set(kwargs["dimensions"]) 

154 except TypeError as exc: 

155 raise dimensionsValueError from exc 

156 # Lookup any python string templates that may have been used in the 

157 # declaration of the name field of a class connection attribute 

158 allTemplates = set() 

159 stringFormatter = string.Formatter() 

160 # Loop over all connections 

161 for obj in dct["allConnections"].values(): 

162 nameValue = obj.name 

163 # add all the parameters to the set of templates 

164 for param in stringFormatter.parse(nameValue): 

165 if param[1] is not None: 

166 allTemplates.add(param[1]) 

167 

168 # look up any template from base classes and merge them all 

169 # together 

170 mergeDict = {} 

171 for base in bases[::-1]: 

172 if hasattr(base, "defaultTemplates"): 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true

173 mergeDict.update(base.defaultTemplates) 

174 if "defaultTemplates" in kwargs: 

175 mergeDict.update(kwargs["defaultTemplates"]) 

176 

177 if len(mergeDict) > 0: 

178 kwargs["defaultTemplates"] = mergeDict 

179 

180 # Verify that if templated strings were used, defaults were 

181 # supplied as an argument in the declaration of the connection 

182 # class 

183 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true

184 raise TypeError( 

185 "PipelineTaskConnection class contains templated attribute names, but no " 

186 "defaut templates were provided, add a dictionary attribute named " 

187 "defaultTemplates which contains the mapping between template key and value" 

188 ) 

189 if len(allTemplates) > 0: 

190 # Verify all templates have a default, and throw if they do not 

191 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

192 templateDifference = allTemplates.difference(defaultTemplateKeys) 

193 if templateDifference: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

195 # Verify that templates do not share names with variable names 

196 # used for a connection, this is needed because of how 

197 # templates are specified in an associated config class. 

198 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

199 if len(nameTemplateIntersection) > 0: 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true

200 raise TypeError( 

201 f"Template parameters cannot share names with Class attributes" 

202 f" (conflicts are {nameTemplateIntersection})." 

203 ) 

204 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

205 

206 # Convert all the connection containers into frozensets so they cannot 

207 # be modified at the class scope 

208 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

209 dct[connectionName] = frozenset(dct[connectionName]) 

210 # our custom dict type must be turned into an actual dict to be used in 

211 # type.__new__ 

212 return super().__new__(cls, name, bases, dict(dct)) 

213 

214 def __init__(cls, name, bases, dct, **kwargs): 

215 # This overrides the default init to drop the kwargs argument. Python 

216 # metaclasses will have this argument set if any kwargs are passes at 

217 # class construction time, but should be consumed before calling 

218 # __init__ on the type metaclass. This is in accordance with python 

219 # documentation on metaclasses 

220 super().__init__(name, bases, dct) 

221 

222 

223class QuantizedConnection(SimpleNamespace): 

224 """A Namespace to map defined variable names of connections to the 

225 associated `lsst.daf.butler.DatasetRef` objects. 

226 

227 This class maps the names used to define a connection on a 

228 PipelineTaskConnectionsClass to the corresponding 

229 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

230 instance. This will be a quantum of execution based on the graph created 

231 by examining all the connections defined on the 

232 `PipelineTaskConnectionsClass`. 

233 """ 

234 

235 def __init__(self, **kwargs): 

236 # Create a variable to track what attributes are added. This is used 

237 # later when iterating over this QuantizedConnection instance 

238 object.__setattr__(self, "_attributes", set()) 

239 

240 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]) -> None: 

241 # Capture the attribute name as it is added to this object 

242 self._attributes.add(name) 

243 super().__setattr__(name, value) 

244 

245 def __delattr__(self, name): 

246 object.__delattr__(self, name) 

247 self._attributes.remove(name) 

248 

249 def __len__(self) -> int: 

250 return len(self._attributes) 

251 

252 def __iter__( 

253 self, 

254 ) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, typing.List[DatasetRef]]], None, None]: 

255 """Make an Iterator for this QuantizedConnection 

256 

257 Iterating over a QuantizedConnection will yield a tuple with the name 

258 of an attribute and the value associated with that name. This is 

259 similar to dict.items() but is on the namespace attributes rather than 

260 dict keys. 

261 """ 

262 yield from ((name, getattr(self, name)) for name in self._attributes) 

263 

264 def keys(self) -> typing.Generator[str, None, None]: 

265 """Returns an iterator over all the attributes added to a 

266 QuantizedConnection class 

267 """ 

268 yield from self._attributes 

269 

270 

271class InputQuantizedConnection(QuantizedConnection): 

272 pass 

273 

274 

275class OutputQuantizedConnection(QuantizedConnection): 

276 pass 

277 

278 

279@dataclass(frozen=True) 

280class DeferredDatasetRef: 

281 """A wrapper class for `DatasetRef` that indicates that a `PipelineTask` 

282 should receive a `DeferredDatasetHandle` instead of an in-memory dataset. 

283 

284 Parameters 

285 ---------- 

286 datasetRef : `lsst.daf.butler.DatasetRef` 

287 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

288 resolve a dataset 

289 """ 

290 

291 datasetRef: DatasetRef 

292 

293 @property 

294 def datasetType(self) -> DatasetType: 

295 """The dataset type for this dataset.""" 

296 return self.datasetRef.datasetType 

297 

298 @property 

299 def dataId(self) -> DataCoordinate: 

300 """The data ID for this dataset.""" 

301 return self.datasetRef.dataId 

302 

303 

304class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

305 """PipelineTaskConnections is a class used to declare desired IO when a 

306 PipelineTask is run by an activator 

307 

308 Parameters 

309 ---------- 

310 config : `PipelineTaskConfig` 

311 A `PipelineTaskConfig` class instance whose class has been configured 

312 to use this `PipelineTaskConnectionsClass` 

313 

314 See also 

315 -------- 

316 iterConnections 

317 

318 Notes 

319 ----- 

320 ``PipelineTaskConnection`` classes are created by declaring class 

321 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

322 listed as follows: 

323 

324 * ``InitInput`` - Defines connections in a quantum graph which are used as 

325 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

326 to this class 

327 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

328 persisted using a butler at the end of the ``__init__`` function of the 

329 `PipelineTask` corresponding to this class. The variable name used to 

330 define this connection should be the same as an attribute name on the 

331 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

332 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

333 a `PipelineTask` instance should have an attribute 

334 ``self.outputSchema`` defined. Its value is what will be saved by the 

335 activator framework. 

336 * ``PrerequisiteInput`` - An input connection type that defines a 

337 `lsst.daf.butler.DatasetType` that must be present at execution time, 

338 but that will not be used during the course of creating the quantum 

339 graph to be executed. These most often are things produced outside the 

340 processing pipeline, such as reference catalogs. 

341 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

342 in the ``run`` method of a `PipelineTask`. The name used to declare 

343 class attribute must match a function argument name in the ``run`` 

344 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

345 defines an ``Input`` with the name ``calexp``, then the corresponding 

346 signature should be ``PipelineTask.run(calexp, ...)`` 

347 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

348 execution of a `PipelineTask`. The name used to declare the connection 

349 must correspond to an attribute of a `Struct` that is returned by a 

350 `PipelineTask` ``run`` method. E.g. if an output connection is 

351 defined with the name ``measCat``, then the corresponding 

352 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

353 X matches the ``storageClass`` type defined on the output connection. 

354 

355 The process of declaring a ``PipelineTaskConnection`` class involves 

356 parameters passed in the declaration statement. 

357 

358 The first parameter is ``dimensions`` which is an iterable of strings which 

359 defines the unit of processing the run method of a corresponding 

360 `PipelineTask` will operate on. These dimensions must match dimensions that 

361 exist in the butler registry which will be used in executing the 

362 corresponding `PipelineTask`. 

363 

364 The second parameter is labeled ``defaultTemplates`` and is conditionally 

365 optional. The name attributes of connections can be specified as python 

366 format strings, with named format arguments. If any of the name parameters 

367 on connections defined in a `PipelineTaskConnections` class contain a 

368 template, then a default template value must be specified in the 

369 ``defaultTemplates`` argument. This is done by passing a dictionary with 

370 keys corresponding to a template identifier, and values corresponding to 

371 the value to use as a default when formatting the string. For example if 

372 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

373 ``defaultTemplates`` = {'input': 'deep'}. 

374 

375 Once a `PipelineTaskConnections` class is created, it is used in the 

376 creation of a `PipelineTaskConfig`. This is further documented in the 

377 documentation of `PipelineTaskConfig`. For the purposes of this 

378 documentation, the relevant information is that the config class allows 

379 configuration of connection names by users when running a pipeline. 

380 

381 Instances of a `PipelineTaskConnections` class are used by the pipeline 

382 task execution framework to introspect what a corresponding `PipelineTask` 

383 will require, and what it will produce. 

384 

385 Examples 

386 -------- 

387 >>> from lsst.pipe.base import connectionTypes as cT 

388 >>> from lsst.pipe.base import PipelineTaskConnections 

389 >>> from lsst.pipe.base import PipelineTaskConfig 

390 >>> class ExampleConnections(PipelineTaskConnections, 

391 ... dimensions=("A", "B"), 

392 ... defaultTemplates={"foo": "Example"}): 

393 ... inputConnection = cT.Input(doc="Example input", 

394 ... dimensions=("A", "B"), 

395 ... storageClass=Exposure, 

396 ... name="{foo}Dataset") 

397 ... outputConnection = cT.Output(doc="Example output", 

398 ... dimensions=("A", "B"), 

399 ... storageClass=Exposure, 

400 ... name="{foo}output") 

401 >>> class ExampleConfig(PipelineTaskConfig, 

402 ... pipelineConnections=ExampleConnections): 

403 ... pass 

404 >>> config = ExampleConfig() 

405 >>> config.connections.foo = Modified 

406 >>> config.connections.outputConnection = "TotallyDifferent" 

407 >>> connections = ExampleConnections(config=config) 

408 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

409 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

410 """ 

411 

412 dimensions: ClassVar[Set[str]] 

413 

414 def __init__(self, *, config: "PipelineTaskConfig" | None = None): 

415 self.inputs: Set[str] = set(self.inputs) 

416 self.prerequisiteInputs: Set[str] = set(self.prerequisiteInputs) 

417 self.outputs: Set[str] = set(self.outputs) 

418 self.initInputs: Set[str] = set(self.initInputs) 

419 self.initOutputs: Set[str] = set(self.initOutputs) 

420 self.allConnections: Dict[str, BaseConnection] = dict(self.allConnections) 

421 

422 from .config import PipelineTaskConfig # local import to avoid cycle 

423 

424 if config is None or not isinstance(config, PipelineTaskConfig): 

425 raise ValueError( 

426 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

427 ) 

428 self.config = config 

429 # Extract the template names that were defined in the config instance 

430 # by looping over the keys of the defaultTemplates dict specified at 

431 # class declaration time 

432 templateValues = { 

433 name: getattr(config.connections, name) for name in getattr(self, "defaultTemplates").keys() 

434 } 

435 # Extract the configured value corresponding to each connection 

436 # variable. I.e. for each connection identifier, populate a override 

437 # for the connection.name attribute 

438 self._nameOverrides = { 

439 name: getattr(config.connections, name).format(**templateValues) 

440 for name in self.allConnections.keys() 

441 } 

442 

443 # connections.name corresponds to a dataset type name, create a reverse 

444 # mapping that goes from dataset type name to attribute identifier name 

445 # (variable name) on the connection class 

446 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

447 

448 def buildDatasetRefs( 

449 self, quantum: Quantum 

450 ) -> typing.Tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

451 """Builds QuantizedConnections corresponding to input Quantum 

452 

453 Parameters 

454 ---------- 

455 quantum : `lsst.daf.butler.Quantum` 

456 Quantum object which defines the inputs and outputs for a given 

457 unit of processing 

458 

459 Returns 

460 ------- 

461 retVal : `tuple` of (`InputQuantizedConnection`, 

462 `OutputQuantizedConnection`) Namespaces mapping attribute names 

463 (identifiers of connections) to butler references defined in the 

464 input `lsst.daf.butler.Quantum` 

465 """ 

466 inputDatasetRefs = InputQuantizedConnection() 

467 outputDatasetRefs = OutputQuantizedConnection() 

468 # operate on a reference object and an interable of names of class 

469 # connection attributes 

470 for refs, names in zip( 

471 (inputDatasetRefs, outputDatasetRefs), 

472 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

473 ): 

474 # get a name of a class connection attribute 

475 for attributeName in names: 

476 # get the attribute identified by name 

477 attribute = getattr(self, attributeName) 

478 # Branch if the attribute dataset type is an input 

479 if attribute.name in quantum.inputs: 

480 # if the dataset is marked to load deferred, wrap it in a 

481 # DeferredDatasetRef 

482 quantumInputRefs: Union[List[DatasetRef], List[DeferredDatasetRef]] 

483 if attribute.deferLoad: 

484 quantumInputRefs = [ 

485 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

486 ] 

487 else: 

488 quantumInputRefs = list(quantum.inputs[attribute.name]) 

489 # Unpack arguments that are not marked multiples (list of 

490 # length one) 

491 if not attribute.multiple: 

492 if len(quantumInputRefs) > 1: 

493 raise ScalarError( 

494 "Received multiple datasets " 

495 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

496 f"for scalar connection {attributeName} " 

497 f"({quantumInputRefs[0].datasetType.name}) " 

498 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

499 ) 

500 if len(quantumInputRefs) == 0: 

501 continue 

502 setattr(refs, attributeName, quantumInputRefs[0]) 

503 else: 

504 # Add to the QuantizedConnection identifier 

505 setattr(refs, attributeName, quantumInputRefs) 

506 # Branch if the attribute dataset type is an output 

507 elif attribute.name in quantum.outputs: 

508 value = quantum.outputs[attribute.name] 

509 # Unpack arguments that are not marked multiples (list of 

510 # length one) 

511 if not attribute.multiple: 

512 setattr(refs, attributeName, value[0]) 

513 else: 

514 setattr(refs, attributeName, value) 

515 # Specified attribute is not in inputs or outputs dont know how 

516 # to handle, throw 

517 else: 

518 raise ValueError( 

519 f"Attribute with name {attributeName} has no counterpoint in input quantum" 

520 ) 

521 return inputDatasetRefs, outputDatasetRefs 

522 

523 def adjustQuantum( 

524 self, 

525 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

526 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

527 label: str, 

528 data_id: DataCoordinate, 

529 ) -> typing.Tuple[ 

530 typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

531 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

532 ]: 

533 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

534 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

535 of the activator. 

536 

537 Parameters 

538 ---------- 

539 inputs : `dict` 

540 Dictionary whose keys are an input (regular or prerequisite) 

541 connection name and whose values are a tuple of the connection 

542 instance and a collection of associated `DatasetRef` objects. 

543 The exact type of the nested collections is unspecified; it can be 

544 assumed to be multi-pass iterable and support `len` and ``in``, but 

545 it should not be mutated in place. In contrast, the outer 

546 dictionaries are guaranteed to be temporary copies that are true 

547 `dict` instances, and hence may be modified and even returned; this 

548 is especially useful for delegating to `super` (see notes below). 

549 outputs : `Mapping` 

550 Mapping of output datasets, with the same structure as ``inputs``. 

551 label : `str` 

552 Label for this task in the pipeline (should be used in all 

553 diagnostic messages). 

554 data_id : `lsst.daf.butler.DataCoordinate` 

555 Data ID for this quantum in the pipeline (should be used in all 

556 diagnostic messages). 

557 

558 Returns 

559 ------- 

560 adjusted_inputs : `Mapping` 

561 Mapping of the same form as ``inputs`` with updated containers of 

562 input `DatasetRef` objects. Connections that are not changed 

563 should not be returned at all. Datasets may only be removed, not 

564 added. Nested collections may be of any multi-pass iterable type, 

565 and the order of iteration will set the order of iteration within 

566 `PipelineTask.runQuantum`. 

567 adjusted_outputs : `Mapping` 

568 Mapping of updated output datasets, with the same structure and 

569 interpretation as ``adjusted_inputs``. 

570 

571 Raises 

572 ------ 

573 ScalarError 

574 Raised if any `Input` or `PrerequisiteInput` connection has 

575 ``multiple`` set to `False`, but multiple datasets. 

576 NoWorkFound 

577 Raised to indicate that this quantum should not be run; not enough 

578 datasets were found for a regular `Input` connection, and the 

579 quantum should be pruned or skipped. 

580 FileNotFoundError 

581 Raised to cause QuantumGraph generation to fail (with the message 

582 included in this exception); not enough datasets were found for a 

583 `PrerequisiteInput` connection. 

584 

585 Notes 

586 ----- 

587 The base class implementation performs important checks. It always 

588 returns an empty mapping (i.e. makes no adjustments). It should 

589 always called be via `super` by custom implementations, ideally at the 

590 end of the custom implementation with already-adjusted mappings when 

591 any datasets are actually dropped, e.g.:: 

592 

593 def adjustQuantum(self, inputs, outputs, label, data_id): 

594 # Filter out some dataset refs for one connection. 

595 connection, old_refs = inputs["my_input"] 

596 new_refs = [ref for ref in old_refs if ...] 

597 adjusted_inputs = {"my_input", (connection, new_refs)} 

598 # Update the original inputs so we can pass them to super. 

599 inputs.update(adjusted_inputs) 

600 # Can ignore outputs from super because they are guaranteed 

601 # to be empty. 

602 super().adjustQuantum(inputs, outputs, label_data_id) 

603 # Return only the connections we modified. 

604 return adjusted_inputs, {} 

605 

606 Removing outputs here is guaranteed to affect what is actually 

607 passed to `PipelineTask.runQuantum`, but its effect on the larger 

608 graph may be deferred to execution, depending on the context in 

609 which `adjustQuantum` is being run: if one quantum removes an output 

610 that is needed by a second quantum as input, the second quantum may not 

611 be adjusted (and hence pruned or skipped) until that output is actually 

612 found to be missing at execution time. 

613 

614 Tasks that desire zip-iteration consistency between any combinations of 

615 connections that have the same data ID should generally implement 

616 `adjustQuantum` to achieve this, even if they could also run that 

617 logic during execution; this allows the system to see outputs that will 

618 not be produced because the corresponding input is missing as early as 

619 possible. 

620 """ 

621 for name, (input_connection, refs) in inputs.items(): 

622 dataset_type_name = input_connection.name 

623 if not input_connection.multiple and len(refs) > 1: 

624 raise ScalarError( 

625 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

626 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

627 f"for quantum data ID {data_id}." 

628 ) 

629 if len(refs) < input_connection.minimum: 

630 if isinstance(input_connection, PrerequisiteInput): 

631 # This branch should only be possible during QG generation, 

632 # or if someone deleted the dataset between making the QG 

633 # and trying to run it. Either one should be a hard error. 

634 raise FileNotFoundError( 

635 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

636 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

637 f"{data_id}." 

638 ) 

639 else: 

640 # This branch should be impossible during QG generation, 

641 # because that algorithm can only make quanta whose inputs 

642 # are either already present or should be created during 

643 # execution. It can trigger during execution if the input 

644 # wasn't actually created by an upstream task in the same 

645 # graph. 

646 raise NoWorkFound(label, name, input_connection) 

647 for name, (output_connection, refs) in outputs.items(): 

648 dataset_type_name = output_connection.name 

649 if not output_connection.multiple and len(refs) > 1: 

650 raise ScalarError( 

651 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

652 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

653 f"for quantum data ID {data_id}." 

654 ) 

655 return {}, {} 

656 

657 

658def iterConnections( 

659 connections: PipelineTaskConnections, connectionType: Union[str, Iterable[str]] 

660) -> typing.Generator[BaseConnection, None, None]: 

661 """Creates an iterator over the selected connections type which yields 

662 all the defined connections of that type. 

663 

664 Parameters 

665 ---------- 

666 connections: `PipelineTaskConnections` 

667 An instance of a `PipelineTaskConnections` object that will be iterated 

668 over. 

669 connectionType: `str` 

670 The type of connections to iterate over, valid values are inputs, 

671 outputs, prerequisiteInputs, initInputs, initOutputs. 

672 

673 Yields 

674 ------ 

675 connection: `BaseConnection` 

676 A connection defined on the input connections object of the type 

677 supplied. The yielded value Will be an derived type of 

678 `BaseConnection`. 

679 """ 

680 if isinstance(connectionType, str): 

681 connectionType = (connectionType,) 

682 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

683 yield getattr(connections, name) 

684 

685 

686@dataclass 

687class AdjustQuantumHelper: 

688 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

689 

690 This class holds `input` and `output` mappings in the form used by 

691 `Quantum` and execution harness code, i.e. with `DatasetType` keys, 

692 translating them to and from the connection-oriented mappings used inside 

693 `PipelineTaskConnections`. 

694 """ 

695 

696 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

697 """Mapping of regular input and prerequisite input datasets, grouped by 

698 `DatasetType`. 

699 """ 

700 

701 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

702 """Mapping of output datasets, grouped by `DatasetType`. 

703 """ 

704 

705 inputs_adjusted: bool = False 

706 """Whether any inputs were removed in the last call to `adjust_in_place`. 

707 """ 

708 

709 outputs_adjusted: bool = False 

710 """Whether any outputs were removed in the last call to `adjust_in_place`. 

711 """ 

712 

713 def adjust_in_place( 

714 self, 

715 connections: PipelineTaskConnections, 

716 label: str, 

717 data_id: DataCoordinate, 

718 ) -> None: 

719 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

720 with its results. 

721 

722 Parameters 

723 ---------- 

724 connections : `PipelineTaskConnections` 

725 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

726 label : `str` 

727 Label for this task in the pipeline (should be used in all 

728 diagnostic messages). 

729 data_id : `lsst.daf.butler.DataCoordinate` 

730 Data ID for this quantum in the pipeline (should be used in all 

731 diagnostic messages). 

732 """ 

733 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

734 # connection-keyed, PipelineTask-oriented mappings. 

735 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {} 

736 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {} 

737 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

738 connection = getattr(connections, name) 

739 dataset_type_name = connection.name 

740 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

741 for name in itertools.chain(connections.outputs): 

742 connection = getattr(connections, name) 

743 dataset_type_name = connection.name 

744 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

745 # Actually call adjustQuantum. 

746 # MyPy correctly complains that this call is not quite legal, but the 

747 # method docs explain exactly what's expected and it's the behavior we 

748 # want. It'd be nice to avoid this if we ever have to change the 

749 # interface anyway, but not an immediate problem. 

750 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

751 inputs_by_connection, # type: ignore 

752 outputs_by_connection, # type: ignore 

753 label, 

754 data_id, 

755 ) 

756 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

757 # installing new mappings in self if necessary. 

758 if adjusted_inputs_by_connection: 

759 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs) 

760 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

761 dataset_type_name = connection.name 

762 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

763 raise RuntimeError( 

764 f"adjustQuantum implementation for task with label {label} returned {name} " 

765 f"({dataset_type_name}) input datasets that are not a subset of those " 

766 f"it was given for data ID {data_id}." 

767 ) 

768 adjusted_inputs[dataset_type_name] = list(updated_refs) 

769 self.inputs = adjusted_inputs.freeze() 

770 self.inputs_adjusted = True 

771 else: 

772 self.inputs_adjusted = False 

773 if adjusted_outputs_by_connection: 

774 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs) 

775 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

776 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

777 raise RuntimeError( 

778 f"adjustQuantum implementation for task with label {label} returned {name} " 

779 f"({dataset_type_name}) output datasets that are not a subset of those " 

780 f"it was given for data ID {data_id}." 

781 ) 

782 adjusted_outputs[dataset_type_name] = list(updated_refs) 

783 self.outputs = adjusted_outputs.freeze() 

784 self.outputs_adjusted = True 

785 else: 

786 self.outputs_adjusted = False