Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%
161 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:46 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:46 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Bunch of common classes and methods for use in unit tests.
29"""
30from __future__ import annotations
32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
34import itertools
35import logging
36from collections.abc import Iterable, Mapping, MutableMapping
37from typing import TYPE_CHECKING, Any, cast
39import lsst.daf.butler.tests as butlerTests
40import lsst.pex.config as pexConfig
41import numpy
42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
43from lsst.daf.butler.logging import ButlerLogRecords
44from lsst.resources import ResourcePath
45from lsst.utils import doImportType
46from lsst.utils.introspection import get_full_type_name
48from .. import connectionTypes as cT
49from .._instrument import Instrument
50from ..config import PipelineTaskConfig
51from ..connections import PipelineTaskConnections
52from ..graph import QuantumGraph
53from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
54from ..graphBuilder import GraphBuilder
55from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef
56from ..pipelineTask import PipelineTask
57from ..struct import Struct
58from ..task import _TASK_FULL_METADATA_TYPE
59from ..taskFactory import TaskFactory
61if TYPE_CHECKING:
62 from lsst.daf.butler import Registry
64_LOG = logging.getLogger(__name__)
67class SimpleInstrument(Instrument):
68 """Simple instrument class suitable for testing.
70 Parameters
71 ----------
72 *args : `~typing.Any`
73 Ignore parameters.
74 **kwargs : `~typing.Any`
75 Unused keyword arguments.
76 """
78 def __init__(self, *args: Any, **kwargs: Any):
79 pass
81 @staticmethod
82 def getName() -> str:
83 return "INSTRU"
85 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
86 return Formatter
88 def register(self, registry: Registry, *, update: bool = False) -> None:
89 pass
92class AddTaskConnections(
93 PipelineTaskConnections,
94 dimensions=("instrument", "detector"),
95 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
96):
97 """Connections for AddTask, has one input and two outputs,
98 plus one init output.
99 """
101 input = cT.Input(
102 name="add_dataset{in_tmpl}",
103 dimensions=["instrument", "detector"],
104 storageClass="NumpyArray",
105 doc="Input dataset type for this task",
106 )
107 output = cT.Output(
108 name="add_dataset{out_tmpl}",
109 dimensions=["instrument", "detector"],
110 storageClass="NumpyArray",
111 doc="Output dataset type for this task",
112 )
113 output2 = cT.Output(
114 name="add2_dataset{out_tmpl}",
115 dimensions=["instrument", "detector"],
116 storageClass="NumpyArray",
117 doc="Output dataset type for this task",
118 )
119 initout = cT.InitOutput(
120 name="add_init_output{out_tmpl}",
121 storageClass="NumpyArray",
122 doc="Init Output dataset type for this task",
123 )
126class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
127 """Config for AddTask."""
129 addend = pexConfig.Field[int](doc="amount to add", default=3)
132class AddTask(PipelineTask):
133 """Trivial PipelineTask for testing, has some extras useful for specific
134 unit tests.
135 """
137 ConfigClass = AddTaskConfig
138 _DefaultName = "add_task"
140 initout = numpy.array([999])
141 """InitOutputs for this task"""
143 taskFactory: AddTaskFactoryMock | None = None
144 """Factory that makes instances"""
146 def run(self, input: int) -> Struct:
147 if self.taskFactory:
148 # do some bookkeeping
149 if self.taskFactory.stopAt == self.taskFactory.countExec:
150 raise RuntimeError("pretend something bad happened")
151 self.taskFactory.countExec += 1
153 self.config = cast(AddTaskConfig, self.config)
154 self.metadata.add("add", self.config.addend)
155 output = input + self.config.addend
156 output2 = output + self.config.addend
157 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
158 return Struct(output=output, output2=output2)
161class AddTaskFactoryMock(TaskFactory):
162 """Special task factory that instantiates AddTask.
164 It also defines some bookkeeping variables used by AddTask to report
165 progress to unit tests.
167 Parameters
168 ----------
169 stopAt : `int`, optional
170 Number of times to call `run` before stopping.
171 """
173 def __init__(self, stopAt: int = -1):
174 self.countExec = 0 # incremented by AddTask
175 self.stopAt = stopAt # AddTask raises exception at this call to run()
177 def makeTask(
178 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None
179 ) -> PipelineTask:
180 taskClass = taskDef.taskClass
181 assert taskClass is not None
182 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label)
183 task.taskFactory = self # type: ignore
184 return task
187def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None:
188 """Register all dataset types used by tasks in a registry.
190 Copied and modified from `PreExecInit.initializeDatasetTypes`.
192 Parameters
193 ----------
194 registry : `~lsst.daf.butler.Registry`
195 Registry instance.
196 pipeline : `typing.Iterable` of `TaskDef`
197 Iterable of TaskDef instances, likely the output of the method
198 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object.
199 """
200 for taskDef in pipeline:
201 configDatasetType = DatasetType(
202 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
203 )
204 storageClass = "Packages"
205 packagesDatasetType = DatasetType(
206 "packages", {}, storageClass=storageClass, universe=registry.dimensions
207 )
208 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
209 for datasetType in itertools.chain(
210 datasetTypes.initInputs,
211 datasetTypes.initOutputs,
212 datasetTypes.inputs,
213 datasetTypes.outputs,
214 datasetTypes.prerequisites,
215 [configDatasetType, packagesDatasetType],
216 ):
217 _LOG.info("Registering %s with registry", datasetType)
218 # this is a no-op if it already exists and is consistent,
219 # and it raises if it is inconsistent. But components must be
220 # skipped
221 if not datasetType.isComponent():
222 registry.registerDatasetType(datasetType)
225def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
226 """Make a simple Pipeline for tests.
228 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
229 function. It can also be used to customize the pipeline used by
230 `makeSimpleQGraph()` function by calling this first and passing the result
231 to it.
233 Parameters
234 ----------
235 nQuanta : `int`
236 The number of quanta to add to the pipeline.
237 instrument : `str` or `None`, optional
238 The importable name of an instrument to be added to the pipeline or
239 if no instrument should be added then an empty string or `None`, by
240 default `None`.
242 Returns
243 -------
244 pipeline : `~lsst.pipe.base.Pipeline`
245 The created pipeline object.
246 """
247 pipeline = Pipeline("test pipeline")
248 # make a bunch of tasks that execute in well defined order (via data
249 # dependencies)
250 for lvl in range(nQuanta):
251 pipeline.addTask(AddTask, f"task{lvl}")
252 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
253 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
254 if instrument:
255 pipeline.addInstrument(instrument)
256 return pipeline
259def makeSimpleButler(
260 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
261) -> Butler:
262 """Create new data butler instance.
264 Parameters
265 ----------
266 root : `str`
267 Path or URI to the root location of the new repository.
268 run : `str`, optional
269 Run collection name.
270 inMemory : `bool`, optional
271 If true make in-memory repository.
272 config : `~lsst.daf.butler.Config`, optional
273 Configuration to use for new Butler, if `None` then default
274 configuration is used. If ``inMemory`` is `True` then configuration
275 is updated to use SQLite registry and file datastore in ``root``.
277 Returns
278 -------
279 butler : `~lsst.daf.butler.Butler`
280 Data butler instance.
281 """
282 root_path = ResourcePath(root, forceDirectory=True)
283 if not root_path.isLocal:
284 raise ValueError(f"Only works with local root not {root_path}")
285 butler_config = Config()
286 if config:
287 butler_config.update(Config(config))
288 if not inMemory:
289 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
290 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
291 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
292 butler = Butler.from_config(butler=repo, run=run)
293 return butler
296def populateButler(
297 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None
298) -> None:
299 """Populate data butler with data needed for test.
301 Initializes data butler with a bunch of items:
302 - registers dataset types which are defined by pipeline
303 - create dimensions data for (instrument, detector)
304 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
305 missing then a single dataset with type "add_dataset0" is added
307 All datasets added to butler have ``dataId={instrument=instrument,
308 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
309 used if pipeline is missing instrument definition. Type of the dataset is
310 guessed from dataset type name (assumes that pipeline is made of `AddTask`
311 tasks).
313 Parameters
314 ----------
315 pipeline : `~lsst.pipe.base.Pipeline`
316 Pipeline instance.
317 butler : `~lsst.daf.butler.Butler`
318 Data butler instance.
319 datasetTypes : `dict` [ `str`, `list` ], optional
320 Dictionary whose keys are collection names and values are lists of
321 dataset type names. By default a single dataset of type "add_dataset0"
322 is added to a ``butler.run`` collection.
323 """
324 # Add dataset types to registry
325 taskDefs = list(pipeline.toExpandedPipeline())
326 registerDatasetTypes(butler.registry, taskDefs)
328 instrument = pipeline.getInstrument()
329 if instrument is not None:
330 instrument_class = doImportType(instrument)
331 instrumentName = cast(Instrument, instrument_class).getName()
332 instrumentClass = get_full_type_name(instrument_class)
333 else:
334 instrumentName = "INSTR"
335 instrumentClass = None
337 # Add all needed dimensions to registry
338 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
339 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
341 taskDefMap = {taskDef.label: taskDef for taskDef in taskDefs}
342 # Add inputs to butler
343 if not datasetTypes:
344 datasetTypes = {None: ["add_dataset0"]}
345 for run, dsTypes in datasetTypes.items():
346 if run is not None:
347 butler.registry.registerRun(run)
348 for dsType in dsTypes:
349 if dsType == "packages":
350 # Version is intentionally inconsistent.
351 # Dict is convertible to Packages if Packages is installed.
352 data: Any = {"python": "9.9.99"}
353 butler.put(data, dsType, run=run)
354 else:
355 if dsType.endswith("_config"):
356 # find a config from matching task name or make a new one
357 taskLabel, _, _ = dsType.rpartition("_")
358 taskDef = taskDefMap.get(taskLabel)
359 if taskDef is not None:
360 data = taskDef.config
361 else:
362 data = AddTaskConfig()
363 elif dsType.endswith("_metadata"):
364 data = _TASK_FULL_METADATA_TYPE()
365 elif dsType.endswith("_log"):
366 data = ButlerLogRecords.from_records([])
367 else:
368 data = numpy.array([0.0, 1.0, 2.0, 5.0])
369 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
372def makeSimpleQGraph(
373 nQuanta: int = 5,
374 pipeline: Pipeline | None = None,
375 butler: Butler | None = None,
376 root: str | None = None,
377 callPopulateButler: bool = True,
378 run: str = "test",
379 instrument: str | None = None,
380 skipExistingIn: Any = None,
381 inMemory: bool = True,
382 userQuery: str = "",
383 datasetTypes: dict[str | None, list[str]] | None = None,
384 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
385 makeDatastoreRecords: bool = False,
386 bind: Mapping[str, Any] | None = None,
387 metadata: MutableMapping[str, Any] | None = None,
388) -> tuple[Butler, QuantumGraph]:
389 """Make simple `QuantumGraph` for tests.
391 Makes simple one-task pipeline with AddTask, sets up in-memory registry
392 and butler, fills them with minimal data, and generates QuantumGraph with
393 all of that.
395 Parameters
396 ----------
397 nQuanta : `int`
398 Number of quanta in a graph, only used if ``pipeline`` is None.
399 pipeline : `~lsst.pipe.base.Pipeline`
400 If `None` then a pipeline is made with `AddTask` and default
401 `AddTaskConfig`.
402 butler : `~lsst.daf.butler.Butler`, optional
403 Data butler instance, if None then new data butler is created by
404 calling `makeSimpleButler`.
405 root : `str`
406 Path or URI to the root location of the new repository. Only used if
407 ``butler`` is None.
408 callPopulateButler : `bool`, optional
409 If True insert datasets into the butler prior to building a graph.
410 If False butler argument must not be None, and must be pre-populated.
411 Defaults to True.
412 run : `str`, optional
413 Name of the RUN collection to add to butler, only used if ``butler``
414 is None.
415 instrument : `str` or `None`, optional
416 The importable name of an instrument to be added to the pipeline or
417 if no instrument should be added then an empty string or `None`, by
418 default `None`. Only used if ``pipeline`` is `None`.
419 skipExistingIn : `~typing.Any`
420 Expressions representing the collections to search for existing
421 output datasets that should be skipped. See
422 :ref:`daf_butler_ordered_collection_searches`.
423 inMemory : `bool`, optional
424 If true make in-memory repository, only used if ``butler`` is `None`.
425 userQuery : `str`, optional
426 The user query to pass to ``makeGraph``, by default an empty string.
427 datasetTypes : `dict` [ `str`, `list` ], optional
428 Dictionary whose keys are collection names and values are lists of
429 dataset type names. By default a single dataset of type "add_dataset0"
430 is added to a ``butler.run`` collection.
431 datasetQueryConstraint : `DatasetQueryConstraintVariant`
432 The query constraint variant that should be used to constrain the
433 query based on dataset existence, defaults to
434 `DatasetQueryConstraintVariant.ALL`.
435 makeDatastoreRecords : `bool`, optional
436 If `True` then add datastore records to generated quanta.
437 bind : `~collections.abc.Mapping`, optional
438 Mapping containing literal values that should be injected into the
439 ``userQuery`` expression, keyed by the identifiers they replace.
440 metadata : `~collections.abc.Mapping`, optional
441 Optional graph metadata.
443 Returns
444 -------
445 butler : `~lsst.daf.butler.Butler`
446 Butler instance.
447 qgraph : `~lsst.pipe.base.QuantumGraph`
448 Quantum graph instance.
449 """
450 if pipeline is None:
451 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument)
453 if butler is None:
454 if root is None:
455 raise ValueError("Must provide `root` when `butler` is None")
456 if callPopulateButler is False:
457 raise ValueError("populateButler can only be False when butler is supplied as an argument")
458 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
460 if callPopulateButler:
461 populateButler(pipeline, butler, datasetTypes=datasetTypes)
463 # Make the graph
464 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
465 builder = GraphBuilder(
466 registry=butler.registry,
467 skipExistingIn=skipExistingIn,
468 datastore=butler._datastore if makeDatastoreRecords else None,
469 )
470 if not run:
471 assert butler.run is not None, "Butler must have run defined"
472 run = butler.run
473 _LOG.debug(
474 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s",
475 butler.collections,
476 run,
477 userQuery,
478 bind,
479 )
480 if not metadata:
481 metadata = {}
482 metadata["output_run"] = run
484 qgraph = builder.makeGraph(
485 pipeline,
486 collections=butler.collections,
487 run=run,
488 userQuery=userQuery,
489 datasetQueryConstraint=datasetQueryConstraint,
490 bind=bind,
491 metadata=metadata,
492 )
494 return butler, qgraph