Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%

155 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-12 20:54 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from collections.abc import Iterable, Mapping 

31from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast 

32 

33import lsst.daf.butler.tests as butlerTests 

34import lsst.pex.config as pexConfig 

35import numpy 

36from lsst.daf.butler import Butler, Config, DataId, DatasetType, Formatter 

37from lsst.daf.butler.core.logging import ButlerLogRecords 

38from lsst.resources import ResourcePath 

39from lsst.utils import doImportType 

40 

41from .. import connectionTypes as cT 

42from .._instrument import Instrument 

43from ..config import PipelineTaskConfig 

44from ..connections import PipelineTaskConnections 

45from ..graph import QuantumGraph 

46from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

47from ..graphBuilder import GraphBuilder 

48from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

49from ..pipelineTask import PipelineTask 

50from ..struct import Struct 

51from ..task import _TASK_FULL_METADATA_TYPE 

52from ..taskFactory import TaskFactory 

53 

54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 from lsst.daf.butler import Registry 

56 

57 from ..configOverrides import ConfigOverrides 

58 

59_LOG = logging.getLogger(__name__) 

60 

61 

62class SimpleInstrument(Instrument): 

63 def __init__(self, *args: Any, **kwargs: Any): 

64 pass 

65 

66 @staticmethod 

67 def getName() -> str: 

68 return "INSTRU" 

69 

70 def getRawFormatter(self, dataId: DataId) -> Type[Formatter]: 

71 return Formatter 

72 

73 def register(self, registry: Registry, *, update: bool = False) -> None: 

74 pass 

75 

76 

77class AddTaskConnections( 

78 PipelineTaskConnections, 

79 dimensions=("instrument", "detector"), 

80 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

81): 

82 """Connections for AddTask, has one input and two outputs, 

83 plus one init output. 

84 """ 

85 

86 input = cT.Input( 

87 name="add_dataset{in_tmpl}", 

88 dimensions=["instrument", "detector"], 

89 storageClass="NumpyArray", 

90 doc="Input dataset type for this task", 

91 ) 

92 output = cT.Output( 

93 name="add_dataset{out_tmpl}", 

94 dimensions=["instrument", "detector"], 

95 storageClass="NumpyArray", 

96 doc="Output dataset type for this task", 

97 ) 

98 output2 = cT.Output( 

99 name="add2_dataset{out_tmpl}", 

100 dimensions=["instrument", "detector"], 

101 storageClass="NumpyArray", 

102 doc="Output dataset type for this task", 

103 ) 

104 initout = cT.InitOutput( 

105 name="add_init_output{out_tmpl}", 

106 storageClass="NumpyArray", 

107 doc="Init Output dataset type for this task", 

108 ) 

109 

110 

111class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

112 """Config for AddTask.""" 

113 

114 addend = pexConfig.Field[int](doc="amount to add", default=3) 

115 

116 

117class AddTask(PipelineTask): 

118 """Trivial PipelineTask for testing, has some extras useful for specific 

119 unit tests. 

120 """ 

121 

122 ConfigClass = AddTaskConfig 

123 _DefaultName = "add_task" 

124 

125 initout = numpy.array([999]) 

126 """InitOutputs for this task""" 

127 

128 taskFactory: Optional[AddTaskFactoryMock] = None 

129 """Factory that makes instances""" 

130 

131 def run(self, input: int) -> Struct: # type: ignore 

132 if self.taskFactory: 

133 # do some bookkeeping 

134 if self.taskFactory.stopAt == self.taskFactory.countExec: 

135 raise RuntimeError("pretend something bad happened") 

136 self.taskFactory.countExec += 1 

137 

138 self.config = cast(AddTaskConfig, self.config) 

139 self.metadata.add("add", self.config.addend) 

140 output = input + self.config.addend 

141 output2 = output + self.config.addend 

142 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

143 return Struct(output=output, output2=output2) 

144 

145 

146class AddTaskFactoryMock(TaskFactory): 

147 """Special task factory that instantiates AddTask. 

148 

149 It also defines some bookkeeping variables used by AddTask to report 

150 progress to unit tests. 

151 """ 

152 

153 def __init__(self, stopAt: int = -1): 

154 self.countExec = 0 # incremented by AddTask 

155 self.stopAt = stopAt # AddTask raises exception at this call to run() 

156 

157 def makeTask( 

158 self, 

159 taskClass: Type[PipelineTask], 

160 name: Optional[str], 

161 config: Optional[PipelineTaskConfig], 

162 overrides: Optional[ConfigOverrides], 

163 butler: Optional[Butler], 

164 ) -> PipelineTask: 

165 if config is None: 

166 config = taskClass.ConfigClass() 

167 if overrides: 

168 overrides.applyTo(config) 

169 task = taskClass(config=config, initInputs=None, name=name) 

170 task.taskFactory = self # type: ignore 

171 return task 

172 

173 

174def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None: 

175 """Register all dataset types used by tasks in a registry. 

176 

177 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

178 

179 Parameters 

180 ---------- 

181 registry : `~lsst.daf.butler.Registry` 

182 Registry instance. 

183 pipeline : `typing.Iterable` of `TaskDef` 

184 Iterable of TaskDef instances, likely the output of the method 

185 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

186 """ 

187 for taskDef in pipeline: 

188 configDatasetType = DatasetType( 

189 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

190 ) 

191 storageClass = "Packages" 

192 packagesDatasetType = DatasetType( 

193 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

194 ) 

195 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

196 for datasetType in itertools.chain( 

197 datasetTypes.initInputs, 

198 datasetTypes.initOutputs, 

199 datasetTypes.inputs, 

200 datasetTypes.outputs, 

201 datasetTypes.prerequisites, 

202 [configDatasetType, packagesDatasetType], 

203 ): 

204 _LOG.info("Registering %s with registry", datasetType) 

205 # this is a no-op if it already exists and is consistent, 

206 # and it raises if it is inconsistent. But components must be 

207 # skipped 

208 if not datasetType.isComponent(): 

209 registry.registerDatasetType(datasetType) 

210 

211 

212def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline: 

213 """Make a simple Pipeline for tests. 

214 

215 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

216 function. It can also be used to customize the pipeline used by 

217 ``makeSimpleQGraph`` function by calling this first and passing the result 

218 to it. 

219 

220 Parameters 

221 ---------- 

222 nQuanta : `int` 

223 The number of quanta to add to the pipeline. 

224 instrument : `str` or `None`, optional 

225 The importable name of an instrument to be added to the pipeline or 

226 if no instrument should be added then an empty string or `None`, by 

227 default None 

228 

229 Returns 

230 ------- 

231 pipeline : `~lsst.pipe.base.Pipeline` 

232 The created pipeline object. 

233 """ 

234 pipeline = Pipeline("test pipeline") 

235 # make a bunch of tasks that execute in well defined order (via data 

236 # dependencies) 

237 for lvl in range(nQuanta): 

238 pipeline.addTask(AddTask, f"task{lvl}") 

239 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

240 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

241 if instrument: 

242 pipeline.addInstrument(instrument) 

243 return pipeline 

244 

245 

246def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

247 """Create new data butler instance. 

248 

249 Parameters 

250 ---------- 

251 root : `str` 

252 Path or URI to the root location of the new repository. 

253 run : `str`, optional 

254 Run collection name. 

255 inMemory : `bool`, optional 

256 If true make in-memory repository. 

257 

258 Returns 

259 ------- 

260 butler : `~lsst.daf.butler.Butler` 

261 Data butler instance. 

262 """ 

263 root_path = ResourcePath(root, forceDirectory=True) 

264 if not root_path.isLocal: 

265 raise ValueError(f"Only works with local root not {root_path}") 

266 config = Config() 

267 if not inMemory: 

268 config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

269 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

270 repo = butlerTests.makeTestRepo(str(root_path), {}, config=config) 

271 butler = Butler(butler=repo, run=run) 

272 return butler 

273 

274 

275def populateButler( 

276 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] | None = None 

277) -> None: 

278 """Populate data butler with data needed for test. 

279 

280 Initializes data butler with a bunch of items: 

281 - registers dataset types which are defined by pipeline 

282 - create dimensions data for (instrument, detector) 

283 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

284 missing then a single dataset with type "add_dataset0" is added 

285 

286 All datasets added to butler have ``dataId={instrument=instrument, 

287 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

288 used if pipeline is missing instrument definition. Type of the dataset is 

289 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

290 tasks). 

291 

292 Parameters 

293 ---------- 

294 pipeline : `~lsst.pipe.base.Pipeline` 

295 Pipeline instance. 

296 butler : `~lsst.daf.butler.Butler` 

297 Data butler instance. 

298 datasetTypes : `dict` [ `str`, `list` ], optional 

299 Dictionary whose keys are collection names and values are lists of 

300 dataset type names. By default a single dataset of type "add_dataset0" 

301 is added to a ``butler.run`` collection. 

302 """ 

303 

304 # Add dataset types to registry 

305 taskDefs = list(pipeline.toExpandedPipeline()) 

306 registerDatasetTypes(butler.registry, taskDefs) 

307 

308 instrument = pipeline.getInstrument() 

309 if instrument is not None: 

310 instrument_class = doImportType(instrument) 

311 instrumentName = instrument_class.getName() 

312 else: 

313 instrumentName = "INSTR" 

314 

315 # Add all needed dimensions to registry 

316 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

317 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

318 

319 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

320 # Add inputs to butler 

321 if not datasetTypes: 

322 datasetTypes = {None: ["add_dataset0"]} 

323 for run, dsTypes in datasetTypes.items(): 

324 if run is not None: 

325 butler.registry.registerRun(run) 

326 for dsType in dsTypes: 

327 if dsType == "packages": 

328 # Version is intentionally inconsistent. 

329 # Dict is convertible to Packages if Packages is installed. 

330 data: Any = {"python": "9.9.99"} 

331 butler.put(data, dsType, run=run) 

332 else: 

333 if dsType.endswith("_config"): 

334 # find a confing from matching task name or make a new one 

335 taskLabel, _, _ = dsType.rpartition("_") 

336 taskDef = taskDefMap.get(taskLabel) 

337 if taskDef is not None: 

338 data = taskDef.config 

339 else: 

340 data = AddTaskConfig() 

341 elif dsType.endswith("_metadata"): 

342 data = _TASK_FULL_METADATA_TYPE() 

343 elif dsType.endswith("_log"): 

344 data = ButlerLogRecords.from_records([]) 

345 else: 

346 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

347 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

348 

349 

350def makeSimpleQGraph( 

351 nQuanta: int = 5, 

352 pipeline: Optional[Pipeline] = None, 

353 butler: Optional[Butler] = None, 

354 root: Optional[str] = None, 

355 callPopulateButler: bool = True, 

356 run: str = "test", 

357 skipExistingIn: Any = None, 

358 inMemory: bool = True, 

359 userQuery: str = "", 

360 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None, 

361 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

362 makeDatastoreRecords: bool = False, 

363 resolveRefs: bool = False, 

364 bind: Optional[Mapping[str, Any]] = None, 

365) -> Tuple[Butler, QuantumGraph]: 

366 """Make simple QuantumGraph for tests. 

367 

368 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

369 and butler, fills them with minimal data, and generates QuantumGraph with 

370 all of that. 

371 

372 Parameters 

373 ---------- 

374 nQuanta : `int` 

375 Number of quanta in a graph, only used if ``pipeline`` is None. 

376 pipeline : `~lsst.pipe.base.Pipeline` 

377 If `None` then a pipeline is made with `AddTask` and default 

378 `AddTaskConfig`. 

379 butler : `~lsst.daf.butler.Butler`, optional 

380 Data butler instance, if None then new data butler is created by 

381 calling `makeSimpleButler`. 

382 callPopulateButler : `bool`, optional 

383 If True insert datasets into the butler prior to building a graph. 

384 If False butler argument must not be None, and must be pre-populated. 

385 Defaults to True. 

386 root : `str` 

387 Path or URI to the root location of the new repository. Only used if 

388 ``butler`` is None. 

389 run : `str`, optional 

390 Name of the RUN collection to add to butler, only used if ``butler`` 

391 is None. 

392 skipExistingIn 

393 Expressions representing the collections to search for existing 

394 output datasets that should be skipped. See 

395 :ref:`daf_butler_ordered_collection_searches`. 

396 inMemory : `bool`, optional 

397 If true make in-memory repository, only used if ``butler`` is `None`. 

398 userQuery : `str`, optional 

399 The user query to pass to ``makeGraph``, by default an empty string. 

400 datasetTypes : `dict` [ `str`, `list` ], optional 

401 Dictionary whose keys are collection names and values are lists of 

402 dataset type names. By default a single dataset of type "add_dataset0" 

403 is added to a ``butler.run`` collection. 

404 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

405 The query constraint variant that should be used to constrain the 

406 query based on dataset existence, defaults to 

407 `DatasetQueryConstraintVariant.ALL`. 

408 makeDatastoreRecords : `bool`, optional 

409 If `True` then add datstore records to generated quanta. 

410 resolveRefs : `bool`, optional 

411 If `True` then resolve all input references and generate random dataset 

412 IDs for all output and intermediate datasets. 

413 bind : `Mapping`, optional 

414 Mapping containing literal values that should be injected into the 

415 ``userQuery`` expression, keyed by the identifiers they replace. 

416 

417 Returns 

418 ------- 

419 butler : `~lsst.daf.butler.Butler` 

420 Butler instance 

421 qgraph : `~lsst.pipe.base.QuantumGraph` 

422 Quantum graph instance 

423 """ 

424 

425 if pipeline is None: 

426 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

427 

428 if butler is None: 

429 if root is None: 

430 raise ValueError("Must provide `root` when `butler` is None") 

431 if callPopulateButler is False: 

432 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

433 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

434 

435 if callPopulateButler: 

436 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

437 

438 # Make the graph 

439 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

440 builder = GraphBuilder( 

441 registry=butler.registry, 

442 skipExistingIn=skipExistingIn, 

443 datastore=butler.datastore if makeDatastoreRecords else None, 

444 ) 

445 _LOG.debug( 

446 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

447 butler.collections, 

448 run or butler.run, 

449 userQuery, 

450 bind, 

451 ) 

452 qgraph = builder.makeGraph( 

453 pipeline, 

454 collections=butler.collections, 

455 run=run or butler.run, 

456 userQuery=userQuery, 

457 datasetQueryConstraint=datasetQueryConstraint, 

458 resolveRefs=resolveRefs, 

459 bind=bind, 

460 ) 

461 

462 return butler, qgraph