Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30import functools 

31import os 

32from collections import defaultdict 

33from types import SimpleNamespace 

34 

35from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

36 DatasetType, Registry) 

37import lsst.pex.config as pexConfig 

38import lsst.pipe.base as pipeBase 

39from lsst.pipe.base import connectionTypes as cT 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class AddTaskConnections(pipeBase.PipelineTaskConnections, 

45 dimensions=("instrument", "detector")): 

46 input = cT.Input(name="add_input", 

47 dimensions=["instrument", "detector"], 

48 storageClass="NumpyArray", 

49 doc="Input dataset type for this task") 

50 output = cT.Output(name="add_output", 

51 dimensions=["instrument", "detector"], 

52 storageClass="NumpyArray", 

53 doc="Output dataset type for this task") 

54 initout = cT.InitOutput(name="add_init_output", 

55 storageClass="NumpyArray", 

56 doc="Init Output dataset type for this task") 

57 

58 

59class AddTaskConfig(pipeBase.PipelineTaskConfig, 

60 pipelineConnections=AddTaskConnections): 

61 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

62 

63 

64# example task which overrides run() method 

65class AddTask(pipeBase.PipelineTask): 

66 ConfigClass = AddTaskConfig 

67 _DefaultName = "add_task" 

68 

69 initout = numpy.array([999]) 

70 """InitOutputs for this task""" 

71 

72 countExec = 0 

73 """Number of times run() method was called for this class""" 

74 

75 stopAt = -1 

76 """Raises exception at this call to run()""" 

77 

78 def run(self, input): 

79 if AddTask.stopAt == AddTask.countExec: 

80 raise RuntimeError("pretend something bad happened") 

81 AddTask.countExec += 1 

82 self.metadata.add("add", self.config.addend) 

83 output = [val + self.config.addend for val in input] 

84 _LOG.info("input = %s, output = %s", input, output) 

85 return pipeBase.Struct(output=output) 

86 

87 

88class AddTaskFactoryMock(pipeBase.TaskFactory): 

89 def loadTaskClass(self, taskName): 

90 if taskName == "AddTask": 

91 return AddTask, "AddTask" 

92 

93 def makeTask(self, taskClass, config, overrides, butler): 

94 if config is None: 

95 config = taskClass.ConfigClass() 

96 if overrides: 

97 overrides.applyTo(config) 

98 return taskClass(config=config, initInputs=None) 

99 

100 

101class ButlerMock: 

102 """Mock version of butler, only usable for testing 

103 

104 Parameters 

105 ---------- 

106 fullRegistry : `boolean`, optional 

107 If True then instantiate SQLite registry with default configuration. 

108 If False then registry is just a namespace with `dimensions` attribute 

109 containing DimensionUniverse from default configuration. 

110 """ 

111 def __init__(self, fullRegistry=False, collection="TestColl"): 

112 self.datasets = {} 

113 self.fullRegistry = fullRegistry 

114 if self.fullRegistry: 

115 testDir = os.path.dirname(__file__) 

116 configFile = os.path.join(testDir, "config/butler.yaml") 

117 butlerConfig = ButlerConfig(configFile) 

118 self.registry = Registry.fromConfig(butlerConfig, create=True) 

119 self.registry.registerRun(collection) 

120 self.run = collection 

121 else: 

122 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

123 self.run = collection 

124 

125 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

126 """Copied from real Butler 

127 """ 

128 if isinstance(datasetRefOrType, DatasetRef): 

129 if dataId is not None or kwds: 

130 raise ValueError("DatasetRef given, cannot use dataId as well") 

131 datasetType = datasetRefOrType.datasetType 

132 dataId = datasetRefOrType.dataId 

133 else: 

134 # Don't check whether DataId is provided, because Registry APIs 

135 # can usually construct a better error message when it wasn't. 

136 if isinstance(datasetRefOrType, DatasetType): 

137 datasetType = datasetRefOrType 

138 else: 

139 datasetType = self.registry.getDatasetType(datasetRefOrType) 

140 return datasetType, dataId 

141 

142 @staticmethod 

143 def key(dataId): 

144 """Make a dict key out of dataId. 

145 """ 

146 return frozenset(dataId.items()) 

147 

148 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

149 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

150 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

151 dsTypeName = datasetType.name 

152 key = self.key(dataId) 

153 dsdata = self.datasets.get(dsTypeName) 

154 if dsdata: 

155 return dsdata.get(key) 

156 return None 

157 

158 def put(self, obj, datasetRefOrType, dataId=None, producer=None, **kwds): 

159 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

160 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

161 dsTypeName = datasetType.name 

162 key = self.key(dataId) 

163 dsdata = self.datasets.setdefault(dsTypeName, {}) 

164 dsdata[key] = obj 

165 if self.fullRegistry: 

166 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, producer=producer, 

167 recursive=False, **kwds) 

168 else: 

169 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

170 refId = sum(len(val) for val in self.datasets.values()) 

171 ref = DatasetRef(datasetType, dataId, id=refId) 

172 return ref 

173 

174 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

175 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

176 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

177 dsTypeName = datasetType.name 

178 key = self.key(dataId) 

179 dsdata = self.datasets.get(dsTypeName) 

180 del dsdata[key] 

181 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

182 if remember: 

183 self.registry.disassociate(self.run, [ref]) 

184 else: 

185 self.registry.removeDataset(ref) 

186 

187 

188def registerDatasetTypes(registry, pipeline): 

189 """Register all dataset types used by tasks in a registry. 

190 

191 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

192 

193 Parameters 

194 ---------- 

195 registry : `~lsst.daf.butler.Registry` 

196 Registry instance. 

197 pipeline : `~lsst.pipe.base.Pipeline` 

198 Iterable of TaskDef instances. 

199 """ 

200 for taskDef in pipeline: 

201 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

202 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

203 datasetTypes.inputs, datasetTypes.outputs, 

204 datasetTypes.prerequisites): 

205 _LOG.info("Registering %s with registry", datasetType) 

206 # this is a no-op if it already exists and is consistent, 

207 # and it raises if it is inconsistent. 

208 registry.registerDatasetType(datasetType) 

209 

210 

211def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False, clobberExisting=False): 

212 """Make simple QuantumGraph for tests. 

213 

214 Makes simple one-task pipeline with AddTask, sets up in-memory 

215 registry and butler, fills them with minimal data, and generates 

216 QuantumGraph with all of that. 

217 

218 Parameters 

219 ---------- 

220 nQuanta : `int` 

221 Number of quanta in a graph. 

222 pipeline : `~lsst.pipe.base.Pipeline` 

223 If `None` then one-task pipeline is made with `AddTask` and 

224 default `AddTaskConfig`. 

225 butler : `~lsst.daf.butler.Butler`, optional 

226 Data butler instance, this should be an instance returned from a 

227 previous call to this method. 

228 skipExisting : `bool`, optional 

229 If `True` (default), a Quantum is not created if all its outputs 

230 already exist. 

231 clobberExisting : `bool`, optional 

232 If `True`, overwrite any outputs that already exist. Cannot be 

233 `True` if ``skipExisting`` is. 

234 

235 Returns 

236 ------- 

237 butler : `~lsst.daf.butler.Butler` 

238 Butler instance 

239 qgraph : `~lsst.pipe.base.QuantumGraph` 

240 Quantum graph instance 

241 """ 

242 

243 if pipeline is None: 

244 pipeline = pipeBase.Pipeline("test pipeline") 

245 pipeline.addTask(AddTask, "task1") 

246 pipeline = list(pipeline.toExpandedPipeline()) 

247 

248 if butler is None: 

249 

250 butler = ButlerMock(fullRegistry=True) 

251 

252 # Add dataset types to registry 

253 registerDatasetTypes(butler.registry, pipeline) 

254 

255 # Small set of DataIds included in QGraph 

256 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

257 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

258 

259 # Add all needed dimensions to registry 

260 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

261 butler.registry.insertDimensionData("detector", *records) 

262 

263 # Add inputs to butler 

264 for i, dataId in enumerate(dataIds): 

265 data = numpy.array([i, 10*i]) 

266 butler.put(data, "add_input", dataId) 

267 

268 # Make the graph, task factory is not needed here 

269 builder = pipeBase.GraphBuilder(registry=butler.registry, 

270 skipExisting=skipExisting, clobberExisting=clobberExisting) 

271 qgraph = builder.makeGraph( 

272 pipeline, 

273 inputCollections=defaultdict(functools.partial(list, [butler.run])), 

274 outputCollection=butler.run, 

275 userQuery="" 

276 ) 

277 

278 return butler, qgraph