Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%
161 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-28 09:35 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-28 09:35 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Bunch of common classes and methods for use in unit tests.
29"""
30from __future__ import annotations
32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
34import itertools
35import logging
36from collections.abc import Iterable, Mapping, MutableMapping
37from typing import TYPE_CHECKING, Any, cast
39import lsst.daf.butler.tests as butlerTests
40import lsst.pex.config as pexConfig
41import numpy
42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
43from lsst.daf.butler.core.logging import ButlerLogRecords
44from lsst.resources import ResourcePath
45from lsst.utils import doImportType
46from lsst.utils.introspection import get_full_type_name
48from .. import connectionTypes as cT
49from .._instrument import Instrument
50from ..config import PipelineTaskConfig
51from ..connections import PipelineTaskConnections
52from ..graph import QuantumGraph
53from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant
54from ..graphBuilder import GraphBuilder
55from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef
56from ..pipelineTask import PipelineTask
57from ..struct import Struct
58from ..task import _TASK_FULL_METADATA_TYPE
59from ..taskFactory import TaskFactory
61if TYPE_CHECKING:
62 from lsst.daf.butler import Registry
64_LOG = logging.getLogger(__name__)
67class SimpleInstrument(Instrument):
68 """Simple instrument class suitable for testing."""
70 def __init__(self, *args: Any, **kwargs: Any):
71 pass
73 @staticmethod
74 def getName() -> str:
75 return "INSTRU"
77 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
78 return Formatter
80 def register(self, registry: Registry, *, update: bool = False) -> None:
81 pass
84class AddTaskConnections(
85 PipelineTaskConnections,
86 dimensions=("instrument", "detector"),
87 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
88):
89 """Connections for AddTask, has one input and two outputs,
90 plus one init output.
91 """
93 input = cT.Input(
94 name="add_dataset{in_tmpl}",
95 dimensions=["instrument", "detector"],
96 storageClass="NumpyArray",
97 doc="Input dataset type for this task",
98 )
99 output = cT.Output(
100 name="add_dataset{out_tmpl}",
101 dimensions=["instrument", "detector"],
102 storageClass="NumpyArray",
103 doc="Output dataset type for this task",
104 )
105 output2 = cT.Output(
106 name="add2_dataset{out_tmpl}",
107 dimensions=["instrument", "detector"],
108 storageClass="NumpyArray",
109 doc="Output dataset type for this task",
110 )
111 initout = cT.InitOutput(
112 name="add_init_output{out_tmpl}",
113 storageClass="NumpyArray",
114 doc="Init Output dataset type for this task",
115 )
118class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
119 """Config for AddTask."""
121 addend = pexConfig.Field[int](doc="amount to add", default=3)
124class AddTask(PipelineTask):
125 """Trivial PipelineTask for testing, has some extras useful for specific
126 unit tests.
127 """
129 ConfigClass = AddTaskConfig
130 _DefaultName = "add_task"
132 initout = numpy.array([999])
133 """InitOutputs for this task"""
135 taskFactory: AddTaskFactoryMock | None = None
136 """Factory that makes instances"""
138 def run(self, input: int) -> Struct:
139 if self.taskFactory:
140 # do some bookkeeping
141 if self.taskFactory.stopAt == self.taskFactory.countExec:
142 raise RuntimeError("pretend something bad happened")
143 self.taskFactory.countExec += 1
145 self.config = cast(AddTaskConfig, self.config)
146 self.metadata.add("add", self.config.addend)
147 output = input + self.config.addend
148 output2 = output + self.config.addend
149 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
150 return Struct(output=output, output2=output2)
153class AddTaskFactoryMock(TaskFactory):
154 """Special task factory that instantiates AddTask.
156 It also defines some bookkeeping variables used by AddTask to report
157 progress to unit tests.
158 """
160 def __init__(self, stopAt: int = -1):
161 self.countExec = 0 # incremented by AddTask
162 self.stopAt = stopAt # AddTask raises exception at this call to run()
164 def makeTask(
165 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None
166 ) -> PipelineTask:
167 taskClass = taskDef.taskClass
168 assert taskClass is not None
169 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label)
170 task.taskFactory = self # type: ignore
171 return task
174def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None:
175 """Register all dataset types used by tasks in a registry.
177 Copied and modified from `PreExecInit.initializeDatasetTypes`.
179 Parameters
180 ----------
181 registry : `~lsst.daf.butler.Registry`
182 Registry instance.
183 pipeline : `typing.Iterable` of `TaskDef`
184 Iterable of TaskDef instances, likely the output of the method
185 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object.
186 """
187 for taskDef in pipeline:
188 configDatasetType = DatasetType(
189 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions
190 )
191 storageClass = "Packages"
192 packagesDatasetType = DatasetType(
193 "packages", {}, storageClass=storageClass, universe=registry.dimensions
194 )
195 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
196 for datasetType in itertools.chain(
197 datasetTypes.initInputs,
198 datasetTypes.initOutputs,
199 datasetTypes.inputs,
200 datasetTypes.outputs,
201 datasetTypes.prerequisites,
202 [configDatasetType, packagesDatasetType],
203 ):
204 _LOG.info("Registering %s with registry", datasetType)
205 # this is a no-op if it already exists and is consistent,
206 # and it raises if it is inconsistent. But components must be
207 # skipped
208 if not datasetType.isComponent():
209 registry.registerDatasetType(datasetType)
212def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
213 """Make a simple Pipeline for tests.
215 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
216 function. It can also be used to customize the pipeline used by
217 `makeSimpleQGraph()` function by calling this first and passing the result
218 to it.
220 Parameters
221 ----------
222 nQuanta : `int`
223 The number of quanta to add to the pipeline.
224 instrument : `str` or `None`, optional
225 The importable name of an instrument to be added to the pipeline or
226 if no instrument should be added then an empty string or `None`, by
227 default None
229 Returns
230 -------
231 pipeline : `~lsst.pipe.base.Pipeline`
232 The created pipeline object.
233 """
234 pipeline = Pipeline("test pipeline")
235 # make a bunch of tasks that execute in well defined order (via data
236 # dependencies)
237 for lvl in range(nQuanta):
238 pipeline.addTask(AddTask, f"task{lvl}")
239 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
240 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
241 if instrument:
242 pipeline.addInstrument(instrument)
243 return pipeline
246def makeSimpleButler(
247 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
248) -> Butler:
249 """Create new data butler instance.
251 Parameters
252 ----------
253 root : `str`
254 Path or URI to the root location of the new repository.
255 run : `str`, optional
256 Run collection name.
257 inMemory : `bool`, optional
258 If true make in-memory repository.
259 config : `~lsst.daf.butler.Config`, optional
260 Configuration to use for new Butler, if `None` then default
261 configuration is used. If ``inMemory`` is `True` then configuration
262 is updated to use SQLite registry and file datastore in ``root``.
264 Returns
265 -------
266 butler : `~lsst.daf.butler.Butler`
267 Data butler instance.
268 """
269 root_path = ResourcePath(root, forceDirectory=True)
270 if not root_path.isLocal:
271 raise ValueError(f"Only works with local root not {root_path}")
272 butler_config = Config()
273 if config:
274 butler_config.update(Config(config))
275 if not inMemory:
276 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
277 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
278 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
279 butler = Butler(butler=repo, run=run)
280 return butler
283def populateButler(
284 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None
285) -> None:
286 """Populate data butler with data needed for test.
288 Initializes data butler with a bunch of items:
289 - registers dataset types which are defined by pipeline
290 - create dimensions data for (instrument, detector)
291 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
292 missing then a single dataset with type "add_dataset0" is added
294 All datasets added to butler have ``dataId={instrument=instrument,
295 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
296 used if pipeline is missing instrument definition. Type of the dataset is
297 guessed from dataset type name (assumes that pipeline is made of `AddTask`
298 tasks).
300 Parameters
301 ----------
302 pipeline : `~lsst.pipe.base.Pipeline`
303 Pipeline instance.
304 butler : `~lsst.daf.butler.Butler`
305 Data butler instance.
306 datasetTypes : `dict` [ `str`, `list` ], optional
307 Dictionary whose keys are collection names and values are lists of
308 dataset type names. By default a single dataset of type "add_dataset0"
309 is added to a ``butler.run`` collection.
310 """
311 # Add dataset types to registry
312 taskDefs = list(pipeline.toExpandedPipeline())
313 registerDatasetTypes(butler.registry, taskDefs)
315 instrument = pipeline.getInstrument()
316 if instrument is not None:
317 instrument_class = doImportType(instrument)
318 instrumentName = cast(Instrument, instrument_class).getName()
319 instrumentClass = get_full_type_name(instrument_class)
320 else:
321 instrumentName = "INSTR"
322 instrumentClass = None
324 # Add all needed dimensions to registry
325 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
326 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
328 taskDefMap = {taskDef.label: taskDef for taskDef in taskDefs}
329 # Add inputs to butler
330 if not datasetTypes:
331 datasetTypes = {None: ["add_dataset0"]}
332 for run, dsTypes in datasetTypes.items():
333 if run is not None:
334 butler.registry.registerRun(run)
335 for dsType in dsTypes:
336 if dsType == "packages":
337 # Version is intentionally inconsistent.
338 # Dict is convertible to Packages if Packages is installed.
339 data: Any = {"python": "9.9.99"}
340 butler.put(data, dsType, run=run)
341 else:
342 if dsType.endswith("_config"):
343 # find a config from matching task name or make a new one
344 taskLabel, _, _ = dsType.rpartition("_")
345 taskDef = taskDefMap.get(taskLabel)
346 if taskDef is not None:
347 data = taskDef.config
348 else:
349 data = AddTaskConfig()
350 elif dsType.endswith("_metadata"):
351 data = _TASK_FULL_METADATA_TYPE()
352 elif dsType.endswith("_log"):
353 data = ButlerLogRecords.from_records([])
354 else:
355 data = numpy.array([0.0, 1.0, 2.0, 5.0])
356 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
359def makeSimpleQGraph(
360 nQuanta: int = 5,
361 pipeline: Pipeline | None = None,
362 butler: Butler | None = None,
363 root: str | None = None,
364 callPopulateButler: bool = True,
365 run: str = "test",
366 instrument: str | None = None,
367 skipExistingIn: Any = None,
368 inMemory: bool = True,
369 userQuery: str = "",
370 datasetTypes: dict[str | None, list[str]] | None = None,
371 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
372 makeDatastoreRecords: bool = False,
373 bind: Mapping[str, Any] | None = None,
374 metadata: MutableMapping[str, Any] | None = None,
375) -> tuple[Butler, QuantumGraph]:
376 """Make simple `QuantumGraph` for tests.
378 Makes simple one-task pipeline with AddTask, sets up in-memory registry
379 and butler, fills them with minimal data, and generates QuantumGraph with
380 all of that.
382 Parameters
383 ----------
384 nQuanta : `int`
385 Number of quanta in a graph, only used if ``pipeline`` is None.
386 pipeline : `~lsst.pipe.base.Pipeline`
387 If `None` then a pipeline is made with `AddTask` and default
388 `AddTaskConfig`.
389 butler : `~lsst.daf.butler.Butler`, optional
390 Data butler instance, if None then new data butler is created by
391 calling `makeSimpleButler`.
392 callPopulateButler : `bool`, optional
393 If True insert datasets into the butler prior to building a graph.
394 If False butler argument must not be None, and must be pre-populated.
395 Defaults to True.
396 root : `str`
397 Path or URI to the root location of the new repository. Only used if
398 ``butler`` is None.
399 run : `str`, optional
400 Name of the RUN collection to add to butler, only used if ``butler``
401 is None.
402 instrument : `str` or `None`, optional
403 The importable name of an instrument to be added to the pipeline or
404 if no instrument should be added then an empty string or `None`, by
405 default `None`. Only used if ``pipeline`` is `None`.
406 skipExistingIn
407 Expressions representing the collections to search for existing
408 output datasets that should be skipped. See
409 :ref:`daf_butler_ordered_collection_searches`.
410 inMemory : `bool`, optional
411 If true make in-memory repository, only used if ``butler`` is `None`.
412 userQuery : `str`, optional
413 The user query to pass to ``makeGraph``, by default an empty string.
414 datasetTypes : `dict` [ `str`, `list` ], optional
415 Dictionary whose keys are collection names and values are lists of
416 dataset type names. By default a single dataset of type "add_dataset0"
417 is added to a ``butler.run`` collection.
418 datasetQueryQConstraint : `DatasetQueryConstraintVariant`
419 The query constraint variant that should be used to constrain the
420 query based on dataset existence, defaults to
421 `DatasetQueryConstraintVariant.ALL`.
422 makeDatastoreRecords : `bool`, optional
423 If `True` then add datastore records to generated quanta.
424 bind : `~collections.abc.Mapping`, optional
425 Mapping containing literal values that should be injected into the
426 ``userQuery`` expression, keyed by the identifiers they replace.
427 metadata : `~collections.abc.Mapping`, optional
428 Optional graph metadata.
430 Returns
431 -------
432 butler : `~lsst.daf.butler.Butler`
433 Butler instance
434 qgraph : `~lsst.pipe.base.QuantumGraph`
435 Quantum graph instance
436 """
437 if pipeline is None:
438 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument)
440 if butler is None:
441 if root is None:
442 raise ValueError("Must provide `root` when `butler` is None")
443 if callPopulateButler is False:
444 raise ValueError("populateButler can only be False when butler is supplied as an argument")
445 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
447 if callPopulateButler:
448 populateButler(pipeline, butler, datasetTypes=datasetTypes)
450 # Make the graph
451 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn)
452 builder = GraphBuilder(
453 registry=butler.registry,
454 skipExistingIn=skipExistingIn,
455 datastore=butler._datastore if makeDatastoreRecords else None,
456 )
457 if not run:
458 assert butler.run is not None, "Butler must have run defined"
459 run = butler.run
460 _LOG.debug(
461 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s",
462 butler.collections,
463 run,
464 userQuery,
465 bind,
466 )
467 if not metadata:
468 metadata = {}
469 metadata["output_run"] = run
471 qgraph = builder.makeGraph(
472 pipeline,
473 collections=butler.collections,
474 run=run,
475 userQuery=userQuery,
476 datasetQueryConstraint=datasetQueryConstraint,
477 bind=bind,
478 metadata=metadata,
479 )
481 return butler, qgraph