Coverage for python/lsst/pipe/base/connections.py: 44%

303 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-24 10:01 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Module defining connection classes for PipelineTask. 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = [ 

34 "AdjustQuantumHelper", 

35 "DeferredDatasetRef", 

36 "InputQuantizedConnection", 

37 "OutputQuantizedConnection", 

38 "PipelineTaskConnections", 

39 "ScalarError", 

40 "iterConnections", 

41 "ScalarError", 

42 "QuantizedConnection", 

43] 

44 

45import dataclasses 

46import itertools 

47import string 

48import warnings 

49from collections import UserDict 

50from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set 

51from dataclasses import dataclass 

52from types import MappingProxyType, SimpleNamespace 

53from typing import TYPE_CHECKING, Any 

54 

55from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

56 

57from ._status import NoWorkFound 

58from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

59 

60if TYPE_CHECKING: 

61 from .config import PipelineTaskConfig 

62 

63 

64class ScalarError(TypeError): 

65 """Exception raised when dataset type is configured as scalar 

66 but there are multiple data IDs in a Quantum for that dataset. 

67 """ 

68 

69 

70class PipelineTaskConnectionDict(UserDict): 

71 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

72 

73 This dict is used in `PipelineTaskConnection` class creation, as the 

74 dictionary that is initially used as ``__dict__``. It exists to 

75 intercept connection fields declared in a `PipelineTaskConnection`, and 

76 what name is used to identify them. The names are then added to class 

77 level list according to the connection type of the class attribute. The 

78 names are also used as keys in a class level dictionary associated with 

79 the corresponding class attribute. This information is a duplicate of 

80 what exists in ``__dict__``, but provides a simple place to lookup and 

81 iterate on only these variables. 

82 

83 Parameters 

84 ---------- 

85 *args : `~typing.Any` 

86 Passed to `dict` constructor. 

87 **kwargs : `~typing.Any` 

88 Passed to `dict` constructor. 

89 """ 

90 

91 def __init__(self, *args: Any, **kwargs: Any): 

92 super().__init__(*args, **kwargs) 

93 # Initialize class level variables used to track any declared 

94 # class level variables that are instances of 

95 # connectionTypes.BaseConnection 

96 self.data["inputs"] = set() 

97 self.data["prerequisiteInputs"] = set() 

98 self.data["outputs"] = set() 

99 self.data["initInputs"] = set() 

100 self.data["initOutputs"] = set() 

101 self.data["allConnections"] = {} 

102 

103 def __setitem__(self, name: str, value: Any) -> None: 

104 if isinstance(value, BaseConnection): 

105 if name in { 105 ↛ 115line 105 didn't jump to line 115, because the condition on line 105 was never true

106 "dimensions", 

107 "inputs", 

108 "prerequisiteInputs", 

109 "outputs", 

110 "initInputs", 

111 "initOutputs", 

112 "allConnections", 

113 }: 

114 # Guard against connections whose names are reserved. 

115 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

116 if (previous := self.data.get(name)) is not None: 116 ↛ 119line 116 didn't jump to line 119, because the condition on line 116 was never true

117 # Guard against changing the type of an in inherited connection 

118 # by first removing it from the set it's current in. 

119 self.data[previous._connection_type_set].discard(name) 

120 object.__setattr__(value, "varName", name) 

121 self.data["allConnections"][name] = value 

122 self.data[value._connection_type_set].add(name) 

123 # defer to the default behavior 

124 super().__setitem__(name, value) 

125 

126 

127class PipelineTaskConnectionsMetaclass(type): 

128 """Metaclass used in the declaration of PipelineTaskConnections classes. 

129 

130 Parameters 

131 ---------- 

132 name : `str` 

133 Name of connection. 

134 bases : `~collections.abc.Collection` 

135 Base classes. 

136 dct : `~collections.abc.Mapping` 

137 Connections dict. 

138 **kwargs : `~typing.Any` 

139 Additional parameters. 

140 """ 

141 

142 # We can annotate these attributes as `collections.abc.Set` to discourage 

143 # undesirable modifications in type-checked code, since the internal code 

144 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

145 # these annotations anyway. 

146 

147 dimensions: Set[str] 

148 """Set of dimension names that define the unit of work for this task. 

149 

150 Required and implied dependencies will automatically be expanded later and 

151 need not be provided. 

152 

153 This is shadowed by an instance-level attribute on 

154 `PipelineTaskConnections` instances. 

155 """ 

156 

157 inputs: Set[str] 

158 """Set with the names of all `~connectionTypes.Input` connection 

159 attributes. 

160 

161 This is updated automatically as class attributes are added. Note that 

162 this attribute is shadowed by an instance-level attribute on 

163 `PipelineTaskConnections` instances. 

164 """ 

165 

166 prerequisiteInputs: Set[str] 

167 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

168 connection attributes. 

169 

170 See `inputs` for additional information. 

171 """ 

172 

173 outputs: Set[str] 

174 """Set with the names of all `~connectionTypes.Output` connection 

175 attributes. 

176 

177 See `inputs` for additional information. 

178 """ 

179 

180 initInputs: Set[str] 

181 """Set with the names of all `~connectionTypes.InitInput` connection 

182 attributes. 

183 

184 See `inputs` for additional information. 

185 """ 

186 

187 initOutputs: Set[str] 

188 """Set with the names of all `~connectionTypes.InitOutput` connection 

189 attributes. 

190 

191 See `inputs` for additional information. 

192 """ 

193 

194 allConnections: Mapping[str, BaseConnection] 

195 """Mapping containing all connection attributes. 

196 

197 See `inputs` for additional information. 

198 """ 

199 

200 def __prepare__(name, bases, **kwargs): # noqa: N804 

201 # Create an instance of our special dict to catch and track all 

202 # variables that are instances of connectionTypes.BaseConnection 

203 # Copy any existing connections from a parent class 

204 dct = PipelineTaskConnectionDict() 

205 for base in bases: 

206 if isinstance(base, PipelineTaskConnectionsMetaclass): 206 ↛ 205line 206 didn't jump to line 205, because the condition on line 206 was never false

207 for name, value in base.allConnections.items(): 207 ↛ 208line 207 didn't jump to line 208, because the loop on line 207 never started

208 dct[name] = value 

209 return dct 

210 

211 def __new__(cls, name, bases, dct, **kwargs): 

212 dimensionsValueError = TypeError( 

213 "PipelineTaskConnections class must be created with a dimensions " 

214 "attribute which is an iterable of dimension names" 

215 ) 

216 

217 if name != "PipelineTaskConnections": 

218 # Verify that dimensions are passed as a keyword in class 

219 # declaration 

220 if "dimensions" not in kwargs: 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true

221 for base in bases: 

222 if hasattr(base, "dimensions"): 

223 kwargs["dimensions"] = base.dimensions 

224 break 

225 if "dimensions" not in kwargs: 

226 raise dimensionsValueError 

227 try: 

228 if isinstance(kwargs["dimensions"], str): 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true

229 raise TypeError( 

230 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

231 ) 

232 if not isinstance(kwargs["dimensions"], Iterable): 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true

233 raise TypeError("Dimensions must be iterable of dimensions") 

234 dct["dimensions"] = set(kwargs["dimensions"]) 

235 except TypeError as exc: 

236 raise dimensionsValueError from exc 

237 # Lookup any python string templates that may have been used in the 

238 # declaration of the name field of a class connection attribute 

239 allTemplates = set() 

240 stringFormatter = string.Formatter() 

241 # Loop over all connections 

242 for obj in dct["allConnections"].values(): 

243 nameValue = obj.name 

244 # add all the parameters to the set of templates 

245 for param in stringFormatter.parse(nameValue): 

246 if param[1] is not None: 

247 allTemplates.add(param[1]) 

248 

249 # look up any template from base classes and merge them all 

250 # together 

251 mergeDict = {} 

252 mergeDeprecationsDict = {} 

253 for base in bases[::-1]: 

254 if hasattr(base, "defaultTemplates"): 

255 mergeDict.update(base.defaultTemplates) 

256 if hasattr(base, "deprecatedTemplates"): 

257 mergeDeprecationsDict.update(base.deprecatedTemplates) 

258 if "defaultTemplates" in kwargs: 

259 mergeDict.update(kwargs["defaultTemplates"]) 

260 if "deprecatedTemplates" in kwargs: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true

261 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"]) 

262 if len(mergeDict) > 0: 

263 kwargs["defaultTemplates"] = mergeDict 

264 if len(mergeDeprecationsDict) > 0: 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true

265 kwargs["deprecatedTemplates"] = mergeDeprecationsDict 

266 

267 # Verify that if templated strings were used, defaults were 

268 # supplied as an argument in the declaration of the connection 

269 # class 

270 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 270 ↛ 271line 270 didn't jump to line 271, because the condition on line 270 was never true

271 raise TypeError( 

272 "PipelineTaskConnection class contains templated attribute names, but no " 

273 "defaut templates were provided, add a dictionary attribute named " 

274 "defaultTemplates which contains the mapping between template key and value" 

275 ) 

276 if len(allTemplates) > 0: 

277 # Verify all templates have a default, and throw if they do not 

278 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

279 templateDifference = allTemplates.difference(defaultTemplateKeys) 

280 if templateDifference: 280 ↛ 281line 280 didn't jump to line 281, because the condition on line 280 was never true

281 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

282 # Verify that templates do not share names with variable names 

283 # used for a connection, this is needed because of how 

284 # templates are specified in an associated config class. 

285 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

286 if len(nameTemplateIntersection) > 0: 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true

287 raise TypeError( 

288 "Template parameters cannot share names with Class attributes" 

289 f" (conflicts are {nameTemplateIntersection})." 

290 ) 

291 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

292 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {}) 

293 

294 # Convert all the connection containers into frozensets so they cannot 

295 # be modified at the class scope 

296 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

297 dct[connectionName] = frozenset(dct[connectionName]) 

298 # our custom dict type must be turned into an actual dict to be used in 

299 # type.__new__ 

300 return super().__new__(cls, name, bases, dict(dct)) 

301 

302 def __init__(cls, name, bases, dct, **kwargs): 

303 # This overrides the default init to drop the kwargs argument. Python 

304 # metaclasses will have this argument set if any kwargs are passes at 

305 # class construction time, but should be consumed before calling 

306 # __init__ on the type metaclass. This is in accordance with python 

307 # documentation on metaclasses 

308 super().__init__(name, bases, dct) 

309 

310 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

311 # MyPy appears not to really understand metaclass.__call__ at all, so 

312 # we need to tell it to ignore __new__ and __init__ calls here. 

313 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

314 

315 # Make mutable copies of all set-like class attributes so derived 

316 # __init__ implementations can modify them in place. 

317 instance.dimensions = set(cls.dimensions) 

318 instance.inputs = set(cls.inputs) 

319 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

320 instance.outputs = set(cls.outputs) 

321 instance.initInputs = set(cls.initInputs) 

322 instance.initOutputs = set(cls.initOutputs) 

323 

324 # Set self.config. It's a bit strange that we claim to accept None but 

325 # really just raise here, but it's not worth changing now. 

326 from .config import PipelineTaskConfig # local import to avoid cycle 

327 

328 if config is None or not isinstance(config, PipelineTaskConfig): 

329 raise ValueError( 

330 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

331 ) 

332 instance.config = config 

333 

334 # Extract the template names that were defined in the config instance 

335 # by looping over the keys of the defaultTemplates dict specified at 

336 # class declaration time. 

337 templateValues = { 

338 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore 

339 } 

340 

341 # We now assemble a mapping of all connection instances keyed by 

342 # internal name, applying the configuration and templates to make new 

343 # configurations from the class-attribute defaults. This will be 

344 # private, but with a public read-only view. This mapping is what the 

345 # descriptor interface of the class-level attributes will return when 

346 # they are accessed on an instance. This is better than just assigning 

347 # regular instance attributes as it makes it so removed connections 

348 # cannot be accessed on instances, instead of having access to them 

349 # silent fall through to the not-removed class connection instance. 

350 instance._allConnections = {} 

351 instance.allConnections = MappingProxyType(instance._allConnections) 

352 for internal_name, connection in cls.allConnections.items(): 

353 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

354 instance_connection = dataclasses.replace( 

355 connection, 

356 name=dataset_type_name, 

357 doc=( 

358 connection.doc 

359 if connection.deprecated is None 

360 else f"{connection.doc}\n{connection.deprecated}" 

361 ), 

362 _deprecation_context=connection._deprecation_context, 

363 ) 

364 instance._allConnections[internal_name] = instance_connection 

365 

366 # Finally call __init__. The base class implementation does nothing; 

367 # we could have left some of the above implementation there (where it 

368 # originated), but putting it here instead makes it hard for derived 

369 # class implementors to get things into a weird state by delegating to 

370 # super().__init__ in the wrong place, or by forgetting to do that 

371 # entirely. 

372 instance.__init__(config=config) # type: ignore 

373 

374 # Derived-class implementations may have changed the contents of the 

375 # various kinds-of-connection sets; update allConnections to have keys 

376 # that are a union of all those. We get values for the new 

377 # allConnections from the attributes, since any dynamically added new 

378 # ones will not be present in the old allConnections. Typically those 

379 # getattrs will invoke the descriptors and get things from the old 

380 # allConnections anyway. After processing each set we replace it with 

381 # a frozenset. 

382 updated_all_connections = {} 

383 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

384 updated_connection_names = getattr(instance, attrName) 

385 updated_all_connections.update( 

386 {name: getattr(instance, name) for name in updated_connection_names} 

387 ) 

388 # Setting these to frozenset is at odds with the type annotation, 

389 # but MyPy can't tell because we're using setattr, and we want to 

390 # lie to it anyway to get runtime guards against post-__init__ 

391 # mutation. 

392 setattr(instance, attrName, frozenset(updated_connection_names)) 

393 # Update the existing dict in place, since we already have a view of 

394 # that. 

395 instance._allConnections.clear() 

396 instance._allConnections.update(updated_all_connections) 

397 

398 for connection_name, obj in instance._allConnections.items(): 

399 if obj.deprecated is not None: 

400 warnings.warn( 

401 f"Connection {connection_name} with datasetType {obj.name} " 

402 f"(from {obj._deprecation_context}): {obj.deprecated}", 

403 FutureWarning, 

404 stacklevel=1, # Report from this location. 

405 ) 

406 

407 # Freeze the connection instance dimensions now. This at odds with the 

408 # type annotation, which says [mutable] `set`, just like the connection 

409 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

410 # tell with those since we're using setattr for them. 

411 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

412 

413 return instance 

414 

415 

416class QuantizedConnection(SimpleNamespace): 

417 r"""A Namespace to map defined variable names of connections to the 

418 associated `lsst.daf.butler.DatasetRef` objects. 

419 

420 This class maps the names used to define a connection on a 

421 `PipelineTaskConnections` class to the corresponding 

422 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum` 

423 instance. This will be a quantum of execution based on the graph created 

424 by examining all the connections defined on the 

425 `PipelineTaskConnections` class. 

426 

427 Parameters 

428 ---------- 

429 **kwargs : `~typing.Any` 

430 Not used. 

431 """ 

432 

433 def __init__(self, **kwargs): 

434 # Create a variable to track what attributes are added. This is used 

435 # later when iterating over this QuantizedConnection instance 

436 object.__setattr__(self, "_attributes", set()) 

437 

438 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

439 # Capture the attribute name as it is added to this object 

440 self._attributes.add(name) 

441 super().__setattr__(name, value) 

442 

443 def __delattr__(self, name): 

444 object.__delattr__(self, name) 

445 self._attributes.remove(name) 

446 

447 def __len__(self) -> int: 

448 return len(self._attributes) 

449 

450 def __iter__( 

451 self, 

452 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

453 """Make an iterator for this `QuantizedConnection`. 

454 

455 Iterating over a `QuantizedConnection` will yield a tuple with the name 

456 of an attribute and the value associated with that name. This is 

457 similar to dict.items() but is on the namespace attributes rather than 

458 dict keys. 

459 """ 

460 yield from ((name, getattr(self, name)) for name in self._attributes) 

461 

462 def keys(self) -> Generator[str, None, None]: 

463 """Return an iterator over all the attributes added to a 

464 `QuantizedConnection` class. 

465 """ 

466 yield from self._attributes 

467 

468 

469class InputQuantizedConnection(QuantizedConnection): 

470 """Input variant of a `QuantizedConnection`.""" 

471 

472 pass 

473 

474 

475class OutputQuantizedConnection(QuantizedConnection): 

476 """Output variant of a `QuantizedConnection`.""" 

477 

478 pass 

479 

480 

481@dataclass(frozen=True) 

482class DeferredDatasetRef: 

483 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

484 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

485 instead of an in-memory dataset. 

486 

487 Attributes 

488 ---------- 

489 datasetRef : `lsst.daf.butler.DatasetRef` 

490 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

491 resolve a dataset. 

492 """ 

493 

494 datasetRef: DatasetRef 

495 

496 @property 

497 def datasetType(self) -> DatasetType: 

498 """The dataset type for this dataset.""" 

499 return self.datasetRef.datasetType 

500 

501 @property 

502 def dataId(self) -> DataCoordinate: 

503 """The data ID for this dataset.""" 

504 return self.datasetRef.dataId 

505 

506 

507class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

508 """PipelineTaskConnections is a class used to declare desired IO when a 

509 PipelineTask is run by an activator. 

510 

511 Parameters 

512 ---------- 

513 config : `PipelineTaskConfig` 

514 A `PipelineTaskConfig` class instance whose class has been configured 

515 to use this `PipelineTaskConnections` class. 

516 

517 See Also 

518 -------- 

519 iterConnections : Iterator over selected connections. 

520 

521 Notes 

522 ----- 

523 ``PipelineTaskConnection`` classes are created by declaring class 

524 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

525 listed as follows: 

526 

527 * ``InitInput`` - Defines connections in a quantum graph which are used as 

528 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

529 to this class 

530 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

531 persisted using a butler at the end of the ``__init__`` function of the 

532 `PipelineTask` corresponding to this class. The variable name used to 

533 define this connection should be the same as an attribute name on the 

534 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

535 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

536 a `PipelineTask` instance should have an attribute 

537 ``self.outputSchema`` defined. Its value is what will be saved by the 

538 activator framework. 

539 * ``PrerequisiteInput`` - An input connection type that defines a 

540 `lsst.daf.butler.DatasetType` that must be present at execution time, 

541 but that will not be used during the course of creating the quantum 

542 graph to be executed. These most often are things produced outside the 

543 processing pipeline, such as reference catalogs. 

544 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

545 in the ``run`` method of a `PipelineTask`. The name used to declare 

546 class attribute must match a function argument name in the ``run`` 

547 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

548 defines an ``Input`` with the name ``calexp``, then the corresponding 

549 signature should be ``PipelineTask.run(calexp, ...)`` 

550 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

551 execution of a `PipelineTask`. The name used to declare the connection 

552 must correspond to an attribute of a `Struct` that is returned by a 

553 `PipelineTask` ``run`` method. E.g. if an output connection is 

554 defined with the name ``measCat``, then the corresponding 

555 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

556 X matches the ``storageClass`` type defined on the output connection. 

557 

558 Attributes of these types can also be created, replaced, or deleted on the 

559 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

560 just the name depends on the configuration. It is preferred to define them 

561 in the class when possible (even if configuration may cause the connection 

562 to be removed from the instance). 

563 

564 The process of declaring a ``PipelineTaskConnection`` class involves 

565 parameters passed in the declaration statement. 

566 

567 The first parameter is ``dimensions`` which is an iterable of strings which 

568 defines the unit of processing the run method of a corresponding 

569 `PipelineTask` will operate on. These dimensions must match dimensions that 

570 exist in the butler registry which will be used in executing the 

571 corresponding `PipelineTask`. The dimensions may be also modified in 

572 subclass ``__init__`` methods if they need to depend on configuration. 

573 

574 The second parameter is labeled ``defaultTemplates`` and is conditionally 

575 optional. The name attributes of connections can be specified as python 

576 format strings, with named format arguments. If any of the name parameters 

577 on connections defined in a `PipelineTaskConnections` class contain a 

578 template, then a default template value must be specified in the 

579 ``defaultTemplates`` argument. This is done by passing a dictionary with 

580 keys corresponding to a template identifier, and values corresponding to 

581 the value to use as a default when formatting the string. For example if 

582 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

583 ``defaultTemplates`` = {'input': 'deep'}. 

584 

585 Once a `PipelineTaskConnections` class is created, it is used in the 

586 creation of a `PipelineTaskConfig`. This is further documented in the 

587 documentation of `PipelineTaskConfig`. For the purposes of this 

588 documentation, the relevant information is that the config class allows 

589 configuration of connection names by users when running a pipeline. 

590 

591 Instances of a `PipelineTaskConnections` class are used by the pipeline 

592 task execution framework to introspect what a corresponding `PipelineTask` 

593 will require, and what it will produce. 

594 

595 Examples 

596 -------- 

597 >>> from lsst.pipe.base import connectionTypes as cT 

598 >>> from lsst.pipe.base import PipelineTaskConnections 

599 >>> from lsst.pipe.base import PipelineTaskConfig 

600 >>> class ExampleConnections(PipelineTaskConnections, 

601 ... dimensions=("A", "B"), 

602 ... defaultTemplates={"foo": "Example"}): 

603 ... inputConnection = cT.Input(doc="Example input", 

604 ... dimensions=("A", "B"), 

605 ... storageClass=Exposure, 

606 ... name="{foo}Dataset") 

607 ... outputConnection = cT.Output(doc="Example output", 

608 ... dimensions=("A", "B"), 

609 ... storageClass=Exposure, 

610 ... name="{foo}output") 

611 >>> class ExampleConfig(PipelineTaskConfig, 

612 ... pipelineConnections=ExampleConnections): 

613 ... pass 

614 >>> config = ExampleConfig() 

615 >>> config.connections.foo = Modified 

616 >>> config.connections.outputConnection = "TotallyDifferent" 

617 >>> connections = ExampleConnections(config=config) 

618 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

619 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

620 """ 

621 

622 # We annotate these attributes as mutable sets because that's what they are 

623 # inside derived ``__init__`` implementations and that's what matters most 

624 # After that's done, the metaclass __call__ makes them into frozensets, but 

625 # relatively little code interacts with them then, and that code knows not 

626 # to try to modify them without having to be told that by mypy. 

627 

628 dimensions: set[str] 

629 """Set of dimension names that define the unit of work for this task. 

630 

631 Required and implied dependencies will automatically be expanded later and 

632 need not be provided. 

633 

634 This may be replaced or modified in ``__init__`` to change the dimensions 

635 of the task. After ``__init__`` it will be a `frozenset` and may not be 

636 replaced. 

637 """ 

638 

639 inputs: set[str] 

640 """Set with the names of all `connectionTypes.Input` connection attributes. 

641 

642 This is updated automatically as class attributes are added, removed, or 

643 replaced in ``__init__``. Removing entries from this set will cause those 

644 connections to be removed after ``__init__`` completes, but this is 

645 supported only for backwards compatibility; new code should instead just 

646 delete the collection attributed directly. After ``__init__`` this will be 

647 a `frozenset` and may not be replaced. 

648 """ 

649 

650 prerequisiteInputs: set[str] 

651 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

652 connection attributes. 

653 

654 See `inputs` for additional information. 

655 """ 

656 

657 outputs: set[str] 

658 """Set with the names of all `~connectionTypes.Output` connection 

659 attributes. 

660 

661 See `inputs` for additional information. 

662 """ 

663 

664 initInputs: set[str] 

665 """Set with the names of all `~connectionTypes.InitInput` connection 

666 attributes. 

667 

668 See `inputs` for additional information. 

669 """ 

670 

671 initOutputs: set[str] 

672 """Set with the names of all `~connectionTypes.InitOutput` connection 

673 attributes. 

674 

675 See `inputs` for additional information. 

676 """ 

677 

678 allConnections: Mapping[str, BaseConnection] 

679 """Mapping holding all connection attributes. 

680 

681 This is a read-only view that is automatically updated when connection 

682 attributes are added, removed, or replaced in ``__init__``. It is also 

683 updated after ``__init__`` completes to reflect changes in `inputs`, 

684 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

685 """ 

686 

687 _allConnections: dict[str, BaseConnection] 

688 

689 def __init__(self, *, config: PipelineTaskConfig | None = None): 

690 pass 

691 

692 def __setattr__(self, name: str, value: Any) -> None: 

693 if isinstance(value, BaseConnection): 

694 previous = self._allConnections.get(name) 

695 try: 

696 getattr(self, value._connection_type_set).add(name) 

697 except AttributeError: 

698 # Attempt to call add on a frozenset, which is what these sets 

699 # are after __init__ is done. 

700 raise TypeError("Connections objects are frozen after construction.") from None 

701 if previous is not None and value._connection_type_set != previous._connection_type_set: 

702 # Connection has changed type, e.g. Input to PrerequisiteInput; 

703 # update the sets accordingly. To be extra defensive about 

704 # multiple assignments we use the type of the previous instance 

705 # instead of assuming that's the same as the type of the self, 

706 # which is just the default. Use discard instead of remove so 

707 # manually removing from these sets first is never an error. 

708 getattr(self, previous._connection_type_set).discard(name) 

709 self._allConnections[name] = value 

710 if hasattr(self.__class__, name): 

711 # Don't actually set the attribute if this was a connection 

712 # declared in the class; in that case we let the descriptor 

713 # return the value we just added to allConnections. 

714 return 

715 # Actually add the attribute. 

716 super().__setattr__(name, value) 

717 

718 def __delattr__(self, name): 

719 """Descriptor delete method.""" 

720 previous = self._allConnections.get(name) 

721 if previous is not None: 

722 # Delete this connection's name from the appropriate set, which we 

723 # have to get from the previous instance instead of assuming it's 

724 # the same set that was appropriate for the class-level default. 

725 # Use discard instead of remove so manually removing from these 

726 # sets first is never an error. 

727 try: 

728 getattr(self, previous._connection_type_set).discard(name) 

729 except AttributeError: 

730 # Attempt to call discard on a frozenset, which is what these 

731 # sets are after __init__ is done. 

732 raise TypeError("Connections objects are frozen after construction.") from None 

733 del self._allConnections[name] 

734 if hasattr(self.__class__, name): 

735 # Don't actually delete the attribute if this was a connection 

736 # declared in the class; in that case we let the descriptor 

737 # see that it's no longer present in allConnections. 

738 return 

739 # Actually delete the attribute. 

740 super().__delattr__(name) 

741 

742 def buildDatasetRefs( 

743 self, quantum: Quantum 

744 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

745 """Build `QuantizedConnection` corresponding to input 

746 `~lsst.daf.butler.Quantum`. 

747 

748 Parameters 

749 ---------- 

750 quantum : `lsst.daf.butler.Quantum` 

751 Quantum object which defines the inputs and outputs for a given 

752 unit of processing. 

753 

754 Returns 

755 ------- 

756 retVal : `tuple` of (`InputQuantizedConnection`, \ 

757 `OutputQuantizedConnection`) 

758 Namespaces mapping attribute names 

759 (identifiers of connections) to butler references defined in the 

760 input `~lsst.daf.butler.Quantum`. 

761 """ 

762 inputDatasetRefs = InputQuantizedConnection() 

763 outputDatasetRefs = OutputQuantizedConnection() 

764 

765 # populate inputDatasetRefs from quantum inputs 

766 for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs): 

767 # get the attribute identified by name 

768 attribute = getattr(self, attributeName) 

769 # if the dataset is marked to load deferred, wrap it in a 

770 # DeferredDatasetRef 

771 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

772 if attribute.deferLoad: 

773 quantumInputRefs = [ 

774 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

775 ] 

776 else: 

777 quantumInputRefs = list(quantum.inputs[attribute.name]) 

778 # Unpack arguments that are not marked multiples (list of 

779 # length one) 

780 if not attribute.multiple: 

781 if len(quantumInputRefs) > 1: 

782 raise ScalarError( 

783 "Received multiple datasets " 

784 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

785 f"for scalar connection {attributeName} " 

786 f"({quantumInputRefs[0].datasetType.name}) " 

787 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

788 ) 

789 if len(quantumInputRefs) == 0: 

790 continue 

791 setattr(inputDatasetRefs, attributeName, quantumInputRefs[0]) 

792 else: 

793 # Add to the QuantizedConnection identifier 

794 setattr(inputDatasetRefs, attributeName, quantumInputRefs) 

795 

796 # populate outputDatasetRefs from quantum outputs 

797 for attributeName in self.outputs: 

798 # get the attribute identified by name 

799 attribute = getattr(self, attributeName) 

800 value = quantum.outputs[attribute.name] 

801 # Unpack arguments that are not marked multiples (list of 

802 # length one) 

803 if not attribute.multiple: 

804 setattr(outputDatasetRefs, attributeName, value[0]) 

805 else: 

806 setattr(outputDatasetRefs, attributeName, value) 

807 

808 return inputDatasetRefs, outputDatasetRefs 

809 

810 def adjustQuantum( 

811 self, 

812 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

813 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

814 label: str, 

815 data_id: DataCoordinate, 

816 ) -> tuple[ 

817 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

818 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

819 ]: 

820 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

821 in the `lsst.daf.butler.Quantum` during the graph generation stage 

822 of the activator. 

823 

824 Parameters 

825 ---------- 

826 inputs : `dict` 

827 Dictionary whose keys are an input (regular or prerequisite) 

828 connection name and whose values are a tuple of the connection 

829 instance and a collection of associated 

830 `~lsst.daf.butler.DatasetRef` objects. 

831 The exact type of the nested collections is unspecified; it can be 

832 assumed to be multi-pass iterable and support `len` and ``in``, but 

833 it should not be mutated in place. In contrast, the outer 

834 dictionaries are guaranteed to be temporary copies that are true 

835 `dict` instances, and hence may be modified and even returned; this 

836 is especially useful for delegating to `super` (see notes below). 

837 outputs : `~collections.abc.Mapping` 

838 Mapping of output datasets, with the same structure as ``inputs``. 

839 label : `str` 

840 Label for this task in the pipeline (should be used in all 

841 diagnostic messages). 

842 data_id : `lsst.daf.butler.DataCoordinate` 

843 Data ID for this quantum in the pipeline (should be used in all 

844 diagnostic messages). 

845 

846 Returns 

847 ------- 

848 adjusted_inputs : `~collections.abc.Mapping` 

849 Mapping of the same form as ``inputs`` with updated containers of 

850 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

851 not changed should not be returned at all. Datasets may only be 

852 removed, not added. Nested collections may be of any multi-pass 

853 iterable type, and the order of iteration will set the order of 

854 iteration within `PipelineTask.runQuantum`. 

855 adjusted_outputs : `~collections.abc.Mapping` 

856 Mapping of updated output datasets, with the same structure and 

857 interpretation as ``adjusted_inputs``. 

858 

859 Raises 

860 ------ 

861 ScalarError 

862 Raised if any `Input` or `PrerequisiteInput` connection has 

863 ``multiple`` set to `False`, but multiple datasets. 

864 NoWorkFound 

865 Raised to indicate that this quantum should not be run; not enough 

866 datasets were found for a regular `Input` connection, and the 

867 quantum should be pruned or skipped. 

868 FileNotFoundError 

869 Raised to cause QuantumGraph generation to fail (with the message 

870 included in this exception); not enough datasets were found for a 

871 `PrerequisiteInput` connection. 

872 

873 Notes 

874 ----- 

875 The base class implementation performs important checks. It always 

876 returns an empty mapping (i.e. makes no adjustments). It should 

877 always called be via `super` by custom implementations, ideally at the 

878 end of the custom implementation with already-adjusted mappings when 

879 any datasets are actually dropped, e.g.: 

880 

881 .. code-block:: python 

882 

883 def adjustQuantum(self, inputs, outputs, label, data_id): 

884 # Filter out some dataset refs for one connection. 

885 connection, old_refs = inputs["my_input"] 

886 new_refs = [ref for ref in old_refs if ...] 

887 adjusted_inputs = {"my_input", (connection, new_refs)} 

888 # Update the original inputs so we can pass them to super. 

889 inputs.update(adjusted_inputs) 

890 # Can ignore outputs from super because they are guaranteed 

891 # to be empty. 

892 super().adjustQuantum(inputs, outputs, label_data_id) 

893 # Return only the connections we modified. 

894 return adjusted_inputs, {} 

895 

896 Removing outputs here is guaranteed to affect what is actually 

897 passed to `PipelineTask.runQuantum`, but its effect on the larger 

898 graph may be deferred to execution, depending on the context in 

899 which `adjustQuantum` is being run: if one quantum removes an output 

900 that is needed by a second quantum as input, the second quantum may not 

901 be adjusted (and hence pruned or skipped) until that output is actually 

902 found to be missing at execution time. 

903 

904 Tasks that desire zip-iteration consistency between any combinations of 

905 connections that have the same data ID should generally implement 

906 `adjustQuantum` to achieve this, even if they could also run that 

907 logic during execution; this allows the system to see outputs that will 

908 not be produced because the corresponding input is missing as early as 

909 possible. 

910 """ 

911 for name, (input_connection, refs) in inputs.items(): 

912 dataset_type_name = input_connection.name 

913 if not input_connection.multiple and len(refs) > 1: 

914 raise ScalarError( 

915 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

916 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

917 f"for quantum data ID {data_id}." 

918 ) 

919 if len(refs) < input_connection.minimum: 

920 if isinstance(input_connection, PrerequisiteInput): 

921 # This branch should only be possible during QG generation, 

922 # or if someone deleted the dataset between making the QG 

923 # and trying to run it. Either one should be a hard error. 

924 raise FileNotFoundError( 

925 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

926 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

927 f"{data_id}." 

928 ) 

929 else: 

930 raise NoWorkFound(label, name, input_connection) 

931 for name, (output_connection, refs) in outputs.items(): 

932 dataset_type_name = output_connection.name 

933 if not output_connection.multiple and len(refs) > 1: 

934 raise ScalarError( 

935 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

936 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

937 f"for quantum data ID {data_id}." 

938 ) 

939 return {}, {} 

940 

941 def getSpatialBoundsConnections(self) -> Iterable[str]: 

942 """Return the names of regular input and output connections whose data 

943 IDs should be used to compute the spatial bounds of this task's quanta. 

944 

945 The spatial bound for a quantum is defined as the union of the regions 

946 of all data IDs of all connections returned here, along with the region 

947 of the quantum data ID (if the task has spatial dimensions). 

948 

949 Returns 

950 ------- 

951 connection_names : `collections.abc.Iterable` [ `str` ] 

952 Names of collections with spatial dimensions. These are the 

953 task-internal connection names, not butler dataset type names. 

954 

955 Notes 

956 ----- 

957 The spatial bound is used to search for prerequisite inputs that have 

958 skypix dimensions. The default implementation returns an empty 

959 iterable, which is usually sufficient for tasks with spatial 

960 dimensions, but if a task's inputs or outputs are associated with 

961 spatial regions that extend beyond the quantum data ID's region, this 

962 method may need to be overridden to expand the set of prerequisite 

963 inputs found. 

964 

965 Tasks that do not have spatial dimensions that have skypix prerequisite 

966 inputs should always override this method, as the default spatial 

967 bounds otherwise cover the full sky. 

968 """ 

969 return () 

970 

971 def getTemporalBoundsConnections(self) -> Iterable[str]: 

972 """Return the names of regular input and output connections whose data 

973 IDs should be used to compute the temporal bounds of this task's 

974 quanta. 

975 

976 The temporal bound for a quantum is defined as the union of the 

977 timespans of all data IDs of all connections returned here, along with 

978 the timespan of the quantum data ID (if the task has temporal 

979 dimensions). 

980 

981 Returns 

982 ------- 

983 connection_names : `collections.abc.Iterable` [ `str` ] 

984 Names of collections with temporal dimensions. These are the 

985 task-internal connection names, not butler dataset type names. 

986 

987 Notes 

988 ----- 

989 The temporal bound is used to search for prerequisite inputs that are 

990 calibration datasets. The default implementation returns an empty 

991 iterable, which is usually sufficient for tasks with temporal 

992 dimensions, but if a task's inputs or outputs are associated with 

993 timespans that extend beyond the quantum data ID's timespan, this 

994 method may need to be overridden to expand the set of prerequisite 

995 inputs found. 

996 

997 Tasks that do not have temporal dimensions that do not implement this 

998 method will use an infinite timespan for any calibration lookups. 

999 """ 

1000 return () 

1001 

1002 

1003def iterConnections( 

1004 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

1005) -> Generator[BaseConnection, None, None]: 

1006 """Create an iterator over the selected connections type which yields 

1007 all the defined connections of that type. 

1008 

1009 Parameters 

1010 ---------- 

1011 connections : `PipelineTaskConnections` 

1012 An instance of a `PipelineTaskConnections` object that will be iterated 

1013 over. 

1014 connectionType : `str` 

1015 The type of connections to iterate over, valid values are inputs, 

1016 outputs, prerequisiteInputs, initInputs, initOutputs. 

1017 

1018 Yields 

1019 ------ 

1020 connection: `~.connectionTypes.BaseConnection` 

1021 A connection defined on the input connections object of the type 

1022 supplied. The yielded value Will be an derived type of 

1023 `~.connectionTypes.BaseConnection`. 

1024 """ 

1025 if isinstance(connectionType, str): 

1026 connectionType = (connectionType,) 

1027 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

1028 yield getattr(connections, name) 

1029 

1030 

1031@dataclass 

1032class AdjustQuantumHelper: 

1033 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

1034 

1035 This class holds `input` and `output` mappings in the form used by 

1036 `Quantum` and execution harness code, i.e. with 

1037 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

1038 connection-oriented mappings used inside `PipelineTaskConnections`. 

1039 """ 

1040 

1041 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1042 """Mapping of regular input and prerequisite input datasets, grouped by 

1043 `~lsst.daf.butler.DatasetType`. 

1044 """ 

1045 

1046 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1047 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

1048 """ 

1049 

1050 inputs_adjusted: bool = False 

1051 """Whether any inputs were removed in the last call to `adjust_in_place`. 

1052 """ 

1053 

1054 outputs_adjusted: bool = False 

1055 """Whether any outputs were removed in the last call to `adjust_in_place`. 

1056 """ 

1057 

1058 def adjust_in_place( 

1059 self, 

1060 connections: PipelineTaskConnections, 

1061 label: str, 

1062 data_id: DataCoordinate, 

1063 ) -> None: 

1064 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

1065 with its results. 

1066 

1067 Parameters 

1068 ---------- 

1069 connections : `PipelineTaskConnections` 

1070 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

1071 label : `str` 

1072 Label for this task in the pipeline (should be used in all 

1073 diagnostic messages). 

1074 data_id : `lsst.daf.butler.DataCoordinate` 

1075 Data ID for this quantum in the pipeline (should be used in all 

1076 diagnostic messages). 

1077 """ 

1078 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

1079 # connection-keyed, PipelineTask-oriented mappings. 

1080 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

1081 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

1082 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

1083 connection = getattr(connections, name) 

1084 dataset_type_name = connection.name 

1085 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

1086 for name in itertools.chain(connections.outputs): 

1087 connection = getattr(connections, name) 

1088 dataset_type_name = connection.name 

1089 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

1090 # Actually call adjustQuantum. 

1091 # MyPy correctly complains that this call is not quite legal, but the 

1092 # method docs explain exactly what's expected and it's the behavior we 

1093 # want. It'd be nice to avoid this if we ever have to change the 

1094 # interface anyway, but not an immediate problem. 

1095 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

1096 inputs_by_connection, # type: ignore 

1097 outputs_by_connection, # type: ignore 

1098 label, 

1099 data_id, 

1100 ) 

1101 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

1102 # installing new mappings in self if necessary. 

1103 if adjusted_inputs_by_connection: 

1104 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) 

1105 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1106 dataset_type_name = connection.name 

1107 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1108 raise RuntimeError( 

1109 f"adjustQuantum implementation for task with label {label} returned {name} " 

1110 f"({dataset_type_name}) input datasets that are not a subset of those " 

1111 f"it was given for data ID {data_id}." 

1112 ) 

1113 adjusted_inputs[dataset_type_name] = tuple(updated_refs) 

1114 self.inputs = adjusted_inputs.freeze() 

1115 self.inputs_adjusted = True 

1116 else: 

1117 self.inputs_adjusted = False 

1118 if adjusted_outputs_by_connection: 

1119 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) 

1120 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1121 dataset_type_name = connection.name 

1122 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1123 raise RuntimeError( 

1124 f"adjustQuantum implementation for task with label {label} returned {name} " 

1125 f"({dataset_type_name}) output datasets that are not a subset of those " 

1126 f"it was given for data ID {data_id}." 

1127 ) 

1128 adjusted_outputs[dataset_type_name] = tuple(updated_refs) 

1129 self.outputs = adjusted_outputs.freeze() 

1130 self.outputs_adjusted = True 

1131 else: 

1132 self.outputs_adjusted = False