Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
29import numpy
31from lsst.base import Packages
32from lsst.daf.butler import Butler, ButlerURI, Config, DatasetType
33import lsst.daf.butler.tests as butlerTests
34from lsst.daf.butler.core.logging import ButlerLogRecords
35import lsst.pex.config as pexConfig
36from lsst.utils import doImport
37from ... import base as pipeBase
38from .. import connectionTypes as cT
39from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
40from ..task import _TASK_FULL_METADATA_TYPE
42_LOG = logging.getLogger(__name__)
45# SimpleInstrument has an instrument-like API as needed for unit testing, but
46# can not explicitly depend on Instrument because pipe_base does not explicitly
47# depend on obs_base.
48class SimpleInstrument:
50 def __init__(self, *args, **kwargs):
51 pass
53 @staticmethod
54 def getName():
55 return "INSTRU"
57 def applyConfigOverrides(self, name, config):
58 pass
61class AddTaskConnections(pipeBase.PipelineTaskConnections,
62 dimensions=("instrument", "detector"),
63 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}):
64 """Connections for AddTask, has one input and two outputs,
65 plus one init output.
66 """
67 input = cT.Input(name="add_dataset{in_tmpl}",
68 dimensions=["instrument", "detector"],
69 storageClass="NumpyArray",
70 doc="Input dataset type for this task")
71 output = cT.Output(name="add_dataset{out_tmpl}",
72 dimensions=["instrument", "detector"],
73 storageClass="NumpyArray",
74 doc="Output dataset type for this task")
75 output2 = cT.Output(name="add2_dataset{out_tmpl}",
76 dimensions=["instrument", "detector"],
77 storageClass="NumpyArray",
78 doc="Output dataset type for this task")
79 initout = cT.InitOutput(name="add_init_output{out_tmpl}",
80 storageClass="NumpyArray",
81 doc="Init Output dataset type for this task")
84class AddTaskConfig(pipeBase.PipelineTaskConfig,
85 pipelineConnections=AddTaskConnections):
86 """Config for AddTask.
87 """
88 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
91class AddTask(pipeBase.PipelineTask):
92 """Trivial PipelineTask for testing, has some extras useful for specific
93 unit tests.
94 """
96 ConfigClass = AddTaskConfig
97 _DefaultName = "add_task"
99 initout = numpy.array([999])
100 """InitOutputs for this task"""
102 taskFactory = None
103 """Factory that makes instances"""
105 def run(self, input):
107 if self.taskFactory:
108 # do some bookkeeping
109 if self.taskFactory.stopAt == self.taskFactory.countExec:
110 raise RuntimeError("pretend something bad happened")
111 self.taskFactory.countExec += 1
113 self.metadata.add("add", self.config.addend)
114 output = input + self.config.addend
115 output2 = output + self.config.addend
116 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
117 return pipeBase.Struct(output=output, output2=output2)
120class AddTaskFactoryMock(pipeBase.TaskFactory):
121 """Special task factory that instantiates AddTask.
123 It also defines some bookkeeping variables used by AddTask to report
124 progress to unit tests.
125 """
126 def __init__(self, stopAt=-1):
127 self.countExec = 0 # incremented by AddTask
128 self.stopAt = stopAt # AddTask raises exception at this call to run()
130 def makeTask(self, taskClass, name, config, overrides, butler):
131 if config is None:
132 config = taskClass.ConfigClass()
133 if overrides:
134 overrides.applyTo(config)
135 task = taskClass(config=config, initInputs=None, name=name)
136 task.taskFactory = self
137 return task
140def registerDatasetTypes(registry, pipeline):
141 """Register all dataset types used by tasks in a registry.
143 Copied and modified from `PreExecInit.initializeDatasetTypes`.
145 Parameters
146 ----------
147 registry : `~lsst.daf.butler.Registry`
148 Registry instance.
149 pipeline : `typing.Iterable` of `TaskDef`
150 Iterable of TaskDef instances, likely the output of the method
151 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
152 """
153 for taskDef in pipeline:
154 configDatasetType = DatasetType(taskDef.configDatasetName, {},
155 storageClass="Config",
156 universe=registry.dimensions)
157 packagesDatasetType = DatasetType("packages", {},
158 storageClass="Packages",
159 universe=registry.dimensions)
160 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
161 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
162 datasetTypes.inputs, datasetTypes.outputs,
163 datasetTypes.prerequisites,
164 [configDatasetType, packagesDatasetType]):
165 _LOG.info("Registering %s with registry", datasetType)
166 # this is a no-op if it already exists and is consistent,
167 # and it raises if it is inconsistent. But components must be
168 # skipped
169 if not datasetType.isComponent():
170 registry.registerDatasetType(datasetType)
173def makeSimplePipeline(nQuanta, instrument=None):
174 """Make a simple Pipeline for tests.
176 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
177 function. It can also be used to customize the pipeline used by
178 ``makeSimpleQGraph`` function by calling this first and passing the result
179 to it.
181 Parameters
182 ----------
183 nQuanta : `int`
184 The number of quanta to add to the pipeline.
185 instrument : `str` or `None`, optional
186 The importable name of an instrument to be added to the pipeline or
187 if no instrument should be added then an empty string or `None`, by
188 default None
190 Returns
191 -------
192 pipeline : `~lsst.pipe.base.Pipeline`
193 The created pipeline object.
194 """
195 pipeline = pipeBase.Pipeline("test pipeline")
196 # make a bunch of tasks that execute in well defined order (via data
197 # dependencies)
198 for lvl in range(nQuanta):
199 pipeline.addTask(AddTask, f"task{lvl}")
200 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
201 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl+1)
202 if instrument:
203 pipeline.addInstrument(instrument)
204 return pipeline
207def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler:
208 """Create new data butler instance.
210 Parameters
211 ----------
212 root : `str`
213 Path or URI to the root location of the new repository.
214 run : `str`, optional
215 Run collection name.
216 inMemory : `bool`, optional
217 If true make in-memory repository.
219 Returns
220 -------
221 butler : `~lsst.daf.butler.Butler`
222 Data butler instance.
223 """
224 root = ButlerURI(root, forceDirectory=True)
225 if not root.isLocal:
226 raise ValueError(f"Only works with local root not {root}")
227 config = Config()
228 if not inMemory:
229 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite"
230 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
231 repo = butlerTests.makeTestRepo(root, {}, config=config)
232 butler = Butler(butler=repo, run=run)
233 return butler
236def populateButler(pipeline, butler, datasetTypes=None):
237 """Populate data butler with data needed for test.
239 Initializes data butler with a bunch of items:
240 - registers dataset types which are defined by pipeline
241 - create dimensions data for (instrument, detector)
242 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
243 missing then a single dataset with type "add_dataset0" is added
245 All datasets added to butler have ``dataId={instrument=instrument,
246 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
247 used if pipeline is missing instrument definition. Type of the dataset is
248 guessed from dataset type name (assumes that pipeline is made of `AddTask`
249 tasks).
251 Parameters
252 ----------
253 pipeline : `~lsst.pipe.base.Pipeline`
254 Pipeline instance.
255 butler : `~lsst.daf.butler.Butler`
256 Data butler instance.
257 datasetTypes : `dict` [ `str`, `list` ], optional
258 Dictionary whose keys are collection names and values are lists of
259 dataset type names. By default a single dataset of type "add_dataset0"
260 is added to a ``butler.run`` collection.
261 """
263 # Add dataset types to registry
264 taskDefs = list(pipeline.toExpandedPipeline())
265 registerDatasetTypes(butler.registry, taskDefs)
267 instrument = pipeline.getInstrument()
268 if instrument is not None:
269 if isinstance(instrument, str):
270 instrument = doImport(instrument)
271 instrumentName = instrument.getName()
272 else:
273 instrumentName = "INSTR"
275 # Add all needed dimensions to registry
276 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
277 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName,
278 id=0, full_name="det0"))
280 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
281 # Add inputs to butler
282 if not datasetTypes:
283 datasetTypes = {None: ["add_dataset0"]}
284 for run, dsTypes in datasetTypes.items():
285 if run is not None:
286 butler.registry.registerRun(run)
287 for dsType in dsTypes:
288 if dsType == "packages":
289 # Version is intentionally inconsistent
290 data = Packages({"python": "9.9.99"})
291 butler.put(data, dsType, run=run)
292 else:
293 if dsType.endswith("_config"):
294 # find a confing from matching task name or make a new one
295 taskLabel, _, _ = dsType.rpartition("_")
296 taskDef = taskDefMap.get(taskLabel)
297 if taskDef is not None:
298 data = taskDef.config
299 else:
300 data = AddTaskConfig()
301 elif dsType.endswith("_metadata"):
302 data = _TASK_FULL_METADATA_TYPE()
303 elif dsType.endswith("_log"):
304 data = ButlerLogRecords.from_records([])
305 else:
306 data = numpy.array([0., 1., 2., 5.])
307 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
310def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, callPopulateButler=True, run="test",
311 skipExistingIn=None, inMemory=True, userQuery="",
312 datasetTypes=None,
313 datasetQueryConstraint: DSQVariant = DSQVariant.ALL):
314 """Make simple QuantumGraph for tests.
316 Makes simple one-task pipeline with AddTask, sets up in-memory registry
317 and butler, fills them with minimal data, and generates QuantumGraph with
318 all of that.
320 Parameters
321 ----------
322 nQuanta : `int`
323 Number of quanta in a graph, only used if ``pipeline`` is None.
324 pipeline : `~lsst.pipe.base.Pipeline`
325 If `None` then a pipeline is made with `AddTask` and default
326 `AddTaskConfig`.
327 butler : `~lsst.daf.butler.Butler`, optional
328 Data butler instance, if None then new data butler is created by
329 calling `makeSimpleButler`.
330 callPopulateButler : `bool`, optional
331 If True insert datasets into the butler prior to building a graph.
332 If False butler argument must not be None, and must be pre-populated.
333 Defaults to True.
334 root : `str`
335 Path or URI to the root location of the new repository. Only used if
336 ``butler`` is None.
337 run : `str`, optional
338 Name of the RUN collection to add to butler, only used if ``butler``
339 is None.
340 skipExistingIn
341 Expressions representing the collections to search for existing
342 output datasets that should be skipped. May be any of the types
343 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
344 inMemory : `bool`, optional
345 If true make in-memory repository, only used if ``butler`` is `None`.
346 userQuery : `str`, optional
347 The user query to pass to ``makeGraph``, by default an empty string.
348 datasetTypes : `dict` [ `str`, `list` ], optional
349 Dictionary whose keys are collection names and values are lists of
350 dataset type names. By default a single dataset of type "add_dataset0"
351 is added to a ``butler.run`` collection.
352 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
353 The query constraint variant that should be used to constrain the
354 query based on dataset existence, defaults to
355 `DatasetQueryConstraintVariant.ALL`.
357 Returns
358 -------
359 butler : `~lsst.daf.butler.Butler`
360 Butler instance
361 qgraph : `~lsst.pipe.base.QuantumGraph`
362 Quantum graph instance
363 """
365 if pipeline is None:
366 pipeline = makeSimplePipeline(nQuanta=nQuanta)
368 if butler is None:
369 if root is None:
370 raise ValueError("Must provide `root` when `butler` is None")
371 if callPopulateButler is False:
372 raise ValueError("populateButler can only be False when butler is supplied as an argument")
373 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
375 if callPopulateButler:
376 populateButler(pipeline, butler, datasetTypes=datasetTypes)
378 # Make the graph
379 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
380 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn)
381 _LOG.debug("Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r",
382 butler.collections, run or butler.run, userQuery)
383 qgraph = builder.makeGraph(
384 pipeline,
385 collections=butler.collections,
386 run=run or butler.run,
387 userQuery=userQuery,
388 datasetQueryConstraint=datasetQueryConstraint
389 )
391 return butler, qgraph