Coverage for python/lsst/pipe/base/tests/simpleQGraph.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
29import numpy
31from lsst.daf.butler import Butler, Config, DatasetType
32import lsst.daf.butler.tests as butlerTests
33import lsst.pex.config as pexConfig
34from lsst.utils import doImport
35from ... import base as pipeBase
36from .. import connectionTypes as cT
38_LOG = logging.getLogger(__name__)
41# SimpleInstrument has an instrument-like API as needed for unit testing, but
42# can not explicitly depend on Instrument because pipe_base does not explicitly
43# depend on obs_base.
44class SimpleInstrument:
46 @staticmethod
47 def getName():
48 return "INSTRU"
50 def applyConfigOverrides(self, name, config):
51 pass
54class AddTaskConnections(pipeBase.PipelineTaskConnections,
55 dimensions=("instrument", "detector"),
56 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
57 """Connections for AddTask, has one input and two outputs,
58 plus one init output.
59 """
60 input = cT.Input(name="add_dataset{in_tmpl}",
61 dimensions=["instrument", "detector"],
62 storageClass="NumpyArray",
63 doc="Input dataset type for this task")
64 output = cT.Output(name="add_dataset{out_tmpl}",
65 dimensions=["instrument", "detector"],
66 storageClass="NumpyArray",
67 doc="Output dataset type for this task")
68 output2 = cT.Output(name="add2_dataset{out_tmpl}",
69 dimensions=["instrument", "detector"],
70 storageClass="NumpyArray",
71 doc="Output dataset type for this task")
72 initout = cT.InitOutput(name="add_init_output{out_tmpl}",
73 storageClass="NumpyArray",
74 doc="Init Output dataset type for this task")
77class AddTaskConfig(pipeBase.PipelineTaskConfig,
78 pipelineConnections=AddTaskConnections):
79 """Config for AddTask.
80 """
81 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
84class AddTask(pipeBase.PipelineTask):
85 """Trivial PipelineTask for testing, has some extras useful for specific
86 unit tests.
87 """
89 ConfigClass = AddTaskConfig
90 _DefaultName = "add_task"
92 initout = numpy.array([999])
93 """InitOutputs for this task"""
95 taskFactory = None
96 """Factory that makes instances"""
98 def run(self, input):
100 if self.taskFactory:
101 # do some bookkeeping
102 if self.taskFactory.stopAt == self.taskFactory.countExec:
103 raise RuntimeError("pretend something bad happened")
104 self.taskFactory.countExec += 1
106 self.metadata.add("add", self.config.addend)
107 output = input + self.config.addend
108 output2 = output + self.config.addend
109 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
110 return pipeBase.Struct(output=output, output2=output2)
113class AddTaskFactoryMock(pipeBase.TaskFactory):
114 """Special task factory that instantiates AddTask.
116 It also defines some bookkeeping variables used by AddTask to report
117 progress to unit tests.
118 """
119 def __init__(self, stopAt=-1):
120 self.countExec = 0 # incremented by AddTask
121 self.stopAt = stopAt # AddTask raises exception at this call to run()
123 def loadTaskClass(self, taskName):
124 if taskName == "AddTask":
125 return AddTask, "AddTask"
127 def makeTask(self, taskClass, name, config, overrides, butler):
128 if config is None:
129 config = taskClass.ConfigClass()
130 if overrides:
131 overrides.applyTo(config)
132 task = taskClass(config=config, initInputs=None, name=name)
133 task.taskFactory = self
134 return task
137def registerDatasetTypes(registry, pipeline):
138 """Register all dataset types used by tasks in a registry.
140 Copied and modified from `PreExecInit.initializeDatasetTypes`.
142 Parameters
143 ----------
144 registry : `~lsst.daf.butler.Registry`
145 Registry instance.
146 pipeline : `typing.Iterable` of `TaskDef`
147 Iterable of TaskDef instances, likely the output of the method
148 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
149 """
150 for taskDef in pipeline:
151 configDatasetType = DatasetType(taskDef.configDatasetName, {},
152 storageClass="Config",
153 universe=registry.dimensions)
154 packagesDatasetType = DatasetType("packages", {},
155 storageClass="Packages",
156 universe=registry.dimensions)
157 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
158 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
159 datasetTypes.inputs, datasetTypes.outputs,
160 datasetTypes.prerequisites,
161 [configDatasetType, packagesDatasetType]):
162 _LOG.info("Registering %s with registry", datasetType)
163 # this is a no-op if it already exists and is consistent,
164 # and it raises if it is inconsistent. But components must be
165 # skipped
166 if not datasetType.isComponent():
167 registry.registerDatasetType(datasetType)
170def makeSimplePipeline(nQuanta, instrument=None):
171 """Make a simple Pipeline for tests.
173 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
174 function. It can also be used to customize the pipeline used by
175 ``makeSimpleQGraph`` function by calling this first and passing the result
176 to it.
178 Parameters
179 ----------
180 nQuanta : `int`
181 The number of quanta to add to the pipeline.
182 instrument : `str` or `None`, optional
183 The importable name of an instrument to be added to the pipeline or
184 if no instrument should be added then an empty string or `None`, by
185 default None
187 Returns
188 -------
189 pipeline : `~lsst.pipe.base.Pipeline`
190 The created pipeline object.
191 """
192 pipeline = pipeBase.Pipeline("test pipeline")
193 # make a bunch of tasks that execute in well defined order (via data
194 # dependencies)
195 for lvl in range(nQuanta):
196 pipeline.addTask(AddTask, f"task{lvl}")
197 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
198 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl+1)
199 if instrument:
200 pipeline.addInstrument(instrument)
201 return pipeline
204def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
205 userQuery=""):
206 """Make simple QuantumGraph for tests.
208 Makes simple one-task pipeline with AddTask, sets up in-memory
209 registry and butler, fills them with minimal data, and generates
210 QuantumGraph with all of that.
212 Parameters
213 ----------
214 nQuanta : `int`
215 Number of quanta in a graph.
216 pipeline : `~lsst.pipe.base.Pipeline`
217 If `None` then one-task pipeline is made with `AddTask` and
218 default `AddTaskConfig`.
219 butler : `~lsst.daf.butler.Butler`, optional
220 Data butler instance, this should be an instance returned from a
221 previous call to this method.
222 root : `str`
223 Path or URI to the root location of the new repository. Only used if
224 ``butler`` is None.
225 skipExisting : `bool`, optional
226 If `True` (default), a Quantum is not created if all its outputs
227 already exist.
228 inMemory : `bool`, optional
229 If true make in-memory repository.
230 userQuery : `str`, optional
231 The user query to pass to ``makeGraph``, by default an empty string.
233 Returns
234 -------
235 butler : `~lsst.daf.butler.Butler`
236 Butler instance
237 qgraph : `~lsst.pipe.base.QuantumGraph`
238 Quantum graph instance
239 """
241 if pipeline is None:
242 pipeline = makeSimplePipeline(nQuanta=nQuanta)
244 if butler is None:
246 if root is None:
247 raise ValueError("Must provide `root` when `butler` is None")
249 config = Config()
250 if not inMemory:
251 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite"
252 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
253 repo = butlerTests.makeTestRepo(root, {}, config=config)
254 collection = "test"
255 butler = Butler(butler=repo, run=collection)
257 # Add dataset types to registry
258 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())
260 instrument = pipeline.getInstrument()
261 if instrument is not None:
262 if isinstance(instrument, str):
263 instrument = doImport(instrument)
264 instrumentName = instrument.getName()
265 else:
266 instrumentName = "INSTR"
268 # Add all needed dimensions to registry
269 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
270 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0,
271 full_name="det0"))
273 # Add inputs to butler
274 data = numpy.array([0., 1., 2., 5.])
275 butler.put(data, "add_dataset0", instrument=instrumentName, detector=0)
277 # Make the graph
278 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
279 qgraph = builder.makeGraph(
280 pipeline,
281 collections=[butler.run],
282 run=butler.run,
283 userQuery=userQuery
284 )
286 return butler, qgraph