Coverage for python / lsst / pipe / base / quantum_graph / _common.py: 66%

195 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 08:44 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "FORMAT_VERSION", 

32 "BaseQuantumGraph", 

33 "BaseQuantumGraphReader", 

34 "BipartiteEdgeInfo", 

35 "DatasetInfo", 

36 "HeaderModel", 

37 "QuantumInfo", 

38) 

39import dataclasses 

40import datetime 

41import getpass 

42import os 

43import sys 

44import uuid 

45import zipfile 

46from abc import ABC, abstractmethod 

47from collections.abc import Iterator, Mapping 

48from contextlib import contextmanager 

49from typing import ( 

50 TYPE_CHECKING, 

51 Any, 

52 Self, 

53 TypedDict, 

54) 

55 

56import networkx 

57import networkx.algorithms.bipartite 

58import pydantic 

59import zstandard 

60 

61from lsst.daf.butler import DataCoordinate, DataIdValue 

62from lsst.daf.butler._rubin import generate_uuidv7 

63from lsst.resources import ResourcePath, ResourcePathExpression 

64 

65from ..pipeline_graph import DatasetTypeNode, Edge, PipelineGraph, TaskImportMode, TaskNode 

66from ..pipeline_graph.io import SerializedPipelineGraph 

67from ._multiblock import ( 

68 DEFAULT_PAGE_SIZE, 

69 AddressReader, 

70 AddressWriter, 

71 Compressor, 

72 Decompressor, 

73) 

74 

75if TYPE_CHECKING: 

76 from ..graph import QuantumGraph 

77 

78 

79# These aliases make it a lot easier how the various pydantic models are 

80# structured, but they're too verbose to be worth exporting to code outside the 

81# quantum_graph subpackage. 

82type TaskLabel = str 

83type DatasetTypeName = str 

84type ConnectionName = str 

85type DatasetIndex = int 

86type QuantumIndex = int 

87type DatastoreName = str 

88type DimensionElementName = str 

89type DataCoordinateValues = list[DataIdValue] 

90 

91 

92FORMAT_VERSION: int = 1 

93""" 

94File format version number for new files. 

95 

96This applies to both predicted and provenance QGs, since they usually change 

97in concert. 

98 

99CHANGELOG: 

100 

101- 0: Initial version. 

102- 1: Switched from internal integer IDs to UUIDs in all models. 

103""" 

104 

105 

106class IncompleteQuantumGraphError(RuntimeError): 

107 pass 

108 

109 

110class HeaderModel(pydantic.BaseModel): 

111 """Data model for the header of a quantum graph file.""" 

112 

113 version: int = FORMAT_VERSION 

114 """File format / data model version number.""" 

115 

116 graph_type: str = "" 

117 """Type of quantum graph stored in this file.""" 

118 

119 inputs: list[str] = pydantic.Field(default_factory=list) 

120 """List of input collections used to build the quantum graph.""" 

121 

122 output: str | None = "" 

123 """Output CHAINED collection provided when building the quantum graph.""" 

124 

125 output_run: str = "" 

126 """Output RUN collection for all output datasets in this graph.""" 

127 

128 user: str = pydantic.Field(default_factory=getpass.getuser) 

129 """Username of the process that built this quantum graph.""" 

130 

131 timestamp: datetime.datetime = pydantic.Field(default_factory=datetime.datetime.now) 

132 """Timestamp for when this quantum graph was built. 

133 

134 It is unspecified exactly which point during quantum-graph generation this 

135 timestamp is recorded. 

136 """ 

137 

138 command: str = pydantic.Field(default_factory=lambda: " ".join(sys.argv)) 

139 """Command-line invocation that created this graph.""" 

140 

141 metadata: dict[str, Any] = pydantic.Field(default_factory=dict) 

142 """Free-form metadata associated with this quantum graph at build time.""" 

143 

144 int_size: int = 8 

145 """Number of bytes in the integers used in this file's multi-block and 

146 address files. 

147 """ 

148 

149 n_quanta: int = 0 

150 """Total number of quanta in this graph. 

151 

152 This does not include special "init" quanta, but it does include quanta 

153 that were not loaded in a partial read (except when reading from an old 

154 quantum graph file). 

155 """ 

156 

157 n_datasets: int = 0 

158 """Total number of distinct datasets in the full graph. This includes 

159 datasets whose related quanta were not loaded in a partial read (except 

160 when reading from an old quantum graph file). 

161 """ 

162 

163 n_task_quanta: dict[TaskLabel, int] = pydantic.Field(default_factory=dict) 

164 """Number of quanta for each task label. 

165 

166 This does not include special "init" quanta, but it does include quanta 

167 that were not loaded in a partial read (except when reading from an old 

168 quantum graph file). 

169 """ 

170 

171 provenance_dataset_id: uuid.UUID = pydantic.Field(default_factory=generate_uuidv7) 

172 """The dataset ID for provenance quantum graph when it is ingested into 

173 a butler repository. 

174 """ 

175 

176 @classmethod 

177 def from_old_quantum_graph(cls, old_quantum_graph: QuantumGraph) -> HeaderModel: 

178 """Extract a header from an old `QuantumGraph` instance. 

179 

180 Parameters 

181 ---------- 

182 old_quantum_graph : `QuantumGraph` 

183 Quantum graph to extract a header from. 

184 

185 Returns 

186 ------- 

187 header : `PredictedHeaderModel` 

188 Header for a new predicted quantum graph. 

189 """ 

190 metadata = dict(old_quantum_graph.metadata) 

191 metadata.pop("packages", None) 

192 if (time_str := metadata.pop("time", None)) is not None: 

193 timestamp = datetime.datetime.fromisoformat(time_str) 

194 else: 

195 timestamp = datetime.datetime.now() 

196 return cls( 

197 inputs=list(metadata.pop("input", []) or []), # Guard against explicit None and missing key. 

198 output=metadata.pop("output", None), 

199 output_run=metadata.pop("output_run", ""), 

200 user=metadata.pop("user", ""), 

201 command=metadata.pop("full_command", ""), 

202 timestamp=timestamp, 

203 metadata=metadata, 

204 ) 

205 

206 def to_old_metadata(self) -> dict[str, Any]: 

207 """Return a dictionary using the key conventions used in old quantum 

208 graph files. 

209 """ 

210 result = self.metadata.copy() 

211 result["input"] = self.inputs 

212 result["output"] = self.output 

213 result["output_run"] = self.output_run 

214 result["full_command"] = self.command 

215 result["user"] = self.user 

216 result["time"] = str(self.timestamp) 

217 return result 

218 

219 # Work around the fact that Sphinx chokes on Pydantic docstring formatting, 

220 # when we inherit those docstrings in our public classes. 

221 if "sphinx" in sys.modules and not TYPE_CHECKING: 

222 

223 def copy(self, *args: Any, **kwargs: Any) -> Any: 

224 """See `pydantic.BaseModel.copy`.""" 

225 return super().copy(*args, **kwargs) 

226 

227 def model_dump(self, *args: Any, **kwargs: Any) -> Any: 

228 """See `pydantic.BaseModel.model_dump`.""" 

229 return super().model_dump(*args, **kwargs) 

230 

231 def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: 

232 """See `pydantic.BaseModel.model_dump_json`.""" 

233 return super().model_dump(*args, **kwargs) 

234 

235 def model_copy(self, *args: Any, **kwargs: Any) -> Any: 

236 """See `pydantic.BaseModel.model_copy`.""" 

237 return super().model_copy(*args, **kwargs) 

238 

239 @classmethod 

240 def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override] 

241 """See `pydantic.BaseModel.model_construct`.""" 

242 return super().model_construct(*args, **kwargs) 

243 

244 @classmethod 

245 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: 

246 """See `pydantic.BaseModel.model_json_schema`.""" 

247 return super().model_json_schema(*args, **kwargs) 

248 

249 @classmethod 

250 def model_validate(cls, *args: Any, **kwargs: Any) -> Any: 

251 """See `pydantic.BaseModel.model_validate`.""" 

252 return super().model_validate(*args, **kwargs) 

253 

254 @classmethod 

255 def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any: 

256 """See `pydantic.BaseModel.model_validate_json`.""" 

257 return super().model_validate_json(*args, **kwargs) 

258 

259 @classmethod 

260 def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any: 

261 """See `pydantic.BaseModel.model_validate_strings`.""" 

262 return super().model_validate_strings(*args, **kwargs) 

263 

264 

265class QuantumInfo(TypedDict): 

266 """A typed dictionary that annotates the attributes of the NetworkX graph 

267 node data for a quantum. 

268 

269 Since NetworkX types are not generic over their node mapping type, this has 

270 to be used explicitly, e.g.:: 

271 

272 node_data: QuantumInfo = xgraph.nodes[quantum_id] 

273 

274 where ``xgraph`` can be either `BaseQuantumGraph.quantum_only_xgraph` 

275 or `BaseQuantumGraph.bipartite_xgraph`. 

276 """ 

277 

278 data_id: DataCoordinate 

279 """Data ID of the quantum.""" 

280 

281 task_label: str 

282 """Label of the task for this quantum.""" 

283 

284 pipeline_node: TaskNode 

285 """Node in the pipeline graph for this quantum's task.""" 

286 

287 

288class DatasetInfo(TypedDict): 

289 """A typed dictionary that annotates the attributes of the NetworkX graph 

290 node data for a dataset. 

291 

292 Since NetworkX types are not generic over their node mapping type, this has 

293 to be used explicitly, e.g.:: 

294 

295 node_data: DatasetInfo = xgraph.nodes[dataset_id] 

296 

297 where ``xgraph`` is from the `BaseQuantumGraph.bipartite_xgraph` property. 

298 """ 

299 

300 data_id: DataCoordinate 

301 """Data ID of the dataset.""" 

302 

303 dataset_type_name: DatasetTypeName 

304 """Name of the type of this dataset. 

305 

306 This is always the general dataset type that matches the data repository 

307 storage class, which may differ from any particular task-adapted dataset 

308 type whose storage class has been overridden to match the task connections. 

309 This means is it never a component. 

310 """ 

311 

312 run: str 

313 """Name of the `~lsst.daf.butler.CollectionType.RUN` collection that holds 

314 or will hold this dataset. 

315 """ 

316 

317 pipeline_node: DatasetTypeNode 

318 """Node in the pipeline graph for this dataset's type.""" 

319 

320 

321class BipartiteEdgeInfo(TypedDict): 

322 """A typed dictionary that annotates the attributes of the NetworkX graph 

323 edge data in a bipartite graph. 

324 """ 

325 

326 is_read: bool 

327 """`True` if this is a dataset -> quantum edge; `False` if it is a 

328 quantum -> dataset edge. 

329 """ 

330 

331 pipeline_edges: list[Edge] 

332 """Corresponding edges in the pipeline graph. 

333 

334 Note that there may be more than one pipeline edge since a quantum can 

335 consume a particular dataset via multiple connections. 

336 """ 

337 

338 

339class BaseQuantumGraph(ABC): 

340 """An abstract base for quantum graphs. 

341 

342 Parameters 

343 ---------- 

344 header : `HeaderModel` 

345 Structured metadata for the graph. 

346 pipeline_graph : `.pipeline_graph.PipelineGraph` 

347 Graph of tasks and dataset types. May contain a superset of the tasks 

348 and dataset types that actually have quanta and datasets in the quantum 

349 graph. 

350 """ 

351 

352 def __init__(self, header: HeaderModel, pipeline_graph: PipelineGraph): 

353 self.header = header 

354 self.pipeline_graph = pipeline_graph 

355 

356 @property 

357 @abstractmethod 

358 def quanta_by_task(self) -> Mapping[str, Mapping[DataCoordinate, uuid.UUID]]: 

359 """A nested mapping of all quanta, keyed first by task name and then by 

360 data ID. 

361 

362 Notes 

363 ----- 

364 Partial loads may not fully populate this mapping, but it can always 

365 be accessed. 

366 

367 The returned object may be an internal dictionary; as the type 

368 annotation indicates, it should not be modified in place. 

369 """ 

370 raise NotImplementedError() 

371 

372 @property 

373 @abstractmethod 

374 def datasets_by_type(self) -> Mapping[str, Mapping[DataCoordinate, uuid.UUID]]: 

375 """A nested mapping of all datasets, keyed first by dataset type name 

376 and then by data ID. 

377 

378 Notes 

379 ----- 

380 Partial loads may not fully populate this mapping, but it can always 

381 be accessed. 

382 

383 The returned object may be an internal dictionary; as the type 

384 annotation indicates, it should not be modified in place. 

385 """ 

386 raise NotImplementedError() 

387 

388 @property 

389 @abstractmethod 

390 def quantum_only_xgraph(self) -> networkx.DiGraph: 

391 """A directed acyclic graph with quanta as nodes and datasets elided. 

392 

393 Notes 

394 ----- 

395 Partial loads may not fully populate this graph, but it can always be 

396 accessed. 

397 

398 Node state dictionaries are described by the `QuantumInfo` type 

399 (or a subtype thereof). 

400 

401 The returned object is a read-only view of an internal one. 

402 """ 

403 raise NotImplementedError() 

404 

405 @property 

406 @abstractmethod 

407 def bipartite_xgraph(self) -> networkx.DiGraph: 

408 """A directed acyclic graph with quantum and dataset nodes. 

409 

410 Notes 

411 ----- 

412 Partial loads may not fully populate this graph, but it can always be 

413 accessed. 

414 

415 Node state dictionaries are described by the `QuantumInfo` and 

416 `DatasetInfo` types (or a subtypes thereof). Edges have state 

417 dictionaries described by `BipartiteEdgeInfo`. 

418 

419 The returned object is a read-only view of an internal one. 

420 """ 

421 raise NotImplementedError() 

422 

423 

424@dataclasses.dataclass 

425class BaseQuantumGraphWriter: 

426 """A helper class for writing quantum graphs.""" 

427 

428 zf: zipfile.ZipFile 

429 """The zip archive that represents the quantum graph on disk.""" 

430 

431 compressor: Compressor 

432 """A compressor for all compressed JSON blocks.""" 

433 

434 address_writer: AddressWriter 

435 """A helper object for reading addresses into the multi-block files.""" 

436 

437 int_size: int 

438 """Size (in bytes) used to write integers to binary files.""" 

439 

440 @classmethod 

441 @contextmanager 

442 def open( 

443 cls, 

444 uri: ResourcePathExpression, 

445 header: HeaderModel, 

446 pipeline_graph: PipelineGraph, 

447 *, 

448 address_filename: str, 

449 cdict_data: bytes | None = None, 

450 zstd_level: int = 10, 

451 ) -> Iterator[Self]: 

452 uri = ResourcePath(uri, forceDirectory=False) 

453 address_writer = AddressWriter() 

454 if uri.isLocal: 

455 os.makedirs(uri.dirname().ospath, exist_ok=True) 

456 cdict = zstandard.ZstdCompressionDict(cdict_data) if cdict_data is not None else None 

457 compressor = zstandard.ZstdCompressor(level=zstd_level, dict_data=cdict) 

458 with uri.open(mode="wb") as stream: 

459 with zipfile.ZipFile(stream, mode="w", compression=zipfile.ZIP_STORED) as zf: 

460 self = cls(zf, compressor, address_writer, header.int_size) 

461 self.write_single_model("header", header) 

462 if cdict_data is not None: 

463 zf.writestr("compression_dict", cdict_data) 

464 self.write_single_model("pipeline_graph", SerializedPipelineGraph.serialize(pipeline_graph)) 

465 yield self 

466 address_writer.write_to_zip(zf, address_filename, int_size=self.int_size) 

467 

468 def write_single_model(self, name: str, model: pydantic.BaseModel) -> None: 

469 """Write a single compressed JSON block as a 'file' in a zip archive. 

470 

471 Parameters 

472 ---------- 

473 name : `str` 

474 Base name of the file. An extension will be added. 

475 model : `pydantic.BaseModel` 

476 Pydantic model to convert to JSON. 

477 """ 

478 json_data = model.model_dump_json().encode() 

479 self.write_single_block(name, json_data) 

480 

481 def write_single_block(self, name: str, json_data: bytes) -> None: 

482 """Write a single compressed JSON block as a 'file' in a zip archive. 

483 

484 Parameters 

485 ---------- 

486 name : `str` 

487 Base name of the file. An extension will be added. 

488 json_data : `bytes` 

489 Raw JSON to compress and write. 

490 """ 

491 json_data = self.compressor.compress(json_data) 

492 self.zf.writestr(f"{name}.json.zst", json_data) 

493 

494 

495@dataclasses.dataclass 

496class BaseQuantumGraphReader: 

497 """A helper class for reading quantum graphs.""" 

498 

499 header: HeaderModel 

500 """Header metadata for the quantum graph.""" 

501 

502 pipeline_graph: PipelineGraph 

503 """Graph of tasks and dataset type names that appear in the quantum 

504 graph. 

505 """ 

506 

507 zf: zipfile.ZipFile 

508 """The zip archive that represents the quantum graph on disk.""" 

509 

510 decompressor: Decompressor 

511 """A decompressor for all compressed JSON blocks.""" 

512 

513 address_reader: AddressReader 

514 """A helper object for reading addresses into the multi-block files.""" 

515 

516 page_size: int 

517 """Approximate number of bytes to read at a time. 

518 

519 Note that this does not set a page size for *all* reads, but it 

520 does affect the smallest, most numerous reads. 

521 """ 

522 

523 @classmethod 

524 @contextmanager 

525 def _open( 

526 cls, 

527 uri: ResourcePathExpression, 

528 *, 

529 address_filename: str, 

530 graph_type: str, 

531 n_addresses: int, 

532 page_size: int | None = None, 

533 import_mode: TaskImportMode = TaskImportMode.ASSUME_CONSISTENT_EDGES, 

534 ) -> Iterator[Self]: 

535 """Construct a reader from a URI. 

536 

537 Parameters 

538 ---------- 

539 uri : convertible to `lsst.resources.ResourcePath` 

540 URI to open. Should have a ``.qg`` extension. 

541 address_filename : `str` 

542 Base filename for the address file. 

543 graph_type : `str` 

544 Value to expect for `HeaderModel.graph_type`. 

545 n_addresses : `int` 

546 Number of addresses to expect per row in the address file. 

547 page_size : `int`, optional 

548 Approximate number of bytes to read at once from address files. 

549 Note that this does not set a page size for *all* reads, but it 

550 does affect the smallest, most numerous reads. When `None`, the 

551 ``LSST_QG_PAGE_SIZE`` environment variable is checked before 

552 falling back to a default of 5MB. 

553 import_mode : `..pipeline_graph.TaskImportMode`, optional 

554 How to handle importing the task classes referenced in the pipeline 

555 graph. 

556 

557 Returns 

558 ------- 

559 reader : `contextlib.AbstractContextManager` [ \ 

560 `PredictedQuantumGraphReader` ] 

561 A context manager that returns the reader when entered. 

562 """ 

563 if page_size is None: 

564 page_size = int(os.environ.get("LSST_QG_PAGE_SIZE", DEFAULT_PAGE_SIZE)) 

565 uri = ResourcePath(uri) 

566 cdict: zstandard.ZstdCompressionDict | None = None 

567 with uri.open(mode="rb") as zf_stream: 

568 with zipfile.ZipFile(zf_stream, "r") as zf: 

569 if (cdict_path := zipfile.Path(zf, "compression_dict")).exists(): 

570 cdict = zstandard.ZstdCompressionDict(cdict_path.read_bytes()) 

571 decompressor = zstandard.ZstdDecompressor(cdict) 

572 header = cls._read_single_block_static("header", HeaderModel, zf, decompressor) 

573 if not header.graph_type == graph_type: 

574 raise TypeError(f"Header is for a {header.graph_type!r} graph, not {graph_type!r} graph.") 

575 serialized_pipeline_graph = cls._read_single_block_static( 

576 "pipeline_graph", SerializedPipelineGraph, zf, decompressor 

577 ) 

578 pipeline_graph = serialized_pipeline_graph.deserialize(import_mode) 

579 with AddressReader.open_in_zip( 

580 zf, 

581 address_filename, 

582 page_size=page_size, 

583 int_size=header.int_size, 

584 n_addresses=n_addresses, 

585 ) as address_reader: 

586 yield cls( 

587 header=header, 

588 pipeline_graph=pipeline_graph, 

589 zf=zf, 

590 decompressor=decompressor, 

591 address_reader=address_reader, 

592 page_size=page_size, 

593 ) 

594 

595 @staticmethod 

596 def _read_single_block_static[T: pydantic.BaseModel]( 

597 name: str, model_type: type[T], zf: zipfile.ZipFile, decompressor: Decompressor 

598 ) -> T: 

599 """Read a single compressed JSON block from a 'file' in a zip archive. 

600 

601 Parameters 

602 ---------- 

603 zf : `zipfile.ZipFile` 

604 Zip archive to read the file from. 

605 name : `str` 

606 Base name of the file. An extension will be added. 

607 model_type : `type` [ `pydantic.BaseModel` ] 

608 Pydantic model to validate JSON with. 

609 decompressor : `Decompressor` 

610 Object with a `decompress` method that takes and returns `bytes`. 

611 

612 Returns 

613 ------- 

614 model : `pydantic.BaseModel` 

615 Validated model. 

616 """ 

617 compressed_data = zf.read(f"{name}.json.zst") 

618 json_data = decompressor.decompress(compressed_data) 

619 return model_type.model_validate_json(json_data) 

620 

621 def _read_single_block[T: pydantic.BaseModel](self, name: str, model_type: type[T]) -> T: 

622 """Read a single compressed JSON block from a 'file' in a zip archive. 

623 

624 Parameters 

625 ---------- 

626 name : `str` 

627 Base name of the file. An extension will be added. 

628 model_type : `type` [ `pydantic.BaseModel` ] 

629 Pydantic model to validate JSON with. 

630 

631 Returns 

632 ------- 

633 model : `pydantic.BaseModel` 

634 Validated model. 

635 """ 

636 return self._read_single_block_static(name, model_type, self.zf, self.decompressor) 

637 

638 def _read_single_block_raw(self, name: str) -> bytes: 

639 """Read a single compressed block from a 'file' in a zip archive. 

640 

641 Parameters 

642 ---------- 

643 name : `str` 

644 Base name of the file. An extension will be added. 

645 

646 Returns 

647 ------- 

648 data : `bytes` 

649 Decompressed bytes. 

650 """ 

651 compressed_data = self.zf.read(f"{name}.json.zst") 

652 return self.decompressor.decompress(compressed_data)