Coverage for python/lsst/pipe/base/connections.py: 43%

302 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-06 02:28 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import dataclasses 

39import itertools 

40import string 

41import warnings 

42from collections import UserDict 

43from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set 

44from dataclasses import dataclass 

45from types import MappingProxyType, SimpleNamespace 

46from typing import TYPE_CHECKING, Any 

47 

48from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

49from lsst.utils.introspection import find_outside_stacklevel 

50 

51from ._status import NoWorkFound 

52from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

53 

54if TYPE_CHECKING: 

55 from .config import PipelineTaskConfig 

56 

57 

58class ScalarError(TypeError): 

59 """Exception raised when dataset type is configured as scalar 

60 but there are multiple data IDs in a Quantum for that dataset. 

61 """ 

62 

63 

64class PipelineTaskConnectionDict(UserDict): 

65 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

66 

67 This dict is used in `PipelineTaskConnection` class creation, as the 

68 dictionary that is initially used as ``__dict__``. It exists to 

69 intercept connection fields declared in a `PipelineTaskConnection`, and 

70 what name is used to identify them. The names are then added to class 

71 level list according to the connection type of the class attribute. The 

72 names are also used as keys in a class level dictionary associated with 

73 the corresponding class attribute. This information is a duplicate of 

74 what exists in ``__dict__``, but provides a simple place to lookup and 

75 iterate on only these variables. 

76 """ 

77 

78 def __init__(self, *args: Any, **kwargs: Any): 

79 super().__init__(*args, **kwargs) 

80 # Initialize class level variables used to track any declared 

81 # class level variables that are instances of 

82 # connectionTypes.BaseConnection 

83 self.data["inputs"] = set() 

84 self.data["prerequisiteInputs"] = set() 

85 self.data["outputs"] = set() 

86 self.data["initInputs"] = set() 

87 self.data["initOutputs"] = set() 

88 self.data["allConnections"] = {} 

89 

90 def __setitem__(self, name: str, value: Any) -> None: 

91 if isinstance(value, BaseConnection): 

92 if name in { 92 ↛ 102line 92 didn't jump to line 102, because the condition on line 92 was never true

93 "dimensions", 

94 "inputs", 

95 "prerequisiteInputs", 

96 "outputs", 

97 "initInputs", 

98 "initOutputs", 

99 "allConnections", 

100 }: 

101 # Guard against connections whose names are reserved. 

102 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

103 if (previous := self.data.get(name)) is not None: 103 ↛ 106line 103 didn't jump to line 106, because the condition on line 103 was never true

104 # Guard against changing the type of an in inherited connection 

105 # by first removing it from the set it's current in. 

106 self.data[previous._connection_type_set].discard(name) 

107 object.__setattr__(value, "varName", name) 

108 self.data["allConnections"][name] = value 

109 self.data[value._connection_type_set].add(name) 

110 # defer to the default behavior 

111 super().__setitem__(name, value) 

112 

113 

114class PipelineTaskConnectionsMetaclass(type): 

115 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

116 

117 # We can annotate these attributes as `collections.abc.Set` to discourage 

118 # undesirable modifications in type-checked code, since the internal code 

119 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

120 # these annotations anyway. 

121 

122 dimensions: Set[str] 

123 """Set of dimension names that define the unit of work for this task. 

124 

125 Required and implied dependencies will automatically be expanded later and 

126 need not be provided. 

127 

128 This is shadowed by an instance-level attribute on 

129 `PipelineTaskConnections` instances. 

130 """ 

131 

132 inputs: Set[str] 

133 """Set with the names of all `~connectionTypes.Input` connection 

134 attributes. 

135 

136 This is updated automatically as class attributes are added. Note that 

137 this attribute is shadowed by an instance-level attribute on 

138 `PipelineTaskConnections` instances. 

139 """ 

140 

141 prerequisiteInputs: Set[str] 

142 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

143 connection attributes. 

144 

145 See `inputs` for additional information. 

146 """ 

147 

148 outputs: Set[str] 

149 """Set with the names of all `~connectionTypes.Output` connection 

150 attributes. 

151 

152 See `inputs` for additional information. 

153 """ 

154 

155 initInputs: Set[str] 

156 """Set with the names of all `~connectionTypes.InitInput` connection 

157 attributes. 

158 

159 See `inputs` for additional information. 

160 """ 

161 

162 initOutputs: Set[str] 

163 """Set with the names of all `~connectionTypes.InitOutput` connection 

164 attributes. 

165 

166 See `inputs` for additional information. 

167 """ 

168 

169 allConnections: Mapping[str, BaseConnection] 

170 """Mapping containing all connection attributes. 

171 

172 See `inputs` for additional information. 

173 """ 

174 

175 def __prepare__(name, bases, **kwargs): # noqa: N804 

176 # Create an instance of our special dict to catch and track all 

177 # variables that are instances of connectionTypes.BaseConnection 

178 # Copy any existing connections from a parent class 

179 dct = PipelineTaskConnectionDict() 

180 for base in bases: 

181 if isinstance(base, PipelineTaskConnectionsMetaclass): 181 ↛ 180line 181 didn't jump to line 180, because the condition on line 181 was never false

182 for name, value in base.allConnections.items(): 182 ↛ 183line 182 didn't jump to line 183, because the loop on line 182 never started

183 dct[name] = value 

184 return dct 

185 

186 def __new__(cls, name, bases, dct, **kwargs): 

187 dimensionsValueError = TypeError( 

188 "PipelineTaskConnections class must be created with a dimensions " 

189 "attribute which is an iterable of dimension names" 

190 ) 

191 

192 if name != "PipelineTaskConnections": 

193 # Verify that dimensions are passed as a keyword in class 

194 # declaration 

195 if "dimensions" not in kwargs: 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true

196 for base in bases: 

197 if hasattr(base, "dimensions"): 

198 kwargs["dimensions"] = base.dimensions 

199 break 

200 if "dimensions" not in kwargs: 

201 raise dimensionsValueError 

202 try: 

203 if isinstance(kwargs["dimensions"], str): 203 ↛ 204line 203 didn't jump to line 204, because the condition on line 203 was never true

204 raise TypeError( 

205 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

206 ) 

207 if not isinstance(kwargs["dimensions"], Iterable): 207 ↛ 208line 207 didn't jump to line 208, because the condition on line 207 was never true

208 raise TypeError("Dimensions must be iterable of dimensions") 

209 dct["dimensions"] = set(kwargs["dimensions"]) 

210 except TypeError as exc: 

211 raise dimensionsValueError from exc 

212 # Lookup any python string templates that may have been used in the 

213 # declaration of the name field of a class connection attribute 

214 allTemplates = set() 

215 stringFormatter = string.Formatter() 

216 # Loop over all connections 

217 for obj in dct["allConnections"].values(): 

218 nameValue = obj.name 

219 # add all the parameters to the set of templates 

220 for param in stringFormatter.parse(nameValue): 

221 if param[1] is not None: 

222 allTemplates.add(param[1]) 

223 

224 # look up any template from base classes and merge them all 

225 # together 

226 mergeDict = {} 

227 mergeDeprecationsDict = {} 

228 for base in bases[::-1]: 

229 if hasattr(base, "defaultTemplates"): 

230 mergeDict.update(base.defaultTemplates) 

231 if hasattr(base, "deprecatedTemplates"): 

232 mergeDeprecationsDict.update(base.deprecatedTemplates) 

233 if "defaultTemplates" in kwargs: 

234 mergeDict.update(kwargs["defaultTemplates"]) 

235 if "deprecatedTemplates" in kwargs: 235 ↛ 236line 235 didn't jump to line 236, because the condition on line 235 was never true

236 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"]) 

237 if len(mergeDict) > 0: 

238 kwargs["defaultTemplates"] = mergeDict 

239 if len(mergeDeprecationsDict) > 0: 239 ↛ 240line 239 didn't jump to line 240, because the condition on line 239 was never true

240 kwargs["deprecatedTemplates"] = mergeDeprecationsDict 

241 

242 # Verify that if templated strings were used, defaults were 

243 # supplied as an argument in the declaration of the connection 

244 # class 

245 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 245 ↛ 246line 245 didn't jump to line 246, because the condition on line 245 was never true

246 raise TypeError( 

247 "PipelineTaskConnection class contains templated attribute names, but no " 

248 "defaut templates were provided, add a dictionary attribute named " 

249 "defaultTemplates which contains the mapping between template key and value" 

250 ) 

251 if len(allTemplates) > 0: 

252 # Verify all templates have a default, and throw if they do not 

253 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

254 templateDifference = allTemplates.difference(defaultTemplateKeys) 

255 if templateDifference: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true

256 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

257 # Verify that templates do not share names with variable names 

258 # used for a connection, this is needed because of how 

259 # templates are specified in an associated config class. 

260 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

261 if len(nameTemplateIntersection) > 0: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true

262 raise TypeError( 

263 "Template parameters cannot share names with Class attributes" 

264 f" (conflicts are {nameTemplateIntersection})." 

265 ) 

266 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

267 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {}) 

268 

269 # Convert all the connection containers into frozensets so they cannot 

270 # be modified at the class scope 

271 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

272 dct[connectionName] = frozenset(dct[connectionName]) 

273 # our custom dict type must be turned into an actual dict to be used in 

274 # type.__new__ 

275 return super().__new__(cls, name, bases, dict(dct)) 

276 

277 def __init__(cls, name, bases, dct, **kwargs): 

278 # This overrides the default init to drop the kwargs argument. Python 

279 # metaclasses will have this argument set if any kwargs are passes at 

280 # class construction time, but should be consumed before calling 

281 # __init__ on the type metaclass. This is in accordance with python 

282 # documentation on metaclasses 

283 super().__init__(name, bases, dct) 

284 

285 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

286 # MyPy appears not to really understand metaclass.__call__ at all, so 

287 # we need to tell it to ignore __new__ and __init__ calls here. 

288 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

289 

290 # Make mutable copies of all set-like class attributes so derived 

291 # __init__ implementations can modify them in place. 

292 instance.dimensions = set(cls.dimensions) 

293 instance.inputs = set(cls.inputs) 

294 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

295 instance.outputs = set(cls.outputs) 

296 instance.initInputs = set(cls.initInputs) 

297 instance.initOutputs = set(cls.initOutputs) 

298 

299 # Set self.config. It's a bit strange that we claim to accept None but 

300 # really just raise here, but it's not worth changing now. 

301 from .config import PipelineTaskConfig # local import to avoid cycle 

302 

303 if config is None or not isinstance(config, PipelineTaskConfig): 

304 raise ValueError( 

305 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

306 ) 

307 instance.config = config 

308 

309 # Extract the template names that were defined in the config instance 

310 # by looping over the keys of the defaultTemplates dict specified at 

311 # class declaration time. 

312 templateValues = { 

313 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore 

314 } 

315 

316 # We now assemble a mapping of all connection instances keyed by 

317 # internal name, applying the configuration and templates to make new 

318 # configurations from the class-attribute defaults. This will be 

319 # private, but with a public read-only view. This mapping is what the 

320 # descriptor interface of the class-level attributes will return when 

321 # they are accessed on an instance. This is better than just assigning 

322 # regular instance attributes as it makes it so removed connections 

323 # cannot be accessed on instances, instead of having access to them 

324 # silent fall through to the not-removed class connection instance. 

325 instance._allConnections = {} 

326 instance.allConnections = MappingProxyType(instance._allConnections) 

327 for internal_name, connection in cls.allConnections.items(): 

328 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

329 instance_connection = dataclasses.replace( 

330 connection, 

331 name=dataset_type_name, 

332 doc=( 

333 connection.doc 

334 if connection.deprecated is None 

335 else f"{connection.doc}\n{connection.deprecated}" 

336 ), 

337 ) 

338 instance._allConnections[internal_name] = instance_connection 

339 

340 # Finally call __init__. The base class implementation does nothing; 

341 # we could have left some of the above implementation there (where it 

342 # originated), but putting it here instead makes it hard for derived 

343 # class implementors to get things into a weird state by delegating to 

344 # super().__init__ in the wrong place, or by forgetting to do that 

345 # entirely. 

346 instance.__init__(config=config) # type: ignore 

347 

348 # Derived-class implementations may have changed the contents of the 

349 # various kinds-of-connection sets; update allConnections to have keys 

350 # that are a union of all those. We get values for the new 

351 # allConnections from the attributes, since any dynamically added new 

352 # ones will not be present in the old allConnections. Typically those 

353 # getattrs will invoke the descriptors and get things from the old 

354 # allConnections anyway. After processing each set we replace it with 

355 # a frozenset. 

356 updated_all_connections = {} 

357 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

358 updated_connection_names = getattr(instance, attrName) 

359 updated_all_connections.update( 

360 {name: getattr(instance, name) for name in updated_connection_names} 

361 ) 

362 # Setting these to frozenset is at odds with the type annotation, 

363 # but MyPy can't tell because we're using setattr, and we want to 

364 # lie to it anyway to get runtime guards against post-__init__ 

365 # mutation. 

366 setattr(instance, attrName, frozenset(updated_connection_names)) 

367 # Update the existing dict in place, since we already have a view of 

368 # that. 

369 instance._allConnections.clear() 

370 instance._allConnections.update(updated_all_connections) 

371 

372 for obj in instance._allConnections.values(): 

373 if obj.deprecated is not None: 

374 warnings.warn( 

375 obj.deprecated, FutureWarning, stacklevel=find_outside_stacklevel("lsst.pipe.base") 

376 ) 

377 

378 # Freeze the connection instance dimensions now. This at odds with the 

379 # type annotation, which says [mutable] `set`, just like the connection 

380 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

381 # tell with those since we're using setattr for them. 

382 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

383 

384 return instance 

385 

386 

387class QuantizedConnection(SimpleNamespace): 

388 r"""A Namespace to map defined variable names of connections to the 

389 associated `lsst.daf.butler.DatasetRef` objects. 

390 

391 This class maps the names used to define a connection on a 

392 `PipelineTaskConnections` class to the corresponding 

393 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum` 

394 instance. This will be a quantum of execution based on the graph created 

395 by examining all the connections defined on the 

396 `PipelineTaskConnections` class. 

397 """ 

398 

399 def __init__(self, **kwargs): 

400 # Create a variable to track what attributes are added. This is used 

401 # later when iterating over this QuantizedConnection instance 

402 object.__setattr__(self, "_attributes", set()) 

403 

404 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

405 # Capture the attribute name as it is added to this object 

406 self._attributes.add(name) 

407 super().__setattr__(name, value) 

408 

409 def __delattr__(self, name): 

410 object.__delattr__(self, name) 

411 self._attributes.remove(name) 

412 

413 def __len__(self) -> int: 

414 return len(self._attributes) 

415 

416 def __iter__( 

417 self, 

418 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

419 """Make an iterator for this `QuantizedConnection`. 

420 

421 Iterating over a `QuantizedConnection` will yield a tuple with the name 

422 of an attribute and the value associated with that name. This is 

423 similar to dict.items() but is on the namespace attributes rather than 

424 dict keys. 

425 """ 

426 yield from ((name, getattr(self, name)) for name in self._attributes) 

427 

428 def keys(self) -> Generator[str, None, None]: 

429 """Return an iterator over all the attributes added to a 

430 `QuantizedConnection` class 

431 """ 

432 yield from self._attributes 

433 

434 

435class InputQuantizedConnection(QuantizedConnection): 

436 """Input variant of a `QuantizedConnection`.""" 

437 

438 pass 

439 

440 

441class OutputQuantizedConnection(QuantizedConnection): 

442 """Output variant of a `QuantizedConnection`.""" 

443 

444 pass 

445 

446 

447@dataclass(frozen=True) 

448class DeferredDatasetRef: 

449 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

450 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

451 instead of an in-memory dataset. 

452 

453 Parameters 

454 ---------- 

455 datasetRef : `lsst.daf.butler.DatasetRef` 

456 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

457 resolve a dataset 

458 """ 

459 

460 datasetRef: DatasetRef 

461 

462 @property 

463 def datasetType(self) -> DatasetType: 

464 """The dataset type for this dataset.""" 

465 return self.datasetRef.datasetType 

466 

467 @property 

468 def dataId(self) -> DataCoordinate: 

469 """The data ID for this dataset.""" 

470 return self.datasetRef.dataId 

471 

472 

473class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

474 """PipelineTaskConnections is a class used to declare desired IO when a 

475 PipelineTask is run by an activator 

476 

477 Parameters 

478 ---------- 

479 config : `PipelineTaskConfig` 

480 A `PipelineTaskConfig` class instance whose class has been configured 

481 to use this `PipelineTaskConnections` class. 

482 

483 See Also 

484 -------- 

485 iterConnections 

486 

487 Notes 

488 ----- 

489 ``PipelineTaskConnection`` classes are created by declaring class 

490 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

491 listed as follows: 

492 

493 * ``InitInput`` - Defines connections in a quantum graph which are used as 

494 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

495 to this class 

496 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

497 persisted using a butler at the end of the ``__init__`` function of the 

498 `PipelineTask` corresponding to this class. The variable name used to 

499 define this connection should be the same as an attribute name on the 

500 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

501 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

502 a `PipelineTask` instance should have an attribute 

503 ``self.outputSchema`` defined. Its value is what will be saved by the 

504 activator framework. 

505 * ``PrerequisiteInput`` - An input connection type that defines a 

506 `lsst.daf.butler.DatasetType` that must be present at execution time, 

507 but that will not be used during the course of creating the quantum 

508 graph to be executed. These most often are things produced outside the 

509 processing pipeline, such as reference catalogs. 

510 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

511 in the ``run`` method of a `PipelineTask`. The name used to declare 

512 class attribute must match a function argument name in the ``run`` 

513 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

514 defines an ``Input`` with the name ``calexp``, then the corresponding 

515 signature should be ``PipelineTask.run(calexp, ...)`` 

516 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

517 execution of a `PipelineTask`. The name used to declare the connection 

518 must correspond to an attribute of a `Struct` that is returned by a 

519 `PipelineTask` ``run`` method. E.g. if an output connection is 

520 defined with the name ``measCat``, then the corresponding 

521 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

522 X matches the ``storageClass`` type defined on the output connection. 

523 

524 Attributes of these types can also be created, replaced, or deleted on the 

525 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

526 just the name depends on the configuration. It is preferred to define them 

527 in the class when possible (even if configuration may cause the connection 

528 to be removed from the instance). 

529 

530 The process of declaring a ``PipelineTaskConnection`` class involves 

531 parameters passed in the declaration statement. 

532 

533 The first parameter is ``dimensions`` which is an iterable of strings which 

534 defines the unit of processing the run method of a corresponding 

535 `PipelineTask` will operate on. These dimensions must match dimensions that 

536 exist in the butler registry which will be used in executing the 

537 corresponding `PipelineTask`. The dimensions may be also modified in 

538 subclass ``__init__`` methods if they need to depend on configuration. 

539 

540 The second parameter is labeled ``defaultTemplates`` and is conditionally 

541 optional. The name attributes of connections can be specified as python 

542 format strings, with named format arguments. If any of the name parameters 

543 on connections defined in a `PipelineTaskConnections` class contain a 

544 template, then a default template value must be specified in the 

545 ``defaultTemplates`` argument. This is done by passing a dictionary with 

546 keys corresponding to a template identifier, and values corresponding to 

547 the value to use as a default when formatting the string. For example if 

548 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

549 ``defaultTemplates`` = {'input': 'deep'}. 

550 

551 Once a `PipelineTaskConnections` class is created, it is used in the 

552 creation of a `PipelineTaskConfig`. This is further documented in the 

553 documentation of `PipelineTaskConfig`. For the purposes of this 

554 documentation, the relevant information is that the config class allows 

555 configuration of connection names by users when running a pipeline. 

556 

557 Instances of a `PipelineTaskConnections` class are used by the pipeline 

558 task execution framework to introspect what a corresponding `PipelineTask` 

559 will require, and what it will produce. 

560 

561 Examples 

562 -------- 

563 >>> from lsst.pipe.base import connectionTypes as cT 

564 >>> from lsst.pipe.base import PipelineTaskConnections 

565 >>> from lsst.pipe.base import PipelineTaskConfig 

566 >>> class ExampleConnections(PipelineTaskConnections, 

567 ... dimensions=("A", "B"), 

568 ... defaultTemplates={"foo": "Example"}): 

569 ... inputConnection = cT.Input(doc="Example input", 

570 ... dimensions=("A", "B"), 

571 ... storageClass=Exposure, 

572 ... name="{foo}Dataset") 

573 ... outputConnection = cT.Output(doc="Example output", 

574 ... dimensions=("A", "B"), 

575 ... storageClass=Exposure, 

576 ... name="{foo}output") 

577 >>> class ExampleConfig(PipelineTaskConfig, 

578 ... pipelineConnections=ExampleConnections): 

579 ... pass 

580 >>> config = ExampleConfig() 

581 >>> config.connections.foo = Modified 

582 >>> config.connections.outputConnection = "TotallyDifferent" 

583 >>> connections = ExampleConnections(config=config) 

584 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

585 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

586 """ 

587 

588 # We annotate these attributes as mutable sets because that's what they are 

589 # inside derived ``__init__`` implementations and that's what matters most 

590 # After that's done, the metaclass __call__ makes them into frozensets, but 

591 # relatively little code interacts with them then, and that code knows not 

592 # to try to modify them without having to be told that by mypy. 

593 

594 dimensions: set[str] 

595 """Set of dimension names that define the unit of work for this task. 

596 

597 Required and implied dependencies will automatically be expanded later and 

598 need not be provided. 

599 

600 This may be replaced or modified in ``__init__`` to change the dimensions 

601 of the task. After ``__init__`` it will be a `frozenset` and may not be 

602 replaced. 

603 """ 

604 

605 inputs: set[str] 

606 """Set with the names of all `connectionTypes.Input` connection attributes. 

607 

608 This is updated automatically as class attributes are added, removed, or 

609 replaced in ``__init__``. Removing entries from this set will cause those 

610 connections to be removed after ``__init__`` completes, but this is 

611 supported only for backwards compatibility; new code should instead just 

612 delete the collection attributed directly. After ``__init__`` this will be 

613 a `frozenset` and may not be replaced. 

614 """ 

615 

616 prerequisiteInputs: set[str] 

617 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

618 connection attributes. 

619 

620 See `inputs` for additional information. 

621 """ 

622 

623 outputs: set[str] 

624 """Set with the names of all `~connectionTypes.Output` connection 

625 attributes. 

626 

627 See `inputs` for additional information. 

628 """ 

629 

630 initInputs: set[str] 

631 """Set with the names of all `~connectionTypes.InitInput` connection 

632 attributes. 

633 

634 See `inputs` for additional information. 

635 """ 

636 

637 initOutputs: set[str] 

638 """Set with the names of all `~connectionTypes.InitOutput` connection 

639 attributes. 

640 

641 See `inputs` for additional information. 

642 """ 

643 

644 allConnections: Mapping[str, BaseConnection] 

645 """Mapping holding all connection attributes. 

646 

647 This is a read-only view that is automatically updated when connection 

648 attributes are added, removed, or replaced in ``__init__``. It is also 

649 updated after ``__init__`` completes to reflect changes in `inputs`, 

650 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

651 """ 

652 

653 _allConnections: dict[str, BaseConnection] 

654 

655 def __init__(self, *, config: PipelineTaskConfig | None = None): 

656 pass 

657 

658 def __setattr__(self, name: str, value: Any) -> None: 

659 if isinstance(value, BaseConnection): 

660 previous = self._allConnections.get(name) 

661 try: 

662 getattr(self, value._connection_type_set).add(name) 

663 except AttributeError: 

664 # Attempt to call add on a frozenset, which is what these sets 

665 # are after __init__ is done. 

666 raise TypeError("Connections objects are frozen after construction.") from None 

667 if previous is not None and value._connection_type_set != previous._connection_type_set: 

668 # Connection has changed type, e.g. Input to PrerequisiteInput; 

669 # update the sets accordingly. To be extra defensive about 

670 # multiple assignments we use the type of the previous instance 

671 # instead of assuming that's the same as the type of the self, 

672 # which is just the default. Use discard instead of remove so 

673 # manually removing from these sets first is never an error. 

674 getattr(self, previous._connection_type_set).discard(name) 

675 self._allConnections[name] = value 

676 if hasattr(self.__class__, name): 

677 # Don't actually set the attribute if this was a connection 

678 # declared in the class; in that case we let the descriptor 

679 # return the value we just added to allConnections. 

680 return 

681 # Actually add the attribute. 

682 super().__setattr__(name, value) 

683 

684 def __delattr__(self, name): 

685 """Descriptor delete method.""" 

686 previous = self._allConnections.get(name) 

687 if previous is not None: 

688 # Delete this connection's name from the appropriate set, which we 

689 # have to get from the previous instance instead of assuming it's 

690 # the same set that was appropriate for the class-level default. 

691 # Use discard instead of remove so manually removing from these 

692 # sets first is never an error. 

693 try: 

694 getattr(self, previous._connection_type_set).discard(name) 

695 except AttributeError: 

696 # Attempt to call discard on a frozenset, which is what these 

697 # sets are after __init__ is done. 

698 raise TypeError("Connections objects are frozen after construction.") from None 

699 del self._allConnections[name] 

700 if hasattr(self.__class__, name): 

701 # Don't actually delete the attribute if this was a connection 

702 # declared in the class; in that case we let the descriptor 

703 # see that it's no longer present in allConnections. 

704 return 

705 # Actually delete the attribute. 

706 super().__delattr__(name) 

707 

708 def buildDatasetRefs( 

709 self, quantum: Quantum 

710 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

711 """Build `QuantizedConnection` corresponding to input 

712 `~lsst.daf.butler.Quantum`. 

713 

714 Parameters 

715 ---------- 

716 quantum : `lsst.daf.butler.Quantum` 

717 Quantum object which defines the inputs and outputs for a given 

718 unit of processing. 

719 

720 Returns 

721 ------- 

722 retVal : `tuple` of (`InputQuantizedConnection`, 

723 `OutputQuantizedConnection`) Namespaces mapping attribute names 

724 (identifiers of connections) to butler references defined in the 

725 input `lsst.daf.butler.Quantum`. 

726 """ 

727 inputDatasetRefs = InputQuantizedConnection() 

728 outputDatasetRefs = OutputQuantizedConnection() 

729 # operate on a reference object and an iterable of names of class 

730 # connection attributes 

731 for refs, names in zip( 

732 (inputDatasetRefs, outputDatasetRefs), 

733 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

734 strict=True, 

735 ): 

736 # get a name of a class connection attribute 

737 for attributeName in names: 

738 # get the attribute identified by name 

739 attribute = getattr(self, attributeName) 

740 # Branch if the attribute dataset type is an input 

741 if attribute.name in quantum.inputs: 

742 # if the dataset is marked to load deferred, wrap it in a 

743 # DeferredDatasetRef 

744 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

745 if attribute.deferLoad: 

746 quantumInputRefs = [ 

747 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

748 ] 

749 else: 

750 quantumInputRefs = list(quantum.inputs[attribute.name]) 

751 # Unpack arguments that are not marked multiples (list of 

752 # length one) 

753 if not attribute.multiple: 

754 if len(quantumInputRefs) > 1: 

755 raise ScalarError( 

756 "Received multiple datasets " 

757 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

758 f"for scalar connection {attributeName} " 

759 f"({quantumInputRefs[0].datasetType.name}) " 

760 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

761 ) 

762 if len(quantumInputRefs) == 0: 

763 continue 

764 setattr(refs, attributeName, quantumInputRefs[0]) 

765 else: 

766 # Add to the QuantizedConnection identifier 

767 setattr(refs, attributeName, quantumInputRefs) 

768 # Branch if the attribute dataset type is an output 

769 elif attribute.name in quantum.outputs: 

770 value = quantum.outputs[attribute.name] 

771 # Unpack arguments that are not marked multiples (list of 

772 # length one) 

773 if not attribute.multiple: 

774 setattr(refs, attributeName, value[0]) 

775 else: 

776 setattr(refs, attributeName, value) 

777 # Specified attribute is not in inputs or outputs dont know how 

778 # to handle, throw 

779 else: 

780 raise ValueError( 

781 f"Attribute with name {attributeName} has no counterpart in input quantum" 

782 ) 

783 return inputDatasetRefs, outputDatasetRefs 

784 

785 def adjustQuantum( 

786 self, 

787 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

788 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

789 label: str, 

790 data_id: DataCoordinate, 

791 ) -> tuple[ 

792 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

793 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

794 ]: 

795 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

796 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

797 of the activator. 

798 

799 Parameters 

800 ---------- 

801 inputs : `dict` 

802 Dictionary whose keys are an input (regular or prerequisite) 

803 connection name and whose values are a tuple of the connection 

804 instance and a collection of associated 

805 `~lsst.daf.butler.DatasetRef` objects. 

806 The exact type of the nested collections is unspecified; it can be 

807 assumed to be multi-pass iterable and support `len` and ``in``, but 

808 it should not be mutated in place. In contrast, the outer 

809 dictionaries are guaranteed to be temporary copies that are true 

810 `dict` instances, and hence may be modified and even returned; this 

811 is especially useful for delegating to `super` (see notes below). 

812 outputs : `~collections.abc.Mapping` 

813 Mapping of output datasets, with the same structure as ``inputs``. 

814 label : `str` 

815 Label for this task in the pipeline (should be used in all 

816 diagnostic messages). 

817 data_id : `lsst.daf.butler.DataCoordinate` 

818 Data ID for this quantum in the pipeline (should be used in all 

819 diagnostic messages). 

820 

821 Returns 

822 ------- 

823 adjusted_inputs : `~collections.abc.Mapping` 

824 Mapping of the same form as ``inputs`` with updated containers of 

825 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

826 not changed should not be returned at all. Datasets may only be 

827 removed, not added. Nested collections may be of any multi-pass 

828 iterable type, and the order of iteration will set the order of 

829 iteration within `PipelineTask.runQuantum`. 

830 adjusted_outputs : `~collections.abc.Mapping` 

831 Mapping of updated output datasets, with the same structure and 

832 interpretation as ``adjusted_inputs``. 

833 

834 Raises 

835 ------ 

836 ScalarError 

837 Raised if any `Input` or `PrerequisiteInput` connection has 

838 ``multiple`` set to `False`, but multiple datasets. 

839 NoWorkFound 

840 Raised to indicate that this quantum should not be run; not enough 

841 datasets were found for a regular `Input` connection, and the 

842 quantum should be pruned or skipped. 

843 FileNotFoundError 

844 Raised to cause QuantumGraph generation to fail (with the message 

845 included in this exception); not enough datasets were found for a 

846 `PrerequisiteInput` connection. 

847 

848 Notes 

849 ----- 

850 The base class implementation performs important checks. It always 

851 returns an empty mapping (i.e. makes no adjustments). It should 

852 always called be via `super` by custom implementations, ideally at the 

853 end of the custom implementation with already-adjusted mappings when 

854 any datasets are actually dropped, e.g.: 

855 

856 .. code-block:: python 

857 

858 def adjustQuantum(self, inputs, outputs, label, data_id): 

859 # Filter out some dataset refs for one connection. 

860 connection, old_refs = inputs["my_input"] 

861 new_refs = [ref for ref in old_refs if ...] 

862 adjusted_inputs = {"my_input", (connection, new_refs)} 

863 # Update the original inputs so we can pass them to super. 

864 inputs.update(adjusted_inputs) 

865 # Can ignore outputs from super because they are guaranteed 

866 # to be empty. 

867 super().adjustQuantum(inputs, outputs, label_data_id) 

868 # Return only the connections we modified. 

869 return adjusted_inputs, {} 

870 

871 Removing outputs here is guaranteed to affect what is actually 

872 passed to `PipelineTask.runQuantum`, but its effect on the larger 

873 graph may be deferred to execution, depending on the context in 

874 which `adjustQuantum` is being run: if one quantum removes an output 

875 that is needed by a second quantum as input, the second quantum may not 

876 be adjusted (and hence pruned or skipped) until that output is actually 

877 found to be missing at execution time. 

878 

879 Tasks that desire zip-iteration consistency between any combinations of 

880 connections that have the same data ID should generally implement 

881 `adjustQuantum` to achieve this, even if they could also run that 

882 logic during execution; this allows the system to see outputs that will 

883 not be produced because the corresponding input is missing as early as 

884 possible. 

885 """ 

886 for name, (input_connection, refs) in inputs.items(): 

887 dataset_type_name = input_connection.name 

888 if not input_connection.multiple and len(refs) > 1: 

889 raise ScalarError( 

890 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

891 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

892 f"for quantum data ID {data_id}." 

893 ) 

894 if len(refs) < input_connection.minimum: 

895 if isinstance(input_connection, PrerequisiteInput): 

896 # This branch should only be possible during QG generation, 

897 # or if someone deleted the dataset between making the QG 

898 # and trying to run it. Either one should be a hard error. 

899 raise FileNotFoundError( 

900 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

901 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

902 f"{data_id}." 

903 ) 

904 else: 

905 # This branch should be impossible during QG generation, 

906 # because that algorithm can only make quanta whose inputs 

907 # are either already present or should be created during 

908 # execution. It can trigger during execution if the input 

909 # wasn't actually created by an upstream task in the same 

910 # graph. 

911 raise NoWorkFound(label, name, input_connection) 

912 for name, (output_connection, refs) in outputs.items(): 

913 dataset_type_name = output_connection.name 

914 if not output_connection.multiple and len(refs) > 1: 

915 raise ScalarError( 

916 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

917 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

918 f"for quantum data ID {data_id}." 

919 ) 

920 return {}, {} 

921 

922 

923def iterConnections( 

924 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

925) -> Generator[BaseConnection, None, None]: 

926 """Create an iterator over the selected connections type which yields 

927 all the defined connections of that type. 

928 

929 Parameters 

930 ---------- 

931 connections : `PipelineTaskConnections` 

932 An instance of a `PipelineTaskConnections` object that will be iterated 

933 over. 

934 connectionType : `str` 

935 The type of connections to iterate over, valid values are inputs, 

936 outputs, prerequisiteInputs, initInputs, initOutputs. 

937 

938 Yields 

939 ------ 

940 connection: `~.connectionTypes.BaseConnection` 

941 A connection defined on the input connections object of the type 

942 supplied. The yielded value Will be an derived type of 

943 `~.connectionTypes.BaseConnection`. 

944 """ 

945 if isinstance(connectionType, str): 

946 connectionType = (connectionType,) 

947 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

948 yield getattr(connections, name) 

949 

950 

951@dataclass 

952class AdjustQuantumHelper: 

953 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

954 

955 This class holds `input` and `output` mappings in the form used by 

956 `Quantum` and execution harness code, i.e. with 

957 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

958 connection-oriented mappings used inside `PipelineTaskConnections`. 

959 """ 

960 

961 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

962 """Mapping of regular input and prerequisite input datasets, grouped by 

963 `~lsst.daf.butler.DatasetType`. 

964 """ 

965 

966 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

967 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

968 """ 

969 

970 inputs_adjusted: bool = False 

971 """Whether any inputs were removed in the last call to `adjust_in_place`. 

972 """ 

973 

974 outputs_adjusted: bool = False 

975 """Whether any outputs were removed in the last call to `adjust_in_place`. 

976 """ 

977 

978 def adjust_in_place( 

979 self, 

980 connections: PipelineTaskConnections, 

981 label: str, 

982 data_id: DataCoordinate, 

983 ) -> None: 

984 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

985 with its results. 

986 

987 Parameters 

988 ---------- 

989 connections : `PipelineTaskConnections` 

990 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

991 label : `str` 

992 Label for this task in the pipeline (should be used in all 

993 diagnostic messages). 

994 data_id : `lsst.daf.butler.DataCoordinate` 

995 Data ID for this quantum in the pipeline (should be used in all 

996 diagnostic messages). 

997 """ 

998 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

999 # connection-keyed, PipelineTask-oriented mappings. 

1000 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

1001 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

1002 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

1003 connection = getattr(connections, name) 

1004 dataset_type_name = connection.name 

1005 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

1006 for name in itertools.chain(connections.outputs): 

1007 connection = getattr(connections, name) 

1008 dataset_type_name = connection.name 

1009 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

1010 # Actually call adjustQuantum. 

1011 # MyPy correctly complains that this call is not quite legal, but the 

1012 # method docs explain exactly what's expected and it's the behavior we 

1013 # want. It'd be nice to avoid this if we ever have to change the 

1014 # interface anyway, but not an immediate problem. 

1015 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

1016 inputs_by_connection, # type: ignore 

1017 outputs_by_connection, # type: ignore 

1018 label, 

1019 data_id, 

1020 ) 

1021 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

1022 # installing new mappings in self if necessary. 

1023 if adjusted_inputs_by_connection: 

1024 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) 

1025 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1026 dataset_type_name = connection.name 

1027 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1028 raise RuntimeError( 

1029 f"adjustQuantum implementation for task with label {label} returned {name} " 

1030 f"({dataset_type_name}) input datasets that are not a subset of those " 

1031 f"it was given for data ID {data_id}." 

1032 ) 

1033 adjusted_inputs[dataset_type_name] = tuple(updated_refs) 

1034 self.inputs = adjusted_inputs.freeze() 

1035 self.inputs_adjusted = True 

1036 else: 

1037 self.inputs_adjusted = False 

1038 if adjusted_outputs_by_connection: 

1039 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) 

1040 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1041 dataset_type_name = connection.name 

1042 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1043 raise RuntimeError( 

1044 f"adjustQuantum implementation for task with label {label} returned {name} " 

1045 f"({dataset_type_name}) output datasets that are not a subset of those " 

1046 f"it was given for data ID {data_id}." 

1047 ) 

1048 adjusted_outputs[dataset_type_name] = tuple(updated_refs) 

1049 self.outputs = adjusted_outputs.freeze() 

1050 self.outputs_adjusted = True 

1051 else: 

1052 self.outputs_adjusted = False