Coverage for python/lsst/pipe/base/connections.py: 43%

305 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import dataclasses 

39import itertools 

40import string 

41import warnings 

42from collections import UserDict 

43from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set 

44from dataclasses import dataclass 

45from types import MappingProxyType, SimpleNamespace 

46from typing import TYPE_CHECKING, Any 

47 

48from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

49 

50from ._status import NoWorkFound 

51from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

52 

53if TYPE_CHECKING: 

54 from .config import PipelineTaskConfig 

55 

56 

57class ScalarError(TypeError): 

58 """Exception raised when dataset type is configured as scalar 

59 but there are multiple data IDs in a Quantum for that dataset. 

60 """ 

61 

62 

63class PipelineTaskConnectionDict(UserDict): 

64 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

65 

66 This dict is used in `PipelineTaskConnection` class creation, as the 

67 dictionary that is initially used as ``__dict__``. It exists to 

68 intercept connection fields declared in a `PipelineTaskConnection`, and 

69 what name is used to identify them. The names are then added to class 

70 level list according to the connection type of the class attribute. The 

71 names are also used as keys in a class level dictionary associated with 

72 the corresponding class attribute. This information is a duplicate of 

73 what exists in ``__dict__``, but provides a simple place to lookup and 

74 iterate on only these variables. 

75 """ 

76 

77 def __init__(self, *args: Any, **kwargs: Any): 

78 super().__init__(*args, **kwargs) 

79 # Initialize class level variables used to track any declared 

80 # class level variables that are instances of 

81 # connectionTypes.BaseConnection 

82 self.data["inputs"] = set() 

83 self.data["prerequisiteInputs"] = set() 

84 self.data["outputs"] = set() 

85 self.data["initInputs"] = set() 

86 self.data["initOutputs"] = set() 

87 self.data["allConnections"] = {} 

88 

89 def __setitem__(self, name: str, value: Any) -> None: 

90 if isinstance(value, BaseConnection): 

91 if name in { 91 ↛ 101line 91 didn't jump to line 101, because the condition on line 91 was never true

92 "dimensions", 

93 "inputs", 

94 "prerequisiteInputs", 

95 "outputs", 

96 "initInputs", 

97 "initOutputs", 

98 "allConnections", 

99 }: 

100 # Guard against connections whose names are reserved. 

101 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

102 if (previous := self.data.get(name)) is not None: 102 ↛ 105line 102 didn't jump to line 105, because the condition on line 102 was never true

103 # Guard against changing the type of an in inherited connection 

104 # by first removing it from the set it's current in. 

105 self.data[previous._connection_type_set].discard(name) 

106 object.__setattr__(value, "varName", name) 

107 self.data["allConnections"][name] = value 

108 self.data[value._connection_type_set].add(name) 

109 # defer to the default behavior 

110 super().__setitem__(name, value) 

111 

112 

113class PipelineTaskConnectionsMetaclass(type): 

114 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

115 

116 # We can annotate these attributes as `collections.abc.Set` to discourage 

117 # undesirable modifications in type-checked code, since the internal code 

118 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

119 # these annotations anyway. 

120 

121 dimensions: Set[str] 

122 """Set of dimension names that define the unit of work for this task. 

123 

124 Required and implied dependencies will automatically be expanded later and 

125 need not be provided. 

126 

127 This is shadowed by an instance-level attribute on 

128 `PipelineTaskConnections` instances. 

129 """ 

130 

131 inputs: Set[str] 

132 """Set with the names of all `~connectionTypes.Input` connection 

133 attributes. 

134 

135 This is updated automatically as class attributes are added. Note that 

136 this attribute is shadowed by an instance-level attribute on 

137 `PipelineTaskConnections` instances. 

138 """ 

139 

140 prerequisiteInputs: Set[str] 

141 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

142 connection attributes. 

143 

144 See `inputs` for additional information. 

145 """ 

146 

147 outputs: Set[str] 

148 """Set with the names of all `~connectionTypes.Output` connection 

149 attributes. 

150 

151 See `inputs` for additional information. 

152 """ 

153 

154 initInputs: Set[str] 

155 """Set with the names of all `~connectionTypes.InitInput` connection 

156 attributes. 

157 

158 See `inputs` for additional information. 

159 """ 

160 

161 initOutputs: Set[str] 

162 """Set with the names of all `~connectionTypes.InitOutput` connection 

163 attributes. 

164 

165 See `inputs` for additional information. 

166 """ 

167 

168 allConnections: Mapping[str, BaseConnection] 

169 """Mapping containing all connection attributes. 

170 

171 See `inputs` for additional information. 

172 """ 

173 

174 def __prepare__(name, bases, **kwargs): # noqa: N804 

175 # Create an instance of our special dict to catch and track all 

176 # variables that are instances of connectionTypes.BaseConnection 

177 # Copy any existing connections from a parent class 

178 dct = PipelineTaskConnectionDict() 

179 for base in bases: 

180 if isinstance(base, PipelineTaskConnectionsMetaclass): 180 ↛ 179line 180 didn't jump to line 179, because the condition on line 180 was never false

181 for name, value in base.allConnections.items(): 181 ↛ 182line 181 didn't jump to line 182, because the loop on line 181 never started

182 dct[name] = value 

183 return dct 

184 

185 def __new__(cls, name, bases, dct, **kwargs): 

186 dimensionsValueError = TypeError( 

187 "PipelineTaskConnections class must be created with a dimensions " 

188 "attribute which is an iterable of dimension names" 

189 ) 

190 

191 if name != "PipelineTaskConnections": 

192 # Verify that dimensions are passed as a keyword in class 

193 # declaration 

194 if "dimensions" not in kwargs: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true

195 for base in bases: 

196 if hasattr(base, "dimensions"): 

197 kwargs["dimensions"] = base.dimensions 

198 break 

199 if "dimensions" not in kwargs: 

200 raise dimensionsValueError 

201 try: 

202 if isinstance(kwargs["dimensions"], str): 202 ↛ 203line 202 didn't jump to line 203, because the condition on line 202 was never true

203 raise TypeError( 

204 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

205 ) 

206 if not isinstance(kwargs["dimensions"], Iterable): 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true

207 raise TypeError("Dimensions must be iterable of dimensions") 

208 dct["dimensions"] = set(kwargs["dimensions"]) 

209 except TypeError as exc: 

210 raise dimensionsValueError from exc 

211 # Lookup any python string templates that may have been used in the 

212 # declaration of the name field of a class connection attribute 

213 allTemplates = set() 

214 stringFormatter = string.Formatter() 

215 # Loop over all connections 

216 for obj in dct["allConnections"].values(): 

217 nameValue = obj.name 

218 # add all the parameters to the set of templates 

219 for param in stringFormatter.parse(nameValue): 

220 if param[1] is not None: 

221 allTemplates.add(param[1]) 

222 

223 # look up any template from base classes and merge them all 

224 # together 

225 mergeDict = {} 

226 mergeDeprecationsDict = {} 

227 for base in bases[::-1]: 

228 if hasattr(base, "defaultTemplates"): 

229 mergeDict.update(base.defaultTemplates) 

230 if hasattr(base, "deprecatedTemplates"): 

231 mergeDeprecationsDict.update(base.deprecatedTemplates) 

232 if "defaultTemplates" in kwargs: 

233 mergeDict.update(kwargs["defaultTemplates"]) 

234 if "deprecatedTemplates" in kwargs: 234 ↛ 235line 234 didn't jump to line 235, because the condition on line 234 was never true

235 mergeDeprecationsDict.update(kwargs["deprecatedTemplates"]) 

236 if len(mergeDict) > 0: 

237 kwargs["defaultTemplates"] = mergeDict 

238 if len(mergeDeprecationsDict) > 0: 238 ↛ 239line 238 didn't jump to line 239, because the condition on line 238 was never true

239 kwargs["deprecatedTemplates"] = mergeDeprecationsDict 

240 

241 # Verify that if templated strings were used, defaults were 

242 # supplied as an argument in the declaration of the connection 

243 # class 

244 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true

245 raise TypeError( 

246 "PipelineTaskConnection class contains templated attribute names, but no " 

247 "defaut templates were provided, add a dictionary attribute named " 

248 "defaultTemplates which contains the mapping between template key and value" 

249 ) 

250 if len(allTemplates) > 0: 

251 # Verify all templates have a default, and throw if they do not 

252 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

253 templateDifference = allTemplates.difference(defaultTemplateKeys) 

254 if templateDifference: 254 ↛ 255line 254 didn't jump to line 255, because the condition on line 254 was never true

255 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

256 # Verify that templates do not share names with variable names 

257 # used for a connection, this is needed because of how 

258 # templates are specified in an associated config class. 

259 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

260 if len(nameTemplateIntersection) > 0: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true

261 raise TypeError( 

262 "Template parameters cannot share names with Class attributes" 

263 f" (conflicts are {nameTemplateIntersection})." 

264 ) 

265 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

266 dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {}) 

267 

268 # Convert all the connection containers into frozensets so they cannot 

269 # be modified at the class scope 

270 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

271 dct[connectionName] = frozenset(dct[connectionName]) 

272 # our custom dict type must be turned into an actual dict to be used in 

273 # type.__new__ 

274 return super().__new__(cls, name, bases, dict(dct)) 

275 

276 def __init__(cls, name, bases, dct, **kwargs): 

277 # This overrides the default init to drop the kwargs argument. Python 

278 # metaclasses will have this argument set if any kwargs are passes at 

279 # class construction time, but should be consumed before calling 

280 # __init__ on the type metaclass. This is in accordance with python 

281 # documentation on metaclasses 

282 super().__init__(name, bases, dct) 

283 

284 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

285 # MyPy appears not to really understand metaclass.__call__ at all, so 

286 # we need to tell it to ignore __new__ and __init__ calls here. 

287 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

288 

289 # Make mutable copies of all set-like class attributes so derived 

290 # __init__ implementations can modify them in place. 

291 instance.dimensions = set(cls.dimensions) 

292 instance.inputs = set(cls.inputs) 

293 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

294 instance.outputs = set(cls.outputs) 

295 instance.initInputs = set(cls.initInputs) 

296 instance.initOutputs = set(cls.initOutputs) 

297 

298 # Set self.config. It's a bit strange that we claim to accept None but 

299 # really just raise here, but it's not worth changing now. 

300 from .config import PipelineTaskConfig # local import to avoid cycle 

301 

302 if config is None or not isinstance(config, PipelineTaskConfig): 

303 raise ValueError( 

304 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

305 ) 

306 instance.config = config 

307 

308 # Extract the template names that were defined in the config instance 

309 # by looping over the keys of the defaultTemplates dict specified at 

310 # class declaration time. 

311 templateValues = { 

312 name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore 

313 } 

314 

315 # We now assemble a mapping of all connection instances keyed by 

316 # internal name, applying the configuration and templates to make new 

317 # configurations from the class-attribute defaults. This will be 

318 # private, but with a public read-only view. This mapping is what the 

319 # descriptor interface of the class-level attributes will return when 

320 # they are accessed on an instance. This is better than just assigning 

321 # regular instance attributes as it makes it so removed connections 

322 # cannot be accessed on instances, instead of having access to them 

323 # silent fall through to the not-removed class connection instance. 

324 instance._allConnections = {} 

325 instance.allConnections = MappingProxyType(instance._allConnections) 

326 for internal_name, connection in cls.allConnections.items(): 

327 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

328 instance_connection = dataclasses.replace( 

329 connection, 

330 name=dataset_type_name, 

331 doc=( 

332 connection.doc 

333 if connection.deprecated is None 

334 else f"{connection.doc}\n{connection.deprecated}" 

335 ), 

336 _deprecation_context=connection._deprecation_context, 

337 ) 

338 instance._allConnections[internal_name] = instance_connection 

339 

340 # Finally call __init__. The base class implementation does nothing; 

341 # we could have left some of the above implementation there (where it 

342 # originated), but putting it here instead makes it hard for derived 

343 # class implementors to get things into a weird state by delegating to 

344 # super().__init__ in the wrong place, or by forgetting to do that 

345 # entirely. 

346 instance.__init__(config=config) # type: ignore 

347 

348 # Derived-class implementations may have changed the contents of the 

349 # various kinds-of-connection sets; update allConnections to have keys 

350 # that are a union of all those. We get values for the new 

351 # allConnections from the attributes, since any dynamically added new 

352 # ones will not be present in the old allConnections. Typically those 

353 # getattrs will invoke the descriptors and get things from the old 

354 # allConnections anyway. After processing each set we replace it with 

355 # a frozenset. 

356 updated_all_connections = {} 

357 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

358 updated_connection_names = getattr(instance, attrName) 

359 updated_all_connections.update( 

360 {name: getattr(instance, name) for name in updated_connection_names} 

361 ) 

362 # Setting these to frozenset is at odds with the type annotation, 

363 # but MyPy can't tell because we're using setattr, and we want to 

364 # lie to it anyway to get runtime guards against post-__init__ 

365 # mutation. 

366 setattr(instance, attrName, frozenset(updated_connection_names)) 

367 # Update the existing dict in place, since we already have a view of 

368 # that. 

369 instance._allConnections.clear() 

370 instance._allConnections.update(updated_all_connections) 

371 

372 for connection_name, obj in instance._allConnections.items(): 

373 if obj.deprecated is not None: 

374 warnings.warn( 

375 f"Connection {connection_name} with datasetType {obj.name} " 

376 f"(from {obj._deprecation_context}): {obj.deprecated}", 

377 FutureWarning, 

378 stacklevel=1, # Report from this location. 

379 ) 

380 

381 # Freeze the connection instance dimensions now. This at odds with the 

382 # type annotation, which says [mutable] `set`, just like the connection 

383 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

384 # tell with those since we're using setattr for them. 

385 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

386 

387 return instance 

388 

389 

390class QuantizedConnection(SimpleNamespace): 

391 r"""A Namespace to map defined variable names of connections to the 

392 associated `lsst.daf.butler.DatasetRef` objects. 

393 

394 This class maps the names used to define a connection on a 

395 `PipelineTaskConnections` class to the corresponding 

396 `~lsst.daf.butler.DatasetRef`\s provided by a `~lsst.daf.butler.Quantum` 

397 instance. This will be a quantum of execution based on the graph created 

398 by examining all the connections defined on the 

399 `PipelineTaskConnections` class. 

400 """ 

401 

402 def __init__(self, **kwargs): 

403 # Create a variable to track what attributes are added. This is used 

404 # later when iterating over this QuantizedConnection instance 

405 object.__setattr__(self, "_attributes", set()) 

406 

407 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

408 # Capture the attribute name as it is added to this object 

409 self._attributes.add(name) 

410 super().__setattr__(name, value) 

411 

412 def __delattr__(self, name): 

413 object.__delattr__(self, name) 

414 self._attributes.remove(name) 

415 

416 def __len__(self) -> int: 

417 return len(self._attributes) 

418 

419 def __iter__( 

420 self, 

421 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

422 """Make an iterator for this `QuantizedConnection`. 

423 

424 Iterating over a `QuantizedConnection` will yield a tuple with the name 

425 of an attribute and the value associated with that name. This is 

426 similar to dict.items() but is on the namespace attributes rather than 

427 dict keys. 

428 """ 

429 yield from ((name, getattr(self, name)) for name in self._attributes) 

430 

431 def keys(self) -> Generator[str, None, None]: 

432 """Return an iterator over all the attributes added to a 

433 `QuantizedConnection` class 

434 """ 

435 yield from self._attributes 

436 

437 

438class InputQuantizedConnection(QuantizedConnection): 

439 """Input variant of a `QuantizedConnection`.""" 

440 

441 pass 

442 

443 

444class OutputQuantizedConnection(QuantizedConnection): 

445 """Output variant of a `QuantizedConnection`.""" 

446 

447 pass 

448 

449 

450@dataclass(frozen=True) 

451class DeferredDatasetRef: 

452 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

453 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

454 instead of an in-memory dataset. 

455 

456 Parameters 

457 ---------- 

458 datasetRef : `lsst.daf.butler.DatasetRef` 

459 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

460 resolve a dataset 

461 """ 

462 

463 datasetRef: DatasetRef 

464 

465 @property 

466 def datasetType(self) -> DatasetType: 

467 """The dataset type for this dataset.""" 

468 return self.datasetRef.datasetType 

469 

470 @property 

471 def dataId(self) -> DataCoordinate: 

472 """The data ID for this dataset.""" 

473 return self.datasetRef.dataId 

474 

475 

476class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

477 """PipelineTaskConnections is a class used to declare desired IO when a 

478 PipelineTask is run by an activator 

479 

480 Parameters 

481 ---------- 

482 config : `PipelineTaskConfig` 

483 A `PipelineTaskConfig` class instance whose class has been configured 

484 to use this `PipelineTaskConnections` class. 

485 

486 See Also 

487 -------- 

488 iterConnections 

489 

490 Notes 

491 ----- 

492 ``PipelineTaskConnection`` classes are created by declaring class 

493 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

494 listed as follows: 

495 

496 * ``InitInput`` - Defines connections in a quantum graph which are used as 

497 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

498 to this class 

499 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

500 persisted using a butler at the end of the ``__init__`` function of the 

501 `PipelineTask` corresponding to this class. The variable name used to 

502 define this connection should be the same as an attribute name on the 

503 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

504 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

505 a `PipelineTask` instance should have an attribute 

506 ``self.outputSchema`` defined. Its value is what will be saved by the 

507 activator framework. 

508 * ``PrerequisiteInput`` - An input connection type that defines a 

509 `lsst.daf.butler.DatasetType` that must be present at execution time, 

510 but that will not be used during the course of creating the quantum 

511 graph to be executed. These most often are things produced outside the 

512 processing pipeline, such as reference catalogs. 

513 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

514 in the ``run`` method of a `PipelineTask`. The name used to declare 

515 class attribute must match a function argument name in the ``run`` 

516 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

517 defines an ``Input`` with the name ``calexp``, then the corresponding 

518 signature should be ``PipelineTask.run(calexp, ...)`` 

519 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

520 execution of a `PipelineTask`. The name used to declare the connection 

521 must correspond to an attribute of a `Struct` that is returned by a 

522 `PipelineTask` ``run`` method. E.g. if an output connection is 

523 defined with the name ``measCat``, then the corresponding 

524 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

525 X matches the ``storageClass`` type defined on the output connection. 

526 

527 Attributes of these types can also be created, replaced, or deleted on the 

528 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

529 just the name depends on the configuration. It is preferred to define them 

530 in the class when possible (even if configuration may cause the connection 

531 to be removed from the instance). 

532 

533 The process of declaring a ``PipelineTaskConnection`` class involves 

534 parameters passed in the declaration statement. 

535 

536 The first parameter is ``dimensions`` which is an iterable of strings which 

537 defines the unit of processing the run method of a corresponding 

538 `PipelineTask` will operate on. These dimensions must match dimensions that 

539 exist in the butler registry which will be used in executing the 

540 corresponding `PipelineTask`. The dimensions may be also modified in 

541 subclass ``__init__`` methods if they need to depend on configuration. 

542 

543 The second parameter is labeled ``defaultTemplates`` and is conditionally 

544 optional. The name attributes of connections can be specified as python 

545 format strings, with named format arguments. If any of the name parameters 

546 on connections defined in a `PipelineTaskConnections` class contain a 

547 template, then a default template value must be specified in the 

548 ``defaultTemplates`` argument. This is done by passing a dictionary with 

549 keys corresponding to a template identifier, and values corresponding to 

550 the value to use as a default when formatting the string. For example if 

551 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

552 ``defaultTemplates`` = {'input': 'deep'}. 

553 

554 Once a `PipelineTaskConnections` class is created, it is used in the 

555 creation of a `PipelineTaskConfig`. This is further documented in the 

556 documentation of `PipelineTaskConfig`. For the purposes of this 

557 documentation, the relevant information is that the config class allows 

558 configuration of connection names by users when running a pipeline. 

559 

560 Instances of a `PipelineTaskConnections` class are used by the pipeline 

561 task execution framework to introspect what a corresponding `PipelineTask` 

562 will require, and what it will produce. 

563 

564 Examples 

565 -------- 

566 >>> from lsst.pipe.base import connectionTypes as cT 

567 >>> from lsst.pipe.base import PipelineTaskConnections 

568 >>> from lsst.pipe.base import PipelineTaskConfig 

569 >>> class ExampleConnections(PipelineTaskConnections, 

570 ... dimensions=("A", "B"), 

571 ... defaultTemplates={"foo": "Example"}): 

572 ... inputConnection = cT.Input(doc="Example input", 

573 ... dimensions=("A", "B"), 

574 ... storageClass=Exposure, 

575 ... name="{foo}Dataset") 

576 ... outputConnection = cT.Output(doc="Example output", 

577 ... dimensions=("A", "B"), 

578 ... storageClass=Exposure, 

579 ... name="{foo}output") 

580 >>> class ExampleConfig(PipelineTaskConfig, 

581 ... pipelineConnections=ExampleConnections): 

582 ... pass 

583 >>> config = ExampleConfig() 

584 >>> config.connections.foo = Modified 

585 >>> config.connections.outputConnection = "TotallyDifferent" 

586 >>> connections = ExampleConnections(config=config) 

587 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

588 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

589 """ 

590 

591 # We annotate these attributes as mutable sets because that's what they are 

592 # inside derived ``__init__`` implementations and that's what matters most 

593 # After that's done, the metaclass __call__ makes them into frozensets, but 

594 # relatively little code interacts with them then, and that code knows not 

595 # to try to modify them without having to be told that by mypy. 

596 

597 dimensions: set[str] 

598 """Set of dimension names that define the unit of work for this task. 

599 

600 Required and implied dependencies will automatically be expanded later and 

601 need not be provided. 

602 

603 This may be replaced or modified in ``__init__`` to change the dimensions 

604 of the task. After ``__init__`` it will be a `frozenset` and may not be 

605 replaced. 

606 """ 

607 

608 inputs: set[str] 

609 """Set with the names of all `connectionTypes.Input` connection attributes. 

610 

611 This is updated automatically as class attributes are added, removed, or 

612 replaced in ``__init__``. Removing entries from this set will cause those 

613 connections to be removed after ``__init__`` completes, but this is 

614 supported only for backwards compatibility; new code should instead just 

615 delete the collection attributed directly. After ``__init__`` this will be 

616 a `frozenset` and may not be replaced. 

617 """ 

618 

619 prerequisiteInputs: set[str] 

620 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

621 connection attributes. 

622 

623 See `inputs` for additional information. 

624 """ 

625 

626 outputs: set[str] 

627 """Set with the names of all `~connectionTypes.Output` connection 

628 attributes. 

629 

630 See `inputs` for additional information. 

631 """ 

632 

633 initInputs: set[str] 

634 """Set with the names of all `~connectionTypes.InitInput` connection 

635 attributes. 

636 

637 See `inputs` for additional information. 

638 """ 

639 

640 initOutputs: set[str] 

641 """Set with the names of all `~connectionTypes.InitOutput` connection 

642 attributes. 

643 

644 See `inputs` for additional information. 

645 """ 

646 

647 allConnections: Mapping[str, BaseConnection] 

648 """Mapping holding all connection attributes. 

649 

650 This is a read-only view that is automatically updated when connection 

651 attributes are added, removed, or replaced in ``__init__``. It is also 

652 updated after ``__init__`` completes to reflect changes in `inputs`, 

653 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

654 """ 

655 

656 _allConnections: dict[str, BaseConnection] 

657 

658 def __init__(self, *, config: PipelineTaskConfig | None = None): 

659 pass 

660 

661 def __setattr__(self, name: str, value: Any) -> None: 

662 if isinstance(value, BaseConnection): 

663 previous = self._allConnections.get(name) 

664 try: 

665 getattr(self, value._connection_type_set).add(name) 

666 except AttributeError: 

667 # Attempt to call add on a frozenset, which is what these sets 

668 # are after __init__ is done. 

669 raise TypeError("Connections objects are frozen after construction.") from None 

670 if previous is not None and value._connection_type_set != previous._connection_type_set: 

671 # Connection has changed type, e.g. Input to PrerequisiteInput; 

672 # update the sets accordingly. To be extra defensive about 

673 # multiple assignments we use the type of the previous instance 

674 # instead of assuming that's the same as the type of the self, 

675 # which is just the default. Use discard instead of remove so 

676 # manually removing from these sets first is never an error. 

677 getattr(self, previous._connection_type_set).discard(name) 

678 self._allConnections[name] = value 

679 if hasattr(self.__class__, name): 

680 # Don't actually set the attribute if this was a connection 

681 # declared in the class; in that case we let the descriptor 

682 # return the value we just added to allConnections. 

683 return 

684 # Actually add the attribute. 

685 super().__setattr__(name, value) 

686 

687 def __delattr__(self, name): 

688 """Descriptor delete method.""" 

689 previous = self._allConnections.get(name) 

690 if previous is not None: 

691 # Delete this connection's name from the appropriate set, which we 

692 # have to get from the previous instance instead of assuming it's 

693 # the same set that was appropriate for the class-level default. 

694 # Use discard instead of remove so manually removing from these 

695 # sets first is never an error. 

696 try: 

697 getattr(self, previous._connection_type_set).discard(name) 

698 except AttributeError: 

699 # Attempt to call discard on a frozenset, which is what these 

700 # sets are after __init__ is done. 

701 raise TypeError("Connections objects are frozen after construction.") from None 

702 del self._allConnections[name] 

703 if hasattr(self.__class__, name): 

704 # Don't actually delete the attribute if this was a connection 

705 # declared in the class; in that case we let the descriptor 

706 # see that it's no longer present in allConnections. 

707 return 

708 # Actually delete the attribute. 

709 super().__delattr__(name) 

710 

711 def buildDatasetRefs( 

712 self, quantum: Quantum 

713 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

714 """Build `QuantizedConnection` corresponding to input 

715 `~lsst.daf.butler.Quantum`. 

716 

717 Parameters 

718 ---------- 

719 quantum : `lsst.daf.butler.Quantum` 

720 Quantum object which defines the inputs and outputs for a given 

721 unit of processing. 

722 

723 Returns 

724 ------- 

725 retVal : `tuple` of (`InputQuantizedConnection`, 

726 `OutputQuantizedConnection`) Namespaces mapping attribute names 

727 (identifiers of connections) to butler references defined in the 

728 input `lsst.daf.butler.Quantum`. 

729 """ 

730 inputDatasetRefs = InputQuantizedConnection() 

731 outputDatasetRefs = OutputQuantizedConnection() 

732 # operate on a reference object and an iterable of names of class 

733 # connection attributes 

734 for refs, names in zip( 

735 (inputDatasetRefs, outputDatasetRefs), 

736 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

737 strict=True, 

738 ): 

739 # get a name of a class connection attribute 

740 for attributeName in names: 

741 # get the attribute identified by name 

742 attribute = getattr(self, attributeName) 

743 # Branch if the attribute dataset type is an input 

744 if attribute.name in quantum.inputs: 

745 # if the dataset is marked to load deferred, wrap it in a 

746 # DeferredDatasetRef 

747 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

748 if attribute.deferLoad: 

749 quantumInputRefs = [ 

750 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

751 ] 

752 else: 

753 quantumInputRefs = list(quantum.inputs[attribute.name]) 

754 # Unpack arguments that are not marked multiples (list of 

755 # length one) 

756 if not attribute.multiple: 

757 if len(quantumInputRefs) > 1: 

758 raise ScalarError( 

759 "Received multiple datasets " 

760 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

761 f"for scalar connection {attributeName} " 

762 f"({quantumInputRefs[0].datasetType.name}) " 

763 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

764 ) 

765 if len(quantumInputRefs) == 0: 

766 continue 

767 setattr(refs, attributeName, quantumInputRefs[0]) 

768 else: 

769 # Add to the QuantizedConnection identifier 

770 setattr(refs, attributeName, quantumInputRefs) 

771 # Branch if the attribute dataset type is an output 

772 elif attribute.name in quantum.outputs: 

773 value = quantum.outputs[attribute.name] 

774 # Unpack arguments that are not marked multiples (list of 

775 # length one) 

776 if not attribute.multiple: 

777 setattr(refs, attributeName, value[0]) 

778 else: 

779 setattr(refs, attributeName, value) 

780 # Specified attribute is not in inputs or outputs dont know how 

781 # to handle, throw 

782 else: 

783 raise ValueError( 

784 f"Attribute with name {attributeName} has no counterpart in input quantum" 

785 ) 

786 return inputDatasetRefs, outputDatasetRefs 

787 

788 def adjustQuantum( 

789 self, 

790 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

791 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

792 label: str, 

793 data_id: DataCoordinate, 

794 ) -> tuple[ 

795 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

796 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

797 ]: 

798 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

799 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

800 of the activator. 

801 

802 Parameters 

803 ---------- 

804 inputs : `dict` 

805 Dictionary whose keys are an input (regular or prerequisite) 

806 connection name and whose values are a tuple of the connection 

807 instance and a collection of associated 

808 `~lsst.daf.butler.DatasetRef` objects. 

809 The exact type of the nested collections is unspecified; it can be 

810 assumed to be multi-pass iterable and support `len` and ``in``, but 

811 it should not be mutated in place. In contrast, the outer 

812 dictionaries are guaranteed to be temporary copies that are true 

813 `dict` instances, and hence may be modified and even returned; this 

814 is especially useful for delegating to `super` (see notes below). 

815 outputs : `~collections.abc.Mapping` 

816 Mapping of output datasets, with the same structure as ``inputs``. 

817 label : `str` 

818 Label for this task in the pipeline (should be used in all 

819 diagnostic messages). 

820 data_id : `lsst.daf.butler.DataCoordinate` 

821 Data ID for this quantum in the pipeline (should be used in all 

822 diagnostic messages). 

823 

824 Returns 

825 ------- 

826 adjusted_inputs : `~collections.abc.Mapping` 

827 Mapping of the same form as ``inputs`` with updated containers of 

828 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

829 not changed should not be returned at all. Datasets may only be 

830 removed, not added. Nested collections may be of any multi-pass 

831 iterable type, and the order of iteration will set the order of 

832 iteration within `PipelineTask.runQuantum`. 

833 adjusted_outputs : `~collections.abc.Mapping` 

834 Mapping of updated output datasets, with the same structure and 

835 interpretation as ``adjusted_inputs``. 

836 

837 Raises 

838 ------ 

839 ScalarError 

840 Raised if any `Input` or `PrerequisiteInput` connection has 

841 ``multiple`` set to `False`, but multiple datasets. 

842 NoWorkFound 

843 Raised to indicate that this quantum should not be run; not enough 

844 datasets were found for a regular `Input` connection, and the 

845 quantum should be pruned or skipped. 

846 FileNotFoundError 

847 Raised to cause QuantumGraph generation to fail (with the message 

848 included in this exception); not enough datasets were found for a 

849 `PrerequisiteInput` connection. 

850 

851 Notes 

852 ----- 

853 The base class implementation performs important checks. It always 

854 returns an empty mapping (i.e. makes no adjustments). It should 

855 always called be via `super` by custom implementations, ideally at the 

856 end of the custom implementation with already-adjusted mappings when 

857 any datasets are actually dropped, e.g.: 

858 

859 .. code-block:: python 

860 

861 def adjustQuantum(self, inputs, outputs, label, data_id): 

862 # Filter out some dataset refs for one connection. 

863 connection, old_refs = inputs["my_input"] 

864 new_refs = [ref for ref in old_refs if ...] 

865 adjusted_inputs = {"my_input", (connection, new_refs)} 

866 # Update the original inputs so we can pass them to super. 

867 inputs.update(adjusted_inputs) 

868 # Can ignore outputs from super because they are guaranteed 

869 # to be empty. 

870 super().adjustQuantum(inputs, outputs, label_data_id) 

871 # Return only the connections we modified. 

872 return adjusted_inputs, {} 

873 

874 Removing outputs here is guaranteed to affect what is actually 

875 passed to `PipelineTask.runQuantum`, but its effect on the larger 

876 graph may be deferred to execution, depending on the context in 

877 which `adjustQuantum` is being run: if one quantum removes an output 

878 that is needed by a second quantum as input, the second quantum may not 

879 be adjusted (and hence pruned or skipped) until that output is actually 

880 found to be missing at execution time. 

881 

882 Tasks that desire zip-iteration consistency between any combinations of 

883 connections that have the same data ID should generally implement 

884 `adjustQuantum` to achieve this, even if they could also run that 

885 logic during execution; this allows the system to see outputs that will 

886 not be produced because the corresponding input is missing as early as 

887 possible. 

888 """ 

889 for name, (input_connection, refs) in inputs.items(): 

890 dataset_type_name = input_connection.name 

891 if not input_connection.multiple and len(refs) > 1: 

892 raise ScalarError( 

893 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

894 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

895 f"for quantum data ID {data_id}." 

896 ) 

897 if len(refs) < input_connection.minimum: 

898 if isinstance(input_connection, PrerequisiteInput): 

899 # This branch should only be possible during QG generation, 

900 # or if someone deleted the dataset between making the QG 

901 # and trying to run it. Either one should be a hard error. 

902 raise FileNotFoundError( 

903 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

904 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

905 f"{data_id}." 

906 ) 

907 else: 

908 raise NoWorkFound(label, name, input_connection) 

909 for name, (output_connection, refs) in outputs.items(): 

910 dataset_type_name = output_connection.name 

911 if not output_connection.multiple and len(refs) > 1: 

912 raise ScalarError( 

913 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

914 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

915 f"for quantum data ID {data_id}." 

916 ) 

917 return {}, {} 

918 

919 def getSpatialBoundsConnections(self) -> Iterable[str]: 

920 """Return the names of regular input and output connections whose data 

921 IDs should be used to compute the spatial bounds of this task's quanta. 

922 

923 The spatial bound for a quantum is defined as the union of the regions 

924 of all data IDs of all connections returned here, along with the region 

925 of the quantum data ID (if the task has spatial dimensions). 

926 

927 Returns 

928 ------- 

929 connection_names : `collections.abc.Iterable` [ `str` ] 

930 Names of collections with spatial dimensions. These are the 

931 task-internal connection names, not butler dataset type names. 

932 

933 Notes 

934 ----- 

935 The spatial bound is used to search for prerequisite inputs that have 

936 skypix dimensions. The default implementation returns an empty 

937 iterable, which is usually sufficient for tasks with spatial 

938 dimensions, but if a task's inputs or outputs are associated with 

939 spatial regions that extend beyond the quantum data ID's region, this 

940 method may need to be overridden to expand the set of prerequisite 

941 inputs found. 

942 

943 Tasks that do not have spatial dimensions that have skypix prerequisite 

944 inputs should always override this method, as the default spatial 

945 bounds otherwise cover the full sky. 

946 """ 

947 return () 

948 

949 def getTemporalBoundsConnections(self) -> Iterable[str]: 

950 """Return the names of regular input and output connections whose data 

951 IDs should be used to compute the temporal bounds of this task's 

952 quanta. 

953 

954 The temporal bound for a quantum is defined as the union of the 

955 timespans of all data IDs of all connections returned here, along with 

956 the timespan of the quantum data ID (if the task has temporal 

957 dimensions). 

958 

959 Returns 

960 ------- 

961 connection_names : `collections.abc.Iterable` [ `str` ] 

962 Names of collections with temporal dimensions. These are the 

963 task-internal connection names, not butler dataset type names. 

964 

965 Notes 

966 ----- 

967 The temporal bound is used to search for prerequisite inputs that are 

968 calibration datasets. The default implementation returns an empty 

969 iterable, which is usually sufficient for tasks with temporal 

970 dimensions, but if a task's inputs or outputs are associated with 

971 timespans that extend beyond the quantum data ID's timespan, this 

972 method may need to be overridden to expand the set of prerequisite 

973 inputs found. 

974 

975 Tasks that do not have temporal dimensions that do not implement this 

976 method will use an infinite timespan for any calibration lookups. 

977 """ 

978 return () 

979 

980 

981def iterConnections( 

982 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

983) -> Generator[BaseConnection, None, None]: 

984 """Create an iterator over the selected connections type which yields 

985 all the defined connections of that type. 

986 

987 Parameters 

988 ---------- 

989 connections : `PipelineTaskConnections` 

990 An instance of a `PipelineTaskConnections` object that will be iterated 

991 over. 

992 connectionType : `str` 

993 The type of connections to iterate over, valid values are inputs, 

994 outputs, prerequisiteInputs, initInputs, initOutputs. 

995 

996 Yields 

997 ------ 

998 connection: `~.connectionTypes.BaseConnection` 

999 A connection defined on the input connections object of the type 

1000 supplied. The yielded value Will be an derived type of 

1001 `~.connectionTypes.BaseConnection`. 

1002 """ 

1003 if isinstance(connectionType, str): 

1004 connectionType = (connectionType,) 

1005 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

1006 yield getattr(connections, name) 

1007 

1008 

1009@dataclass 

1010class AdjustQuantumHelper: 

1011 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

1012 

1013 This class holds `input` and `output` mappings in the form used by 

1014 `Quantum` and execution harness code, i.e. with 

1015 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

1016 connection-oriented mappings used inside `PipelineTaskConnections`. 

1017 """ 

1018 

1019 inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1020 """Mapping of regular input and prerequisite input datasets, grouped by 

1021 `~lsst.daf.butler.DatasetType`. 

1022 """ 

1023 

1024 outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] 

1025 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

1026 """ 

1027 

1028 inputs_adjusted: bool = False 

1029 """Whether any inputs were removed in the last call to `adjust_in_place`. 

1030 """ 

1031 

1032 outputs_adjusted: bool = False 

1033 """Whether any outputs were removed in the last call to `adjust_in_place`. 

1034 """ 

1035 

1036 def adjust_in_place( 

1037 self, 

1038 connections: PipelineTaskConnections, 

1039 label: str, 

1040 data_id: DataCoordinate, 

1041 ) -> None: 

1042 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

1043 with its results. 

1044 

1045 Parameters 

1046 ---------- 

1047 connections : `PipelineTaskConnections` 

1048 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

1049 label : `str` 

1050 Label for this task in the pipeline (should be used in all 

1051 diagnostic messages). 

1052 data_id : `lsst.daf.butler.DataCoordinate` 

1053 Data ID for this quantum in the pipeline (should be used in all 

1054 diagnostic messages). 

1055 """ 

1056 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

1057 # connection-keyed, PipelineTask-oriented mappings. 

1058 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

1059 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

1060 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

1061 connection = getattr(connections, name) 

1062 dataset_type_name = connection.name 

1063 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

1064 for name in itertools.chain(connections.outputs): 

1065 connection = getattr(connections, name) 

1066 dataset_type_name = connection.name 

1067 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

1068 # Actually call adjustQuantum. 

1069 # MyPy correctly complains that this call is not quite legal, but the 

1070 # method docs explain exactly what's expected and it's the behavior we 

1071 # want. It'd be nice to avoid this if we ever have to change the 

1072 # interface anyway, but not an immediate problem. 

1073 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

1074 inputs_by_connection, # type: ignore 

1075 outputs_by_connection, # type: ignore 

1076 label, 

1077 data_id, 

1078 ) 

1079 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

1080 # installing new mappings in self if necessary. 

1081 if adjusted_inputs_by_connection: 

1082 adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) 

1083 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1084 dataset_type_name = connection.name 

1085 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1086 raise RuntimeError( 

1087 f"adjustQuantum implementation for task with label {label} returned {name} " 

1088 f"({dataset_type_name}) input datasets that are not a subset of those " 

1089 f"it was given for data ID {data_id}." 

1090 ) 

1091 adjusted_inputs[dataset_type_name] = tuple(updated_refs) 

1092 self.inputs = adjusted_inputs.freeze() 

1093 self.inputs_adjusted = True 

1094 else: 

1095 self.inputs_adjusted = False 

1096 if adjusted_outputs_by_connection: 

1097 adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) 

1098 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1099 dataset_type_name = connection.name 

1100 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1101 raise RuntimeError( 

1102 f"adjustQuantum implementation for task with label {label} returned {name} " 

1103 f"({dataset_type_name}) output datasets that are not a subset of those " 

1104 f"it was given for data ID {data_id}." 

1105 ) 

1106 adjusted_outputs[dataset_type_name] = tuple(updated_refs) 

1107 self.outputs = adjusted_outputs.freeze() 

1108 self.outputs_adjusted = True 

1109 else: 

1110 self.outputs_adjusted = False