Coverage for python/lsst/pipe/base/tests/simpleQGraph.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
29import numpy
31from lsst.daf.butler import (Butler, Config, DatasetType, CollectionSearch)
32import lsst.daf.butler.tests as butlerTests
33import lsst.pex.config as pexConfig
34from ... import base as pipeBase
35from .. import connectionTypes as cT
37_LOG = logging.getLogger(__name__)
40# SimpleInstrument has an instrument-like API as needed for unit testing, but
41# can not explicitly depend on Instrument because pipe_base does not explicitly
42# depend on obs_base.
43class SimpleInstrument:
45 @staticmethod
46 def getName():
47 return "SimpleInstrument"
49 def applyConfigOverrides(self, name, config):
50 pass
53class AddTaskConnections(pipeBase.PipelineTaskConnections,
54 dimensions=("instrument", "detector"),
55 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
56 """Connections for AddTask, has one input and two outputs,
57 plus one init output.
58 """
59 input = cT.Input(name="add_dataset{in_tmpl}",
60 dimensions=["instrument", "detector"],
61 storageClass="NumpyArray",
62 doc="Input dataset type for this task")
63 output = cT.Output(name="add_dataset{out_tmpl}",
64 dimensions=["instrument", "detector"],
65 storageClass="NumpyArray",
66 doc="Output dataset type for this task")
67 output2 = cT.Output(name="add2_dataset{out_tmpl}",
68 dimensions=["instrument", "detector"],
69 storageClass="NumpyArray",
70 doc="Output dataset type for this task")
71 initout = cT.InitOutput(name="add_init_output{out_tmpl}",
72 storageClass="NumpyArray",
73 doc="Init Output dataset type for this task")
76class AddTaskConfig(pipeBase.PipelineTaskConfig,
77 pipelineConnections=AddTaskConnections):
78 """Config for AddTask.
79 """
80 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
83class AddTask(pipeBase.PipelineTask):
84 """Trivial PipelineTask for testing, has some extras useful for specific
85 unit tests.
86 """
88 ConfigClass = AddTaskConfig
89 _DefaultName = "add_task"
91 initout = numpy.array([999])
92 """InitOutputs for this task"""
94 taskFactory = None
95 """Factory that makes instances"""
97 def run(self, input):
99 if self.taskFactory:
100 # do some bookkeeping
101 if self.taskFactory.stopAt == self.taskFactory.countExec:
102 raise RuntimeError("pretend something bad happened")
103 self.taskFactory.countExec += 1
105 self.metadata.add("add", self.config.addend)
106 output = input + self.config.addend
107 output2 = output + self.config.addend
108 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
109 return pipeBase.Struct(output=output, output2=output2)
112class AddTaskFactoryMock(pipeBase.TaskFactory):
113 """Special task factory that instantiates AddTask.
115 It also defines some bookkeeping variables used by AddTask to report
116 progress to unit tests.
117 """
118 def __init__(self, stopAt=-1):
119 self.countExec = 0 # incremented by AddTask
120 self.stopAt = stopAt # AddTask raises exception at this call to run()
122 def loadTaskClass(self, taskName):
123 if taskName == "AddTask":
124 return AddTask, "AddTask"
126 def makeTask(self, taskClass, config, overrides, butler):
127 if config is None:
128 config = taskClass.ConfigClass()
129 if overrides:
130 overrides.applyTo(config)
131 task = taskClass(config=config, initInputs=None)
132 task.taskFactory = self
133 return task
136def registerDatasetTypes(registry, pipeline):
137 """Register all dataset types used by tasks in a registry.
139 Copied and modified from `PreExecInit.initializeDatasetTypes`.
141 Parameters
142 ----------
143 registry : `~lsst.daf.butler.Registry`
144 Registry instance.
145 pipeline : `typing.Iterable` of `TaskDef`
146 Iterable of TaskDef instances, likely the output of the method
147 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
148 """
149 for taskDef in pipeline:
150 configDatasetType = DatasetType(taskDef.configDatasetName, {},
151 storageClass="Config",
152 universe=registry.dimensions)
153 packagesDatasetType = DatasetType("packages", {},
154 storageClass="Packages",
155 universe=registry.dimensions)
156 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
157 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
158 datasetTypes.inputs, datasetTypes.outputs,
159 datasetTypes.prerequisites,
160 [configDatasetType, packagesDatasetType]):
161 _LOG.info("Registering %s with registry", datasetType)
162 # this is a no-op if it already exists and is consistent,
163 # and it raises if it is inconsistent. But components must be
164 # skipped
165 if not datasetType.isComponent():
166 registry.registerDatasetType(datasetType)
169def makeSimplePipeline(nQuanta, instrument=None):
170 """Make a simple Pipeline for tests.
172 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
173 function. It can also be used to customize the pipeline used by
174 ``makeSimpleQGraph`` function by calling this first and passing the result
175 to it.
177 Parameters
178 ----------
179 nQuanta : `int`
180 The number of quanta to add to the pipeline.
181 instrument : `str` or `None`, optional
182 The importable name of an instrument to be added to the pipeline or
183 if no instrument should be added then an empty string or `None`, by
184 default None
186 Returns
187 -------
188 pipeline : `~lsst.pipe.base.Pipeline`
189 The created pipeline object.
190 """
191 pipeline = pipeBase.Pipeline("test pipeline")
192 # make a bunch of tasks that execute in well defined order (via data
193 # dependencies)
194 for lvl in range(nQuanta):
195 pipeline.addTask(AddTask, f"task{lvl}")
196 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}")
197 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}")
198 if instrument:
199 pipeline.addInstrument(instrument)
200 return pipeline
203def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
204 userQuery=""):
205 """Make simple QuantumGraph for tests.
207 Makes simple one-task pipeline with AddTask, sets up in-memory
208 registry and butler, fills them with minimal data, and generates
209 QuantumGraph with all of that.
211 Parameters
212 ----------
213 nQuanta : `int`
214 Number of quanta in a graph.
215 pipeline : `~lsst.pipe.base.Pipeline`
216 If `None` then one-task pipeline is made with `AddTask` and
217 default `AddTaskConfig`.
218 butler : `~lsst.daf.butler.Butler`, optional
219 Data butler instance, this should be an instance returned from a
220 previous call to this method.
221 root : `str`
222 Path or URI to the root location of the new repository. Only used if
223 ``butler`` is None.
224 skipExisting : `bool`, optional
225 If `True` (default), a Quantum is not created if all its outputs
226 already exist.
227 inMemory : `bool`, optional
228 If true make in-memory repository.
229 userQuery : `str`, optional
230 The user query to pass to ``makeGraph``, by default an empty string.
232 Returns
233 -------
234 butler : `~lsst.daf.butler.Butler`
235 Butler instance
236 qgraph : `~lsst.pipe.base.QuantumGraph`
237 Quantum graph instance
238 """
240 if pipeline is None:
241 pipeline = makeSimplePipeline(nQuanta=nQuanta)
243 if butler is None:
245 if root is None:
246 raise ValueError("Must provide `root` when `butler` is None")
248 config = Config()
249 if not inMemory:
250 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite"
251 config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore"
252 repo = butlerTests.makeTestRepo(root, {}, config=config)
253 collection = "test"
254 butler = Butler(butler=repo, run=collection)
256 # Add dataset types to registry
257 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())
259 # Add all needed dimensions to registry
260 butler.registry.insertDimensionData("instrument", dict(name="INSTR"))
261 butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0"))
263 # Add inputs to butler
264 data = numpy.array([0., 1., 2., 5.])
265 butler.put(data, "add_dataset0", instrument="INSTR", detector=0)
267 # Make the graph
268 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
269 qgraph = builder.makeGraph(
270 pipeline,
271 collections=CollectionSearch.fromExpression(butler.run),
272 run=butler.run,
273 userQuery=userQuery
274 )
276 return butler, qgraph