Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import contextlib 

28import itertools 

29import logging 

30import numpy 

31import os 

32from types import SimpleNamespace 

33 

34from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

35 DatasetType, Registry, CollectionSearch) 

36import lsst.pex.config as pexConfig 

37import lsst.pipe.base as pipeBase 

38from lsst.pipe.base import connectionTypes as cT 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class AddTaskConnections(pipeBase.PipelineTaskConnections, 

44 dimensions=("instrument", "detector")): 

45 input = cT.Input(name="add_input", 

46 dimensions=["instrument", "detector"], 

47 storageClass="NumpyArray", 

48 doc="Input dataset type for this task") 

49 output = cT.Output(name="add_output", 

50 dimensions=["instrument", "detector"], 

51 storageClass="NumpyArray", 

52 doc="Output dataset type for this task") 

53 initout = cT.InitOutput(name="add_init_output", 

54 storageClass="NumpyArray", 

55 doc="Init Output dataset type for this task") 

56 

57 

58class AddTaskConfig(pipeBase.PipelineTaskConfig, 

59 pipelineConnections=AddTaskConnections): 

60 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

61 

62 

63# example task which overrides run() method 

64class AddTask(pipeBase.PipelineTask): 

65 ConfigClass = AddTaskConfig 

66 _DefaultName = "add_task" 

67 

68 initout = numpy.array([999]) 

69 """InitOutputs for this task""" 

70 

71 countExec = 0 

72 """Number of times run() method was called for this class""" 

73 

74 stopAt = -1 

75 """Raises exception at this call to run()""" 

76 

77 def run(self, input): 

78 if AddTask.stopAt == AddTask.countExec: 

79 raise RuntimeError("pretend something bad happened") 

80 AddTask.countExec += 1 

81 self.metadata.add("add", self.config.addend) 

82 output = [val + self.config.addend for val in input] 

83 _LOG.info("input = %s, output = %s", input, output) 

84 return pipeBase.Struct(output=output) 

85 

86 

87class AddTaskFactoryMock(pipeBase.TaskFactory): 

88 def loadTaskClass(self, taskName): 

89 if taskName == "AddTask": 

90 return AddTask, "AddTask" 

91 

92 def makeTask(self, taskClass, config, overrides, butler): 

93 if config is None: 

94 config = taskClass.ConfigClass() 

95 if overrides: 

96 overrides.applyTo(config) 

97 return taskClass(config=config, initInputs=None) 

98 

99 

100class ButlerMock: 

101 """Mock version of butler, only usable for testing 

102 

103 Parameters 

104 ---------- 

105 fullRegistry : `boolean`, optional 

106 If True then instantiate SQLite registry with default configuration. 

107 If False then registry is just a namespace with `dimensions` attribute 

108 containing DimensionUniverse from default configuration. 

109 """ 

110 def __init__(self, fullRegistry=False, collection="TestColl"): 

111 self.datasets = {} 

112 self.fullRegistry = fullRegistry 

113 if self.fullRegistry: 

114 testDir = os.path.dirname(__file__) 

115 configFile = os.path.join(testDir, "config/butler.yaml") 

116 butlerConfig = ButlerConfig(configFile) 

117 self.registry = Registry.fromConfig(butlerConfig, create=True) 

118 self.registry.registerRun(collection) 

119 self.run = collection 

120 else: 

121 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

122 self.run = collection 

123 

124 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

125 """Copied from real Butler 

126 """ 

127 if isinstance(datasetRefOrType, DatasetRef): 

128 if dataId is not None or kwds: 

129 raise ValueError("DatasetRef given, cannot use dataId as well") 

130 datasetType = datasetRefOrType.datasetType 

131 dataId = datasetRefOrType.dataId 

132 else: 

133 # Don't check whether DataId is provided, because Registry APIs 

134 # can usually construct a better error message when it wasn't. 

135 if isinstance(datasetRefOrType, DatasetType): 

136 datasetType = datasetRefOrType 

137 else: 

138 datasetType = self.registry.getDatasetType(datasetRefOrType) 

139 return datasetType, dataId 

140 

141 @classmethod 

142 def _unpickle(cls, fullRegistry, run): 

143 return cls(fullRegistry, run) 

144 

145 def __reduce__(self): 

146 """Support pickling. 

147 """ 

148 return (ButlerMock._unpickle, (self.fullRegistry, self.run)) 

149 

150 @staticmethod 

151 def key(dataId): 

152 """Make a dict key out of dataId. 

153 """ 

154 return frozenset(dataId.items()) 

155 

156 @contextlib.contextmanager 

157 def transaction(self): 

158 yield 

159 

160 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

161 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

162 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

163 dsTypeName = datasetType.name 

164 key = self.key(dataId) 

165 dsdata = self.datasets.get(dsTypeName) 

166 if dsdata: 

167 return dsdata.get(key) 

168 raise LookupError 

169 

170 def put(self, obj, datasetRefOrType, dataId=None, **kwds): 

171 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

172 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

173 dsTypeName = datasetType.name 

174 key = self.key(dataId) 

175 dsdata = self.datasets.setdefault(dsTypeName, {}) 

176 dsdata[key] = obj 

177 if self.fullRegistry: 

178 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, **kwds) 

179 else: 

180 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

181 refId = sum(len(val) for val in self.datasets.values()) 

182 ref = DatasetRef(datasetType, dataId, id=refId) 

183 return ref 

184 

185 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

186 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

187 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

188 dsTypeName = datasetType.name 

189 key = self.key(dataId) 

190 dsdata = self.datasets.get(dsTypeName) 

191 del dsdata[key] 

192 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

193 if remember: 

194 self.registry.disassociate(self.run, [ref]) 

195 else: 

196 self.registry.removeDatasets([ref]) 

197 

198 

199def registerDatasetTypes(registry, pipeline): 

200 """Register all dataset types used by tasks in a registry. 

201 

202 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

203 

204 Parameters 

205 ---------- 

206 registry : `~lsst.daf.butler.Registry` 

207 Registry instance. 

208 pipeline : `~lsst.pipe.base.Pipeline` 

209 Iterable of TaskDef instances. 

210 """ 

211 for taskDef in pipeline: 

212 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

213 storageClass="Config", 

214 universe=registry.dimensions) 

215 packagesDatasetType = DatasetType("packages", {}, 

216 storageClass="Packages", 

217 universe=registry.dimensions) 

218 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

219 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

220 datasetTypes.inputs, datasetTypes.outputs, 

221 datasetTypes.prerequisites, 

222 [configDatasetType, packagesDatasetType]): 

223 _LOG.info("Registering %s with registry", datasetType) 

224 # this is a no-op if it already exists and is consistent, 

225 # and it raises if it is inconsistent. But components must be 

226 # skipped 

227 if not datasetType.isComponent(): 

228 registry.registerDatasetType(datasetType) 

229 

230 

231def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False): 

232 """Make simple QuantumGraph for tests. 

233 

234 Makes simple one-task pipeline with AddTask, sets up in-memory 

235 registry and butler, fills them with minimal data, and generates 

236 QuantumGraph with all of that. 

237 

238 Parameters 

239 ---------- 

240 nQuanta : `int` 

241 Number of quanta in a graph. 

242 pipeline : `~lsst.pipe.base.Pipeline` 

243 If `None` then one-task pipeline is made with `AddTask` and 

244 default `AddTaskConfig`. 

245 butler : `~lsst.daf.butler.Butler`, optional 

246 Data butler instance, this should be an instance returned from a 

247 previous call to this method. 

248 skipExisting : `bool`, optional 

249 If `True` (default), a Quantum is not created if all its outputs 

250 already exist. 

251 

252 Returns 

253 ------- 

254 butler : `~lsst.daf.butler.Butler` 

255 Butler instance 

256 qgraph : `~lsst.pipe.base.QuantumGraph` 

257 Quantum graph instance 

258 """ 

259 

260 if pipeline is None: 

261 pipeline = pipeBase.Pipeline("test pipeline") 

262 pipeline.addTask(AddTask, "task1") 

263 pipeline = list(pipeline.toExpandedPipeline()) 

264 

265 if butler is None: 

266 

267 butler = ButlerMock(fullRegistry=True) 

268 

269 # Add dataset types to registry 

270 registerDatasetTypes(butler.registry, pipeline) 

271 

272 # Small set of DataIds included in QGraph 

273 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

274 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

275 

276 # Add all needed dimensions to registry 

277 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

278 butler.registry.insertDimensionData("detector", *records) 

279 

280 # Add inputs to butler 

281 for i, dataId in enumerate(dataIds): 

282 data = numpy.array([i, 10*i]) 

283 butler.put(data, "add_input", dataId) 

284 

285 # Make the graph, task factory is not needed here 

286 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

287 qgraph = builder.makeGraph( 

288 pipeline, 

289 collections=CollectionSearch.fromExpression(butler.run), 

290 run=butler.run, 

291 userQuery="" 

292 ) 

293 

294 return butler, qgraph