Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import contextlib 

28import itertools 

29import logging 

30import numpy 

31import os 

32from types import SimpleNamespace 

33 

34from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

35 DatasetType, Registry, CollectionSearch) 

36import lsst.pex.config as pexConfig 

37import lsst.pipe.base as pipeBase 

38from lsst.pipe.base import connectionTypes as cT 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class AddTaskConnections(pipeBase.PipelineTaskConnections, 

44 dimensions=("instrument", "detector")): 

45 input = cT.Input(name="add_input", 

46 dimensions=["instrument", "detector"], 

47 storageClass="NumpyArray", 

48 doc="Input dataset type for this task") 

49 output = cT.Output(name="add_output", 

50 dimensions=["instrument", "detector"], 

51 storageClass="NumpyArray", 

52 doc="Output dataset type for this task") 

53 initout = cT.InitOutput(name="add_init_output", 

54 storageClass="NumpyArray", 

55 doc="Init Output dataset type for this task") 

56 

57 

58class AddTaskConfig(pipeBase.PipelineTaskConfig, 

59 pipelineConnections=AddTaskConnections): 

60 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

61 

62 

63# example task which overrides run() method 

64class AddTask(pipeBase.PipelineTask): 

65 ConfigClass = AddTaskConfig 

66 _DefaultName = "add_task" 

67 

68 initout = numpy.array([999]) 

69 """InitOutputs for this task""" 

70 

71 countExec = 0 

72 """Number of times run() method was called for this class""" 

73 

74 stopAt = -1 

75 """Raises exception at this call to run()""" 

76 

77 def run(self, input): 

78 if AddTask.stopAt == AddTask.countExec: 

79 raise RuntimeError("pretend something bad happened") 

80 AddTask.countExec += 1 

81 self.metadata.add("add", self.config.addend) 

82 output = [val + self.config.addend for val in input] 

83 _LOG.info("input = %s, output = %s", input, output) 

84 return pipeBase.Struct(output=output) 

85 

86 

87class AddTaskFactoryMock(pipeBase.TaskFactory): 

88 def loadTaskClass(self, taskName): 

89 if taskName == "AddTask": 

90 return AddTask, "AddTask" 

91 

92 def makeTask(self, taskClass, config, overrides, butler): 

93 if config is None: 

94 config = taskClass.ConfigClass() 

95 if overrides: 

96 overrides.applyTo(config) 

97 return taskClass(config=config, initInputs=None) 

98 

99 

100class ButlerMock: 

101 """Mock version of butler, only usable for testing 

102 

103 Parameters 

104 ---------- 

105 fullRegistry : `boolean`, optional 

106 If True then instantiate SQLite registry with default configuration. 

107 If False then registry is just a namespace with `dimensions` attribute 

108 containing DimensionUniverse from default configuration. 

109 """ 

110 def __init__(self, fullRegistry=False, collection="TestColl"): 

111 self.datasets = {} 

112 self.fullRegistry = fullRegistry 

113 if self.fullRegistry: 

114 testDir = os.path.dirname(__file__) 

115 configFile = os.path.join(testDir, "config/butler.yaml") 

116 butlerConfig = ButlerConfig(configFile) 

117 self.registry = Registry.fromConfig(butlerConfig, create=True) 

118 self.registry.registerRun(collection) 

119 self.run = collection 

120 else: 

121 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

122 self.run = collection 

123 

124 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

125 """Copied from real Butler 

126 """ 

127 if isinstance(datasetRefOrType, DatasetRef): 

128 if dataId is not None or kwds: 

129 raise ValueError("DatasetRef given, cannot use dataId as well") 

130 datasetType = datasetRefOrType.datasetType 

131 dataId = datasetRefOrType.dataId 

132 else: 

133 # Don't check whether DataId is provided, because Registry APIs 

134 # can usually construct a better error message when it wasn't. 

135 if isinstance(datasetRefOrType, DatasetType): 

136 datasetType = datasetRefOrType 

137 else: 

138 datasetType = self.registry.getDatasetType(datasetRefOrType) 

139 return datasetType, dataId 

140 

141 @staticmethod 

142 def key(dataId): 

143 """Make a dict key out of dataId. 

144 """ 

145 return frozenset(dataId.items()) 

146 

147 @contextlib.contextmanager 

148 def transaction(self): 

149 yield 

150 

151 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

152 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

153 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

154 dsTypeName = datasetType.name 

155 key = self.key(dataId) 

156 dsdata = self.datasets.get(dsTypeName) 

157 if dsdata: 

158 return dsdata.get(key) 

159 return None 

160 

161 def put(self, obj, datasetRefOrType, dataId=None, producer=None, **kwds): 

162 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

163 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

164 dsTypeName = datasetType.name 

165 key = self.key(dataId) 

166 dsdata = self.datasets.setdefault(dsTypeName, {}) 

167 dsdata[key] = obj 

168 if self.fullRegistry: 

169 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, producer=producer, 

170 **kwds) 

171 else: 

172 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

173 refId = sum(len(val) for val in self.datasets.values()) 

174 ref = DatasetRef(datasetType, dataId, id=refId) 

175 return ref 

176 

177 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

178 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

179 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

180 dsTypeName = datasetType.name 

181 key = self.key(dataId) 

182 dsdata = self.datasets.get(dsTypeName) 

183 del dsdata[key] 

184 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

185 if remember: 

186 self.registry.disassociate(self.run, [ref]) 

187 else: 

188 self.registry.removeDatasets([ref]) 

189 

190 

191def registerDatasetTypes(registry, pipeline): 

192 """Register all dataset types used by tasks in a registry. 

193 

194 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

195 

196 Parameters 

197 ---------- 

198 registry : `~lsst.daf.butler.Registry` 

199 Registry instance. 

200 pipeline : `~lsst.pipe.base.Pipeline` 

201 Iterable of TaskDef instances. 

202 """ 

203 for taskDef in pipeline: 

204 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

205 storageClass="Config", 

206 universe=registry.dimensions) 

207 packagesDatasetType = DatasetType("packages", {}, 

208 storageClass="Packages", 

209 universe=registry.dimensions) 

210 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

211 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

212 datasetTypes.inputs, datasetTypes.outputs, 

213 datasetTypes.prerequisites, 

214 [configDatasetType, packagesDatasetType]): 

215 _LOG.info("Registering %s with registry", datasetType) 

216 # this is a no-op if it already exists and is consistent, 

217 # and it raises if it is inconsistent. 

218 registry.registerDatasetType(datasetType) 

219 

220 

221def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False): 

222 """Make simple QuantumGraph for tests. 

223 

224 Makes simple one-task pipeline with AddTask, sets up in-memory 

225 registry and butler, fills them with minimal data, and generates 

226 QuantumGraph with all of that. 

227 

228 Parameters 

229 ---------- 

230 nQuanta : `int` 

231 Number of quanta in a graph. 

232 pipeline : `~lsst.pipe.base.Pipeline` 

233 If `None` then one-task pipeline is made with `AddTask` and 

234 default `AddTaskConfig`. 

235 butler : `~lsst.daf.butler.Butler`, optional 

236 Data butler instance, this should be an instance returned from a 

237 previous call to this method. 

238 skipExisting : `bool`, optional 

239 If `True` (default), a Quantum is not created if all its outputs 

240 already exist. 

241 

242 Returns 

243 ------- 

244 butler : `~lsst.daf.butler.Butler` 

245 Butler instance 

246 qgraph : `~lsst.pipe.base.QuantumGraph` 

247 Quantum graph instance 

248 """ 

249 

250 if pipeline is None: 

251 pipeline = pipeBase.Pipeline("test pipeline") 

252 pipeline.addTask(AddTask, "task1") 

253 pipeline = list(pipeline.toExpandedPipeline()) 

254 

255 if butler is None: 

256 

257 butler = ButlerMock(fullRegistry=True) 

258 

259 # Add dataset types to registry 

260 registerDatasetTypes(butler.registry, pipeline) 

261 

262 # Small set of DataIds included in QGraph 

263 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

264 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

265 

266 # Add all needed dimensions to registry 

267 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

268 butler.registry.insertDimensionData("detector", *records) 

269 

270 # Add inputs to butler 

271 for i, dataId in enumerate(dataIds): 

272 data = numpy.array([i, 10*i]) 

273 butler.put(data, "add_input", dataId) 

274 

275 # Make the graph, task factory is not needed here 

276 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

277 qgraph = builder.makeGraph( 

278 pipeline, 

279 collections=CollectionSearch.fromExpression(butler.run), 

280 run=butler.run, 

281 userQuery="" 

282 ) 

283 

284 return butler, qgraph