Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import contextlib 

28import itertools 

29import logging 

30import numpy 

31import os 

32from types import SimpleNamespace 

33 

34from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

35 DatasetType, Registry, CollectionSearch) 

36import lsst.pex.config as pexConfig 

37import lsst.pipe.base as pipeBase 

38from lsst.pipe.base import connectionTypes as cT 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class AddTaskConnections(pipeBase.PipelineTaskConnections, 

44 dimensions=("instrument", "detector")): 

45 input = cT.Input(name="add_input", 

46 dimensions=["instrument", "detector"], 

47 storageClass="NumpyArray", 

48 doc="Input dataset type for this task") 

49 output = cT.Output(name="add_output", 

50 dimensions=["instrument", "detector"], 

51 storageClass="NumpyArray", 

52 doc="Output dataset type for this task") 

53 initout = cT.InitOutput(name="add_init_output", 

54 storageClass="NumpyArray", 

55 doc="Init Output dataset type for this task") 

56 

57 

58class AddTaskConfig(pipeBase.PipelineTaskConfig, 

59 pipelineConnections=AddTaskConnections): 

60 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

61 

62 

63# example task which overrides run() method 

64class AddTask(pipeBase.PipelineTask): 

65 ConfigClass = AddTaskConfig 

66 _DefaultName = "add_task" 

67 

68 initout = numpy.array([999]) 

69 """InitOutputs for this task""" 

70 

71 countExec = 0 

72 """Number of times run() method was called for this class""" 

73 

74 stopAt = -1 

75 """Raises exception at this call to run()""" 

76 

77 def run(self, input): 

78 if AddTask.stopAt == AddTask.countExec: 

79 raise RuntimeError("pretend something bad happened") 

80 AddTask.countExec += 1 

81 self.metadata.add("add", self.config.addend) 

82 output = [val + self.config.addend for val in input] 

83 _LOG.info("input = %s, output = %s", input, output) 

84 return pipeBase.Struct(output=output) 

85 

86 

87class AddTaskFactoryMock(pipeBase.TaskFactory): 

88 def loadTaskClass(self, taskName): 

89 if taskName == "AddTask": 

90 return AddTask, "AddTask" 

91 

92 def makeTask(self, taskClass, config, overrides, butler): 

93 if config is None: 

94 config = taskClass.ConfigClass() 

95 if overrides: 

96 overrides.applyTo(config) 

97 return taskClass(config=config, initInputs=None) 

98 

99 

100class ButlerMock: 

101 """Mock version of butler, only usable for testing 

102 

103 Parameters 

104 ---------- 

105 fullRegistry : `boolean`, optional 

106 If True then instantiate SQLite registry with default configuration. 

107 If False then registry is just a namespace with `dimensions` attribute 

108 containing DimensionUniverse from default configuration. 

109 """ 

110 def __init__(self, fullRegistry=False, collection="TestColl"): 

111 self.datasets = {} 

112 self.fullRegistry = fullRegistry 

113 if self.fullRegistry: 

114 testDir = os.path.dirname(__file__) 

115 configFile = os.path.join(testDir, "config/butler.yaml") 

116 butlerConfig = ButlerConfig(configFile) 

117 self.registry = Registry.fromConfig(butlerConfig, create=True) 

118 self.registry.registerRun(collection) 

119 self.run = collection 

120 else: 

121 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

122 self.run = collection 

123 

124 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

125 """Copied from real Butler 

126 """ 

127 if isinstance(datasetRefOrType, DatasetRef): 

128 if dataId is not None or kwds: 

129 raise ValueError("DatasetRef given, cannot use dataId as well") 

130 datasetType = datasetRefOrType.datasetType 

131 dataId = datasetRefOrType.dataId 

132 else: 

133 # Don't check whether DataId is provided, because Registry APIs 

134 # can usually construct a better error message when it wasn't. 

135 if isinstance(datasetRefOrType, DatasetType): 

136 datasetType = datasetRefOrType 

137 else: 

138 datasetType = self.registry.getDatasetType(datasetRefOrType) 

139 return datasetType, dataId 

140 

141 @staticmethod 

142 def key(dataId): 

143 """Make a dict key out of dataId. 

144 """ 

145 return frozenset(dataId.items()) 

146 

147 @contextlib.contextmanager 

148 def transaction(self): 

149 yield 

150 

151 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

152 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

153 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

154 dsTypeName = datasetType.name 

155 key = self.key(dataId) 

156 dsdata = self.datasets.get(dsTypeName) 

157 if dsdata: 

158 return dsdata.get(key) 

159 raise LookupError 

160 

161 def put(self, obj, datasetRefOrType, dataId=None, **kwds): 

162 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

163 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

164 dsTypeName = datasetType.name 

165 key = self.key(dataId) 

166 dsdata = self.datasets.setdefault(dsTypeName, {}) 

167 dsdata[key] = obj 

168 if self.fullRegistry: 

169 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, **kwds) 

170 else: 

171 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

172 refId = sum(len(val) for val in self.datasets.values()) 

173 ref = DatasetRef(datasetType, dataId, id=refId) 

174 return ref 

175 

176 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

177 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

178 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

179 dsTypeName = datasetType.name 

180 key = self.key(dataId) 

181 dsdata = self.datasets.get(dsTypeName) 

182 del dsdata[key] 

183 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

184 if remember: 

185 self.registry.disassociate(self.run, [ref]) 

186 else: 

187 self.registry.removeDatasets([ref]) 

188 

189 

190def registerDatasetTypes(registry, pipeline): 

191 """Register all dataset types used by tasks in a registry. 

192 

193 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

194 

195 Parameters 

196 ---------- 

197 registry : `~lsst.daf.butler.Registry` 

198 Registry instance. 

199 pipeline : `~lsst.pipe.base.Pipeline` 

200 Iterable of TaskDef instances. 

201 """ 

202 for taskDef in pipeline: 

203 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

204 storageClass="Config", 

205 universe=registry.dimensions) 

206 packagesDatasetType = DatasetType("packages", {}, 

207 storageClass="Packages", 

208 universe=registry.dimensions) 

209 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

210 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

211 datasetTypes.inputs, datasetTypes.outputs, 

212 datasetTypes.prerequisites, 

213 [configDatasetType, packagesDatasetType]): 

214 _LOG.info("Registering %s with registry", datasetType) 

215 # this is a no-op if it already exists and is consistent, 

216 # and it raises if it is inconsistent. 

217 registry.registerDatasetType(datasetType) 

218 

219 

220def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False): 

221 """Make simple QuantumGraph for tests. 

222 

223 Makes simple one-task pipeline with AddTask, sets up in-memory 

224 registry and butler, fills them with minimal data, and generates 

225 QuantumGraph with all of that. 

226 

227 Parameters 

228 ---------- 

229 nQuanta : `int` 

230 Number of quanta in a graph. 

231 pipeline : `~lsst.pipe.base.Pipeline` 

232 If `None` then one-task pipeline is made with `AddTask` and 

233 default `AddTaskConfig`. 

234 butler : `~lsst.daf.butler.Butler`, optional 

235 Data butler instance, this should be an instance returned from a 

236 previous call to this method. 

237 skipExisting : `bool`, optional 

238 If `True` (default), a Quantum is not created if all its outputs 

239 already exist. 

240 

241 Returns 

242 ------- 

243 butler : `~lsst.daf.butler.Butler` 

244 Butler instance 

245 qgraph : `~lsst.pipe.base.QuantumGraph` 

246 Quantum graph instance 

247 """ 

248 

249 if pipeline is None: 

250 pipeline = pipeBase.Pipeline("test pipeline") 

251 pipeline.addTask(AddTask, "task1") 

252 pipeline = list(pipeline.toExpandedPipeline()) 

253 

254 if butler is None: 

255 

256 butler = ButlerMock(fullRegistry=True) 

257 

258 # Add dataset types to registry 

259 registerDatasetTypes(butler.registry, pipeline) 

260 

261 # Small set of DataIds included in QGraph 

262 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

263 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

264 

265 # Add all needed dimensions to registry 

266 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

267 butler.registry.insertDimensionData("detector", *records) 

268 

269 # Add inputs to butler 

270 for i, dataId in enumerate(dataIds): 

271 data = numpy.array([i, 10*i]) 

272 butler.put(data, "add_input", dataId) 

273 

274 # Make the graph, task factory is not needed here 

275 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

276 qgraph = builder.makeGraph( 

277 pipeline, 

278 collections=CollectionSearch.fromExpression(butler.run), 

279 run=butler.run, 

280 userQuery="" 

281 ) 

282 

283 return butler, qgraph