Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "iterConnections", 

34] 

35 

36from collections import UserDict, namedtuple 

37from dataclasses import dataclass 

38from types import SimpleNamespace 

39import typing 

40from typing import Union, Iterable 

41 

42import itertools 

43import string 

44 

45from . import config as configMod 

46from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

47 Output, BaseConnection, BaseInput) 

48from ._status import NoWorkFound 

49from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

50 

51if typing.TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 from .config import PipelineTaskConfig 

53 

54 

55class ScalarError(TypeError): 

56 """Exception raised when dataset type is configured as scalar 

57 but there are multiple data IDs in a Quantum for that dataset. 

58 """ 

59 

60 

61class PipelineTaskConnectionDict(UserDict): 

62 """This is a special dict class used by PipelineTaskConnectionMetaclass 

63 

64 This dict is used in PipelineTaskConnection class creation, as the 

65 dictionary that is initially used as __dict__. It exists to 

66 intercept connection fields declared in a PipelineTaskConnection, and 

67 what name is used to identify them. The names are then added to class 

68 level list according to the connection type of the class attribute. The 

69 names are also used as keys in a class level dictionary associated with 

70 the corresponding class attribute. This information is a duplicate of 

71 what exists in __dict__, but provides a simple place to lookup and 

72 iterate on only these variables. 

73 """ 

74 def __init__(self, *args, **kwargs): 

75 super().__init__(*args, **kwargs) 

76 # Initialize class level variables used to track any declared 

77 # class level variables that are instances of 

78 # connectionTypes.BaseConnection 

79 self.data['inputs'] = [] 

80 self.data['prerequisiteInputs'] = [] 

81 self.data['outputs'] = [] 

82 self.data['initInputs'] = [] 

83 self.data['initOutputs'] = [] 

84 self.data['allConnections'] = {} 

85 

86 def __setitem__(self, name, value): 

87 if isinstance(value, Input): 

88 self.data['inputs'].append(name) 

89 elif isinstance(value, PrerequisiteInput): 

90 self.data['prerequisiteInputs'].append(name) 

91 elif isinstance(value, Output): 

92 self.data['outputs'].append(name) 

93 elif isinstance(value, InitInput): 

94 self.data['initInputs'].append(name) 

95 elif isinstance(value, InitOutput): 

96 self.data['initOutputs'].append(name) 

97 # This should not be an elif, as it needs tested for 

98 # everything that inherits from BaseConnection 

99 if isinstance(value, BaseConnection): 

100 object.__setattr__(value, 'varName', name) 

101 self.data['allConnections'][name] = value 

102 # defer to the default behavior 

103 super().__setitem__(name, value) 

104 

105 

106class PipelineTaskConnectionsMetaclass(type): 

107 """Metaclass used in the declaration of PipelineTaskConnections classes 

108 """ 

109 def __prepare__(name, bases, **kwargs): # noqa: 805 

110 # Create an instance of our special dict to catch and track all 

111 # variables that are instances of connectionTypes.BaseConnection 

112 # Copy any existing connections from a parent class 

113 dct = PipelineTaskConnectionDict() 

114 for base in bases: 

115 if isinstance(base, PipelineTaskConnectionsMetaclass): 115 ↛ 114line 115 didn't jump to line 114, because the condition on line 115 was never false

116 for name, value in base.allConnections.items(): 116 ↛ 117line 116 didn't jump to line 117, because the loop on line 116 never started

117 dct[name] = value 

118 return dct 

119 

120 def __new__(cls, name, bases, dct, **kwargs): 

121 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

122 "attribute which is an iterable of dimension names") 

123 

124 if name != 'PipelineTaskConnections': 

125 # Verify that dimensions are passed as a keyword in class 

126 # declaration 

127 if 'dimensions' not in kwargs: 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true

128 for base in bases: 

129 if hasattr(base, 'dimensions'): 

130 kwargs['dimensions'] = base.dimensions 

131 break 

132 if 'dimensions' not in kwargs: 

133 raise dimensionsValueError 

134 try: 

135 if isinstance(kwargs['dimensions'], str): 135 ↛ 136line 135 didn't jump to line 136, because the condition on line 135 was never true

136 raise TypeError("Dimensions must be iterable of dimensions, got str," 

137 "possibly omitted trailing comma") 

138 if not isinstance(kwargs['dimensions'], typing.Iterable): 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true

139 raise TypeError("Dimensions must be iterable of dimensions") 

140 dct['dimensions'] = set(kwargs['dimensions']) 

141 except TypeError as exc: 

142 raise dimensionsValueError from exc 

143 # Lookup any python string templates that may have been used in the 

144 # declaration of the name field of a class connection attribute 

145 allTemplates = set() 

146 stringFormatter = string.Formatter() 

147 # Loop over all connections 

148 for obj in dct['allConnections'].values(): 

149 nameValue = obj.name 

150 # add all the parameters to the set of templates 

151 for param in stringFormatter.parse(nameValue): 

152 if param[1] is not None: 

153 allTemplates.add(param[1]) 

154 

155 # look up any template from base classes and merge them all 

156 # together 

157 mergeDict = {} 

158 for base in bases[::-1]: 

159 if hasattr(base, 'defaultTemplates'): 159 ↛ 160line 159 didn't jump to line 160, because the condition on line 159 was never true

160 mergeDict.update(base.defaultTemplates) 

161 if 'defaultTemplates' in kwargs: 

162 mergeDict.update(kwargs['defaultTemplates']) 

163 

164 if len(mergeDict) > 0: 

165 kwargs['defaultTemplates'] = mergeDict 

166 

167 # Verify that if templated strings were used, defaults were 

168 # supplied as an argument in the declaration of the connection 

169 # class 

170 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true

171 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

172 "defaut templates were provided, add a dictionary attribute named " 

173 "defaultTemplates which contains the mapping between template key and value") 

174 if len(allTemplates) > 0: 

175 # Verify all templates have a default, and throw if they do not 

176 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

177 templateDifference = allTemplates.difference(defaultTemplateKeys) 

178 if templateDifference: 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true

179 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

180 # Verify that templates do not share names with variable names 

181 # used for a connection, this is needed because of how 

182 # templates are specified in an associated config class. 

183 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

184 if len(nameTemplateIntersection) > 0: 184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true

185 raise TypeError(f"Template parameters cannot share names with Class attributes" 

186 f" (conflicts are {nameTemplateIntersection}).") 

187 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

188 

189 # Convert all the connection containers into frozensets so they cannot 

190 # be modified at the class scope 

191 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

192 dct[connectionName] = frozenset(dct[connectionName]) 

193 # our custom dict type must be turned into an actual dict to be used in 

194 # type.__new__ 

195 return super().__new__(cls, name, bases, dict(dct)) 

196 

197 def __init__(cls, name, bases, dct, **kwargs): 

198 # This overrides the default init to drop the kwargs argument. Python 

199 # metaclasses will have this argument set if any kwargs are passes at 

200 # class construction time, but should be consumed before calling 

201 # __init__ on the type metaclass. This is in accordance with python 

202 # documentation on metaclasses 

203 super().__init__(name, bases, dct) 

204 

205 

206class QuantizedConnection(SimpleNamespace): 

207 """A Namespace to map defined variable names of connections to the 

208 associated `lsst.daf.butler.DatasetRef` objects. 

209 

210 This class maps the names used to define a connection on a 

211 PipelineTaskConnectionsClass to the corresponding 

212 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

213 instance. This will be a quantum of execution based on the graph created 

214 by examining all the connections defined on the 

215 `PipelineTaskConnectionsClass`. 

216 """ 

217 def __init__(self, **kwargs): 

218 # Create a variable to track what attributes are added. This is used 

219 # later when iterating over this QuantizedConnection instance 

220 object.__setattr__(self, "_attributes", set()) 

221 

222 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

223 # Capture the attribute name as it is added to this object 

224 self._attributes.add(name) 

225 super().__setattr__(name, value) 

226 

227 def __delattr__(self, name): 

228 object.__delattr__(self, name) 

229 self._attributes.remove(name) 

230 

231 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

232 typing.List[DatasetRef]]], None, None]: 

233 """Make an Iterator for this QuantizedConnection 

234 

235 Iterating over a QuantizedConnection will yield a tuple with the name 

236 of an attribute and the value associated with that name. This is 

237 similar to dict.items() but is on the namespace attributes rather than 

238 dict keys. 

239 """ 

240 yield from ((name, getattr(self, name)) for name in self._attributes) 

241 

242 def keys(self) -> typing.Generator[str, None, None]: 

243 """Returns an iterator over all the attributes added to a 

244 QuantizedConnection class 

245 """ 

246 yield from self._attributes 

247 

248 

249class InputQuantizedConnection(QuantizedConnection): 

250 pass 

251 

252 

253class OutputQuantizedConnection(QuantizedConnection): 

254 pass 

255 

256 

257class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

258 """Class which denotes that a datasetRef should be treated as deferred when 

259 interacting with the butler 

260 

261 Parameters 

262 ---------- 

263 datasetRef : `lsst.daf.butler.DatasetRef` 

264 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

265 resolve a dataset 

266 """ 

267 __slots__ = () 

268 

269 

270class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

271 """PipelineTaskConnections is a class used to declare desired IO when a 

272 PipelineTask is run by an activator 

273 

274 Parameters 

275 ---------- 

276 config : `PipelineTaskConfig` 

277 A `PipelineTaskConfig` class instance whose class has been configured 

278 to use this `PipelineTaskConnectionsClass` 

279 

280 See also 

281 -------- 

282 iterConnections 

283 

284 Notes 

285 ----- 

286 ``PipelineTaskConnection`` classes are created by declaring class 

287 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

288 listed as follows: 

289 

290 * ``InitInput`` - Defines connections in a quantum graph which are used as 

291 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

292 to this class 

293 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

294 persisted using a butler at the end of the ``__init__`` function of the 

295 `PipelineTask` corresponding to this class. The variable name used to 

296 define this connection should be the same as an attribute name on the 

297 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

298 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

299 a `PipelineTask` instance should have an attribute 

300 ``self.outputSchema`` defined. Its value is what will be saved by the 

301 activator framework. 

302 * ``PrerequisiteInput`` - An input connection type that defines a 

303 `lsst.daf.butler.DatasetType` that must be present at execution time, 

304 but that will not be used during the course of creating the quantum 

305 graph to be executed. These most often are things produced outside the 

306 processing pipeline, such as reference catalogs. 

307 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

308 in the ``run`` method of a `PipelineTask`. The name used to declare 

309 class attribute must match a function argument name in the ``run`` 

310 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

311 defines an ``Input`` with the name ``calexp``, then the corresponding 

312 signature should be ``PipelineTask.run(calexp, ...)`` 

313 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

314 execution of a `PipelineTask`. The name used to declare the connection 

315 must correspond to an attribute of a `Struct` that is returned by a 

316 `PipelineTask` ``run`` method. E.g. if an output connection is 

317 defined with the name ``measCat``, then the corresponding 

318 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

319 X matches the ``storageClass`` type defined on the output connection. 

320 

321 The process of declaring a ``PipelineTaskConnection`` class involves 

322 parameters passed in the declaration statement. 

323 

324 The first parameter is ``dimensions`` which is an iterable of strings which 

325 defines the unit of processing the run method of a corresponding 

326 `PipelineTask` will operate on. These dimensions must match dimensions that 

327 exist in the butler registry which will be used in executing the 

328 corresponding `PipelineTask`. 

329 

330 The second parameter is labeled ``defaultTemplates`` and is conditionally 

331 optional. The name attributes of connections can be specified as python 

332 format strings, with named format arguments. If any of the name parameters 

333 on connections defined in a `PipelineTaskConnections` class contain a 

334 template, then a default template value must be specified in the 

335 ``defaultTemplates`` argument. This is done by passing a dictionary with 

336 keys corresponding to a template identifier, and values corresponding to 

337 the value to use as a default when formatting the string. For example if 

338 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

339 ``defaultTemplates`` = {'input': 'deep'}. 

340 

341 Once a `PipelineTaskConnections` class is created, it is used in the 

342 creation of a `PipelineTaskConfig`. This is further documented in the 

343 documentation of `PipelineTaskConfig`. For the purposes of this 

344 documentation, the relevant information is that the config class allows 

345 configuration of connection names by users when running a pipeline. 

346 

347 Instances of a `PipelineTaskConnections` class are used by the pipeline 

348 task execution framework to introspect what a corresponding `PipelineTask` 

349 will require, and what it will produce. 

350 

351 Examples 

352 -------- 

353 >>> from lsst.pipe.base import connectionTypes as cT 

354 >>> from lsst.pipe.base import PipelineTaskConnections 

355 >>> from lsst.pipe.base import PipelineTaskConfig 

356 >>> class ExampleConnections(PipelineTaskConnections, 

357 ... dimensions=("A", "B"), 

358 ... defaultTemplates={"foo": "Example"}): 

359 ... inputConnection = cT.Input(doc="Example input", 

360 ... dimensions=("A", "B"), 

361 ... storageClass=Exposure, 

362 ... name="{foo}Dataset") 

363 ... outputConnection = cT.Output(doc="Example output", 

364 ... dimensions=("A", "B"), 

365 ... storageClass=Exposure, 

366 ... name="{foo}output") 

367 >>> class ExampleConfig(PipelineTaskConfig, 

368 ... pipelineConnections=ExampleConnections): 

369 ... pass 

370 >>> config = ExampleConfig() 

371 >>> config.connections.foo = Modified 

372 >>> config.connections.outputConnection = "TotallyDifferent" 

373 >>> connections = ExampleConnections(config=config) 

374 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

375 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

376 """ 

377 

378 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

379 self.inputs = set(self.inputs) 

380 self.prerequisiteInputs = set(self.prerequisiteInputs) 

381 self.outputs = set(self.outputs) 

382 self.initInputs = set(self.initInputs) 

383 self.initOutputs = set(self.initOutputs) 

384 self.allConnections = dict(self.allConnections) 

385 

386 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

387 raise ValueError("PipelineTaskConnections must be instantiated with" 

388 " a PipelineTaskConfig instance") 

389 self.config = config 

390 # Extract the template names that were defined in the config instance 

391 # by looping over the keys of the defaultTemplates dict specified at 

392 # class declaration time 

393 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

394 'defaultTemplates').keys()} 

395 # Extract the configured value corresponding to each connection 

396 # variable. I.e. for each connection identifier, populate a override 

397 # for the connection.name attribute 

398 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

399 for name in self.allConnections.keys()} 

400 

401 # connections.name corresponds to a dataset type name, create a reverse 

402 # mapping that goes from dataset type name to attribute identifier name 

403 # (variable name) on the connection class 

404 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

405 

406 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

407 OutputQuantizedConnection]: 

408 """Builds QuantizedConnections corresponding to input Quantum 

409 

410 Parameters 

411 ---------- 

412 quantum : `lsst.daf.butler.Quantum` 

413 Quantum object which defines the inputs and outputs for a given 

414 unit of processing 

415 

416 Returns 

417 ------- 

418 retVal : `tuple` of (`InputQuantizedConnection`, 

419 `OutputQuantizedConnection`) Namespaces mapping attribute names 

420 (identifiers of connections) to butler references defined in the 

421 input `lsst.daf.butler.Quantum` 

422 """ 

423 inputDatasetRefs = InputQuantizedConnection() 

424 outputDatasetRefs = OutputQuantizedConnection() 

425 # operate on a reference object and an interable of names of class 

426 # connection attributes 

427 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

428 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

429 # get a name of a class connection attribute 

430 for attributeName in names: 

431 # get the attribute identified by name 

432 attribute = getattr(self, attributeName) 

433 # Branch if the attribute dataset type is an input 

434 if attribute.name in quantum.inputs: 

435 # Get the DatasetRefs 

436 quantumInputRefs = quantum.inputs[attribute.name] 

437 # if the dataset is marked to load deferred, wrap it in a 

438 # DeferredDatasetRef 

439 if attribute.deferLoad: 

440 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

441 # Unpack arguments that are not marked multiples (list of 

442 # length one) 

443 if not attribute.multiple: 

444 if len(quantumInputRefs) > 1: 

445 raise ScalarError( 

446 f"Received multiple datasets " 

447 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

448 f"for scalar connection {attributeName} " 

449 f"({quantumInputRefs[0].datasetType.name}) " 

450 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

451 ) 

452 if len(quantumInputRefs) == 0: 

453 continue 

454 quantumInputRefs = quantumInputRefs[0] 

455 # Add to the QuantizedConnection identifier 

456 setattr(refs, attributeName, quantumInputRefs) 

457 # Branch if the attribute dataset type is an output 

458 elif attribute.name in quantum.outputs: 

459 value = quantum.outputs[attribute.name] 

460 # Unpack arguments that are not marked multiples (list of 

461 # length one) 

462 if not attribute.multiple: 

463 value = value[0] 

464 # Add to the QuantizedConnection identifier 

465 setattr(refs, attributeName, value) 

466 # Specified attribute is not in inputs or outputs dont know how 

467 # to handle, throw 

468 else: 

469 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

470 "in input quantum") 

471 return inputDatasetRefs, outputDatasetRefs 

472 

473 def adjustQuantum( 

474 self, 

475 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

476 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

477 label: str, 

478 data_id: DataCoordinate, 

479 ) -> tuple.Tuple[typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

480 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]]]: 

481 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

482 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

483 of the activator. 

484 

485 Parameters 

486 ---------- 

487 inputs : `dict` 

488 Dictionary whose keys are an input (regular or prerequisite) 

489 connection name and whose values are a tuple of the connection 

490 instance and a collection of associated `DatasetRef` objects. 

491 The exact type of the nested collections is unspecified; it can be 

492 assumed to be multi-pass iterable and support `len` and ``in``, but 

493 it should not be mutated in place. In contrast, the outer 

494 dictionaries are guaranteed to be temporary copies that are true 

495 `dict` instances, and hence may be modified and even returned; this 

496 is especially useful for delegating to `super` (see notes below). 

497 outputs : `Mapping` 

498 Mapping of output datasets, with the same structure as ``inputs``. 

499 label : `str` 

500 Label for this task in the pipeline (should be used in all 

501 diagnostic messages). 

502 data_id : `lsst.daf.butler.DataCoordinate` 

503 Data ID for this quantum in the pipeline (should be used in all 

504 diagnostic messages). 

505 

506 Returns 

507 ------- 

508 adjusted_inputs : `Mapping` 

509 Mapping of the same form as ``inputs`` with updated containers of 

510 input `DatasetRef` objects. Connections that are not changed 

511 should not be returned at all. Datasets may only be removed, not 

512 added. Nested collections may be of any multi-pass iterable type, 

513 and the order of iteration will set the order of iteration within 

514 `PipelineTask.runQuantum`. 

515 adjusted_outputs : `Mapping` 

516 Mapping of updated output datasets, with the same structure and 

517 interpretation as ``adjusted_inputs``. 

518 

519 Raises 

520 ------ 

521 ScalarError 

522 Raised if any `Input` or `PrerequisiteInput` connection has 

523 ``multiple`` set to `False`, but multiple datasets. 

524 NoWorkFound 

525 Raised to indicate that this quantum should not be run; not enough 

526 datasets were found for a regular `Input` connection, and the 

527 quantum should be pruned or skipped. 

528 FileNotFoundError 

529 Raised to cause QuantumGraph generation to fail (with the message 

530 included in this exception); not enough datasets were found for a 

531 `PrerequisiteInput` connection. 

532 

533 Notes 

534 ----- 

535 The base class implementation performs important checks. It always 

536 returns an empty mapping (i.e. makes no adjustments). It should 

537 always called be via `super` by custom implementations, ideally at the 

538 end of the custom implementation with already-adjusted mappings when 

539 any datasets are actually dropped, e.g.:: 

540 

541 def adjustQuantum(self, inputs, outputs, label, data_id): 

542 # Filter out some dataset refs for one connection. 

543 connection, old_refs = inputs["my_input"] 

544 new_refs = [ref for ref in old_refs if ...] 

545 adjusted_inputs = {"my_input", (connection, new_refs)} 

546 # Update the original inputs so we can pass them to super. 

547 inputs.update(adjusted_inputs) 

548 # Can ignore outputs from super because they are guaranteed 

549 # to be empty. 

550 super().adjustQuantum(inputs, outputs, label_data_id) 

551 # Return only the connections we modified. 

552 return adjusted_inputs, {} 

553 

554 Removing outputs here is guaranteed to affect what is actually 

555 passed to `PipelineTask.runQuantum`, but its effect on the larger 

556 graph may be deferred to execution, depending on the context in 

557 which `adjustQuantum` is being run: if one quantum removes an output 

558 that is needed by a second quantum as input, the second quantum may not 

559 be adjusted (and hence pruned or skipped) until that output is actually 

560 found to be missing at execution time. 

561 

562 Tasks that desire zip-iteration consistency between any combinations of 

563 connections that have the same data ID should generally implement 

564 `adjustQuantum` to achieve this, even if they could also run that 

565 logic during execution; this allows the system to see outputs that will 

566 not be produced because the corresponding input is missing as early as 

567 possible. 

568 """ 

569 for name, (connection, refs) in inputs.items(): 

570 dataset_type_name = connection.name 

571 if not connection.multiple and len(refs) > 1: 

572 raise ScalarError( 

573 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

574 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

575 f"for quantum data ID {data_id}." 

576 ) 

577 if len(refs) < connection.minimum: 

578 if isinstance(connection, PrerequisiteInput): 

579 # This branch should only be possible during QG generation, 

580 # or if someone deleted the dataset between making the QG 

581 # and trying to run it. Either one should be a hard error. 

582 raise FileNotFoundError( 

583 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

584 f"({dataset_type_name}) with minimum={connection.minimum} for quantum data ID " 

585 f"{data_id}." 

586 ) 

587 else: 

588 # This branch should be impossible during QG generation, 

589 # because that algorithm can only make quanta whose inputs 

590 # are either already present or should be created during 

591 # execution. It can trigger during execution if the input 

592 # wasn't actually created by an upstream task in the same 

593 # graph. 

594 raise NoWorkFound(label, name, connection) 

595 for name, (connection, refs) in outputs.items(): 

596 dataset_type_name = connection.name 

597 if not connection.multiple and len(refs) > 1: 

598 raise ScalarError( 

599 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

600 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

601 f"for quantum data ID {data_id}." 

602 ) 

603 return {}, {} 

604 

605 

606def iterConnections(connections: PipelineTaskConnections, 

607 connectionType: Union[str, Iterable[str]] 

608 ) -> typing.Generator[BaseConnection, None, None]: 

609 """Creates an iterator over the selected connections type which yields 

610 all the defined connections of that type. 

611 

612 Parameters 

613 ---------- 

614 connections: `PipelineTaskConnections` 

615 An instance of a `PipelineTaskConnections` object that will be iterated 

616 over. 

617 connectionType: `str` 

618 The type of connections to iterate over, valid values are inputs, 

619 outputs, prerequisiteInputs, initInputs, initOutputs. 

620 

621 Yields 

622 ------- 

623 connection: `BaseConnection` 

624 A connection defined on the input connections object of the type 

625 supplied. The yielded value Will be an derived type of 

626 `BaseConnection`. 

627 """ 

628 if isinstance(connectionType, str): 

629 connectionType = (connectionType,) 

630 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

631 yield getattr(connections, name) 

632 

633 

634@dataclass 

635class AdjustQuantumHelper: 

636 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

637 

638 This class holds `input` and `output` mappings in the form used by 

639 `Quantum` and execution harness code, i.e. with `DatasetType` keys, 

640 translating them to and from the connection-oriented mappings used inside 

641 `PipelineTaskConnections`. 

642 """ 

643 

644 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

645 """Mapping of regular input and prerequisite input datasets, grouped by 

646 `DatasetType`. 

647 """ 

648 

649 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

650 """Mapping of output datasets, grouped by `DatasetType`. 

651 """ 

652 

653 inputs_adjusted: bool = False 

654 """Whether any inputs were removed in the last call to `adjust_in_place`. 

655 """ 

656 

657 outputs_adjusted: bool = False 

658 """Whether any outputs were removed in the last call to `adjust_in_place`. 

659 """ 

660 

661 def adjust_in_place( 

662 self, 

663 connections: PipelineTaskConnections, 

664 label: str, 

665 data_id: DataCoordinate, 

666 ) -> None: 

667 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

668 with its results. 

669 

670 Parameters 

671 ---------- 

672 connections : `PipelineTaskConnections` 

673 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

674 label : `str` 

675 Label for this task in the pipeline (should be used in all 

676 diagnostic messages). 

677 data_id : `lsst.daf.butler.DataCoordinate` 

678 Data ID for this quantum in the pipeline (should be used in all 

679 diagnostic messages). 

680 """ 

681 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

682 # connection-keyed, PipelineTask-oriented mappings. 

683 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {} 

684 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {} 

685 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

686 connection = getattr(connections, name) 

687 dataset_type_name = connection.name 

688 inputs_by_connection[name] = ( 

689 connection, 

690 tuple(self.inputs.get(dataset_type_name, ())) 

691 ) 

692 for name in itertools.chain(connections.outputs): 

693 connection = getattr(connections, name) 

694 outputs_by_connection[name] = ( 

695 connection, 

696 tuple(self.outputs.get(dataset_type_name, ())) 

697 ) 

698 # Actually call adjustQuantum. 

699 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

700 inputs_by_connection, 

701 outputs_by_connection, 

702 label, 

703 data_id, 

704 ) 

705 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

706 # installing new mappings in self if necessary. 

707 if adjusted_inputs_by_connection: 

708 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs) 

709 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

710 dataset_type_name = connection.name 

711 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

712 raise RuntimeError( 

713 f"adjustQuantum implementation for task with label {label} returned {name} " 

714 f"({dataset_type_name}) input datasets that are not a subset of those " 

715 f"it was given for data ID {data_id}." 

716 ) 

717 adjusted_inputs[dataset_type_name] = list(updated_refs) 

718 self.inputs = adjusted_inputs.freeze() 

719 self.inputs_adjusted = True 

720 else: 

721 self.inputs_adjusted = False 

722 if adjusted_outputs_by_connection is not None: 

723 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs) 

724 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

725 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

726 raise RuntimeError( 

727 f"adjustQuantum implementation for task with label {label} returned {name} " 

728 f"({dataset_type_name}) output datasets that are not a subset of those " 

729 f"it was given for data ID {data_id}." 

730 ) 

731 adjusted_outputs[dataset_type_name] = list(updated_refs) 

732 self.outputs = adjusted_outputs.freeze() 

733 self.outputs_adjusted = True 

734 else: 

735 self.outputs_adjusted = False