Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35] 

36 

37from collections import UserDict, namedtuple 

38from dataclasses import dataclass 

39from types import SimpleNamespace 

40import typing 

41from typing import Union, Iterable 

42 

43import itertools 

44import string 

45 

46from . import config as configMod 

47from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

48 Output, BaseConnection, BaseInput) 

49from ._status import NoWorkFound 

50from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

51 

52if typing.TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true

53 from .config import PipelineTaskConfig 

54 

55 

56class ScalarError(TypeError): 

57 """Exception raised when dataset type is configured as scalar 

58 but there are multiple data IDs in a Quantum for that dataset. 

59 """ 

60 

61 

62class PipelineTaskConnectionDict(UserDict): 

63 """This is a special dict class used by PipelineTaskConnectionMetaclass 

64 

65 This dict is used in PipelineTaskConnection class creation, as the 

66 dictionary that is initially used as __dict__. It exists to 

67 intercept connection fields declared in a PipelineTaskConnection, and 

68 what name is used to identify them. The names are then added to class 

69 level list according to the connection type of the class attribute. The 

70 names are also used as keys in a class level dictionary associated with 

71 the corresponding class attribute. This information is a duplicate of 

72 what exists in __dict__, but provides a simple place to lookup and 

73 iterate on only these variables. 

74 """ 

75 def __init__(self, *args, **kwargs): 

76 super().__init__(*args, **kwargs) 

77 # Initialize class level variables used to track any declared 

78 # class level variables that are instances of 

79 # connectionTypes.BaseConnection 

80 self.data['inputs'] = [] 

81 self.data['prerequisiteInputs'] = [] 

82 self.data['outputs'] = [] 

83 self.data['initInputs'] = [] 

84 self.data['initOutputs'] = [] 

85 self.data['allConnections'] = {} 

86 

87 def __setitem__(self, name, value): 

88 if isinstance(value, Input): 

89 self.data['inputs'].append(name) 

90 elif isinstance(value, PrerequisiteInput): 

91 self.data['prerequisiteInputs'].append(name) 

92 elif isinstance(value, Output): 

93 self.data['outputs'].append(name) 

94 elif isinstance(value, InitInput): 

95 self.data['initInputs'].append(name) 

96 elif isinstance(value, InitOutput): 

97 self.data['initOutputs'].append(name) 

98 # This should not be an elif, as it needs tested for 

99 # everything that inherits from BaseConnection 

100 if isinstance(value, BaseConnection): 

101 object.__setattr__(value, 'varName', name) 

102 self.data['allConnections'][name] = value 

103 # defer to the default behavior 

104 super().__setitem__(name, value) 

105 

106 

107class PipelineTaskConnectionsMetaclass(type): 

108 """Metaclass used in the declaration of PipelineTaskConnections classes 

109 """ 

110 def __prepare__(name, bases, **kwargs): # noqa: 805 

111 # Create an instance of our special dict to catch and track all 

112 # variables that are instances of connectionTypes.BaseConnection 

113 # Copy any existing connections from a parent class 

114 dct = PipelineTaskConnectionDict() 

115 for base in bases: 

116 if isinstance(base, PipelineTaskConnectionsMetaclass): 116 ↛ 115line 116 didn't jump to line 115, because the condition on line 116 was never false

117 for name, value in base.allConnections.items(): 117 ↛ 118line 117 didn't jump to line 118, because the loop on line 117 never started

118 dct[name] = value 

119 return dct 

120 

121 def __new__(cls, name, bases, dct, **kwargs): 

122 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

123 "attribute which is an iterable of dimension names") 

124 

125 if name != 'PipelineTaskConnections': 

126 # Verify that dimensions are passed as a keyword in class 

127 # declaration 

128 if 'dimensions' not in kwargs: 128 ↛ 129line 128 didn't jump to line 129, because the condition on line 128 was never true

129 for base in bases: 

130 if hasattr(base, 'dimensions'): 

131 kwargs['dimensions'] = base.dimensions 

132 break 

133 if 'dimensions' not in kwargs: 

134 raise dimensionsValueError 

135 try: 

136 if isinstance(kwargs['dimensions'], str): 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true

137 raise TypeError("Dimensions must be iterable of dimensions, got str," 

138 "possibly omitted trailing comma") 

139 if not isinstance(kwargs['dimensions'], typing.Iterable): 139 ↛ 140line 139 didn't jump to line 140, because the condition on line 139 was never true

140 raise TypeError("Dimensions must be iterable of dimensions") 

141 dct['dimensions'] = set(kwargs['dimensions']) 

142 except TypeError as exc: 

143 raise dimensionsValueError from exc 

144 # Lookup any python string templates that may have been used in the 

145 # declaration of the name field of a class connection attribute 

146 allTemplates = set() 

147 stringFormatter = string.Formatter() 

148 # Loop over all connections 

149 for obj in dct['allConnections'].values(): 

150 nameValue = obj.name 

151 # add all the parameters to the set of templates 

152 for param in stringFormatter.parse(nameValue): 

153 if param[1] is not None: 

154 allTemplates.add(param[1]) 

155 

156 # look up any template from base classes and merge them all 

157 # together 

158 mergeDict = {} 

159 for base in bases[::-1]: 

160 if hasattr(base, 'defaultTemplates'): 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true

161 mergeDict.update(base.defaultTemplates) 

162 if 'defaultTemplates' in kwargs: 

163 mergeDict.update(kwargs['defaultTemplates']) 

164 

165 if len(mergeDict) > 0: 

166 kwargs['defaultTemplates'] = mergeDict 

167 

168 # Verify that if templated strings were used, defaults were 

169 # supplied as an argument in the declaration of the connection 

170 # class 

171 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true

172 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

173 "defaut templates were provided, add a dictionary attribute named " 

174 "defaultTemplates which contains the mapping between template key and value") 

175 if len(allTemplates) > 0: 

176 # Verify all templates have a default, and throw if they do not 

177 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

178 templateDifference = allTemplates.difference(defaultTemplateKeys) 

179 if templateDifference: 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true

180 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

181 # Verify that templates do not share names with variable names 

182 # used for a connection, this is needed because of how 

183 # templates are specified in an associated config class. 

184 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

185 if len(nameTemplateIntersection) > 0: 185 ↛ 186line 185 didn't jump to line 186, because the condition on line 185 was never true

186 raise TypeError(f"Template parameters cannot share names with Class attributes" 

187 f" (conflicts are {nameTemplateIntersection}).") 

188 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

189 

190 # Convert all the connection containers into frozensets so they cannot 

191 # be modified at the class scope 

192 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

193 dct[connectionName] = frozenset(dct[connectionName]) 

194 # our custom dict type must be turned into an actual dict to be used in 

195 # type.__new__ 

196 return super().__new__(cls, name, bases, dict(dct)) 

197 

198 def __init__(cls, name, bases, dct, **kwargs): 

199 # This overrides the default init to drop the kwargs argument. Python 

200 # metaclasses will have this argument set if any kwargs are passes at 

201 # class construction time, but should be consumed before calling 

202 # __init__ on the type metaclass. This is in accordance with python 

203 # documentation on metaclasses 

204 super().__init__(name, bases, dct) 

205 

206 

207class QuantizedConnection(SimpleNamespace): 

208 """A Namespace to map defined variable names of connections to the 

209 associated `lsst.daf.butler.DatasetRef` objects. 

210 

211 This class maps the names used to define a connection on a 

212 PipelineTaskConnectionsClass to the corresponding 

213 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

214 instance. This will be a quantum of execution based on the graph created 

215 by examining all the connections defined on the 

216 `PipelineTaskConnectionsClass`. 

217 """ 

218 def __init__(self, **kwargs): 

219 # Create a variable to track what attributes are added. This is used 

220 # later when iterating over this QuantizedConnection instance 

221 object.__setattr__(self, "_attributes", set()) 

222 

223 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

224 # Capture the attribute name as it is added to this object 

225 self._attributes.add(name) 

226 super().__setattr__(name, value) 

227 

228 def __delattr__(self, name): 

229 object.__delattr__(self, name) 

230 self._attributes.remove(name) 

231 

232 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

233 typing.List[DatasetRef]]], None, None]: 

234 """Make an Iterator for this QuantizedConnection 

235 

236 Iterating over a QuantizedConnection will yield a tuple with the name 

237 of an attribute and the value associated with that name. This is 

238 similar to dict.items() but is on the namespace attributes rather than 

239 dict keys. 

240 """ 

241 yield from ((name, getattr(self, name)) for name in self._attributes) 

242 

243 def keys(self) -> typing.Generator[str, None, None]: 

244 """Returns an iterator over all the attributes added to a 

245 QuantizedConnection class 

246 """ 

247 yield from self._attributes 

248 

249 

250class InputQuantizedConnection(QuantizedConnection): 

251 pass 

252 

253 

254class OutputQuantizedConnection(QuantizedConnection): 

255 pass 

256 

257 

258class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

259 """Class which denotes that a datasetRef should be treated as deferred when 

260 interacting with the butler 

261 

262 Parameters 

263 ---------- 

264 datasetRef : `lsst.daf.butler.DatasetRef` 

265 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

266 resolve a dataset 

267 """ 

268 __slots__ = () 

269 

270 

271class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

272 """PipelineTaskConnections is a class used to declare desired IO when a 

273 PipelineTask is run by an activator 

274 

275 Parameters 

276 ---------- 

277 config : `PipelineTaskConfig` 

278 A `PipelineTaskConfig` class instance whose class has been configured 

279 to use this `PipelineTaskConnectionsClass` 

280 

281 See also 

282 -------- 

283 iterConnections 

284 

285 Notes 

286 ----- 

287 ``PipelineTaskConnection`` classes are created by declaring class 

288 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

289 listed as follows: 

290 

291 * ``InitInput`` - Defines connections in a quantum graph which are used as 

292 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

293 to this class 

294 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

295 persisted using a butler at the end of the ``__init__`` function of the 

296 `PipelineTask` corresponding to this class. The variable name used to 

297 define this connection should be the same as an attribute name on the 

298 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

299 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

300 a `PipelineTask` instance should have an attribute 

301 ``self.outputSchema`` defined. Its value is what will be saved by the 

302 activator framework. 

303 * ``PrerequisiteInput`` - An input connection type that defines a 

304 `lsst.daf.butler.DatasetType` that must be present at execution time, 

305 but that will not be used during the course of creating the quantum 

306 graph to be executed. These most often are things produced outside the 

307 processing pipeline, such as reference catalogs. 

308 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

309 in the ``run`` method of a `PipelineTask`. The name used to declare 

310 class attribute must match a function argument name in the ``run`` 

311 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

312 defines an ``Input`` with the name ``calexp``, then the corresponding 

313 signature should be ``PipelineTask.run(calexp, ...)`` 

314 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

315 execution of a `PipelineTask`. The name used to declare the connection 

316 must correspond to an attribute of a `Struct` that is returned by a 

317 `PipelineTask` ``run`` method. E.g. if an output connection is 

318 defined with the name ``measCat``, then the corresponding 

319 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

320 X matches the ``storageClass`` type defined on the output connection. 

321 

322 The process of declaring a ``PipelineTaskConnection`` class involves 

323 parameters passed in the declaration statement. 

324 

325 The first parameter is ``dimensions`` which is an iterable of strings which 

326 defines the unit of processing the run method of a corresponding 

327 `PipelineTask` will operate on. These dimensions must match dimensions that 

328 exist in the butler registry which will be used in executing the 

329 corresponding `PipelineTask`. 

330 

331 The second parameter is labeled ``defaultTemplates`` and is conditionally 

332 optional. The name attributes of connections can be specified as python 

333 format strings, with named format arguments. If any of the name parameters 

334 on connections defined in a `PipelineTaskConnections` class contain a 

335 template, then a default template value must be specified in the 

336 ``defaultTemplates`` argument. This is done by passing a dictionary with 

337 keys corresponding to a template identifier, and values corresponding to 

338 the value to use as a default when formatting the string. For example if 

339 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

340 ``defaultTemplates`` = {'input': 'deep'}. 

341 

342 Once a `PipelineTaskConnections` class is created, it is used in the 

343 creation of a `PipelineTaskConfig`. This is further documented in the 

344 documentation of `PipelineTaskConfig`. For the purposes of this 

345 documentation, the relevant information is that the config class allows 

346 configuration of connection names by users when running a pipeline. 

347 

348 Instances of a `PipelineTaskConnections` class are used by the pipeline 

349 task execution framework to introspect what a corresponding `PipelineTask` 

350 will require, and what it will produce. 

351 

352 Examples 

353 -------- 

354 >>> from lsst.pipe.base import connectionTypes as cT 

355 >>> from lsst.pipe.base import PipelineTaskConnections 

356 >>> from lsst.pipe.base import PipelineTaskConfig 

357 >>> class ExampleConnections(PipelineTaskConnections, 

358 ... dimensions=("A", "B"), 

359 ... defaultTemplates={"foo": "Example"}): 

360 ... inputConnection = cT.Input(doc="Example input", 

361 ... dimensions=("A", "B"), 

362 ... storageClass=Exposure, 

363 ... name="{foo}Dataset") 

364 ... outputConnection = cT.Output(doc="Example output", 

365 ... dimensions=("A", "B"), 

366 ... storageClass=Exposure, 

367 ... name="{foo}output") 

368 >>> class ExampleConfig(PipelineTaskConfig, 

369 ... pipelineConnections=ExampleConnections): 

370 ... pass 

371 >>> config = ExampleConfig() 

372 >>> config.connections.foo = Modified 

373 >>> config.connections.outputConnection = "TotallyDifferent" 

374 >>> connections = ExampleConnections(config=config) 

375 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

376 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

377 """ 

378 

379 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

380 self.inputs = set(self.inputs) 

381 self.prerequisiteInputs = set(self.prerequisiteInputs) 

382 self.outputs = set(self.outputs) 

383 self.initInputs = set(self.initInputs) 

384 self.initOutputs = set(self.initOutputs) 

385 self.allConnections = dict(self.allConnections) 

386 

387 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

388 raise ValueError("PipelineTaskConnections must be instantiated with" 

389 " a PipelineTaskConfig instance") 

390 self.config = config 

391 # Extract the template names that were defined in the config instance 

392 # by looping over the keys of the defaultTemplates dict specified at 

393 # class declaration time 

394 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

395 'defaultTemplates').keys()} 

396 # Extract the configured value corresponding to each connection 

397 # variable. I.e. for each connection identifier, populate a override 

398 # for the connection.name attribute 

399 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

400 for name in self.allConnections.keys()} 

401 

402 # connections.name corresponds to a dataset type name, create a reverse 

403 # mapping that goes from dataset type name to attribute identifier name 

404 # (variable name) on the connection class 

405 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

406 

407 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

408 OutputQuantizedConnection]: 

409 """Builds QuantizedConnections corresponding to input Quantum 

410 

411 Parameters 

412 ---------- 

413 quantum : `lsst.daf.butler.Quantum` 

414 Quantum object which defines the inputs and outputs for a given 

415 unit of processing 

416 

417 Returns 

418 ------- 

419 retVal : `tuple` of (`InputQuantizedConnection`, 

420 `OutputQuantizedConnection`) Namespaces mapping attribute names 

421 (identifiers of connections) to butler references defined in the 

422 input `lsst.daf.butler.Quantum` 

423 """ 

424 inputDatasetRefs = InputQuantizedConnection() 

425 outputDatasetRefs = OutputQuantizedConnection() 

426 # operate on a reference object and an interable of names of class 

427 # connection attributes 

428 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

429 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

430 # get a name of a class connection attribute 

431 for attributeName in names: 

432 # get the attribute identified by name 

433 attribute = getattr(self, attributeName) 

434 # Branch if the attribute dataset type is an input 

435 if attribute.name in quantum.inputs: 

436 # Get the DatasetRefs 

437 quantumInputRefs = quantum.inputs[attribute.name] 

438 # if the dataset is marked to load deferred, wrap it in a 

439 # DeferredDatasetRef 

440 if attribute.deferLoad: 

441 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

442 # Unpack arguments that are not marked multiples (list of 

443 # length one) 

444 if not attribute.multiple: 

445 if len(quantumInputRefs) > 1: 

446 raise ScalarError( 

447 f"Received multiple datasets " 

448 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

449 f"for scalar connection {attributeName} " 

450 f"({quantumInputRefs[0].datasetType.name}) " 

451 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

452 ) 

453 if len(quantumInputRefs) == 0: 

454 continue 

455 quantumInputRefs = quantumInputRefs[0] 

456 # Add to the QuantizedConnection identifier 

457 setattr(refs, attributeName, quantumInputRefs) 

458 # Branch if the attribute dataset type is an output 

459 elif attribute.name in quantum.outputs: 

460 value = quantum.outputs[attribute.name] 

461 # Unpack arguments that are not marked multiples (list of 

462 # length one) 

463 if not attribute.multiple: 

464 value = value[0] 

465 # Add to the QuantizedConnection identifier 

466 setattr(refs, attributeName, value) 

467 # Specified attribute is not in inputs or outputs dont know how 

468 # to handle, throw 

469 else: 

470 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

471 "in input quantum") 

472 return inputDatasetRefs, outputDatasetRefs 

473 

474 def adjustQuantum( 

475 self, 

476 inputs: typing.Dict[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

477 outputs: typing.Dict[str, typing.Tuple[Output, typing.Collection[DatasetRef]]], 

478 label: str, 

479 data_id: DataCoordinate, 

480 ) -> tuple.Tuple[typing.Mapping[str, typing.Tuple[BaseInput, typing.Collection[DatasetRef]]], 

481 typing.Mapping[str, typing.Tuple[Output, typing.Collection[DatasetRef]]]]: 

482 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

483 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

484 of the activator. 

485 

486 Parameters 

487 ---------- 

488 inputs : `dict` 

489 Dictionary whose keys are an input (regular or prerequisite) 

490 connection name and whose values are a tuple of the connection 

491 instance and a collection of associated `DatasetRef` objects. 

492 The exact type of the nested collections is unspecified; it can be 

493 assumed to be multi-pass iterable and support `len` and ``in``, but 

494 it should not be mutated in place. In contrast, the outer 

495 dictionaries are guaranteed to be temporary copies that are true 

496 `dict` instances, and hence may be modified and even returned; this 

497 is especially useful for delegating to `super` (see notes below). 

498 outputs : `Mapping` 

499 Mapping of output datasets, with the same structure as ``inputs``. 

500 label : `str` 

501 Label for this task in the pipeline (should be used in all 

502 diagnostic messages). 

503 data_id : `lsst.daf.butler.DataCoordinate` 

504 Data ID for this quantum in the pipeline (should be used in all 

505 diagnostic messages). 

506 

507 Returns 

508 ------- 

509 adjusted_inputs : `Mapping` 

510 Mapping of the same form as ``inputs`` with updated containers of 

511 input `DatasetRef` objects. Connections that are not changed 

512 should not be returned at all. Datasets may only be removed, not 

513 added. Nested collections may be of any multi-pass iterable type, 

514 and the order of iteration will set the order of iteration within 

515 `PipelineTask.runQuantum`. 

516 adjusted_outputs : `Mapping` 

517 Mapping of updated output datasets, with the same structure and 

518 interpretation as ``adjusted_inputs``. 

519 

520 Raises 

521 ------ 

522 ScalarError 

523 Raised if any `Input` or `PrerequisiteInput` connection has 

524 ``multiple`` set to `False`, but multiple datasets. 

525 NoWorkFound 

526 Raised to indicate that this quantum should not be run; not enough 

527 datasets were found for a regular `Input` connection, and the 

528 quantum should be pruned or skipped. 

529 FileNotFoundError 

530 Raised to cause QuantumGraph generation to fail (with the message 

531 included in this exception); not enough datasets were found for a 

532 `PrerequisiteInput` connection. 

533 

534 Notes 

535 ----- 

536 The base class implementation performs important checks. It always 

537 returns an empty mapping (i.e. makes no adjustments). It should 

538 always called be via `super` by custom implementations, ideally at the 

539 end of the custom implementation with already-adjusted mappings when 

540 any datasets are actually dropped, e.g.:: 

541 

542 def adjustQuantum(self, inputs, outputs, label, data_id): 

543 # Filter out some dataset refs for one connection. 

544 connection, old_refs = inputs["my_input"] 

545 new_refs = [ref for ref in old_refs if ...] 

546 adjusted_inputs = {"my_input", (connection, new_refs)} 

547 # Update the original inputs so we can pass them to super. 

548 inputs.update(adjusted_inputs) 

549 # Can ignore outputs from super because they are guaranteed 

550 # to be empty. 

551 super().adjustQuantum(inputs, outputs, label_data_id) 

552 # Return only the connections we modified. 

553 return adjusted_inputs, {} 

554 

555 Removing outputs here is guaranteed to affect what is actually 

556 passed to `PipelineTask.runQuantum`, but its effect on the larger 

557 graph may be deferred to execution, depending on the context in 

558 which `adjustQuantum` is being run: if one quantum removes an output 

559 that is needed by a second quantum as input, the second quantum may not 

560 be adjusted (and hence pruned or skipped) until that output is actually 

561 found to be missing at execution time. 

562 

563 Tasks that desire zip-iteration consistency between any combinations of 

564 connections that have the same data ID should generally implement 

565 `adjustQuantum` to achieve this, even if they could also run that 

566 logic during execution; this allows the system to see outputs that will 

567 not be produced because the corresponding input is missing as early as 

568 possible. 

569 """ 

570 for name, (connection, refs) in inputs.items(): 

571 dataset_type_name = connection.name 

572 if not connection.multiple and len(refs) > 1: 

573 raise ScalarError( 

574 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

575 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

576 f"for quantum data ID {data_id}." 

577 ) 

578 if len(refs) < connection.minimum: 

579 if isinstance(connection, PrerequisiteInput): 

580 # This branch should only be possible during QG generation, 

581 # or if someone deleted the dataset between making the QG 

582 # and trying to run it. Either one should be a hard error. 

583 raise FileNotFoundError( 

584 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

585 f"({dataset_type_name}) with minimum={connection.minimum} for quantum data ID " 

586 f"{data_id}." 

587 ) 

588 else: 

589 # This branch should be impossible during QG generation, 

590 # because that algorithm can only make quanta whose inputs 

591 # are either already present or should be created during 

592 # execution. It can trigger during execution if the input 

593 # wasn't actually created by an upstream task in the same 

594 # graph. 

595 raise NoWorkFound(label, name, connection) 

596 for name, (connection, refs) in outputs.items(): 

597 dataset_type_name = connection.name 

598 if not connection.multiple and len(refs) > 1: 

599 raise ScalarError( 

600 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

601 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

602 f"for quantum data ID {data_id}." 

603 ) 

604 return {}, {} 

605 

606 

607def iterConnections(connections: PipelineTaskConnections, 

608 connectionType: Union[str, Iterable[str]] 

609 ) -> typing.Generator[BaseConnection, None, None]: 

610 """Creates an iterator over the selected connections type which yields 

611 all the defined connections of that type. 

612 

613 Parameters 

614 ---------- 

615 connections: `PipelineTaskConnections` 

616 An instance of a `PipelineTaskConnections` object that will be iterated 

617 over. 

618 connectionType: `str` 

619 The type of connections to iterate over, valid values are inputs, 

620 outputs, prerequisiteInputs, initInputs, initOutputs. 

621 

622 Yields 

623 ------- 

624 connection: `BaseConnection` 

625 A connection defined on the input connections object of the type 

626 supplied. The yielded value Will be an derived type of 

627 `BaseConnection`. 

628 """ 

629 if isinstance(connectionType, str): 

630 connectionType = (connectionType,) 

631 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

632 yield getattr(connections, name) 

633 

634 

635@dataclass 

636class AdjustQuantumHelper: 

637 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

638 

639 This class holds `input` and `output` mappings in the form used by 

640 `Quantum` and execution harness code, i.e. with `DatasetType` keys, 

641 translating them to and from the connection-oriented mappings used inside 

642 `PipelineTaskConnections`. 

643 """ 

644 

645 inputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

646 """Mapping of regular input and prerequisite input datasets, grouped by 

647 `DatasetType`. 

648 """ 

649 

650 outputs: NamedKeyMapping[DatasetType, typing.List[DatasetRef]] 

651 """Mapping of output datasets, grouped by `DatasetType`. 

652 """ 

653 

654 inputs_adjusted: bool = False 

655 """Whether any inputs were removed in the last call to `adjust_in_place`. 

656 """ 

657 

658 outputs_adjusted: bool = False 

659 """Whether any outputs were removed in the last call to `adjust_in_place`. 

660 """ 

661 

662 def adjust_in_place( 

663 self, 

664 connections: PipelineTaskConnections, 

665 label: str, 

666 data_id: DataCoordinate, 

667 ) -> None: 

668 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

669 with its results. 

670 

671 Parameters 

672 ---------- 

673 connections : `PipelineTaskConnections` 

674 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

675 label : `str` 

676 Label for this task in the pipeline (should be used in all 

677 diagnostic messages). 

678 data_id : `lsst.daf.butler.DataCoordinate` 

679 Data ID for this quantum in the pipeline (should be used in all 

680 diagnostic messages). 

681 """ 

682 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

683 # connection-keyed, PipelineTask-oriented mappings. 

684 inputs_by_connection: typing.Dict[str, typing.Tuple[BaseInput, typing.Tuple[DatasetRef, ...]]] = {} 

685 outputs_by_connection: typing.Dict[str, typing.Tuple[Output, typing.Tuple[DatasetRef, ...]]] = {} 

686 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

687 connection = getattr(connections, name) 

688 dataset_type_name = connection.name 

689 inputs_by_connection[name] = ( 

690 connection, 

691 tuple(self.inputs.get(dataset_type_name, ())) 

692 ) 

693 for name in itertools.chain(connections.outputs): 

694 connection = getattr(connections, name) 

695 outputs_by_connection[name] = ( 

696 connection, 

697 tuple(self.outputs.get(dataset_type_name, ())) 

698 ) 

699 # Actually call adjustQuantum. 

700 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

701 inputs_by_connection, 

702 outputs_by_connection, 

703 label, 

704 data_id, 

705 ) 

706 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

707 # installing new mappings in self if necessary. 

708 if adjusted_inputs_by_connection: 

709 adjusted_inputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.inputs) 

710 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

711 dataset_type_name = connection.name 

712 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

713 raise RuntimeError( 

714 f"adjustQuantum implementation for task with label {label} returned {name} " 

715 f"({dataset_type_name}) input datasets that are not a subset of those " 

716 f"it was given for data ID {data_id}." 

717 ) 

718 adjusted_inputs[dataset_type_name] = list(updated_refs) 

719 self.inputs = adjusted_inputs.freeze() 

720 self.inputs_adjusted = True 

721 else: 

722 self.inputs_adjusted = False 

723 if adjusted_outputs_by_connection is not None: 

724 adjusted_outputs = NamedKeyDict[DatasetType, typing.List[DatasetRef]](self.outputs) 

725 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

726 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

727 raise RuntimeError( 

728 f"adjustQuantum implementation for task with label {label} returned {name} " 

729 f"({dataset_type_name}) output datasets that are not a subset of those " 

730 f"it was given for data ID {data_id}." 

731 ) 

732 adjusted_outputs[dataset_type_name] = list(updated_refs) 

733 self.outputs = adjusted_outputs.freeze() 

734 self.outputs_adjusted = True 

735 else: 

736 self.outputs_adjusted = False