Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 23%
188 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 11:28 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 11:28 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Bunch of common classes and methods for use in unit tests.
29"""
30from __future__ import annotations
32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
34import logging
35import warnings
36from collections.abc import Iterable, Mapping, MutableMapping
37from typing import TYPE_CHECKING, Any, cast
39import lsst.daf.butler.tests as butlerTests
40import lsst.pex.config as pexConfig
41import numpy
42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
43from lsst.daf.butler.logging import ButlerLogRecords
44from lsst.resources import ResourcePath
45from lsst.utils import doImportType
46from lsst.utils.introspection import find_outside_stacklevel, get_full_type_name
48from .. import connectionTypes as cT
49from .._instrument import Instrument
50from ..all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
51from ..all_dimensions_quantum_graph_builder import DatasetQueryConstraintVariant as DSQVariant
52from ..automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME, PACKAGES_INIT_OUTPUT_STORAGE_CLASS
53from ..config import PipelineTaskConfig
54from ..connections import PipelineTaskConnections
55from ..graph import QuantumGraph
56from ..pipeline import Pipeline, TaskDef
57from ..pipeline_graph import PipelineGraph, TaskNode
58from ..pipelineTask import PipelineTask
59from ..struct import Struct
60from ..task import _TASK_FULL_METADATA_TYPE
61from ..taskFactory import TaskFactory
63if TYPE_CHECKING:
64 from lsst.daf.butler import Registry
66_LOG = logging.getLogger(__name__)
69class SimpleInstrument(Instrument):
70 """Simple instrument class suitable for testing.
72 Parameters
73 ----------
74 *args : `~typing.Any`
75 Ignore parameters.
76 **kwargs : `~typing.Any`
77 Unused keyword arguments.
78 """
80 def __init__(self, *args: Any, **kwargs: Any):
81 pass
83 @staticmethod
84 def getName() -> str:
85 return "INSTRU"
87 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
88 return Formatter
90 def register(self, registry: Registry, *, update: bool = False) -> None:
91 pass
94class AddTaskConnections(
95 PipelineTaskConnections,
96 dimensions=("instrument", "detector"),
97 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
98):
99 """Connections for AddTask, has one input and two outputs,
100 plus one init output.
101 """
103 input = cT.Input(
104 name="add_dataset{in_tmpl}",
105 dimensions=["instrument", "detector"],
106 storageClass="NumpyArray",
107 doc="Input dataset type for this task",
108 )
109 output = cT.Output(
110 name="add_dataset{out_tmpl}",
111 dimensions=["instrument", "detector"],
112 storageClass="NumpyArray",
113 doc="Output dataset type for this task",
114 )
115 output2 = cT.Output(
116 name="add2_dataset{out_tmpl}",
117 dimensions=["instrument", "detector"],
118 storageClass="NumpyArray",
119 doc="Output dataset type for this task",
120 )
121 initout = cT.InitOutput(
122 name="add_init_output{out_tmpl}",
123 storageClass="NumpyArray",
124 doc="Init Output dataset type for this task",
125 )
128class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
129 """Config for AddTask."""
131 addend = pexConfig.Field[int](doc="amount to add", default=3)
134class AddTask(PipelineTask):
135 """Trivial PipelineTask for testing, has some extras useful for specific
136 unit tests.
137 """
139 ConfigClass = AddTaskConfig
140 _DefaultName = "add_task"
142 initout = numpy.array([999])
143 """InitOutputs for this task"""
145 taskFactory: AddTaskFactoryMock | None = None
146 """Factory that makes instances"""
148 def run(self, input: int) -> Struct:
149 if self.taskFactory:
150 # do some bookkeeping
151 if self.taskFactory.stopAt == self.taskFactory.countExec:
152 raise RuntimeError("pretend something bad happened")
153 self.taskFactory.countExec += 1
155 self.config = cast(AddTaskConfig, self.config)
156 self.metadata.add("add", self.config.addend)
157 output = input + self.config.addend
158 output2 = output + self.config.addend
159 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
160 return Struct(output=output, output2=output2)
163class AddTaskFactoryMock(TaskFactory):
164 """Special task factory that instantiates AddTask.
166 It also defines some bookkeeping variables used by AddTask to report
167 progress to unit tests.
169 Parameters
170 ----------
171 stopAt : `int`, optional
172 Number of times to call `run` before stopping.
173 """
175 def __init__(self, stopAt: int = -1):
176 self.countExec = 0 # incremented by AddTask
177 self.stopAt = stopAt # AddTask raises exception at this call to run()
179 def makeTask(
180 self,
181 task_node: TaskDef | TaskNode,
182 /,
183 butler: LimitedButler,
184 initInputRefs: Iterable[DatasetRef] | None,
185 ) -> PipelineTask:
186 if isinstance(task_node, TaskDef):
187 # TODO: remove support on DM-40443.
188 warnings.warn(
189 "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.",
190 FutureWarning,
191 find_outside_stacklevel("lsst.pipe.base"),
192 )
193 task_class = task_node.taskClass
194 assert task_class is not None
195 else:
196 task_class = task_node.task_class
197 task = task_class(config=task_node.config, initInputs=None, name=task_node.label)
198 task.taskFactory = self # type: ignore
199 return task
202def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None:
203 """Register all dataset types used by tasks in a registry.
205 Copied and modified from `PreExecInit.initializeDatasetTypes`.
207 Parameters
208 ----------
209 registry : `~lsst.daf.butler.Registry`
210 Registry instance.
211 pipeline : `.Pipeline`, `~collections.abc..Iterable` of `.TaskDef`, or \
212 `.pipeline_graph.PipelineGraph`
213 The pipeline whose dataset types should be registered, as a `.Pipeline`
214 instance, `.PipelineGraph` instance, or iterable of `.TaskDef`
215 instances. Support for `.TaskDef` is deprecated and will be removed
216 after v27.
217 """
218 match pipeline:
219 case PipelineGraph() as pipeline_graph:
220 pass
221 case Pipeline():
222 pipeline_graph = pipeline.to_graph()
223 case _:
224 warnings.warn(
225 "Passing TaskDefs is deprecated and will not be supported after v27.",
226 category=FutureWarning,
227 stacklevel=find_outside_stacklevel("lsst.pipe.base"),
228 )
229 pipeline_graph = PipelineGraph()
230 for task_def in pipeline:
231 pipeline_graph.add_task(
232 task_def.label, task_def.taskClass, task_def.config, connections=task_def.connections
233 )
234 pipeline_graph.resolve(registry)
235 dataset_types = [node.dataset_type for node in pipeline_graph.dataset_types.values()]
236 dataset_types.append(
237 DatasetType(
238 PACKAGES_INIT_OUTPUT_NAME,
239 {},
240 storageClass=PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
241 universe=registry.dimensions,
242 )
243 )
244 for dataset_type in dataset_types:
245 _LOG.info("Registering %s with registry", dataset_type)
246 registry.registerDatasetType(dataset_type)
249def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
250 """Make a simple Pipeline for tests.
252 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
253 function. It can also be used to customize the pipeline used by
254 `makeSimpleQGraph()` function by calling this first and passing the result
255 to it.
257 Parameters
258 ----------
259 nQuanta : `int`
260 The number of quanta to add to the pipeline.
261 instrument : `str` or `None`, optional
262 The importable name of an instrument to be added to the pipeline or
263 if no instrument should be added then an empty string or `None`, by
264 default `None`.
266 Returns
267 -------
268 pipeline : `~lsst.pipe.base.Pipeline`
269 The created pipeline object.
270 """
271 pipeline = Pipeline("test pipeline")
272 # make a bunch of tasks that execute in well defined order (via data
273 # dependencies)
274 for lvl in range(nQuanta):
275 pipeline.addTask(AddTask, f"task{lvl}")
276 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
277 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
278 if instrument:
279 pipeline.addInstrument(instrument)
280 return pipeline
283def makeSimpleButler(
284 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
285) -> Butler:
286 """Create new data butler instance.
288 Parameters
289 ----------
290 root : `str`
291 Path or URI to the root location of the new repository.
292 run : `str`, optional
293 Run collection name.
294 inMemory : `bool`, optional
295 If true make in-memory repository.
296 config : `~lsst.daf.butler.Config`, optional
297 Configuration to use for new Butler, if `None` then default
298 configuration is used. If ``inMemory`` is `True` then configuration
299 is updated to use SQLite registry and file datastore in ``root``.
301 Returns
302 -------
303 butler : `~lsst.daf.butler.Butler`
304 Data butler instance.
305 """
306 root_path = ResourcePath(root, forceDirectory=True)
307 if not root_path.isLocal:
308 raise ValueError(f"Only works with local root not {root_path}")
309 butler_config = Config()
310 if config:
311 butler_config.update(Config(config))
312 if not inMemory:
313 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
314 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
315 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
316 butler = Butler.from_config(butler=repo, run=run)
317 return butler
320def populateButler(
321 pipeline: Pipeline | PipelineGraph,
322 butler: Butler,
323 datasetTypes: dict[str | None, list[str]] | None = None,
324 instrument: str | None = None,
325) -> None:
326 """Populate data butler with data needed for test.
328 Initializes data butler with a bunch of items:
329 - registers dataset types which are defined by pipeline
330 - create dimensions data for (instrument, detector)
331 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
332 missing then a single dataset with type "add_dataset0" is added.
334 All datasets added to butler have ``dataId={instrument=instrument,
335 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
336 used if pipeline is missing instrument definition. Type of the dataset is
337 guessed from dataset type name (assumes that pipeline is made of `AddTask`
338 tasks).
340 Parameters
341 ----------
342 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph`
343 Pipeline or pipeline graph instance.
344 butler : `~lsst.daf.butler.Butler`
345 Data butler instance.
346 datasetTypes : `dict` [ `str`, `list` ], optional
347 Dictionary whose keys are collection names and values are lists of
348 dataset type names. By default a single dataset of type "add_dataset0"
349 is added to a ``butler.run`` collection.
350 instrument : `str`, optional
351 Fully-qualified name of the instrumemnt class (as it appears in a
352 pipeline). This is needed to propagate the instrument from the
353 original pipeline if a pipeline graph is passed instead.
354 """
355 # Add dataset types to registry
356 match pipeline:
357 case PipelineGraph() as pipeline_graph:
358 pass
359 case Pipeline():
360 pipeline_graph = pipeline.to_graph()
361 if instrument is None:
362 instrument = pipeline.getInstrument()
363 case _:
364 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.")
366 registerDatasetTypes(butler.registry, pipeline_graph)
368 if instrument is not None:
369 instrument_class = doImportType(instrument)
370 instrumentName = cast(Instrument, instrument_class).getName()
371 instrumentClass = get_full_type_name(instrument_class)
372 else:
373 instrumentName = "INSTR"
374 instrumentClass = None
376 # Add all needed dimensions to registry
377 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
378 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
380 # Add inputs to butler.
381 if not datasetTypes:
382 datasetTypes = {None: ["add_dataset0"]}
383 for run, dsTypes in datasetTypes.items():
384 if run is not None:
385 butler.registry.registerRun(run)
386 for dsType in dsTypes:
387 if dsType == "packages":
388 # Version is intentionally inconsistent.
389 # Dict is convertible to Packages if Packages is installed.
390 data: Any = {"python": "9.9.99"}
391 butler.put(data, dsType, run=run)
392 else:
393 if dsType.endswith("_config"):
394 # find a config from matching task name or make a new one
395 taskLabel, _, _ = dsType.rpartition("_")
396 task_node = pipeline_graph.tasks.get(taskLabel)
397 if task_node is not None:
398 data = task_node.config
399 else:
400 data = AddTaskConfig()
401 elif dsType.endswith("_metadata"):
402 data = _TASK_FULL_METADATA_TYPE()
403 elif dsType.endswith("_log"):
404 data = ButlerLogRecords.from_records([])
405 else:
406 data = numpy.array([0.0, 1.0, 2.0, 5.0])
407 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
410def makeSimpleQGraph(
411 nQuanta: int = 5,
412 pipeline: Pipeline | PipelineGraph | None = None,
413 butler: Butler | None = None,
414 root: str | None = None,
415 callPopulateButler: bool = True,
416 run: str = "test",
417 instrument: str | None = None,
418 skipExistingIn: Any = None,
419 inMemory: bool = True,
420 userQuery: str = "",
421 datasetTypes: dict[str | None, list[str]] | None = None,
422 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
423 makeDatastoreRecords: bool = False,
424 bind: Mapping[str, Any] | None = None,
425 metadata: MutableMapping[str, Any] | None = None,
426) -> tuple[Butler, QuantumGraph]:
427 """Make simple `QuantumGraph` for tests.
429 Makes simple one-task pipeline with AddTask, sets up in-memory registry
430 and butler, fills them with minimal data, and generates QuantumGraph with
431 all of that.
433 Parameters
434 ----------
435 nQuanta : `int`
436 Number of quanta in a graph, only used if ``pipeline`` is None.
437 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph`
438 Pipeline or pipeline graph to build the graph from. If `None`, a
439 pipeline is made with `AddTask` and default `AddTaskConfig`.
440 butler : `~lsst.daf.butler.Butler`, optional
441 Data butler instance, if None then new data butler is created by
442 calling `makeSimpleButler`.
443 root : `str`
444 Path or URI to the root location of the new repository. Only used if
445 ``butler`` is None.
446 callPopulateButler : `bool`, optional
447 If True insert datasets into the butler prior to building a graph.
448 If False butler argument must not be None, and must be pre-populated.
449 Defaults to True.
450 run : `str`, optional
451 Name of the RUN collection to add to butler, only used if ``butler``
452 is None.
453 instrument : `str` or `None`, optional
454 The importable name of an instrument to be added to the pipeline or
455 if no instrument should be added then an empty string or `None`, by
456 default `None`. Only used if ``pipeline`` is `None`.
457 skipExistingIn : `~typing.Any`
458 Expressions representing the collections to search for existing
459 output datasets that should be skipped. See
460 :ref:`daf_butler_ordered_collection_searches`.
461 inMemory : `bool`, optional
462 If true make in-memory repository, only used if ``butler`` is `None`.
463 userQuery : `str`, optional
464 The user query to pass to ``makeGraph``, by default an empty string.
465 datasetTypes : `dict` [ `str`, `list` ], optional
466 Dictionary whose keys are collection names and values are lists of
467 dataset type names. By default a single dataset of type "add_dataset0"
468 is added to a ``butler.run`` collection.
469 datasetQueryConstraint : `DatasetQueryConstraintVariant`
470 The query constraint variant that should be used to constrain the
471 query based on dataset existence, defaults to
472 `DatasetQueryConstraintVariant.ALL`.
473 makeDatastoreRecords : `bool`, optional
474 If `True` then add datastore records to generated quanta.
475 bind : `~collections.abc.Mapping`, optional
476 Mapping containing literal values that should be injected into the
477 ``userQuery`` expression, keyed by the identifiers they replace.
478 metadata : `~collections.abc.Mapping`, optional
479 Optional graph metadata.
481 Returns
482 -------
483 butler : `~lsst.daf.butler.Butler`
484 Butler instance.
485 qgraph : `~lsst.pipe.base.QuantumGraph`
486 Quantum graph instance.
487 """
488 match pipeline:
489 case PipelineGraph() as pipeline_graph:
490 pass
491 case Pipeline():
492 pipeline_graph = pipeline.to_graph()
493 if instrument is None:
494 instrument = pipeline.getInstrument()
495 case None:
496 pipeline_graph = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument).to_graph()
497 case _:
498 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.")
500 if butler is None:
501 if root is None:
502 raise ValueError("Must provide `root` when `butler` is None")
503 if callPopulateButler is False:
504 raise ValueError("populateButler can only be False when butler is supplied as an argument")
505 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
507 if callPopulateButler:
508 populateButler(pipeline_graph, butler, datasetTypes=datasetTypes, instrument=instrument)
510 # Make the graph
511 _LOG.debug(
512 "Instantiating QuantumGraphBuilder, "
513 "skip_existing_in=%s, input_collections=%r, output_run=%r, where=%r, bind=%s.",
514 skipExistingIn,
515 butler.collections,
516 run,
517 userQuery,
518 bind,
519 )
520 if not run:
521 assert butler.run is not None, "Butler must have run defined"
522 run = butler.run
523 builder = AllDimensionsQuantumGraphBuilder(
524 pipeline_graph,
525 butler,
526 skip_existing_in=skipExistingIn if skipExistingIn is not None else [],
527 input_collections=butler.collections if butler.collections is not None else [run],
528 output_run=run,
529 where=userQuery,
530 bind=bind,
531 dataset_query_constraint=datasetQueryConstraint,
532 )
533 _LOG.debug("Calling QuantumGraphBuilder.build.")
534 if not metadata:
535 metadata = {}
536 metadata["output_run"] = run
538 qgraph = builder.build(metadata=metadata, attach_datastore_records=makeDatastoreRecords)
540 return butler, qgraph