Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29 

30import lsst.daf.butler.tests as butlerTests 

31import lsst.pex.config as pexConfig 

32import numpy 

33from lsst.base import Packages 

34from lsst.daf.butler import Butler, ButlerURI, Config, DatasetType 

35from lsst.daf.butler.core.logging import ButlerLogRecords 

36from lsst.utils import doImport 

37 

38from ... import base as pipeBase 

39from .. import connectionTypes as cT 

40from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

41from ..task import _TASK_FULL_METADATA_TYPE 

42 

43_LOG = logging.getLogger(__name__) 

44 

45 

46# SimpleInstrument has an instrument-like API as needed for unit testing, but 

47# can not explicitly depend on Instrument because pipe_base does not explicitly 

48# depend on obs_base. 

49class SimpleInstrument: 

50 def __init__(self, *args, **kwargs): 

51 pass 

52 

53 @staticmethod 

54 def getName(): 

55 return "INSTRU" 

56 

57 def applyConfigOverrides(self, name, config): 

58 pass 

59 

60 

61class AddTaskConnections( 

62 pipeBase.PipelineTaskConnections, 

63 dimensions=("instrument", "detector"), 

64 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

65): 

66 """Connections for AddTask, has one input and two outputs, 

67 plus one init output. 

68 """ 

69 

70 input = cT.Input( 

71 name="add_dataset{in_tmpl}", 

72 dimensions=["instrument", "detector"], 

73 storageClass="NumpyArray", 

74 doc="Input dataset type for this task", 

75 ) 

76 output = cT.Output( 

77 name="add_dataset{out_tmpl}", 

78 dimensions=["instrument", "detector"], 

79 storageClass="NumpyArray", 

80 doc="Output dataset type for this task", 

81 ) 

82 output2 = cT.Output( 

83 name="add2_dataset{out_tmpl}", 

84 dimensions=["instrument", "detector"], 

85 storageClass="NumpyArray", 

86 doc="Output dataset type for this task", 

87 ) 

88 initout = cT.InitOutput( 

89 name="add_init_output{out_tmpl}", 

90 storageClass="NumpyArray", 

91 doc="Init Output dataset type for this task", 

92 ) 

93 

94 

95class AddTaskConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

96 """Config for AddTask.""" 

97 

98 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

99 

100 

101class AddTask(pipeBase.PipelineTask): 

102 """Trivial PipelineTask for testing, has some extras useful for specific 

103 unit tests. 

104 """ 

105 

106 ConfigClass = AddTaskConfig 

107 _DefaultName = "add_task" 

108 

109 initout = numpy.array([999]) 

110 """InitOutputs for this task""" 

111 

112 taskFactory = None 

113 """Factory that makes instances""" 

114 

115 def run(self, input): 

116 

117 if self.taskFactory: 

118 # do some bookkeeping 

119 if self.taskFactory.stopAt == self.taskFactory.countExec: 

120 raise RuntimeError("pretend something bad happened") 

121 self.taskFactory.countExec += 1 

122 

123 self.metadata.add("add", self.config.addend) 

124 output = input + self.config.addend 

125 output2 = output + self.config.addend 

126 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

127 return pipeBase.Struct(output=output, output2=output2) 

128 

129 

130class AddTaskFactoryMock(pipeBase.TaskFactory): 

131 """Special task factory that instantiates AddTask. 

132 

133 It also defines some bookkeeping variables used by AddTask to report 

134 progress to unit tests. 

135 """ 

136 

137 def __init__(self, stopAt=-1): 

138 self.countExec = 0 # incremented by AddTask 

139 self.stopAt = stopAt # AddTask raises exception at this call to run() 

140 

141 def makeTask(self, taskClass, name, config, overrides, butler): 

142 if config is None: 

143 config = taskClass.ConfigClass() 

144 if overrides: 

145 overrides.applyTo(config) 

146 task = taskClass(config=config, initInputs=None, name=name) 

147 task.taskFactory = self 

148 return task 

149 

150 

151def registerDatasetTypes(registry, pipeline): 

152 """Register all dataset types used by tasks in a registry. 

153 

154 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

155 

156 Parameters 

157 ---------- 

158 registry : `~lsst.daf.butler.Registry` 

159 Registry instance. 

160 pipeline : `typing.Iterable` of `TaskDef` 

161 Iterable of TaskDef instances, likely the output of the method 

162 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

163 """ 

164 for taskDef in pipeline: 

165 configDatasetType = DatasetType( 

166 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

167 ) 

168 packagesDatasetType = DatasetType( 

169 "packages", {}, storageClass="Packages", universe=registry.dimensions 

170 ) 

171 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

172 for datasetType in itertools.chain( 

173 datasetTypes.initInputs, 

174 datasetTypes.initOutputs, 

175 datasetTypes.inputs, 

176 datasetTypes.outputs, 

177 datasetTypes.prerequisites, 

178 [configDatasetType, packagesDatasetType], 

179 ): 

180 _LOG.info("Registering %s with registry", datasetType) 

181 # this is a no-op if it already exists and is consistent, 

182 # and it raises if it is inconsistent. But components must be 

183 # skipped 

184 if not datasetType.isComponent(): 

185 registry.registerDatasetType(datasetType) 

186 

187 

188def makeSimplePipeline(nQuanta, instrument=None): 

189 """Make a simple Pipeline for tests. 

190 

191 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

192 function. It can also be used to customize the pipeline used by 

193 ``makeSimpleQGraph`` function by calling this first and passing the result 

194 to it. 

195 

196 Parameters 

197 ---------- 

198 nQuanta : `int` 

199 The number of quanta to add to the pipeline. 

200 instrument : `str` or `None`, optional 

201 The importable name of an instrument to be added to the pipeline or 

202 if no instrument should be added then an empty string or `None`, by 

203 default None 

204 

205 Returns 

206 ------- 

207 pipeline : `~lsst.pipe.base.Pipeline` 

208 The created pipeline object. 

209 """ 

210 pipeline = pipeBase.Pipeline("test pipeline") 

211 # make a bunch of tasks that execute in well defined order (via data 

212 # dependencies) 

213 for lvl in range(nQuanta): 

214 pipeline.addTask(AddTask, f"task{lvl}") 

215 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

216 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

217 if instrument: 

218 pipeline.addInstrument(instrument) 

219 return pipeline 

220 

221 

222def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

223 """Create new data butler instance. 

224 

225 Parameters 

226 ---------- 

227 root : `str` 

228 Path or URI to the root location of the new repository. 

229 run : `str`, optional 

230 Run collection name. 

231 inMemory : `bool`, optional 

232 If true make in-memory repository. 

233 

234 Returns 

235 ------- 

236 butler : `~lsst.daf.butler.Butler` 

237 Data butler instance. 

238 """ 

239 root = ButlerURI(root, forceDirectory=True) 

240 if not root.isLocal: 

241 raise ValueError(f"Only works with local root not {root}") 

242 config = Config() 

243 if not inMemory: 

244 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite" 

245 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

246 repo = butlerTests.makeTestRepo(root, {}, config=config) 

247 butler = Butler(butler=repo, run=run) 

248 return butler 

249 

250 

251def populateButler(pipeline, butler, datasetTypes=None): 

252 """Populate data butler with data needed for test. 

253 

254 Initializes data butler with a bunch of items: 

255 - registers dataset types which are defined by pipeline 

256 - create dimensions data for (instrument, detector) 

257 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

258 missing then a single dataset with type "add_dataset0" is added 

259 

260 All datasets added to butler have ``dataId={instrument=instrument, 

261 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

262 used if pipeline is missing instrument definition. Type of the dataset is 

263 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

264 tasks). 

265 

266 Parameters 

267 ---------- 

268 pipeline : `~lsst.pipe.base.Pipeline` 

269 Pipeline instance. 

270 butler : `~lsst.daf.butler.Butler` 

271 Data butler instance. 

272 datasetTypes : `dict` [ `str`, `list` ], optional 

273 Dictionary whose keys are collection names and values are lists of 

274 dataset type names. By default a single dataset of type "add_dataset0" 

275 is added to a ``butler.run`` collection. 

276 """ 

277 

278 # Add dataset types to registry 

279 taskDefs = list(pipeline.toExpandedPipeline()) 

280 registerDatasetTypes(butler.registry, taskDefs) 

281 

282 instrument = pipeline.getInstrument() 

283 if instrument is not None: 

284 if isinstance(instrument, str): 

285 instrument = doImport(instrument) 

286 instrumentName = instrument.getName() 

287 else: 

288 instrumentName = "INSTR" 

289 

290 # Add all needed dimensions to registry 

291 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

292 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

293 

294 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

295 # Add inputs to butler 

296 if not datasetTypes: 

297 datasetTypes = {None: ["add_dataset0"]} 

298 for run, dsTypes in datasetTypes.items(): 

299 if run is not None: 

300 butler.registry.registerRun(run) 

301 for dsType in dsTypes: 

302 if dsType == "packages": 

303 # Version is intentionally inconsistent 

304 data = Packages({"python": "9.9.99"}) 

305 butler.put(data, dsType, run=run) 

306 else: 

307 if dsType.endswith("_config"): 

308 # find a confing from matching task name or make a new one 

309 taskLabel, _, _ = dsType.rpartition("_") 

310 taskDef = taskDefMap.get(taskLabel) 

311 if taskDef is not None: 

312 data = taskDef.config 

313 else: 

314 data = AddTaskConfig() 

315 elif dsType.endswith("_metadata"): 

316 data = _TASK_FULL_METADATA_TYPE() 

317 elif dsType.endswith("_log"): 

318 data = ButlerLogRecords.from_records([]) 

319 else: 

320 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

321 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

322 

323 

324def makeSimpleQGraph( 

325 nQuanta=5, 

326 pipeline=None, 

327 butler=None, 

328 root=None, 

329 callPopulateButler=True, 

330 run="test", 

331 skipExistingIn=None, 

332 inMemory=True, 

333 userQuery="", 

334 datasetTypes=None, 

335 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

336): 

337 """Make simple QuantumGraph for tests. 

338 

339 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

340 and butler, fills them with minimal data, and generates QuantumGraph with 

341 all of that. 

342 

343 Parameters 

344 ---------- 

345 nQuanta : `int` 

346 Number of quanta in a graph, only used if ``pipeline`` is None. 

347 pipeline : `~lsst.pipe.base.Pipeline` 

348 If `None` then a pipeline is made with `AddTask` and default 

349 `AddTaskConfig`. 

350 butler : `~lsst.daf.butler.Butler`, optional 

351 Data butler instance, if None then new data butler is created by 

352 calling `makeSimpleButler`. 

353 callPopulateButler : `bool`, optional 

354 If True insert datasets into the butler prior to building a graph. 

355 If False butler argument must not be None, and must be pre-populated. 

356 Defaults to True. 

357 root : `str` 

358 Path or URI to the root location of the new repository. Only used if 

359 ``butler`` is None. 

360 run : `str`, optional 

361 Name of the RUN collection to add to butler, only used if ``butler`` 

362 is None. 

363 skipExistingIn 

364 Expressions representing the collections to search for existing 

365 output datasets that should be skipped. May be any of the types 

366 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

367 inMemory : `bool`, optional 

368 If true make in-memory repository, only used if ``butler`` is `None`. 

369 userQuery : `str`, optional 

370 The user query to pass to ``makeGraph``, by default an empty string. 

371 datasetTypes : `dict` [ `str`, `list` ], optional 

372 Dictionary whose keys are collection names and values are lists of 

373 dataset type names. By default a single dataset of type "add_dataset0" 

374 is added to a ``butler.run`` collection. 

375 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

376 The query constraint variant that should be used to constrain the 

377 query based on dataset existence, defaults to 

378 `DatasetQueryConstraintVariant.ALL`. 

379 

380 Returns 

381 ------- 

382 butler : `~lsst.daf.butler.Butler` 

383 Butler instance 

384 qgraph : `~lsst.pipe.base.QuantumGraph` 

385 Quantum graph instance 

386 """ 

387 

388 if pipeline is None: 

389 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

390 

391 if butler is None: 

392 if root is None: 

393 raise ValueError("Must provide `root` when `butler` is None") 

394 if callPopulateButler is False: 

395 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

396 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

397 

398 if callPopulateButler: 

399 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

400 

401 # Make the graph 

402 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

403 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn) 

404 _LOG.debug( 

405 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

406 butler.collections, 

407 run or butler.run, 

408 userQuery, 

409 ) 

410 qgraph = builder.makeGraph( 

411 pipeline, 

412 collections=butler.collections, 

413 run=run or butler.run, 

414 userQuery=userQuery, 

415 datasetQueryConstraint=datasetQueryConstraint, 

416 ) 

417 

418 return butler, qgraph