Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30 

31from lsst.base import Packages 

32from lsst.daf.base import PropertySet 

33from lsst.daf.butler import Butler, ButlerURI, Config, DatasetType 

34import lsst.daf.butler.tests as butlerTests 

35from lsst.daf.butler.core.logging import ButlerLogRecords 

36import lsst.pex.config as pexConfig 

37from lsst.utils import doImport 

38from ... import base as pipeBase 

39from .. import connectionTypes as cT 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44# SimpleInstrument has an instrument-like API as needed for unit testing, but 

45# can not explicitly depend on Instrument because pipe_base does not explicitly 

46# depend on obs_base. 

47class SimpleInstrument: 

48 

49 @staticmethod 

50 def getName(): 

51 return "INSTRU" 

52 

53 def applyConfigOverrides(self, name, config): 

54 pass 

55 

56 

57class AddTaskConnections(pipeBase.PipelineTaskConnections, 

58 dimensions=("instrument", "detector"), 

59 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}): 

60 """Connections for AddTask, has one input and two outputs, 

61 plus one init output. 

62 """ 

63 input = cT.Input(name="add_dataset{in_tmpl}", 

64 dimensions=["instrument", "detector"], 

65 storageClass="NumpyArray", 

66 doc="Input dataset type for this task") 

67 output = cT.Output(name="add_dataset{out_tmpl}", 

68 dimensions=["instrument", "detector"], 

69 storageClass="NumpyArray", 

70 doc="Output dataset type for this task") 

71 output2 = cT.Output(name="add2_dataset{out_tmpl}", 

72 dimensions=["instrument", "detector"], 

73 storageClass="NumpyArray", 

74 doc="Output dataset type for this task") 

75 initout = cT.InitOutput(name="add_init_output{out_tmpl}", 

76 storageClass="NumpyArray", 

77 doc="Init Output dataset type for this task") 

78 

79 

80class AddTaskConfig(pipeBase.PipelineTaskConfig, 

81 pipelineConnections=AddTaskConnections): 

82 """Config for AddTask. 

83 """ 

84 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

85 

86 

87class AddTask(pipeBase.PipelineTask): 

88 """Trivial PipelineTask for testing, has some extras useful for specific 

89 unit tests. 

90 """ 

91 

92 ConfigClass = AddTaskConfig 

93 _DefaultName = "add_task" 

94 

95 initout = numpy.array([999]) 

96 """InitOutputs for this task""" 

97 

98 taskFactory = None 

99 """Factory that makes instances""" 

100 

101 def run(self, input): 

102 

103 if self.taskFactory: 

104 # do some bookkeeping 

105 if self.taskFactory.stopAt == self.taskFactory.countExec: 

106 raise RuntimeError("pretend something bad happened") 

107 self.taskFactory.countExec += 1 

108 

109 self.metadata.add("add", self.config.addend) 

110 output = input + self.config.addend 

111 output2 = output + self.config.addend 

112 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

113 return pipeBase.Struct(output=output, output2=output2) 

114 

115 

116class AddTaskFactoryMock(pipeBase.TaskFactory): 

117 """Special task factory that instantiates AddTask. 

118 

119 It also defines some bookkeeping variables used by AddTask to report 

120 progress to unit tests. 

121 """ 

122 def __init__(self, stopAt=-1): 

123 self.countExec = 0 # incremented by AddTask 

124 self.stopAt = stopAt # AddTask raises exception at this call to run() 

125 

126 def makeTask(self, taskClass, name, config, overrides, butler): 

127 if config is None: 

128 config = taskClass.ConfigClass() 

129 if overrides: 

130 overrides.applyTo(config) 

131 task = taskClass(config=config, initInputs=None, name=name) 

132 task.taskFactory = self 

133 return task 

134 

135 

136def registerDatasetTypes(registry, pipeline): 

137 """Register all dataset types used by tasks in a registry. 

138 

139 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

140 

141 Parameters 

142 ---------- 

143 registry : `~lsst.daf.butler.Registry` 

144 Registry instance. 

145 pipeline : `typing.Iterable` of `TaskDef` 

146 Iterable of TaskDef instances, likely the output of the method 

147 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

148 """ 

149 for taskDef in pipeline: 

150 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

151 storageClass="Config", 

152 universe=registry.dimensions) 

153 packagesDatasetType = DatasetType("packages", {}, 

154 storageClass="Packages", 

155 universe=registry.dimensions) 

156 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

157 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

158 datasetTypes.inputs, datasetTypes.outputs, 

159 datasetTypes.prerequisites, 

160 [configDatasetType, packagesDatasetType]): 

161 _LOG.info("Registering %s with registry", datasetType) 

162 # this is a no-op if it already exists and is consistent, 

163 # and it raises if it is inconsistent. But components must be 

164 # skipped 

165 if not datasetType.isComponent(): 

166 registry.registerDatasetType(datasetType) 

167 

168 

169def makeSimplePipeline(nQuanta, instrument=None): 

170 """Make a simple Pipeline for tests. 

171 

172 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

173 function. It can also be used to customize the pipeline used by 

174 ``makeSimpleQGraph`` function by calling this first and passing the result 

175 to it. 

176 

177 Parameters 

178 ---------- 

179 nQuanta : `int` 

180 The number of quanta to add to the pipeline. 

181 instrument : `str` or `None`, optional 

182 The importable name of an instrument to be added to the pipeline or 

183 if no instrument should be added then an empty string or `None`, by 

184 default None 

185 

186 Returns 

187 ------- 

188 pipeline : `~lsst.pipe.base.Pipeline` 

189 The created pipeline object. 

190 """ 

191 pipeline = pipeBase.Pipeline("test pipeline") 

192 # make a bunch of tasks that execute in well defined order (via data 

193 # dependencies) 

194 for lvl in range(nQuanta): 

195 pipeline.addTask(AddTask, f"task{lvl}") 

196 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

197 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl+1) 

198 if instrument: 

199 pipeline.addInstrument(instrument) 

200 return pipeline 

201 

202 

203def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

204 """Create new data butler instance. 

205 

206 Parameters 

207 ---------- 

208 root : `str` 

209 Path or URI to the root location of the new repository. 

210 run : `str`, optional 

211 Run collection name. 

212 inMemory : `bool`, optional 

213 If true make in-memory repository. 

214 

215 Returns 

216 ------- 

217 butler : `~lsst.daf.butler.Butler` 

218 Data butler instance. 

219 """ 

220 root = ButlerURI(root, forceDirectory=True) 

221 if not root.isLocal: 

222 raise ValueError(f"Only works with local root not {root}") 

223 config = Config() 

224 if not inMemory: 

225 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite" 

226 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

227 repo = butlerTests.makeTestRepo(root, {}, config=config) 

228 butler = Butler(butler=repo, run=run) 

229 return butler 

230 

231 

232def populateButler(pipeline, butler, datasetTypes=None): 

233 """Populate data butler with data needed for test. 

234 

235 Initializes data butler with a bunch of items: 

236 - registers dataset types which are defined by pipeline 

237 - create dimensions data for (instrument, detector) 

238 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

239 missing then a single dataset with type "add_dataset0" is added 

240 

241 All datasets added to butler have ``dataId={instrument=instrument, 

242 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

243 used if pipeline is missing instrument definition. Type of the dataset is 

244 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

245 tasks). 

246 

247 Parameters 

248 ---------- 

249 pipeline : `~lsst.pipe.base.Pipeline` 

250 Pipeline instance. 

251 butler : `~lsst.daf.butler.Butler` 

252 Data butler instance. 

253 datasetTypes : `dict` [ `str`, `list` ], optional 

254 Dictionary whose keys are collection names and values are lists of 

255 dataset type names. By default a single dataset of type "add_dataset0" 

256 is added to a ``butler.run`` collection. 

257 """ 

258 

259 # Add dataset types to registry 

260 taskDefs = list(pipeline.toExpandedPipeline()) 

261 registerDatasetTypes(butler.registry, taskDefs) 

262 

263 instrument = pipeline.getInstrument() 

264 if instrument is not None: 

265 if isinstance(instrument, str): 

266 instrument = doImport(instrument) 

267 instrumentName = instrument.getName() 

268 else: 

269 instrumentName = "INSTR" 

270 

271 # Add all needed dimensions to registry 

272 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

273 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, 

274 id=0, full_name="det0")) 

275 

276 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

277 # Add inputs to butler 

278 if not datasetTypes: 

279 datasetTypes = {None: ["add_dataset0"]} 

280 for run, dsTypes in datasetTypes.items(): 

281 if run is not None: 

282 butler.registry.registerRun(run) 

283 for dsType in dsTypes: 

284 if dsType == "packages": 

285 # Version is intentionally inconsistent 

286 data = Packages({"python": "9.9.99"}) 

287 butler.put(data, dsType, run=run) 

288 else: 

289 if dsType.endswith("_config"): 

290 # find a confing from matching task name or make a new one 

291 taskLabel, _, _ = dsType.rpartition("_") 

292 taskDef = taskDefMap.get(taskLabel) 

293 if taskDef is not None: 

294 data = taskDef.config 

295 else: 

296 data = AddTaskConfig() 

297 elif dsType.endswith("_metadata"): 

298 data = PropertySet() 

299 elif dsType.endswith("_log"): 

300 data = ButlerLogRecords.from_records([]) 

301 else: 

302 data = numpy.array([0., 1., 2., 5.]) 

303 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

304 

305 

306def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, run="test", 

307 skipExistingIn=None, inMemory=True, userQuery="", 

308 datasetTypes=None): 

309 """Make simple QuantumGraph for tests. 

310 

311 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

312 and butler, fills them with minimal data, and generates QuantumGraph with 

313 all of that. 

314 

315 Parameters 

316 ---------- 

317 nQuanta : `int` 

318 Number of quanta in a graph, only used if ``pipeline`` is None. 

319 pipeline : `~lsst.pipe.base.Pipeline` 

320 If `None` then a pipeline is made with `AddTask` and default 

321 `AddTaskConfig`. 

322 butler : `~lsst.daf.butler.Butler`, optional 

323 Data butler instance, if None then new data butler is created by 

324 calling `makeSimpleButler`. 

325 root : `str` 

326 Path or URI to the root location of the new repository. Only used if 

327 ``butler`` is None. 

328 run : `str`, optional 

329 Name of the RUN collection to add to butler, only used if ``butler`` 

330 is None. 

331 skipExistingIn 

332 Expressions representing the collections to search for existing 

333 output datasets that should be skipped. May be any of the types 

334 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

335 inMemory : `bool`, optional 

336 If true make in-memory repository, only used if ``butler`` is `None`. 

337 userQuery : `str`, optional 

338 The user query to pass to ``makeGraph``, by default an empty string. 

339 datasetTypes : `dict` [ `str`, `list` ], optional 

340 Dictionary whose keys are collection names and values are lists of 

341 dataset type names. By default a single dataset of type "add_dataset0" 

342 is added to a ``butler.run`` collection. 

343 

344 Returns 

345 ------- 

346 butler : `~lsst.daf.butler.Butler` 

347 Butler instance 

348 qgraph : `~lsst.pipe.base.QuantumGraph` 

349 Quantum graph instance 

350 """ 

351 

352 if pipeline is None: 

353 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

354 

355 if butler is None: 

356 if root is None: 

357 raise ValueError("Must provide `root` when `butler` is None") 

358 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

359 

360 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

361 

362 # Make the graph 

363 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

364 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn) 

365 _LOG.debug("Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

366 butler.collections, run or butler.run, userQuery) 

367 qgraph = builder.makeGraph( 

368 pipeline, 

369 collections=butler.collections, 

370 run=run or butler.run, 

371 userQuery=userQuery 

372 ) 

373 

374 return butler, qgraph