Coverage for python/lsst/pipe/base/connections.py: 42%

288 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-11 02:00 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = [ 

28 "AdjustQuantumHelper", 

29 "DeferredDatasetRef", 

30 "InputQuantizedConnection", 

31 "OutputQuantizedConnection", 

32 "PipelineTaskConnections", 

33 "ScalarError", 

34 "iterConnections", 

35 "ScalarError", 

36] 

37 

38import dataclasses 

39import itertools 

40import string 

41from collections import UserDict 

42from collections.abc import Collection, Generator, Iterable, Mapping, Set 

43from dataclasses import dataclass 

44from types import MappingProxyType, SimpleNamespace 

45from typing import TYPE_CHECKING, Any 

46 

47from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum 

48 

49from ._status import NoWorkFound 

50from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput 

51 

52if TYPE_CHECKING: 

53 from .config import PipelineTaskConfig 

54 

55 

56class ScalarError(TypeError): 

57 """Exception raised when dataset type is configured as scalar 

58 but there are multiple data IDs in a Quantum for that dataset. 

59 """ 

60 

61 

62class PipelineTaskConnectionDict(UserDict): 

63 """A special dict class used by `PipelineTaskConnectionMetaclass`. 

64 

65 This dict is used in `PipelineTaskConnection` class creation, as the 

66 dictionary that is initially used as ``__dict__``. It exists to 

67 intercept connection fields declared in a `PipelineTaskConnection`, and 

68 what name is used to identify them. The names are then added to class 

69 level list according to the connection type of the class attribute. The 

70 names are also used as keys in a class level dictionary associated with 

71 the corresponding class attribute. This information is a duplicate of 

72 what exists in ``__dict__``, but provides a simple place to lookup and 

73 iterate on only these variables. 

74 """ 

75 

76 def __init__(self, *args: Any, **kwargs: Any): 

77 super().__init__(*args, **kwargs) 

78 # Initialize class level variables used to track any declared 

79 # class level variables that are instances of 

80 # connectionTypes.BaseConnection 

81 self.data["inputs"] = set() 

82 self.data["prerequisiteInputs"] = set() 

83 self.data["outputs"] = set() 

84 self.data["initInputs"] = set() 

85 self.data["initOutputs"] = set() 

86 self.data["allConnections"] = {} 

87 

88 def __setitem__(self, name: str, value: Any) -> None: 

89 if isinstance(value, BaseConnection): 

90 if name in { 90 ↛ 100line 90 didn't jump to line 100, because the condition on line 90 was never true

91 "dimensions", 

92 "inputs", 

93 "prerequisiteInputs", 

94 "outputs", 

95 "initInputs", 

96 "initOutputs", 

97 "allConnections", 

98 }: 

99 # Guard against connections whose names are reserved. 

100 raise AttributeError(f"Connection name {name!r} is reserved for internal use.") 

101 if (previous := self.data.get(name)) is not None: 101 ↛ 104line 101 didn't jump to line 104, because the condition on line 101 was never true

102 # Guard against changing the type of an in inherited connection 

103 # by first removing it from the set it's current in. 

104 self.data[previous._connection_type_set].discard(name) 

105 object.__setattr__(value, "varName", name) 

106 self.data["allConnections"][name] = value 

107 self.data[value._connection_type_set].add(name) 

108 # defer to the default behavior 

109 super().__setitem__(name, value) 

110 

111 

112class PipelineTaskConnectionsMetaclass(type): 

113 """Metaclass used in the declaration of PipelineTaskConnections classes""" 

114 

115 # We can annotate these attributes as `collections.abc.Set` to discourage 

116 # undesirable modifications in type-checked code, since the internal code 

117 # modifying them is in `PipelineTaskConnectionDict` and that doesn't see 

118 # these annotations anyway. 

119 

120 dimensions: Set[str] 

121 """Set of dimension names that define the unit of work for this task. 

122 

123 Required and implied dependencies will automatically be expanded later and 

124 need not be provided. 

125 

126 This is shadowed by an instance-level attribute on 

127 `PipelineTaskConnections` instances. 

128 """ 

129 

130 inputs: Set[str] 

131 """Set with the names of all `~connectionTypes.Input` connection 

132 attributes. 

133 

134 This is updated automatically as class attributes are added. Note that 

135 this attribute is shadowed by an instance-level attribute on 

136 `PipelineTaskConnections` instances. 

137 """ 

138 

139 prerequisiteInputs: Set[str] 

140 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

141 connection attributes. 

142 

143 See `inputs` for additional information. 

144 """ 

145 

146 outputs: Set[str] 

147 """Set with the names of all `~connectionTypes.Output` connection 

148 attributes. 

149 

150 See `inputs` for additional information. 

151 """ 

152 

153 initInputs: Set[str] 

154 """Set with the names of all `~connectionTypes.InitInput` connection 

155 attributes. 

156 

157 See `inputs` for additional information. 

158 """ 

159 

160 initOutputs: Set[str] 

161 """Set with the names of all `~connectionTypes.InitOutput` connection 

162 attributes. 

163 

164 See `inputs` for additional information. 

165 """ 

166 

167 allConnections: Mapping[str, BaseConnection] 

168 """Mapping containing all connection attributes. 

169 

170 See `inputs` for additional information. 

171 """ 

172 

173 def __prepare__(name, bases, **kwargs): # noqa: 805 

174 # Create an instance of our special dict to catch and track all 

175 # variables that are instances of connectionTypes.BaseConnection 

176 # Copy any existing connections from a parent class 

177 dct = PipelineTaskConnectionDict() 

178 for base in bases: 

179 if isinstance(base, PipelineTaskConnectionsMetaclass): 179 ↛ 178line 179 didn't jump to line 178, because the condition on line 179 was never false

180 for name, value in base.allConnections.items(): 180 ↛ 181line 180 didn't jump to line 181, because the loop on line 180 never started

181 dct[name] = value 

182 return dct 

183 

184 def __new__(cls, name, bases, dct, **kwargs): 

185 dimensionsValueError = TypeError( 

186 "PipelineTaskConnections class must be created with a dimensions " 

187 "attribute which is an iterable of dimension names" 

188 ) 

189 

190 if name != "PipelineTaskConnections": 

191 # Verify that dimensions are passed as a keyword in class 

192 # declaration 

193 if "dimensions" not in kwargs: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 for base in bases: 

195 if hasattr(base, "dimensions"): 

196 kwargs["dimensions"] = base.dimensions 

197 break 

198 if "dimensions" not in kwargs: 

199 raise dimensionsValueError 

200 try: 

201 if isinstance(kwargs["dimensions"], str): 201 ↛ 202line 201 didn't jump to line 202, because the condition on line 201 was never true

202 raise TypeError( 

203 "Dimensions must be iterable of dimensions, got str,possibly omitted trailing comma" 

204 ) 

205 if not isinstance(kwargs["dimensions"], Iterable): 205 ↛ 206line 205 didn't jump to line 206, because the condition on line 205 was never true

206 raise TypeError("Dimensions must be iterable of dimensions") 

207 dct["dimensions"] = set(kwargs["dimensions"]) 

208 except TypeError as exc: 

209 raise dimensionsValueError from exc 

210 # Lookup any python string templates that may have been used in the 

211 # declaration of the name field of a class connection attribute 

212 allTemplates = set() 

213 stringFormatter = string.Formatter() 

214 # Loop over all connections 

215 for obj in dct["allConnections"].values(): 

216 nameValue = obj.name 

217 # add all the parameters to the set of templates 

218 for param in stringFormatter.parse(nameValue): 

219 if param[1] is not None: 

220 allTemplates.add(param[1]) 

221 

222 # look up any template from base classes and merge them all 

223 # together 

224 mergeDict = {} 

225 for base in bases[::-1]: 

226 if hasattr(base, "defaultTemplates"): 226 ↛ 227line 226 didn't jump to line 227, because the condition on line 226 was never true

227 mergeDict.update(base.defaultTemplates) 

228 if "defaultTemplates" in kwargs: 

229 mergeDict.update(kwargs["defaultTemplates"]) 

230 

231 if len(mergeDict) > 0: 

232 kwargs["defaultTemplates"] = mergeDict 

233 

234 # Verify that if templated strings were used, defaults were 

235 # supplied as an argument in the declaration of the connection 

236 # class 

237 if len(allTemplates) > 0 and "defaultTemplates" not in kwargs: 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true

238 raise TypeError( 

239 "PipelineTaskConnection class contains templated attribute names, but no " 

240 "defaut templates were provided, add a dictionary attribute named " 

241 "defaultTemplates which contains the mapping between template key and value" 

242 ) 

243 if len(allTemplates) > 0: 

244 # Verify all templates have a default, and throw if they do not 

245 defaultTemplateKeys = set(kwargs["defaultTemplates"].keys()) 

246 templateDifference = allTemplates.difference(defaultTemplateKeys) 

247 if templateDifference: 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true

248 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

249 # Verify that templates do not share names with variable names 

250 # used for a connection, this is needed because of how 

251 # templates are specified in an associated config class. 

252 nameTemplateIntersection = allTemplates.intersection(set(dct["allConnections"].keys())) 

253 if len(nameTemplateIntersection) > 0: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true

254 raise TypeError( 

255 "Template parameters cannot share names with Class attributes" 

256 f" (conflicts are {nameTemplateIntersection})." 

257 ) 

258 dct["defaultTemplates"] = kwargs.get("defaultTemplates", {}) 

259 

260 # Convert all the connection containers into frozensets so they cannot 

261 # be modified at the class scope 

262 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

263 dct[connectionName] = frozenset(dct[connectionName]) 

264 # our custom dict type must be turned into an actual dict to be used in 

265 # type.__new__ 

266 return super().__new__(cls, name, bases, dict(dct)) 

267 

268 def __init__(cls, name, bases, dct, **kwargs): 

269 # This overrides the default init to drop the kwargs argument. Python 

270 # metaclasses will have this argument set if any kwargs are passes at 

271 # class construction time, but should be consumed before calling 

272 # __init__ on the type metaclass. This is in accordance with python 

273 # documentation on metaclasses 

274 super().__init__(name, bases, dct) 

275 

276 def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskConnections: 

277 # MyPy appears not to really understand metaclass.__call__ at all, so 

278 # we need to tell it to ignore __new__ and __init__ calls here. 

279 instance: PipelineTaskConnections = cls.__new__(cls) # type: ignore 

280 

281 # Make mutable copies of all set-like class attributes so derived 

282 # __init__ implementations can modify them in place. 

283 instance.dimensions = set(cls.dimensions) 

284 instance.inputs = set(cls.inputs) 

285 instance.prerequisiteInputs = set(cls.prerequisiteInputs) 

286 instance.outputs = set(cls.outputs) 

287 instance.initInputs = set(cls.initInputs) 

288 instance.initOutputs = set(cls.initOutputs) 

289 

290 # Set self.config. It's a bit strange that we claim to accept None but 

291 # really just raise here, but it's not worth changing now. 

292 from .config import PipelineTaskConfig # local import to avoid cycle 

293 

294 if config is None or not isinstance(config, PipelineTaskConfig): 

295 raise ValueError( 

296 "PipelineTaskConnections must be instantiated with a PipelineTaskConfig instance" 

297 ) 

298 instance.config = config 

299 

300 # Extract the template names that were defined in the config instance 

301 # by looping over the keys of the defaultTemplates dict specified at 

302 # class declaration time. 

303 templateValues = { 

304 name: getattr(config.connections, name) for name in getattr(cls, "defaultTemplates").keys() 

305 } 

306 

307 # We now assemble a mapping of all connection instances keyed by 

308 # internal name, applying the configuration and templates to make new 

309 # configurations from the class-attribute defaults. This will be 

310 # private, but with a public read-only view. This mapping is what the 

311 # descriptor interface of the class-level attributes will return when 

312 # they are accessed on an instance. This is better than just assigning 

313 # regular instance attributes as it makes it so removed connections 

314 # cannot be accessed on instances, instead of having access to them 

315 # silent fall through to the not-removed class connection instance. 

316 instance._allConnections = {} 

317 instance.allConnections = MappingProxyType(instance._allConnections) 

318 for internal_name, connection in cls.allConnections.items(): 

319 dataset_type_name = getattr(config.connections, internal_name).format(**templateValues) 

320 instance_connection = dataclasses.replace(connection, name=dataset_type_name) 

321 instance._allConnections[internal_name] = instance_connection 

322 

323 # Finally call __init__. The base class implementation does nothing; 

324 # we could have left some of the above implementation there (where it 

325 # originated), but putting it here instead makes it hard for derived 

326 # class implementors to get things into a weird state by delegating to 

327 # super().__init__ in the wrong place, or by forgetting to do that 

328 # entirely. 

329 instance.__init__(config=config) # type: ignore 

330 

331 # Derived-class implementations may have changed the contents of the 

332 # various kinds-of-connection sets; update allConnections to have keys 

333 # that are a union of all those. We get values for the new 

334 # allConnections from the attributes, since any dynamically added new 

335 # ones will not be present in the old allConnections. Typically those 

336 # getattrs will invoke the descriptors and get things from the old 

337 # allConnections anyway. After processing each set we replace it with 

338 # a frozenset. 

339 updated_all_connections = {} 

340 for attrName in ("initInputs", "prerequisiteInputs", "inputs", "initOutputs", "outputs"): 

341 updated_connection_names = getattr(instance, attrName) 

342 updated_all_connections.update( 

343 {name: getattr(instance, name) for name in updated_connection_names} 

344 ) 

345 # Setting these to frozenset is at odds with the type annotation, 

346 # but MyPy can't tell because we're using setattr, and we want to 

347 # lie to it anyway to get runtime guards against post-__init__ 

348 # mutation. 

349 setattr(instance, attrName, frozenset(updated_connection_names)) 

350 # Update the existing dict in place, since we already have a view of 

351 # that. 

352 instance._allConnections.clear() 

353 instance._allConnections.update(updated_all_connections) 

354 

355 # Freeze the connection instance dimensions now. This at odds with the 

356 # type annotation, which says [mutable] `set`, just like the connection 

357 # type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't 

358 # tell with those since we're using setattr for them. 

359 instance.dimensions = frozenset(instance.dimensions) # type: ignore 

360 

361 return instance 

362 

363 

364class QuantizedConnection(SimpleNamespace): 

365 """A Namespace to map defined variable names of connections to the 

366 associated `lsst.daf.butler.DatasetRef` objects. 

367 

368 This class maps the names used to define a connection on a 

369 `PipelineTaskConnections` class to the corresponding 

370 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

371 instance. This will be a quantum of execution based on the graph created 

372 by examining all the connections defined on the 

373 `PipelineTaskConnections` class. 

374 """ 

375 

376 def __init__(self, **kwargs): 

377 # Create a variable to track what attributes are added. This is used 

378 # later when iterating over this QuantizedConnection instance 

379 object.__setattr__(self, "_attributes", set()) 

380 

381 def __setattr__(self, name: str, value: DatasetRef | list[DatasetRef]) -> None: 

382 # Capture the attribute name as it is added to this object 

383 self._attributes.add(name) 

384 super().__setattr__(name, value) 

385 

386 def __delattr__(self, name): 

387 object.__delattr__(self, name) 

388 self._attributes.remove(name) 

389 

390 def __len__(self) -> int: 

391 return len(self._attributes) 

392 

393 def __iter__( 

394 self, 

395 ) -> Generator[tuple[str, DatasetRef | list[DatasetRef]], None, None]: 

396 """Make an iterator for this `QuantizedConnection`. 

397 

398 Iterating over a `QuantizedConnection` will yield a tuple with the name 

399 of an attribute and the value associated with that name. This is 

400 similar to dict.items() but is on the namespace attributes rather than 

401 dict keys. 

402 """ 

403 yield from ((name, getattr(self, name)) for name in self._attributes) 

404 

405 def keys(self) -> Generator[str, None, None]: 

406 """Return an iterator over all the attributes added to a 

407 `QuantizedConnection` class 

408 """ 

409 yield from self._attributes 

410 

411 

412class InputQuantizedConnection(QuantizedConnection): 

413 """Input variant of a `QuantizedConnection`.""" 

414 

415 pass 

416 

417 

418class OutputQuantizedConnection(QuantizedConnection): 

419 """Output variant of a `QuantizedConnection`.""" 

420 

421 pass 

422 

423 

424@dataclass(frozen=True) 

425class DeferredDatasetRef: 

426 """A wrapper class for `~lsst.daf.butler.DatasetRef` that indicates that a 

427 `PipelineTask` should receive a `~lsst.daf.butler.DeferredDatasetHandle` 

428 instead of an in-memory dataset. 

429 

430 Parameters 

431 ---------- 

432 datasetRef : `lsst.daf.butler.DatasetRef` 

433 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

434 resolve a dataset 

435 """ 

436 

437 datasetRef: DatasetRef 

438 

439 @property 

440 def datasetType(self) -> DatasetType: 

441 """The dataset type for this dataset.""" 

442 return self.datasetRef.datasetType 

443 

444 @property 

445 def dataId(self) -> DataCoordinate: 

446 """The data ID for this dataset.""" 

447 return self.datasetRef.dataId 

448 

449 

450class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

451 """PipelineTaskConnections is a class used to declare desired IO when a 

452 PipelineTask is run by an activator 

453 

454 Parameters 

455 ---------- 

456 config : `PipelineTaskConfig` 

457 A `PipelineTaskConfig` class instance whose class has been configured 

458 to use this `PipelineTaskConnections` class. 

459 

460 See Also 

461 -------- 

462 iterConnections 

463 

464 Notes 

465 ----- 

466 ``PipelineTaskConnection`` classes are created by declaring class 

467 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

468 listed as follows: 

469 

470 * ``InitInput`` - Defines connections in a quantum graph which are used as 

471 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

472 to this class 

473 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

474 persisted using a butler at the end of the ``__init__`` function of the 

475 `PipelineTask` corresponding to this class. The variable name used to 

476 define this connection should be the same as an attribute name on the 

477 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

478 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

479 a `PipelineTask` instance should have an attribute 

480 ``self.outputSchema`` defined. Its value is what will be saved by the 

481 activator framework. 

482 * ``PrerequisiteInput`` - An input connection type that defines a 

483 `lsst.daf.butler.DatasetType` that must be present at execution time, 

484 but that will not be used during the course of creating the quantum 

485 graph to be executed. These most often are things produced outside the 

486 processing pipeline, such as reference catalogs. 

487 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

488 in the ``run`` method of a `PipelineTask`. The name used to declare 

489 class attribute must match a function argument name in the ``run`` 

490 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

491 defines an ``Input`` with the name ``calexp``, then the corresponding 

492 signature should be ``PipelineTask.run(calexp, ...)`` 

493 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

494 execution of a `PipelineTask`. The name used to declare the connection 

495 must correspond to an attribute of a `Struct` that is returned by a 

496 `PipelineTask` ``run`` method. E.g. if an output connection is 

497 defined with the name ``measCat``, then the corresponding 

498 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

499 X matches the ``storageClass`` type defined on the output connection. 

500 

501 Attributes of these types can also be created, replaced, or deleted on the 

502 `PipelineTaskConnections` instance in the ``__init__`` method, if more than 

503 just the name depends on the configuration. It is preferred to define them 

504 in the class when possible (even if configuration may cause the connection 

505 to be removed from the instance). 

506 

507 The process of declaring a ``PipelineTaskConnection`` class involves 

508 parameters passed in the declaration statement. 

509 

510 The first parameter is ``dimensions`` which is an iterable of strings which 

511 defines the unit of processing the run method of a corresponding 

512 `PipelineTask` will operate on. These dimensions must match dimensions that 

513 exist in the butler registry which will be used in executing the 

514 corresponding `PipelineTask`. The dimensions may be also modified in 

515 subclass ``__init__`` methods if they need to depend on configuration. 

516 

517 The second parameter is labeled ``defaultTemplates`` and is conditionally 

518 optional. The name attributes of connections can be specified as python 

519 format strings, with named format arguments. If any of the name parameters 

520 on connections defined in a `PipelineTaskConnections` class contain a 

521 template, then a default template value must be specified in the 

522 ``defaultTemplates`` argument. This is done by passing a dictionary with 

523 keys corresponding to a template identifier, and values corresponding to 

524 the value to use as a default when formatting the string. For example if 

525 ``ConnectionsClass.calexp.name = '{input}Coadd_calexp'`` then 

526 ``defaultTemplates`` = {'input': 'deep'}. 

527 

528 Once a `PipelineTaskConnections` class is created, it is used in the 

529 creation of a `PipelineTaskConfig`. This is further documented in the 

530 documentation of `PipelineTaskConfig`. For the purposes of this 

531 documentation, the relevant information is that the config class allows 

532 configuration of connection names by users when running a pipeline. 

533 

534 Instances of a `PipelineTaskConnections` class are used by the pipeline 

535 task execution framework to introspect what a corresponding `PipelineTask` 

536 will require, and what it will produce. 

537 

538 Examples 

539 -------- 

540 >>> from lsst.pipe.base import connectionTypes as cT 

541 >>> from lsst.pipe.base import PipelineTaskConnections 

542 >>> from lsst.pipe.base import PipelineTaskConfig 

543 >>> class ExampleConnections(PipelineTaskConnections, 

544 ... dimensions=("A", "B"), 

545 ... defaultTemplates={"foo": "Example"}): 

546 ... inputConnection = cT.Input(doc="Example input", 

547 ... dimensions=("A", "B"), 

548 ... storageClass=Exposure, 

549 ... name="{foo}Dataset") 

550 ... outputConnection = cT.Output(doc="Example output", 

551 ... dimensions=("A", "B"), 

552 ... storageClass=Exposure, 

553 ... name="{foo}output") 

554 >>> class ExampleConfig(PipelineTaskConfig, 

555 ... pipelineConnections=ExampleConnections): 

556 ... pass 

557 >>> config = ExampleConfig() 

558 >>> config.connections.foo = Modified 

559 >>> config.connections.outputConnection = "TotallyDifferent" 

560 >>> connections = ExampleConnections(config=config) 

561 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

562 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

563 """ 

564 

565 # We annotate these attributes as mutable sets because that's what they are 

566 # inside derived ``__init__`` implementations and that's what matters most 

567 # After that's done, the metaclass __call__ makes them into frozensets, but 

568 # relatively little code interacts with them then, and that code knows not 

569 # to try to modify them without having to be told that by mypy. 

570 

571 dimensions: set[str] 

572 """Set of dimension names that define the unit of work for this task. 

573 

574 Required and implied dependencies will automatically be expanded later and 

575 need not be provided. 

576 

577 This may be replaced or modified in ``__init__`` to change the dimensions 

578 of the task. After ``__init__`` it will be a `frozenset` and may not be 

579 replaced. 

580 """ 

581 

582 inputs: set[str] 

583 """Set with the names of all `connectionTypes.Input` connection attributes. 

584 

585 This is updated automatically as class attributes are added, removed, or 

586 replaced in ``__init__``. Removing entries from this set will cause those 

587 connections to be removed after ``__init__`` completes, but this is 

588 supported only for backwards compatibility; new code should instead just 

589 delete the collection attributed directly. After ``__init__`` this will be 

590 a `frozenset` and may not be replaced. 

591 """ 

592 

593 prerequisiteInputs: set[str] 

594 """Set with the names of all `~connectionTypes.PrerequisiteInput` 

595 connection attributes. 

596 

597 See `inputs` for additional information. 

598 """ 

599 

600 outputs: set[str] 

601 """Set with the names of all `~connectionTypes.Output` connection 

602 attributes. 

603 

604 See `inputs` for additional information. 

605 """ 

606 

607 initInputs: set[str] 

608 """Set with the names of all `~connectionTypes.InitInput` connection 

609 attributes. 

610 

611 See `inputs` for additional information. 

612 """ 

613 

614 initOutputs: set[str] 

615 """Set with the names of all `~connectionTypes.InitOutput` connection 

616 attributes. 

617 

618 See `inputs` for additional information. 

619 """ 

620 

621 allConnections: Mapping[str, BaseConnection] 

622 """Mapping holding all connection attributes. 

623 

624 This is a read-only view that is automatically updated when connection 

625 attributes are added, removed, or replaced in ``__init__``. It is also 

626 updated after ``__init__`` completes to reflect changes in `inputs`, 

627 `prerequisiteInputs`, `outputs`, `initInputs`, and `initOutputs`. 

628 """ 

629 

630 _allConnections: dict[str, BaseConnection] 

631 

632 def __init__(self, *, config: PipelineTaskConfig | None = None): 

633 pass 

634 

635 def __setattr__(self, name: str, value: Any) -> None: 

636 if isinstance(value, BaseConnection): 

637 previous = self._allConnections.get(name) 

638 try: 

639 getattr(self, value._connection_type_set).add(name) 

640 except AttributeError: 

641 # Attempt to call add on a frozenset, which is what these sets 

642 # are after __init__ is done. 

643 raise TypeError("Connections objects are frozen after construction.") from None 

644 if previous is not None and value._connection_type_set != previous._connection_type_set: 

645 # Connection has changed type, e.g. Input to PrerequisiteInput; 

646 # update the sets accordingly. To be extra defensive about 

647 # multiple assignments we use the type of the previous instance 

648 # instead of assuming that's the same as the type of the self, 

649 # which is just the default. Use discard instead of remove so 

650 # manually removing from these sets first is never an error. 

651 getattr(self, previous._connection_type_set).discard(name) 

652 self._allConnections[name] = value 

653 if hasattr(self.__class__, name): 

654 # Don't actually set the attribute if this was a connection 

655 # declared in the class; in that case we let the descriptor 

656 # return the value we just added to allConnections. 

657 return 

658 # Actually add the attribute. 

659 super().__setattr__(name, value) 

660 

661 def __delattr__(self, name): 

662 """Descriptor delete method.""" 

663 previous = self._allConnections.get(name) 

664 if previous is not None: 

665 # Delete this connection's name from the appropriate set, which we 

666 # have to get from the previous instance instead of assuming it's 

667 # the same set that was appropriate for the class-level default. 

668 # Use discard instead of remove so manually removing from these 

669 # sets first is never an error. 

670 try: 

671 getattr(self, previous._connection_type_set).discard(name) 

672 except AttributeError: 

673 # Attempt to call discard on a frozenset, which is what these 

674 # sets are after __init__ is done. 

675 raise TypeError("Connections objects are frozen after construction.") from None 

676 del self._allConnections[name] 

677 if hasattr(self.__class__, name): 

678 # Don't actually delete the attribute if this was a connection 

679 # declared in the class; in that case we let the descriptor 

680 # see that it's no longer present in allConnections. 

681 return 

682 # Actually delete the attribute. 

683 super().__delattr__(name) 

684 

685 def buildDatasetRefs( 

686 self, quantum: Quantum 

687 ) -> tuple[InputQuantizedConnection, OutputQuantizedConnection]: 

688 """Build `QuantizedConnection` corresponding to input 

689 `~lsst.daf.butler.Quantum`. 

690 

691 Parameters 

692 ---------- 

693 quantum : `lsst.daf.butler.Quantum` 

694 Quantum object which defines the inputs and outputs for a given 

695 unit of processing. 

696 

697 Returns 

698 ------- 

699 retVal : `tuple` of (`InputQuantizedConnection`, 

700 `OutputQuantizedConnection`) Namespaces mapping attribute names 

701 (identifiers of connections) to butler references defined in the 

702 input `lsst.daf.butler.Quantum`. 

703 """ 

704 inputDatasetRefs = InputQuantizedConnection() 

705 outputDatasetRefs = OutputQuantizedConnection() 

706 # operate on a reference object and an interable of names of class 

707 # connection attributes 

708 for refs, names in zip( 

709 (inputDatasetRefs, outputDatasetRefs), 

710 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), 

711 ): 

712 # get a name of a class connection attribute 

713 for attributeName in names: 

714 # get the attribute identified by name 

715 attribute = getattr(self, attributeName) 

716 # Branch if the attribute dataset type is an input 

717 if attribute.name in quantum.inputs: 

718 # if the dataset is marked to load deferred, wrap it in a 

719 # DeferredDatasetRef 

720 quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] 

721 if attribute.deferLoad: 

722 quantumInputRefs = [ 

723 DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] 

724 ] 

725 else: 

726 quantumInputRefs = list(quantum.inputs[attribute.name]) 

727 # Unpack arguments that are not marked multiples (list of 

728 # length one) 

729 if not attribute.multiple: 

730 if len(quantumInputRefs) > 1: 

731 raise ScalarError( 

732 "Received multiple datasets " 

733 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

734 f"for scalar connection {attributeName} " 

735 f"({quantumInputRefs[0].datasetType.name}) " 

736 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

737 ) 

738 if len(quantumInputRefs) == 0: 

739 continue 

740 setattr(refs, attributeName, quantumInputRefs[0]) 

741 else: 

742 # Add to the QuantizedConnection identifier 

743 setattr(refs, attributeName, quantumInputRefs) 

744 # Branch if the attribute dataset type is an output 

745 elif attribute.name in quantum.outputs: 

746 value = quantum.outputs[attribute.name] 

747 # Unpack arguments that are not marked multiples (list of 

748 # length one) 

749 if not attribute.multiple: 

750 setattr(refs, attributeName, value[0]) 

751 else: 

752 setattr(refs, attributeName, value) 

753 # Specified attribute is not in inputs or outputs dont know how 

754 # to handle, throw 

755 else: 

756 raise ValueError( 

757 f"Attribute with name {attributeName} has no counterpart in input quantum" 

758 ) 

759 return inputDatasetRefs, outputDatasetRefs 

760 

761 def adjustQuantum( 

762 self, 

763 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

764 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

765 label: str, 

766 data_id: DataCoordinate, 

767 ) -> tuple[ 

768 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

769 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

770 ]: 

771 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

772 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

773 of the activator. 

774 

775 Parameters 

776 ---------- 

777 inputs : `dict` 

778 Dictionary whose keys are an input (regular or prerequisite) 

779 connection name and whose values are a tuple of the connection 

780 instance and a collection of associated 

781 `~lsst.daf.butler.DatasetRef` objects. 

782 The exact type of the nested collections is unspecified; it can be 

783 assumed to be multi-pass iterable and support `len` and ``in``, but 

784 it should not be mutated in place. In contrast, the outer 

785 dictionaries are guaranteed to be temporary copies that are true 

786 `dict` instances, and hence may be modified and even returned; this 

787 is especially useful for delegating to `super` (see notes below). 

788 outputs : `~collections.abc.Mapping` 

789 Mapping of output datasets, with the same structure as ``inputs``. 

790 label : `str` 

791 Label for this task in the pipeline (should be used in all 

792 diagnostic messages). 

793 data_id : `lsst.daf.butler.DataCoordinate` 

794 Data ID for this quantum in the pipeline (should be used in all 

795 diagnostic messages). 

796 

797 Returns 

798 ------- 

799 adjusted_inputs : `~collections.abc.Mapping` 

800 Mapping of the same form as ``inputs`` with updated containers of 

801 input `~lsst.daf.butler.DatasetRef` objects. Connections that are 

802 not changed should not be returned at all. Datasets may only be 

803 removed, not added. Nested collections may be of any multi-pass 

804 iterable type, and the order of iteration will set the order of 

805 iteration within `PipelineTask.runQuantum`. 

806 adjusted_outputs : `~collections.abc.Mapping` 

807 Mapping of updated output datasets, with the same structure and 

808 interpretation as ``adjusted_inputs``. 

809 

810 Raises 

811 ------ 

812 ScalarError 

813 Raised if any `Input` or `PrerequisiteInput` connection has 

814 ``multiple`` set to `False`, but multiple datasets. 

815 NoWorkFound 

816 Raised to indicate that this quantum should not be run; not enough 

817 datasets were found for a regular `Input` connection, and the 

818 quantum should be pruned or skipped. 

819 FileNotFoundError 

820 Raised to cause QuantumGraph generation to fail (with the message 

821 included in this exception); not enough datasets were found for a 

822 `PrerequisiteInput` connection. 

823 

824 Notes 

825 ----- 

826 The base class implementation performs important checks. It always 

827 returns an empty mapping (i.e. makes no adjustments). It should 

828 always called be via `super` by custom implementations, ideally at the 

829 end of the custom implementation with already-adjusted mappings when 

830 any datasets are actually dropped, e.g.: 

831 

832 .. code-block:: python 

833 

834 def adjustQuantum(self, inputs, outputs, label, data_id): 

835 # Filter out some dataset refs for one connection. 

836 connection, old_refs = inputs["my_input"] 

837 new_refs = [ref for ref in old_refs if ...] 

838 adjusted_inputs = {"my_input", (connection, new_refs)} 

839 # Update the original inputs so we can pass them to super. 

840 inputs.update(adjusted_inputs) 

841 # Can ignore outputs from super because they are guaranteed 

842 # to be empty. 

843 super().adjustQuantum(inputs, outputs, label_data_id) 

844 # Return only the connections we modified. 

845 return adjusted_inputs, {} 

846 

847 Removing outputs here is guaranteed to affect what is actually 

848 passed to `PipelineTask.runQuantum`, but its effect on the larger 

849 graph may be deferred to execution, depending on the context in 

850 which `adjustQuantum` is being run: if one quantum removes an output 

851 that is needed by a second quantum as input, the second quantum may not 

852 be adjusted (and hence pruned or skipped) until that output is actually 

853 found to be missing at execution time. 

854 

855 Tasks that desire zip-iteration consistency between any combinations of 

856 connections that have the same data ID should generally implement 

857 `adjustQuantum` to achieve this, even if they could also run that 

858 logic during execution; this allows the system to see outputs that will 

859 not be produced because the corresponding input is missing as early as 

860 possible. 

861 """ 

862 for name, (input_connection, refs) in inputs.items(): 

863 dataset_type_name = input_connection.name 

864 if not input_connection.multiple and len(refs) > 1: 

865 raise ScalarError( 

866 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

867 f"for non-multiple input connection {label}.{name} ({dataset_type_name}) " 

868 f"for quantum data ID {data_id}." 

869 ) 

870 if len(refs) < input_connection.minimum: 

871 if isinstance(input_connection, PrerequisiteInput): 

872 # This branch should only be possible during QG generation, 

873 # or if someone deleted the dataset between making the QG 

874 # and trying to run it. Either one should be a hard error. 

875 raise FileNotFoundError( 

876 f"Not enough datasets ({len(refs)}) found for non-optional connection {label}.{name} " 

877 f"({dataset_type_name}) with minimum={input_connection.minimum} for quantum data ID " 

878 f"{data_id}." 

879 ) 

880 else: 

881 # This branch should be impossible during QG generation, 

882 # because that algorithm can only make quanta whose inputs 

883 # are either already present or should be created during 

884 # execution. It can trigger during execution if the input 

885 # wasn't actually created by an upstream task in the same 

886 # graph. 

887 raise NoWorkFound(label, name, input_connection) 

888 for name, (output_connection, refs) in outputs.items(): 

889 dataset_type_name = output_connection.name 

890 if not output_connection.multiple and len(refs) > 1: 

891 raise ScalarError( 

892 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

893 f"for non-multiple output connection {label}.{name} ({dataset_type_name}) " 

894 f"for quantum data ID {data_id}." 

895 ) 

896 return {}, {} 

897 

898 

899def iterConnections( 

900 connections: PipelineTaskConnections, connectionType: str | Iterable[str] 

901) -> Generator[BaseConnection, None, None]: 

902 """Create an iterator over the selected connections type which yields 

903 all the defined connections of that type. 

904 

905 Parameters 

906 ---------- 

907 connections : `PipelineTaskConnections` 

908 An instance of a `PipelineTaskConnections` object that will be iterated 

909 over. 

910 connectionType : `str` 

911 The type of connections to iterate over, valid values are inputs, 

912 outputs, prerequisiteInputs, initInputs, initOutputs. 

913 

914 Yields 

915 ------ 

916 connection: `~.connectionTypes.BaseConnection` 

917 A connection defined on the input connections object of the type 

918 supplied. The yielded value Will be an derived type of 

919 `~.connectionTypes.BaseConnection`. 

920 """ 

921 if isinstance(connectionType, str): 

922 connectionType = (connectionType,) 

923 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

924 yield getattr(connections, name) 

925 

926 

927@dataclass 

928class AdjustQuantumHelper: 

929 """Helper class for calling `PipelineTaskConnections.adjustQuantum`. 

930 

931 This class holds `input` and `output` mappings in the form used by 

932 `Quantum` and execution harness code, i.e. with 

933 `~lsst.daf.butler.DatasetType` keys, translating them to and from the 

934 connection-oriented mappings used inside `PipelineTaskConnections`. 

935 """ 

936 

937 inputs: NamedKeyMapping[DatasetType, list[DatasetRef]] 

938 """Mapping of regular input and prerequisite input datasets, grouped by 

939 `~lsst.daf.butler.DatasetType`. 

940 """ 

941 

942 outputs: NamedKeyMapping[DatasetType, list[DatasetRef]] 

943 """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. 

944 """ 

945 

946 inputs_adjusted: bool = False 

947 """Whether any inputs were removed in the last call to `adjust_in_place`. 

948 """ 

949 

950 outputs_adjusted: bool = False 

951 """Whether any outputs were removed in the last call to `adjust_in_place`. 

952 """ 

953 

954 def adjust_in_place( 

955 self, 

956 connections: PipelineTaskConnections, 

957 label: str, 

958 data_id: DataCoordinate, 

959 ) -> None: 

960 """Call `~PipelineTaskConnections.adjustQuantum` and update ``self`` 

961 with its results. 

962 

963 Parameters 

964 ---------- 

965 connections : `PipelineTaskConnections` 

966 Instance on which to call `~PipelineTaskConnections.adjustQuantum`. 

967 label : `str` 

968 Label for this task in the pipeline (should be used in all 

969 diagnostic messages). 

970 data_id : `lsst.daf.butler.DataCoordinate` 

971 Data ID for this quantum in the pipeline (should be used in all 

972 diagnostic messages). 

973 """ 

974 # Translate self's DatasetType-keyed, Quantum-oriented mappings into 

975 # connection-keyed, PipelineTask-oriented mappings. 

976 inputs_by_connection: dict[str, tuple[BaseInput, tuple[DatasetRef, ...]]] = {} 

977 outputs_by_connection: dict[str, tuple[Output, tuple[DatasetRef, ...]]] = {} 

978 for name in itertools.chain(connections.inputs, connections.prerequisiteInputs): 

979 connection = getattr(connections, name) 

980 dataset_type_name = connection.name 

981 inputs_by_connection[name] = (connection, tuple(self.inputs.get(dataset_type_name, ()))) 

982 for name in itertools.chain(connections.outputs): 

983 connection = getattr(connections, name) 

984 dataset_type_name = connection.name 

985 outputs_by_connection[name] = (connection, tuple(self.outputs.get(dataset_type_name, ()))) 

986 # Actually call adjustQuantum. 

987 # MyPy correctly complains that this call is not quite legal, but the 

988 # method docs explain exactly what's expected and it's the behavior we 

989 # want. It'd be nice to avoid this if we ever have to change the 

990 # interface anyway, but not an immediate problem. 

991 adjusted_inputs_by_connection, adjusted_outputs_by_connection = connections.adjustQuantum( 

992 inputs_by_connection, # type: ignore 

993 outputs_by_connection, # type: ignore 

994 label, 

995 data_id, 

996 ) 

997 # Translate adjustments to DatasetType-keyed, Quantum-oriented form, 

998 # installing new mappings in self if necessary. 

999 if adjusted_inputs_by_connection: 

1000 adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs) 

1001 for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): 

1002 dataset_type_name = connection.name 

1003 if not set(updated_refs).issubset(self.inputs[dataset_type_name]): 

1004 raise RuntimeError( 

1005 f"adjustQuantum implementation for task with label {label} returned {name} " 

1006 f"({dataset_type_name}) input datasets that are not a subset of those " 

1007 f"it was given for data ID {data_id}." 

1008 ) 

1009 adjusted_inputs[dataset_type_name] = list(updated_refs) 

1010 self.inputs = adjusted_inputs.freeze() 

1011 self.inputs_adjusted = True 

1012 else: 

1013 self.inputs_adjusted = False 

1014 if adjusted_outputs_by_connection: 

1015 adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs) 

1016 for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): 

1017 if not set(updated_refs).issubset(self.outputs[dataset_type_name]): 

1018 raise RuntimeError( 

1019 f"adjustQuantum implementation for task with label {label} returned {name} " 

1020 f"({dataset_type_name}) output datasets that are not a subset of those " 

1021 f"it was given for data ID {data_id}." 

1022 ) 

1023 adjusted_outputs[dataset_type_name] = list(updated_refs) 

1024 self.outputs = adjusted_outputs.freeze() 

1025 self.outputs_adjusted = True 

1026 else: 

1027 self.outputs_adjusted = False