Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%

155 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-15 02:04 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from collections.abc import Iterable, Mapping 

31from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast 

32 

33import lsst.daf.butler.tests as butlerTests 

34import lsst.pex.config as pexConfig 

35import numpy 

36from lsst.daf.butler import Butler, Config, DataId, DatasetType, Formatter 

37from lsst.daf.butler.core.logging import ButlerLogRecords 

38from lsst.resources import ResourcePath 

39from lsst.utils import doImportType 

40 

41from .. import connectionTypes as cT 

42from .._instrument import Instrument 

43from ..config import PipelineTaskConfig 

44from ..connections import PipelineTaskConnections 

45from ..graph import QuantumGraph 

46from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

47from ..graphBuilder import GraphBuilder 

48from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

49from ..pipelineTask import PipelineTask 

50from ..struct import Struct 

51from ..task import _TASK_FULL_METADATA_TYPE 

52from ..taskFactory import TaskFactory 

53 

54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 from lsst.daf.butler import Registry 

56 

57 from ..configOverrides import ConfigOverrides 

58 

59_LOG = logging.getLogger(__name__) 

60 

61 

62class SimpleInstrument(Instrument): 

63 def __init__(self, *args: Any, **kwargs: Any): 

64 pass 

65 

66 @staticmethod 

67 def getName() -> str: 

68 return "INSTRU" 

69 

70 def getRawFormatter(self, dataId: DataId) -> Type[Formatter]: 

71 return Formatter 

72 

73 def register(self, registry: Registry, *, update: bool = False) -> None: 

74 pass 

75 

76 

77class AddTaskConnections( 

78 PipelineTaskConnections, 

79 dimensions=("instrument", "detector"), 

80 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

81): 

82 """Connections for AddTask, has one input and two outputs, 

83 plus one init output. 

84 """ 

85 

86 input = cT.Input( 

87 name="add_dataset{in_tmpl}", 

88 dimensions=["instrument", "detector"], 

89 storageClass="NumpyArray", 

90 doc="Input dataset type for this task", 

91 ) 

92 output = cT.Output( 

93 name="add_dataset{out_tmpl}", 

94 dimensions=["instrument", "detector"], 

95 storageClass="NumpyArray", 

96 doc="Output dataset type for this task", 

97 ) 

98 output2 = cT.Output( 

99 name="add2_dataset{out_tmpl}", 

100 dimensions=["instrument", "detector"], 

101 storageClass="NumpyArray", 

102 doc="Output dataset type for this task", 

103 ) 

104 initout = cT.InitOutput( 

105 name="add_init_output{out_tmpl}", 

106 storageClass="NumpyArray", 

107 doc="Init Output dataset type for this task", 

108 ) 

109 

110 

111class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

112 """Config for AddTask.""" 

113 

114 addend = pexConfig.Field[int](doc="amount to add", default=3) 

115 

116 

117class AddTask(PipelineTask): 

118 """Trivial PipelineTask for testing, has some extras useful for specific 

119 unit tests. 

120 """ 

121 

122 ConfigClass = AddTaskConfig 

123 _DefaultName = "add_task" 

124 

125 initout = numpy.array([999]) 

126 """InitOutputs for this task""" 

127 

128 taskFactory: Optional[AddTaskFactoryMock] = None 

129 """Factory that makes instances""" 

130 

131 def run(self, input: int) -> Struct: # type: ignore 

132 

133 if self.taskFactory: 

134 # do some bookkeeping 

135 if self.taskFactory.stopAt == self.taskFactory.countExec: 

136 raise RuntimeError("pretend something bad happened") 

137 self.taskFactory.countExec += 1 

138 

139 self.config = cast(AddTaskConfig, self.config) 

140 self.metadata.add("add", self.config.addend) 

141 output = input + self.config.addend 

142 output2 = output + self.config.addend 

143 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

144 return Struct(output=output, output2=output2) 

145 

146 

147class AddTaskFactoryMock(TaskFactory): 

148 """Special task factory that instantiates AddTask. 

149 

150 It also defines some bookkeeping variables used by AddTask to report 

151 progress to unit tests. 

152 """ 

153 

154 def __init__(self, stopAt: int = -1): 

155 self.countExec = 0 # incremented by AddTask 

156 self.stopAt = stopAt # AddTask raises exception at this call to run() 

157 

158 def makeTask( 

159 self, 

160 taskClass: Type[PipelineTask], 

161 name: Optional[str], 

162 config: Optional[PipelineTaskConfig], 

163 overrides: Optional[ConfigOverrides], 

164 butler: Optional[Butler], 

165 ) -> PipelineTask: 

166 if config is None: 

167 config = taskClass.ConfigClass() 

168 if overrides: 

169 overrides.applyTo(config) 

170 task = taskClass(config=config, initInputs=None, name=name) 

171 task.taskFactory = self # type: ignore 

172 return task 

173 

174 

175def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None: 

176 """Register all dataset types used by tasks in a registry. 

177 

178 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

179 

180 Parameters 

181 ---------- 

182 registry : `~lsst.daf.butler.Registry` 

183 Registry instance. 

184 pipeline : `typing.Iterable` of `TaskDef` 

185 Iterable of TaskDef instances, likely the output of the method 

186 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

187 """ 

188 for taskDef in pipeline: 

189 configDatasetType = DatasetType( 

190 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

191 ) 

192 storageClass = "Packages" 

193 packagesDatasetType = DatasetType( 

194 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

195 ) 

196 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

197 for datasetType in itertools.chain( 

198 datasetTypes.initInputs, 

199 datasetTypes.initOutputs, 

200 datasetTypes.inputs, 

201 datasetTypes.outputs, 

202 datasetTypes.prerequisites, 

203 [configDatasetType, packagesDatasetType], 

204 ): 

205 _LOG.info("Registering %s with registry", datasetType) 

206 # this is a no-op if it already exists and is consistent, 

207 # and it raises if it is inconsistent. But components must be 

208 # skipped 

209 if not datasetType.isComponent(): 

210 registry.registerDatasetType(datasetType) 

211 

212 

213def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline: 

214 """Make a simple Pipeline for tests. 

215 

216 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

217 function. It can also be used to customize the pipeline used by 

218 ``makeSimpleQGraph`` function by calling this first and passing the result 

219 to it. 

220 

221 Parameters 

222 ---------- 

223 nQuanta : `int` 

224 The number of quanta to add to the pipeline. 

225 instrument : `str` or `None`, optional 

226 The importable name of an instrument to be added to the pipeline or 

227 if no instrument should be added then an empty string or `None`, by 

228 default None 

229 

230 Returns 

231 ------- 

232 pipeline : `~lsst.pipe.base.Pipeline` 

233 The created pipeline object. 

234 """ 

235 pipeline = Pipeline("test pipeline") 

236 # make a bunch of tasks that execute in well defined order (via data 

237 # dependencies) 

238 for lvl in range(nQuanta): 

239 pipeline.addTask(AddTask, f"task{lvl}") 

240 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

241 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

242 if instrument: 

243 pipeline.addInstrument(instrument) 

244 return pipeline 

245 

246 

247def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

248 """Create new data butler instance. 

249 

250 Parameters 

251 ---------- 

252 root : `str` 

253 Path or URI to the root location of the new repository. 

254 run : `str`, optional 

255 Run collection name. 

256 inMemory : `bool`, optional 

257 If true make in-memory repository. 

258 

259 Returns 

260 ------- 

261 butler : `~lsst.daf.butler.Butler` 

262 Data butler instance. 

263 """ 

264 root_path = ResourcePath(root, forceDirectory=True) 

265 if not root_path.isLocal: 

266 raise ValueError(f"Only works with local root not {root_path}") 

267 config = Config() 

268 if not inMemory: 

269 config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

270 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

271 repo = butlerTests.makeTestRepo(str(root_path), {}, config=config) 

272 butler = Butler(butler=repo, run=run) 

273 return butler 

274 

275 

276def populateButler( 

277 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] | None = None 

278) -> None: 

279 """Populate data butler with data needed for test. 

280 

281 Initializes data butler with a bunch of items: 

282 - registers dataset types which are defined by pipeline 

283 - create dimensions data for (instrument, detector) 

284 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

285 missing then a single dataset with type "add_dataset0" is added 

286 

287 All datasets added to butler have ``dataId={instrument=instrument, 

288 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

289 used if pipeline is missing instrument definition. Type of the dataset is 

290 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

291 tasks). 

292 

293 Parameters 

294 ---------- 

295 pipeline : `~lsst.pipe.base.Pipeline` 

296 Pipeline instance. 

297 butler : `~lsst.daf.butler.Butler` 

298 Data butler instance. 

299 datasetTypes : `dict` [ `str`, `list` ], optional 

300 Dictionary whose keys are collection names and values are lists of 

301 dataset type names. By default a single dataset of type "add_dataset0" 

302 is added to a ``butler.run`` collection. 

303 """ 

304 

305 # Add dataset types to registry 

306 taskDefs = list(pipeline.toExpandedPipeline()) 

307 registerDatasetTypes(butler.registry, taskDefs) 

308 

309 instrument = pipeline.getInstrument() 

310 if instrument is not None: 

311 instrument_class = doImportType(instrument) 

312 instrumentName = instrument_class.getName() 

313 else: 

314 instrumentName = "INSTR" 

315 

316 # Add all needed dimensions to registry 

317 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

318 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

319 

320 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

321 # Add inputs to butler 

322 if not datasetTypes: 

323 datasetTypes = {None: ["add_dataset0"]} 

324 for run, dsTypes in datasetTypes.items(): 

325 if run is not None: 

326 butler.registry.registerRun(run) 

327 for dsType in dsTypes: 

328 if dsType == "packages": 

329 # Version is intentionally inconsistent. 

330 # Dict is convertible to Packages if Packages is installed. 

331 data: Any = {"python": "9.9.99"} 

332 butler.put(data, dsType, run=run) 

333 else: 

334 if dsType.endswith("_config"): 

335 # find a confing from matching task name or make a new one 

336 taskLabel, _, _ = dsType.rpartition("_") 

337 taskDef = taskDefMap.get(taskLabel) 

338 if taskDef is not None: 

339 data = taskDef.config 

340 else: 

341 data = AddTaskConfig() 

342 elif dsType.endswith("_metadata"): 

343 data = _TASK_FULL_METADATA_TYPE() 

344 elif dsType.endswith("_log"): 

345 data = ButlerLogRecords.from_records([]) 

346 else: 

347 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

348 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

349 

350 

351def makeSimpleQGraph( 

352 nQuanta: int = 5, 

353 pipeline: Optional[Pipeline] = None, 

354 butler: Optional[Butler] = None, 

355 root: Optional[str] = None, 

356 callPopulateButler: bool = True, 

357 run: str = "test", 

358 skipExistingIn: Any = None, 

359 inMemory: bool = True, 

360 userQuery: str = "", 

361 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None, 

362 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

363 makeDatastoreRecords: bool = False, 

364 resolveRefs: bool = False, 

365 bind: Optional[Mapping[str, Any]] = None, 

366) -> Tuple[Butler, QuantumGraph]: 

367 """Make simple QuantumGraph for tests. 

368 

369 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

370 and butler, fills them with minimal data, and generates QuantumGraph with 

371 all of that. 

372 

373 Parameters 

374 ---------- 

375 nQuanta : `int` 

376 Number of quanta in a graph, only used if ``pipeline`` is None. 

377 pipeline : `~lsst.pipe.base.Pipeline` 

378 If `None` then a pipeline is made with `AddTask` and default 

379 `AddTaskConfig`. 

380 butler : `~lsst.daf.butler.Butler`, optional 

381 Data butler instance, if None then new data butler is created by 

382 calling `makeSimpleButler`. 

383 callPopulateButler : `bool`, optional 

384 If True insert datasets into the butler prior to building a graph. 

385 If False butler argument must not be None, and must be pre-populated. 

386 Defaults to True. 

387 root : `str` 

388 Path or URI to the root location of the new repository. Only used if 

389 ``butler`` is None. 

390 run : `str`, optional 

391 Name of the RUN collection to add to butler, only used if ``butler`` 

392 is None. 

393 skipExistingIn 

394 Expressions representing the collections to search for existing 

395 output datasets that should be skipped. See 

396 :ref:`daf_butler_ordered_collection_searches`. 

397 inMemory : `bool`, optional 

398 If true make in-memory repository, only used if ``butler`` is `None`. 

399 userQuery : `str`, optional 

400 The user query to pass to ``makeGraph``, by default an empty string. 

401 datasetTypes : `dict` [ `str`, `list` ], optional 

402 Dictionary whose keys are collection names and values are lists of 

403 dataset type names. By default a single dataset of type "add_dataset0" 

404 is added to a ``butler.run`` collection. 

405 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

406 The query constraint variant that should be used to constrain the 

407 query based on dataset existence, defaults to 

408 `DatasetQueryConstraintVariant.ALL`. 

409 makeDatastoreRecords : `bool`, optional 

410 If `True` then add datstore records to generated quanta. 

411 resolveRefs : `bool`, optional 

412 If `True` then resolve all input references and generate random dataset 

413 IDs for all output and intermediate datasets. 

414 bind : `Mapping`, optional 

415 Mapping containing literal values that should be injected into the 

416 ``userQuery`` expression, keyed by the identifiers they replace. 

417 

418 Returns 

419 ------- 

420 butler : `~lsst.daf.butler.Butler` 

421 Butler instance 

422 qgraph : `~lsst.pipe.base.QuantumGraph` 

423 Quantum graph instance 

424 """ 

425 

426 if pipeline is None: 

427 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

428 

429 if butler is None: 

430 if root is None: 

431 raise ValueError("Must provide `root` when `butler` is None") 

432 if callPopulateButler is False: 

433 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

434 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

435 

436 if callPopulateButler: 

437 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

438 

439 # Make the graph 

440 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

441 builder = GraphBuilder( 

442 registry=butler.registry, 

443 skipExistingIn=skipExistingIn, 

444 datastore=butler.datastore if makeDatastoreRecords else None, 

445 ) 

446 _LOG.debug( 

447 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

448 butler.collections, 

449 run or butler.run, 

450 userQuery, 

451 bind, 

452 ) 

453 qgraph = builder.makeGraph( 

454 pipeline, 

455 collections=butler.collections, 

456 run=run or butler.run, 

457 userQuery=userQuery, 

458 datasetQueryConstraint=datasetQueryConstraint, 

459 resolveRefs=resolveRefs, 

460 bind=bind, 

461 ) 

462 

463 return butler, qgraph