Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import contextlib 

28import itertools 

29import logging 

30import numpy 

31import os 

32from types import SimpleNamespace 

33 

34from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

35 DatasetType, Registry, CollectionSearch) 

36import lsst.pex.config as pexConfig 

37import lsst.pipe.base as pipeBase 

38from lsst.pipe.base import connectionTypes as cT 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class AddTaskConnections(pipeBase.PipelineTaskConnections, 

44 dimensions=("instrument", "detector")): 

45 input = cT.Input(name="add_input", 

46 dimensions=["instrument", "detector"], 

47 storageClass="NumpyArray", 

48 doc="Input dataset type for this task") 

49 output = cT.Output(name="add_output", 

50 dimensions=["instrument", "detector"], 

51 storageClass="NumpyArray", 

52 doc="Output dataset type for this task") 

53 initout = cT.InitOutput(name="add_init_output", 

54 storageClass="NumpyArray", 

55 doc="Init Output dataset type for this task") 

56 

57 

58class AddTaskConfig(pipeBase.PipelineTaskConfig, 

59 pipelineConnections=AddTaskConnections): 

60 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

61 

62 

63# example task which overrides run() method 

64class AddTask(pipeBase.PipelineTask): 

65 ConfigClass = AddTaskConfig 

66 _DefaultName = "add_task" 

67 

68 initout = numpy.array([999]) 

69 """InitOutputs for this task""" 

70 

71 countExec = 0 

72 """Number of times run() method was called for this class""" 

73 

74 stopAt = -1 

75 """Raises exception at this call to run()""" 

76 

77 def run(self, input): 

78 if AddTask.stopAt == AddTask.countExec: 

79 raise RuntimeError("pretend something bad happened") 

80 AddTask.countExec += 1 

81 self.metadata.add("add", self.config.addend) 

82 output = [val + self.config.addend for val in input] 

83 _LOG.info("input = %s, output = %s", input, output) 

84 return pipeBase.Struct(output=output) 

85 

86 

87class AddTaskFactoryMock(pipeBase.TaskFactory): 

88 def loadTaskClass(self, taskName): 

89 if taskName == "AddTask": 

90 return AddTask, "AddTask" 

91 

92 def makeTask(self, taskClass, config, overrides, butler): 

93 if config is None: 

94 config = taskClass.ConfigClass() 

95 if overrides: 

96 overrides.applyTo(config) 

97 return taskClass(config=config, initInputs=None) 

98 

99 

100class ButlerMock: 

101 """Mock version of butler, only usable for testing 

102 

103 Parameters 

104 ---------- 

105 fullRegistry : `boolean`, optional 

106 If True then instantiate SQLite registry with default configuration. 

107 If False then registry is just a namespace with `dimensions` attribute 

108 containing DimensionUniverse from default configuration. 

109 """ 

110 def __init__(self, fullRegistry=False, collection="TestColl"): 

111 self.datasets = {} 

112 self.fullRegistry = fullRegistry 

113 if self.fullRegistry: 

114 testDir = os.path.dirname(__file__) 

115 configFile = os.path.join(testDir, "config/butler.yaml") 

116 butlerConfig = ButlerConfig(configFile) 

117 self.registry = Registry.fromConfig(butlerConfig, create=True) 

118 self.registry.registerRun(collection) 

119 self.run = collection 

120 else: 

121 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

122 self.run = collection 

123 

124 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

125 """Copied from real Butler 

126 """ 

127 if isinstance(datasetRefOrType, DatasetRef): 

128 if dataId is not None or kwds: 

129 raise ValueError("DatasetRef given, cannot use dataId as well") 

130 datasetType = datasetRefOrType.datasetType 

131 dataId = datasetRefOrType.dataId 

132 else: 

133 # Don't check whether DataId is provided, because Registry APIs 

134 # can usually construct a better error message when it wasn't. 

135 if isinstance(datasetRefOrType, DatasetType): 

136 datasetType = datasetRefOrType 

137 else: 

138 datasetType = self.registry.getDatasetType(datasetRefOrType) 

139 return datasetType, dataId 

140 

141 @staticmethod 

142 def key(dataId): 

143 """Make a dict key out of dataId. 

144 """ 

145 return frozenset(dataId.items()) 

146 

147 @contextlib.contextmanager 

148 def transaction(self): 

149 yield 

150 

151 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

152 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

153 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

154 dsTypeName = datasetType.name 

155 key = self.key(dataId) 

156 dsdata = self.datasets.get(dsTypeName) 

157 if dsdata: 

158 return dsdata.get(key) 

159 raise LookupError 

160 

161 def put(self, obj, datasetRefOrType, dataId=None, **kwds): 

162 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

163 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

164 dsTypeName = datasetType.name 

165 key = self.key(dataId) 

166 dsdata = self.datasets.setdefault(dsTypeName, {}) 

167 dsdata[key] = obj 

168 if self.fullRegistry: 

169 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, **kwds) 

170 else: 

171 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

172 refId = sum(len(val) for val in self.datasets.values()) 

173 ref = DatasetRef(datasetType, dataId, id=refId) 

174 return ref 

175 

176 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

177 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

178 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

179 dsTypeName = datasetType.name 

180 key = self.key(dataId) 

181 dsdata = self.datasets.get(dsTypeName) 

182 del dsdata[key] 

183 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

184 if remember: 

185 self.registry.disassociate(self.run, [ref]) 

186 else: 

187 self.registry.removeDatasets([ref]) 

188 

189 

190def registerDatasetTypes(registry, pipeline): 

191 """Register all dataset types used by tasks in a registry. 

192 

193 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

194 

195 Parameters 

196 ---------- 

197 registry : `~lsst.daf.butler.Registry` 

198 Registry instance. 

199 pipeline : `~lsst.pipe.base.Pipeline` 

200 Iterable of TaskDef instances. 

201 """ 

202 for taskDef in pipeline: 

203 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

204 storageClass="Config", 

205 universe=registry.dimensions) 

206 packagesDatasetType = DatasetType("packages", {}, 

207 storageClass="Packages", 

208 universe=registry.dimensions) 

209 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

210 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

211 datasetTypes.inputs, datasetTypes.outputs, 

212 datasetTypes.prerequisites, 

213 [configDatasetType, packagesDatasetType]): 

214 _LOG.info("Registering %s with registry", datasetType) 

215 # this is a no-op if it already exists and is consistent, 

216 # and it raises if it is inconsistent. But components must be 

217 # skipped 

218 if not datasetType.isComponent(): 

219 registry.registerDatasetType(datasetType) 

220 

221 

222def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False): 

223 """Make simple QuantumGraph for tests. 

224 

225 Makes simple one-task pipeline with AddTask, sets up in-memory 

226 registry and butler, fills them with minimal data, and generates 

227 QuantumGraph with all of that. 

228 

229 Parameters 

230 ---------- 

231 nQuanta : `int` 

232 Number of quanta in a graph. 

233 pipeline : `~lsst.pipe.base.Pipeline` 

234 If `None` then one-task pipeline is made with `AddTask` and 

235 default `AddTaskConfig`. 

236 butler : `~lsst.daf.butler.Butler`, optional 

237 Data butler instance, this should be an instance returned from a 

238 previous call to this method. 

239 skipExisting : `bool`, optional 

240 If `True` (default), a Quantum is not created if all its outputs 

241 already exist. 

242 

243 Returns 

244 ------- 

245 butler : `~lsst.daf.butler.Butler` 

246 Butler instance 

247 qgraph : `~lsst.pipe.base.QuantumGraph` 

248 Quantum graph instance 

249 """ 

250 

251 if pipeline is None: 

252 pipeline = pipeBase.Pipeline("test pipeline") 

253 pipeline.addTask(AddTask, "task1") 

254 pipeline = list(pipeline.toExpandedPipeline()) 

255 

256 if butler is None: 

257 

258 butler = ButlerMock(fullRegistry=True) 

259 

260 # Add dataset types to registry 

261 registerDatasetTypes(butler.registry, pipeline) 

262 

263 # Small set of DataIds included in QGraph 

264 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

265 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

266 

267 # Add all needed dimensions to registry 

268 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

269 butler.registry.insertDimensionData("detector", *records) 

270 

271 # Add inputs to butler 

272 for i, dataId in enumerate(dataIds): 

273 data = numpy.array([i, 10*i]) 

274 butler.put(data, "add_input", dataId) 

275 

276 # Make the graph, task factory is not needed here 

277 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

278 qgraph = builder.makeGraph( 

279 pipeline, 

280 collections=CollectionSearch.fromExpression(butler.run), 

281 run=butler.run, 

282 userQuery="" 

283 ) 

284 

285 return butler, qgraph