Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30import os 

31from types import SimpleNamespace 

32 

33from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse, 

34 DatasetType, Registry, CollectionSearch) 

35import lsst.pex.config as pexConfig 

36import lsst.pipe.base as pipeBase 

37from lsst.pipe.base import connectionTypes as cT 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class AddTaskConnections(pipeBase.PipelineTaskConnections, 

43 dimensions=("instrument", "detector")): 

44 input = cT.Input(name="add_input", 

45 dimensions=["instrument", "detector"], 

46 storageClass="NumpyArray", 

47 doc="Input dataset type for this task") 

48 output = cT.Output(name="add_output", 

49 dimensions=["instrument", "detector"], 

50 storageClass="NumpyArray", 

51 doc="Output dataset type for this task") 

52 initout = cT.InitOutput(name="add_init_output", 

53 storageClass="NumpyArray", 

54 doc="Init Output dataset type for this task") 

55 

56 

57class AddTaskConfig(pipeBase.PipelineTaskConfig, 

58 pipelineConnections=AddTaskConnections): 

59 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

60 

61 

62# example task which overrides run() method 

63class AddTask(pipeBase.PipelineTask): 

64 ConfigClass = AddTaskConfig 

65 _DefaultName = "add_task" 

66 

67 initout = numpy.array([999]) 

68 """InitOutputs for this task""" 

69 

70 countExec = 0 

71 """Number of times run() method was called for this class""" 

72 

73 stopAt = -1 

74 """Raises exception at this call to run()""" 

75 

76 def run(self, input): 

77 if AddTask.stopAt == AddTask.countExec: 

78 raise RuntimeError("pretend something bad happened") 

79 AddTask.countExec += 1 

80 self.metadata.add("add", self.config.addend) 

81 output = [val + self.config.addend for val in input] 

82 _LOG.info("input = %s, output = %s", input, output) 

83 return pipeBase.Struct(output=output) 

84 

85 

86class AddTaskFactoryMock(pipeBase.TaskFactory): 

87 def loadTaskClass(self, taskName): 

88 if taskName == "AddTask": 

89 return AddTask, "AddTask" 

90 

91 def makeTask(self, taskClass, config, overrides, butler): 

92 if config is None: 

93 config = taskClass.ConfigClass() 

94 if overrides: 

95 overrides.applyTo(config) 

96 return taskClass(config=config, initInputs=None) 

97 

98 

99class ButlerMock: 

100 """Mock version of butler, only usable for testing 

101 

102 Parameters 

103 ---------- 

104 fullRegistry : `boolean`, optional 

105 If True then instantiate SQLite registry with default configuration. 

106 If False then registry is just a namespace with `dimensions` attribute 

107 containing DimensionUniverse from default configuration. 

108 """ 

109 def __init__(self, fullRegistry=False, collection="TestColl"): 

110 self.datasets = {} 

111 self.fullRegistry = fullRegistry 

112 if self.fullRegistry: 

113 testDir = os.path.dirname(__file__) 

114 configFile = os.path.join(testDir, "config/butler.yaml") 

115 butlerConfig = ButlerConfig(configFile) 

116 self.registry = Registry.fromConfig(butlerConfig, create=True) 

117 self.registry.registerRun(collection) 

118 self.run = collection 

119 else: 

120 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig()) 

121 self.run = collection 

122 

123 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds): 

124 """Copied from real Butler 

125 """ 

126 if isinstance(datasetRefOrType, DatasetRef): 

127 if dataId is not None or kwds: 

128 raise ValueError("DatasetRef given, cannot use dataId as well") 

129 datasetType = datasetRefOrType.datasetType 

130 dataId = datasetRefOrType.dataId 

131 else: 

132 # Don't check whether DataId is provided, because Registry APIs 

133 # can usually construct a better error message when it wasn't. 

134 if isinstance(datasetRefOrType, DatasetType): 

135 datasetType = datasetRefOrType 

136 else: 

137 datasetType = self.registry.getDatasetType(datasetRefOrType) 

138 return datasetType, dataId 

139 

140 @staticmethod 

141 def key(dataId): 

142 """Make a dict key out of dataId. 

143 """ 

144 return frozenset(dataId.items()) 

145 

146 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds): 

147 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

148 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId) 

149 dsTypeName = datasetType.name 

150 key = self.key(dataId) 

151 dsdata = self.datasets.get(dsTypeName) 

152 if dsdata: 

153 return dsdata.get(key) 

154 return None 

155 

156 def put(self, obj, datasetRefOrType, dataId=None, producer=None, **kwds): 

157 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

158 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj) 

159 dsTypeName = datasetType.name 

160 key = self.key(dataId) 

161 dsdata = self.datasets.setdefault(dsTypeName, {}) 

162 dsdata[key] = obj 

163 if self.fullRegistry: 

164 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, producer=producer, 

165 recursive=False, **kwds) 

166 else: 

167 # we should return DatasetRef with reasonable ID, ID is supposed to be unique 

168 refId = sum(len(val) for val in self.datasets.values()) 

169 ref = DatasetRef(datasetType, dataId, id=refId) 

170 return ref 

171 

172 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds): 

173 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds) 

174 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId) 

175 dsTypeName = datasetType.name 

176 key = self.key(dataId) 

177 dsdata = self.datasets.get(dsTypeName) 

178 del dsdata[key] 

179 ref = self.registry.find(self.run, datasetType, dataId, **kwds) 

180 if remember: 

181 self.registry.disassociate(self.run, [ref]) 

182 else: 

183 self.registry.removeDatasets([ref]) 

184 

185 

186def registerDatasetTypes(registry, pipeline): 

187 """Register all dataset types used by tasks in a registry. 

188 

189 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

190 

191 Parameters 

192 ---------- 

193 registry : `~lsst.daf.butler.Registry` 

194 Registry instance. 

195 pipeline : `~lsst.pipe.base.Pipeline` 

196 Iterable of TaskDef instances. 

197 """ 

198 for taskDef in pipeline: 

199 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

200 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

201 datasetTypes.inputs, datasetTypes.outputs, 

202 datasetTypes.prerequisites): 

203 _LOG.info("Registering %s with registry", datasetType) 

204 # this is a no-op if it already exists and is consistent, 

205 # and it raises if it is inconsistent. 

206 registry.registerDatasetType(datasetType) 

207 

208 

209def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False): 

210 """Make simple QuantumGraph for tests. 

211 

212 Makes simple one-task pipeline with AddTask, sets up in-memory 

213 registry and butler, fills them with minimal data, and generates 

214 QuantumGraph with all of that. 

215 

216 Parameters 

217 ---------- 

218 nQuanta : `int` 

219 Number of quanta in a graph. 

220 pipeline : `~lsst.pipe.base.Pipeline` 

221 If `None` then one-task pipeline is made with `AddTask` and 

222 default `AddTaskConfig`. 

223 butler : `~lsst.daf.butler.Butler`, optional 

224 Data butler instance, this should be an instance returned from a 

225 previous call to this method. 

226 skipExisting : `bool`, optional 

227 If `True` (default), a Quantum is not created if all its outputs 

228 already exist. 

229 

230 Returns 

231 ------- 

232 butler : `~lsst.daf.butler.Butler` 

233 Butler instance 

234 qgraph : `~lsst.pipe.base.QuantumGraph` 

235 Quantum graph instance 

236 """ 

237 

238 if pipeline is None: 

239 pipeline = pipeBase.Pipeline("test pipeline") 

240 pipeline.addTask(AddTask, "task1") 

241 pipeline = list(pipeline.toExpandedPipeline()) 

242 

243 if butler is None: 

244 

245 butler = ButlerMock(fullRegistry=True) 

246 

247 # Add dataset types to registry 

248 registerDatasetTypes(butler.registry, pipeline) 

249 

250 # Small set of DataIds included in QGraph 

251 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)] 

252 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)] 

253 

254 # Add all needed dimensions to registry 

255 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

256 butler.registry.insertDimensionData("detector", *records) 

257 

258 # Add inputs to butler 

259 for i, dataId in enumerate(dataIds): 

260 data = numpy.array([i, 10*i]) 

261 butler.put(data, "add_input", dataId) 

262 

263 # Make the graph, task factory is not needed here 

264 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

265 qgraph = builder.makeGraph( 

266 pipeline, 

267 collections=CollectionSearch.fromExpression(butler.run), 

268 run=butler.run, 

269 userQuery="" 

270 ) 

271 

272 return butler, qgraph