Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection", 

26 "DeferredDatasetRef", "iterConnections"] 

27 

28from collections import UserDict, namedtuple 

29from types import SimpleNamespace 

30import typing 

31from typing import Union, Iterable 

32 

33import itertools 

34import string 

35 

36from . import config as configMod 

37from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

38 Output, BaseConnection) 

39from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum 

40 

41if typing.TYPE_CHECKING: 41 ↛ 42line 41 didn't jump to line 42, because the condition on line 41 was never true

42 from .config import PipelineTaskConfig 

43 

44 

45class ScalarError(TypeError): 

46 """Exception raised when dataset type is configured as scalar 

47 but there are multiple data IDs in a Quantum for that dataset. 

48 """ 

49 

50 

51class PipelineTaskConnectionDict(UserDict): 

52 """This is a special dict class used by PipelineTaskConnectionMetaclass 

53 

54 This dict is used in PipelineTaskConnection class creation, as the 

55 dictionary that is initially used as __dict__. It exists to 

56 intercept connection fields declared in a PipelineTaskConnection, and 

57 what name is used to identify them. The names are then added to class 

58 level list according to the connection type of the class attribute. The 

59 names are also used as keys in a class level dictionary associated with 

60 the corresponding class attribute. This information is a duplicate of 

61 what exists in __dict__, but provides a simple place to lookup and 

62 iterate on only these variables. 

63 """ 

64 def __init__(self, *args, **kwargs): 

65 super().__init__(*args, **kwargs) 

66 # Initialize class level variables used to track any declared 

67 # class level variables that are instances of 

68 # connectionTypes.BaseConnection 

69 self.data['inputs'] = [] 

70 self.data['prerequisiteInputs'] = [] 

71 self.data['outputs'] = [] 

72 self.data['initInputs'] = [] 

73 self.data['initOutputs'] = [] 

74 self.data['allConnections'] = {} 

75 

76 def __setitem__(self, name, value): 

77 if isinstance(value, Input): 

78 self.data['inputs'].append(name) 

79 elif isinstance(value, PrerequisiteInput): 

80 self.data['prerequisiteInputs'].append(name) 

81 elif isinstance(value, Output): 

82 self.data['outputs'].append(name) 

83 elif isinstance(value, InitInput): 

84 self.data['initInputs'].append(name) 

85 elif isinstance(value, InitOutput): 

86 self.data['initOutputs'].append(name) 

87 # This should not be an elif, as it needs tested for 

88 # everything that inherits from BaseConnection 

89 if isinstance(value, BaseConnection): 

90 object.__setattr__(value, 'varName', name) 

91 self.data['allConnections'][name] = value 

92 # defer to the default behavior 

93 super().__setitem__(name, value) 

94 

95 

96class PipelineTaskConnectionsMetaclass(type): 

97 """Metaclass used in the declaration of PipelineTaskConnections classes 

98 """ 

99 def __prepare__(name, bases, **kwargs): # noqa: 805 

100 # Create an instance of our special dict to catch and track all 

101 # variables that are instances of connectionTypes.BaseConnection 

102 # Copy any existing connections from a parent class 

103 dct = PipelineTaskConnectionDict() 

104 for base in bases: 

105 if isinstance(base, PipelineTaskConnectionsMetaclass): 105 ↛ 104line 105 didn't jump to line 104, because the condition on line 105 was never false

106 for name, value in base.allConnections.items(): 106 ↛ 107line 106 didn't jump to line 107, because the loop on line 106 never started

107 dct[name] = value 

108 return dct 

109 

110 def __new__(cls, name, bases, dct, **kwargs): 

111 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

112 "attribute which is an iterable of dimension names") 

113 

114 if name != 'PipelineTaskConnections': 

115 # Verify that dimensions are passed as a keyword in class 

116 # declaration 

117 if 'dimensions' not in kwargs: 117 ↛ 118line 117 didn't jump to line 118, because the condition on line 117 was never true

118 for base in bases: 

119 if hasattr(base, 'dimensions'): 

120 kwargs['dimensions'] = base.dimensions 

121 break 

122 if 'dimensions' not in kwargs: 

123 raise dimensionsValueError 

124 try: 

125 if isinstance(kwargs['dimensions'], str): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 raise TypeError("Dimensions must be iterable of dimensions, got str," 

127 "possibly omitted trailing comma") 

128 if not isinstance(kwargs['dimensions'], typing.Iterable): 128 ↛ 129line 128 didn't jump to line 129, because the condition on line 128 was never true

129 raise TypeError("Dimensions must be iterable of dimensions") 

130 dct['dimensions'] = set(kwargs['dimensions']) 

131 except TypeError as exc: 

132 raise dimensionsValueError from exc 

133 # Lookup any python string templates that may have been used in the 

134 # declaration of the name field of a class connection attribute 

135 allTemplates = set() 

136 stringFormatter = string.Formatter() 

137 # Loop over all connections 

138 for obj in dct['allConnections'].values(): 

139 nameValue = obj.name 

140 # add all the parameters to the set of templates 

141 for param in stringFormatter.parse(nameValue): 

142 if param[1] is not None: 

143 allTemplates.add(param[1]) 

144 

145 # look up any template from base classes and merge them all 

146 # together 

147 mergeDict = {} 

148 for base in bases[::-1]: 

149 if hasattr(base, 'defaultTemplates'): 149 ↛ 150line 149 didn't jump to line 150, because the condition on line 149 was never true

150 mergeDict.update(base.defaultTemplates) 

151 if 'defaultTemplates' in kwargs: 

152 mergeDict.update(kwargs['defaultTemplates']) 

153 

154 if len(mergeDict) > 0: 

155 kwargs['defaultTemplates'] = mergeDict 

156 

157 # Verify that if templated strings were used, defaults were 

158 # supplied as an argument in the declaration of the connection 

159 # class 

160 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true

161 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

162 "defaut templates were provided, add a dictionary attribute named " 

163 "defaultTemplates which contains the mapping between template key and value") 

164 if len(allTemplates) > 0: 

165 # Verify all templates have a default, and throw if they do not 

166 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

167 templateDifference = allTemplates.difference(defaultTemplateKeys) 

168 if templateDifference: 168 ↛ 169line 168 didn't jump to line 169, because the condition on line 168 was never true

169 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

170 # Verify that templates do not share names with variable names 

171 # used for a connection, this is needed because of how 

172 # templates are specified in an associated config class. 

173 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

174 if len(nameTemplateIntersection) > 0: 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true

175 raise TypeError(f"Template parameters cannot share names with Class attributes" 

176 f" (conflicts are {nameTemplateIntersection}).") 

177 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

178 

179 # Convert all the connection containers into frozensets so they cannot 

180 # be modified at the class scope 

181 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

182 dct[connectionName] = frozenset(dct[connectionName]) 

183 # our custom dict type must be turned into an actual dict to be used in 

184 # type.__new__ 

185 return super().__new__(cls, name, bases, dict(dct)) 

186 

187 def __init__(cls, name, bases, dct, **kwargs): 

188 # This overrides the default init to drop the kwargs argument. Python 

189 # metaclasses will have this argument set if any kwargs are passes at 

190 # class construction time, but should be consumed before calling 

191 # __init__ on the type metaclass. This is in accordance with python 

192 # documentation on metaclasses 

193 super().__init__(name, bases, dct) 

194 

195 

196class QuantizedConnection(SimpleNamespace): 

197 """A Namespace to map defined variable names of connections to their 

198 `lsst.daf.buter.DatasetRef`s 

199 

200 This class maps the names used to define a connection on a 

201 PipelineTaskConnectionsClass to the corresponding 

202 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

203 instance. This will be a quantum of execution based on the graph created 

204 by examining all the connections defined on the 

205 `PipelineTaskConnectionsClass`. 

206 """ 

207 def __init__(self, **kwargs): 

208 # Create a variable to track what attributes are added. This is used 

209 # later when iterating over this QuantizedConnection instance 

210 object.__setattr__(self, "_attributes", set()) 

211 

212 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

213 # Capture the attribute name as it is added to this object 

214 self._attributes.add(name) 

215 super().__setattr__(name, value) 

216 

217 def __delattr__(self, name): 

218 object.__delattr__(self, name) 

219 self._attributes.remove(name) 

220 

221 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

222 typing.List[DatasetRef]]], None, None]: 

223 """Make an Iterator for this QuantizedConnection 

224 

225 Iterating over a QuantizedConnection will yield a tuple with the name 

226 of an attribute and the value associated with that name. This is 

227 similar to dict.items() but is on the namespace attributes rather than 

228 dict keys. 

229 """ 

230 yield from ((name, getattr(self, name)) for name in self._attributes) 

231 

232 def keys(self) -> typing.Generator[str, None, None]: 

233 """Returns an iterator over all the attributes added to a 

234 QuantizedConnection class 

235 """ 

236 yield from self._attributes 

237 

238 

239class InputQuantizedConnection(QuantizedConnection): 

240 pass 

241 

242 

243class OutputQuantizedConnection(QuantizedConnection): 

244 pass 

245 

246 

247class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

248 """Class which denotes that a datasetRef should be treated as deferred when 

249 interacting with the butler 

250 

251 Parameters 

252 ---------- 

253 datasetRef : `lsst.daf.butler.DatasetRef` 

254 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

255 resolve a dataset 

256 """ 

257 __slots__ = () 

258 

259 

260class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

261 """PipelineTaskConnections is a class used to declare desired IO when a 

262 PipelineTask is run by an activator 

263 

264 Parameters 

265 ---------- 

266 config : `PipelineTaskConfig` 

267 A `PipelineTaskConfig` class instance whose class has been configured 

268 to use this `PipelineTaskConnectionsClass` 

269 

270 See also 

271 -------- 

272 iterConnections 

273 

274 Notes 

275 ----- 

276 ``PipelineTaskConnection`` classes are created by declaring class 

277 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

278 listed as follows: 

279 

280 * ``InitInput`` - Defines connections in a quantum graph which are used as 

281 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

282 to this class 

283 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

284 persisted using a butler at the end of the ``__init__`` function of the 

285 `PipelineTask` corresponding to this class. The variable name used to 

286 define this connection should be the same as an attribute name on the 

287 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

288 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

289 a `PipelineTask` instance should have an attribute 

290 ``self.outputSchema`` defined. Its value is what will be saved by the 

291 activator framework. 

292 * ``PrerequisiteInput`` - An input connection type that defines a 

293 `lsst.daf.butler.DatasetType` that must be present at execution time, 

294 but that will not be used during the course of creating the quantum 

295 graph to be executed. These most often are things produced outside the 

296 processing pipeline, such as reference catalogs. 

297 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

298 in the ``run`` method of a `PipelineTask`. The name used to declare 

299 class attribute must match a function argument name in the ``run`` 

300 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

301 defines an ``Input`` with the name ``calexp``, then the corresponding 

302 signature should be ``PipelineTask.run(calexp, ...)`` 

303 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

304 execution of a `PipelineTask`. The name used to declare the connection 

305 must correspond to an attribute of a `Struct` that is returned by a 

306 `PipelineTask` ``run`` method. E.g. if an output connection is 

307 defined with the name ``measCat``, then the corresponding 

308 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

309 X matches the ``storageClass`` type defined on the output connection. 

310 

311 The process of declaring a ``PipelineTaskConnection`` class involves 

312 parameters passed in the declaration statement. 

313 

314 The first parameter is ``dimensions`` which is an iterable of strings which 

315 defines the unit of processing the run method of a corresponding 

316 `PipelineTask` will operate on. These dimensions must match dimensions that 

317 exist in the butler registry which will be used in executing the 

318 corresponding `PipelineTask`. 

319 

320 The second parameter is labeled ``defaultTemplates`` and is conditionally 

321 optional. The name attributes of connections can be specified as python 

322 format strings, with named format arguments. If any of the name parameters 

323 on connections defined in a `PipelineTaskConnections` class contain a 

324 template, then a default template value must be specified in the 

325 ``defaultTemplates`` argument. This is done by passing a dictionary with 

326 keys corresponding to a template identifier, and values corresponding to 

327 the value to use as a default when formatting the string. For example if 

328 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

329 ``defaultTemplates`` = {'input': 'deep'}. 

330 

331 Once a `PipelineTaskConnections` class is created, it is used in the 

332 creation of a `PipelineTaskConfig`. This is further documented in the 

333 documentation of `PipelineTaskConfig`. For the purposes of this 

334 documentation, the relevant information is that the config class allows 

335 configuration of connection names by users when running a pipeline. 

336 

337 Instances of a `PipelineTaskConnections` class are used by the pipeline 

338 task execution framework to introspect what a corresponding `PipelineTask` 

339 will require, and what it will produce. 

340 

341 Examples 

342 -------- 

343 >>> from lsst.pipe.base import connectionTypes as cT 

344 >>> from lsst.pipe.base import PipelineTaskConnections 

345 >>> from lsst.pipe.base import PipelineTaskConfig 

346 >>> class ExampleConnections(PipelineTaskConnections, 

347 ... dimensions=("A", "B"), 

348 ... defaultTemplates={"foo": "Example"}): 

349 ... inputConnection = cT.Input(doc="Example input", 

350 ... dimensions=("A", "B"), 

351 ... storageClass=Exposure, 

352 ... name="{foo}Dataset") 

353 ... outputConnection = cT.Output(doc="Example output", 

354 ... dimensions=("A", "B"), 

355 ... storageClass=Exposure, 

356 ... name="{foo}output") 

357 >>> class ExampleConfig(PipelineTaskConfig, 

358 ... pipelineConnections=ExampleConnections): 

359 ... pass 

360 >>> config = ExampleConfig() 

361 >>> config.connections.foo = Modified 

362 >>> config.connections.outputConnection = "TotallyDifferent" 

363 >>> connections = ExampleConnections(config=config) 

364 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

365 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

366 """ 

367 

368 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

369 self.inputs = set(self.inputs) 

370 self.prerequisiteInputs = set(self.prerequisiteInputs) 

371 self.outputs = set(self.outputs) 

372 self.initInputs = set(self.initInputs) 

373 self.initOutputs = set(self.initOutputs) 

374 self.allConnections = dict(self.allConnections) 

375 

376 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

377 raise ValueError("PipelineTaskConnections must be instantiated with" 

378 " a PipelineTaskConfig instance") 

379 self.config = config 

380 # Extract the template names that were defined in the config instance 

381 # by looping over the keys of the defaultTemplates dict specified at 

382 # class declaration time 

383 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

384 'defaultTemplates').keys()} 

385 # Extract the configured value corresponding to each connection 

386 # variable. I.e. for each connection identifier, populate a override 

387 # for the connection.name attribute 

388 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

389 for name in self.allConnections.keys()} 

390 

391 # connections.name corresponds to a dataset type name, create a reverse 

392 # mapping that goes from dataset type name to attribute identifier name 

393 # (variable name) on the connection class 

394 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

395 

396 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

397 OutputQuantizedConnection]: 

398 """Builds QuantizedConnections corresponding to input Quantum 

399 

400 Parameters 

401 ---------- 

402 quantum : `lsst.daf.butler.Quantum` 

403 Quantum object which defines the inputs and outputs for a given 

404 unit of processing 

405 

406 Returns 

407 ------- 

408 retVal : `tuple` of (`InputQuantizedConnection`, 

409 `OutputQuantizedConnection`) Namespaces mapping attribute names 

410 (identifiers of connections) to butler references defined in the 

411 input `lsst.daf.butler.Quantum` 

412 """ 

413 inputDatasetRefs = InputQuantizedConnection() 

414 outputDatasetRefs = OutputQuantizedConnection() 

415 # operate on a reference object and an interable of names of class 

416 # connection attributes 

417 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

418 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

419 # get a name of a class connection attribute 

420 for attributeName in names: 

421 # get the attribute identified by name 

422 attribute = getattr(self, attributeName) 

423 # Branch if the attribute dataset type is an input 

424 if attribute.name in quantum.inputs: 

425 # Get the DatasetRefs 

426 quantumInputRefs = quantum.inputs[attribute.name] 

427 # if the dataset is marked to load deferred, wrap it in a 

428 # DeferredDatasetRef 

429 if attribute.deferLoad: 

430 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

431 # Unpack arguments that are not marked multiples (list of 

432 # length one) 

433 if not attribute.multiple: 

434 if len(quantumInputRefs) > 1: 

435 raise ScalarError( 

436 f"Received multiple datasets " 

437 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

438 f"for scalar connection {attributeName} " 

439 f"({quantumInputRefs[0].datasetType.name}) " 

440 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

441 ) 

442 if len(quantumInputRefs) == 0: 

443 continue 

444 quantumInputRefs = quantumInputRefs[0] 

445 # Add to the QuantizedConnection identifier 

446 setattr(refs, attributeName, quantumInputRefs) 

447 # Branch if the attribute dataset type is an output 

448 elif attribute.name in quantum.outputs: 

449 value = quantum.outputs[attribute.name] 

450 # Unpack arguments that are not marked multiples (list of 

451 # length one) 

452 if not attribute.multiple: 

453 value = value[0] 

454 # Add to the QuantizedConnection identifier 

455 setattr(refs, attributeName, value) 

456 # Specified attribute is not in inputs or outputs dont know how 

457 # to handle, throw 

458 else: 

459 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

460 "in input quantum") 

461 return inputDatasetRefs, outputDatasetRefs 

462 

463 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]] 

464 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]: 

465 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

466 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

467 of the activator. 

468 

469 The base class implementation simply checks that input connections with 

470 ``multiple`` set to `False` have no more than one dataset. 

471 

472 Parameters 

473 ---------- 

474 datasetRefMap : `NamedKeyDict` 

475 Mapping from dataset type to a `set` of 

476 `lsst.daf.butler.DatasetRef` objects 

477 

478 Returns 

479 ------- 

480 datasetRefMap : `NamedKeyDict` 

481 Modified mapping of input with possibly adjusted 

482 `lsst.daf.butler.DatasetRef` objects. 

483 

484 Raises 

485 ------ 

486 ScalarError 

487 Raised if any `Input` or `PrerequisiteInput` connection has 

488 ``multiple`` set to `False`, but multiple datasets. 

489 Exception 

490 Overrides of this function have the option of raising an Exception 

491 if a field in the input does not satisfy a need for a corresponding 

492 pipelineTask, i.e. no reference catalogs are found. 

493 """ 

494 for connection in itertools.chain(iterConnections(self, "inputs"), 

495 iterConnections(self, "prerequisiteInputs")): 

496 refs = datasetRefMap[connection.name] 

497 if not connection.multiple and len(refs) > 1: 

498 raise ScalarError( 

499 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

500 f"for scalar connection {connection.name} ({refs[0].datasetType.name})." 

501 ) 

502 return datasetRefMap 

503 

504 

505def iterConnections(connections: PipelineTaskConnections, 

506 connectionType: Union[str, Iterable[str]] 

507 ) -> typing.Generator[BaseConnection, None, None]: 

508 """Creates an iterator over the selected connections type which yields 

509 all the defined connections of that type. 

510 

511 Parameters 

512 ---------- 

513 connections: `PipelineTaskConnections` 

514 An instance of a `PipelineTaskConnections` object that will be iterated 

515 over. 

516 connectionType: `str` 

517 The type of connections to iterate over, valid values are inputs, 

518 outputs, prerequisiteInputs, initInputs, initOutputs. 

519 

520 Yields 

521 ------- 

522 connection: `BaseConnection` 

523 A connection defined on the input connections object of the type 

524 supplied. The yielded value Will be an derived type of 

525 `BaseConnection`. 

526 """ 

527 if isinstance(connectionType, str): 

528 connectionType = (connectionType,) 

529 for name in itertools.chain.from_iterable(getattr(connections, ct) for ct in connectionType): 

530 yield getattr(connections, name)