Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining connection classes for PipelineTask. 

23""" 

24 

25__all__ = ["PipelineTaskConnections", "InputQuantizedConnection", "OutputQuantizedConnection", 

26 "DeferredDatasetRef", "iterConnections"] 

27 

28from collections import UserDict, namedtuple 

29from types import SimpleNamespace 

30import typing 

31 

32import itertools 

33import string 

34 

35from . import config as configMod 

36from .connectionTypes import (InitInput, InitOutput, Input, PrerequisiteInput, 

37 Output, BaseConnection) 

38from lsst.daf.butler import DatasetRef, Quantum 

39 

40if typing.TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from .config import PipelineTaskConfig 

42 

43 

44class ScalarError(TypeError): 

45 """Exception raised when dataset type is configured as scalar 

46 but there are multiple DataIds in a Quantum for that dataset. 

47 

48 Parameters 

49 ---------- 

50 key : `str` 

51 Name of the configuration field for dataset type. 

52 If ``numDataIds`` is not specified, it is assumed that this parameter 

53 is the full message to be reported and not the key. 

54 numDataIds : `int`, optional 

55 Actual number of DataIds in a Quantum for this dataset type. 

56 """ 

57 def __init__(self, key, numDataIds=None): 

58 if numDataIds is None: 

59 # Assume we are receiving a normal TypeError message 

60 err_msg = key 

61 else: 

62 err_msg = f"Expected scalar for output dataset field {key}, " \ 

63 f"received {numDataIds} DataIds" 

64 super().__init__(err_msg) 

65 

66 

67class PipelineTaskConnectionDict(UserDict): 

68 """This is a special dict class used by PipelineTaskConnectionMetaclass 

69 

70 This dict is used in PipelineTaskConnection class creation, as the 

71 dictionary that is initially used as __dict__. It exists to 

72 intercept connection fields declared in a PipelineTaskConnection, and 

73 what name is used to identify them. The names are then added to class 

74 level list according to the connection type of the class attribute. The 

75 names are also used as keys in a class level dictionary associated with 

76 the corresponding class attribute. This information is a duplicate of 

77 what exists in __dict__, but provides a simple place to lookup and 

78 iterate on only these variables. 

79 """ 

80 def __init__(self, *args, **kwargs): 

81 super().__init__(*args, **kwargs) 

82 # Initialize class level variables used to track any declared 

83 # class level variables that are instances of 

84 # connectionTypes.BaseConnection 

85 self.data['inputs'] = [] 

86 self.data['prerequisiteInputs'] = [] 

87 self.data['outputs'] = [] 

88 self.data['initInputs'] = [] 

89 self.data['initOutputs'] = [] 

90 self.data['allConnections'] = {} 

91 

92 def __setitem__(self, name, value): 

93 if isinstance(value, Input): 

94 self.data['inputs'].append(name) 

95 elif isinstance(value, PrerequisiteInput): 

96 self.data['prerequisiteInputs'].append(name) 

97 elif isinstance(value, Output): 

98 self.data['outputs'].append(name) 

99 elif isinstance(value, InitInput): 99 ↛ 100line 99 didn't jump to line 100, because the condition on line 99 was never true

100 self.data['initInputs'].append(name) 

101 elif isinstance(value, InitOutput): 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true

102 self.data['initOutputs'].append(name) 

103 # This should not be an elif, as it needs tested for 

104 # everything that inherits from BaseConnection 

105 if isinstance(value, BaseConnection): 

106 object.__setattr__(value, 'varName', name) 

107 self.data['allConnections'][name] = value 

108 # defer to the default behavior 

109 super().__setitem__(name, value) 

110 

111 

112class PipelineTaskConnectionsMetaclass(type): 

113 """Metaclass used in the declaration of PipelineTaskConnections classes 

114 """ 

115 def __prepare__(name, bases, **kwargs): # noqa: 805 

116 # Create an instance of our special dict to catch and track all 

117 # variables that are instances of connectionTypes.BaseConnection 

118 # Copy any existing connections from a parent class 

119 dct = PipelineTaskConnectionDict() 

120 for base in bases: 

121 if isinstance(base, PipelineTaskConnectionsMetaclass): 121 ↛ 120line 121 didn't jump to line 120, because the condition on line 121 was never false

122 for name, value in base.allConnections.items(): 122 ↛ 123line 122 didn't jump to line 123, because the loop on line 122 never started

123 dct[name] = value 

124 return dct 

125 

126 def __new__(cls, name, bases, dct, **kwargs): 

127 dimensionsValueError = TypeError("PipelineTaskConnections class must be created with a dimensions " 

128 "attribute which is an iterable of dimension names") 

129 

130 if name != 'PipelineTaskConnections': 

131 # Verify that dimensions are passed as a keyword in class 

132 # declaration 

133 if 'dimensions' not in kwargs: 133 ↛ 134line 133 didn't jump to line 134, because the condition on line 133 was never true

134 for base in bases: 

135 if hasattr(base, 'dimensions'): 

136 kwargs['dimensions'] = base.dimensions 

137 break 

138 if 'dimensions' not in kwargs: 

139 raise dimensionsValueError 

140 try: 

141 dct['dimensions'] = set(kwargs['dimensions']) 

142 except TypeError as exc: 

143 raise dimensionsValueError from exc 

144 # Lookup any python string templates that may have been used in the 

145 # declaration of the name field of a class connection attribute 

146 allTemplates = set() 

147 stringFormatter = string.Formatter() 

148 # Loop over all connections 

149 for obj in dct['allConnections'].values(): 

150 nameValue = obj.name 

151 # add all the parameters to the set of templates 

152 for param in stringFormatter.parse(nameValue): 

153 if param[1] is not None: 153 ↛ 154line 153 didn't jump to line 154, because the condition on line 153 was never true

154 allTemplates.add(param[1]) 

155 

156 # look up any template from base classes and merge them all 

157 # together 

158 mergeDict = {} 

159 for base in bases[::-1]: 

160 if hasattr(base, 'defaultTemplates'): 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true

161 mergeDict.update(base.defaultTemplates) 

162 if 'defaultTemplates' in kwargs: 162 ↛ 163line 162 didn't jump to line 163, because the condition on line 162 was never true

163 mergeDict.update(kwargs['defaultTemplates']) 

164 

165 if len(mergeDict) > 0: 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 kwargs['defaultTemplates'] = mergeDict 

167 

168 # Verify that if templated strings were used, defaults were 

169 # supplied as an argument in the declaration of the connection 

170 # class 

171 if len(allTemplates) > 0 and 'defaultTemplates' not in kwargs: 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true

172 raise TypeError("PipelineTaskConnection class contains templated attribute names, but no " 

173 "defaut templates were provided, add a dictionary attribute named " 

174 "defaultTemplates which contains the mapping between template key and value") 

175 if len(allTemplates) > 0: 175 ↛ 177line 175 didn't jump to line 177, because the condition on line 175 was never true

176 # Verify all templates have a default, and throw if they do not 

177 defaultTemplateKeys = set(kwargs['defaultTemplates'].keys()) 

178 templateDifference = allTemplates.difference(defaultTemplateKeys) 

179 if templateDifference: 

180 raise TypeError(f"Default template keys were not provided for {templateDifference}") 

181 # Verify that templates do not share names with variable names 

182 # used for a connection, this is needed because of how 

183 # templates are specified in an associated config class. 

184 nameTemplateIntersection = allTemplates.intersection(set(dct['allConnections'].keys())) 

185 if len(nameTemplateIntersection) > 0: 

186 raise TypeError(f"Template parameters cannot share names with Class attributes") 

187 dct['defaultTemplates'] = kwargs.get('defaultTemplates', {}) 

188 

189 # Convert all the connection containers into frozensets so they cannot 

190 # be modified at the class scope 

191 for connectionName in ("inputs", "prerequisiteInputs", "outputs", "initInputs", "initOutputs"): 

192 dct[connectionName] = frozenset(dct[connectionName]) 

193 # our custom dict type must be turned into an actual dict to be used in 

194 # type.__new__ 

195 return super().__new__(cls, name, bases, dict(dct)) 

196 

197 def __init__(cls, name, bases, dct, **kwargs): 

198 # This overrides the default init to drop the kwargs argument. Python 

199 # metaclasses will have this argument set if any kwargs are passes at 

200 # class construction time, but should be consumed before calling 

201 # __init__ on the type metaclass. This is in accordance with python 

202 # documentation on metaclasses 

203 super().__init__(name, bases, dct) 

204 

205 

206class QuantizedConnection(SimpleNamespace): 

207 """A Namespace to map defined variable names of connections to their 

208 `lsst.daf.buter.DatasetRef`s 

209 

210 This class maps the names used to define a connection on a 

211 PipelineTaskConnectionsClass to the corresponding 

212 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum` 

213 instance. This will be a quantum of execution based on the graph created 

214 by examining all the connections defined on the 

215 `PipelineTaskConnectionsClass`. 

216 """ 

217 def __init__(self, **kwargs): 

218 # Create a variable to track what attributes are added. This is used 

219 # later when iterating over this QuantizedConnection instance 

220 object.__setattr__(self, "_attributes", set()) 

221 

222 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]): 

223 # Capture the attribute name as it is added to this object 

224 self._attributes.add(name) 

225 super().__setattr__(name, value) 

226 

227 def __delattr__(self, name): 

228 object.__delattr__(self, name) 

229 self._attributes.remove(name) 

230 

231 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef, 

232 typing.List[DatasetRef]]], None, None]: 

233 """Make an Iterator for this QuantizedConnection 

234 

235 Iterating over a QuantizedConnection will yield a tuple with the name 

236 of an attribute and the value associated with that name. This is 

237 similar to dict.items() but is on the namespace attributes rather than 

238 dict keys. 

239 """ 

240 yield from ((name, getattr(self, name)) for name in self._attributes) 

241 

242 def keys(self) -> typing.Generator[str, None, None]: 

243 """Returns an iterator over all the attributes added to a 

244 QuantizedConnection class 

245 """ 

246 yield from self._attributes 

247 

248 

249class InputQuantizedConnection(QuantizedConnection): 

250 pass 

251 

252 

253class OutputQuantizedConnection(QuantizedConnection): 

254 pass 

255 

256 

257class DeferredDatasetRef(namedtuple("DeferredDatasetRefBase", "datasetRef")): 

258 """Class which denotes that a datasetRef should be treated as deferred when 

259 interacting with the butler 

260 

261 Parameters 

262 ---------- 

263 datasetRef : `lsst.daf.butler.DatasetRef` 

264 The `lsst.daf.butler.DatasetRef` that will be eventually used to 

265 resolve a dataset 

266 """ 

267 __slots__ = () 

268 

269 

270class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass): 

271 """PipelineTaskConnections is a class used to declare desired IO when a 

272 PipelineTask is run by an activator 

273 

274 Parameters 

275 ---------- 

276 config : `PipelineTaskConfig` 

277 A `PipelineTaskConfig` class instance whose class has been configured 

278 to use this `PipelineTaskConnectionsClass` 

279 

280 Notes 

281 ----- 

282 ``PipelineTaskConnection`` classes are created by declaring class 

283 attributes of types defined in `lsst.pipe.base.connectionTypes` and are 

284 listed as follows: 

285 

286 * ``InitInput`` - Defines connections in a quantum graph which are used as 

287 inputs to the ``__init__`` function of the `PipelineTask` corresponding 

288 to this class 

289 * ``InitOuput`` - Defines connections in a quantum graph which are to be 

290 persisted using a butler at the end of the ``__init__`` function of the 

291 `PipelineTask` corresponding to this class. The variable name used to 

292 define this connection should be the same as an attribute name on the 

293 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with 

294 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then 

295 a `PipelineTask` instance should have an attribute 

296 ``self.outputSchema`` defined. Its value is what will be saved by the 

297 activator framework. 

298 * ``PrerequisiteInput`` - An input connection type that defines a 

299 `lsst.daf.butler.DatasetType` that must be present at execution time, 

300 but that will not be used during the course of creating the quantum 

301 graph to be executed. These most often are things produced outside the 

302 processing pipeline, such as reference catalogs. 

303 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used 

304 in the ``run`` method of a `PipelineTask`. The name used to declare 

305 class attribute must match a function argument name in the ``run`` 

306 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections`` 

307 defines an ``Input`` with the name ``calexp``, then the corresponding 

308 signature should be ``PipelineTask.run(calexp, ...)`` 

309 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an 

310 execution of a `PipelineTask`. The name used to declare the connection 

311 must correspond to an attribute of a `Struct` that is returned by a 

312 `PipelineTask` ``run`` method. E.g. if an output connection is 

313 defined with the name ``measCat``, then the corresponding 

314 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where 

315 X matches the ``storageClass`` type defined on the output connection. 

316 

317 The process of declaring a ``PipelineTaskConnection`` class involves 

318 parameters passed in the declaration statement. 

319 

320 The first parameter is ``dimensions`` which is an iterable of strings which 

321 defines the unit of processing the run method of a corresponding 

322 `PipelineTask` will operate on. These dimensions must match dimensions that 

323 exist in the butler registry which will be used in executing the 

324 corresponding `PipelineTask`. 

325 

326 The second parameter is labeled ``defaultTemplates`` and is conditionally 

327 optional. The name attributes of connections can be specified as python 

328 format strings, with named format arguments. If any of the name parameters 

329 on connections defined in a `PipelineTaskConnections` class contain a 

330 template, then a default template value must be specified in the 

331 ``defaultTemplates`` argument. This is done by passing a dictionary with 

332 keys corresponding to a template identifier, and values corresponding to 

333 the value to use as a default when formatting the string. For example if 

334 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then 

335 ``defaultTemplates`` = {'input': 'deep'}. 

336 

337 Once a `PipelineTaskConnections` class is created, it is used in the 

338 creation of a `PipelineTaskConfig`. This is further documented in the 

339 documentation of `PipelineTaskConfig`. For the purposes of this 

340 documentation, the relevant information is that the config class allows 

341 configuration of connection names by users when running a pipeline. 

342 

343 Instances of a `PipelineTaskConnections` class are used by the pipeline 

344 task execution framework to introspect what a corresponding `PipelineTask` 

345 will require, and what it will produce. 

346 

347 Examples 

348 -------- 

349 >>> from lsst.pipe.base import connectionTypes as cT 

350 >>> from lsst.pipe.base import PipelineTaskConnections 

351 >>> from lsst.pipe.base import PipelineTaskConfig 

352 >>> class ExampleConnections(PipelineTaskConnections, 

353 ... dimensions=("A", "B"), 

354 ... defaultTemplates={"foo": "Example"}): 

355 ... inputConnection = cT.Input(doc="Example input", 

356 ... dimensions=("A", "B"), 

357 ... storageClass=Exposure, 

358 ... name="{foo}Dataset") 

359 ... outputConnection = cT.Output(doc="Example output", 

360 ... dimensions=("A", "B"), 

361 ... storageClass=Exposure, 

362 ... name="{foo}output") 

363 >>> class ExampleConfig(PipelineTaskConfig, 

364 ... pipelineConnections=ExampleConnections): 

365 ... pass 

366 >>> config = ExampleConfig() 

367 >>> config.connections.foo = Modified 

368 >>> config.connections.outputConnection = "TotallyDifferent" 

369 >>> connections = ExampleConnections(config=config) 

370 >>> assert(connections.inputConnection.name == "ModifiedDataset") 

371 >>> assert(connections.outputConnection.name == "TotallyDifferent") 

372 """ 

373 

374 def __init__(self, *, config: 'PipelineTaskConfig' = None): 

375 self.inputs = set(self.inputs) 

376 self.prerequisiteInputs = set(self.prerequisiteInputs) 

377 self.outputs = set(self.outputs) 

378 self.initInputs = set(self.initInputs) 

379 self.initOutputs = set(self.initOutputs) 

380 

381 if config is None or not isinstance(config, configMod.PipelineTaskConfig): 

382 raise ValueError("PipelineTaskConnections must be instantiated with" 

383 " a PipelineTaskConfig instance") 

384 self.config = config 

385 # Extract the template names that were defined in the config instance 

386 # by looping over the keys of the defaultTemplates dict specified at 

387 # class declaration time 

388 templateValues = {name: getattr(config.connections, name) for name in getattr(self, 

389 'defaultTemplates').keys()} 

390 # Extract the configured value corresponding to each connection 

391 # variable. I.e. for each connection identifier, populate a override 

392 # for the connection.name attribute 

393 self._nameOverrides = {name: getattr(config.connections, name).format(**templateValues) 

394 for name in self.allConnections.keys()} 

395 

396 # connections.name corresponds to a dataset type name, create a reverse 

397 # mapping that goes from dataset type name to attribute identifier name 

398 # (variable name) on the connection class 

399 self._typeNameToVarName = {v: k for k, v in self._nameOverrides.items()} 

400 

401 def buildDatasetRefs(self, quantum: Quantum) -> typing.Tuple[InputQuantizedConnection, 

402 OutputQuantizedConnection]: 

403 """Builds QuantizedConnections corresponding to input Quantum 

404 

405 Parameters 

406 ---------- 

407 quantum : `lsst.daf.butler.Quantum` 

408 Quantum object which defines the inputs and outputs for a given 

409 unit of processing 

410 

411 Returns 

412 ------- 

413 retVal : `tuple` of (`InputQuantizedConnection`, 

414 `OutputQuantizedConnection`) Namespaces mapping attribute names 

415 (identifiers of connections) to butler references defined in the 

416 input `lsst.daf.butler.Quantum` 

417 """ 

418 inputDatasetRefs = InputQuantizedConnection() 

419 outputDatasetRefs = OutputQuantizedConnection() 

420 # operate on a reference object and an interable of names of class 

421 # connection attributes 

422 for refs, names in zip((inputDatasetRefs, outputDatasetRefs), 

423 (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs)): 

424 # get a name of a class connection attribute 

425 for attributeName in names: 

426 # get the attribute identified by name 

427 attribute = getattr(self, attributeName) 

428 # Branch if the attribute dataset type is an input 

429 if attribute.name in quantum.predictedInputs: 

430 # Get the DatasetRefs 

431 quantumInputRefs = quantum.predictedInputs[attribute.name] 

432 # if the dataset is marked to load deferred, wrap it in a 

433 # DeferredDatasetRef 

434 if attribute.deferLoad: 

435 quantumInputRefs = [DeferredDatasetRef(datasetRef=ref) for ref in quantumInputRefs] 

436 # Unpack arguments that are not marked multiples (list of 

437 # length one) 

438 if not attribute.multiple: 

439 if len(quantumInputRefs) > 1: 

440 raise ScalarError(attributeName, len(quantumInputRefs)) 

441 if len(quantumInputRefs) == 0: 

442 continue 

443 quantumInputRefs = quantumInputRefs[0] 

444 # Add to the QuantizedConnection identifier 

445 setattr(refs, attributeName, quantumInputRefs) 

446 # Branch if the attribute dataset type is an output 

447 elif attribute.name in quantum.outputs: 

448 value = quantum.outputs[attribute.name] 

449 # Unpack arguments that are not marked multiples (list of 

450 # length one) 

451 if not attribute.multiple: 

452 value = value[0] 

453 # Add to the QuantizedConnection identifier 

454 setattr(refs, attributeName, value) 

455 # Specified attribute is not in inputs or outputs dont know how 

456 # to handle, throw 

457 else: 

458 raise ValueError(f"Attribute with name {attributeName} has no counterpoint " 

459 "in input quantum") 

460 return inputDatasetRefs, outputDatasetRefs 

461 

462 def adjustQuantum(self, datasetRefMap: InputQuantizedConnection): 

463 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects 

464 in the `lsst.daf.butler.core.Quantum` during the graph generation stage 

465 of the activator. 

466 

467 Parameters 

468 ---------- 

469 datasetRefMap : `dict` 

470 Mapping with keys of dataset type name to `list` of 

471 `lsst.daf.butler.DatasetRef` objects 

472 

473 Returns 

474 ------- 

475 datasetRefMap : `dict` 

476 Modified mapping of input with possible adjusted 

477 `lsst.daf.butler.DatasetRef` objects 

478 

479 Raises 

480 ------ 

481 Exception 

482 Overrides of this function have the option of raising an Exception 

483 if a field in the input does not satisfy a need for a corresponding 

484 pipelineTask, i.e. no reference catalogs are found. 

485 """ 

486 return datasetRefMap 

487 

488 

489def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator: 

490 """Creates an iterator over the selected connections type which yields 

491 all the defined connections of that type. 

492 

493 Parameters 

494 ---------- 

495 connections: `PipelineTaskConnections` 

496 An instance of a `PipelineTaskConnections` object that will be iterated 

497 over. 

498 connectionType: `str` 

499 The type of connections to iterate over, valid values are inputs, 

500 outputs, prerequisiteInputs, initInputs, initOutputs. 

501 

502 Yields 

503 ------- 

504 connection: `BaseConnection` 

505 A connection defined on the input connections object of the type 

506 supplied. The yielded value Will be an derived type of 

507 `BaseConnection`. 

508 """ 

509 for name in getattr(connections, connectionType): 

510 yield getattr(connections, name)