Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%
161 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
24from __future__ import annotations
26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
28import itertools
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from typing import TYPE_CHECKING, Any, cast
33import lsst.daf.butler.tests as butlerTests
34import lsst.pex.config as pexConfig
35import numpy
36from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
37from lsst.daf.butler.core.logging import ButlerLogRecords
38from lsst.resources import ResourcePath
39from lsst.utils import doImportType
40from lsst.utils.introspection import get_full_type_name
42from .. import connectionTypes as cT
43from .._instrument import Instrument
44from ..config import PipelineTaskConfig
45from ..connections import PipelineTaskConnections
46from ..graph import QuantumGraph
47from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
48from ..graphBuilder import GraphBuilder
49from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef
50from ..pipelineTask import PipelineTask
51from ..struct import Struct
52from ..task import _TASK_FULL_METADATA_TYPE
53from ..taskFactory import TaskFactory
55if TYPE_CHECKING:
56 from lsst.daf.butler import Registry
58_LOG = logging.getLogger(__name__)
61class SimpleInstrument(Instrument):
62 """Simple instrument class suitable for testing."""
64 def __init__(self, *args: Any, **kwargs: Any):
65 pass
67 @staticmethod
68 def getName() -> str:
69 return "INSTRU"
71 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
72 return Formatter
74 def register(self, registry: Registry, *, update: bool = False) -> None:
75 pass
78class AddTaskConnections(
79 PipelineTaskConnections,
80 dimensions=("instrument", "detector"),
81 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
82):
83 """Connections for AddTask, has one input and two outputs,
84 plus one init output.
85 """
87 input = cT.Input(
88 name="add_dataset{in_tmpl}",
89 dimensions=["instrument", "detector"],
90 storageClass="NumpyArray",
91 doc="Input dataset type for this task",
92 )
93 output = cT.Output(
94 name="add_dataset{out_tmpl}",
95 dimensions=["instrument", "detector"],
96 storageClass="NumpyArray",
97 doc="Output dataset type for this task",
98 )
99 output2 = cT.Output(
100 name="add2_dataset{out_tmpl}",
101 dimensions=["instrument", "detector"],
102 storageClass="NumpyArray",
103 doc="Output dataset type for this task",
104 )
105 initout = cT.InitOutput(
106 name="add_init_output{out_tmpl}",
107 storageClass="NumpyArray",
108 doc="Init Output dataset type for this task",
109 )
112class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
113 """Config for AddTask."""
115 addend = pexConfig.Field[int](doc="amount to add", default=3)
118class AddTask(PipelineTask):
119 """Trivial PipelineTask for testing, has some extras useful for specific
120 unit tests.
121 """
123 ConfigClass = AddTaskConfig
124 _DefaultName = "add_task"
126 initout = numpy.array([999])
127 """InitOutputs for this task"""
129 taskFactory: AddTaskFactoryMock | None = None
130 """Factory that makes instances"""
132 def run(self, input: int) -> Struct: # type: ignore
133 if self.taskFactory:
134 # do some bookkeeping
135 if self.taskFactory.stopAt == self.taskFactory.countExec:
136 raise RuntimeError("pretend something bad happened")
137 self.taskFactory.countExec += 1
139 self.config = cast(AddTaskConfig, self.config)
140 self.metadata.add("add", self.config.addend)
141 output = input + self.config.addend
142 output2 = output + self.config.addend
143 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
144 return Struct(output=output, output2=output2)
147class AddTaskFactoryMock(TaskFactory):
148 """Special task factory that instantiates AddTask.
150 It also defines some bookkeeping variables used by AddTask to report
151 progress to unit tests.
152 """
154 def __init__(self, stopAt: int = -1):
155 self.countExec = 0 # incremented by AddTask
156 self.stopAt = stopAt # AddTask raises exception at this call to run()
158 def makeTask(
159 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None
160 ) -> PipelineTask:
161 taskClass = taskDef.taskClass
162 assert taskClass is not None
163 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label)
164 task.taskFactory = self # type: ignore
165 return task
168def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None:
169 """Register all dataset types used by tasks in a registry.
171 Copied and modified from `PreExecInit.initializeDatasetTypes`.
173 Parameters
174 ----------
175 registry : `~lsst.daf.butler.Registry`
176 Registry instance.
177 pipeline : `typing.Iterable` of `TaskDef`
178 Iterable of TaskDef instances, likely the output of the method
179 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object.
180 """
181 for taskDef in pipeline:
182 configDatasetType = DatasetType(
183 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
184 )
185 storageClass = "Packages"
186 packagesDatasetType = DatasetType(
187 "packages", {}, storageClass=storageClass, universe=registry.dimensions
188 )
189 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
190 for datasetType in itertools.chain(
191 datasetTypes.initInputs,
192 datasetTypes.initOutputs,
193 datasetTypes.inputs,
194 datasetTypes.outputs,
195 datasetTypes.prerequisites,
196 [configDatasetType, packagesDatasetType],
197 ):
198 _LOG.info("Registering %s with registry", datasetType)
199 # this is a no-op if it already exists and is consistent,
200 # and it raises if it is inconsistent. But components must be
201 # skipped
202 if not datasetType.isComponent():
203 registry.registerDatasetType(datasetType)
206def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
207 """Make a simple Pipeline for tests.
209 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
210 function. It can also be used to customize the pipeline used by
211 `makeSimpleQGraph()` function by calling this first and passing the result
212 to it.
214 Parameters
215 ----------
216 nQuanta : `int`
217 The number of quanta to add to the pipeline.
218 instrument : `str` or `None`, optional
219 The importable name of an instrument to be added to the pipeline or
220 if no instrument should be added then an empty string or `None`, by
221 default None
223 Returns
224 -------
225 pipeline : `~lsst.pipe.base.Pipeline`
226 The created pipeline object.
227 """
228 pipeline = Pipeline("test pipeline")
229 # make a bunch of tasks that execute in well defined order (via data
230 # dependencies)
231 for lvl in range(nQuanta):
232 pipeline.addTask(AddTask, f"task{lvl}")
233 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
234 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
235 if instrument:
236 pipeline.addInstrument(instrument)
237 return pipeline
240def makeSimpleButler(
241 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
242) -> Butler:
243 """Create new data butler instance.
245 Parameters
246 ----------
247 root : `str`
248 Path or URI to the root location of the new repository.
249 run : `str`, optional
250 Run collection name.
251 inMemory : `bool`, optional
252 If true make in-memory repository.
253 config : `~lsst.daf.butler.Config`, optional
254 Configuration to use for new Butler, if `None` then default
255 configuration is used. If ``inMemory`` is `True` then configuration
256 is updated to use SQLite registry and file datastore in ``root``.
258 Returns
259 -------
260 butler : `~lsst.daf.butler.Butler`
261 Data butler instance.
262 """
263 root_path = ResourcePath(root, forceDirectory=True)
264 if not root_path.isLocal:
265 raise ValueError(f"Only works with local root not {root_path}")
266 butler_config = Config()
267 if config:
268 butler_config.update(Config(config))
269 if not inMemory:
270 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
271 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
272 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
273 butler = Butler(butler=repo, run=run)
274 return butler
277def populateButler(
278 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None
279) -> None:
280 """Populate data butler with data needed for test.
282 Initializes data butler with a bunch of items:
283 - registers dataset types which are defined by pipeline
284 - create dimensions data for (instrument, detector)
285 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
286 missing then a single dataset with type "add_dataset0" is added
288 All datasets added to butler have ``dataId={instrument=instrument,
289 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
290 used if pipeline is missing instrument definition. Type of the dataset is
291 guessed from dataset type name (assumes that pipeline is made of `AddTask`
292 tasks).
294 Parameters
295 ----------
296 pipeline : `~lsst.pipe.base.Pipeline`
297 Pipeline instance.
298 butler : `~lsst.daf.butler.Butler`
299 Data butler instance.
300 datasetTypes : `dict` [ `str`, `list` ], optional
301 Dictionary whose keys are collection names and values are lists of
302 dataset type names. By default a single dataset of type "add_dataset0"
303 is added to a ``butler.run`` collection.
304 """
305 # Add dataset types to registry
306 taskDefs = list(pipeline.toExpandedPipeline())
307 registerDatasetTypes(butler.registry, taskDefs)
309 instrument = pipeline.getInstrument()
310 if instrument is not None:
311 instrument_class = doImportType(instrument)
312 instrumentName = cast(Instrument, instrument_class).getName()
313 instrumentClass = get_full_type_name(instrument_class)
314 else:
315 instrumentName = "INSTR"
316 instrumentClass = None
318 # Add all needed dimensions to registry
319 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
320 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
322 taskDefMap = {taskDef.label: taskDef for taskDef in taskDefs}
323 # Add inputs to butler
324 if not datasetTypes:
325 datasetTypes = {None: ["add_dataset0"]}
326 for run, dsTypes in datasetTypes.items():
327 if run is not None:
328 butler.registry.registerRun(run)
329 for dsType in dsTypes:
330 if dsType == "packages":
331 # Version is intentionally inconsistent.
332 # Dict is convertible to Packages if Packages is installed.
333 data: Any = {"python": "9.9.99"}
334 butler.put(data, dsType, run=run)
335 else:
336 if dsType.endswith("_config"):
337 # find a config from matching task name or make a new one
338 taskLabel, _, _ = dsType.rpartition("_")
339 taskDef = taskDefMap.get(taskLabel)
340 if taskDef is not None:
341 data = taskDef.config
342 else:
343 data = AddTaskConfig()
344 elif dsType.endswith("_metadata"):
345 data = _TASK_FULL_METADATA_TYPE()
346 elif dsType.endswith("_log"):
347 data = ButlerLogRecords.from_records([])
348 else:
349 data = numpy.array([0.0, 1.0, 2.0, 5.0])
350 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
353def makeSimpleQGraph(
354 nQuanta: int = 5,
355 pipeline: Pipeline | None = None,
356 butler: Butler | None = None,
357 root: str | None = None,
358 callPopulateButler: bool = True,
359 run: str = "test",
360 instrument: str | None = None,
361 skipExistingIn: Any = None,
362 inMemory: bool = True,
363 userQuery: str = "",
364 datasetTypes: dict[str | None, list[str]] | None = None,
365 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
366 makeDatastoreRecords: bool = False,
367 bind: Mapping[str, Any] | None = None,
368 metadata: MutableMapping[str, Any] | None = None,
369) -> tuple[Butler, QuantumGraph]:
370 """Make simple `QuantumGraph` for tests.
372 Makes simple one-task pipeline with AddTask, sets up in-memory registry
373 and butler, fills them with minimal data, and generates QuantumGraph with
374 all of that.
376 Parameters
377 ----------
378 nQuanta : `int`
379 Number of quanta in a graph, only used if ``pipeline`` is None.
380 pipeline : `~lsst.pipe.base.Pipeline`
381 If `None` then a pipeline is made with `AddTask` and default
382 `AddTaskConfig`.
383 butler : `~lsst.daf.butler.Butler`, optional
384 Data butler instance, if None then new data butler is created by
385 calling `makeSimpleButler`.
386 callPopulateButler : `bool`, optional
387 If True insert datasets into the butler prior to building a graph.
388 If False butler argument must not be None, and must be pre-populated.
389 Defaults to True.
390 root : `str`
391 Path or URI to the root location of the new repository. Only used if
392 ``butler`` is None.
393 run : `str`, optional
394 Name of the RUN collection to add to butler, only used if ``butler``
395 is None.
396 instrument : `str` or `None`, optional
397 The importable name of an instrument to be added to the pipeline or
398 if no instrument should be added then an empty string or `None`, by
399 default `None`. Only used if ``pipeline`` is `None`.
400 skipExistingIn
401 Expressions representing the collections to search for existing
402 output datasets that should be skipped. See
403 :ref:`daf_butler_ordered_collection_searches`.
404 inMemory : `bool`, optional
405 If true make in-memory repository, only used if ``butler`` is `None`.
406 userQuery : `str`, optional
407 The user query to pass to ``makeGraph``, by default an empty string.
408 datasetTypes : `dict` [ `str`, `list` ], optional
409 Dictionary whose keys are collection names and values are lists of
410 dataset type names. By default a single dataset of type "add_dataset0"
411 is added to a ``butler.run`` collection.
412 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
413 The query constraint variant that should be used to constrain the
414 query based on dataset existence, defaults to
415 `DatasetQueryConstraintVariant.ALL`.
416 makeDatastoreRecords : `bool`, optional
417 If `True` then add datastore records to generated quanta.
418 bind : `~collections.abc.Mapping`, optional
419 Mapping containing literal values that should be injected into the
420 ``userQuery`` expression, keyed by the identifiers they replace.
421 metadata : `~collections.abc.Mapping`, optional
422 Optional graph metadata.
424 Returns
425 -------
426 butler : `~lsst.daf.butler.Butler`
427 Butler instance
428 qgraph : `~lsst.pipe.base.QuantumGraph`
429 Quantum graph instance
430 """
431 if pipeline is None:
432 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument)
434 if butler is None:
435 if root is None:
436 raise ValueError("Must provide `root` when `butler` is None")
437 if callPopulateButler is False:
438 raise ValueError("populateButler can only be False when butler is supplied as an argument")
439 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
441 if callPopulateButler:
442 populateButler(pipeline, butler, datasetTypes=datasetTypes)
444 # Make the graph
445 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
446 builder = GraphBuilder(
447 registry=butler.registry,
448 skipExistingIn=skipExistingIn,
449 datastore=butler._datastore if makeDatastoreRecords else None,
450 )
451 if not run:
452 assert butler.run is not None, "Butler must have run defined"
453 run = butler.run
454 _LOG.debug(
455 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s",
456 butler.collections,
457 run,
458 userQuery,
459 bind,
460 )
461 if not metadata:
462 metadata = {}
463 metadata["output_run"] = run
465 qgraph = builder.makeGraph(
466 pipeline,
467 collections=butler.collections,
468 run=run,
469 userQuery=userQuery,
470 datasetQueryConstraint=datasetQueryConstraint,
471 bind=bind,
472 metadata=metadata,
473 )
475 return butler, qgraph