Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%
161 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 02:10 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 02:10 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Bunch of common classes and methods for use in unit tests.
23"""
24from __future__ import annotations
26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
28import itertools
29import logging
30from collections.abc import Iterable, Mapping, MutableMapping
31from typing import TYPE_CHECKING, Any, cast
33import lsst.daf.butler.tests as butlerTests
34import lsst.pex.config as pexConfig
35import numpy
36from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
37from lsst.daf.butler.core.logging import ButlerLogRecords
38from lsst.resources import ResourcePath
39from lsst.utils import doImportType
40from lsst.utils.introspection import get_full_type_name
42from .. import connectionTypes as cT
43from .._instrument import Instrument
44from ..config import PipelineTaskConfig
45from ..connections import PipelineTaskConnections
46from ..graph import QuantumGraph
47from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
48from ..graphBuilder import GraphBuilder
49from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef
50from ..pipelineTask import PipelineTask
51from ..struct import Struct
52from ..task import _TASK_FULL_METADATA_TYPE
53from ..taskFactory import TaskFactory
55if TYPE_CHECKING:
56 from lsst.daf.butler import Registry
58_LOG = logging.getLogger(__name__)
61class SimpleInstrument(Instrument):
62 def __init__(self, *args: Any, **kwargs: Any):
63 pass
65 @staticmethod
66 def getName() -> str:
67 return "INSTRU"
69 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
70 return Formatter
72 def register(self, registry: Registry, *, update: bool = False) -> None:
73 pass
76class AddTaskConnections(
77 PipelineTaskConnections,
78 dimensions=("instrument", "detector"),
79 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
80):
81 """Connections for AddTask, has one input and two outputs,
82 plus one init output.
83 """
85 input = cT.Input(
86 name="add_dataset{in_tmpl}",
87 dimensions=["instrument", "detector"],
88 storageClass="NumpyArray",
89 doc="Input dataset type for this task",
90 )
91 output = cT.Output(
92 name="add_dataset{out_tmpl}",
93 dimensions=["instrument", "detector"],
94 storageClass="NumpyArray",
95 doc="Output dataset type for this task",
96 )
97 output2 = cT.Output(
98 name="add2_dataset{out_tmpl}",
99 dimensions=["instrument", "detector"],
100 storageClass="NumpyArray",
101 doc="Output dataset type for this task",
102 )
103 initout = cT.InitOutput(
104 name="add_init_output{out_tmpl}",
105 storageClass="NumpyArray",
106 doc="Init Output dataset type for this task",
107 )
110class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
111 """Config for AddTask."""
113 addend = pexConfig.Field[int](doc="amount to add", default=3)
116class AddTask(PipelineTask):
117 """Trivial PipelineTask for testing, has some extras useful for specific
118 unit tests.
119 """
121 ConfigClass = AddTaskConfig
122 _DefaultName = "add_task"
124 initout = numpy.array([999])
125 """InitOutputs for this task"""
127 taskFactory: AddTaskFactoryMock | None = None
128 """Factory that makes instances"""
130 def run(self, input: int) -> Struct: # type: ignore
131 if self.taskFactory:
132 # do some bookkeeping
133 if self.taskFactory.stopAt == self.taskFactory.countExec:
134 raise RuntimeError("pretend something bad happened")
135 self.taskFactory.countExec += 1
137 self.config = cast(AddTaskConfig, self.config)
138 self.metadata.add("add", self.config.addend)
139 output = input + self.config.addend
140 output2 = output + self.config.addend
141 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
142 return Struct(output=output, output2=output2)
145class AddTaskFactoryMock(TaskFactory):
146 """Special task factory that instantiates AddTask.
148 It also defines some bookkeeping variables used by AddTask to report
149 progress to unit tests.
150 """
152 def __init__(self, stopAt: int = -1):
153 self.countExec = 0 # incremented by AddTask
154 self.stopAt = stopAt # AddTask raises exception at this call to run()
156 def makeTask(
157 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None
158 ) -> PipelineTask:
159 taskClass = taskDef.taskClass
160 assert taskClass is not None
161 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label)
162 task.taskFactory = self # type: ignore
163 return task
166def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None:
167 """Register all dataset types used by tasks in a registry.
169 Copied and modified from `PreExecInit.initializeDatasetTypes`.
171 Parameters
172 ----------
173 registry : `~lsst.daf.butler.Registry`
174 Registry instance.
175 pipeline : `typing.Iterable` of `TaskDef`
176 Iterable of TaskDef instances, likely the output of the method
177 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object.
178 """
179 for taskDef in pipeline:
180 configDatasetType = DatasetType(
181 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
182 )
183 storageClass = "Packages"
184 packagesDatasetType = DatasetType(
185 "packages", {}, storageClass=storageClass, universe=registry.dimensions
186 )
187 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
188 for datasetType in itertools.chain(
189 datasetTypes.initInputs,
190 datasetTypes.initOutputs,
191 datasetTypes.inputs,
192 datasetTypes.outputs,
193 datasetTypes.prerequisites,
194 [configDatasetType, packagesDatasetType],
195 ):
196 _LOG.info("Registering %s with registry", datasetType)
197 # this is a no-op if it already exists and is consistent,
198 # and it raises if it is inconsistent. But components must be
199 # skipped
200 if not datasetType.isComponent():
201 registry.registerDatasetType(datasetType)
204def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
205 """Make a simple Pipeline for tests.
207 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
208 function. It can also be used to customize the pipeline used by
209 `makeSimpleQGraph()` function by calling this first and passing the result
210 to it.
212 Parameters
213 ----------
214 nQuanta : `int`
215 The number of quanta to add to the pipeline.
216 instrument : `str` or `None`, optional
217 The importable name of an instrument to be added to the pipeline or
218 if no instrument should be added then an empty string or `None`, by
219 default None
221 Returns
222 -------
223 pipeline : `~lsst.pipe.base.Pipeline`
224 The created pipeline object.
225 """
226 pipeline = Pipeline("test pipeline")
227 # make a bunch of tasks that execute in well defined order (via data
228 # dependencies)
229 for lvl in range(nQuanta):
230 pipeline.addTask(AddTask, f"task{lvl}")
231 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
232 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
233 if instrument:
234 pipeline.addInstrument(instrument)
235 return pipeline
238def makeSimpleButler(
239 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
240) -> Butler:
241 """Create new data butler instance.
243 Parameters
244 ----------
245 root : `str`
246 Path or URI to the root location of the new repository.
247 run : `str`, optional
248 Run collection name.
249 inMemory : `bool`, optional
250 If true make in-memory repository.
251 config : `~lsst.daf.butler.Config`, optional
252 Configuration to use for new Butler, if `None` then default
253 configuration is used. If ``inMemory`` is `True` then configuration
254 is updated to use SQLite registry and file datastore in ``root``.
256 Returns
257 -------
258 butler : `~lsst.daf.butler.Butler`
259 Data butler instance.
260 """
261 root_path = ResourcePath(root, forceDirectory=True)
262 if not root_path.isLocal:
263 raise ValueError(f"Only works with local root not {root_path}")
264 butler_config = Config()
265 if config:
266 butler_config.update(Config(config))
267 if not inMemory:
268 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
269 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
270 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
271 butler = Butler(butler=repo, run=run)
272 return butler
275def populateButler(
276 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None
277) -> None:
278 """Populate data butler with data needed for test.
280 Initializes data butler with a bunch of items:
281 - registers dataset types which are defined by pipeline
282 - create dimensions data for (instrument, detector)
283 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
284 missing then a single dataset with type "add_dataset0" is added
286 All datasets added to butler have ``dataId={instrument=instrument,
287 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
288 used if pipeline is missing instrument definition. Type of the dataset is
289 guessed from dataset type name (assumes that pipeline is made of `AddTask`
290 tasks).
292 Parameters
293 ----------
294 pipeline : `~lsst.pipe.base.Pipeline`
295 Pipeline instance.
296 butler : `~lsst.daf.butler.Butler`
297 Data butler instance.
298 datasetTypes : `dict` [ `str`, `list` ], optional
299 Dictionary whose keys are collection names and values are lists of
300 dataset type names. By default a single dataset of type "add_dataset0"
301 is added to a ``butler.run`` collection.
302 """
303 # Add dataset types to registry
304 taskDefs = list(pipeline.toExpandedPipeline())
305 registerDatasetTypes(butler.registry, taskDefs)
307 instrument = pipeline.getInstrument()
308 if instrument is not None:
309 instrument_class = doImportType(instrument)
310 instrumentName = instrument_class.getName()
311 instrumentClass = get_full_type_name(instrument_class)
312 else:
313 instrumentName = "INSTR"
314 instrumentClass = None
316 # Add all needed dimensions to registry
317 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
318 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
320 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs)
321 # Add inputs to butler
322 if not datasetTypes:
323 datasetTypes = {None: ["add_dataset0"]}
324 for run, dsTypes in datasetTypes.items():
325 if run is not None:
326 butler.registry.registerRun(run)
327 for dsType in dsTypes:
328 if dsType == "packages":
329 # Version is intentionally inconsistent.
330 # Dict is convertible to Packages if Packages is installed.
331 data: Any = {"python": "9.9.99"}
332 butler.put(data, dsType, run=run)
333 else:
334 if dsType.endswith("_config"):
335 # find a config from matching task name or make a new one
336 taskLabel, _, _ = dsType.rpartition("_")
337 taskDef = taskDefMap.get(taskLabel)
338 if taskDef is not None:
339 data = taskDef.config
340 else:
341 data = AddTaskConfig()
342 elif dsType.endswith("_metadata"):
343 data = _TASK_FULL_METADATA_TYPE()
344 elif dsType.endswith("_log"):
345 data = ButlerLogRecords.from_records([])
346 else:
347 data = numpy.array([0.0, 1.0, 2.0, 5.0])
348 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
351def makeSimpleQGraph(
352 nQuanta: int = 5,
353 pipeline: Pipeline | None = None,
354 butler: Butler | None = None,
355 root: str | None = None,
356 callPopulateButler: bool = True,
357 run: str = "test",
358 instrument: str | None = None,
359 skipExistingIn: Any = None,
360 inMemory: bool = True,
361 userQuery: str = "",
362 datasetTypes: dict[str | None, list[str]] | None = None,
363 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
364 makeDatastoreRecords: bool = False,
365 bind: Mapping[str, Any] | None = None,
366 metadata: MutableMapping[str, Any] | None = None,
367) -> tuple[Butler, QuantumGraph]:
368 """Make simple `QuantumGraph` for tests.
370 Makes simple one-task pipeline with AddTask, sets up in-memory registry
371 and butler, fills them with minimal data, and generates QuantumGraph with
372 all of that.
374 Parameters
375 ----------
376 nQuanta : `int`
377 Number of quanta in a graph, only used if ``pipeline`` is None.
378 pipeline : `~lsst.pipe.base.Pipeline`
379 If `None` then a pipeline is made with `AddTask` and default
380 `AddTaskConfig`.
381 butler : `~lsst.daf.butler.Butler`, optional
382 Data butler instance, if None then new data butler is created by
383 calling `makeSimpleButler`.
384 callPopulateButler : `bool`, optional
385 If True insert datasets into the butler prior to building a graph.
386 If False butler argument must not be None, and must be pre-populated.
387 Defaults to True.
388 root : `str`
389 Path or URI to the root location of the new repository. Only used if
390 ``butler`` is None.
391 run : `str`, optional
392 Name of the RUN collection to add to butler, only used if ``butler``
393 is None.
394 instrument : `str` or `None`, optional
395 The importable name of an instrument to be added to the pipeline or
396 if no instrument should be added then an empty string or `None`, by
397 default `None`. Only used if ``pipeline`` is `None`.
398 skipExistingIn
399 Expressions representing the collections to search for existing
400 output datasets that should be skipped. See
401 :ref:`daf_butler_ordered_collection_searches`.
402 inMemory : `bool`, optional
403 If true make in-memory repository, only used if ``butler`` is `None`.
404 userQuery : `str`, optional
405 The user query to pass to ``makeGraph``, by default an empty string.
406 datasetTypes : `dict` [ `str`, `list` ], optional
407 Dictionary whose keys are collection names and values are lists of
408 dataset type names. By default a single dataset of type "add_dataset0"
409 is added to a ``butler.run`` collection.
410 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
411 The query constraint variant that should be used to constrain the
412 query based on dataset existence, defaults to
413 `DatasetQueryConstraintVariant.ALL`.
414 makeDatastoreRecords : `bool`, optional
415 If `True` then add datstore records to generated quanta.
416 bind : `~collections.abc.Mapping`, optional
417 Mapping containing literal values that should be injected into the
418 ``userQuery`` expression, keyed by the identifiers they replace.
419 metadata : `~collections.abc.Mapping`, optional
420 Optional graph metadata.
422 Returns
423 -------
424 butler : `~lsst.daf.butler.Butler`
425 Butler instance
426 qgraph : `~lsst.pipe.base.QuantumGraph`
427 Quantum graph instance
428 """
429 if pipeline is None:
430 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument)
432 if butler is None:
433 if root is None:
434 raise ValueError("Must provide `root` when `butler` is None")
435 if callPopulateButler is False:
436 raise ValueError("populateButler can only be False when butler is supplied as an argument")
437 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
439 if callPopulateButler:
440 populateButler(pipeline, butler, datasetTypes=datasetTypes)
442 # Make the graph
443 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
444 builder = GraphBuilder(
445 registry=butler.registry,
446 skipExistingIn=skipExistingIn,
447 datastore=butler.datastore if makeDatastoreRecords else None,
448 )
449 if not run:
450 assert butler.run is not None, "Butler must have run defined"
451 run = butler.run
452 _LOG.debug(
453 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s",
454 butler.collections,
455 run,
456 userQuery,
457 bind,
458 )
459 if not metadata:
460 metadata = {}
461 metadata["output_run"] = run
463 qgraph = builder.makeGraph(
464 pipeline,
465 collections=butler.collections,
466 run=run,
467 userQuery=userQuery,
468 datasetQueryConstraint=datasetQueryConstraint,
469 bind=bind,
470 metadata=metadata,
471 )
473 return butler, qgraph