Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

137 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29 

30import lsst.daf.butler.tests as butlerTests 

31import lsst.pex.config as pexConfig 

32import numpy 

33from lsst.base import Packages 

34from lsst.daf.butler import Butler, Config, DatasetType 

35from lsst.daf.butler.core.logging import ButlerLogRecords 

36from lsst.resources import ResourcePath 

37from lsst.utils import doImport 

38 

39from ... import base as pipeBase 

40from .. import connectionTypes as cT 

41from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

42from ..task import _TASK_FULL_METADATA_TYPE 

43 

44_LOG = logging.getLogger(__name__) 

45 

46 

47# SimpleInstrument has an instrument-like API as needed for unit testing, but 

48# can not explicitly depend on Instrument because pipe_base does not explicitly 

49# depend on obs_base. 

50class SimpleInstrument: 

51 def __init__(self, *args, **kwargs): 

52 pass 

53 

54 @staticmethod 

55 def getName(): 

56 return "INSTRU" 

57 

58 def applyConfigOverrides(self, name, config): 

59 pass 

60 

61 

62class AddTaskConnections( 

63 pipeBase.PipelineTaskConnections, 

64 dimensions=("instrument", "detector"), 

65 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

66): 

67 """Connections for AddTask, has one input and two outputs, 

68 plus one init output. 

69 """ 

70 

71 input = cT.Input( 

72 name="add_dataset{in_tmpl}", 

73 dimensions=["instrument", "detector"], 

74 storageClass="NumpyArray", 

75 doc="Input dataset type for this task", 

76 ) 

77 output = cT.Output( 

78 name="add_dataset{out_tmpl}", 

79 dimensions=["instrument", "detector"], 

80 storageClass="NumpyArray", 

81 doc="Output dataset type for this task", 

82 ) 

83 output2 = cT.Output( 

84 name="add2_dataset{out_tmpl}", 

85 dimensions=["instrument", "detector"], 

86 storageClass="NumpyArray", 

87 doc="Output dataset type for this task", 

88 ) 

89 initout = cT.InitOutput( 

90 name="add_init_output{out_tmpl}", 

91 storageClass="NumpyArray", 

92 doc="Init Output dataset type for this task", 

93 ) 

94 

95 

96class AddTaskConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

97 """Config for AddTask.""" 

98 

99 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

100 

101 

102class AddTask(pipeBase.PipelineTask): 

103 """Trivial PipelineTask for testing, has some extras useful for specific 

104 unit tests. 

105 """ 

106 

107 ConfigClass = AddTaskConfig 

108 _DefaultName = "add_task" 

109 

110 initout = numpy.array([999]) 

111 """InitOutputs for this task""" 

112 

113 taskFactory = None 

114 """Factory that makes instances""" 

115 

116 def run(self, input): 

117 

118 if self.taskFactory: 

119 # do some bookkeeping 

120 if self.taskFactory.stopAt == self.taskFactory.countExec: 

121 raise RuntimeError("pretend something bad happened") 

122 self.taskFactory.countExec += 1 

123 

124 self.metadata.add("add", self.config.addend) 

125 output = input + self.config.addend 

126 output2 = output + self.config.addend 

127 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

128 return pipeBase.Struct(output=output, output2=output2) 

129 

130 

131class AddTaskFactoryMock(pipeBase.TaskFactory): 

132 """Special task factory that instantiates AddTask. 

133 

134 It also defines some bookkeeping variables used by AddTask to report 

135 progress to unit tests. 

136 """ 

137 

138 def __init__(self, stopAt=-1): 

139 self.countExec = 0 # incremented by AddTask 

140 self.stopAt = stopAt # AddTask raises exception at this call to run() 

141 

142 def makeTask(self, taskClass, name, config, overrides, butler): 

143 if config is None: 

144 config = taskClass.ConfigClass() 

145 if overrides: 

146 overrides.applyTo(config) 

147 task = taskClass(config=config, initInputs=None, name=name) 

148 task.taskFactory = self 

149 return task 

150 

151 

152def registerDatasetTypes(registry, pipeline): 

153 """Register all dataset types used by tasks in a registry. 

154 

155 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

156 

157 Parameters 

158 ---------- 

159 registry : `~lsst.daf.butler.Registry` 

160 Registry instance. 

161 pipeline : `typing.Iterable` of `TaskDef` 

162 Iterable of TaskDef instances, likely the output of the method 

163 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

164 """ 

165 for taskDef in pipeline: 

166 configDatasetType = DatasetType( 

167 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

168 ) 

169 packagesDatasetType = DatasetType( 

170 "packages", {}, storageClass="Packages", universe=registry.dimensions 

171 ) 

172 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

173 for datasetType in itertools.chain( 

174 datasetTypes.initInputs, 

175 datasetTypes.initOutputs, 

176 datasetTypes.inputs, 

177 datasetTypes.outputs, 

178 datasetTypes.prerequisites, 

179 [configDatasetType, packagesDatasetType], 

180 ): 

181 _LOG.info("Registering %s with registry", datasetType) 

182 # this is a no-op if it already exists and is consistent, 

183 # and it raises if it is inconsistent. But components must be 

184 # skipped 

185 if not datasetType.isComponent(): 

186 registry.registerDatasetType(datasetType) 

187 

188 

189def makeSimplePipeline(nQuanta, instrument=None): 

190 """Make a simple Pipeline for tests. 

191 

192 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

193 function. It can also be used to customize the pipeline used by 

194 ``makeSimpleQGraph`` function by calling this first and passing the result 

195 to it. 

196 

197 Parameters 

198 ---------- 

199 nQuanta : `int` 

200 The number of quanta to add to the pipeline. 

201 instrument : `str` or `None`, optional 

202 The importable name of an instrument to be added to the pipeline or 

203 if no instrument should be added then an empty string or `None`, by 

204 default None 

205 

206 Returns 

207 ------- 

208 pipeline : `~lsst.pipe.base.Pipeline` 

209 The created pipeline object. 

210 """ 

211 pipeline = pipeBase.Pipeline("test pipeline") 

212 # make a bunch of tasks that execute in well defined order (via data 

213 # dependencies) 

214 for lvl in range(nQuanta): 

215 pipeline.addTask(AddTask, f"task{lvl}") 

216 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

217 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

218 if instrument: 

219 pipeline.addInstrument(instrument) 

220 return pipeline 

221 

222 

223def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

224 """Create new data butler instance. 

225 

226 Parameters 

227 ---------- 

228 root : `str` 

229 Path or URI to the root location of the new repository. 

230 run : `str`, optional 

231 Run collection name. 

232 inMemory : `bool`, optional 

233 If true make in-memory repository. 

234 

235 Returns 

236 ------- 

237 butler : `~lsst.daf.butler.Butler` 

238 Data butler instance. 

239 """ 

240 root = ResourcePath(root, forceDirectory=True) 

241 if not root.isLocal: 

242 raise ValueError(f"Only works with local root not {root}") 

243 config = Config() 

244 if not inMemory: 

245 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite" 

246 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

247 repo = butlerTests.makeTestRepo(root, {}, config=config) 

248 butler = Butler(butler=repo, run=run) 

249 return butler 

250 

251 

252def populateButler(pipeline, butler, datasetTypes=None): 

253 """Populate data butler with data needed for test. 

254 

255 Initializes data butler with a bunch of items: 

256 - registers dataset types which are defined by pipeline 

257 - create dimensions data for (instrument, detector) 

258 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

259 missing then a single dataset with type "add_dataset0" is added 

260 

261 All datasets added to butler have ``dataId={instrument=instrument, 

262 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

263 used if pipeline is missing instrument definition. Type of the dataset is 

264 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

265 tasks). 

266 

267 Parameters 

268 ---------- 

269 pipeline : `~lsst.pipe.base.Pipeline` 

270 Pipeline instance. 

271 butler : `~lsst.daf.butler.Butler` 

272 Data butler instance. 

273 datasetTypes : `dict` [ `str`, `list` ], optional 

274 Dictionary whose keys are collection names and values are lists of 

275 dataset type names. By default a single dataset of type "add_dataset0" 

276 is added to a ``butler.run`` collection. 

277 """ 

278 

279 # Add dataset types to registry 

280 taskDefs = list(pipeline.toExpandedPipeline()) 

281 registerDatasetTypes(butler.registry, taskDefs) 

282 

283 instrument = pipeline.getInstrument() 

284 if instrument is not None: 

285 if isinstance(instrument, str): 

286 instrument = doImport(instrument) 

287 instrumentName = instrument.getName() 

288 else: 

289 instrumentName = "INSTR" 

290 

291 # Add all needed dimensions to registry 

292 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

293 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

294 

295 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

296 # Add inputs to butler 

297 if not datasetTypes: 

298 datasetTypes = {None: ["add_dataset0"]} 

299 for run, dsTypes in datasetTypes.items(): 

300 if run is not None: 

301 butler.registry.registerRun(run) 

302 for dsType in dsTypes: 

303 if dsType == "packages": 

304 # Version is intentionally inconsistent 

305 data = Packages({"python": "9.9.99"}) 

306 butler.put(data, dsType, run=run) 

307 else: 

308 if dsType.endswith("_config"): 

309 # find a confing from matching task name or make a new one 

310 taskLabel, _, _ = dsType.rpartition("_") 

311 taskDef = taskDefMap.get(taskLabel) 

312 if taskDef is not None: 

313 data = taskDef.config 

314 else: 

315 data = AddTaskConfig() 

316 elif dsType.endswith("_metadata"): 

317 data = _TASK_FULL_METADATA_TYPE() 

318 elif dsType.endswith("_log"): 

319 data = ButlerLogRecords.from_records([]) 

320 else: 

321 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

322 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

323 

324 

325def makeSimpleQGraph( 

326 nQuanta=5, 

327 pipeline=None, 

328 butler=None, 

329 root=None, 

330 callPopulateButler=True, 

331 run="test", 

332 skipExistingIn=None, 

333 inMemory=True, 

334 userQuery="", 

335 datasetTypes=None, 

336 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

337): 

338 """Make simple QuantumGraph for tests. 

339 

340 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

341 and butler, fills them with minimal data, and generates QuantumGraph with 

342 all of that. 

343 

344 Parameters 

345 ---------- 

346 nQuanta : `int` 

347 Number of quanta in a graph, only used if ``pipeline`` is None. 

348 pipeline : `~lsst.pipe.base.Pipeline` 

349 If `None` then a pipeline is made with `AddTask` and default 

350 `AddTaskConfig`. 

351 butler : `~lsst.daf.butler.Butler`, optional 

352 Data butler instance, if None then new data butler is created by 

353 calling `makeSimpleButler`. 

354 callPopulateButler : `bool`, optional 

355 If True insert datasets into the butler prior to building a graph. 

356 If False butler argument must not be None, and must be pre-populated. 

357 Defaults to True. 

358 root : `str` 

359 Path or URI to the root location of the new repository. Only used if 

360 ``butler`` is None. 

361 run : `str`, optional 

362 Name of the RUN collection to add to butler, only used if ``butler`` 

363 is None. 

364 skipExistingIn 

365 Expressions representing the collections to search for existing 

366 output datasets that should be skipped. May be any of the types 

367 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

368 inMemory : `bool`, optional 

369 If true make in-memory repository, only used if ``butler`` is `None`. 

370 userQuery : `str`, optional 

371 The user query to pass to ``makeGraph``, by default an empty string. 

372 datasetTypes : `dict` [ `str`, `list` ], optional 

373 Dictionary whose keys are collection names and values are lists of 

374 dataset type names. By default a single dataset of type "add_dataset0" 

375 is added to a ``butler.run`` collection. 

376 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

377 The query constraint variant that should be used to constrain the 

378 query based on dataset existence, defaults to 

379 `DatasetQueryConstraintVariant.ALL`. 

380 

381 Returns 

382 ------- 

383 butler : `~lsst.daf.butler.Butler` 

384 Butler instance 

385 qgraph : `~lsst.pipe.base.QuantumGraph` 

386 Quantum graph instance 

387 """ 

388 

389 if pipeline is None: 

390 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

391 

392 if butler is None: 

393 if root is None: 

394 raise ValueError("Must provide `root` when `butler` is None") 

395 if callPopulateButler is False: 

396 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

397 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

398 

399 if callPopulateButler: 

400 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

401 

402 # Make the graph 

403 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

404 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn) 

405 _LOG.debug( 

406 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

407 butler.collections, 

408 run or butler.run, 

409 userQuery, 

410 ) 

411 qgraph = builder.makeGraph( 

412 pipeline, 

413 collections=butler.collections, 

414 run=run or butler.run, 

415 userQuery=userQuery, 

416 datasetQueryConstraint=datasetQueryConstraint, 

417 ) 

418 

419 return butler, qgraph