Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 31%

153 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-11 02:51 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union 

31 

32import lsst.daf.butler.tests as butlerTests 

33import lsst.pex.config as pexConfig 

34import numpy 

35from lsst.daf.butler import Butler, Config, DataId, DatasetType, Formatter 

36from lsst.daf.butler.core.logging import ButlerLogRecords 

37from lsst.resources import ResourcePath 

38from lsst.utils import doImportType 

39 

40from .. import connectionTypes as cT 

41from .._instrument import Instrument 

42from ..config import PipelineTaskConfig 

43from ..connections import PipelineTaskConnections 

44from ..graph import QuantumGraph 

45from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

46from ..graphBuilder import GraphBuilder 

47from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

48from ..pipelineTask import PipelineTask 

49from ..struct import Struct 

50from ..task import _TASK_FULL_METADATA_TYPE 

51from ..taskFactory import TaskFactory 

52 

53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 from lsst.daf.butler import Registry 

55 

56 from ..configOverrides import ConfigOverrides 

57 

58_LOG = logging.getLogger(__name__) 

59 

60 

61class SimpleInstrument(Instrument): 

62 def __init__(self, *args: Any, **kwargs: Any): 

63 pass 

64 

65 @staticmethod 

66 def getName() -> str: 

67 return "INSTRU" 

68 

69 def getRawFormatter(self, dataId: DataId) -> Type[Formatter]: 

70 return Formatter 

71 

72 def register(self, registry: Registry, *, update: bool = False) -> None: 

73 pass 

74 

75 

76class AddTaskConnections( 

77 PipelineTaskConnections, 

78 dimensions=("instrument", "detector"), 

79 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

80): 

81 """Connections for AddTask, has one input and two outputs, 

82 plus one init output. 

83 """ 

84 

85 input = cT.Input( 

86 name="add_dataset{in_tmpl}", 

87 dimensions=["instrument", "detector"], 

88 storageClass="NumpyArray", 

89 doc="Input dataset type for this task", 

90 ) 

91 output = cT.Output( 

92 name="add_dataset{out_tmpl}", 

93 dimensions=["instrument", "detector"], 

94 storageClass="NumpyArray", 

95 doc="Output dataset type for this task", 

96 ) 

97 output2 = cT.Output( 

98 name="add2_dataset{out_tmpl}", 

99 dimensions=["instrument", "detector"], 

100 storageClass="NumpyArray", 

101 doc="Output dataset type for this task", 

102 ) 

103 initout = cT.InitOutput( 

104 name="add_init_output{out_tmpl}", 

105 storageClass="NumpyArray", 

106 doc="Init Output dataset type for this task", 

107 ) 

108 

109 

110class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

111 """Config for AddTask.""" 

112 

113 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

114 

115 

116class AddTask(PipelineTask): 

117 """Trivial PipelineTask for testing, has some extras useful for specific 

118 unit tests. 

119 """ 

120 

121 ConfigClass = AddTaskConfig 

122 _DefaultName = "add_task" 

123 

124 initout = numpy.array([999]) 

125 """InitOutputs for this task""" 

126 

127 taskFactory: Optional[AddTaskFactoryMock] = None 

128 """Factory that makes instances""" 

129 

130 def run(self, input: int) -> Struct: # type: ignore 

131 

132 if self.taskFactory: 

133 # do some bookkeeping 

134 if self.taskFactory.stopAt == self.taskFactory.countExec: 

135 raise RuntimeError("pretend something bad happened") 

136 self.taskFactory.countExec += 1 

137 

138 self.metadata.add("add", self.config.addend) 

139 output = input + self.config.addend 

140 output2 = output + self.config.addend 

141 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

142 return Struct(output=output, output2=output2) 

143 

144 

145class AddTaskFactoryMock(TaskFactory): 

146 """Special task factory that instantiates AddTask. 

147 

148 It also defines some bookkeeping variables used by AddTask to report 

149 progress to unit tests. 

150 """ 

151 

152 def __init__(self, stopAt: int = -1): 

153 self.countExec = 0 # incremented by AddTask 

154 self.stopAt = stopAt # AddTask raises exception at this call to run() 

155 

156 def makeTask( 

157 self, 

158 taskClass: Type[PipelineTask], 

159 name: Optional[str], 

160 config: Optional[PipelineTaskConfig], 

161 overrides: Optional[ConfigOverrides], 

162 butler: Optional[Butler], 

163 ) -> PipelineTask: 

164 if config is None: 

165 config = taskClass.ConfigClass() 

166 if overrides: 

167 overrides.applyTo(config) 

168 task = taskClass(config=config, initInputs=None, name=name) 

169 task.taskFactory = self # type: ignore 

170 return task 

171 

172 

173def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None: 

174 """Register all dataset types used by tasks in a registry. 

175 

176 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

177 

178 Parameters 

179 ---------- 

180 registry : `~lsst.daf.butler.Registry` 

181 Registry instance. 

182 pipeline : `typing.Iterable` of `TaskDef` 

183 Iterable of TaskDef instances, likely the output of the method 

184 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

185 """ 

186 for taskDef in pipeline: 

187 configDatasetType = DatasetType( 

188 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

189 ) 

190 storageClass = "Packages" 

191 packagesDatasetType = DatasetType( 

192 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

193 ) 

194 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

195 for datasetType in itertools.chain( 

196 datasetTypes.initInputs, 

197 datasetTypes.initOutputs, 

198 datasetTypes.inputs, 

199 datasetTypes.outputs, 

200 datasetTypes.prerequisites, 

201 [configDatasetType, packagesDatasetType], 

202 ): 

203 _LOG.info("Registering %s with registry", datasetType) 

204 # this is a no-op if it already exists and is consistent, 

205 # and it raises if it is inconsistent. But components must be 

206 # skipped 

207 if not datasetType.isComponent(): 

208 registry.registerDatasetType(datasetType) 

209 

210 

211def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline: 

212 """Make a simple Pipeline for tests. 

213 

214 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

215 function. It can also be used to customize the pipeline used by 

216 ``makeSimpleQGraph`` function by calling this first and passing the result 

217 to it. 

218 

219 Parameters 

220 ---------- 

221 nQuanta : `int` 

222 The number of quanta to add to the pipeline. 

223 instrument : `str` or `None`, optional 

224 The importable name of an instrument to be added to the pipeline or 

225 if no instrument should be added then an empty string or `None`, by 

226 default None 

227 

228 Returns 

229 ------- 

230 pipeline : `~lsst.pipe.base.Pipeline` 

231 The created pipeline object. 

232 """ 

233 pipeline = Pipeline("test pipeline") 

234 # make a bunch of tasks that execute in well defined order (via data 

235 # dependencies) 

236 for lvl in range(nQuanta): 

237 pipeline.addTask(AddTask, f"task{lvl}") 

238 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

239 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

240 if instrument: 

241 pipeline.addInstrument(instrument) 

242 return pipeline 

243 

244 

245def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

246 """Create new data butler instance. 

247 

248 Parameters 

249 ---------- 

250 root : `str` 

251 Path or URI to the root location of the new repository. 

252 run : `str`, optional 

253 Run collection name. 

254 inMemory : `bool`, optional 

255 If true make in-memory repository. 

256 

257 Returns 

258 ------- 

259 butler : `~lsst.daf.butler.Butler` 

260 Data butler instance. 

261 """ 

262 root_path = ResourcePath(root, forceDirectory=True) 

263 if not root_path.isLocal: 

264 raise ValueError(f"Only works with local root not {root_path}") 

265 config = Config() 

266 if not inMemory: 

267 config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

268 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

269 repo = butlerTests.makeTestRepo(str(root_path), {}, config=config) 

270 butler = Butler(butler=repo, run=run) 

271 return butler 

272 

273 

274def populateButler( 

275 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] = None 

276) -> None: 

277 """Populate data butler with data needed for test. 

278 

279 Initializes data butler with a bunch of items: 

280 - registers dataset types which are defined by pipeline 

281 - create dimensions data for (instrument, detector) 

282 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

283 missing then a single dataset with type "add_dataset0" is added 

284 

285 All datasets added to butler have ``dataId={instrument=instrument, 

286 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

287 used if pipeline is missing instrument definition. Type of the dataset is 

288 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

289 tasks). 

290 

291 Parameters 

292 ---------- 

293 pipeline : `~lsst.pipe.base.Pipeline` 

294 Pipeline instance. 

295 butler : `~lsst.daf.butler.Butler` 

296 Data butler instance. 

297 datasetTypes : `dict` [ `str`, `list` ], optional 

298 Dictionary whose keys are collection names and values are lists of 

299 dataset type names. By default a single dataset of type "add_dataset0" 

300 is added to a ``butler.run`` collection. 

301 """ 

302 

303 # Add dataset types to registry 

304 taskDefs = list(pipeline.toExpandedPipeline()) 

305 registerDatasetTypes(butler.registry, taskDefs) 

306 

307 instrument = pipeline.getInstrument() 

308 if instrument is not None: 

309 instrument_class = doImportType(instrument) 

310 instrumentName = instrument_class.getName() 

311 else: 

312 instrumentName = "INSTR" 

313 

314 # Add all needed dimensions to registry 

315 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

316 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

317 

318 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

319 # Add inputs to butler 

320 if not datasetTypes: 

321 datasetTypes = {None: ["add_dataset0"]} 

322 for run, dsTypes in datasetTypes.items(): 

323 if run is not None: 

324 butler.registry.registerRun(run) 

325 for dsType in dsTypes: 

326 if dsType == "packages": 

327 # Version is intentionally inconsistent. 

328 # Dict is convertible to Packages if Packages is installed. 

329 data: Any = {"python": "9.9.99"} 

330 butler.put(data, dsType, run=run) 

331 else: 

332 if dsType.endswith("_config"): 

333 # find a confing from matching task name or make a new one 

334 taskLabel, _, _ = dsType.rpartition("_") 

335 taskDef = taskDefMap.get(taskLabel) 

336 if taskDef is not None: 

337 data = taskDef.config 

338 else: 

339 data = AddTaskConfig() 

340 elif dsType.endswith("_metadata"): 

341 data = _TASK_FULL_METADATA_TYPE() 

342 elif dsType.endswith("_log"): 

343 data = ButlerLogRecords.from_records([]) 

344 else: 

345 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

346 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

347 

348 

349def makeSimpleQGraph( 

350 nQuanta: int = 5, 

351 pipeline: Optional[Pipeline] = None, 

352 butler: Optional[Butler] = None, 

353 root: Optional[str] = None, 

354 callPopulateButler: bool = True, 

355 run: str = "test", 

356 skipExistingIn: Any = None, 

357 inMemory: bool = True, 

358 userQuery: str = "", 

359 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None, 

360 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

361 makeDatastoreRecords: bool = False, 

362) -> Tuple[Butler, QuantumGraph]: 

363 """Make simple QuantumGraph for tests. 

364 

365 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

366 and butler, fills them with minimal data, and generates QuantumGraph with 

367 all of that. 

368 

369 Parameters 

370 ---------- 

371 nQuanta : `int` 

372 Number of quanta in a graph, only used if ``pipeline`` is None. 

373 pipeline : `~lsst.pipe.base.Pipeline` 

374 If `None` then a pipeline is made with `AddTask` and default 

375 `AddTaskConfig`. 

376 butler : `~lsst.daf.butler.Butler`, optional 

377 Data butler instance, if None then new data butler is created by 

378 calling `makeSimpleButler`. 

379 callPopulateButler : `bool`, optional 

380 If True insert datasets into the butler prior to building a graph. 

381 If False butler argument must not be None, and must be pre-populated. 

382 Defaults to True. 

383 root : `str` 

384 Path or URI to the root location of the new repository. Only used if 

385 ``butler`` is None. 

386 run : `str`, optional 

387 Name of the RUN collection to add to butler, only used if ``butler`` 

388 is None. 

389 skipExistingIn 

390 Expressions representing the collections to search for existing 

391 output datasets that should be skipped. May be any of the types 

392 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

393 inMemory : `bool`, optional 

394 If true make in-memory repository, only used if ``butler`` is `None`. 

395 userQuery : `str`, optional 

396 The user query to pass to ``makeGraph``, by default an empty string. 

397 datasetTypes : `dict` [ `str`, `list` ], optional 

398 Dictionary whose keys are collection names and values are lists of 

399 dataset type names. By default a single dataset of type "add_dataset0" 

400 is added to a ``butler.run`` collection. 

401 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

402 The query constraint variant that should be used to constrain the 

403 query based on dataset existence, defaults to 

404 `DatasetQueryConstraintVariant.ALL`. 

405 makeDatastoreRecords : `bool`, optional 

406 If `True` then add datstore records to generated quanta. 

407 

408 Returns 

409 ------- 

410 butler : `~lsst.daf.butler.Butler` 

411 Butler instance 

412 qgraph : `~lsst.pipe.base.QuantumGraph` 

413 Quantum graph instance 

414 """ 

415 

416 if pipeline is None: 

417 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

418 

419 if butler is None: 

420 if root is None: 

421 raise ValueError("Must provide `root` when `butler` is None") 

422 if callPopulateButler is False: 

423 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

424 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

425 

426 if callPopulateButler: 

427 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

428 

429 # Make the graph 

430 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

431 builder = GraphBuilder( 

432 registry=butler.registry, 

433 skipExistingIn=skipExistingIn, 

434 datastore=butler.datastore if makeDatastoreRecords else None, 

435 ) 

436 _LOG.debug( 

437 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

438 butler.collections, 

439 run or butler.run, 

440 userQuery, 

441 ) 

442 qgraph = builder.makeGraph( 

443 pipeline, 

444 collections=butler.collections, 

445 run=run or butler.run, 

446 userQuery=userQuery, 

447 datasetQueryConstraint=datasetQueryConstraint, 

448 ) 

449 

450 return butler, qgraph