Coverage for python / lsst / pipe / base / quantum_graph / _common.py: 66%
195 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "FORMAT_VERSION",
32 "BaseQuantumGraph",
33 "BaseQuantumGraphReader",
34 "BipartiteEdgeInfo",
35 "DatasetInfo",
36 "HeaderModel",
37 "QuantumInfo",
38)
39import dataclasses
40import datetime
41import getpass
42import os
43import sys
44import uuid
45import zipfile
46from abc import ABC, abstractmethod
47from collections.abc import Iterator, Mapping
48from contextlib import contextmanager
49from typing import (
50 TYPE_CHECKING,
51 Any,
52 Self,
53 TypedDict,
54)
56import networkx
57import networkx.algorithms.bipartite
58import pydantic
59import zstandard
61from lsst.daf.butler import DataCoordinate, DataIdValue
62from lsst.daf.butler._rubin import generate_uuidv7
63from lsst.resources import ResourcePath, ResourcePathExpression
65from ..pipeline_graph import DatasetTypeNode, Edge, PipelineGraph, TaskImportMode, TaskNode
66from ..pipeline_graph.io import SerializedPipelineGraph
67from ._multiblock import (
68 DEFAULT_PAGE_SIZE,
69 AddressReader,
70 AddressWriter,
71 Compressor,
72 Decompressor,
73)
75if TYPE_CHECKING:
76 from ..graph import QuantumGraph
79# These aliases make it a lot easier how the various pydantic models are
80# structured, but they're too verbose to be worth exporting to code outside the
81# quantum_graph subpackage.
82type TaskLabel = str
83type DatasetTypeName = str
84type ConnectionName = str
85type DatasetIndex = int
86type QuantumIndex = int
87type DatastoreName = str
88type DimensionElementName = str
89type DataCoordinateValues = list[DataIdValue]
92FORMAT_VERSION: int = 1
93"""
94File format version number for new files.
96This applies to both predicted and provenance QGs, since they usually change
97in concert.
99CHANGELOG:
101- 0: Initial version.
102- 1: Switched from internal integer IDs to UUIDs in all models.
103"""
106class IncompleteQuantumGraphError(RuntimeError):
107 pass
110class HeaderModel(pydantic.BaseModel):
111 """Data model for the header of a quantum graph file."""
113 version: int = FORMAT_VERSION
114 """File format / data model version number."""
116 graph_type: str = ""
117 """Type of quantum graph stored in this file."""
119 inputs: list[str] = pydantic.Field(default_factory=list)
120 """List of input collections used to build the quantum graph."""
122 output: str | None = ""
123 """Output CHAINED collection provided when building the quantum graph."""
125 output_run: str = ""
126 """Output RUN collection for all output datasets in this graph."""
128 user: str = pydantic.Field(default_factory=getpass.getuser)
129 """Username of the process that built this quantum graph."""
131 timestamp: datetime.datetime = pydantic.Field(default_factory=datetime.datetime.now)
132 """Timestamp for when this quantum graph was built.
134 It is unspecified exactly which point during quantum-graph generation this
135 timestamp is recorded.
136 """
138 command: str = pydantic.Field(default_factory=lambda: " ".join(sys.argv))
139 """Command-line invocation that created this graph."""
141 metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
142 """Free-form metadata associated with this quantum graph at build time."""
144 int_size: int = 8
145 """Number of bytes in the integers used in this file's multi-block and
146 address files.
147 """
149 n_quanta: int = 0
150 """Total number of quanta in this graph.
152 This does not include special "init" quanta, but it does include quanta
153 that were not loaded in a partial read (except when reading from an old
154 quantum graph file).
155 """
157 n_datasets: int = 0
158 """Total number of distinct datasets in the full graph. This includes
159 datasets whose related quanta were not loaded in a partial read (except
160 when reading from an old quantum graph file).
161 """
163 n_task_quanta: dict[TaskLabel, int] = pydantic.Field(default_factory=dict)
164 """Number of quanta for each task label.
166 This does not include special "init" quanta, but it does include quanta
167 that were not loaded in a partial read (except when reading from an old
168 quantum graph file).
169 """
171 provenance_dataset_id: uuid.UUID = pydantic.Field(default_factory=generate_uuidv7)
172 """The dataset ID for provenance quantum graph when it is ingested into
173 a butler repository.
174 """
176 @classmethod
177 def from_old_quantum_graph(cls, old_quantum_graph: QuantumGraph) -> HeaderModel:
178 """Extract a header from an old `QuantumGraph` instance.
180 Parameters
181 ----------
182 old_quantum_graph : `QuantumGraph`
183 Quantum graph to extract a header from.
185 Returns
186 -------
187 header : `PredictedHeaderModel`
188 Header for a new predicted quantum graph.
189 """
190 metadata = dict(old_quantum_graph.metadata)
191 metadata.pop("packages", None)
192 if (time_str := metadata.pop("time", None)) is not None:
193 timestamp = datetime.datetime.fromisoformat(time_str)
194 else:
195 timestamp = datetime.datetime.now()
196 return cls(
197 inputs=list(metadata.pop("input", []) or []), # Guard against explicit None and missing key.
198 output=metadata.pop("output", None),
199 output_run=metadata.pop("output_run", ""),
200 user=metadata.pop("user", ""),
201 command=metadata.pop("full_command", ""),
202 timestamp=timestamp,
203 metadata=metadata,
204 )
206 def to_old_metadata(self) -> dict[str, Any]:
207 """Return a dictionary using the key conventions used in old quantum
208 graph files.
209 """
210 result = self.metadata.copy()
211 result["input"] = self.inputs
212 result["output"] = self.output
213 result["output_run"] = self.output_run
214 result["full_command"] = self.command
215 result["user"] = self.user
216 result["time"] = str(self.timestamp)
217 return result
219 # Work around the fact that Sphinx chokes on Pydantic docstring formatting,
220 # when we inherit those docstrings in our public classes.
221 if "sphinx" in sys.modules and not TYPE_CHECKING:
223 def copy(self, *args: Any, **kwargs: Any) -> Any:
224 """See `pydantic.BaseModel.copy`."""
225 return super().copy(*args, **kwargs)
227 def model_dump(self, *args: Any, **kwargs: Any) -> Any:
228 """See `pydantic.BaseModel.model_dump`."""
229 return super().model_dump(*args, **kwargs)
231 def model_dump_json(self, *args: Any, **kwargs: Any) -> Any:
232 """See `pydantic.BaseModel.model_dump_json`."""
233 return super().model_dump(*args, **kwargs)
235 def model_copy(self, *args: Any, **kwargs: Any) -> Any:
236 """See `pydantic.BaseModel.model_copy`."""
237 return super().model_copy(*args, **kwargs)
239 @classmethod
240 def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override]
241 """See `pydantic.BaseModel.model_construct`."""
242 return super().model_construct(*args, **kwargs)
244 @classmethod
245 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any:
246 """See `pydantic.BaseModel.model_json_schema`."""
247 return super().model_json_schema(*args, **kwargs)
249 @classmethod
250 def model_validate(cls, *args: Any, **kwargs: Any) -> Any:
251 """See `pydantic.BaseModel.model_validate`."""
252 return super().model_validate(*args, **kwargs)
254 @classmethod
255 def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any:
256 """See `pydantic.BaseModel.model_validate_json`."""
257 return super().model_validate_json(*args, **kwargs)
259 @classmethod
260 def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any:
261 """See `pydantic.BaseModel.model_validate_strings`."""
262 return super().model_validate_strings(*args, **kwargs)
265class QuantumInfo(TypedDict):
266 """A typed dictionary that annotates the attributes of the NetworkX graph
267 node data for a quantum.
269 Since NetworkX types are not generic over their node mapping type, this has
270 to be used explicitly, e.g.::
272 node_data: QuantumInfo = xgraph.nodes[quantum_id]
274 where ``xgraph`` can be either `BaseQuantumGraph.quantum_only_xgraph`
275 or `BaseQuantumGraph.bipartite_xgraph`.
276 """
278 data_id: DataCoordinate
279 """Data ID of the quantum."""
281 task_label: str
282 """Label of the task for this quantum."""
284 pipeline_node: TaskNode
285 """Node in the pipeline graph for this quantum's task."""
288class DatasetInfo(TypedDict):
289 """A typed dictionary that annotates the attributes of the NetworkX graph
290 node data for a dataset.
292 Since NetworkX types are not generic over their node mapping type, this has
293 to be used explicitly, e.g.::
295 node_data: DatasetInfo = xgraph.nodes[dataset_id]
297 where ``xgraph`` is from the `BaseQuantumGraph.bipartite_xgraph` property.
298 """
300 data_id: DataCoordinate
301 """Data ID of the dataset."""
303 dataset_type_name: DatasetTypeName
304 """Name of the type of this dataset.
306 This is always the general dataset type that matches the data repository
307 storage class, which may differ from any particular task-adapted dataset
308 type whose storage class has been overridden to match the task connections.
309 This means is it never a component.
310 """
312 run: str
313 """Name of the `~lsst.daf.butler.CollectionType.RUN` collection that holds
314 or will hold this dataset.
315 """
317 pipeline_node: DatasetTypeNode
318 """Node in the pipeline graph for this dataset's type."""
321class BipartiteEdgeInfo(TypedDict):
322 """A typed dictionary that annotates the attributes of the NetworkX graph
323 edge data in a bipartite graph.
324 """
326 is_read: bool
327 """`True` if this is a dataset -> quantum edge; `False` if it is a
328 quantum -> dataset edge.
329 """
331 pipeline_edges: list[Edge]
332 """Corresponding edges in the pipeline graph.
334 Note that there may be more than one pipeline edge since a quantum can
335 consume a particular dataset via multiple connections.
336 """
339class BaseQuantumGraph(ABC):
340 """An abstract base for quantum graphs.
342 Parameters
343 ----------
344 header : `HeaderModel`
345 Structured metadata for the graph.
346 pipeline_graph : `.pipeline_graph.PipelineGraph`
347 Graph of tasks and dataset types. May contain a superset of the tasks
348 and dataset types that actually have quanta and datasets in the quantum
349 graph.
350 """
352 def __init__(self, header: HeaderModel, pipeline_graph: PipelineGraph):
353 self.header = header
354 self.pipeline_graph = pipeline_graph
356 @property
357 @abstractmethod
358 def quanta_by_task(self) -> Mapping[str, Mapping[DataCoordinate, uuid.UUID]]:
359 """A nested mapping of all quanta, keyed first by task name and then by
360 data ID.
362 Notes
363 -----
364 Partial loads may not fully populate this mapping, but it can always
365 be accessed.
367 The returned object may be an internal dictionary; as the type
368 annotation indicates, it should not be modified in place.
369 """
370 raise NotImplementedError()
372 @property
373 @abstractmethod
374 def datasets_by_type(self) -> Mapping[str, Mapping[DataCoordinate, uuid.UUID]]:
375 """A nested mapping of all datasets, keyed first by dataset type name
376 and then by data ID.
378 Notes
379 -----
380 Partial loads may not fully populate this mapping, but it can always
381 be accessed.
383 The returned object may be an internal dictionary; as the type
384 annotation indicates, it should not be modified in place.
385 """
386 raise NotImplementedError()
388 @property
389 @abstractmethod
390 def quantum_only_xgraph(self) -> networkx.DiGraph:
391 """A directed acyclic graph with quanta as nodes and datasets elided.
393 Notes
394 -----
395 Partial loads may not fully populate this graph, but it can always be
396 accessed.
398 Node state dictionaries are described by the `QuantumInfo` type
399 (or a subtype thereof).
401 The returned object is a read-only view of an internal one.
402 """
403 raise NotImplementedError()
405 @property
406 @abstractmethod
407 def bipartite_xgraph(self) -> networkx.DiGraph:
408 """A directed acyclic graph with quantum and dataset nodes.
410 Notes
411 -----
412 Partial loads may not fully populate this graph, but it can always be
413 accessed.
415 Node state dictionaries are described by the `QuantumInfo` and
416 `DatasetInfo` types (or a subtypes thereof). Edges have state
417 dictionaries described by `BipartiteEdgeInfo`.
419 The returned object is a read-only view of an internal one.
420 """
421 raise NotImplementedError()
424@dataclasses.dataclass
425class BaseQuantumGraphWriter:
426 """A helper class for writing quantum graphs."""
428 zf: zipfile.ZipFile
429 """The zip archive that represents the quantum graph on disk."""
431 compressor: Compressor
432 """A compressor for all compressed JSON blocks."""
434 address_writer: AddressWriter
435 """A helper object for reading addresses into the multi-block files."""
437 int_size: int
438 """Size (in bytes) used to write integers to binary files."""
440 @classmethod
441 @contextmanager
442 def open(
443 cls,
444 uri: ResourcePathExpression,
445 header: HeaderModel,
446 pipeline_graph: PipelineGraph,
447 *,
448 address_filename: str,
449 cdict_data: bytes | None = None,
450 zstd_level: int = 10,
451 ) -> Iterator[Self]:
452 uri = ResourcePath(uri, forceDirectory=False)
453 address_writer = AddressWriter()
454 if uri.isLocal:
455 os.makedirs(uri.dirname().ospath, exist_ok=True)
456 cdict = zstandard.ZstdCompressionDict(cdict_data) if cdict_data is not None else None
457 compressor = zstandard.ZstdCompressor(level=zstd_level, dict_data=cdict)
458 with uri.open(mode="wb") as stream:
459 with zipfile.ZipFile(stream, mode="w", compression=zipfile.ZIP_STORED) as zf:
460 self = cls(zf, compressor, address_writer, header.int_size)
461 self.write_single_model("header", header)
462 if cdict_data is not None:
463 zf.writestr("compression_dict", cdict_data)
464 self.write_single_model("pipeline_graph", SerializedPipelineGraph.serialize(pipeline_graph))
465 yield self
466 address_writer.write_to_zip(zf, address_filename, int_size=self.int_size)
468 def write_single_model(self, name: str, model: pydantic.BaseModel) -> None:
469 """Write a single compressed JSON block as a 'file' in a zip archive.
471 Parameters
472 ----------
473 name : `str`
474 Base name of the file. An extension will be added.
475 model : `pydantic.BaseModel`
476 Pydantic model to convert to JSON.
477 """
478 json_data = model.model_dump_json().encode()
479 self.write_single_block(name, json_data)
481 def write_single_block(self, name: str, json_data: bytes) -> None:
482 """Write a single compressed JSON block as a 'file' in a zip archive.
484 Parameters
485 ----------
486 name : `str`
487 Base name of the file. An extension will be added.
488 json_data : `bytes`
489 Raw JSON to compress and write.
490 """
491 json_data = self.compressor.compress(json_data)
492 self.zf.writestr(f"{name}.json.zst", json_data)
495@dataclasses.dataclass
496class BaseQuantumGraphReader:
497 """A helper class for reading quantum graphs."""
499 header: HeaderModel
500 """Header metadata for the quantum graph."""
502 pipeline_graph: PipelineGraph
503 """Graph of tasks and dataset type names that appear in the quantum
504 graph.
505 """
507 zf: zipfile.ZipFile
508 """The zip archive that represents the quantum graph on disk."""
510 decompressor: Decompressor
511 """A decompressor for all compressed JSON blocks."""
513 address_reader: AddressReader
514 """A helper object for reading addresses into the multi-block files."""
516 page_size: int
517 """Approximate number of bytes to read at a time.
519 Note that this does not set a page size for *all* reads, but it
520 does affect the smallest, most numerous reads.
521 """
523 @classmethod
524 @contextmanager
525 def _open(
526 cls,
527 uri: ResourcePathExpression,
528 *,
529 address_filename: str,
530 graph_type: str,
531 n_addresses: int,
532 page_size: int | None = None,
533 import_mode: TaskImportMode = TaskImportMode.ASSUME_CONSISTENT_EDGES,
534 ) -> Iterator[Self]:
535 """Construct a reader from a URI.
537 Parameters
538 ----------
539 uri : convertible to `lsst.resources.ResourcePath`
540 URI to open. Should have a ``.qg`` extension.
541 address_filename : `str`
542 Base filename for the address file.
543 graph_type : `str`
544 Value to expect for `HeaderModel.graph_type`.
545 n_addresses : `int`
546 Number of addresses to expect per row in the address file.
547 page_size : `int`, optional
548 Approximate number of bytes to read at once from address files.
549 Note that this does not set a page size for *all* reads, but it
550 does affect the smallest, most numerous reads. When `None`, the
551 ``LSST_QG_PAGE_SIZE`` environment variable is checked before
552 falling back to a default of 5MB.
553 import_mode : `..pipeline_graph.TaskImportMode`, optional
554 How to handle importing the task classes referenced in the pipeline
555 graph.
557 Returns
558 -------
559 reader : `contextlib.AbstractContextManager` [ \
560 `PredictedQuantumGraphReader` ]
561 A context manager that returns the reader when entered.
562 """
563 if page_size is None:
564 page_size = int(os.environ.get("LSST_QG_PAGE_SIZE", DEFAULT_PAGE_SIZE))
565 uri = ResourcePath(uri)
566 cdict: zstandard.ZstdCompressionDict | None = None
567 with uri.open(mode="rb") as zf_stream:
568 with zipfile.ZipFile(zf_stream, "r") as zf:
569 if (cdict_path := zipfile.Path(zf, "compression_dict")).exists():
570 cdict = zstandard.ZstdCompressionDict(cdict_path.read_bytes())
571 decompressor = zstandard.ZstdDecompressor(cdict)
572 header = cls._read_single_block_static("header", HeaderModel, zf, decompressor)
573 if not header.graph_type == graph_type:
574 raise TypeError(f"Header is for a {header.graph_type!r} graph, not {graph_type!r} graph.")
575 serialized_pipeline_graph = cls._read_single_block_static(
576 "pipeline_graph", SerializedPipelineGraph, zf, decompressor
577 )
578 pipeline_graph = serialized_pipeline_graph.deserialize(import_mode)
579 with AddressReader.open_in_zip(
580 zf,
581 address_filename,
582 page_size=page_size,
583 int_size=header.int_size,
584 n_addresses=n_addresses,
585 ) as address_reader:
586 yield cls(
587 header=header,
588 pipeline_graph=pipeline_graph,
589 zf=zf,
590 decompressor=decompressor,
591 address_reader=address_reader,
592 page_size=page_size,
593 )
595 @staticmethod
596 def _read_single_block_static[T: pydantic.BaseModel](
597 name: str, model_type: type[T], zf: zipfile.ZipFile, decompressor: Decompressor
598 ) -> T:
599 """Read a single compressed JSON block from a 'file' in a zip archive.
601 Parameters
602 ----------
603 zf : `zipfile.ZipFile`
604 Zip archive to read the file from.
605 name : `str`
606 Base name of the file. An extension will be added.
607 model_type : `type` [ `pydantic.BaseModel` ]
608 Pydantic model to validate JSON with.
609 decompressor : `Decompressor`
610 Object with a `decompress` method that takes and returns `bytes`.
612 Returns
613 -------
614 model : `pydantic.BaseModel`
615 Validated model.
616 """
617 compressed_data = zf.read(f"{name}.json.zst")
618 json_data = decompressor.decompress(compressed_data)
619 return model_type.model_validate_json(json_data)
621 def _read_single_block[T: pydantic.BaseModel](self, name: str, model_type: type[T]) -> T:
622 """Read a single compressed JSON block from a 'file' in a zip archive.
624 Parameters
625 ----------
626 name : `str`
627 Base name of the file. An extension will be added.
628 model_type : `type` [ `pydantic.BaseModel` ]
629 Pydantic model to validate JSON with.
631 Returns
632 -------
633 model : `pydantic.BaseModel`
634 Validated model.
635 """
636 return self._read_single_block_static(name, model_type, self.zf, self.decompressor)
638 def _read_single_block_raw(self, name: str) -> bytes:
639 """Read a single compressed block from a 'file' in a zip archive.
641 Parameters
642 ----------
643 name : `str`
644 Base name of the file. An extension will be added.
646 Returns
647 -------
648 data : `bytes`
649 Decompressed bytes.
650 """
651 compressed_data = self.zf.read(f"{name}.json.zst")
652 return self.decompressor.decompress(compressed_data)