Coverage for tests/testUtil.py : 23%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock", "ButlerMock"]
27import contextlib
28import itertools
29import logging
30import numpy
31import os
32from types import SimpleNamespace
34from lsst.daf.butler import (ButlerConfig, DatasetRef, DimensionUniverse,
35 DatasetType, Registry, CollectionSearch)
36import lsst.pex.config as pexConfig
37import lsst.pipe.base as pipeBase
38from lsst.pipe.base import connectionTypes as cT
40_LOG = logging.getLogger(__name__)
43class AddTaskConnections(pipeBase.PipelineTaskConnections,
44 dimensions=("instrument", "detector")):
45 input = cT.Input(name="add_input",
46 dimensions=["instrument", "detector"],
47 storageClass="NumpyArray",
48 doc="Input dataset type for this task")
49 output = cT.Output(name="add_output",
50 dimensions=["instrument", "detector"],
51 storageClass="NumpyArray",
52 doc="Output dataset type for this task")
53 initout = cT.InitOutput(name="add_init_output",
54 storageClass="NumpyArray",
55 doc="Init Output dataset type for this task")
58class AddTaskConfig(pipeBase.PipelineTaskConfig,
59 pipelineConnections=AddTaskConnections):
60 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
63# example task which overrides run() method
64class AddTask(pipeBase.PipelineTask):
65 ConfigClass = AddTaskConfig
66 _DefaultName = "add_task"
68 initout = numpy.array([999])
69 """InitOutputs for this task"""
71 countExec = 0
72 """Number of times run() method was called for this class"""
74 stopAt = -1
75 """Raises exception at this call to run()"""
77 def run(self, input):
78 if AddTask.stopAt == AddTask.countExec:
79 raise RuntimeError("pretend something bad happened")
80 AddTask.countExec += 1
81 self.metadata.add("add", self.config.addend)
82 output = [val + self.config.addend for val in input]
83 _LOG.info("input = %s, output = %s", input, output)
84 return pipeBase.Struct(output=output)
87class AddTaskFactoryMock(pipeBase.TaskFactory):
88 def loadTaskClass(self, taskName):
89 if taskName == "AddTask":
90 return AddTask, "AddTask"
92 def makeTask(self, taskClass, config, overrides, butler):
93 if config is None:
94 config = taskClass.ConfigClass()
95 if overrides:
96 overrides.applyTo(config)
97 return taskClass(config=config, initInputs=None)
100class ButlerMock:
101 """Mock version of butler, only usable for testing
103 Parameters
104 ----------
105 fullRegistry : `boolean`, optional
106 If True then instantiate SQLite registry with default configuration.
107 If False then registry is just a namespace with `dimensions` attribute
108 containing DimensionUniverse from default configuration.
109 """
110 def __init__(self, fullRegistry=False, collection="TestColl"):
111 self.datasets = {}
112 self.fullRegistry = fullRegistry
113 if self.fullRegistry:
114 testDir = os.path.dirname(__file__)
115 configFile = os.path.join(testDir, "config/butler.yaml")
116 butlerConfig = ButlerConfig(configFile)
117 self.registry = Registry.fromConfig(butlerConfig, create=True)
118 self.registry.registerRun(collection)
119 self.run = collection
120 else:
121 self.registry = SimpleNamespace(dimensions=DimensionUniverse.fromConfig())
122 self.run = collection
124 def _standardizeArgs(self, datasetRefOrType, dataId=None, **kwds):
125 """Copied from real Butler
126 """
127 if isinstance(datasetRefOrType, DatasetRef):
128 if dataId is not None or kwds:
129 raise ValueError("DatasetRef given, cannot use dataId as well")
130 datasetType = datasetRefOrType.datasetType
131 dataId = datasetRefOrType.dataId
132 else:
133 # Don't check whether DataId is provided, because Registry APIs
134 # can usually construct a better error message when it wasn't.
135 if isinstance(datasetRefOrType, DatasetType):
136 datasetType = datasetRefOrType
137 else:
138 datasetType = self.registry.getDatasetType(datasetRefOrType)
139 return datasetType, dataId
141 @staticmethod
142 def key(dataId):
143 """Make a dict key out of dataId.
144 """
145 return frozenset(dataId.items())
147 @contextlib.contextmanager
148 def transaction(self):
149 yield
151 def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds):
152 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds)
153 _LOG.info("butler.get: datasetType=%s dataId=%s", datasetType.name, dataId)
154 dsTypeName = datasetType.name
155 key = self.key(dataId)
156 dsdata = self.datasets.get(dsTypeName)
157 if dsdata:
158 return dsdata.get(key)
159 raise LookupError
161 def put(self, obj, datasetRefOrType, dataId=None, **kwds):
162 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds)
163 _LOG.info("butler.put: datasetType=%s dataId=%s obj=%r", datasetType.name, dataId, obj)
164 dsTypeName = datasetType.name
165 key = self.key(dataId)
166 dsdata = self.datasets.setdefault(dsTypeName, {})
167 dsdata[key] = obj
168 if self.fullRegistry:
169 ref = self.registry.insertDatasets(datasetType, dataIds=[dataId], run=self.run, **kwds)
170 else:
171 # we should return DatasetRef with reasonable ID, ID is supposed to be unique
172 refId = sum(len(val) for val in self.datasets.values())
173 ref = DatasetRef(datasetType, dataId, id=refId)
174 return ref
176 def remove(self, datasetRefOrType, dataId=None, *, delete=True, remember=True, **kwds):
177 datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds)
178 _LOG.info("butler.remove: datasetType=%s dataId=%s", datasetType.name, dataId)
179 dsTypeName = datasetType.name
180 key = self.key(dataId)
181 dsdata = self.datasets.get(dsTypeName)
182 del dsdata[key]
183 ref = self.registry.find(self.run, datasetType, dataId, **kwds)
184 if remember:
185 self.registry.disassociate(self.run, [ref])
186 else:
187 self.registry.removeDatasets([ref])
190def registerDatasetTypes(registry, pipeline):
191 """Register all dataset types used by tasks in a registry.
193 Copied and modified from `PreExecInit.initializeDatasetTypes`.
195 Parameters
196 ----------
197 registry : `~lsst.daf.butler.Registry`
198 Registry instance.
199 pipeline : `~lsst.pipe.base.Pipeline`
200 Iterable of TaskDef instances.
201 """
202 for taskDef in pipeline:
203 configDatasetType = DatasetType(taskDef.configDatasetName, {},
204 storageClass="Config",
205 universe=registry.dimensions)
206 packagesDatasetType = DatasetType("packages", {},
207 storageClass="Packages",
208 universe=registry.dimensions)
209 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
210 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
211 datasetTypes.inputs, datasetTypes.outputs,
212 datasetTypes.prerequisites,
213 [configDatasetType, packagesDatasetType]):
214 _LOG.info("Registering %s with registry", datasetType)
215 # this is a no-op if it already exists and is consistent,
216 # and it raises if it is inconsistent.
217 registry.registerDatasetType(datasetType)
220def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, skipExisting=False):
221 """Make simple QuantumGraph for tests.
223 Makes simple one-task pipeline with AddTask, sets up in-memory
224 registry and butler, fills them with minimal data, and generates
225 QuantumGraph with all of that.
227 Parameters
228 ----------
229 nQuanta : `int`
230 Number of quanta in a graph.
231 pipeline : `~lsst.pipe.base.Pipeline`
232 If `None` then one-task pipeline is made with `AddTask` and
233 default `AddTaskConfig`.
234 butler : `~lsst.daf.butler.Butler`, optional
235 Data butler instance, this should be an instance returned from a
236 previous call to this method.
237 skipExisting : `bool`, optional
238 If `True` (default), a Quantum is not created if all its outputs
239 already exist.
241 Returns
242 -------
243 butler : `~lsst.daf.butler.Butler`
244 Butler instance
245 qgraph : `~lsst.pipe.base.QuantumGraph`
246 Quantum graph instance
247 """
249 if pipeline is None:
250 pipeline = pipeBase.Pipeline("test pipeline")
251 pipeline.addTask(AddTask, "task1")
252 pipeline = list(pipeline.toExpandedPipeline())
254 if butler is None:
256 butler = ButlerMock(fullRegistry=True)
258 # Add dataset types to registry
259 registerDatasetTypes(butler.registry, pipeline)
261 # Small set of DataIds included in QGraph
262 records = [dict(instrument="INSTR", id=i, full_name=str(i)) for i in range(nQuanta)]
263 dataIds = [dict(instrument="INSTR", detector=detector) for detector in range(nQuanta)]
265 # Add all needed dimensions to registry
266 butler.registry.insertDimensionData("instrument", dict(name="INSTR"))
267 butler.registry.insertDimensionData("detector", *records)
269 # Add inputs to butler
270 for i, dataId in enumerate(dataIds):
271 data = numpy.array([i, 10*i])
272 butler.put(data, "add_input", dataId)
274 # Make the graph, task factory is not needed here
275 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
276 qgraph = builder.makeGraph(
277 pipeline,
278 collections=CollectionSearch.fromExpression(butler.run),
279 run=butler.run,
280 userQuery=""
281 )
283 return butler, qgraph