Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection", 

26 "DeferredDatasetRef", "iterConnections"] 

27 

28from collections import UserDict, namedtuple 

29from types import SimpleNamespace 

30import typing 

31 

32import itertools 

33import string 

34 

35from . import config as configMod 

36from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

37 Output, BaseConnection) 

38from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum 

39 

40if typing.TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from .config import PipelineTaskConfig 

42 

43 

44class ScalarError(TypeError): 

45 """Exception raised when dataset type is configured as scalar 

46 but there are multiple data IDs in a Quantum for that dataset. 

47 """ 

48 

49 

50class PipelineTaskConnectionDict(UserDict): 

51 """This is a special dict class used by PipelineTaskConnectionMetaclass 

52 

53 This dict is used in PipelineTaskConnection class creation, as the 

54 dictionary that is initially used as __dict__. It exists to 

55 intercept connection fields declared in a PipelineTaskConnection, and 

56 what name is used to identify them. The names are then added to class 

57 level list according to the connection type of the class attribute. The 

58 names are also used as keys in a class level dictionary associated with 

59 the corresponding class attribute. This information is a duplicate of 

60 what exists in __dict__, but provides a simple place to lookup and 

61 iterate on only these variables. 

62 """ 

63 def __init__(self, *args, **kwargs): 

64 super().__init__(*args, **kwargs) 

65 # Initialize class level variables used to track any declared 

66 # class level variables that are instances of 

67 # connectionTypes.BaseConnection 

68 self.data['inputs'] = [] 

69 self.data['prerequisiteInputs'] = [] 

70 self.data['outputs'] = [] 

71 self.data['initInputs'] = [] 

72 self.data['initOutputs'] = [] 

73 self.data['allConnections'] = {} 

74 

75 def __setitem__(self, name, value): 

76 if isinstance(value, Input): 

77 self.data['inputs'].append(name) 

78 elif isinstance(value, PrerequisiteInput): 

79 self.data['prerequisiteInputs'].append(name) 

80 elif isinstance(value, Output): 

81 self.data['outputs'].append(name) 

82 elif isinstance(value, InitInput): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true

83 self.data['initInputs'].append(name) 

84 elif isinstance(value, InitOutput): 

85 self.data['initOutputs'].append(name) 

86 # This should not be an elif, as it needs tested for 

87 # everything that inherits from BaseConnection 

88 if isinstance(value, BaseConnection): 

89 object.__setattr__(value, 'varName', name) 

90 self.data['allConnections'][name] = value 

91 # defer to the default behavior 

92 super().__setitem__(name, value) 

93 

94 

95class PipelineTaskConnectionsMetaclass(type): 

96 """Metaclass used in the declaration of PipelineTaskConnections classes 

97 """ 

98 def __prepare__(name, bases, **kwargs): # noqa: 805 

99 # Create an instance of our special dict to catch and track all 

100 # variables that are instances of connectionTypes.BaseConnection 

101 # Copy any existing connections from a parent class 

102 dct = PipelineTaskConnectionDict() 

103 for base in bases: 

104 if isinstance(base, PipelineTaskConnectionsMetaclass): 104 ↛ 103line 104 didn't jump to line 103, because the condition on line 104 was never false

105 for name, value in base.allConnections.items(): 105 ↛ 106line 105 didn't jump to line 106, because the loop on line 105 never started

106 dct[name] = value 

107 return dct 

108 

109 def __new__(cls, name, bases, dct, **kwargs): 

110 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

111 "attribute which is an iterable of dimension names") 

112 

113 if name != 'PipelineTaskConnections': 

114 # Verify that dimensions are passed as a keyword in class 

115 # declaration 

116 if 'dimensions' not in kwargs: 116 ↛ 117line 116 didn't jump to line 117, because the condition on line 116 was never true

117 for base in bases: 

118 if hasattr(base, 'dimensions'): 

119 kwargs['dimensions'] = base.dimensions 

120 break 

121 if 'dimensions' not in kwargs: 

122 raise dimensionsValueError 

123 try: 

124 if isinstance(kwargs['dimensions'], str): 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true

125 raise TypeError("Dimensions must be iterable of dimensions, got str," 

126 "possibly omitted trailing comma") 

127 if not isinstance(kwargs['dimensions'], typing.Iterable): 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true

128 raise TypeError("Dimensions must be iterable of dimensions") 

129 dct['dimensions'] = set(kwargs['dimensions']) 

130 except TypeError as exc: 

131 raise dimensionsValueError from exc 

132 # Lookup any python string templates that may have been used in the 

133 # declaration of the name field of a class connection attribute 

134 allTemplates = set() 

135 stringFormatter = string.Formatter() 

136 # Loop over all connections 

137 for obj in dct['allConnections'].values(): 

138 nameValue = obj.name 

139 # add all the parameters to the set of templates 

140 for param in stringFormatter.parse(nameValue): 

141 if param[1] is not None: 

142 allTemplates.add(param[1]) 

143 

144 # look up any template from base classes and merge them all 

145 # together 

146 mergeDict = {} 

147 for base in bases[::-1]: 

148 if hasattr(base, 'defaultTemplates'): 148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true

149 mergeDict.update(base.defaultTemplates) 

150 if 'defaultTemplates' in kwargs: 

151 mergeDict.update(kwargs['defaultTemplates']) 

152 

153 if len(mergeDict) > 0: 

154 kwargs['defaultTemplates'] = mergeDict 

155 

156 # Verify that if templated strings were used, defaults were 

157 # supplied as an argument in the declaration of the connection 

158 # class 

159 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 159 ↛ 160line 159 didn't jump to line 160, because the condition on line 159 was never true

160 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

161 "defaut templates were provided, add a dictionary attribute named " 

162 "defaultTemplates which contains the mapping between template key and value") 

163 if len(allTemplates) > 0: 

164 # Verify all templates have a default, and throw if they do not 

165 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

166 templateDifference = allTemplates.difference(defaultTemplateKeys) 

167 if templateDifference: 167 ↛ 168line 167 didn't jump to line 168, because the condition on line 167 was never true

168 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

169 # Verify that templates do not share names with variable names 

170 # used for a connection, this is needed because of how 

171 # templates are specified in an associated config class. 

172 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

173 if len(nameTemplateIntersection) > 0: 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 raise TypeError(f"Template parameters cannot share names with Class attributes" 

175 f" (conflicts are {nameTemplateIntersection}).") 

176 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

177 

178 # Convert all the connection containers into frozensets so they cannot 

179 # be modified at the class scope 

180 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

181 dct[connectionName] = frozenset(dct[connectionName]) 

182 # our custom dict type must be turned into an actual dict to be used in 

183 # type.__new__ 

184 return super().__new__(cls, name, bases, dict(dct)) 

185 

186 def __init__(cls, name, bases, dct, **kwargs): 

187 # This overrides the default init to drop the kwargs argument. Python 

188 # metaclasses will have this argument set if any kwargs are passes at 

189 # class construction time, but should be consumed before calling 

190 # __init__ on the type metaclass. This is in accordance with python 

191 # documentation on metaclasses 

192 super().__init__(name, bases, dct) 

193 

194 

195class QuantizedConnection(SimpleNamespace): 

196 """A Namespace to map defined variable names of connections to their 

197 `lsst.daf.buter.DatasetRef`s 

198 

199 This class maps the names used to define a connection on a 

200 PipelineTaskConnectionsClass to the corresponding 

201 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

202 instance. This will be a quantum of execution based on the graph created 

203 by examining all the connections defined on the 

204 `PipelineTaskConnectionsClass`. 

205 """ 

206 def __init__(self, **kwargs): 

207 # Create a variable to track what attributes are added. This is used 

208 # later when iterating over this QuantizedConnection instance 

209 object.__setattr__(self, "_attributes", set()) 

210 

211 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

212 # Capture the attribute name as it is added to this object 

213 self._attributes.add(name) 

214 super().__setattr__(name, value) 

215 

216 def __delattr__(self, name): 

217 object.__delattr__(self, name) 

218 self._attributes.remove(name) 

219 

220 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

221 typing.List[DatasetRef]]], None, None]: 

222 """Make an Iterator for this QuantizedConnection 

223 

224 Iterating over a QuantizedConnection will yield a tuple with the name 

225 of an attribute and the value associated with that name. This is 

226 similar to dict.items() but is on the namespace attributes rather than 

227 dict keys. 

228 """ 

229 yield from ((name, getattr(self, name)) for name in self._attributes) 

230 

231 def keys(self) -> typing.Generator[str, None, None]: 

232 """Returns an iterator over all the attributes added to a 

233 QuantizedConnection class 

234 """ 

235 yield from self._attributes 

236 

237 

238class InputQuantizedConnection(QuantizedConnection): 

239 pass 

240 

241 

242class OutputQuantizedConnection(QuantizedConnection): 

243 pass 

244 

245 

246class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

247 """Class which denotes that a datasetRef should be treated as deferred when 

248 interacting with the butler 

249 

250 Parameters 

251 ---------- 

252 datasetRef : `lsst.daf.butler.DatasetRef` 

253 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

254 resolve a dataset 

255 """ 

256 __slots__ = () 

257 

258 

259class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

260 """PipelineTaskConnections is a class used to declare desired IO when a 

261 PipelineTask is run by an activator 

262 

263 Parameters 

264 ---------- 

265 config : `PipelineTaskConfig` 

266 A `PipelineTaskConfig` class instance whose class has been configured 

267 to use this `PipelineTaskConnectionsClass` 

268 

269 Notes 

270 ----- 

271 ``PipelineTaskConnection`` classes are created by declaring class 

272 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

273 listed as follows: 

274 

275 * ``InitInput`` - Defines connections in a quantum graph which are used as 

276 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

277 to this class 

278 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

279 persisted using a butler at the end of the ``__init__`` function of the 

280 `PipelineTask` corresponding to this class. The variable name used to 

281 define this connection should be the same as an attribute name on the 

282 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

283 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

284 a `PipelineTask` instance should have an attribute 

285 ``self.outputSchema`` defined. Its value is what will be saved by the 

286 activator framework. 

287 * ``PrerequisiteInput`` - An input connection type that defines a 

288 `lsst.daf.butler.DatasetType` that must be present at execution time, 

289 but that will not be used during the course of creating the quantum 

290 graph to be executed. These most often are things produced outside the 

291 processing pipeline, such as reference catalogs. 

292 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

293 in the ``run`` method of a `PipelineTask`. The name used to declare 

294 class attribute must match a function argument name in the ``run`` 

295 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

296 defines an ``Input`` with the name ``calexp``, then the corresponding 

297 signature should be ``PipelineTask.run(calexp, ...)`` 

298 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

299 execution of a `PipelineTask`. The name used to declare the connection 

300 must correspond to an attribute of a `Struct` that is returned by a 

301 `PipelineTask` ``run`` method. E.g. if an output connection is 

302 defined with the name ``measCat``, then the corresponding 

303 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

304 X matches the ``storageClass`` type defined on the output connection. 

305 

306 The process of declaring a ``PipelineTaskConnection`` class involves 

307 parameters passed in the declaration statement. 

308 

309 The first parameter is ``dimensions`` which is an iterable of strings which 

310 defines the unit of processing the run method of a corresponding 

311 `PipelineTask` will operate on. These dimensions must match dimensions that 

312 exist in the butler registry which will be used in executing the 

313 corresponding `PipelineTask`. 

314 

315 The second parameter is labeled ``defaultTemplates`` and is conditionally 

316 optional. The name attributes of connections can be specified as python 

317 format strings, with named format arguments. If any of the name parameters 

318 on connections defined in a `PipelineTaskConnections` class contain a 

319 template, then a default template value must be specified in the 

320 ``defaultTemplates`` argument. This is done by passing a dictionary with 

321 keys corresponding to a template identifier, and values corresponding to 

322 the value to use as a default when formatting the string. For example if 

323 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

324 ``defaultTemplates`` = {'input': 'deep'}. 

325 

326 Once a `PipelineTaskConnections` class is created, it is used in the 

327 creation of a `PipelineTaskConfig`. This is further documented in the 

328 documentation of `PipelineTaskConfig`. For the purposes of this 

329 documentation, the relevant information is that the config class allows 

330 configuration of connection names by users when running a pipeline. 

331 

332 Instances of a `PipelineTaskConnections` class are used by the pipeline 

333 task execution framework to introspect what a corresponding `PipelineTask` 

334 will require, and what it will produce. 

335 

336 Examples 

337 -------- 

338 >>> from lsst.pipe.base import connectionTypes as cT 

339 >>> from lsst.pipe.base import PipelineTaskConnections 

340 >>> from lsst.pipe.base import PipelineTaskConfig 

341 >>> class ExampleConnections(PipelineTaskConnections, 

342 ... dimensions=("A", "B"), 

343 ... defaultTemplates={"foo": "Example"}): 

344 ... inputConnection = cT.Input(doc="Example input", 

345 ... dimensions=("A", "B"), 

346 ... storageClass=Exposure, 

347 ... name="{foo}Dataset") 

348 ... outputConnection = cT.Output(doc="Example output", 

349 ... dimensions=("A", "B"), 

350 ... storageClass=Exposure, 

351 ... name="{foo}output") 

352 >>> class ExampleConfig(PipelineTaskConfig, 

353 ... pipelineConnections=ExampleConnections): 

354 ... pass 

355 >>> config = ExampleConfig() 

356 >>> config.connections.foo = Modified 

357 >>> config.connections.outputConnection = "TotallyDifferent" 

358 >>> connections = ExampleConnections(config=config) 

359 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

360 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

361 """ 

362 

363 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

364 self.inputs = set(self.inputs) 

365 self.prerequisiteInputs = set(self.prerequisiteInputs) 

366 self.outputs = set(self.outputs) 

367 self.initInputs = set(self.initInputs) 

368 self.initOutputs = set(self.initOutputs) 

369 self.allConnections = dict(self.allConnections) 

370 

371 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

372 raise ValueError("PipelineTaskConnections must be instantiated with" 

373 " a PipelineTaskConfig instance") 

374 self.config = config 

375 # Extract the template names that were defined in the config instance 

376 # by looping over the keys of the defaultTemplates dict specified at 

377 # class declaration time 

378 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

379 'defaultTemplates').keys()} 

380 # Extract the configured value corresponding to each connection 

381 # variable. I.e. for each connection identifier, populate a override 

382 # for the connection.name attribute 

383 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

384 for name in self.allConnections.keys()} 

385 

386 # connections.name corresponds to a dataset type name, create a reverse 

387 # mapping that goes from dataset type name to attribute identifier name 

388 # (variable name) on the connection class 

389 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

390 

391 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

392 OutputQuantizedConnection]: 

393 """Builds QuantizedConnections corresponding to input Quantum 

394 

395 Parameters 

396 ---------- 

397 quantum : `lsst.daf.butler.Quantum` 

398 Quantum object which defines the inputs and outputs for a given 

399 unit of processing 

400 

401 Returns 

402 ------- 

403 retVal : `tuple` of (`InputQuantizedConnection`, 

404 `OutputQuantizedConnection`) Namespaces mapping attribute names 

405 (identifiers of connections) to butler references defined in the 

406 input `lsst.daf.butler.Quantum` 

407 """ 

408 inputDatasetRefs = InputQuantizedConnection() 

409 outputDatasetRefs = OutputQuantizedConnection() 

410 # operate on a reference object and an interable of names of class 

411 # connection attributes 

412 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

413 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

414 # get a name of a class connection attribute 

415 for attributeName in names: 

416 # get the attribute identified by name 

417 attribute = getattr(self, attributeName) 

418 # Branch if the attribute dataset type is an input 

419 if attribute.name in quantum.predictedInputs: 

420 # Get the DatasetRefs 

421 quantumInputRefs = quantum.predictedInputs[attribute.name] 

422 # if the dataset is marked to load deferred, wrap it in a 

423 # DeferredDatasetRef 

424 if attribute.deferLoad: 

425 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

426 # Unpack arguments that are not marked multiples (list of 

427 # length one) 

428 if not attribute.multiple: 

429 if len(quantumInputRefs) > 1: 

430 raise ScalarError( 

431 f"Received multiple datasets " 

432 f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " 

433 f"for scalar connection {attributeName} " 

434 f"({quantumInputRefs[0].datasetType.name}) " 

435 f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." 

436 ) 

437 if len(quantumInputRefs) == 0: 

438 continue 

439 quantumInputRefs = quantumInputRefs[0] 

440 # Add to the QuantizedConnection identifier 

441 setattr(refs, attributeName, quantumInputRefs) 

442 # Branch if the attribute dataset type is an output 

443 elif attribute.name in quantum.outputs: 

444 value = quantum.outputs[attribute.name] 

445 # Unpack arguments that are not marked multiples (list of 

446 # length one) 

447 if not attribute.multiple: 

448 value = value[0] 

449 # Add to the QuantizedConnection identifier 

450 setattr(refs, attributeName, value) 

451 # Specified attribute is not in inputs or outputs dont know how 

452 # to handle, throw 

453 else: 

454 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

455 "in input quantum") 

456 return inputDatasetRefs, outputDatasetRefs 

457 

458 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]] 

459 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]: 

460 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

461 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

462 of the activator. 

463 

464 The base class implementation simply checks that input connections with 

465 ``multiple`` set to `False` have no more than one dataset. 

466 

467 Parameters 

468 ---------- 

469 datasetRefMap : `NamedKeyDict` 

470 Mapping from dataset type to a `set` of 

471 `lsst.daf.butler.DatasetRef` objects 

472 

473 Returns 

474 ------- 

475 datasetRefMap : `NamedKeyDict` 

476 Modified mapping of input with possibly adjusted 

477 `lsst.daf.butler.DatasetRef` objects. 

478 

479 Raises 

480 ------ 

481 ScalarError 

482 Raised if any `Input` or `PrerequisiteInput` connection has 

483 ``multiple`` set to `False`, but multiple datasets. 

484 Exception 

485 Overrides of this function have the option of raising an Exception 

486 if a field in the input does not satisfy a need for a corresponding 

487 pipelineTask, i.e. no reference catalogs are found. 

488 """ 

489 for connection in itertools.chain(iterConnections(self, "inputs"), 

490 iterConnections(self, "prerequisiteInputs")): 

491 refs = datasetRefMap[connection.name] 

492 if not connection.multiple and len(refs) > 1: 

493 raise ScalarError( 

494 f"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} " 

495 f"for scalar connection {connection.name} ({refs[0].datasetType.name})." 

496 ) 

497 return datasetRefMap 

498 

499 

500def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator: 

501 """Creates an iterator over the selected connections type which yields 

502 all the defined connections of that type. 

503 

504 Parameters 

505 ---------- 

506 connections: `PipelineTaskConnections` 

507 An instance of a `PipelineTaskConnections` object that will be iterated 

508 over. 

509 connectionType: `str` 

510 The type of connections to iterate over, valid values are inputs, 

511 outputs, prerequisiteInputs, initInputs, initOutputs. 

512 

513 Yields 

514 ------- 

515 connection: `BaseConnection` 

516 A connection defined on the input connections object of the type 

517 supplied. The yielded value Will be an derived type of 

518 `BaseConnection`. 

519 """ 

520 for name in getattr(connections, connectionType): 

521 yield getattr(connections, name)