Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
27import itertools
28import logging
30import lsst.daf.butler.tests as butlerTests
31import lsst.pex.config as pexConfig
32import numpy
33from lsst.base import Packages
34from lsst.daf.butler import Butler, Config, DatasetType
35from lsst.daf.butler.core.logging import ButlerLogRecords
36from lsst.resources import ResourcePath
37from lsst.utils import doImport
39from ... import base as pipeBase
40from .. import connectionTypes as cT
41from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
42from ..task import _TASK_FULL_METADATA_TYPE
44_LOG = logging.getLogger(__name__)
47# SimpleInstrument has an instrument-like API as needed for unit testing, but
48# can not explicitly depend on Instrument because pipe_base does not explicitly
49# depend on obs_base.
50class SimpleInstrument:
51 def __init__(self, *args, **kwargs):
52 pass
54 @staticmethod
55 def getName():
56 return "INSTRU"
58 def applyConfigOverrides(self, name, config):
59 pass
62class AddTaskConnections(
63 pipeBase.PipelineTaskConnections,
64 dimensions=("instrument", "detector"),
65 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
66):
67 """Connections for AddTask, has one input and two outputs,
68 plus one init output.
69 """
71 input = cT.Input(
72 name="add_dataset{in_tmpl}",
73 dimensions=["instrument", "detector"],
74 storageClass="NumpyArray",
75 doc="Input dataset type for this task",
76 )
77 output = cT.Output(
78 name="add_dataset{out_tmpl}",
79 dimensions=["instrument", "detector"],
80 storageClass="NumpyArray",
81 doc="Output dataset type for this task",
82 )
83 output2 = cT.Output(
84 name="add2_dataset{out_tmpl}",
85 dimensions=["instrument", "detector"],
86 storageClass="NumpyArray",
87 doc="Output dataset type for this task",
88 )
89 initout = cT.InitOutput(
90 name="add_init_output{out_tmpl}",
91 storageClass="NumpyArray",
92 doc="Init Output dataset type for this task",
93 )
96class AddTaskConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddTaskConnections):
97 """Config for AddTask."""
99 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
102class AddTask(pipeBase.PipelineTask):
103 """Trivial PipelineTask for testing, has some extras useful for specific
104 unit tests.
105 """
107 ConfigClass = AddTaskConfig
108 _DefaultName = "add_task"
110 initout = numpy.array([999])
111 """InitOutputs for this task"""
113 taskFactory = None
114 """Factory that makes instances"""
116 def run(self, input):
118 if self.taskFactory:
119 # do some bookkeeping
120 if self.taskFactory.stopAt == self.taskFactory.countExec:
121 raise RuntimeError("pretend something bad happened")
122 self.taskFactory.countExec += 1
124 self.metadata.add("add", self.config.addend)
125 output = input + self.config.addend
126 output2 = output + self.config.addend
127 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
128 return pipeBase.Struct(output=output, output2=output2)
131class AddTaskFactoryMock(pipeBase.TaskFactory):
132 """Special task factory that instantiates AddTask.
134 It also defines some bookkeeping variables used by AddTask to report
135 progress to unit tests.
136 """
138 def __init__(self, stopAt=-1):
139 self.countExec = 0 # incremented by AddTask
140 self.stopAt = stopAt # AddTask raises exception at this call to run()
142 def makeTask(self, taskClass, name, config, overrides, butler):
143 if config is None:
144 config = taskClass.ConfigClass()
145 if overrides:
146 overrides.applyTo(config)
147 task = taskClass(config=config, initInputs=None, name=name)
148 task.taskFactory = self
149 return task
152def registerDatasetTypes(registry, pipeline):
153 """Register all dataset types used by tasks in a registry.
155 Copied and modified from `PreExecInit.initializeDatasetTypes`.
157 Parameters
158 ----------
159 registry : `~lsst.daf.butler.Registry`
160 Registry instance.
161 pipeline : `typing.Iterable` of `TaskDef`
162 Iterable of TaskDef instances, likely the output of the method
163 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
164 """
165 for taskDef in pipeline:
166 configDatasetType = DatasetType(
167 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
168 )
169 packagesDatasetType = DatasetType(
170 "packages", {}, storageClass="Packages", universe=registry.dimensions
171 )
172 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
173 for datasetType in itertools.chain(
174 datasetTypes.initInputs,
175 datasetTypes.initOutputs,
176 datasetTypes.inputs,
177 datasetTypes.outputs,
178 datasetTypes.prerequisites,
179 [configDatasetType, packagesDatasetType],
180 ):
181 _LOG.info("Registering %s with registry", datasetType)
182 # this is a no-op if it already exists and is consistent,
183 # and it raises if it is inconsistent. But components must be
184 # skipped
185 if not datasetType.isComponent():
186 registry.registerDatasetType(datasetType)
189def makeSimplePipeline(nQuanta, instrument=None):
190 """Make a simple Pipeline for tests.
192 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
193 function. It can also be used to customize the pipeline used by
194 ``makeSimpleQGraph`` function by calling this first and passing the result
195 to it.
197 Parameters
198 ----------
199 nQuanta : `int`
200 The number of quanta to add to the pipeline.
201 instrument : `str` or `None`, optional
202 The importable name of an instrument to be added to the pipeline or
203 if no instrument should be added then an empty string or `None`, by
204 default None
206 Returns
207 -------
208 pipeline : `~lsst.pipe.base.Pipeline`
209 The created pipeline object.
210 """
211 pipeline = pipeBase.Pipeline("test pipeline")
212 # make a bunch of tasks that execute in well defined order (via data
213 # dependencies)
214 for lvl in range(nQuanta):
215 pipeline.addTask(AddTask, f"task{lvl}")
216 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
217 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
218 if instrument:
219 pipeline.addInstrument(instrument)
220 return pipeline
223def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler:
224 """Create new data butler instance.
226 Parameters
227 ----------
228 root : `str`
229 Path or URI to the root location of the new repository.
230 run : `str`, optional
231 Run collection name.
232 inMemory : `bool`, optional
233 If true make in-memory repository.
235 Returns
236 -------
237 butler : `~lsst.daf.butler.Butler`
238 Data butler instance.
239 """
240 root = ResourcePath(root, forceDirectory=True)
241 if not root.isLocal:
242 raise ValueError(f"Only works with local root not {root}")
243 config = Config()
244 if not inMemory:
245 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite"
246 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
247 repo = butlerTests.makeTestRepo(root, {}, config=config)
248 butler = Butler(butler=repo, run=run)
249 return butler
252def populateButler(pipeline, butler, datasetTypes=None):
253 """Populate data butler with data needed for test.
255 Initializes data butler with a bunch of items:
256 - registers dataset types which are defined by pipeline
257 - create dimensions data for (instrument, detector)
258 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
259 missing then a single dataset with type "add_dataset0" is added
261 All datasets added to butler have ``dataId={instrument=instrument,
262 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
263 used if pipeline is missing instrument definition. Type of the dataset is
264 guessed from dataset type name (assumes that pipeline is made of `AddTask`
265 tasks).
267 Parameters
268 ----------
269 pipeline : `~lsst.pipe.base.Pipeline`
270 Pipeline instance.
271 butler : `~lsst.daf.butler.Butler`
272 Data butler instance.
273 datasetTypes : `dict` [ `str`, `list` ], optional
274 Dictionary whose keys are collection names and values are lists of
275 dataset type names. By default a single dataset of type "add_dataset0"
276 is added to a ``butler.run`` collection.
277 """
279 # Add dataset types to registry
280 taskDefs = list(pipeline.toExpandedPipeline())
281 registerDatasetTypes(butler.registry, taskDefs)
283 instrument = pipeline.getInstrument()
284 if instrument is not None:
285 if isinstance(instrument, str):
286 instrument = doImport(instrument)
287 instrumentName = instrument.getName()
288 else:
289 instrumentName = "INSTR"
291 # Add all needed dimensions to registry
292 butler.registry.insertDimensionData("instrument", dict(name=instrumentName))
293 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
295 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
296 # Add inputs to butler
297 if not datasetTypes:
298 datasetTypes = {None: ["add_dataset0"]}
299 for run, dsTypes in datasetTypes.items():
300 if run is not None:
301 butler.registry.registerRun(run)
302 for dsType in dsTypes:
303 if dsType == "packages":
304 # Version is intentionally inconsistent
305 data = Packages({"python": "9.9.99"})
306 butler.put(data, dsType, run=run)
307 else:
308 if dsType.endswith("_config"):
309 # find a confing from matching task name or make a new one
310 taskLabel, _, _ = dsType.rpartition("_")
311 taskDef = taskDefMap.get(taskLabel)
312 if taskDef is not None:
313 data = taskDef.config
314 else:
315 data = AddTaskConfig()
316 elif dsType.endswith("_metadata"):
317 data = _TASK_FULL_METADATA_TYPE()
318 elif dsType.endswith("_log"):
319 data = ButlerLogRecords.from_records([])
320 else:
321 data = numpy.array([0.0, 1.0, 2.0, 5.0])
322 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
325def makeSimpleQGraph(
326 nQuanta=5,
327 pipeline=None,
328 butler=None,
329 root=None,
330 callPopulateButler=True,
331 run="test",
332 skipExistingIn=None,
333 inMemory=True,
334 userQuery="",
335 datasetTypes=None,
336 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
337):
338 """Make simple QuantumGraph for tests.
340 Makes simple one-task pipeline with AddTask, sets up in-memory registry
341 and butler, fills them with minimal data, and generates QuantumGraph with
342 all of that.
344 Parameters
345 ----------
346 nQuanta : `int`
347 Number of quanta in a graph, only used if ``pipeline`` is None.
348 pipeline : `~lsst.pipe.base.Pipeline`
349 If `None` then a pipeline is made with `AddTask` and default
350 `AddTaskConfig`.
351 butler : `~lsst.daf.butler.Butler`, optional
352 Data butler instance, if None then new data butler is created by
353 calling `makeSimpleButler`.
354 callPopulateButler : `bool`, optional
355 If True insert datasets into the butler prior to building a graph.
356 If False butler argument must not be None, and must be pre-populated.
357 Defaults to True.
358 root : `str`
359 Path or URI to the root location of the new repository. Only used if
360 ``butler`` is None.
361 run : `str`, optional
362 Name of the RUN collection to add to butler, only used if ``butler``
363 is None.
364 skipExistingIn
365 Expressions representing the collections to search for existing
366 output datasets that should be skipped. May be any of the types
367 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`.
368 inMemory : `bool`, optional
369 If true make in-memory repository, only used if ``butler`` is `None`.
370 userQuery : `str`, optional
371 The user query to pass to ``makeGraph``, by default an empty string.
372 datasetTypes : `dict` [ `str`, `list` ], optional
373 Dictionary whose keys are collection names and values are lists of
374 dataset type names. By default a single dataset of type "add_dataset0"
375 is added to a ``butler.run`` collection.
376 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
377 The query constraint variant that should be used to constrain the
378 query based on dataset existence, defaults to
379 `DatasetQueryConstraintVariant.ALL`.
381 Returns
382 -------
383 butler : `~lsst.daf.butler.Butler`
384 Butler instance
385 qgraph : `~lsst.pipe.base.QuantumGraph`
386 Quantum graph instance
387 """
389 if pipeline is None:
390 pipeline = makeSimplePipeline(nQuanta=nQuanta)
392 if butler is None:
393 if root is None:
394 raise ValueError("Must provide `root` when `butler` is None")
395 if callPopulateButler is False:
396 raise ValueError("populateButler can only be False when butler is supplied as an argument")
397 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
399 if callPopulateButler:
400 populateButler(pipeline, butler, datasetTypes=datasetTypes)
402 # Make the graph
403 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
404 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn)
405 _LOG.debug(
406 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r",
407 butler.collections,
408 run or butler.run,
409 userQuery,
410 )
411 qgraph = builder.makeGraph(
412 pipeline,
413 collections=butler.collections,
414 run=run or butler.run,
415 userQuery=userQuery,
416 datasetQueryConstraint=datasetQueryConstraint,
417 )
419 return butler, qgraph