Coverage for python/lsst/pipe/base/connections.py: 42%

302 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-23 08:14 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import dataclasses 

39import itertools 

40import string 

41import warnings 

42from collections import UserDict 

43from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set 

44from dataclasses import dataclass 

45from types import MappingProxyType, SimpleNamespace 

46from typing import TYPE_CHECKING, Any 

47 

48from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

49from lsst.utils.introspection import find_outside_stacklevel 

50 

51from ._status import NoWorkFound 

52from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

53 

54if TYPE_CHECKING: 

55 from .config import PipelineTaskConfig 

56 

57 

58class ScalarError(TypeError): 

59 """Exception raised when dataset type is configured as scalar 

60 but there are multiple data IDs in a Quantum for that dataset. 

61 """ 

62 

63 

64class PipelineTaskConnectionDict(UserDict): 

65 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

66 

67 This dict is used in `PipelineTaskConnection` class creation, as the 

68 dictionary that is initially used as ``__dict__``. It exists to 

69 intercept connection fields declared in a `PipelineTaskConnection`, and 

70 what name is used to identify them. The names are then added to class 

71 level list according to the connection type of the class attribute. The 

72 names are also used as keys in a class level dictionary associated with 

73 the corresponding class attribute. This information is a duplicate of 

74 what exists in ``__dict__``, but provides a simple place to lookup and 

75 iterate on only these variables. 

76 """ 

77 

78 def __init__(self, *args: Any, **kwargs: Any): 

79 super().__init__(*args, **kwargs) 

80 # Initialize class level variables used to track any declared 

81 # class level variables that are instances of 

82 # connectionTypes.BaseConnection 

83 self.data["inputs"] = set() 

84 self.data["prerequisiteInputs"] = set() 

85 self.data["outputs"] = set() 

86 self.data["initInputs"] = set() 

87 self.data["initOutputs"] = set() 

88 self.data["allConnections"] = {} 

89 

90 def __setitem__(self, name: str, value: Any) -> None: 

91 if isinstance(value, BaseConnection): 

92 if name in { 92 ↛ 102line 92 didn't jump to line 102, because the condition on line 92 was never true

93 "dimensions", 

94 "inputs", 

95 "prerequisiteInputs", 

96 "outputs", 

97 "initInputs", 

98 "initOutputs", 

99 "allConnections", 

100 }: 

101 # Guard against connections whose names are reserved. 

102 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

103 if (previous := self.data.get(name)) is not None: 103 ↛ 106line 103 didn't jump to line 106, because the condition on line 103 was never true

104 # Guard against changing the type of an in inherited connection 

105 # by first removing it from the set it's current in. 

106 self.data[previous._connection_type_set].discard(name) 

107 object.__setattr__(value, "varName", name) 

108 self.data["allConnections"][name] = value 

109 self.data[value._connection_type_set].add(name) 

110 # defer to the default behavior 

111 super().__setitem__(name, value) 

112 

113 

114class PipelineTaskConnectionsMetaclass(type): 

115 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

116 

117 # We can annotate these attributes as `collections.abc.Set` to discourage 

118 # undesirable modifications in type-checked code, since the internal code 

119 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

120 # these annotations anyway. 

121 

122 dimensions: Set[str] 

123 """Set of dimension names that define the unit of work for this task. 

124 

125 Required and implied dependencies will automatically be expanded later and 

126 need not be provided. 

127 

128 This is shadowed by an instance-level attribute on 

129 `PipelineTaskConnections` instances. 

130 """ 

131 

132 inputs: Set[str] 

133 """Set with the names of all `~connectionTypes.Input` connection 

134 attributes. 

135 

136 This is updated automatically as class attributes are added. Note that 

137 this attribute is shadowed by an instance-level attribute on 

138 `PipelineTaskConnections` instances. 

139 """ 

140 

141 prerequisiteInputs: Set[str] 

142 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

143 connection attributes. 

144 

145 See `inputs` for additional information. 

146 """ 

147 

148 outputs: Set[str] 

149 """Set with the names of all `~connectionTypes.Output` connection 

150 attributes. 

151 

152 See `inputs` for additional information. 

153 """ 

154 

155 initInputs: Set[str] 

156 """Set with the names of all `~connectionTypes.InitInput` connection 

157 attributes. 

158 

159 See `inputs` for additional information. 

160 """ 

161 

162 initOutputs: Set[str] 

163 """Set with the names of all `~connectionTypes.InitOutput` connection 

164 attributes. 

165 

166 See `inputs` for additional information. 

167 """ 

168 

169 allConnections: Mapping[str, BaseConnection] 

170 """Mapping containing all connection attributes. 

171 

172 See `inputs` for additional information. 

173 """ 

174 

175 def __prepare__(name, bases, **kwargs): # noqa: N804 

176 # Create an instance of our special dict to catch and track all 

177 # variables that are instances of connectionTypes.BaseConnection 

178 # Copy any existing connections from a parent class 

179 dct = PipelineTaskConnectionDict() 

180 for base in bases: 

181 if isinstance(base, PipelineTaskConnectionsMetaclass): 181 ↛ 180line 181 didn't jump to line 180, because the condition on line 181 was never false

182 for name, value in base.allConnections.items(): 182 ↛ 183line 182 didn't jump to line 183, because the loop on line 182 never started

183 dct[name] = value 

184 return dct 

185 

186 def __new__(cls, name, bases, dct, **kwargs): 

187 dimensionsValueError = TypeError( 

188 "PipelineTaskConnections class must be created with a dimensions " 

189 "attribute which is an iterable of dimension names" 

190 ) 

191 

192 if name != "PipelineTaskConnections": 

193 # Verify that dimensions are passed as a keyword in class 

194 # declaration 

195 if "dimensions" not in kwargs: 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true

196 for base in bases: 

197 if hasattr(base, "dimensions"): 

198 kwargs["dimensions"] = base.dimensions 

199 break 

200 if "dimensions" not in kwargs: 

201 raise dimensionsValueError 

202 try: 

203 if isinstance(kwargs["dimensions"], str): 203 ↛ 204line 203 didn't jump to line 204, because the condition on line 203 was never true

204 raise TypeError( 

205 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

206 ) 

207 if not isinstance(kwargs["dimensions"], Iterable): 207 ↛ 208line 207 didn't jump to line 208, because the condition on line 207 was never true

208 raise TypeError("Dimensions must be iterable of dimensions") 

209 dct["dimensions"] = set(kwargs["dimensions"]) 

210 except TypeError as exc: 

211 raise dimensionsValueError from exc 

212 # Lookup any python string templates that may have been used in the 

213 # declaration of the name field of a class connection attribute 

214 allTemplates = set() 

215 stringFormatter = string.Formatter() 

216 # Loop over all connections 

217 for obj in dct["allConnections"].values(): 

218 nameValue = obj.name 

219 # add all the parameters to the set of templates 

220 for param in stringFormatter.parse(nameValue): 

221 if param[1] is not None: 

222 allTemplates.add(param[1]) 

223 

224 # look up any template from base classes and merge them all 

225 # together 

226 mergeDict = {} 

227 mergeDeprecationsDict = {} 

228 for base in bases[::-1]: 

229 if hasattr(base, "defaultTemplates"): 229 ↛ 230line 229 didn't jump to line 230, because the condition on line 229 was never true

230 mergeDict.update(base.defaultTemplates) 

231 if hasattr(base, "deprecatedTemplates"): 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true

232 mergeDeprecationsDict.update(base.deprecatedTemplates) 

233 if "defaultTemplates" in kwargs: 

234 mergeDict.update(kwargs["defaultTemplates"]) 

235 if "deprecatedTemplates" in kwargs: 235 ↛ 236line 235 didn't jump to line 236, because the condition on line 235 was never true

236 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"]) 

237 if len(mergeDict) > 0: 

238 kwargs["defaultTemplates"] = mergeDict 

239 if len(mergeDeprecationsDict) > 0: 239 ↛ 240line 239 didn't jump to line 240, because the condition on line 239 was never true

240 kwargs["deprecatedTemplates"] = mergeDeprecationsDict 

241 

242 # Verify that if templated strings were used, defaults were 

243 # supplied as an argument in the declaration of the connection 

244 # class 

245 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 245 ↛ 246line 245 didn't jump to line 246, because the condition on line 245 was never true

246 raise TypeError( 

247 "PipelineTaskConnection class contains templated attribute names, but no " 

248 "defaut templates were provided, add a dictionary attribute named " 

249 "defaultTemplates which contains the mapping between template key and value" 

250 ) 

251 if len(allTemplates) > 0: 

252 # Verify all templates have a default, and throw if they do not 

253 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

254 templateDifference = allTemplates.difference(defaultTemplateKeys) 

255 if templateDifference: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true

256 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

257 # Verify that templates do not share names with variable names 

258 # used for a connection, this is needed because of how 

259 # templates are specified in an associated config class. 

260 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

261 if len(nameTemplateIntersection) > 0: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true

262 raise TypeError( 

263 "Template parameters cannot share names with Class attributes" 

264 f" (conflicts are {nameTemplateIntersection})." 

265 ) 

266 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

267 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {}) 

268 

269 # Convert all the connection containers into frozensets so they cannot 

270 # be modified at the class scope 

271 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

272 dct[connectionName] = frozenset(dct[connectionName]) 

273 # our custom dict type must be turned into an actual dict to be used in 

274 # type.__new__ 

275 return super().__new__(cls, name, bases, dict(dct)) 

276 

277 def __init__(cls, name, bases, dct, **kwargs): 

278 # This overrides the default init to drop the kwargs argument. Python 

279 # metaclasses will have this argument set if any kwargs are passes at 

280 # class construction time, but should be consumed before calling 

281 # __init__ on the type metaclass. This is in accordance with python 

282 # documentation on metaclasses 

283 super().__init__(name, bases, dct) 

284 

285 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

286 # MyPy appears not to really understand metaclass.__call__ at all, so 

287 # we need to tell it to ignore __new__ and __init__ calls here. 

288 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

289 

290 # Make mutable copies of all set-like class attributes so derived 

291 # __init__ implementations can modify them in place. 

292 instance.dimensions = set(cls.dimensions) 

293 instance.inputs = set(cls.inputs) 

294 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

295 instance.outputs = set(cls.outputs) 

296 instance.initInputs = set(cls.initInputs) 

297 instance.initOutputs = set(cls.initOutputs) 

298 

299 # Set self.config. It's a bit strange that we claim to accept None but 

300 # really just raise here, but it's not worth changing now. 

301 from .config import PipelineTaskConfig # local import to avoid cycle 

302 

303 if config is None or not isinstance(config, PipelineTaskConfig): 

304 raise ValueError( 

305 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

306 ) 

307 instance.config = config 

308 

309 # Extract the template names that were defined in the config instance 

310 # by looping over the keys of the defaultTemplates dict specified at 

311 # class declaration time. 

312 templateValues = { 

313 name: getattr(config.connections, name) for name in getattr(cls, "defaultTemplates").keys() 

314 } 

315 

316 # We now assemble a mapping of all connection instances keyed by 

317 # internal name, applying the configuration and templates to make new 

318 # configurations from the class-attribute defaults. This will be 

319 # private, but with a public read-only view. This mapping is what the 

320 # descriptor interface of the class-level attributes will return when 

321 # they are accessed on an instance. This is better than just assigning 

322 # regular instance attributes as it makes it so removed connections 

323 # cannot be accessed on instances, instead of having access to them 

324 # silent fall through to the not-removed class connection instance. 

325 instance._allConnections = {} 

326 instance.allConnections = MappingProxyType(instance._allConnections) 

327 for internal_name, connection in cls.allConnections.items(): 

328 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

329 instance_connection = dataclasses.replace( 

330 connection, 

331 name=dataset_type_name, 

332 doc=( 

333 connection.doc 

334 if connection.deprecated is None 

335 else f"{connection.doc}\n{connection.deprecated}" 

336 ), 

337 ) 

338 instance._allConnections[internal_name] = instance_connection 

339 

340 # Finally call __init__. The base class implementation does nothing; 

341 # we could have left some of the above implementation there (where it 

342 # originated), but putting it here instead makes it hard for derived 

343 # class implementors to get things into a weird state by delegating to 

344 # super().__init__ in the wrong place, or by forgetting to do that 

345 # entirely. 

346 instance.__init__(config=config) # type: ignore 

347 

348 # Derived-class implementations may have changed the contents of the 

349 # various kinds-of-connection sets; update allConnections to have keys 

350 # that are a union of all those. We get values for the new 

351 # allConnections from the attributes, since any dynamically added new 

352 # ones will not be present in the old allConnections. Typically those 

353 # getattrs will invoke the descriptors and get things from the old 

354 # allConnections anyway. After processing each set we replace it with 

355 # a frozenset. 

356 updated_all_connections = {} 

357 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

358 updated_connection_names = getattr(instance, attrName) 

359 updated_all_connections.update( 

360 {name: getattr(instance, name) for name in updated_connection_names} 

361 ) 

362 # Setting these to frozenset is at odds with the type annotation, 

363 # but MyPy can't tell because we're using setattr, and we want to 

364 # lie to it anyway to get runtime guards against post-__init__ 

365 # mutation. 

366 setattr(instance, attrName, frozenset(updated_connection_names)) 

367 # Update the existing dict in place, since we already have a view of 

368 # that. 

369 instance._allConnections.clear() 

370 instance._allConnections.update(updated_all_connections) 

371 

372 for obj in instance._allConnections.values(): 

373 if obj.deprecated is not None: 

374 warnings.warn( 

375 obj.deprecated, FutureWarning, stacklevel=find_outside_stacklevel("lsst.pipe.base") 

376 ) 

377 

378 # Freeze the connection instance dimensions now. This at odds with the 

379 # type annotation, which says [mutable] `set`, just like the connection 

380 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

381 # tell with those since we're using setattr for them. 

382 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

383 

384 return instance 

385 

386 

387class QuantizedConnection(SimpleNamespace): 

388 r"""A Namespace to map defined variable names of connections to the 

389 associated `lsst.daf.butler.DatasetRef` objects. 

390 

391 This class maps the names used to define a connection on a 

392 `PipelineTaskConnections` class to the corresponding 

393 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum` 

394 instance. This will be a quantum of execution based on the graph created 

395 by examining all the connections defined on the 

396 `PipelineTaskConnections` class. 

397 """ 

398 

399 def __init__(self, **kwargs): 

400 # Create a variable to track what attributes are added. This is used 

401 # later when iterating over this QuantizedConnection instance 

402 object.__setattr__(self, "_attributes", set()) 

403 

404 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

405 # Capture the attribute name as it is added to this object 

406 self._attributes.add(name) 

407 super().__setattr__(name, value) 

408 

409 def __delattr__(self, name): 

410 object.__delattr__(self, name) 

411 self._attributes.remove(name) 

412 

413 def __len__(self) -> int: 

414 return len(self._attributes) 

415 

416 def __iter__( 

417 self, 

418 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

419 """Make an iterator for this `QuantizedConnection`. 

420 

421 Iterating over a `QuantizedConnection` will yield a tuple with the name 

422 of an attribute and the value associated with that name. This is 

423 similar to dict.items() but is on the namespace attributes rather than 

424 dict keys. 

425 """ 

426 yield from ((name, getattr(self, name)) for name in self._attributes) 

427 

428 def keys(self) -> Generator[str, None, None]: 

429 """Return an iterator over all the attributes added to a 

430 `QuantizedConnection` class 

431 """ 

432 yield from self._attributes 

433 

434 

435class InputQuantizedConnection(QuantizedConnection): 

436 """Input variant of a `QuantizedConnection`.""" 

437 

438 pass 

439 

440 

441class OutputQuantizedConnection(QuantizedConnection): 

442 """Output variant of a `QuantizedConnection`.""" 

443 

444 pass 

445 

446 

447@dataclass(frozen=True) 

448class DeferredDatasetRef: 

449 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

450 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

451 instead of an in-memory dataset. 

452 

453 Parameters 

454 ---------- 

455 datasetRef : `lsst.daf.butler.DatasetRef` 

456 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

457 resolve a dataset 

458 """ 

459 

460 datasetRef: DatasetRef 

461 

462 @property 

463 def datasetType(self) -> DatasetType: 

464 """The dataset type for this dataset.""" 

465 return self.datasetRef.datasetType 

466 

467 @property 

468 def dataId(self) -> DataCoordinate: 

469 """The data ID for this dataset.""" 

470 return self.datasetRef.dataId 

471 

472 

473class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

474 """PipelineTaskConnections is a class used to declare desired IO when a 

475 PipelineTask is run by an activator 

476 

477 Parameters 

478 ---------- 

479 config : `PipelineTaskConfig` 

480 A `PipelineTaskConfig` class instance whose class has been configured 

481 to use this `PipelineTaskConnections` class. 

482 

483 See Also 

484 -------- 

485 iterConnections 

486 

487 Notes 

488 ----- 

489 ``PipelineTaskConnection`` classes are created by declaring class 

490 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

491 listed as follows: 

492 

493 * ``InitInput`` - Defines connections in a quantum graph which are used as 

494 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

495 to this class 

496 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

497 persisted using a butler at the end of the ``__init__`` function of the 

498 `PipelineTask` corresponding to this class. The variable name used to 

499 define this connection should be the same as an attribute name on the 

500 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

501 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

502 a `PipelineTask` instance should have an attribute 

503 ``self.outputSchema`` defined. Its value is what will be saved by the 

504 activator framework. 

505 * ``PrerequisiteInput`` - An input connection type that defines a 

506 `lsst.daf.butler.DatasetType` that must be present at execution time, 

507 but that will not be used during the course of creating the quantum 

508 graph to be executed. These most often are things produced outside the 

509 processing pipeline, such as reference catalogs. 

510 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

511 in the ``run`` method of a `PipelineTask`. The name used to declare 

512 class attribute must match a function argument name in the ``run`` 

513 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

514 defines an ``Input`` with the name ``calexp``, then the corresponding 

515 signature should be ``PipelineTask.run(calexp, ...)`` 

516 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

517 execution of a `PipelineTask`. The name used to declare the connection 

518 must correspond to an attribute of a `Struct` that is returned by a 

519 `PipelineTask` ``run`` method. E.g. if an output connection is 

520 defined with the name ``measCat``, then the corresponding 

521 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

522 X matches the ``storageClass`` type defined on the output connection. 

523 

524 Attributes of these types can also be created, replaced, or deleted on the 

525 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

526 just the name depends on the configuration. It is preferred to define them 

527 in the class when possible (even if configuration may cause the connection 

528 to be removed from the instance). 

529 

530 The process of declaring a ``PipelineTaskConnection`` class involves 

531 parameters passed in the declaration statement. 

532 

533 The first parameter is ``dimensions`` which is an iterable of strings which 

534 defines the unit of processing the run method of a corresponding 

535 `PipelineTask` will operate on. These dimensions must match dimensions that 

536 exist in the butler registry which will be used in executing the 

537 corresponding `PipelineTask`. The dimensions may be also modified in 

538 subclass ``__init__`` methods if they need to depend on configuration. 

539 

540 The second parameter is labeled ``defaultTemplates`` and is conditionally 

541 optional. The name attributes of connections can be specified as python 

542 format strings, with named format arguments. If any of the name parameters 

543 on connections defined in a `PipelineTaskConnections` class contain a 

544 template, then a default template value must be specified in the 

545 ``defaultTemplates`` argument. This is done by passing a dictionary with 

546 keys corresponding to a template identifier, and values corresponding to 

547 the value to use as a default when formatting the string. For example if 

548 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

549 ``defaultTemplates`` = {'input': 'deep'}. 

550 

551 Once a `PipelineTaskConnections` class is created, it is used in the 

552 creation of a `PipelineTaskConfig`. This is further documented in the 

553 documentation of `PipelineTaskConfig`. For the purposes of this 

554 documentation, the relevant information is that the config class allows 

555 configuration of connection names by users when running a pipeline. 

556 

557 Instances of a `PipelineTaskConnections` class are used by the pipeline 

558 task execution framework to introspect what a corresponding `PipelineTask` 

559 will require, and what it will produce. 

560 

561 Examples 

562 -------- 

563 >>> from lsst.pipe.base import connectionTypes as cT 

564 >>> from lsst.pipe.base import PipelineTaskConnections 

565 >>> from lsst.pipe.base import PipelineTaskConfig 

566 >>> class ExampleConnections(PipelineTaskConnections, 

567 ... dimensions=("A", "B"), 

568 ... defaultTemplates={"foo": "Example"}): 

569 ... inputConnection = cT.Input(doc="Example input", 

570 ... dimensions=("A", "B"), 

571 ... storageClass=Exposure, 

572 ... name="{foo}Dataset") 

573 ... outputConnection = cT.Output(doc="Example output", 

574 ... dimensions=("A", "B"), 

575 ... storageClass=Exposure, 

576 ... name="{foo}output") 

577 >>> class ExampleConfig(PipelineTaskConfig, 

578 ... pipelineConnections=ExampleConnections): 

579 ... pass 

580 >>> config = ExampleConfig() 

581 >>> config.connections.foo = Modified 

582 >>> config.connections.outputConnection = "TotallyDifferent" 

583 >>> connections = ExampleConnections(config=config) 

584 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

585 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

586 """ 

587 

588 # We annotate these attributes as mutable sets because that's what they are 

589 # inside derived ``__init__`` implementations and that's what matters most 

590 # After that's done, the metaclass __call__ makes them into frozensets, but 

591 # relatively little code interacts with them then, and that code knows not 

592 # to try to modify them without having to be told that by mypy. 

593 

594 dimensions: set[str] 

595 """Set of dimension names that define the unit of work for this task. 

596 

597 Required and implied dependencies will automatically be expanded later and 

598 need not be provided. 

599 

600 This may be replaced or modified in ``__init__`` to change the dimensions 

601 of the task. After ``__init__`` it will be a `frozenset` and may not be 

602 replaced. 

603 """ 

604 

605 inputs: set[str] 

606 """Set with the names of all `connectionTypes.Input` connection attributes. 

607 

608 This is updated automatically as class attributes are added, removed, or 

609 replaced in ``__init__``. Removing entries from this set will cause those 

610 connections to be removed after ``__init__`` completes, but this is 

611 supported only for backwards compatibility; new code should instead just 

612 delete the collection attributed directly. After ``__init__`` this will be 

613 a `frozenset` and may not be replaced. 

614 """ 

615 

616 prerequisiteInputs: set[str] 

617 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

618 connection attributes. 

619 

620 See `inputs` for additional information. 

621 """ 

622 

623 outputs: set[str] 

624 """Set with the names of all `~connectionTypes.Output` connection 

625 attributes. 

626 

627 See `inputs` for additional information. 

628 """ 

629 

630 initInputs: set[str] 

631 """Set with the names of all `~connectionTypes.InitInput` connection 

632 attributes. 

633 

634 See `inputs` for additional information. 

635 """ 

636 

637 initOutputs: set[str] 

638 """Set with the names of all `~connectionTypes.InitOutput` connection 

639 attributes. 

640 

641 See `inputs` for additional information. 

642 """ 

643 

644 allConnections: Mapping[str, BaseConnection] 

645 """Mapping holding all connection attributes. 

646 

647 This is a read-only view that is automatically updated when connection 

648 attributes are added, removed, or replaced in ``__init__``. It is also 

649 updated after ``__init__`` completes to reflect changes in `inputs`, 

650 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

651 """ 

652 

653 _allConnections: dict[str, BaseConnection] 

654 

655 def __init__(self, *, config: PipelineTaskConfig | None = None): 

656 pass 

657 

658 def __setattr__(self, name: str, value: Any) -> None: 

659 if isinstance(value, BaseConnection): 

660 previous = self._allConnections.get(name) 

661 try: 

662 getattr(self, value._connection_type_set).add(name) 

663 except AttributeError: 

664 # Attempt to call add on a frozenset, which is what these sets 

665 # are after __init__ is done. 

666 raise TypeError("Connections objects are frozen after construction.") from None 

667 if previous is not None and value._connection_type_set != previous._connection_type_set: 

668 # Connection has changed type, e.g. Input to PrerequisiteInput; 

669 # update the sets accordingly. To be extra defensive about 

670 # multiple assignments we use the type of the previous instance 

671 # instead of assuming that's the same as the type of the self, 

672 # which is just the default. Use discard instead of remove so 

673 # manually removing from these sets first is never an error. 

674 getattr(self, previous._connection_type_set).discard(name) 

675 self._allConnections[name] = value 

676 if hasattr(self.__class__, name): 

677 # Don't actually set the attribute if this was a connection 

678 # declared in the class; in that case we let the descriptor 

679 # return the value we just added to allConnections. 

680 return 

681 # Actually add the attribute. 

682 super().__setattr__(name, value) 

683 

684 def __delattr__(self, name): 

685 """Descriptor delete method.""" 

686 previous = self._allConnections.get(name) 

687 if previous is not None: 

688 # Delete this connection's name from the appropriate set, which we 

689 # have to get from the previous instance instead of assuming it's 

690 # the same set that was appropriate for the class-level default. 

691 # Use discard instead of remove so manually removing from these 

692 # sets first is never an error. 

693 try: 

694 getattr(self, previous._connection_type_set).discard(name) 

695 except AttributeError: 

696 # Attempt to call discard on a frozenset, which is what these 

697 # sets are after __init__ is done. 

698 raise TypeError("Connections objects are frozen after construction.") from None 

699 del self._allConnections[name] 

700 if hasattr(self.__class__, name): 

701 # Don't actually delete the attribute if this was a connection 

702 # declared in the class; in that case we let the descriptor 

703 # see that it's no longer present in allConnections. 

704 return 

705 # Actually delete the attribute. 

706 super().__delattr__(name) 

707 

708 def buildDatasetRefs( 

709 self, quantum: Quantum 

710 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

711 """Build `QuantizedConnection` corresponding to input 

712 `~lsst.daf.butler.Quantum`. 

713 

714 Parameters 

715 ---------- 

716 quantum : `lsst.daf.butler.Quantum` 

717 Quantum object which defines the inputs and outputs for a given 

718 unit of processing. 

719 

720 Returns 

721 ------- 

722 retVal : `tuple` of (`InputQuantizedConnection`, 

723 `OutputQuantizedConnection`) Namespaces mapping attribute names 

724 (identifiers of connections) to butler references defined in the 

725 input `lsst.daf.butler.Quantum`. 

726 """ 

727 inputDatasetRefs = InputQuantizedConnection() 

728 outputDatasetRefs = OutputQuantizedConnection() 

729 # operate on a reference object and an interable of names of class 

730 # connection attributes 

731 for refs, names in zip( 

732 (inputDatasetRefs, outputDatasetRefs), 

733 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

734 ): 

735 # get a name of a class connection attribute 

736 for attributeName in names: 

737 # get the attribute identified by name 

738 attribute = getattr(self, attributeName) 

739 # Branch if the attribute dataset type is an input 

740 if attribute.name in quantum.inputs: 

741 # if the dataset is marked to load deferred, wrap it in a 

742 # DeferredDatasetRef 

743 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

744 if attribute.deferLoad: 

745 quantumInputRefs = [ 

746 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

747 ] 

748 else: 

749 quantumInputRefs = list(quantum.inputs[attribute.name]) 

750 # Unpack arguments that are not marked multiples (list of 

751 # length one) 

752 if not attribute.multiple: 

753 if len(quantumInputRefs) > 1: 

754 raise ScalarError( 

755 "Received multiple datasets " 

756 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

757 f"for scalar connection {attributeName} " 

758 f"({quantumInputRefs[0].datasetType.name}) " 

759 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

760 ) 

761 if len(quantumInputRefs) == 0: 

762 continue 

763 setattr(refs, attributeName, quantumInputRefs[0]) 

764 else: 

765 # Add to the QuantizedConnection identifier 

766 setattr(refs, attributeName, quantumInputRefs) 

767 # Branch if the attribute dataset type is an output 

768 elif attribute.name in quantum.outputs: 

769 value = quantum.outputs[attribute.name] 

770 # Unpack arguments that are not marked multiples (list of 

771 # length one) 

772 if not attribute.multiple: 

773 setattr(refs, attributeName, value[0]) 

774 else: 

775 setattr(refs, attributeName, value) 

776 # Specified attribute is not in inputs or outputs dont know how 

777 # to handle, throw 

778 else: 

779 raise ValueError( 

780 f"Attribute with name {attributeName} has no counterpart in input quantum" 

781 ) 

782 return inputDatasetRefs, outputDatasetRefs 

783 

784 def adjustQuantum( 

785 self, 

786 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

787 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

788 label: str, 

789 data_id: DataCoordinate, 

790 ) -> tuple[ 

791 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

792 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

793 ]: 

794 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

795 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

796 of the activator. 

797 

798 Parameters 

799 ---------- 

800 inputs : `dict` 

801 Dictionary whose keys are an input (regular or prerequisite) 

802 connection name and whose values are a tuple of the connection 

803 instance and a collection of associated 

804 `~lsst.daf.butler.DatasetRef` objects. 

805 The exact type of the nested collections is unspecified; it can be 

806 assumed to be multi-pass iterable and support `len` and ``in``, but 

807 it should not be mutated in place. In contrast, the outer 

808 dictionaries are guaranteed to be temporary copies that are true 

809 `dict` instances, and hence may be modified and even returned; this 

810 is especially useful for delegating to `super` (see notes below). 

811 outputs : `~collections.abc.Mapping` 

812 Mapping of output datasets, with the same structure as ``inputs``. 

813 label : `str` 

814 Label for this task in the pipeline (should be used in all 

815 diagnostic messages). 

816 data_id : `lsst.daf.butler.DataCoordinate` 

817 Data ID for this quantum in the pipeline (should be used in all 

818 diagnostic messages). 

819 

820 Returns 

821 ------- 

822 adjusted_inputs : `~collections.abc.Mapping` 

823 Mapping of the same form as ``inputs`` with updated containers of 

824 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

825 not changed should not be returned at all. Datasets may only be 

826 removed, not added. Nested collections may be of any multi-pass 

827 iterable type, and the order of iteration will set the order of 

828 iteration within `PipelineTask.runQuantum`. 

829 adjusted_outputs : `~collections.abc.Mapping` 

830 Mapping of updated output datasets, with the same structure and 

831 interpretation as ``adjusted_inputs``. 

832 

833 Raises 

834 ------ 

835 ScalarError 

836 Raised if any `Input` or `PrerequisiteInput` connection has 

837 ``multiple`` set to `False`, but multiple datasets. 

838 NoWorkFound 

839 Raised to indicate that this quantum should not be run; not enough 

840 datasets were found for a regular `Input` connection, and the 

841 quantum should be pruned or skipped. 

842 FileNotFoundError 

843 Raised to cause QuantumGraph generation to fail (with the message 

844 included in this exception); not enough datasets were found for a 

845 `PrerequisiteInput` connection. 

846 

847 Notes 

848 ----- 

849 The base class implementation performs important checks. It always 

850 returns an empty mapping (i.e. makes no adjustments). It should 

851 always called be via `super` by custom implementations, ideally at the 

852 end of the custom implementation with already-adjusted mappings when 

853 any datasets are actually dropped, e.g.: 

854 

855 .. code-block:: python 

856 

857 def adjustQuantum(self, inputs, outputs, label, data_id): 

858 # Filter out some dataset refs for one connection. 

859 connection, old_refs = inputs["my_input"] 

860 new_refs = [ref for ref in old_refs if ...] 

861 adjusted_inputs = {"my_input", (connection, new_refs)} 

862 # Update the original inputs so we can pass them to super. 

863 inputs.update(adjusted_inputs) 

864 # Can ignore outputs from super because they are guaranteed 

865 # to be empty. 

866 super().adjustQuantum(inputs, outputs, label_data_id) 

867 # Return only the connections we modified. 

868 return adjusted_inputs, {} 

869 

870 Removing outputs here is guaranteed to affect what is actually 

871 passed to `PipelineTask.runQuantum`, but its effect on the larger 

872 graph may be deferred to execution, depending on the context in 

873 which `adjustQuantum` is being run: if one quantum removes an output 

874 that is needed by a second quantum as input, the second quantum may not 

875 be adjusted (and hence pruned or skipped) until that output is actually 

876 found to be missing at execution time. 

877 

878 Tasks that desire zip-iteration consistency between any combinations of 

879 connections that have the same data ID should generally implement 

880 `adjustQuantum` to achieve this, even if they could also run that 

881 logic during execution; this allows the system to see outputs that will 

882 not be produced because the corresponding input is missing as early as 

883 possible. 

884 """ 

885 for name, (input_connection, refs) in inputs.items(): 

886 dataset_type_name = input_connection.name 

887 if not input_connection.multiple and len(refs) > 1: 

888 raise ScalarError( 

889 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

890 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

891 f"for quantum data ID {data_id}." 

892 ) 

893 if len(refs) < input_connection.minimum: 

894 if isinstance(input_connection, PrerequisiteInput): 

895 # This branch should only be possible during QG generation, 

896 # or if someone deleted the dataset between making the QG 

897 # and trying to run it. Either one should be a hard error. 

898 raise FileNotFoundError( 

899 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

900 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

901 f"{data_id}." 

902 ) 

903 else: 

904 # This branch should be impossible during QG generation, 

905 # because that algorithm can only make quanta whose inputs 

906 # are either already present or should be created during 

907 # execution. It can trigger during execution if the input 

908 # wasn't actually created by an upstream task in the same 

909 # graph. 

910 raise NoWorkFound(label, name, input_connection) 

911 for name, (output_connection, refs) in outputs.items(): 

912 dataset_type_name = output_connection.name 

913 if not output_connection.multiple and len(refs) > 1: 

914 raise ScalarError( 

915 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

916 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

917 f"for quantum data ID {data_id}." 

918 ) 

919 return {}, {} 

920 

921 

922def iterConnections( 

923 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

924) -> Generator[BaseConnection, None, None]: 

925 """Create an iterator over the selected connections type which yields 

926 all the defined connections of that type. 

927 

928 Parameters 

929 ---------- 

930 connections : `PipelineTaskConnections` 

931 An instance of a `PipelineTaskConnections` object that will be iterated 

932 over. 

933 connectionType : `str` 

934 The type of connections to iterate over, valid values are inputs, 

935 outputs, prerequisiteInputs, initInputs, initOutputs. 

936 

937 Yields 

938 ------ 

939 connection: `~.connectionTypes.BaseConnection` 

940 A connection defined on the input connections object of the type 

941 supplied. The yielded value Will be an derived type of 

942 `~.connectionTypes.BaseConnection`. 

943 """ 

944 if isinstance(connectionType, str): 

945 connectionType = (connectionType,) 

946 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

947 yield getattr(connections, name) 

948 

949 

950@dataclass 

951class AdjustQuantumHelper: 

952 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

953 

954 This class holds `input` and `output` mappings in the form used by 

955 `Quantum` and execution harness code, i.e. with 

956 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

957 connection-oriented mappings used inside `PipelineTaskConnections`. 

958 """ 

959 

960 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

961 """Mapping of regular input and prerequisite input datasets, grouped by 

962 `~lsst.daf.butler.DatasetType`. 

963 """ 

964 

965 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

966 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

967 """ 

968 

969 inputs_adjusted: bool = False 

970 """Whether any inputs were removed in the last call to `adjust_in_place`. 

971 """ 

972 

973 outputs_adjusted: bool = False 

974 """Whether any outputs were removed in the last call to `adjust_in_place`. 

975 """ 

976 

977 def adjust_in_place( 

978 self, 

979 connections: PipelineTaskConnections, 

980 label: str, 

981 data_id: DataCoordinate, 

982 ) -> None: 

983 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

984 with its results. 

985 

986 Parameters 

987 ---------- 

988 connections : `PipelineTaskConnections` 

989 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

990 label : `str` 

991 Label for this task in the pipeline (should be used in all 

992 diagnostic messages). 

993 data_id : `lsst.daf.butler.DataCoordinate` 

994 Data ID for this quantum in the pipeline (should be used in all 

995 diagnostic messages). 

996 """ 

997 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

998 # connection-keyed, PipelineTask-oriented mappings. 

999 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

1000 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

1001 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

1002 connection = getattr(connections, name) 

1003 dataset_type_name = connection.name 

1004 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

1005 for name in itertools.chain(connections.outputs): 

1006 connection = getattr(connections, name) 

1007 dataset_type_name = connection.name 

1008 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

1009 # Actually call adjustQuantum. 

1010 # MyPy correctly complains that this call is not quite legal, but the 

1011 # method docs explain exactly what's expected and it's the behavior we 

1012 # want. It'd be nice to avoid this if we ever have to change the 

1013 # interface anyway, but not an immediate problem. 

1014 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

1015 inputs_by_connection, # type: ignore 

1016 outputs_by_connection, # type: ignore 

1017 label, 

1018 data_id, 

1019 ) 

1020 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

1021 # installing new mappings in self if necessary. 

1022 if adjusted_inputs_by_connection: 

1023 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) 

1024 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1025 dataset_type_name = connection.name 

1026 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1027 raise RuntimeError( 

1028 f"adjustQuantum implementation for task with label {label} returned {name} " 

1029 f"({dataset_type_name}) input datasets that are not a subset of those " 

1030 f"it was given for data ID {data_id}." 

1031 ) 

1032 adjusted_inputs[dataset_type_name] = tuple(updated_refs) 

1033 self.inputs = adjusted_inputs.freeze() 

1034 self.inputs_adjusted = True 

1035 else: 

1036 self.inputs_adjusted = False 

1037 if adjusted_outputs_by_connection: 

1038 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) 

1039 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1040 dataset_type_name = connection.name 

1041 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1042 raise RuntimeError( 

1043 f"adjustQuantum implementation for task with label {label} returned {name} " 

1044 f"({dataset_type_name}) output datasets that are not a subset of those " 

1045 f"it was given for data ID {data_id}." 

1046 ) 

1047 adjusted_outputs[dataset_type_name] = tuple(updated_refs) 

1048 self.outputs = adjusted_outputs.freeze() 

1049 self.outputs_adjusted = True 

1050 else: 

1051 self.outputs_adjusted = False