Coverage for python/lsst/pipe/base/connections.py: 43%

305 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-17 10:52 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Module defining connection classes for PipelineTask. 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = [ 

34 "AdjustQuantumHelper", 

35 "DeferredDatasetRef", 

36 "InputQuantizedConnection", 

37 "OutputQuantizedConnection", 

38 "PipelineTaskConnections", 

39 "ScalarError", 

40 "iterConnections", 

41 "ScalarError", 

42] 

43 

44import dataclasses 

45import itertools 

46import string 

47import warnings 

48from collections import UserDict 

49from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set 

50from dataclasses import dataclass 

51from types import MappingProxyType, SimpleNamespace 

52from typing import TYPE_CHECKING, Any 

53 

54from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

55 

56from ._status import NoWorkFound 

57from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

58 

59if TYPE_CHECKING: 

60 from .config import PipelineTaskConfig 

61 

62 

63class ScalarError(TypeError): 

64 """Exception raised when dataset type is configured as scalar 

65 but there are multiple data IDs in a Quantum for that dataset. 

66 """ 

67 

68 

69class PipelineTaskConnectionDict(UserDict): 

70 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

71 

72 This dict is used in `PipelineTaskConnection` class creation, as the 

73 dictionary that is initially used as ``__dict__``. It exists to 

74 intercept connection fields declared in a `PipelineTaskConnection`, and 

75 what name is used to identify them. The names are then added to class 

76 level list according to the connection type of the class attribute. The 

77 names are also used as keys in a class level dictionary associated with 

78 the corresponding class attribute. This information is a duplicate of 

79 what exists in ``__dict__``, but provides a simple place to lookup and 

80 iterate on only these variables. 

81 """ 

82 

83 def __init__(self, *args: Any, **kwargs: Any): 

84 super().__init__(*args, **kwargs) 

85 # Initialize class level variables used to track any declared 

86 # class level variables that are instances of 

87 # connectionTypes.BaseConnection 

88 self.data["inputs"] = set() 

89 self.data["prerequisiteInputs"] = set() 

90 self.data["outputs"] = set() 

91 self.data["initInputs"] = set() 

92 self.data["initOutputs"] = set() 

93 self.data["allConnections"] = {} 

94 

95 def __setitem__(self, name: str, value: Any) -> None: 

96 if isinstance(value, BaseConnection): 

97 if name in { 97 ↛ 107line 97 didn't jump to line 107, because the condition on line 97 was never true

98 "dimensions", 

99 "inputs", 

100 "prerequisiteInputs", 

101 "outputs", 

102 "initInputs", 

103 "initOutputs", 

104 "allConnections", 

105 }: 

106 # Guard against connections whose names are reserved. 

107 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

108 if (previous := self.data.get(name)) is not None: 108 ↛ 111line 108 didn't jump to line 111, because the condition on line 108 was never true

109 # Guard against changing the type of an in inherited connection 

110 # by first removing it from the set it's current in. 

111 self.data[previous._connection_type_set].discard(name) 

112 object.__setattr__(value, "varName", name) 

113 self.data["allConnections"][name] = value 

114 self.data[value._connection_type_set].add(name) 

115 # defer to the default behavior 

116 super().__setitem__(name, value) 

117 

118 

119class PipelineTaskConnectionsMetaclass(type): 

120 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

121 

122 # We can annotate these attributes as `collections.abc.Set` to discourage 

123 # undesirable modifications in type-checked code, since the internal code 

124 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

125 # these annotations anyway. 

126 

127 dimensions: Set[str] 

128 """Set of dimension names that define the unit of work for this task. 

129 

130 Required and implied dependencies will automatically be expanded later and 

131 need not be provided. 

132 

133 This is shadowed by an instance-level attribute on 

134 `PipelineTaskConnections` instances. 

135 """ 

136 

137 inputs: Set[str] 

138 """Set with the names of all `~connectionTypes.Input` connection 

139 attributes. 

140 

141 This is updated automatically as class attributes are added. Note that 

142 this attribute is shadowed by an instance-level attribute on 

143 `PipelineTaskConnections` instances. 

144 """ 

145 

146 prerequisiteInputs: Set[str] 

147 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

148 connection attributes. 

149 

150 See `inputs` for additional information. 

151 """ 

152 

153 outputs: Set[str] 

154 """Set with the names of all `~connectionTypes.Output` connection 

155 attributes. 

156 

157 See `inputs` for additional information. 

158 """ 

159 

160 initInputs: Set[str] 

161 """Set with the names of all `~connectionTypes.InitInput` connection 

162 attributes. 

163 

164 See `inputs` for additional information. 

165 """ 

166 

167 initOutputs: Set[str] 

168 """Set with the names of all `~connectionTypes.InitOutput` connection 

169 attributes. 

170 

171 See `inputs` for additional information. 

172 """ 

173 

174 allConnections: Mapping[str, BaseConnection] 

175 """Mapping containing all connection attributes. 

176 

177 See `inputs` for additional information. 

178 """ 

179 

180 def __prepare__(name, bases, **kwargs): # noqa: N804 

181 # Create an instance of our special dict to catch and track all 

182 # variables that are instances of connectionTypes.BaseConnection 

183 # Copy any existing connections from a parent class 

184 dct = PipelineTaskConnectionDict() 

185 for base in bases: 

186 if isinstance(base, PipelineTaskConnectionsMetaclass): 186 ↛ 185line 186 didn't jump to line 185, because the condition on line 186 was never false

187 for name, value in base.allConnections.items(): 187 ↛ 188line 187 didn't jump to line 188, because the loop on line 187 never started

188 dct[name] = value 

189 return dct 

190 

191 def __new__(cls, name, bases, dct, **kwargs): 

192 dimensionsValueError = TypeError( 

193 "PipelineTaskConnections class must be created with a dimensions " 

194 "attribute which is an iterable of dimension names" 

195 ) 

196 

197 if name != "PipelineTaskConnections": 

198 # Verify that dimensions are passed as a keyword in class 

199 # declaration 

200 if "dimensions" not in kwargs: 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true

201 for base in bases: 

202 if hasattr(base, "dimensions"): 

203 kwargs["dimensions"] = base.dimensions 

204 break 

205 if "dimensions" not in kwargs: 

206 raise dimensionsValueError 

207 try: 

208 if isinstance(kwargs["dimensions"], str): 208 ↛ 209line 208 didn't jump to line 209, because the condition on line 208 was never true

209 raise TypeError( 

210 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

211 ) 

212 if not isinstance(kwargs["dimensions"], Iterable): 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true

213 raise TypeError("Dimensions must be iterable of dimensions") 

214 dct["dimensions"] = set(kwargs["dimensions"]) 

215 except TypeError as exc: 

216 raise dimensionsValueError from exc 

217 # Lookup any python string templates that may have been used in the 

218 # declaration of the name field of a class connection attribute 

219 allTemplates = set() 

220 stringFormatter = string.Formatter() 

221 # Loop over all connections 

222 for obj in dct["allConnections"].values(): 

223 nameValue = obj.name 

224 # add all the parameters to the set of templates 

225 for param in stringFormatter.parse(nameValue): 

226 if param[1] is not None: 

227 allTemplates.add(param[1]) 

228 

229 # look up any template from base classes and merge them all 

230 # together 

231 mergeDict = {} 

232 mergeDeprecationsDict = {} 

233 for base in bases[::-1]: 

234 if hasattr(base, "defaultTemplates"): 

235 mergeDict.update(base.defaultTemplates) 

236 if hasattr(base, "deprecatedTemplates"): 

237 mergeDeprecationsDict.update(base.deprecatedTemplates) 

238 if "defaultTemplates" in kwargs: 

239 mergeDict.update(kwargs["defaultTemplates"]) 

240 if "deprecatedTemplates" in kwargs: 240 ↛ 241line 240 didn't jump to line 241, because the condition on line 240 was never true

241 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"]) 

242 if len(mergeDict) > 0: 

243 kwargs["defaultTemplates"] = mergeDict 

244 if len(mergeDeprecationsDict) > 0: 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true

245 kwargs["deprecatedTemplates"] = mergeDeprecationsDict 

246 

247 # Verify that if templated strings were used, defaults were 

248 # supplied as an argument in the declaration of the connection 

249 # class 

250 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 250 ↛ 251line 250 didn't jump to line 251, because the condition on line 250 was never true

251 raise TypeError( 

252 "PipelineTaskConnection class contains templated attribute names, but no " 

253 "defaut templates were provided, add a dictionary attribute named " 

254 "defaultTemplates which contains the mapping between template key and value" 

255 ) 

256 if len(allTemplates) > 0: 

257 # Verify all templates have a default, and throw if they do not 

258 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

259 templateDifference = allTemplates.difference(defaultTemplateKeys) 

260 if templateDifference: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true

261 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

262 # Verify that templates do not share names with variable names 

263 # used for a connection, this is needed because of how 

264 # templates are specified in an associated config class. 

265 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

266 if len(nameTemplateIntersection) > 0: 266 ↛ 267line 266 didn't jump to line 267, because the condition on line 266 was never true

267 raise TypeError( 

268 "Template parameters cannot share names with Class attributes" 

269 f" (conflicts are {nameTemplateIntersection})." 

270 ) 

271 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

272 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {}) 

273 

274 # Convert all the connection containers into frozensets so they cannot 

275 # be modified at the class scope 

276 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

277 dct[connectionName] = frozenset(dct[connectionName]) 

278 # our custom dict type must be turned into an actual dict to be used in 

279 # type.__new__ 

280 return super().__new__(cls, name, bases, dict(dct)) 

281 

282 def __init__(cls, name, bases, dct, **kwargs): 

283 # This overrides the default init to drop the kwargs argument. Python 

284 # metaclasses will have this argument set if any kwargs are passes at 

285 # class construction time, but should be consumed before calling 

286 # __init__ on the type metaclass. This is in accordance with python 

287 # documentation on metaclasses 

288 super().__init__(name, bases, dct) 

289 

290 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

291 # MyPy appears not to really understand metaclass.__call__ at all, so 

292 # we need to tell it to ignore __new__ and __init__ calls here. 

293 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

294 

295 # Make mutable copies of all set-like class attributes so derived 

296 # __init__ implementations can modify them in place. 

297 instance.dimensions = set(cls.dimensions) 

298 instance.inputs = set(cls.inputs) 

299 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

300 instance.outputs = set(cls.outputs) 

301 instance.initInputs = set(cls.initInputs) 

302 instance.initOutputs = set(cls.initOutputs) 

303 

304 # Set self.config. It's a bit strange that we claim to accept None but 

305 # really just raise here, but it's not worth changing now. 

306 from .config import PipelineTaskConfig # local import to avoid cycle 

307 

308 if config is None or not isinstance(config, PipelineTaskConfig): 

309 raise ValueError( 

310 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

311 ) 

312 instance.config = config 

313 

314 # Extract the template names that were defined in the config instance 

315 # by looping over the keys of the defaultTemplates dict specified at 

316 # class declaration time. 

317 templateValues = { 

318 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore 

319 } 

320 

321 # We now assemble a mapping of all connection instances keyed by 

322 # internal name, applying the configuration and templates to make new 

323 # configurations from the class-attribute defaults. This will be 

324 # private, but with a public read-only view. This mapping is what the 

325 # descriptor interface of the class-level attributes will return when 

326 # they are accessed on an instance. This is better than just assigning 

327 # regular instance attributes as it makes it so removed connections 

328 # cannot be accessed on instances, instead of having access to them 

329 # silent fall through to the not-removed class connection instance. 

330 instance._allConnections = {} 

331 instance.allConnections = MappingProxyType(instance._allConnections) 

332 for internal_name, connection in cls.allConnections.items(): 

333 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

334 instance_connection = dataclasses.replace( 

335 connection, 

336 name=dataset_type_name, 

337 doc=( 

338 connection.doc 

339 if connection.deprecated is None 

340 else f"{connection.doc}\n{connection.deprecated}" 

341 ), 

342 _deprecation_context=connection._deprecation_context, 

343 ) 

344 instance._allConnections[internal_name] = instance_connection 

345 

346 # Finally call __init__. The base class implementation does nothing; 

347 # we could have left some of the above implementation there (where it 

348 # originated), but putting it here instead makes it hard for derived 

349 # class implementors to get things into a weird state by delegating to 

350 # super().__init__ in the wrong place, or by forgetting to do that 

351 # entirely. 

352 instance.__init__(config=config) # type: ignore 

353 

354 # Derived-class implementations may have changed the contents of the 

355 # various kinds-of-connection sets; update allConnections to have keys 

356 # that are a union of all those. We get values for the new 

357 # allConnections from the attributes, since any dynamically added new 

358 # ones will not be present in the old allConnections. Typically those 

359 # getattrs will invoke the descriptors and get things from the old 

360 # allConnections anyway. After processing each set we replace it with 

361 # a frozenset. 

362 updated_all_connections = {} 

363 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

364 updated_connection_names = getattr(instance, attrName) 

365 updated_all_connections.update( 

366 {name: getattr(instance, name) for name in updated_connection_names} 

367 ) 

368 # Setting these to frozenset is at odds with the type annotation, 

369 # but MyPy can't tell because we're using setattr, and we want to 

370 # lie to it anyway to get runtime guards against post-__init__ 

371 # mutation. 

372 setattr(instance, attrName, frozenset(updated_connection_names)) 

373 # Update the existing dict in place, since we already have a view of 

374 # that. 

375 instance._allConnections.clear() 

376 instance._allConnections.update(updated_all_connections) 

377 

378 for connection_name, obj in instance._allConnections.items(): 

379 if obj.deprecated is not None: 

380 warnings.warn( 

381 f"Connection {connection_name} with datasetType {obj.name} " 

382 f"(from {obj._deprecation_context}): {obj.deprecated}", 

383 FutureWarning, 

384 stacklevel=1, # Report from this location. 

385 ) 

386 

387 # Freeze the connection instance dimensions now. This at odds with the 

388 # type annotation, which says [mutable] `set`, just like the connection 

389 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

390 # tell with those since we're using setattr for them. 

391 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

392 

393 return instance 

394 

395 

396class QuantizedConnection(SimpleNamespace): 

397 r"""A Namespace to map defined variable names of connections to the 

398 associated `lsst.daf.butler.DatasetRef` objects. 

399 

400 This class maps the names used to define a connection on a 

401 `PipelineTaskConnections` class to the corresponding 

402 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum` 

403 instance. This will be a quantum of execution based on the graph created 

404 by examining all the connections defined on the 

405 `PipelineTaskConnections` class. 

406 """ 

407 

408 def __init__(self, **kwargs): 

409 # Create a variable to track what attributes are added. This is used 

410 # later when iterating over this QuantizedConnection instance 

411 object.__setattr__(self, "_attributes", set()) 

412 

413 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

414 # Capture the attribute name as it is added to this object 

415 self._attributes.add(name) 

416 super().__setattr__(name, value) 

417 

418 def __delattr__(self, name): 

419 object.__delattr__(self, name) 

420 self._attributes.remove(name) 

421 

422 def __len__(self) -> int: 

423 return len(self._attributes) 

424 

425 def __iter__( 

426 self, 

427 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

428 """Make an iterator for this `QuantizedConnection`. 

429 

430 Iterating over a `QuantizedConnection` will yield a tuple with the name 

431 of an attribute and the value associated with that name. This is 

432 similar to dict.items() but is on the namespace attributes rather than 

433 dict keys. 

434 """ 

435 yield from ((name, getattr(self, name)) for name in self._attributes) 

436 

437 def keys(self) -> Generator[str, None, None]: 

438 """Return an iterator over all the attributes added to a 

439 `QuantizedConnection` class 

440 """ 

441 yield from self._attributes 

442 

443 

444class InputQuantizedConnection(QuantizedConnection): 

445 """Input variant of a `QuantizedConnection`.""" 

446 

447 pass 

448 

449 

450class OutputQuantizedConnection(QuantizedConnection): 

451 """Output variant of a `QuantizedConnection`.""" 

452 

453 pass 

454 

455 

456@dataclass(frozen=True) 

457class DeferredDatasetRef: 

458 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

459 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

460 instead of an in-memory dataset. 

461 

462 Parameters 

463 ---------- 

464 datasetRef : `lsst.daf.butler.DatasetRef` 

465 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

466 resolve a dataset 

467 """ 

468 

469 datasetRef: DatasetRef 

470 

471 @property 

472 def datasetType(self) -> DatasetType: 

473 """The dataset type for this dataset.""" 

474 return self.datasetRef.datasetType 

475 

476 @property 

477 def dataId(self) -> DataCoordinate: 

478 """The data ID for this dataset.""" 

479 return self.datasetRef.dataId 

480 

481 

482class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

483 """PipelineTaskConnections is a class used to declare desired IO when a 

484 PipelineTask is run by an activator 

485 

486 Parameters 

487 ---------- 

488 config : `PipelineTaskConfig` 

489 A `PipelineTaskConfig` class instance whose class has been configured 

490 to use this `PipelineTaskConnections` class. 

491 

492 See Also 

493 -------- 

494 iterConnections 

495 

496 Notes 

497 ----- 

498 ``PipelineTaskConnection`` classes are created by declaring class 

499 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

500 listed as follows: 

501 

502 * ``InitInput`` - Defines connections in a quantum graph which are used as 

503 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

504 to this class 

505 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

506 persisted using a butler at the end of the ``__init__`` function of the 

507 `PipelineTask` corresponding to this class. The variable name used to 

508 define this connection should be the same as an attribute name on the 

509 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

510 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

511 a `PipelineTask` instance should have an attribute 

512 ``self.outputSchema`` defined. Its value is what will be saved by the 

513 activator framework. 

514 * ``PrerequisiteInput`` - An input connection type that defines a 

515 `lsst.daf.butler.DatasetType` that must be present at execution time, 

516 but that will not be used during the course of creating the quantum 

517 graph to be executed. These most often are things produced outside the 

518 processing pipeline, such as reference catalogs. 

519 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

520 in the ``run`` method of a `PipelineTask`. The name used to declare 

521 class attribute must match a function argument name in the ``run`` 

522 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

523 defines an ``Input`` with the name ``calexp``, then the corresponding 

524 signature should be ``PipelineTask.run(calexp, ...)`` 

525 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

526 execution of a `PipelineTask`. The name used to declare the connection 

527 must correspond to an attribute of a `Struct` that is returned by a 

528 `PipelineTask` ``run`` method. E.g. if an output connection is 

529 defined with the name ``measCat``, then the corresponding 

530 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

531 X matches the ``storageClass`` type defined on the output connection. 

532 

533 Attributes of these types can also be created, replaced, or deleted on the 

534 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

535 just the name depends on the configuration. It is preferred to define them 

536 in the class when possible (even if configuration may cause the connection 

537 to be removed from the instance). 

538 

539 The process of declaring a ``PipelineTaskConnection`` class involves 

540 parameters passed in the declaration statement. 

541 

542 The first parameter is ``dimensions`` which is an iterable of strings which 

543 defines the unit of processing the run method of a corresponding 

544 `PipelineTask` will operate on. These dimensions must match dimensions that 

545 exist in the butler registry which will be used in executing the 

546 corresponding `PipelineTask`. The dimensions may be also modified in 

547 subclass ``__init__`` methods if they need to depend on configuration. 

548 

549 The second parameter is labeled ``defaultTemplates`` and is conditionally 

550 optional. The name attributes of connections can be specified as python 

551 format strings, with named format arguments. If any of the name parameters 

552 on connections defined in a `PipelineTaskConnections` class contain a 

553 template, then a default template value must be specified in the 

554 ``defaultTemplates`` argument. This is done by passing a dictionary with 

555 keys corresponding to a template identifier, and values corresponding to 

556 the value to use as a default when formatting the string. For example if 

557 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

558 ``defaultTemplates`` = {'input': 'deep'}. 

559 

560 Once a `PipelineTaskConnections` class is created, it is used in the 

561 creation of a `PipelineTaskConfig`. This is further documented in the 

562 documentation of `PipelineTaskConfig`. For the purposes of this 

563 documentation, the relevant information is that the config class allows 

564 configuration of connection names by users when running a pipeline. 

565 

566 Instances of a `PipelineTaskConnections` class are used by the pipeline 

567 task execution framework to introspect what a corresponding `PipelineTask` 

568 will require, and what it will produce. 

569 

570 Examples 

571 -------- 

572 >>> from lsst.pipe.base import connectionTypes as cT 

573 >>> from lsst.pipe.base import PipelineTaskConnections 

574 >>> from lsst.pipe.base import PipelineTaskConfig 

575 >>> class ExampleConnections(PipelineTaskConnections, 

576 ... dimensions=("A", "B"), 

577 ... defaultTemplates={"foo": "Example"}): 

578 ... inputConnection = cT.Input(doc="Example input", 

579 ... dimensions=("A", "B"), 

580 ... storageClass=Exposure, 

581 ... name="{foo}Dataset") 

582 ... outputConnection = cT.Output(doc="Example output", 

583 ... dimensions=("A", "B"), 

584 ... storageClass=Exposure, 

585 ... name="{foo}output") 

586 >>> class ExampleConfig(PipelineTaskConfig, 

587 ... pipelineConnections=ExampleConnections): 

588 ... pass 

589 >>> config = ExampleConfig() 

590 >>> config.connections.foo = Modified 

591 >>> config.connections.outputConnection = "TotallyDifferent" 

592 >>> connections = ExampleConnections(config=config) 

593 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

594 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

595 """ 

596 

597 # We annotate these attributes as mutable sets because that's what they are 

598 # inside derived ``__init__`` implementations and that's what matters most 

599 # After that's done, the metaclass __call__ makes them into frozensets, but 

600 # relatively little code interacts with them then, and that code knows not 

601 # to try to modify them without having to be told that by mypy. 

602 

603 dimensions: set[str] 

604 """Set of dimension names that define the unit of work for this task. 

605 

606 Required and implied dependencies will automatically be expanded later and 

607 need not be provided. 

608 

609 This may be replaced or modified in ``__init__`` to change the dimensions 

610 of the task. After ``__init__`` it will be a `frozenset` and may not be 

611 replaced. 

612 """ 

613 

614 inputs: set[str] 

615 """Set with the names of all `connectionTypes.Input` connection attributes. 

616 

617 This is updated automatically as class attributes are added, removed, or 

618 replaced in ``__init__``. Removing entries from this set will cause those 

619 connections to be removed after ``__init__`` completes, but this is 

620 supported only for backwards compatibility; new code should instead just 

621 delete the collection attributed directly. After ``__init__`` this will be 

622 a `frozenset` and may not be replaced. 

623 """ 

624 

625 prerequisiteInputs: set[str] 

626 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

627 connection attributes. 

628 

629 See `inputs` for additional information. 

630 """ 

631 

632 outputs: set[str] 

633 """Set with the names of all `~connectionTypes.Output` connection 

634 attributes. 

635 

636 See `inputs` for additional information. 

637 """ 

638 

639 initInputs: set[str] 

640 """Set with the names of all `~connectionTypes.InitInput` connection 

641 attributes. 

642 

643 See `inputs` for additional information. 

644 """ 

645 

646 initOutputs: set[str] 

647 """Set with the names of all `~connectionTypes.InitOutput` connection 

648 attributes. 

649 

650 See `inputs` for additional information. 

651 """ 

652 

653 allConnections: Mapping[str, BaseConnection] 

654 """Mapping holding all connection attributes. 

655 

656 This is a read-only view that is automatically updated when connection 

657 attributes are added, removed, or replaced in ``__init__``. It is also 

658 updated after ``__init__`` completes to reflect changes in `inputs`, 

659 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

660 """ 

661 

662 _allConnections: dict[str, BaseConnection] 

663 

664 def __init__(self, *, config: PipelineTaskConfig | None = None): 

665 pass 

666 

667 def __setattr__(self, name: str, value: Any) -> None: 

668 if isinstance(value, BaseConnection): 

669 previous = self._allConnections.get(name) 

670 try: 

671 getattr(self, value._connection_type_set).add(name) 

672 except AttributeError: 

673 # Attempt to call add on a frozenset, which is what these sets 

674 # are after __init__ is done. 

675 raise TypeError("Connections objects are frozen after construction.") from None 

676 if previous is not None and value._connection_type_set != previous._connection_type_set: 

677 # Connection has changed type, e.g. Input to PrerequisiteInput; 

678 # update the sets accordingly. To be extra defensive about 

679 # multiple assignments we use the type of the previous instance 

680 # instead of assuming that's the same as the type of the self, 

681 # which is just the default. Use discard instead of remove so 

682 # manually removing from these sets first is never an error. 

683 getattr(self, previous._connection_type_set).discard(name) 

684 self._allConnections[name] = value 

685 if hasattr(self.__class__, name): 

686 # Don't actually set the attribute if this was a connection 

687 # declared in the class; in that case we let the descriptor 

688 # return the value we just added to allConnections. 

689 return 

690 # Actually add the attribute. 

691 super().__setattr__(name, value) 

692 

693 def __delattr__(self, name): 

694 """Descriptor delete method.""" 

695 previous = self._allConnections.get(name) 

696 if previous is not None: 

697 # Delete this connection's name from the appropriate set, which we 

698 # have to get from the previous instance instead of assuming it's 

699 # the same set that was appropriate for the class-level default. 

700 # Use discard instead of remove so manually removing from these 

701 # sets first is never an error. 

702 try: 

703 getattr(self, previous._connection_type_set).discard(name) 

704 except AttributeError: 

705 # Attempt to call discard on a frozenset, which is what these 

706 # sets are after __init__ is done. 

707 raise TypeError("Connections objects are frozen after construction.") from None 

708 del self._allConnections[name] 

709 if hasattr(self.__class__, name): 

710 # Don't actually delete the attribute if this was a connection 

711 # declared in the class; in that case we let the descriptor 

712 # see that it's no longer present in allConnections. 

713 return 

714 # Actually delete the attribute. 

715 super().__delattr__(name) 

716 

717 def buildDatasetRefs( 

718 self, quantum: Quantum 

719 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

720 """Build `QuantizedConnection` corresponding to input 

721 `~lsst.daf.butler.Quantum`. 

722 

723 Parameters 

724 ---------- 

725 quantum : `lsst.daf.butler.Quantum` 

726 Quantum object which defines the inputs and outputs for a given 

727 unit of processing. 

728 

729 Returns 

730 ------- 

731 retVal : `tuple` of (`InputQuantizedConnection`, 

732 `OutputQuantizedConnection`) Namespaces mapping attribute names 

733 (identifiers of connections) to butler references defined in the 

734 input `lsst.daf.butler.Quantum`. 

735 """ 

736 inputDatasetRefs = InputQuantizedConnection() 

737 outputDatasetRefs = OutputQuantizedConnection() 

738 # operate on a reference object and an iterable of names of class 

739 # connection attributes 

740 for refs, names in zip( 

741 (inputDatasetRefs, outputDatasetRefs), 

742 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

743 strict=True, 

744 ): 

745 # get a name of a class connection attribute 

746 for attributeName in names: 

747 # get the attribute identified by name 

748 attribute = getattr(self, attributeName) 

749 # Branch if the attribute dataset type is an input 

750 if attribute.name in quantum.inputs: 

751 # if the dataset is marked to load deferred, wrap it in a 

752 # DeferredDatasetRef 

753 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

754 if attribute.deferLoad: 

755 quantumInputRefs = [ 

756 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

757 ] 

758 else: 

759 quantumInputRefs = list(quantum.inputs[attribute.name]) 

760 # Unpack arguments that are not marked multiples (list of 

761 # length one) 

762 if not attribute.multiple: 

763 if len(quantumInputRefs) > 1: 

764 raise ScalarError( 

765 "Received multiple datasets " 

766 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

767 f"for scalar connection {attributeName} " 

768 f"({quantumInputRefs[0].datasetType.name}) " 

769 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

770 ) 

771 if len(quantumInputRefs) == 0: 

772 continue 

773 setattr(refs, attributeName, quantumInputRefs[0]) 

774 else: 

775 # Add to the QuantizedConnection identifier 

776 setattr(refs, attributeName, quantumInputRefs) 

777 # Branch if the attribute dataset type is an output 

778 elif attribute.name in quantum.outputs: 

779 value = quantum.outputs[attribute.name] 

780 # Unpack arguments that are not marked multiples (list of 

781 # length one) 

782 if not attribute.multiple: 

783 setattr(refs, attributeName, value[0]) 

784 else: 

785 setattr(refs, attributeName, value) 

786 # Specified attribute is not in inputs or outputs dont know how 

787 # to handle, throw 

788 else: 

789 raise ValueError( 

790 f"Attribute with name {attributeName} has no counterpart in input quantum" 

791 ) 

792 return inputDatasetRefs, outputDatasetRefs 

793 

794 def adjustQuantum( 

795 self, 

796 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

797 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

798 label: str, 

799 data_id: DataCoordinate, 

800 ) -> tuple[ 

801 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

802 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

803 ]: 

804 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

805 in the `lsst.daf.butler.Quantum` during the graph generation stage 

806 of the activator. 

807 

808 Parameters 

809 ---------- 

810 inputs : `dict` 

811 Dictionary whose keys are an input (regular or prerequisite) 

812 connection name and whose values are a tuple of the connection 

813 instance and a collection of associated 

814 `~lsst.daf.butler.DatasetRef` objects. 

815 The exact type of the nested collections is unspecified; it can be 

816 assumed to be multi-pass iterable and support `len` and ``in``, but 

817 it should not be mutated in place. In contrast, the outer 

818 dictionaries are guaranteed to be temporary copies that are true 

819 `dict` instances, and hence may be modified and even returned; this 

820 is especially useful for delegating to `super` (see notes below). 

821 outputs : `~collections.abc.Mapping` 

822 Mapping of output datasets, with the same structure as ``inputs``. 

823 label : `str` 

824 Label for this task in the pipeline (should be used in all 

825 diagnostic messages). 

826 data_id : `lsst.daf.butler.DataCoordinate` 

827 Data ID for this quantum in the pipeline (should be used in all 

828 diagnostic messages). 

829 

830 Returns 

831 ------- 

832 adjusted_inputs : `~collections.abc.Mapping` 

833 Mapping of the same form as ``inputs`` with updated containers of 

834 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

835 not changed should not be returned at all. Datasets may only be 

836 removed, not added. Nested collections may be of any multi-pass 

837 iterable type, and the order of iteration will set the order of 

838 iteration within `PipelineTask.runQuantum`. 

839 adjusted_outputs : `~collections.abc.Mapping` 

840 Mapping of updated output datasets, with the same structure and 

841 interpretation as ``adjusted_inputs``. 

842 

843 Raises 

844 ------ 

845 ScalarError 

846 Raised if any `Input` or `PrerequisiteInput` connection has 

847 ``multiple`` set to `False`, but multiple datasets. 

848 NoWorkFound 

849 Raised to indicate that this quantum should not be run; not enough 

850 datasets were found for a regular `Input` connection, and the 

851 quantum should be pruned or skipped. 

852 FileNotFoundError 

853 Raised to cause QuantumGraph generation to fail (with the message 

854 included in this exception); not enough datasets were found for a 

855 `PrerequisiteInput` connection. 

856 

857 Notes 

858 ----- 

859 The base class implementation performs important checks. It always 

860 returns an empty mapping (i.e. makes no adjustments). It should 

861 always called be via `super` by custom implementations, ideally at the 

862 end of the custom implementation with already-adjusted mappings when 

863 any datasets are actually dropped, e.g.: 

864 

865 .. code-block:: python 

866 

867 def adjustQuantum(self, inputs, outputs, label, data_id): 

868 # Filter out some dataset refs for one connection. 

869 connection, old_refs = inputs["my_input"] 

870 new_refs = [ref for ref in old_refs if ...] 

871 adjusted_inputs = {"my_input", (connection, new_refs)} 

872 # Update the original inputs so we can pass them to super. 

873 inputs.update(adjusted_inputs) 

874 # Can ignore outputs from super because they are guaranteed 

875 # to be empty. 

876 super().adjustQuantum(inputs, outputs, label_data_id) 

877 # Return only the connections we modified. 

878 return adjusted_inputs, {} 

879 

880 Removing outputs here is guaranteed to affect what is actually 

881 passed to `PipelineTask.runQuantum`, but its effect on the larger 

882 graph may be deferred to execution, depending on the context in 

883 which `adjustQuantum` is being run: if one quantum removes an output 

884 that is needed by a second quantum as input, the second quantum may not 

885 be adjusted (and hence pruned or skipped) until that output is actually 

886 found to be missing at execution time. 

887 

888 Tasks that desire zip-iteration consistency between any combinations of 

889 connections that have the same data ID should generally implement 

890 `adjustQuantum` to achieve this, even if they could also run that 

891 logic during execution; this allows the system to see outputs that will 

892 not be produced because the corresponding input is missing as early as 

893 possible. 

894 """ 

895 for name, (input_connection, refs) in inputs.items(): 

896 dataset_type_name = input_connection.name 

897 if not input_connection.multiple and len(refs) > 1: 

898 raise ScalarError( 

899 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

900 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

901 f"for quantum data ID {data_id}." 

902 ) 

903 if len(refs) < input_connection.minimum: 

904 if isinstance(input_connection, PrerequisiteInput): 

905 # This branch should only be possible during QG generation, 

906 # or if someone deleted the dataset between making the QG 

907 # and trying to run it. Either one should be a hard error. 

908 raise FileNotFoundError( 

909 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

910 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

911 f"{data_id}." 

912 ) 

913 else: 

914 raise NoWorkFound(label, name, input_connection) 

915 for name, (output_connection, refs) in outputs.items(): 

916 dataset_type_name = output_connection.name 

917 if not output_connection.multiple and len(refs) > 1: 

918 raise ScalarError( 

919 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

920 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

921 f"for quantum data ID {data_id}." 

922 ) 

923 return {}, {} 

924 

925 def getSpatialBoundsConnections(self) -> Iterable[str]: 

926 """Return the names of regular input and output connections whose data 

927 IDs should be used to compute the spatial bounds of this task's quanta. 

928 

929 The spatial bound for a quantum is defined as the union of the regions 

930 of all data IDs of all connections returned here, along with the region 

931 of the quantum data ID (if the task has spatial dimensions). 

932 

933 Returns 

934 ------- 

935 connection_names : `collections.abc.Iterable` [ `str` ] 

936 Names of collections with spatial dimensions. These are the 

937 task-internal connection names, not butler dataset type names. 

938 

939 Notes 

940 ----- 

941 The spatial bound is used to search for prerequisite inputs that have 

942 skypix dimensions. The default implementation returns an empty 

943 iterable, which is usually sufficient for tasks with spatial 

944 dimensions, but if a task's inputs or outputs are associated with 

945 spatial regions that extend beyond the quantum data ID's region, this 

946 method may need to be overridden to expand the set of prerequisite 

947 inputs found. 

948 

949 Tasks that do not have spatial dimensions that have skypix prerequisite 

950 inputs should always override this method, as the default spatial 

951 bounds otherwise cover the full sky. 

952 """ 

953 return () 

954 

955 def getTemporalBoundsConnections(self) -> Iterable[str]: 

956 """Return the names of regular input and output connections whose data 

957 IDs should be used to compute the temporal bounds of this task's 

958 quanta. 

959 

960 The temporal bound for a quantum is defined as the union of the 

961 timespans of all data IDs of all connections returned here, along with 

962 the timespan of the quantum data ID (if the task has temporal 

963 dimensions). 

964 

965 Returns 

966 ------- 

967 connection_names : `collections.abc.Iterable` [ `str` ] 

968 Names of collections with temporal dimensions. These are the 

969 task-internal connection names, not butler dataset type names. 

970 

971 Notes 

972 ----- 

973 The temporal bound is used to search for prerequisite inputs that are 

974 calibration datasets. The default implementation returns an empty 

975 iterable, which is usually sufficient for tasks with temporal 

976 dimensions, but if a task's inputs or outputs are associated with 

977 timespans that extend beyond the quantum data ID's timespan, this 

978 method may need to be overridden to expand the set of prerequisite 

979 inputs found. 

980 

981 Tasks that do not have temporal dimensions that do not implement this 

982 method will use an infinite timespan for any calibration lookups. 

983 """ 

984 return () 

985 

986 

987def iterConnections( 

988 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

989) -> Generator[BaseConnection, None, None]: 

990 """Create an iterator over the selected connections type which yields 

991 all the defined connections of that type. 

992 

993 Parameters 

994 ---------- 

995 connections : `PipelineTaskConnections` 

996 An instance of a `PipelineTaskConnections` object that will be iterated 

997 over. 

998 connectionType : `str` 

999 The type of connections to iterate over, valid values are inputs, 

1000 outputs, prerequisiteInputs, initInputs, initOutputs. 

1001 

1002 Yields 

1003 ------ 

1004 connection: `~.connectionTypes.BaseConnection` 

1005 A connection defined on the input connections object of the type 

1006 supplied. The yielded value Will be an derived type of 

1007 `~.connectionTypes.BaseConnection`. 

1008 """ 

1009 if isinstance(connectionType, str): 

1010 connectionType = (connectionType,) 

1011 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

1012 yield getattr(connections, name) 

1013 

1014 

1015@dataclass 

1016class AdjustQuantumHelper: 

1017 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

1018 

1019 This class holds `input` and `output` mappings in the form used by 

1020 `Quantum` and execution harness code, i.e. with 

1021 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

1022 connection-oriented mappings used inside `PipelineTaskConnections`. 

1023 """ 

1024 

1025 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1026 """Mapping of regular input and prerequisite input datasets, grouped by 

1027 `~lsst.daf.butler.DatasetType`. 

1028 """ 

1029 

1030 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1031 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

1032 """ 

1033 

1034 inputs_adjusted: bool = False 

1035 """Whether any inputs were removed in the last call to `adjust_in_place`. 

1036 """ 

1037 

1038 outputs_adjusted: bool = False 

1039 """Whether any outputs were removed in the last call to `adjust_in_place`. 

1040 """ 

1041 

1042 def adjust_in_place( 

1043 self, 

1044 connections: PipelineTaskConnections, 

1045 label: str, 

1046 data_id: DataCoordinate, 

1047 ) -> None: 

1048 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

1049 with its results. 

1050 

1051 Parameters 

1052 ---------- 

1053 connections : `PipelineTaskConnections` 

1054 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

1055 label : `str` 

1056 Label for this task in the pipeline (should be used in all 

1057 diagnostic messages). 

1058 data_id : `lsst.daf.butler.DataCoordinate` 

1059 Data ID for this quantum in the pipeline (should be used in all 

1060 diagnostic messages). 

1061 """ 

1062 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

1063 # connection-keyed, PipelineTask-oriented mappings. 

1064 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

1065 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

1066 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

1067 connection = getattr(connections, name) 

1068 dataset_type_name = connection.name 

1069 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

1070 for name in itertools.chain(connections.outputs): 

1071 connection = getattr(connections, name) 

1072 dataset_type_name = connection.name 

1073 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

1074 # Actually call adjustQuantum. 

1075 # MyPy correctly complains that this call is not quite legal, but the 

1076 # method docs explain exactly what's expected and it's the behavior we 

1077 # want. It'd be nice to avoid this if we ever have to change the 

1078 # interface anyway, but not an immediate problem. 

1079 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

1080 inputs_by_connection, # type: ignore 

1081 outputs_by_connection, # type: ignore 

1082 label, 

1083 data_id, 

1084 ) 

1085 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

1086 # installing new mappings in self if necessary. 

1087 if adjusted_inputs_by_connection: 

1088 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) 

1089 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1090 dataset_type_name = connection.name 

1091 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1092 raise RuntimeError( 

1093 f"adjustQuantum implementation for task with label {label} returned {name} " 

1094 f"({dataset_type_name}) input datasets that are not a subset of those " 

1095 f"it was given for data ID {data_id}." 

1096 ) 

1097 adjusted_inputs[dataset_type_name] = tuple(updated_refs) 

1098 self.inputs = adjusted_inputs.freeze() 

1099 self.inputs_adjusted = True 

1100 else: 

1101 self.inputs_adjusted = False 

1102 if adjusted_outputs_by_connection: 

1103 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) 

1104 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1105 dataset_type_name = connection.name 

1106 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1107 raise RuntimeError( 

1108 f"adjustQuantum implementation for task with label {label} returned {name} " 

1109 f"({dataset_type_name}) output datasets that are not a subset of those " 

1110 f"it was given for data ID {data_id}." 

1111 ) 

1112 adjusted_outputs[dataset_type_name] = tuple(updated_refs) 

1113 self.outputs = adjusted_outputs.freeze() 

1114 self.outputs_adjusted = True 

1115 else: 

1116 self.outputs_adjusted = False