Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection", 

26 "DeferredDatasetRef", "iterConnections"] 

27 

28from collections import UserDict, namedtuple 

29from types import SimpleNamespace 

30import typing 

31 

32import itertools 

33import string 

34 

35from . import config as configMod 

36from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

37 Output, BaseConnection) 

38from lsst.daf.butler import DatasetRef, Quantum 

39 

40if typing.TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from .config import PipelineTaskConfig 

42 

43 

44class ScalarError(TypeError): 

45 """Exception raised when dataset type is configured as scalar 

46 but there are multiple data IDs in a Quantum for that dataset. 

47 """ 

48 

49 

50class PipelineTaskConnectionDict(UserDict): 

51 """This is a special dict class used by PipelineTaskConnectionMetaclass 

52 

53 This dict is used in PipelineTaskConnection class creation, as the 

54 dictionary that is initially used as __dict__. It exists to 

55 intercept connection fields declared in a PipelineTaskConnection, and 

56 what name is used to identify them. The names are then added to class 

57 level list according to the connection type of the class attribute. The 

58 names are also used as keys in a class level dictionary associated with 

59 the corresponding class attribute. This information is a duplicate of 

60 what exists in __dict__, but provides a simple place to lookup and 

61 iterate on only these variables. 

62 """ 

63 def __init__(self, *args, **kwargs): 

64 super().__init__(*args, **kwargs) 

65 # Initialize class level variables used to track any declared 

66 # class level variables that are instances of 

67 # connectionTypes.BaseConnection 

68 self.data['inputs'] = [] 

69 self.data['prerequisiteInputs'] = [] 

70 self.data['outputs'] = [] 

71 self.data['initInputs'] = [] 

72 self.data['initOutputs'] = [] 

73 self.data['allConnections'] = {} 

74 

75 def __setitem__(self, name, value): 

76 if isinstance(value, Input): 

77 self.data['inputs'].append(name) 

78 elif isinstance(value, PrerequisiteInput): 

79 self.data['prerequisiteInputs'].append(name) 

80 elif isinstance(value, Output): 

81 self.data['outputs'].append(name) 

82 elif isinstance(value, InitInput): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true

83 self.data['initInputs'].append(name) 

84 elif isinstance(value, InitOutput): 84 ↛ 85line 84 didn't jump to line 85, because the condition on line 84 was never true

85 self.data['initOutputs'].append(name) 

86 # This should not be an elif, as it needs tested for 

87 # everything that inherits from BaseConnection 

88 if isinstance(value, BaseConnection): 

89 object.__setattr__(value, 'varName', name) 

90 self.data['allConnections'][name] = value 

91 # defer to the default behavior 

92 super().__setitem__(name, value) 

93 

94 

95class PipelineTaskConnectionsMetaclass(type): 

96 """Metaclass used in the declaration of PipelineTaskConnections classes 

97 """ 

98 def __prepare__(name, bases, **kwargs): # noqa: 805 

99 # Create an instance of our special dict to catch and track all 

100 # variables that are instances of connectionTypes.BaseConnection 

101 # Copy any existing connections from a parent class 

102 dct = PipelineTaskConnectionDict() 

103 for base in bases: 

104 if isinstance(base, PipelineTaskConnectionsMetaclass): 104 ↛ 103line 104 didn't jump to line 103, because the condition on line 104 was never false

105 for name, value in base.allConnections.items(): 105 ↛ 106line 105 didn't jump to line 106, because the loop on line 105 never started

106 dct[name] = value 

107 return dct 

108 

109 def __new__(cls, name, bases, dct, **kwargs): 

110 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

111 "attribute which is an iterable of dimension names") 

112 

113 if name != 'PipelineTaskConnections': 

114 # Verify that dimensions are passed as a keyword in class 

115 # declaration 

116 if 'dimensions' not in kwargs: 116 ↛ 117line 116 didn't jump to line 117, because the condition on line 116 was never true

117 for base in bases: 

118 if hasattr(base, 'dimensions'): 

119 kwargs['dimensions'] = base.dimensions 

120 break 

121 if 'dimensions' not in kwargs: 

122 raise dimensionsValueError 

123 try: 

124 dct['dimensions'] = set(kwargs['dimensions']) 

125 except TypeError as exc: 

126 raise dimensionsValueError from exc 

127 # Lookup any python string templates that may have been used in the 

128 # declaration of the name field of a class connection attribute 

129 allTemplates = set() 

130 stringFormatter = string.Formatter() 

131 # Loop over all connections 

132 for obj in dct['allConnections'].values(): 

133 nameValue = obj.name 

134 # add all the parameters to the set of templates 

135 for param in stringFormatter.parse(nameValue): 

136 if param[1] is not None: 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true

137 allTemplates.add(param[1]) 

138 

139 # look up any template from base classes and merge them all 

140 # together 

141 mergeDict = {} 

142 for base in bases[::-1]: 

143 if hasattr(base, 'defaultTemplates'): 143 ↛ 144line 143 didn't jump to line 144, because the condition on line 143 was never true

144 mergeDict.update(base.defaultTemplates) 

145 if 'defaultTemplates' in kwargs: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 mergeDict.update(kwargs['defaultTemplates']) 

147 

148 if len(mergeDict) > 0: 148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true

149 kwargs['defaultTemplates'] = mergeDict 

150 

151 # Verify that if templated strings were used, defaults were 

152 # supplied as an argument in the declaration of the connection 

153 # class 

154 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 154 ↛ 155line 154 didn't jump to line 155, because the condition on line 154 was never true

155 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

156 "defaut templates were provided, add a dictionary attribute named " 

157 "defaultTemplates which contains the mapping between template key and value") 

158 if len(allTemplates) > 0: 158 ↛ 160line 158 didn't jump to line 160, because the condition on line 158 was never true

159 # Verify all templates have a default, and throw if they do not 

160 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

161 templateDifference = allTemplates.difference(defaultTemplateKeys) 

162 if templateDifference: 

163 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

164 # Verify that templates do not share names with variable names 

165 # used for a connection, this is needed because of how 

166 # templates are specified in an associated config class. 

167 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

168 if len(nameTemplateIntersection) > 0: 

169 raise TypeError(f"Template parameters cannot share names with Class attributes") 

170 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

171 

172 # Convert all the connection containers into frozensets so they cannot 

173 # be modified at the class scope 

174 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

175 dct[connectionName] = frozenset(dct[connectionName]) 

176 # our custom dict type must be turned into an actual dict to be used in 

177 # type.__new__ 

178 return super().__new__(cls, name, bases, dict(dct)) 

179 

180 def __init__(cls, name, bases, dct, **kwargs): 

181 # This overrides the default init to drop the kwargs argument. Python 

182 # metaclasses will have this argument set if any kwargs are passes at 

183 # class construction time, but should be consumed before calling 

184 # __init__ on the type metaclass. This is in accordance with python 

185 # documentation on metaclasses 

186 super().__init__(name, bases, dct) 

187 

188 

189class QuantizedConnection(SimpleNamespace): 

190 """A Namespace to map defined variable names of connections to their 

191 `lsst.daf.buter.DatasetRef`s 

192 

193 This class maps the names used to define a connection on a 

194 PipelineTaskConnectionsClass to the corresponding 

195 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

196 instance. This will be a quantum of execution based on the graph created 

197 by examining all the connections defined on the 

198 `PipelineTaskConnectionsClass`. 

199 """ 

200 def __init__(self, **kwargs): 

201 # Create a variable to track what attributes are added. This is used 

202 # later when iterating over this QuantizedConnection instance 

203 object.__setattr__(self, "_attributes", set()) 

204 

205 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

206 # Capture the attribute name as it is added to this object 

207 self._attributes.add(name) 

208 super().__setattr__(name, value) 

209 

210 def __delattr__(self, name): 

211 object.__delattr__(self, name) 

212 self._attributes.remove(name) 

213 

214 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

215 typing.List[DatasetRef]]], None, None]: 

216 """Make an Iterator for this QuantizedConnection 

217 

218 Iterating over a QuantizedConnection will yield a tuple with the name 

219 of an attribute and the value associated with that name. This is 

220 similar to dict.items() but is on the namespace attributes rather than 

221 dict keys. 

222 """ 

223 yield from ((name, getattr(self, name)) for name in self._attributes) 

224 

225 def keys(self) -> typing.Generator[str, None, None]: 

226 """Returns an iterator over all the attributes added to a 

227 QuantizedConnection class 

228 """ 

229 yield from self._attributes 

230 

231 

232class InputQuantizedConnection(QuantizedConnection): 

233 pass 

234 

235 

236class OutputQuantizedConnection(QuantizedConnection): 

237 pass 

238 

239 

240class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

241 """Class which denotes that a datasetRef should be treated as deferred when 

242 interacting with the butler 

243 

244 Parameters 

245 ---------- 

246 datasetRef : `lsst.daf.butler.DatasetRef` 

247 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

248 resolve a dataset 

249 """ 

250 __slots__ = () 

251 

252 

253class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

254 """PipelineTaskConnections is a class used to declare desired IO when a 

255 PipelineTask is run by an activator 

256 

257 Parameters 

258 ---------- 

259 config : `PipelineTaskConfig` 

260 A `PipelineTaskConfig` class instance whose class has been configured 

261 to use this `PipelineTaskConnectionsClass` 

262 

263 Notes 

264 ----- 

265 ``PipelineTaskConnection`` classes are created by declaring class 

266 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

267 listed as follows: 

268 

269 * ``InitInput`` - Defines connections in a quantum graph which are used as 

270 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

271 to this class 

272 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

273 persisted using a butler at the end of the ``__init__`` function of the 

274 `PipelineTask` corresponding to this class. The variable name used to 

275 define this connection should be the same as an attribute name on the 

276 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

277 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

278 a `PipelineTask` instance should have an attribute 

279 ``self.outputSchema`` defined. Its value is what will be saved by the 

280 activator framework. 

281 * ``PrerequisiteInput`` - An input connection type that defines a 

282 `lsst.daf.butler.DatasetType` that must be present at execution time, 

283 but that will not be used during the course of creating the quantum 

284 graph to be executed. These most often are things produced outside the 

285 processing pipeline, such as reference catalogs. 

286 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

287 in the ``run`` method of a `PipelineTask`. The name used to declare 

288 class attribute must match a function argument name in the ``run`` 

289 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

290 defines an ``Input`` with the name ``calexp``, then the corresponding 

291 signature should be ``PipelineTask.run(calexp, ...)`` 

292 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

293 execution of a `PipelineTask`. The name used to declare the connection 

294 must correspond to an attribute of a `Struct` that is returned by a 

295 `PipelineTask` ``run`` method. E.g. if an output connection is 

296 defined with the name ``measCat``, then the corresponding 

297 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

298 X matches the ``storageClass`` type defined on the output connection. 

299 

300 The process of declaring a ``PipelineTaskConnection`` class involves 

301 parameters passed in the declaration statement. 

302 

303 The first parameter is ``dimensions`` which is an iterable of strings which 

304 defines the unit of processing the run method of a corresponding 

305 `PipelineTask` will operate on. These dimensions must match dimensions that 

306 exist in the butler registry which will be used in executing the 

307 corresponding `PipelineTask`. 

308 

309 The second parameter is labeled ``defaultTemplates`` and is conditionally 

310 optional. The name attributes of connections can be specified as python 

311 format strings, with named format arguments. If any of the name parameters 

312 on connections defined in a `PipelineTaskConnections` class contain a 

313 template, then a default template value must be specified in the 

314 ``defaultTemplates`` argument. This is done by passing a dictionary with 

315 keys corresponding to a template identifier, and values corresponding to 

316 the value to use as a default when formatting the string. For example if 

317 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

318 ``defaultTemplates`` = {'input': 'deep'}. 

319 

320 Once a `PipelineTaskConnections` class is created, it is used in the 

321 creation of a `PipelineTaskConfig`. This is further documented in the 

322 documentation of `PipelineTaskConfig`. For the purposes of this 

323 documentation, the relevant information is that the config class allows 

324 configuration of connection names by users when running a pipeline. 

325 

326 Instances of a `PipelineTaskConnections` class are used by the pipeline 

327 task execution framework to introspect what a corresponding `PipelineTask` 

328 will require, and what it will produce. 

329 

330 Examples 

331 -------- 

332 >>> from lsst.pipe.base import connectionTypes as cT 

333 >>> from lsst.pipe.base import PipelineTaskConnections 

334 >>> from lsst.pipe.base import PipelineTaskConfig 

335 >>> class ExampleConnections(PipelineTaskConnections, 

336 ... dimensions=("A", "B"), 

337 ... defaultTemplates={"foo": "Example"}): 

338 ... inputConnection = cT.Input(doc="Example input", 

339 ... dimensions=("A", "B"), 

340 ... storageClass=Exposure, 

341 ... name="{foo}Dataset") 

342 ... outputConnection = cT.Output(doc="Example output", 

343 ... dimensions=("A", "B"), 

344 ... storageClass=Exposure, 

345 ... name="{foo}output") 

346 >>> class ExampleConfig(PipelineTaskConfig, 

347 ... pipelineConnections=ExampleConnections): 

348 ... pass 

349 >>> config = ExampleConfig() 

350 >>> config.connections.foo = Modified 

351 >>> config.connections.outputConnection = "TotallyDifferent" 

352 >>> connections = ExampleConnections(config=config) 

353 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

354 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

355 """ 

356 

357 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

358 self.inputs = set(self.inputs) 

359 self.prerequisiteInputs = set(self.prerequisiteInputs) 

360 self.outputs = set(self.outputs) 

361 self.initInputs = set(self.initInputs) 

362 self.initOutputs = set(self.initOutputs) 

363 

364 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

365 raise ValueError("PipelineTaskConnections must be instantiated with" 

366 " a PipelineTaskConfig instance") 

367 self.config = config 

368 # Extract the template names that were defined in the config instance 

369 # by looping over the keys of the defaultTemplates dict specified at 

370 # class declaration time 

371 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

372 'defaultTemplates').keys()} 

373 # Extract the configured value corresponding to each connection 

374 # variable. I.e. for each connection identifier, populate a override 

375 # for the connection.name attribute 

376 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

377 for name in self.allConnections.keys()} 

378 

379 # connections.name corresponds to a dataset type name, create a reverse 

380 # mapping that goes from dataset type name to attribute identifier name 

381 # (variable name) on the connection class 

382 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

383 

384 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

385 OutputQuantizedConnection]: 

386 """Builds QuantizedConnections corresponding to input Quantum 

387 

388 Parameters 

389 ---------- 

390 quantum : `lsst.daf.butler.Quantum` 

391 Quantum object which defines the inputs and outputs for a given 

392 unit of processing 

393 

394 Returns 

395 ------- 

396 retVal : `tuple` of (`InputQuantizedConnection`, 

397 `OutputQuantizedConnection`) Namespaces mapping attribute names 

398 (identifiers of connections) to butler references defined in the 

399 input `lsst.daf.butler.Quantum` 

400 """ 

401 inputDatasetRefs = InputQuantizedConnection() 

402 outputDatasetRefs = OutputQuantizedConnection() 

403 # operate on a reference object and an interable of names of class 

404 # connection attributes 

405 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

406 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

407 # get a name of a class connection attribute 

408 for attributeName in names: 

409 # get the attribute identified by name 

410 attribute = getattr(self, attributeName) 

411 # Branch if the attribute dataset type is an input 

412 if attribute.name in quantum.predictedInputs: 

413 # Get the DatasetRefs 

414 quantumInputRefs = quantum.predictedInputs[attribute.name] 

415 # if the dataset is marked to load deferred, wrap it in a 

416 # DeferredDatasetRef 

417 if attribute.deferLoad: 

418 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

419 # Unpack arguments that are not marked multiples (list of 

420 # length one) 

421 if not attribute.multiple: 

422 if len(quantumInputRefs) > 1: 

423 raise ScalarError( 

424 f"Received multiple datasets " 

425 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

426 f"for scalar connection {attributeName} " 

427 f"({quantumInputRefs[0].datasetType.name}) " 

428 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

429 ) 

430 if len(quantumInputRefs) == 0: 

431 continue 

432 quantumInputRefs = quantumInputRefs[0] 

433 # Add to the QuantizedConnection identifier 

434 setattr(refs, attributeName, quantumInputRefs) 

435 # Branch if the attribute dataset type is an output 

436 elif attribute.name in quantum.outputs: 

437 value = quantum.outputs[attribute.name] 

438 # Unpack arguments that are not marked multiples (list of 

439 # length one) 

440 if not attribute.multiple: 

441 value = value[0] 

442 # Add to the QuantizedConnection identifier 

443 setattr(refs, attributeName, value) 

444 # Specified attribute is not in inputs or outputs dont know how 

445 # to handle, throw 

446 else: 

447 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

448 "in input quantum") 

449 return inputDatasetRefs, outputDatasetRefs 

450 

451 def adjustQuantum(self, datasetRefMap: InputQuantizedConnection): 

452 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

453 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

454 of the activator. 

455 

456 The base class implementation simply checks that input connections with 

457 ``multiple`` set to `False` have no more than one dataset. 

458 

459 Parameters 

460 ---------- 

461 datasetRefMap : `dict` 

462 Mapping from dataset type name to `list` of 

463 `lsst.daf.butler.DatasetRef` objects 

464 

465 Returns 

466 ------- 

467 datasetRefMap : `dict` 

468 Modified mapping of input with possible adjusted 

469 `lsst.daf.butler.DatasetRef` objects. 

470 

471 Raises 

472 ------ 

473 ScalarError 

474 Raised if any `Input` or `PrerequisiteInput` connection has 

475 ``multiple`` set to `False`, but multiple datasets. 

476 Exception 

477 Overrides of this function have the option of raising an Exception 

478 if a field in the input does not satisfy a need for a corresponding 

479 pipelineTask, i.e. no reference catalogs are found. 

480 """ 

481 for connection in itertools.chain(iterConnections(self, "inputs"), 

482 iterConnections(self, "prerequisiteInputs")): 

483 refs = datasetRefMap[connection.name] 

484 if not connection.multiple and len(refs) > 1: 

485 raise ScalarError( 

486 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

487 f"for scalar connection {connection.name} ({refs[0].datasetType.name})." 

488 ) 

489 return datasetRefMap 

490 

491 

492def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator: 

493 """Creates an iterator over the selected connections type which yields 

494 all the defined connections of that type. 

495 

496 Parameters 

497 ---------- 

498 connections: `PipelineTaskConnections` 

499 An instance of a `PipelineTaskConnections` object that will be iterated 

500 over. 

501 connectionType: `str` 

502 The type of connections to iterate over, valid values are inputs, 

503 outputs, prerequisiteInputs, initInputs, initOutputs. 

504 

505 Yields 

506 ------- 

507 connection: `BaseConnection` 

508 A connection defined on the input connections object of the type 

509 supplied. The yielded value Will be an derived type of 

510 `BaseConnection`. 

511 """ 

512 for name in getattr(connections, connectionType): 

513 yield getattr(connections, name)