Coverage for python / lsst / pipe / base / tests / simpleQGraph.py: 23%
186 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Bunch of common classes and methods for use in unit tests."""
30from __future__ import annotations
32__all__ = ["AddTask", "AddTaskConfig", "AddTaskFactoryMock"]
34import logging
35from collections.abc import Iterable, Mapping, MutableMapping
36from typing import TYPE_CHECKING, Any, cast
38import numpy
40import lsst.daf.butler.tests as butlerTests
41import lsst.pex.config as pexConfig
42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler
43from lsst.daf.butler.logging import ButlerLogRecords
44from lsst.daf.butler.registry import RegistryDefaults
45from lsst.resources import ResourcePath
46from lsst.utils import doImportType
47from lsst.utils.introspection import get_full_type_name
49from .. import connectionTypes as cT
50from .._instrument import Instrument
51from ..all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
52from ..all_dimensions_quantum_graph_builder import DatasetQueryConstraintVariant as DSQVariant
53from ..automatic_connection_constants import (
54 CONFIG_INIT_OUTPUT_CONNECTION_NAME,
55 LOG_OUTPUT_CONNECTION_NAME,
56 METADATA_OUTPUT_CONNECTION_NAME,
57 PACKAGES_INIT_OUTPUT_NAME,
58 PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
59)
60from ..config import PipelineTaskConfig
61from ..connections import PipelineTaskConnections
62from ..graph import QuantumGraph
63from ..pipeline import Pipeline
64from ..pipeline_graph import PipelineGraph, TaskNode
65from ..pipelineTask import PipelineTask
66from ..struct import Struct
67from ..task import _TASK_FULL_METADATA_TYPE
68from ..taskFactory import TaskFactory
70if TYPE_CHECKING:
71 from lsst.daf.butler import Registry
73_LOG = logging.getLogger(__name__)
76class SimpleInstrument(Instrument):
77 """Simple instrument class suitable for testing.
79 Parameters
80 ----------
81 *args : `~typing.Any`
82 Ignore parameters.
83 **kwargs : `~typing.Any`
84 Unused keyword arguments.
85 """
87 def __init__(self, *args: Any, **kwargs: Any):
88 pass
90 @staticmethod
91 def getName() -> str:
92 return "INSTRU"
94 def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
95 return Formatter
97 def register(self, registry: Registry, *, update: bool = False) -> None:
98 pass
101class AddTaskConnections(
102 PipelineTaskConnections,
103 dimensions=("instrument", "detector"),
104 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
105):
106 """Connections for AddTask, has one input and two outputs,
107 plus one init output.
108 """
110 input = cT.Input(
111 name="add_dataset{in_tmpl}",
112 dimensions=["instrument", "detector"],
113 storageClass="NumpyArray",
114 doc="Input dataset type for this task",
115 )
116 output = cT.Output(
117 name="add_dataset{out_tmpl}",
118 dimensions=["instrument", "detector"],
119 storageClass="NumpyArray",
120 doc="Output dataset type for this task",
121 )
122 output2 = cT.Output(
123 name="add2_dataset{out_tmpl}",
124 dimensions=["instrument", "detector"],
125 storageClass="NumpyArray",
126 doc="Output dataset type for this task",
127 )
128 initout = cT.InitOutput(
129 name="add_init_output{out_tmpl}",
130 storageClass="NumpyArray",
131 doc="Init Output dataset type for this task",
132 )
135class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections):
136 """Config for AddTask."""
138 addend = pexConfig.Field[int](doc="amount to add", default=3)
141class AddTask(PipelineTask):
142 """Trivial PipelineTask for testing, has some extras useful for specific
143 unit tests.
144 """
146 ConfigClass = AddTaskConfig
147 _DefaultName = "add_task"
149 initout = numpy.array([999])
150 """InitOutputs for this task"""
152 taskFactory: AddTaskFactoryMock | None = None
153 """Factory that makes instances"""
155 def run(self, input: int) -> Struct:
156 if self.taskFactory:
157 # do some bookkeeping
158 if self.taskFactory.stopAt == self.taskFactory.countExec:
159 raise RuntimeError("pretend something bad happened")
160 self.taskFactory.countExec += 1
162 self.config = cast(AddTaskConfig, self.config)
163 self.metadata.add("add", self.config.addend)
164 output = input + self.config.addend
165 output2 = output + self.config.addend
166 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
167 return Struct(output=output, output2=output2)
170class AddTaskFactoryMock(TaskFactory):
171 """Special task factory that instantiates AddTask.
173 It also defines some bookkeeping variables used by AddTask to report
174 progress to unit tests.
176 Parameters
177 ----------
178 stopAt : `int`, optional
179 Number of times to call `run` before stopping.
180 """
182 def __init__(self, stopAt: int = -1):
183 self.countExec = 0 # incremented by AddTask
184 self.stopAt = stopAt # AddTask raises exception at this call to run()
186 def makeTask(
187 self,
188 task_node: TaskNode,
189 /,
190 butler: LimitedButler,
191 initInputRefs: Iterable[DatasetRef] | None,
192 ) -> PipelineTask:
193 task = task_node.task_class(config=task_node.config, initInputs=None, name=task_node.label)
194 task.taskFactory = self # type: ignore
195 return task
198def registerDatasetTypes(registry: Registry, pipeline: Pipeline | PipelineGraph) -> None:
199 """Register all dataset types used by tasks in a registry.
201 Copied and modified from `PreExecInit.initializeDatasetTypes`.
203 Parameters
204 ----------
205 registry : `~lsst.daf.butler.Registry`
206 Registry instance.
207 pipeline : `.Pipeline`, or `.pipeline_graph.PipelineGraph`
208 The pipeline whose dataset types should be registered, as a `.Pipeline`
209 instance or `.PipelineGraph` instance.
210 """
211 match pipeline:
212 case PipelineGraph() as pipeline_graph:
213 pass
214 case Pipeline():
215 pipeline_graph = pipeline.to_graph()
216 case _:
217 raise TypeError(f"Unexpected pipeline argument value: {pipeline}.")
218 pipeline_graph.resolve(registry)
219 dataset_types = [node.dataset_type for node in pipeline_graph.dataset_types.values()]
220 dataset_types.append(
221 DatasetType(
222 PACKAGES_INIT_OUTPUT_NAME,
223 {},
224 storageClass=PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
225 universe=registry.dimensions,
226 )
227 )
228 for dataset_type in dataset_types:
229 _LOG.info("Registering %s with registry", dataset_type)
230 registry.registerDatasetType(dataset_type)
233def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline:
234 """Make a simple Pipeline for tests.
236 This is called by `makeSimpleQGraph()` if no pipeline is passed to that
237 function. It can also be used to customize the pipeline used by
238 `makeSimpleQGraph()` function by calling this first and passing the result
239 to it.
241 Parameters
242 ----------
243 nQuanta : `int`
244 The number of quanta to add to the pipeline.
245 instrument : `str` or `None`, optional
246 The importable name of an instrument to be added to the pipeline or
247 if no instrument should be added then an empty string or `None`, by
248 default `None`.
250 Returns
251 -------
252 pipeline : `~lsst.pipe.base.Pipeline`
253 The created pipeline object.
254 """
255 pipeline = Pipeline("test pipeline")
256 # make a bunch of tasks that execute in well defined order (via data
257 # dependencies)
258 for lvl in range(nQuanta):
259 pipeline.addTask(AddTask, f"task{lvl}")
260 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl)
261 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1)
262 if instrument:
263 pipeline.addInstrument(instrument)
264 return pipeline
267def makeSimpleButler(
268 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None
269) -> Butler:
270 """Create new data butler instance.
272 Parameters
273 ----------
274 root : `str`
275 Path or URI to the root location of the new repository.
276 run : `str`, optional
277 Run collection name.
278 inMemory : `bool`, optional
279 If true make in-memory repository.
280 config : `~lsst.daf.butler.Config`, optional
281 Configuration to use for new Butler, if `None` then default
282 configuration is used. If ``inMemory`` is `True` then configuration
283 is updated to use SQLite registry and file datastore in ``root``.
285 Returns
286 -------
287 butler : `~lsst.daf.butler.Butler`
288 Data butler instance.
289 """
290 root_path = ResourcePath(root, forceDirectory=True)
291 if not root_path.isLocal:
292 raise ValueError(f"Only works with local root not {root_path}")
293 butler_config = Config()
294 if config:
295 butler_config.update(Config(config))
296 if not inMemory:
297 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite"
298 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
299 butler = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config)
300 butler.registry.defaults = RegistryDefaults(run=run)
301 return butler
304def populateButler(
305 pipeline: Pipeline | PipelineGraph,
306 butler: Butler,
307 datasetTypes: dict[str | None, list[str]] | None = None,
308 instrument: str | None = None,
309) -> None:
310 """Populate data butler with data needed for test.
312 Initializes data butler with a bunch of items:
313 - registers dataset types which are defined by pipeline
314 - create dimensions data for (instrument, detector)
315 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is
316 missing then a single dataset with type "add_dataset0" is added.
318 All datasets added to butler have ``dataId={instrument=instrument,
319 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is
320 used if pipeline is missing instrument definition. Type of the dataset is
321 guessed from dataset type name (assumes that pipeline is made of `AddTask`
322 tasks).
324 Parameters
325 ----------
326 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph`
327 Pipeline or pipeline graph instance.
328 butler : `~lsst.daf.butler.Butler`
329 Data butler instance.
330 datasetTypes : `dict` [ `str`, `list` ], optional
331 Dictionary whose keys are collection names and values are lists of
332 dataset type names. By default a single dataset of type "add_dataset0"
333 is added to a ``butler.run`` collection.
334 instrument : `str`, optional
335 Fully-qualified name of the instrumemnt class (as it appears in a
336 pipeline). This is needed to propagate the instrument from the
337 original pipeline if a pipeline graph is passed instead.
338 """
339 # Add dataset types to registry
340 match pipeline:
341 case PipelineGraph() as pipeline_graph:
342 pass
343 case Pipeline():
344 pipeline_graph = pipeline.to_graph()
345 if instrument is None:
346 instrument = pipeline.getInstrument()
347 case _:
348 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.")
350 registerDatasetTypes(butler.registry, pipeline_graph)
352 if instrument is not None:
353 instrument_class = doImportType(instrument)
354 instrumentName = cast(Instrument, instrument_class).getName()
355 instrumentClass = get_full_type_name(instrument_class)
356 else:
357 instrumentName = "INSTR"
358 instrumentClass = None
360 # Add all needed dimensions to registry
361 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass))
362 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0"))
364 # Add inputs to butler.
365 if not datasetTypes:
366 datasetTypes = {None: ["add_dataset0"]}
367 for run, dsTypes in datasetTypes.items():
368 if run is not None:
369 butler.registry.registerRun(run)
370 for dsType in dsTypes:
371 if dsType == PACKAGES_INIT_OUTPUT_NAME:
372 # Version is intentionally inconsistent.
373 # Dict is convertible to Packages if Packages is installed.
374 data: Any = {"python": "9.9.99"}
375 butler.put(data, dsType, run=run)
376 else:
377 if dsType.endswith(CONFIG_INIT_OUTPUT_CONNECTION_NAME):
378 # find a config from matching task name or make a new one
379 taskLabel, _, _ = dsType.rpartition("_")
380 task_node = pipeline_graph.tasks.get(taskLabel)
381 if task_node is not None:
382 data = task_node.config
383 else:
384 data = AddTaskConfig()
385 elif dsType.endswith(METADATA_OUTPUT_CONNECTION_NAME):
386 data = _TASK_FULL_METADATA_TYPE()
387 elif dsType.endswith(LOG_OUTPUT_CONNECTION_NAME):
388 data = ButlerLogRecords.from_records([])
389 else:
390 data = numpy.array([0.0, 1.0, 2.0, 5.0])
391 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0)
394def makeSimpleQGraph(
395 nQuanta: int = 5,
396 pipeline: Pipeline | PipelineGraph | None = None,
397 butler: Butler | None = None,
398 root: str | None = None,
399 callPopulateButler: bool = True,
400 run: str = "test",
401 instrument: str | None = None,
402 skipExistingIn: Any = None,
403 inMemory: bool = True,
404 userQuery: str = "",
405 datasetTypes: dict[str | None, list[str]] | None = None,
406 datasetQueryConstraint: DSQVariant = DSQVariant.ALL,
407 makeDatastoreRecords: bool = False,
408 bind: Mapping[str, Any] | None = None,
409 metadata: MutableMapping[str, Any] | None = None,
410) -> tuple[Butler, QuantumGraph]:
411 """Make simple `QuantumGraph` for tests.
413 Makes simple one-task pipeline with AddTask, sets up in-memory registry
414 and butler, fills them with minimal data, and generates QuantumGraph with
415 all of that.
417 Parameters
418 ----------
419 nQuanta : `int`
420 Number of quanta in a graph, only used if ``pipeline`` is None.
421 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph`
422 Pipeline or pipeline graph to build the graph from. If `None`, a
423 pipeline is made with `AddTask` and default `AddTaskConfig`.
424 butler : `~lsst.daf.butler.Butler`, optional
425 Data butler instance, if None then new data butler is created by
426 calling `makeSimpleButler`.
427 root : `str`
428 Path or URI to the root location of the new repository. Only used if
429 ``butler`` is None.
430 callPopulateButler : `bool`, optional
431 If True insert datasets into the butler prior to building a graph.
432 If False butler argument must not be None, and must be pre-populated.
433 Defaults to True.
434 run : `str`, optional
435 Name of the RUN collection to add to butler, only used if ``butler``
436 is None.
437 instrument : `str` or `None`, optional
438 The importable name of an instrument to be added to the pipeline or
439 if no instrument should be added then an empty string or `None`, by
440 default `None`. Only used if ``pipeline`` is `None`.
441 skipExistingIn : `~typing.Any`
442 Expressions representing the collections to search for existing
443 output datasets that should be skipped. See
444 :ref:`daf_butler_ordered_collection_searches`.
445 inMemory : `bool`, optional
446 If true make in-memory repository, only used if ``butler`` is `None`.
447 userQuery : `str`, optional
448 The user query to pass to ``makeGraph``, by default an empty string.
449 datasetTypes : `dict` [ `str`, `list` ], optional
450 Dictionary whose keys are collection names and values are lists of
451 dataset type names. By default a single dataset of type "add_dataset0"
452 is added to a ``butler.run`` collection.
453 datasetQueryConstraint : `DatasetQueryConstraintVariant`
454 The query constraint variant that should be used to constrain the
455 query based on dataset existence, defaults to
456 `DatasetQueryConstraintVariant.ALL`.
457 makeDatastoreRecords : `bool`, optional
458 If `True` then add datastore records to generated quanta.
459 bind : `~collections.abc.Mapping`, optional
460 Mapping containing literal values that should be injected into the
461 ``userQuery`` expression, keyed by the identifiers they replace.
462 metadata : `~collections.abc.Mapping`, optional
463 Optional graph metadata.
465 Returns
466 -------
467 butler : `~lsst.daf.butler.Butler`
468 Butler instance.
469 qgraph : `~lsst.pipe.base.QuantumGraph`
470 Quantum graph instance.
471 """
472 match pipeline:
473 case PipelineGraph() as pipeline_graph:
474 pass
475 case Pipeline():
476 pipeline_graph = pipeline.to_graph()
477 if instrument is None:
478 instrument = pipeline.getInstrument()
479 case None:
480 pipeline_graph = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument).to_graph()
481 case _:
482 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.")
484 butler_created = False
485 if butler is None:
486 if root is None:
487 raise ValueError("Must provide `root` when `butler` is None")
488 if callPopulateButler is False:
489 raise ValueError("populateButler can only be False when butler is supplied as an argument")
490 butler = makeSimpleButler(root, run=run, inMemory=inMemory)
491 butler_created = True
493 try:
494 if callPopulateButler:
495 populateButler(pipeline_graph, butler, datasetTypes=datasetTypes, instrument=instrument)
497 # Make the graph
498 _LOG.debug(
499 "Instantiating QuantumGraphBuilder, "
500 "skip_existing_in=%s, input_collections=%r, output_run=%r, where=%r, bind=%s.",
501 skipExistingIn,
502 butler.collections.defaults,
503 run,
504 userQuery,
505 bind,
506 )
507 if not run:
508 assert butler.run is not None, "Butler must have run defined"
509 run = butler.run
510 builder = AllDimensionsQuantumGraphBuilder(
511 pipeline_graph,
512 butler,
513 skip_existing_in=skipExistingIn if skipExistingIn is not None else [],
514 input_collections=butler.collections.defaults,
515 output_run=run,
516 where=userQuery,
517 bind=bind,
518 dataset_query_constraint=datasetQueryConstraint,
519 )
520 _LOG.debug("Calling QuantumGraphBuilder.build.")
521 if not metadata:
522 metadata = {}
523 metadata["output_run"] = run
525 qgraph = builder.build(metadata=metadata, attach_datastore_records=makeDatastoreRecords)
526 except Exception:
527 if butler_created:
528 butler.close()
529 raise
531 return butler, qgraph