Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
30import lsst.daf.butler.tests as butlerTests
31import lsst.pex.config as pexConfig
32import numpy
33from lsst.base import Packages
34from lsst.daf.butler import Butler, ButlerURI, Config, DatasetType
35from lsst.daf.butler.core.logging import ButlerLogRecords
36from lsst.utils import doImport
38from ... import base as pipeBase
39from .. import connectionTypes as cT
40from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
41from ..task import _TASK_FULL_METADATA_TYPE
43_LOG = logging.getLogger(__name__)
46# SimpleInstrument has an instrument-like API as needed for unit testing, but
47# can not explicitly depend on Instrument because pipe_base does not explicitly
48# depend on obs_base.
49class SimpleInstrument:
50 def __init__(self, *args, **kwargs):
51 pass
53 @staticmethod
54 def getName():
55 return "INSTRU"
57 def applyConfigOverrides(self, name, config):
58 pass
61class AddTaskConnections(
62 pipeBase.PipelineTaskConnections,
63 dimensions=("instrument", "detector"),
64 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
65):
66 """Connections for AddTask, has one input and two outputs,
67 plus one init output.
68 """
70 input = cT.Input(
71 name="add_dataset{in_tmpl}",
72 dimensions=["instrument", "detector"],
73 storageClass="NumpyArray",
74 doc="Input dataset type for this task",
75 )
76 output = cT.Output(
77 name="add_dataset{out_tmpl}",
78 dimensions=["instrument", "detector"],
79 storageClass="NumpyArray",
80 doc="Output dataset type for this task",
81 )
82 output2 = cT.Output(
83 name="add2_dataset{out_tmpl}",
84 dimensions=["instrument", "detector"],
85 storageClass="NumpyArray",
86 doc="Output dataset type for this task",
87 )
88 initout = cT.InitOutput(
89 name="add_init_output{out_tmpl}",
90 storageClass="NumpyArray",
91 doc="Init Output dataset type for this task",
92 )
95class AddTaskConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddTaskConnections):
96 """Config for AddTask."""
98 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
101class AddTask(pipeBase.PipelineTask):
102 """Trivial PipelineTask for testing, has some extras useful for specific
103 unit tests.
104 """
106 ConfigClass = AddTaskConfig
107 _DefaultName = "add_task"
109 initout = numpy.array([999])
110 """InitOutputs for this task"""
112 taskFactory = None
113 """Factory that makes instances"""
115 def run(self, input):
117 if self.taskFactory:
118 # do some bookkeeping
119 if self.taskFactory.stopAt == self.taskFactory.countExec:
120 raise RuntimeError("pretend something bad happened")
121 self.taskFactory.countExec += 1
123 self.metadata.add("add", self.config.addend)
124 output = input + self.config.addend
125 output2 = output + self.config.addend
126 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
127 return pipeBase.Struct(output=output, output2=output2)
130class AddTaskFactoryMock(pipeBase.TaskFactory):
131 """Special task factory that instantiates AddTask.
133 It also defines some bookkeeping variables used by AddTask to report
134 progress to unit tests.
135 """
137 def __init__(self, stopAt=-1):
138 self.countExec = 0 # incremented by AddTask
139 self.stopAt = stopAt # AddTask raises exception at this call to run()
141 def makeTask(self, taskClass, name, config, overrides, butler):
142 if config is None:
143 config = taskClass.ConfigClass()
144 if overrides:
145 overrides.applyTo(config)
146 task = taskClass(config=config, initInputs=None, name=name)
147 task.taskFactory = self
148 return task
151def registerDatasetTypes(registry, pipeline):
152 """Register all dataset types used by tasks in a registry.
154 Copied and modified from `PreExecInit.initializeDatasetTypes`.
156 Parameters
157 ----------
158 registry : `~lsst.daf.butler.Registry`
159 Registry instance.
160 pipeline : `typing.Iterable` of `TaskDef`
161 Iterable of TaskDef instances, likely the output of the method
162 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
163 """
164 for taskDef in pipeline:
165 configDatasetType = DatasetType(
166 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
167 )
168 packagesDatasetType = DatasetType(
169 "packages", {}, storageClass="Packages", universe=registry.dimensions
170 )
171 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
172 for datasetType in itertools.chain(
173 datasetTypes.initInputs,
174 datasetTypes.initOutputs,
175 datasetTypes.inputs,
176 datasetTypes.outputs,
177 datasetTypes.prerequisites,
178 [configDatasetType, packagesDatasetType],
179 ):
180 _LOG.info("Registering %s with registry", datasetType)
181 # this is a no-op if it already exists and is consistent,
182 # and it raises if it is inconsistent. But components must be
183 # skipped
184 if not datasetType.isComponent():
185 registry.registerDatasetType(datasetType)
188def makeSimplePipeline(nQuanta, instrument=None):
189 """Make a simple Pipeline for tests.
191 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
192 function. It can also be used to customize the pipeline used by
193 ``makeSimpleQGraph`` function by calling this first and passing the result
194 to it.
196 Parameters
197 ----------
198 nQuanta : `int`
199 The number of quanta to add to the pipeline.
200 instrument : `str` or `None`, optional
201 The importable name of an instrument to be added to the pipeline or
202 if no instrument should be added then an empty string or `None`, by
203 default None
205 Returns
206 -------
207 pipeline : `~lsst.pipe.base.Pipeline`
208 The created pipeline object.
209 """
210 pipeline = pipeBase.Pipeline("test pipeline")
211 # make a bunch of tasks that execute in well defined order (via data
212 # dependencies)
213 for lvl in range(nQuanta):
214 pipeline.addTask(AddTask, f"task{lvl}")
215 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
216 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
217 if instrument:
218 pipeline.addInstrument(instrument)
219 return pipeline
222def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler:
223 """Create new data butler instance.
225 Parameters
226 ----------
227 root : `str`
228 Path or URI to the root location of the new repository.
229 run : `str`, optional
230 Run collection name.
231 inMemory : `bool`, optional
232 If true make in-memory repository.
234 Returns
235 -------
236 butler : `~lsst.daf.butler.Butler`
237 Data butler instance.
238 """
239 root = ButlerURI(root, forceDirectory=True)
240 if not root.isLocal:
241 raise ValueError(f"Only works with local root not {root}")
242 config = Config()
243 if not inMemory:
244 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite"
245 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
246 repo = butlerTests.makeTestRepo(root, {}, config=config)
247 butler = Butler(butler=repo, run=run)
248 return butler
251def populateButler(pipeline, butler, datasetTypes=None):
252 """Populate data butler with data needed for test.
254 Initializes data butler with a bunch of items:
255 - registers dataset types which are defined by pipeline
256 - create dimensions data for (instrument, detector)
257 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
258 missing then a single dataset with type "add_dataset0" is added
260 All datasets added to butler have ``dataId={instrument=instrument,
261 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
262 used if pipeline is missing instrument definition. Type of the dataset is
263 guessed from dataset type name (assumes that pipeline is made of `AddTask`
264 tasks).
266 Parameters
267 ----------
268 pipeline : `~lsst.pipe.base.Pipeline`
269 Pipeline instance.
270 butler : `~lsst.daf.butler.Butler`
271 Data butler instance.
272 datasetTypes : `dict` [ `str`, `list` ], optional
273 Dictionary whose keys are collection names and values are lists of
274 dataset type names. By default a single dataset of type "add_dataset0"
275 is added to a ``butler.run`` collection.
276 """
278 # Add dataset types to registry
279 taskDefs = list(pipeline.toExpandedPipeline())
280 registerDatasetTypes(butler.registry, taskDefs)
282 instrument = pipeline.getInstrument()
283 if instrument is not None:
284 if isinstance(instrument, str):
285 instrument = doImport(instrument)
286 instrumentName = instrument.getName()
287 else:
288 instrumentName = "INSTR"
290 # Add all needed dimensions to registry
291 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
292 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
294 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
295 # Add inputs to butler
296 if not datasetTypes:
297 datasetTypes = {None: ["add_dataset0"]}
298 for run, dsTypes in datasetTypes.items():
299 if run is not None:
300 butler.registry.registerRun(run)
301 for dsType in dsTypes:
302 if dsType == "packages":
303 # Version is intentionally inconsistent
304 data = Packages({"python": "9.9.99"})
305 butler.put(data, dsType, run=run)
306 else:
307 if dsType.endswith("_config"):
308 # find a confing from matching task name or make a new one
309 taskLabel, _, _ = dsType.rpartition("_")
310 taskDef = taskDefMap.get(taskLabel)
311 if taskDef is not None:
312 data = taskDef.config
313 else:
314 data = AddTaskConfig()
315 elif dsType.endswith("_metadata"):
316 data = _TASK_FULL_METADATA_TYPE()
317 elif dsType.endswith("_log"):
318 data = ButlerLogRecords.from_records([])
319 else:
320 data = numpy.array([0.0, 1.0, 2.0, 5.0])
321 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
324def makeSimpleQGraph(
325 nQuanta=5,
326 pipeline=None,
327 butler=None,
328 root=None,
329 callPopulateButler=True,
330 run="test",
331 skipExistingIn=None,
332 inMemory=True,
333 userQuery="",
334 datasetTypes=None,
335 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
336):
337 """Make simple QuantumGraph for tests.
339 Makes simple one-task pipeline with AddTask, sets up in-memory registry
340 and butler, fills them with minimal data, and generates QuantumGraph with
341 all of that.
343 Parameters
344 ----------
345 nQuanta : `int`
346 Number of quanta in a graph, only used if ``pipeline`` is None.
347 pipeline : `~lsst.pipe.base.Pipeline`
348 If `None` then a pipeline is made with `AddTask` and default
349 `AddTaskConfig`.
350 butler : `~lsst.daf.butler.Butler`, optional
351 Data butler instance, if None then new data butler is created by
352 calling `makeSimpleButler`.
353 callPopulateButler : `bool`, optional
354 If True insert datasets into the butler prior to building a graph.
355 If False butler argument must not be None, and must be pre-populated.
356 Defaults to True.
357 root : `str`
358 Path or URI to the root location of the new repository. Only used if
359 ``butler`` is None.
360 run : `str`, optional
361 Name of the RUN collection to add to butler, only used if ``butler``
362 is None.
363 skipExistingIn
364 Expressions representing the collections to search for existing
365 output datasets that should be skipped. May be any of the types
366 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
367 inMemory : `bool`, optional
368 If true make in-memory repository, only used if ``butler`` is `None`.
369 userQuery : `str`, optional
370 The user query to pass to ``makeGraph``, by default an empty string.
371 datasetTypes : `dict` [ `str`, `list` ], optional
372 Dictionary whose keys are collection names and values are lists of
373 dataset type names. By default a single dataset of type "add_dataset0"
374 is added to a ``butler.run`` collection.
375 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
376 The query constraint variant that should be used to constrain the
377 query based on dataset existence, defaults to
378 `DatasetQueryConstraintVariant.ALL`.
380 Returns
381 -------
382 butler : `~lsst.daf.butler.Butler`
383 Butler instance
384 qgraph : `~lsst.pipe.base.QuantumGraph`
385 Quantum graph instance
386 """
388 if pipeline is None:
389 pipeline = makeSimplePipeline(nQuanta=nQuanta)
391 if butler is None:
392 if root is None:
393 raise ValueError("Must provide `root` when `butler` is None")
394 if callPopulateButler is False:
395 raise ValueError("populateButler can only be False when butler is supplied as an argument")
396 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
398 if callPopulateButler:
399 populateButler(pipeline, butler, datasetTypes=datasetTypes)
401 # Make the graph
402 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
403 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn)
404 _LOG.debug(
405 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r",
406 butler.collections,
407 run or butler.run,
408 userQuery,
409 )
410 qgraph = builder.makeGraph(
411 pipeline,
412 collections=butler.collections,
413 run=run or butler.run,
414 userQuery=userQuery,
415 datasetQueryConstraint=datasetQueryConstraint,
416 )
418 return butler, qgraph