Coverage for python/lsst/pipe/base/connections.py: 44%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

211 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError" 

36] 

37 

38from collections import UserDict, namedtuple 

39from dataclasses import dataclass 

40from types import SimpleNamespace 

41import typing 

42from typing import Union, Iterable 

43 

44import itertools 

45import string 

46 

47from . import config as configMod 

48from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

49 Output, BaseConnection, BaseInput) 

50from ._status import NoWorkFound 

51from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

52 

53if typing.TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 from .config import PipelineTaskConfig 

55 

56 

57class ScalarError(TypeError): 

58 """Exception raised when dataset type is configured as scalar 

59 but there are multiple data IDs in a Quantum for that dataset. 

60 """ 

61 

62 

63class PipelineTaskConnectionDict(UserDict): 

64 """This is a special dict class used by PipelineTaskConnectionMetaclass 

65 

66 This dict is used in PipelineTaskConnection class creation, as the 

67 dictionary that is initially used as __dict__. It exists to 

68 intercept connection fields declared in a PipelineTaskConnection, and 

69 what name is used to identify them. The names are then added to class 

70 level list according to the connection type of the class attribute. The 

71 names are also used as keys in a class level dictionary associated with 

72 the corresponding class attribute. This information is a duplicate of 

73 what exists in __dict__, but provides a simple place to lookup and 

74 iterate on only these variables. 

75 """ 

76 def __init__(self, *args, **kwargs): 

77 super().__init__(*args, **kwargs) 

78 # Initialize class level variables used to track any declared 

79 # class level variables that are instances of 

80 # connectionTypes.BaseConnection 

81 self.data['inputs'] = [] 

82 self.data['prerequisiteInputs'] = [] 

83 self.data['outputs'] = [] 

84 self.data['initInputs'] = [] 

85 self.data['initOutputs'] = [] 

86 self.data['allConnections'] = {} 

87 

88 def __setitem__(self, name, value): 

89 if isinstance(value, Input): 

90 self.data['inputs'].append(name) 

91 elif isinstance(value, PrerequisiteInput): 

92 self.data['prerequisiteInputs'].append(name) 

93 elif isinstance(value, Output): 

94 self.data['outputs'].append(name) 

95 elif isinstance(value, InitInput): 

96 self.data['initInputs'].append(name) 

97 elif isinstance(value, InitOutput): 

98 self.data['initOutputs'].append(name) 

99 # This should not be an elif, as it needs tested for 

100 # everything that inherits from BaseConnection 

101 if isinstance(value, BaseConnection): 

102 object.__setattr__(value, 'varName', name) 

103 self.data['allConnections'][name] = value 

104 # defer to the default behavior 

105 super().__setitem__(name, value) 

106 

107 

108class PipelineTaskConnectionsMetaclass(type): 

109 """Metaclass used in the declaration of PipelineTaskConnections classes 

110 """ 

111 def __prepare__(name, bases, **kwargs): # noqa: 805 

112 # Create an instance of our special dict to catch and track all 

113 # variables that are instances of connectionTypes.BaseConnection 

114 # Copy any existing connections from a parent class 

115 dct = PipelineTaskConnectionDict() 

116 for base in bases: 

117 if isinstance(base, PipelineTaskConnectionsMetaclass): 117 ↛ 116line 117 didn't jump to line 116, because the condition on line 117 was never false

118 for name, value in base.allConnections.items(): 118 ↛ 119line 118 didn't jump to line 119, because the loop on line 118 never started

119 dct[name] = value 

120 return dct 

121 

122 def __new__(cls, name, bases, dct, **kwargs): 

123 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

124 "attribute which is an iterable of dimension names") 

125 

126 if name != 'PipelineTaskConnections': 

127 # Verify that dimensions are passed as a keyword in class 

128 # declaration 

129 if 'dimensions' not in kwargs: 129 ↛ 130line 129 didn't jump to line 130, because the condition on line 129 was never true

130 for base in bases: 

131 if hasattr(base, 'dimensions'): 

132 kwargs['dimensions'] = base.dimensions 

133 break 

134 if 'dimensions' not in kwargs: 

135 raise dimensionsValueError 

136 try: 

137 if isinstance(kwargs['dimensions'], str): 137 ↛ 138line 137 didn't jump to line 138, because the condition on line 137 was never true

138 raise TypeError("Dimensions must be iterable of dimensions, got str," 

139 "possibly omitted trailing comma") 

140 if not isinstance(kwargs['dimensions'], typing.Iterable): 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true

141 raise TypeError("Dimensions must be iterable of dimensions") 

142 dct['dimensions'] = set(kwargs['dimensions']) 

143 except TypeError as exc: 

144 raise dimensionsValueError from exc 

145 # Lookup any python string templates that may have been used in the 

146 # declaration of the name field of a class connection attribute 

147 allTemplates = set() 

148 stringFormatter = string.Formatter() 

149 # Loop over all connections 

150 for obj in dct['allConnections'].values(): 

151 nameValue = obj.name 

152 # add all the parameters to the set of templates 

153 for param in stringFormatter.parse(nameValue): 

154 if param[1] is not None: 

155 allTemplates.add(param[1]) 

156 

157 # look up any template from base classes and merge them all 

158 # together 

159 mergeDict = {} 

160 for base in bases[::-1]: 

161 if hasattr(base, 'defaultTemplates'): 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true

162 mergeDict.update(base.defaultTemplates) 

163 if 'defaultTemplates' in kwargs: 

164 mergeDict.update(kwargs['defaultTemplates']) 

165 

166 if len(mergeDict) > 0: 

167 kwargs['defaultTemplates'] = mergeDict 

168 

169 # Verify that if templated strings were used, defaults were 

170 # supplied as an argument in the declaration of the connection 

171 # class 

172 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true

173 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

174 "defaut templates were provided, add a dictionary attribute named " 

175 "defaultTemplates which contains the mapping between template key and value") 

176 if len(allTemplates) > 0: 

177 # Verify all templates have a default, and throw if they do not 

178 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

179 templateDifference = allTemplates.difference(defaultTemplateKeys) 

180 if templateDifference: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true

181 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

182 # Verify that templates do not share names with variable names 

183 # used for a connection, this is needed because of how 

184 # templates are specified in an associated config class. 

185 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

186 if len(nameTemplateIntersection) > 0: 186 ↛ 187line 186 didn't jump to line 187, because the condition on line 186 was never true

187 raise TypeError(f"Template parameters cannot share names with Class attributes" 

188 f" (conflicts are {nameTemplateIntersection}).") 

189 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

190 

191 # Convert all the connection containers into frozensets so they cannot 

192 # be modified at the class scope 

193 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

194 dct[connectionName] = frozenset(dct[connectionName]) 

195 # our custom dict type must be turned into an actual dict to be used in 

196 # type.__new__ 

197 return super().__new__(cls, name, bases, dict(dct)) 

198 

199 def __init__(cls, name, bases, dct, **kwargs): 

200 # This overrides the default init to drop the kwargs argument. Python 

201 # metaclasses will have this argument set if any kwargs are passes at 

202 # class construction time, but should be consumed before calling 

203 # __init__ on the type metaclass. This is in accordance with python 

204 # documentation on metaclasses 

205 super().__init__(name, bases, dct) 

206 

207 

208class QuantizedConnection(SimpleNamespace): 

209 """A Namespace to map defined variable names of connections to the 

210 associated `lsst.daf.butler.DatasetRef` objects. 

211 

212 This class maps the names used to define a connection on a 

213 PipelineTaskConnectionsClass to the corresponding 

214 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

215 instance. This will be a quantum of execution based on the graph created 

216 by examining all the connections defined on the 

217 `PipelineTaskConnectionsClass`. 

218 """ 

219 def __init__(self, **kwargs): 

220 # Create a variable to track what attributes are added. This is used 

221 # later when iterating over this QuantizedConnection instance 

222 object.__setattr__(self, "_attributes", set()) 

223 

224 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

225 # Capture the attribute name as it is added to this object 

226 self._attributes.add(name) 

227 super().__setattr__(name, value) 

228 

229 def __delattr__(self, name): 

230 object.__delattr__(self, name) 

231 self._attributes.remove(name) 

232 

233 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

234 typing.List[DatasetRef]]], None, None]: 

235 """Make an Iterator for this QuantizedConnection 

236 

237 Iterating over a QuantizedConnection will yield a tuple with the name 

238 of an attribute and the value associated with that name. This is 

239 similar to dict.items() but is on the namespace attributes rather than 

240 dict keys. 

241 """ 

242 yield from ((name, getattr(self, name)) for name in self._attributes) 

243 

244 def keys(self) -> typing.Generator[str, None, None]: 

245 """Returns an iterator over all the attributes added to a 

246 QuantizedConnection class 

247 """ 

248 yield from self._attributes 

249 

250 

251class InputQuantizedConnection(QuantizedConnection): 

252 pass 

253 

254 

255class OutputQuantizedConnection(QuantizedConnection): 

256 pass 

257 

258 

259class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

260 """Class which denotes that a datasetRef should be treated as deferred when 

261 interacting with the butler 

262 

263 Parameters 

264 ---------- 

265 datasetRef : `lsst.daf.butler.DatasetRef` 

266 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

267 resolve a dataset 

268 """ 

269 __slots__ = () 

270 

271 

272class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

273 """PipelineTaskConnections is a class used to declare desired IO when a 

274 PipelineTask is run by an activator 

275 

276 Parameters 

277 ---------- 

278 config : `PipelineTaskConfig` 

279 A `PipelineTaskConfig` class instance whose class has been configured 

280 to use this `PipelineTaskConnectionsClass` 

281 

282 See also 

283 -------- 

284 iterConnections 

285 

286 Notes 

287 ----- 

288 ``PipelineTaskConnection`` classes are created by declaring class 

289 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

290 listed as follows: 

291 

292 * ``InitInput`` - Defines connections in a quantum graph which are used as 

293 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

294 to this class 

295 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

296 persisted using a butler at the end of the ``__init__`` function of the 

297 `PipelineTask` corresponding to this class. The variable name used to 

298 define this connection should be the same as an attribute name on the 

299 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

300 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

301 a `PipelineTask` instance should have an attribute 

302 ``self.outputSchema`` defined. Its value is what will be saved by the 

303 activator framework. 

304 * ``PrerequisiteInput`` - An input connection type that defines a 

305 `lsst.daf.butler.DatasetType` that must be present at execution time, 

306 but that will not be used during the course of creating the quantum 

307 graph to be executed. These most often are things produced outside the 

308 processing pipeline, such as reference catalogs. 

309 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

310 in the ``run`` method of a `PipelineTask`. The name used to declare 

311 class attribute must match a function argument name in the ``run`` 

312 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

313 defines an ``Input`` with the name ``calexp``, then the corresponding 

314 signature should be ``PipelineTask.run(calexp, ...)`` 

315 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

316 execution of a `PipelineTask`. The name used to declare the connection 

317 must correspond to an attribute of a `Struct` that is returned by a 

318 `PipelineTask` ``run`` method. E.g. if an output connection is 

319 defined with the name ``measCat``, then the corresponding 

320 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

321 X matches the ``storageClass`` type defined on the output connection. 

322 

323 The process of declaring a ``PipelineTaskConnection`` class involves 

324 parameters passed in the declaration statement. 

325 

326 The first parameter is ``dimensions`` which is an iterable of strings which 

327 defines the unit of processing the run method of a corresponding 

328 `PipelineTask` will operate on. These dimensions must match dimensions that 

329 exist in the butler registry which will be used in executing the 

330 corresponding `PipelineTask`. 

331 

332 The second parameter is labeled ``defaultTemplates`` and is conditionally 

333 optional. The name attributes of connections can be specified as python 

334 format strings, with named format arguments. If any of the name parameters 

335 on connections defined in a `PipelineTaskConnections` class contain a 

336 template, then a default template value must be specified in the 

337 ``defaultTemplates`` argument. This is done by passing a dictionary with 

338 keys corresponding to a template identifier, and values corresponding to 

339 the value to use as a default when formatting the string. For example if 

340 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

341 ``defaultTemplates`` = {'input': 'deep'}. 

342 

343 Once a `PipelineTaskConnections` class is created, it is used in the 

344 creation of a `PipelineTaskConfig`. This is further documented in the 

345 documentation of `PipelineTaskConfig`. For the purposes of this 

346 documentation, the relevant information is that the config class allows 

347 configuration of connection names by users when running a pipeline. 

348 

349 Instances of a `PipelineTaskConnections` class are used by the pipeline 

350 task execution framework to introspect what a corresponding `PipelineTask` 

351 will require, and what it will produce. 

352 

353 Examples 

354 -------- 

355 >>> from lsst.pipe.base import connectionTypes as cT 

356 >>> from lsst.pipe.base import PipelineTaskConnections 

357 >>> from lsst.pipe.base import PipelineTaskConfig 

358 >>> class ExampleConnections(PipelineTaskConnections, 

359 ... dimensions=("A", "B"), 

360 ... defaultTemplates={"foo": "Example"}): 

361 ... inputConnection = cT.Input(doc="Example input", 

362 ... dimensions=("A", "B"), 

363 ... storageClass=Exposure, 

364 ... name="{foo}Dataset") 

365 ... outputConnection = cT.Output(doc="Example output", 

366 ... dimensions=("A", "B"), 

367 ... storageClass=Exposure, 

368 ... name="{foo}output") 

369 >>> class ExampleConfig(PipelineTaskConfig, 

370 ... pipelineConnections=ExampleConnections): 

371 ... pass 

372 >>> config = ExampleConfig() 

373 >>> config.connections.foo = Modified 

374 >>> config.connections.outputConnection = "TotallyDifferent" 

375 >>> connections = ExampleConnections(config=config) 

376 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

377 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

378 """ 

379 

380 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

381 self.inputs = set(self.inputs) 

382 self.prerequisiteInputs = set(self.prerequisiteInputs) 

383 self.outputs = set(self.outputs) 

384 self.initInputs = set(self.initInputs) 

385 self.initOutputs = set(self.initOutputs) 

386 self.allConnections = dict(self.allConnections) 

387 

388 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

389 raise ValueError("PipelineTaskConnections must be instantiated with" 

390 " a PipelineTaskConfig instance") 

391 self.config = config 

392 # Extract the template names that were defined in the config instance 

393 # by looping over the keys of the defaultTemplates dict specified at 

394 # class declaration time 

395 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

396 'defaultTemplates').keys()} 

397 # Extract the configured value corresponding to each connection 

398 # variable. I.e. for each connection identifier, populate a override 

399 # for the connection.name attribute 

400 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

401 for name in self.allConnections.keys()} 

402 

403 # connections.name corresponds to a dataset type name, create a reverse 

404 # mapping that goes from dataset type name to attribute identifier name 

405 # (variable name) on the connection class 

406 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

407 

408 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

409 OutputQuantizedConnection]: 

410 """Builds QuantizedConnections corresponding to input Quantum 

411 

412 Parameters 

413 ---------- 

414 quantum : `lsst.daf.butler.Quantum` 

415 Quantum object which defines the inputs and outputs for a given 

416 unit of processing 

417 

418 Returns 

419 ------- 

420 retVal : `tuple` of (`InputQuantizedConnection`, 

421 `OutputQuantizedConnection`) Namespaces mapping attribute names 

422 (identifiers of connections) to butler references defined in the 

423 input `lsst.daf.butler.Quantum` 

424 """ 

425 inputDatasetRefs = InputQuantizedConnection() 

426 outputDatasetRefs = OutputQuantizedConnection() 

427 # operate on a reference object and an interable of names of class 

428 # connection attributes 

429 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

430 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

431 # get a name of a class connection attribute 

432 for attributeName in names: 

433 # get the attribute identified by name 

434 attribute = getattr(self, attributeName) 

435 # Branch if the attribute dataset type is an input 

436 if attribute.name in quantum.inputs: 

437 # Get the DatasetRefs 

438 quantumInputRefs = quantum.inputs[attribute.name] 

439 # if the dataset is marked to load deferred, wrap it in a 

440 # DeferredDatasetRef 

441 if attribute.deferLoad: 

442 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

443 # Unpack arguments that are not marked multiples (list of 

444 # length one) 

445 if not attribute.multiple: 

446 if len(quantumInputRefs) > 1: 

447 raise ScalarError( 

448 f"Received multiple datasets " 

449 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

450 f"for scalar connection {attributeName} " 

451 f"({quantumInputRefs[0].datasetType.name}) " 

452 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

453 ) 

454 if len(quantumInputRefs) == 0: 

455 continue 

456 quantumInputRefs = quantumInputRefs[0] 

457 # Add to the QuantizedConnection identifier 

458 setattr(refs, attributeName, quantumInputRefs) 

459 # Branch if the attribute dataset type is an output 

460 elif attribute.name in quantum.outputs: 

461 value = quantum.outputs[attribute.name] 

462 # Unpack arguments that are not marked multiples (list of 

463 # length one) 

464 if not attribute.multiple: 

465 value = value[0] 

466 # Add to the QuantizedConnection identifier 

467 setattr(refs, attributeName, value) 

468 # Specified attribute is not in inputs or outputs dont know how 

469 # to handle, throw 

470 else: 

471 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

472 "in input quantum") 

473 return inputDatasetRefs, outputDatasetRefs 

474 

475 def adjustQuantum( 

476 self, 

477 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

478 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

479 label: str, 

480 data_id: DataCoordinate, 

481 ) -> tuple.Tuple[typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

482 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]]]: 

483 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

484 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

485 of the activator. 

486 

487 Parameters 

488 ---------- 

489 inputs : `dict` 

490 Dictionary whose keys are an input (regular or prerequisite) 

491 connection name and whose values are a tuple of the connection 

492 instance and a collection of associated `DatasetRef` objects. 

493 The exact type of the nested collections is unspecified; it can be 

494 assumed to be multi-pass iterable and support `len` and ``in``, but 

495 it should not be mutated in place. In contrast, the outer 

496 dictionaries are guaranteed to be temporary copies that are true 

497 `dict` instances, and hence may be modified and even returned; this 

498 is especially useful for delegating to `super` (see notes below). 

499 outputs : `Mapping` 

500 Mapping of output datasets, with the same structure as ``inputs``. 

501 label : `str` 

502 Label for this task in the pipeline (should be used in all 

503 diagnostic messages). 

504 data_id : `lsst.daf.butler.DataCoordinate` 

505 Data ID for this quantum in the pipeline (should be used in all 

506 diagnostic messages). 

507 

508 Returns 

509 ------- 

510 adjusted_inputs : `Mapping` 

511 Mapping of the same form as ``inputs`` with updated containers of 

512 input `DatasetRef` objects. Connections that are not changed 

513 should not be returned at all. Datasets may only be removed, not 

514 added. Nested collections may be of any multi-pass iterable type, 

515 and the order of iteration will set the order of iteration within 

516 `PipelineTask.runQuantum`. 

517 adjusted_outputs : `Mapping` 

518 Mapping of updated output datasets, with the same structure and 

519 interpretation as ``adjusted_inputs``. 

520 

521 Raises 

522 ------ 

523 ScalarError 

524 Raised if any `Input` or `PrerequisiteInput` connection has 

525 ``multiple`` set to `False`, but multiple datasets. 

526 NoWorkFound 

527 Raised to indicate that this quantum should not be run; not enough 

528 datasets were found for a regular `Input` connection, and the 

529 quantum should be pruned or skipped. 

530 FileNotFoundError 

531 Raised to cause QuantumGraph generation to fail (with the message 

532 included in this exception); not enough datasets were found for a 

533 `PrerequisiteInput` connection. 

534 

535 Notes 

536 ----- 

537 The base class implementation performs important checks. It always 

538 returns an empty mapping (i.e. makes no adjustments). It should 

539 always called be via `super` by custom implementations, ideally at the 

540 end of the custom implementation with already-adjusted mappings when 

541 any datasets are actually dropped, e.g.:: 

542 

543 def adjustQuantum(self, inputs, outputs, label, data_id): 

544 # Filter out some dataset refs for one connection. 

545 connection, old_refs = inputs["my_input"] 

546 new_refs = [ref for ref in old_refs if ...] 

547 adjusted_inputs = {"my_input", (connection, new_refs)} 

548 # Update the original inputs so we can pass them to super. 

549 inputs.update(adjusted_inputs) 

550 # Can ignore outputs from super because they are guaranteed 

551 # to be empty. 

552 super().adjustQuantum(inputs, outputs, label_data_id) 

553 # Return only the connections we modified. 

554 return adjusted_inputs, {} 

555 

556 Removing outputs here is guaranteed to affect what is actually 

557 passed to `PipelineTask.runQuantum`, but its effect on the larger 

558 graph may be deferred to execution, depending on the context in 

559 which `adjustQuantum` is being run: if one quantum removes an output 

560 that is needed by a second quantum as input, the second quantum may not 

561 be adjusted (and hence pruned or skipped) until that output is actually 

562 found to be missing at execution time. 

563 

564 Tasks that desire zip-iteration consistency between any combinations of 

565 connections that have the same data ID should generally implement 

566 `adjustQuantum` to achieve this, even if they could also run that 

567 logic during execution; this allows the system to see outputs that will 

568 not be produced because the corresponding input is missing as early as 

569 possible. 

570 """ 

571 for name, (connection, refs) in inputs.items(): 

572 dataset_type_name = connection.name 

573 if not connection.multiple and len(refs) > 1: 

574 raise ScalarError( 

575 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

576 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

577 f"for quantum data ID {data_id}." 

578 ) 

579 if len(refs) < connection.minimum: 

580 if isinstance(connection, PrerequisiteInput): 

581 # This branch should only be possible during QG generation, 

582 # or if someone deleted the dataset between making the QG 

583 # and trying to run it. Either one should be a hard error. 

584 raise FileNotFoundError( 

585 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

586 f"({dataset_type_name}) with minimum={connection.minimum} for quantum data ID " 

587 f"{data_id}." 

588 ) 

589 else: 

590 # This branch should be impossible during QG generation, 

591 # because that algorithm can only make quanta whose inputs 

592 # are either already present or should be created during 

593 # execution. It can trigger during execution if the input 

594 # wasn't actually created by an upstream task in the same 

595 # graph. 

596 raise NoWorkFound(label, name, connection) 

597 for name, (connection, refs) in outputs.items(): 

598 dataset_type_name = connection.name 

599 if not connection.multiple and len(refs) > 1: 

600 raise ScalarError( 

601 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

602 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

603 f"for quantum data ID {data_id}." 

604 ) 

605 return {}, {} 

606 

607 

608def iterConnections(connections: PipelineTaskConnections, 

609 connectionType: Union[str, Iterable[str]] 

610 ) -> typing.Generator[BaseConnection, None, None]: 

611 """Creates an iterator over the selected connections type which yields 

612 all the defined connections of that type. 

613 

614 Parameters 

615 ---------- 

616 connections: `PipelineTaskConnections` 

617 An instance of a `PipelineTaskConnections` object that will be iterated 

618 over. 

619 connectionType: `str` 

620 The type of connections to iterate over, valid values are inputs, 

621 outputs, prerequisiteInputs, initInputs, initOutputs. 

622 

623 Yields 

624 ------- 

625 connection: `BaseConnection` 

626 A connection defined on the input connections object of the type 

627 supplied. The yielded value Will be an derived type of 

628 `BaseConnection`. 

629 """ 

630 if isinstance(connectionType, str): 

631 connectionType = (connectionType,) 

632 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

633 yield getattr(connections, name) 

634 

635 

636@dataclass 

637class AdjustQuantumHelper: 

638 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

639 

640 This class holds `input` and `output` mappings in the form used by 

641 `Quantum` and execution harness code, i.e. with `DatasetType` keys, 

642 translating them to and from the connection-oriented mappings used inside 

643 `PipelineTaskConnections`. 

644 """ 

645 

646 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

647 """Mapping of regular input and prerequisite input datasets, grouped by 

648 `DatasetType`. 

649 """ 

650 

651 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

652 """Mapping of output datasets, grouped by `DatasetType`. 

653 """ 

654 

655 inputs_adjusted: bool = False 

656 """Whether any inputs were removed in the last call to `adjust_in_place`. 

657 """ 

658 

659 outputs_adjusted: bool = False 

660 """Whether any outputs were removed in the last call to `adjust_in_place`. 

661 """ 

662 

663 def adjust_in_place( 

664 self, 

665 connections: PipelineTaskConnections, 

666 label: str, 

667 data_id: DataCoordinate, 

668 ) -> None: 

669 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

670 with its results. 

671 

672 Parameters 

673 ---------- 

674 connections : `PipelineTaskConnections` 

675 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

676 label : `str` 

677 Label for this task in the pipeline (should be used in all 

678 diagnostic messages). 

679 data_id : `lsst.daf.butler.DataCoordinate` 

680 Data ID for this quantum in the pipeline (should be used in all 

681 diagnostic messages). 

682 """ 

683 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

684 # connection-keyed, PipelineTask-oriented mappings. 

685 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {} 

686 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {} 

687 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

688 connection = getattr(connections, name) 

689 dataset_type_name = connection.name 

690 inputs_by_connection[name] = ( 

691 connection, 

692 tuple(self.inputs.get(dataset_type_name, ())) 

693 ) 

694 for name in itertools.chain(connections.outputs): 

695 connection = getattr(connections, name) 

696 dataset_type_name = connection.name 

697 outputs_by_connection[name] = ( 

698 connection, 

699 tuple(self.outputs.get(dataset_type_name, ())) 

700 ) 

701 # Actually call adjustQuantum. 

702 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

703 inputs_by_connection, 

704 outputs_by_connection, 

705 label, 

706 data_id, 

707 ) 

708 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

709 # installing new mappings in self if necessary. 

710 if adjusted_inputs_by_connection: 

711 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs) 

712 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

713 dataset_type_name = connection.name 

714 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

715 raise RuntimeError( 

716 f"adjustQuantum implementation for task with label {label} returned {name} " 

717 f"({dataset_type_name}) input datasets that are not a subset of those " 

718 f"it was given for data ID {data_id}." 

719 ) 

720 adjusted_inputs[dataset_type_name] = list(updated_refs) 

721 self.inputs = adjusted_inputs.freeze() 

722 self.inputs_adjusted = True 

723 else: 

724 self.inputs_adjusted = False 

725 if adjusted_outputs_by_connection: 

726 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs) 

727 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

728 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

729 raise RuntimeError( 

730 f"adjustQuantum implementation for task with label {label} returned {name} " 

731 f"({dataset_type_name}) output datasets that are not a subset of those " 

732 f"it was given for data ID {data_id}." 

733 ) 

734 adjusted_outputs[dataset_type_name] = list(updated_refs) 

735 self.outputs = adjusted_outputs.freeze() 

736 self.outputs_adjusted = True 

737 else: 

738 self.outputs_adjusted = False