Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
30import lsst.daf.butler.tests as butlerTests
31import lsst.pex.config as pexConfig
32import numpy
34try:
35 from lsst.base import Packages
36except ImportError:
37 Packages = None
38from lsst.daf.butler import Butler, Config, DatasetType
39from lsst.daf.butler.core.logging import ButlerLogRecords
40from lsst.resources import ResourcePath
41from lsst.utils import doImport
43from ... import base as pipeBase
44from .. import connectionTypes as cT
45from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
46from ..task import _TASK_FULL_METADATA_TYPE
48_LOG = logging.getLogger(__name__)
51# SimpleInstrument has an instrument-like API as needed for unit testing, but
52# can not explicitly depend on Instrument because pipe_base does not explicitly
53# depend on obs_base.
54class SimpleInstrument:
55 def __init__(self, *args, **kwargs):
56 pass
58 @staticmethod
59 def getName():
60 return "INSTRU"
62 def applyConfigOverrides(self, name, config):
63 pass
66class AddTaskConnections(
67 pipeBase.PipelineTaskConnections,
68 dimensions=("instrument", "detector"),
69 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
70):
71 """Connections for AddTask, has one input and two outputs,
72 plus one init output.
73 """
75 input = cT.Input(
76 name="add_dataset{in_tmpl}",
77 dimensions=["instrument", "detector"],
78 storageClass="NumpyArray",
79 doc="Input dataset type for this task",
80 )
81 output = cT.Output(
82 name="add_dataset{out_tmpl}",
83 dimensions=["instrument", "detector"],
84 storageClass="NumpyArray",
85 doc="Output dataset type for this task",
86 )
87 output2 = cT.Output(
88 name="add2_dataset{out_tmpl}",
89 dimensions=["instrument", "detector"],
90 storageClass="NumpyArray",
91 doc="Output dataset type for this task",
92 )
93 initout = cT.InitOutput(
94 name="add_init_output{out_tmpl}",
95 storageClass="NumpyArray",
96 doc="Init Output dataset type for this task",
97 )
100class AddTaskConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddTaskConnections):
101 """Config for AddTask."""
103 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
106class AddTask(pipeBase.PipelineTask):
107 """Trivial PipelineTask for testing, has some extras useful for specific
108 unit tests.
109 """
111 ConfigClass = AddTaskConfig
112 _DefaultName = "add_task"
114 initout = numpy.array([999])
115 """InitOutputs for this task"""
117 taskFactory = None
118 """Factory that makes instances"""
120 def run(self, input):
122 if self.taskFactory:
123 # do some bookkeeping
124 if self.taskFactory.stopAt == self.taskFactory.countExec:
125 raise RuntimeError("pretend something bad happened")
126 self.taskFactory.countExec += 1
128 self.metadata.add("add", self.config.addend)
129 output = input + self.config.addend
130 output2 = output + self.config.addend
131 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
132 return pipeBase.Struct(output=output, output2=output2)
135class AddTaskFactoryMock(pipeBase.TaskFactory):
136 """Special task factory that instantiates AddTask.
138 It also defines some bookkeeping variables used by AddTask to report
139 progress to unit tests.
140 """
142 def __init__(self, stopAt=-1):
143 self.countExec = 0 # incremented by AddTask
144 self.stopAt = stopAt # AddTask raises exception at this call to run()
146 def makeTask(self, taskClass, name, config, overrides, butler):
147 if config is None:
148 config = taskClass.ConfigClass()
149 if overrides:
150 overrides.applyTo(config)
151 task = taskClass(config=config, initInputs=None, name=name)
152 task.taskFactory = self
153 return task
156def registerDatasetTypes(registry, pipeline):
157 """Register all dataset types used by tasks in a registry.
159 Copied and modified from `PreExecInit.initializeDatasetTypes`.
161 Parameters
162 ----------
163 registry : `~lsst.daf.butler.Registry`
164 Registry instance.
165 pipeline : `typing.Iterable` of `TaskDef`
166 Iterable of TaskDef instances, likely the output of the method
167 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
168 """
169 for taskDef in pipeline:
170 configDatasetType = DatasetType(
171 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
172 )
173 storageClass = "Packages" if Packages is not None else "StructuredDataDict"
174 packagesDatasetType = DatasetType(
175 "packages", {}, storageClass=storageClass, universe=registry.dimensions
176 )
177 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
178 for datasetType in itertools.chain(
179 datasetTypes.initInputs,
180 datasetTypes.initOutputs,
181 datasetTypes.inputs,
182 datasetTypes.outputs,
183 datasetTypes.prerequisites,
184 [configDatasetType, packagesDatasetType],
185 ):
186 _LOG.info("Registering %s with registry", datasetType)
187 # this is a no-op if it already exists and is consistent,
188 # and it raises if it is inconsistent. But components must be
189 # skipped
190 if not datasetType.isComponent():
191 registry.registerDatasetType(datasetType)
194def makeSimplePipeline(nQuanta, instrument=None):
195 """Make a simple Pipeline for tests.
197 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
198 function. It can also be used to customize the pipeline used by
199 ``makeSimpleQGraph`` function by calling this first and passing the result
200 to it.
202 Parameters
203 ----------
204 nQuanta : `int`
205 The number of quanta to add to the pipeline.
206 instrument : `str` or `None`, optional
207 The importable name of an instrument to be added to the pipeline or
208 if no instrument should be added then an empty string or `None`, by
209 default None
211 Returns
212 -------
213 pipeline : `~lsst.pipe.base.Pipeline`
214 The created pipeline object.
215 """
216 pipeline = pipeBase.Pipeline("test pipeline")
217 # make a bunch of tasks that execute in well defined order (via data
218 # dependencies)
219 for lvl in range(nQuanta):
220 pipeline.addTask(AddTask, f"task{lvl}")
221 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
222 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
223 if instrument:
224 pipeline.addInstrument(instrument)
225 return pipeline
228def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler:
229 """Create new data butler instance.
231 Parameters
232 ----------
233 root : `str`
234 Path or URI to the root location of the new repository.
235 run : `str`, optional
236 Run collection name.
237 inMemory : `bool`, optional
238 If true make in-memory repository.
240 Returns
241 -------
242 butler : `~lsst.daf.butler.Butler`
243 Data butler instance.
244 """
245 root = ResourcePath(root, forceDirectory=True)
246 if not root.isLocal:
247 raise ValueError(f"Only works with local root not {root}")
248 config = Config()
249 if not inMemory:
250 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite"
251 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
252 repo = butlerTests.makeTestRepo(root, {}, config=config)
253 butler = Butler(butler=repo, run=run)
254 return butler
257def populateButler(pipeline, butler, datasetTypes=None):
258 """Populate data butler with data needed for test.
260 Initializes data butler with a bunch of items:
261 - registers dataset types which are defined by pipeline
262 - create dimensions data for (instrument, detector)
263 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
264 missing then a single dataset with type "add_dataset0" is added
266 All datasets added to butler have ``dataId={instrument=instrument,
267 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
268 used if pipeline is missing instrument definition. Type of the dataset is
269 guessed from dataset type name (assumes that pipeline is made of `AddTask`
270 tasks).
272 Parameters
273 ----------
274 pipeline : `~lsst.pipe.base.Pipeline`
275 Pipeline instance.
276 butler : `~lsst.daf.butler.Butler`
277 Data butler instance.
278 datasetTypes : `dict` [ `str`, `list` ], optional
279 Dictionary whose keys are collection names and values are lists of
280 dataset type names. By default a single dataset of type "add_dataset0"
281 is added to a ``butler.run`` collection.
282 """
284 # Add dataset types to registry
285 taskDefs = list(pipeline.toExpandedPipeline())
286 registerDatasetTypes(butler.registry, taskDefs)
288 instrument = pipeline.getInstrument()
289 if instrument is not None:
290 if isinstance(instrument, str):
291 instrument = doImport(instrument)
292 instrumentName = instrument.getName()
293 else:
294 instrumentName = "INSTR"
296 # Add all needed dimensions to registry
297 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
298 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
300 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
301 # Add inputs to butler
302 if not datasetTypes:
303 datasetTypes = {None: ["add_dataset0"]}
304 for run, dsTypes in datasetTypes.items():
305 if run is not None:
306 butler.registry.registerRun(run)
307 for dsType in dsTypes:
308 if dsType == "packages":
309 # Version is intentionally inconsistent.
310 # Dict is convertible to Packages if Packages is installed.
311 data = {"python": "9.9.99"}
312 butler.put(data, dsType, run=run)
313 else:
314 if dsType.endswith("_config"):
315 # find a confing from matching task name or make a new one
316 taskLabel, _, _ = dsType.rpartition("_")
317 taskDef = taskDefMap.get(taskLabel)
318 if taskDef is not None:
319 data = taskDef.config
320 else:
321 data = AddTaskConfig()
322 elif dsType.endswith("_metadata"):
323 data = _TASK_FULL_METADATA_TYPE()
324 elif dsType.endswith("_log"):
325 data = ButlerLogRecords.from_records([])
326 else:
327 data = numpy.array([0.0, 1.0, 2.0, 5.0])
328 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
331def makeSimpleQGraph(
332 nQuanta=5,
333 pipeline=None,
334 butler=None,
335 root=None,
336 callPopulateButler=True,
337 run="test",
338 skipExistingIn=None,
339 inMemory=True,
340 userQuery="",
341 datasetTypes=None,
342 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
343):
344 """Make simple QuantumGraph for tests.
346 Makes simple one-task pipeline with AddTask, sets up in-memory registry
347 and butler, fills them with minimal data, and generates QuantumGraph with
348 all of that.
350 Parameters
351 ----------
352 nQuanta : `int`
353 Number of quanta in a graph, only used if ``pipeline`` is None.
354 pipeline : `~lsst.pipe.base.Pipeline`
355 If `None` then a pipeline is made with `AddTask` and default
356 `AddTaskConfig`.
357 butler : `~lsst.daf.butler.Butler`, optional
358 Data butler instance, if None then new data butler is created by
359 calling `makeSimpleButler`.
360 callPopulateButler : `bool`, optional
361 If True insert datasets into the butler prior to building a graph.
362 If False butler argument must not be None, and must be pre-populated.
363 Defaults to True.
364 root : `str`
365 Path or URI to the root location of the new repository. Only used if
366 ``butler`` is None.
367 run : `str`, optional
368 Name of the RUN collection to add to butler, only used if ``butler``
369 is None.
370 skipExistingIn
371 Expressions representing the collections to search for existing
372 output datasets that should be skipped. May be any of the types
373 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
374 inMemory : `bool`, optional
375 If true make in-memory repository, only used if ``butler`` is `None`.
376 userQuery : `str`, optional
377 The user query to pass to ``makeGraph``, by default an empty string.
378 datasetTypes : `dict` [ `str`, `list` ], optional
379 Dictionary whose keys are collection names and values are lists of
380 dataset type names. By default a single dataset of type "add_dataset0"
381 is added to a ``butler.run`` collection.
382 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
383 The query constraint variant that should be used to constrain the
384 query based on dataset existence, defaults to
385 `DatasetQueryConstraintVariant.ALL`.
387 Returns
388 -------
389 butler : `~lsst.daf.butler.Butler`
390 Butler instance
391 qgraph : `~lsst.pipe.base.QuantumGraph`
392 Quantum graph instance
393 """
395 if pipeline is None:
396 pipeline = makeSimplePipeline(nQuanta=nQuanta)
398 if butler is None:
399 if root is None:
400 raise ValueError("Must provide `root` when `butler` is None")
401 if callPopulateButler is False:
402 raise ValueError("populateButler can only be False when butler is supplied as an argument")
403 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
405 if callPopulateButler:
406 populateButler(pipeline, butler, datasetTypes=datasetTypes)
408 # Make the graph
409 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
410 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn)
411 _LOG.debug(
412 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r",
413 butler.collections,
414 run or butler.run,
415 userQuery,
416 )
417 qgraph = builder.makeGraph(
418 pipeline,
419 collections=butler.collections,
420 run=run or butler.run,
421 userQuery=userQuery,
422 datasetQueryConstraint=datasetQueryConstraint,
423 )
425 return butler, qgraph