Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
24from __future__ import annotations
26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
28import itertools
29import logging
30from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union
32import lsst.daf.butler.tests as butlerTests
33import lsst.pex.config as pexConfig
34import numpy
36try:
37 from lsst.base import Packages
38except ImportError:
39 Packages = None
40from lsst.daf.butler import Butler, Config, DatasetType
41from lsst.daf.butler.core.logging import ButlerLogRecords
42from lsst.resources import ResourcePath
43from lsst.utils import doImportType
45from .. import connectionTypes as cT
46from ..config import PipelineTaskConfig
47from ..connections import PipelineTaskConnections
48from ..graph import QuantumGraph
49from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
50from ..graphBuilder import GraphBuilder
51from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef
52from ..pipelineTask import PipelineTask
53from ..struct import Struct
54from ..task import _TASK_FULL_METADATA_TYPE
55from ..taskFactory import TaskFactory
57if TYPE_CHECKING: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true
58 from lsst.daf.butler import Registry
60 from ..configOverrides import ConfigOverrides
62_LOG = logging.getLogger(__name__)
65# SimpleInstrument has an instrument-like API as needed for unit testing, but
66# can not explicitly depend on Instrument because pipe_base does not explicitly
67# depend on obs_base.
68class SimpleInstrument:
69 def __init__(self, *args: Any, **kwargs: Any):
70 pass
72 @staticmethod
73 def getName() -> str:
74 return "INSTRU"
76 def applyConfigOverrides(self, name: str, config: pexConfig.Config) -> None:
77 pass
80class AddTaskConnections(
81 PipelineTaskConnections,
82 dimensions=("instrument", "detector"),
83 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
84):
85 """Connections for AddTask, has one input and two outputs,
86 plus one init output.
87 """
89 input = cT.Input(
90 name="add_dataset{in_tmpl}",
91 dimensions=["instrument", "detector"],
92 storageClass="NumpyArray",
93 doc="Input dataset type for this task",
94 )
95 output = cT.Output(
96 name="add_dataset{out_tmpl}",
97 dimensions=["instrument", "detector"],
98 storageClass="NumpyArray",
99 doc="Output dataset type for this task",
100 )
101 output2 = cT.Output(
102 name="add2_dataset{out_tmpl}",
103 dimensions=["instrument", "detector"],
104 storageClass="NumpyArray",
105 doc="Output dataset type for this task",
106 )
107 initout = cT.InitOutput(
108 name="add_init_output{out_tmpl}",
109 storageClass="NumpyArray",
110 doc="Init Output dataset type for this task",
111 )
114class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
115 """Config for AddTask."""
117 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
120class AddTask(PipelineTask):
121 """Trivial PipelineTask for testing, has some extras useful for specific
122 unit tests.
123 """
125 ConfigClass = AddTaskConfig
126 _DefaultName = "add_task"
128 initout = numpy.array([999])
129 """InitOutputs for this task"""
131 taskFactory: Optional[AddTaskFactoryMock] = None
132 """Factory that makes instances"""
134 def run(self, input: int) -> Struct: # type: ignore
136 if self.taskFactory:
137 # do some bookkeeping
138 if self.taskFactory.stopAt == self.taskFactory.countExec:
139 raise RuntimeError("pretend something bad happened")
140 self.taskFactory.countExec += 1
142 self.metadata.add("add", self.config.addend)
143 output = input + self.config.addend
144 output2 = output + self.config.addend
145 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
146 return Struct(output=output, output2=output2)
149class AddTaskFactoryMock(TaskFactory):
150 """Special task factory that instantiates AddTask.
152 It also defines some bookkeeping variables used by AddTask to report
153 progress to unit tests.
154 """
156 def __init__(self, stopAt: int = -1):
157 self.countExec = 0 # incremented by AddTask
158 self.stopAt = stopAt # AddTask raises exception at this call to run()
160 def makeTask(
161 self,
162 taskClass: Type[PipelineTask],
163 name: Optional[str],
164 config: Optional[PipelineTaskConfig],
165 overrides: Optional[ConfigOverrides],
166 butler: Optional[Butler],
167 ) -> PipelineTask:
168 if config is None:
169 config = taskClass.ConfigClass()
170 if overrides:
171 overrides.applyTo(config)
172 task = taskClass(config=config, initInputs=None, name=name)
173 task.taskFactory = self # type: ignore
174 return task
177def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None:
178 """Register all dataset types used by tasks in a registry.
180 Copied and modified from `PreExecInit.initializeDatasetTypes`.
182 Parameters
183 ----------
184 registry : `~lsst.daf.butler.Registry`
185 Registry instance.
186 pipeline : `typing.Iterable` of `TaskDef`
187 Iterable of TaskDef instances, likely the output of the method
188 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
189 """
190 for taskDef in pipeline:
191 configDatasetType = DatasetType(
192 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
193 )
194 storageClass = "Packages" if Packages is not None else "StructuredDataDict"
195 packagesDatasetType = DatasetType(
196 "packages", {}, storageClass=storageClass, universe=registry.dimensions
197 )
198 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
199 for datasetType in itertools.chain(
200 datasetTypes.initInputs,
201 datasetTypes.initOutputs,
202 datasetTypes.inputs,
203 datasetTypes.outputs,
204 datasetTypes.prerequisites,
205 [configDatasetType, packagesDatasetType],
206 ):
207 _LOG.info("Registering %s with registry", datasetType)
208 # this is a no-op if it already exists and is consistent,
209 # and it raises if it is inconsistent. But components must be
210 # skipped
211 if not datasetType.isComponent():
212 registry.registerDatasetType(datasetType)
215def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline:
216 """Make a simple Pipeline for tests.
218 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
219 function. It can also be used to customize the pipeline used by
220 ``makeSimpleQGraph`` function by calling this first and passing the result
221 to it.
223 Parameters
224 ----------
225 nQuanta : `int`
226 The number of quanta to add to the pipeline.
227 instrument : `str` or `None`, optional
228 The importable name of an instrument to be added to the pipeline or
229 if no instrument should be added then an empty string or `None`, by
230 default None
232 Returns
233 -------
234 pipeline : `~lsst.pipe.base.Pipeline`
235 The created pipeline object.
236 """
237 pipeline = Pipeline("test pipeline")
238 # make a bunch of tasks that execute in well defined order (via data
239 # dependencies)
240 for lvl in range(nQuanta):
241 pipeline.addTask(AddTask, f"task{lvl}")
242 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
243 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
244 if instrument:
245 pipeline.addInstrument(instrument)
246 return pipeline
249def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler:
250 """Create new data butler instance.
252 Parameters
253 ----------
254 root : `str`
255 Path or URI to the root location of the new repository.
256 run : `str`, optional
257 Run collection name.
258 inMemory : `bool`, optional
259 If true make in-memory repository.
261 Returns
262 -------
263 butler : `~lsst.daf.butler.Butler`
264 Data butler instance.
265 """
266 root_path = ResourcePath(root, forceDirectory=True)
267 if not root_path.isLocal:
268 raise ValueError(f"Only works with local root not {root_path}")
269 config = Config()
270 if not inMemory:
271 config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
272 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
273 repo = butlerTests.makeTestRepo(str(root_path), {}, config=config)
274 butler = Butler(butler=repo, run=run)
275 return butler
278def populateButler(
279 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] = None
280) -> None:
281 """Populate data butler with data needed for test.
283 Initializes data butler with a bunch of items:
284 - registers dataset types which are defined by pipeline
285 - create dimensions data for (instrument, detector)
286 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
287 missing then a single dataset with type "add_dataset0" is added
289 All datasets added to butler have ``dataId={instrument=instrument,
290 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
291 used if pipeline is missing instrument definition. Type of the dataset is
292 guessed from dataset type name (assumes that pipeline is made of `AddTask`
293 tasks).
295 Parameters
296 ----------
297 pipeline : `~lsst.pipe.base.Pipeline`
298 Pipeline instance.
299 butler : `~lsst.daf.butler.Butler`
300 Data butler instance.
301 datasetTypes : `dict` [ `str`, `list` ], optional
302 Dictionary whose keys are collection names and values are lists of
303 dataset type names. By default a single dataset of type "add_dataset0"
304 is added to a ``butler.run`` collection.
305 """
307 # Add dataset types to registry
308 taskDefs = list(pipeline.toExpandedPipeline())
309 registerDatasetTypes(butler.registry, taskDefs)
311 instrument = pipeline.getInstrument()
312 if instrument is not None:
313 instrument_class = doImportType(instrument)
314 instrumentName = instrument_class.getName()
315 else:
316 instrumentName = "INSTR"
318 # Add all needed dimensions to registry
319 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
320 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
322 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
323 # Add inputs to butler
324 if not datasetTypes:
325 datasetTypes = {None: ["add_dataset0"]}
326 for run, dsTypes in datasetTypes.items():
327 if run is not None:
328 butler.registry.registerRun(run)
329 for dsType in dsTypes:
330 if dsType == "packages":
331 # Version is intentionally inconsistent.
332 # Dict is convertible to Packages if Packages is installed.
333 data: Any = {"python": "9.9.99"}
334 butler.put(data, dsType, run=run)
335 else:
336 if dsType.endswith("_config"):
337 # find a confing from matching task name or make a new one
338 taskLabel, _, _ = dsType.rpartition("_")
339 taskDef = taskDefMap.get(taskLabel)
340 if taskDef is not None:
341 data = taskDef.config
342 else:
343 data = AddTaskConfig()
344 elif dsType.endswith("_metadata"):
345 data = _TASK_FULL_METADATA_TYPE()
346 elif dsType.endswith("_log"):
347 data = ButlerLogRecords.from_records([])
348 else:
349 data = numpy.array([0.0, 1.0, 2.0, 5.0])
350 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
353def makeSimpleQGraph(
354 nQuanta: int = 5,
355 pipeline: Optional[Pipeline] = None,
356 butler: Optional[Butler] = None,
357 root: Optional[str] = None,
358 callPopulateButler: bool = True,
359 run: str = "test",
360 skipExistingIn: Any = None,
361 inMemory: bool = True,
362 userQuery: str = "",
363 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None,
364 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
365) -> Tuple[Butler, QuantumGraph]:
366 """Make simple QuantumGraph for tests.
368 Makes simple one-task pipeline with AddTask, sets up in-memory registry
369 and butler, fills them with minimal data, and generates QuantumGraph with
370 all of that.
372 Parameters
373 ----------
374 nQuanta : `int`
375 Number of quanta in a graph, only used if ``pipeline`` is None.
376 pipeline : `~lsst.pipe.base.Pipeline`
377 If `None` then a pipeline is made with `AddTask` and default
378 `AddTaskConfig`.
379 butler : `~lsst.daf.butler.Butler`, optional
380 Data butler instance, if None then new data butler is created by
381 calling `makeSimpleButler`.
382 callPopulateButler : `bool`, optional
383 If True insert datasets into the butler prior to building a graph.
384 If False butler argument must not be None, and must be pre-populated.
385 Defaults to True.
386 root : `str`
387 Path or URI to the root location of the new repository. Only used if
388 ``butler`` is None.
389 run : `str`, optional
390 Name of the RUN collection to add to butler, only used if ``butler``
391 is None.
392 skipExistingIn
393 Expressions representing the collections to search for existing
394 output datasets that should be skipped. May be any of the types
395 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
396 inMemory : `bool`, optional
397 If true make in-memory repository, only used if ``butler`` is `None`.
398 userQuery : `str`, optional
399 The user query to pass to ``makeGraph``, by default an empty string.
400 datasetTypes : `dict` [ `str`, `list` ], optional
401 Dictionary whose keys are collection names and values are lists of
402 dataset type names. By default a single dataset of type "add_dataset0"
403 is added to a ``butler.run`` collection.
404 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
405 The query constraint variant that should be used to constrain the
406 query based on dataset existence, defaults to
407 `DatasetQueryConstraintVariant.ALL`.
409 Returns
410 -------
411 butler : `~lsst.daf.butler.Butler`
412 Butler instance
413 qgraph : `~lsst.pipe.base.QuantumGraph`
414 Quantum graph instance
415 """
417 if pipeline is None:
418 pipeline = makeSimplePipeline(nQuanta=nQuanta)
420 if butler is None:
421 if root is None:
422 raise ValueError("Must provide `root` when `butler` is None")
423 if callPopulateButler is False:
424 raise ValueError("populateButler can only be False when butler is supplied as an argument")
425 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
427 if callPopulateButler:
428 populateButler(pipeline, butler, datasetTypes=datasetTypes)
430 # Make the graph
431 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
432 builder = GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn)
433 _LOG.debug(
434 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r",
435 butler.collections,
436 run or butler.run,
437 userQuery,
438 )
439 qgraph = builder.makeGraph(
440 pipeline,
441 collections=butler.collections,
442 run=run or butler.run,
443 userQuery=userQuery,
444 datasetQueryConstraint=datasetQueryConstraint,
445 )
447 return butler, qgraph