Coverage for tests/testUtil.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
29import numpy
31from lsst.daf.butler import (Butler, Config, DatasetType, CollectionSearch)
32import lsst.daf.butler.tests as butlerTests
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35from lsst.pipe.base import connectionTypes as cT
37_LOG = logging.getLogger(__name__)
40class AddTaskConnections(pipeBase.PipelineTaskConnections,
41 dimensions=("instrument", "detector"),
42 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
43 """Connections for AddTask, has one input and two outputs,
44 plus one init output.
45 """
46 input = cT.Input(name="add_dataset{in_tmpl}",
47 dimensions=["instrument", "detector"],
48 storageClass="NumpyArray",
49 doc="Input dataset type for this task")
50 output = cT.Output(name="add_dataset{out_tmpl}",
51 dimensions=["instrument", "detector"],
52 storageClass="NumpyArray",
53 doc="Output dataset type for this task")
54 output2 = cT.Output(name="add2_dataset{out_tmpl}",
55 dimensions=["instrument", "detector"],
56 storageClass="NumpyArray",
57 doc="Output dataset type for this task")
58 initout = cT.InitOutput(name="add_init_output{out_tmpl}",
59 storageClass="NumpyArray",
60 doc="Init Output dataset type for this task")
63class AddTaskConfig(pipeBase.PipelineTaskConfig,
64 pipelineConnections=AddTaskConnections):
65 """Config for AddTask.
66 """
67 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
70class AddTask(pipeBase.PipelineTask):
71 """Trivial PipelineTask for testing, has some extras useful for specific
72 unit tests.
73 """
75 ConfigClass = AddTaskConfig
76 _DefaultName = "add_task"
78 initout = numpy.array([999])
79 """InitOutputs for this task"""
81 taskFactory = None
82 """Factory that makes instances"""
84 def run(self, input):
86 if self.taskFactory:
87 # do some bookkeeping
88 if self.taskFactory.stopAt == self.taskFactory.countExec:
89 raise RuntimeError("pretend something bad happened")
90 self.taskFactory.countExec += 1
92 self.metadata.add("add", self.config.addend)
93 output = input + self.config.addend
94 output2 = output + self.config.addend
95 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
96 return pipeBase.Struct(output=output, output2=output2)
99class AddTaskFactoryMock(pipeBase.TaskFactory):
100 """Special task factory that instantiates AddTask.
102 It also defines some bookkeeping variables used by AddTask to report
103 progress to unit tests.
104 """
105 def __init__(self, stopAt=-1):
106 self.countExec = 0 # incremented by AddTask
107 self.stopAt = stopAt # AddTask raises exception at this call to run()
109 def loadTaskClass(self, taskName):
110 if taskName == "AddTask":
111 return AddTask, "AddTask"
113 def makeTask(self, taskClass, config, overrides, butler):
114 if config is None:
115 config = taskClass.ConfigClass()
116 if overrides:
117 overrides.applyTo(config)
118 task = taskClass(config=config, initInputs=None)
119 task.taskFactory = self
120 return task
123def registerDatasetTypes(registry, pipeline):
124 """Register all dataset types used by tasks in a registry.
126 Copied and modified from `PreExecInit.initializeDatasetTypes`.
128 Parameters
129 ----------
130 registry : `~lsst.daf.butler.Registry`
131 Registry instance.
132 pipeline : `~lsst.pipe.base.Pipeline`
133 Iterable of TaskDef instances.
134 """
135 for taskDef in pipeline:
136 configDatasetType = DatasetType(taskDef.configDatasetName, {},
137 storageClass="Config",
138 universe=registry.dimensions)
139 packagesDatasetType = DatasetType("packages", {},
140 storageClass="Packages",
141 universe=registry.dimensions)
142 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
143 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
144 datasetTypes.inputs, datasetTypes.outputs,
145 datasetTypes.prerequisites,
146 [configDatasetType, packagesDatasetType]):
147 _LOG.info("Registering %s with registry", datasetType)
148 # this is a no-op if it already exists and is consistent,
149 # and it raises if it is inconsistent. But components must be
150 # skipped
151 if not datasetType.isComponent():
152 registry.registerDatasetType(datasetType)
155def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True):
156 """Make simple QuantumGraph for tests.
158 Makes simple one-task pipeline with AddTask, sets up in-memory
159 registry and butler, fills them with minimal data, and generates
160 QuantumGraph with all of that.
162 Parameters
163 ----------
164 nQuanta : `int`
165 Number of quanta in a graph.
166 pipeline : `~lsst.pipe.base.Pipeline`
167 If `None` then one-task pipeline is made with `AddTask` and
168 default `AddTaskConfig`.
169 butler : `~lsst.daf.butler.Butler`, optional
170 Data butler instance, this should be an instance returned from a
171 previous call to this method.
172 root : `str`
173 Path or URI to the root location of the new repository. Only used if
174 ``butler`` is None.
175 skipExisting : `bool`, optional
176 If `True` (default), a Quantum is not created if all its outputs
177 already exist.
178 inMemory : `bool`, optional
179 If true make in-memory repository.
181 Returns
182 -------
183 butler : `~lsst.daf.butler.Butler`
184 Butler instance
185 qgraph : `~lsst.pipe.base.QuantumGraph`
186 Quantum graph instance
187 """
189 if pipeline is None:
190 pipeline = pipeBase.Pipeline("test pipeline")
191 # make a bunch of tasks that execute in well defined order (via data
192 # dependencies)
193 for lvl in range(nQuanta):
194 pipeline.addTask(AddTask, f"task{lvl}")
195 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}")
196 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}")
198 pipeline = list(pipeline.toExpandedPipeline())
200 if butler is None:
202 if root is None:
203 raise ValueError("Must provide `root` when `butler` is None")
205 config = Config()
206 if not inMemory:
207 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite"
208 config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore"
209 repo = butlerTests.makeTestRepo(root, {}, config=config)
210 collection = "test"
211 butler = Butler(butler=repo, run=collection)
213 # Add dataset types to registry
214 registerDatasetTypes(butler.registry, pipeline)
216 # Add all needed dimensions to registry
217 butler.registry.insertDimensionData("instrument", dict(name="INSTR"))
218 butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0"))
220 # Add inputs to butler
221 data = numpy.array([0., 1., 2., 5.])
222 butler.put(data, "add_dataset0", instrument="INSTR", detector=0)
224 # Make the graph
225 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
226 qgraph = builder.makeGraph(
227 pipeline,
228 collections=CollectionSearch.fromExpression(butler.run),
229 run=butler.run,
230 userQuery=""
231 )
233 return butler, qgraph