Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%

163 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-11 03:12 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast 

32 

33import lsst.daf.butler.tests as butlerTests 

34import lsst.pex.config as pexConfig 

35import numpy 

36from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler 

37from lsst.daf.butler.core.logging import ButlerLogRecords 

38from lsst.resources import ResourcePath 

39from lsst.utils import doImportType 

40from lsst.utils.introspection import get_full_type_name 

41 

42from .. import connectionTypes as cT 

43from .._instrument import Instrument 

44from ..config import PipelineTaskConfig 

45from ..connections import PipelineTaskConnections 

46from ..graph import QuantumGraph 

47from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

48from ..graphBuilder import GraphBuilder 

49from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

50from ..pipelineTask import PipelineTask 

51from ..struct import Struct 

52from ..task import _TASK_FULL_METADATA_TYPE 

53from ..taskFactory import TaskFactory 

54 

55if TYPE_CHECKING: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true

56 from lsst.daf.butler import Registry 

57 

58_LOG = logging.getLogger(__name__) 

59 

60 

61class SimpleInstrument(Instrument): 

62 def __init__(self, *args: Any, **kwargs: Any): 

63 pass 

64 

65 @staticmethod 

66 def getName() -> str: 

67 return "INSTRU" 

68 

69 def getRawFormatter(self, dataId: DataId) -> Type[Formatter]: 

70 return Formatter 

71 

72 def register(self, registry: Registry, *, update: bool = False) -> None: 

73 pass 

74 

75 

76class AddTaskConnections( 

77 PipelineTaskConnections, 

78 dimensions=("instrument", "detector"), 

79 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

80): 

81 """Connections for AddTask, has one input and two outputs, 

82 plus one init output. 

83 """ 

84 

85 input = cT.Input( 

86 name="add_dataset{in_tmpl}", 

87 dimensions=["instrument", "detector"], 

88 storageClass="NumpyArray", 

89 doc="Input dataset type for this task", 

90 ) 

91 output = cT.Output( 

92 name="add_dataset{out_tmpl}", 

93 dimensions=["instrument", "detector"], 

94 storageClass="NumpyArray", 

95 doc="Output dataset type for this task", 

96 ) 

97 output2 = cT.Output( 

98 name="add2_dataset{out_tmpl}", 

99 dimensions=["instrument", "detector"], 

100 storageClass="NumpyArray", 

101 doc="Output dataset type for this task", 

102 ) 

103 initout = cT.InitOutput( 

104 name="add_init_output{out_tmpl}", 

105 storageClass="NumpyArray", 

106 doc="Init Output dataset type for this task", 

107 ) 

108 

109 

110class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

111 """Config for AddTask.""" 

112 

113 addend = pexConfig.Field[int](doc="amount to add", default=3) 

114 

115 

116class AddTask(PipelineTask): 

117 """Trivial PipelineTask for testing, has some extras useful for specific 

118 unit tests. 

119 """ 

120 

121 ConfigClass = AddTaskConfig 

122 _DefaultName = "add_task" 

123 

124 initout = numpy.array([999]) 

125 """InitOutputs for this task""" 

126 

127 taskFactory: Optional[AddTaskFactoryMock] = None 

128 """Factory that makes instances""" 

129 

130 def run(self, input: int) -> Struct: # type: ignore 

131 if self.taskFactory: 

132 # do some bookkeeping 

133 if self.taskFactory.stopAt == self.taskFactory.countExec: 

134 raise RuntimeError("pretend something bad happened") 

135 self.taskFactory.countExec += 1 

136 

137 self.config = cast(AddTaskConfig, self.config) 

138 self.metadata.add("add", self.config.addend) 

139 output = input + self.config.addend 

140 output2 = output + self.config.addend 

141 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

142 return Struct(output=output, output2=output2) 

143 

144 

145class AddTaskFactoryMock(TaskFactory): 

146 """Special task factory that instantiates AddTask. 

147 

148 It also defines some bookkeeping variables used by AddTask to report 

149 progress to unit tests. 

150 """ 

151 

152 def __init__(self, stopAt: int = -1): 

153 self.countExec = 0 # incremented by AddTask 

154 self.stopAt = stopAt # AddTask raises exception at this call to run() 

155 

156 def makeTask( 

157 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None 

158 ) -> PipelineTask: 

159 taskClass = taskDef.taskClass 

160 assert taskClass is not None 

161 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label) 

162 task.taskFactory = self # type: ignore 

163 return task 

164 

165 

166def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None: 

167 """Register all dataset types used by tasks in a registry. 

168 

169 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

170 

171 Parameters 

172 ---------- 

173 registry : `~lsst.daf.butler.Registry` 

174 Registry instance. 

175 pipeline : `typing.Iterable` of `TaskDef` 

176 Iterable of TaskDef instances, likely the output of the method 

177 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

178 """ 

179 for taskDef in pipeline: 

180 configDatasetType = DatasetType( 

181 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

182 ) 

183 storageClass = "Packages" 

184 packagesDatasetType = DatasetType( 

185 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

186 ) 

187 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

188 for datasetType in itertools.chain( 

189 datasetTypes.initInputs, 

190 datasetTypes.initOutputs, 

191 datasetTypes.inputs, 

192 datasetTypes.outputs, 

193 datasetTypes.prerequisites, 

194 [configDatasetType, packagesDatasetType], 

195 ): 

196 _LOG.info("Registering %s with registry", datasetType) 

197 # this is a no-op if it already exists and is consistent, 

198 # and it raises if it is inconsistent. But components must be 

199 # skipped 

200 if not datasetType.isComponent(): 

201 registry.registerDatasetType(datasetType) 

202 

203 

204def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline: 

205 """Make a simple Pipeline for tests. 

206 

207 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

208 function. It can also be used to customize the pipeline used by 

209 ``makeSimpleQGraph`` function by calling this first and passing the result 

210 to it. 

211 

212 Parameters 

213 ---------- 

214 nQuanta : `int` 

215 The number of quanta to add to the pipeline. 

216 instrument : `str` or `None`, optional 

217 The importable name of an instrument to be added to the pipeline or 

218 if no instrument should be added then an empty string or `None`, by 

219 default None 

220 

221 Returns 

222 ------- 

223 pipeline : `~lsst.pipe.base.Pipeline` 

224 The created pipeline object. 

225 """ 

226 pipeline = Pipeline("test pipeline") 

227 # make a bunch of tasks that execute in well defined order (via data 

228 # dependencies) 

229 for lvl in range(nQuanta): 

230 pipeline.addTask(AddTask, f"task{lvl}") 

231 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

232 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

233 if instrument: 

234 pipeline.addInstrument(instrument) 

235 return pipeline 

236 

237 

238def makeSimpleButler( 

239 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None 

240) -> Butler: 

241 """Create new data butler instance. 

242 

243 Parameters 

244 ---------- 

245 root : `str` 

246 Path or URI to the root location of the new repository. 

247 run : `str`, optional 

248 Run collection name. 

249 inMemory : `bool`, optional 

250 If true make in-memory repository. 

251 config : `~lsst.daf.butler.Config`, optional 

252 Configuration to use for new Butler, if `None` then default 

253 configuration is used. If ``inMemory`` is `True` then configuration 

254 is updated to use SQLite registry and file datastore in ``root``. 

255 

256 Returns 

257 ------- 

258 butler : `~lsst.daf.butler.Butler` 

259 Data butler instance. 

260 """ 

261 root_path = ResourcePath(root, forceDirectory=True) 

262 if not root_path.isLocal: 

263 raise ValueError(f"Only works with local root not {root_path}") 

264 butler_config = Config() 

265 if config: 

266 butler_config.update(Config(config)) 

267 if not inMemory: 

268 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

269 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

270 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config) 

271 butler = Butler(butler=repo, run=run) 

272 return butler 

273 

274 

275def populateButler( 

276 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] | None = None 

277) -> None: 

278 """Populate data butler with data needed for test. 

279 

280 Initializes data butler with a bunch of items: 

281 - registers dataset types which are defined by pipeline 

282 - create dimensions data for (instrument, detector) 

283 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

284 missing then a single dataset with type "add_dataset0" is added 

285 

286 All datasets added to butler have ``dataId={instrument=instrument, 

287 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

288 used if pipeline is missing instrument definition. Type of the dataset is 

289 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

290 tasks). 

291 

292 Parameters 

293 ---------- 

294 pipeline : `~lsst.pipe.base.Pipeline` 

295 Pipeline instance. 

296 butler : `~lsst.daf.butler.Butler` 

297 Data butler instance. 

298 datasetTypes : `dict` [ `str`, `list` ], optional 

299 Dictionary whose keys are collection names and values are lists of 

300 dataset type names. By default a single dataset of type "add_dataset0" 

301 is added to a ``butler.run`` collection. 

302 """ 

303 

304 # Add dataset types to registry 

305 taskDefs = list(pipeline.toExpandedPipeline()) 

306 registerDatasetTypes(butler.registry, taskDefs) 

307 

308 instrument = pipeline.getInstrument() 

309 if instrument is not None: 

310 instrument_class = doImportType(instrument) 

311 instrumentName = instrument_class.getName() 

312 instrumentClass = get_full_type_name(instrument_class) 

313 else: 

314 instrumentName = "INSTR" 

315 instrumentClass = None 

316 

317 # Add all needed dimensions to registry 

318 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass)) 

319 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

320 

321 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

322 # Add inputs to butler 

323 if not datasetTypes: 

324 datasetTypes = {None: ["add_dataset0"]} 

325 for run, dsTypes in datasetTypes.items(): 

326 if run is not None: 

327 butler.registry.registerRun(run) 

328 for dsType in dsTypes: 

329 if dsType == "packages": 

330 # Version is intentionally inconsistent. 

331 # Dict is convertible to Packages if Packages is installed. 

332 data: Any = {"python": "9.9.99"} 

333 butler.put(data, dsType, run=run) 

334 else: 

335 if dsType.endswith("_config"): 

336 # find a config from matching task name or make a new one 

337 taskLabel, _, _ = dsType.rpartition("_") 

338 taskDef = taskDefMap.get(taskLabel) 

339 if taskDef is not None: 

340 data = taskDef.config 

341 else: 

342 data = AddTaskConfig() 

343 elif dsType.endswith("_metadata"): 

344 data = _TASK_FULL_METADATA_TYPE() 

345 elif dsType.endswith("_log"): 

346 data = ButlerLogRecords.from_records([]) 

347 else: 

348 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

349 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

350 

351 

352def makeSimpleQGraph( 

353 nQuanta: int = 5, 

354 pipeline: Optional[Pipeline] = None, 

355 butler: Optional[Butler] = None, 

356 root: Optional[str] = None, 

357 callPopulateButler: bool = True, 

358 run: str = "test", 

359 instrument: Optional[str] = None, 

360 skipExistingIn: Any = None, 

361 inMemory: bool = True, 

362 userQuery: str = "", 

363 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None, 

364 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

365 makeDatastoreRecords: bool = False, 

366 bind: Optional[Mapping[str, Any]] = None, 

367 metadata: Optional[MutableMapping[str, Any]] = None, 

368) -> Tuple[Butler, QuantumGraph]: 

369 """Make simple QuantumGraph for tests. 

370 

371 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

372 and butler, fills them with minimal data, and generates QuantumGraph with 

373 all of that. 

374 

375 Parameters 

376 ---------- 

377 nQuanta : `int` 

378 Number of quanta in a graph, only used if ``pipeline`` is None. 

379 pipeline : `~lsst.pipe.base.Pipeline` 

380 If `None` then a pipeline is made with `AddTask` and default 

381 `AddTaskConfig`. 

382 butler : `~lsst.daf.butler.Butler`, optional 

383 Data butler instance, if None then new data butler is created by 

384 calling `makeSimpleButler`. 

385 callPopulateButler : `bool`, optional 

386 If True insert datasets into the butler prior to building a graph. 

387 If False butler argument must not be None, and must be pre-populated. 

388 Defaults to True. 

389 root : `str` 

390 Path or URI to the root location of the new repository. Only used if 

391 ``butler`` is None. 

392 run : `str`, optional 

393 Name of the RUN collection to add to butler, only used if ``butler`` 

394 is None. 

395 instrument : `str` or `None`, optional 

396 The importable name of an instrument to be added to the pipeline or 

397 if no instrument should be added then an empty string or `None`, by 

398 default `None`. Only used if ``pipeline`` is `None`. 

399 skipExistingIn 

400 Expressions representing the collections to search for existing 

401 output datasets that should be skipped. See 

402 :ref:`daf_butler_ordered_collection_searches`. 

403 inMemory : `bool`, optional 

404 If true make in-memory repository, only used if ``butler`` is `None`. 

405 userQuery : `str`, optional 

406 The user query to pass to ``makeGraph``, by default an empty string. 

407 datasetTypes : `dict` [ `str`, `list` ], optional 

408 Dictionary whose keys are collection names and values are lists of 

409 dataset type names. By default a single dataset of type "add_dataset0" 

410 is added to a ``butler.run`` collection. 

411 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

412 The query constraint variant that should be used to constrain the 

413 query based on dataset existence, defaults to 

414 `DatasetQueryConstraintVariant.ALL`. 

415 makeDatastoreRecords : `bool`, optional 

416 If `True` then add datstore records to generated quanta. 

417 bind : `Mapping`, optional 

418 Mapping containing literal values that should be injected into the 

419 ``userQuery`` expression, keyed by the identifiers they replace. 

420 metadata : `Mapping`, optional 

421 Optional graph metadata. 

422 

423 Returns 

424 ------- 

425 butler : `~lsst.daf.butler.Butler` 

426 Butler instance 

427 qgraph : `~lsst.pipe.base.QuantumGraph` 

428 Quantum graph instance 

429 """ 

430 

431 if pipeline is None: 

432 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument) 

433 

434 if butler is None: 

435 if root is None: 

436 raise ValueError("Must provide `root` when `butler` is None") 

437 if callPopulateButler is False: 

438 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

439 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

440 

441 if callPopulateButler: 

442 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

443 

444 # Make the graph 

445 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

446 builder = GraphBuilder( 

447 registry=butler.registry, 

448 skipExistingIn=skipExistingIn, 

449 datastore=butler.datastore if makeDatastoreRecords else None, 

450 ) 

451 if not run: 

452 assert butler.run is not None, "Butler must have run defined" 

453 run = butler.run 

454 _LOG.debug( 

455 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

456 butler.collections, 

457 run, 

458 userQuery, 

459 bind, 

460 ) 

461 if not metadata: 

462 metadata = {} 

463 metadata["output_run"] = run 

464 

465 qgraph = builder.makeGraph( 

466 pipeline, 

467 collections=butler.collections, 

468 run=run, 

469 userQuery=userQuery, 

470 datasetQueryConstraint=datasetQueryConstraint, 

471 bind=bind, 

472 metadata=metadata, 

473 ) 

474 

475 return butler, qgraph