Coverage for python / lsst / images / tests / _roundtrip.py: 28%
146 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:48 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("RoundtripFits", "RoundtripJson", "TemporaryButler")
16import tempfile
17import unittest
18import uuid
19from abc import ABC, abstractmethod
20from contextlib import ExitStack
21from typing import Any, Self, TypeVar
23import astropy.io.fits
24from pydantic_core import from_json
26try:
27 from lsst.daf.butler import Butler, DataCoordinate, DatasetProvenance, DatasetRef, DatasetType
29 HAVE_BUTLER = True
30except ImportError:
31 HAVE_BUTLER = False
33from .. import fits, json
34from .._generalized_image import GeneralizedImage
35from ..serialization import ArchiveTree, MetadataValue, ReadResult
37# We need an old-style TypeVar for Sphinx.
38T = TypeVar("T")
41class TemporaryButler:
42 """Make a temporary butler repository.
44 Parameters
45 ----------
46 run
47 Name of a `~lsst.daf.butler.CollectionType.RUN` collection to
48 register and use as the default run for the returned butler.
49 **kwargs
50 A mapping from a dataset type name to its storage class. For each
51 entry, a dataset type will be registered with empty dimensions, and a
52 `~lsst.daf.butler.DatasetRef` will be created and added as an
53 attribute of this class.
55 Raises
56 ------
57 unittest.SkipTest
58 Raised when the context manager is entered if `lsst.daf.butler` could
59 not be imported. This is typically handled by using this context
60 manager within a `unittest.TestCase.subTest` context, which will skip
61 just the butler-required tests in that context while allowing the rest
62 of the test to continue.
63 """
65 def __init__(self, run: str = "test_run", **kwargs: str):
66 self.run = run
67 self._kwargs = kwargs
68 self._exit_stack = ExitStack()
70 def __enter__(self) -> TemporaryButler:
71 if not HAVE_BUTLER:
72 raise unittest.SkipTest("lsst.daf.butler could not be imported.")
73 self._exit_stack.__enter__()
74 root = self._exit_stack.enter_context(
75 tempfile.TemporaryDirectory(ignore_cleanup_errors=True, delete=True)
76 )
77 butler_config = Butler.makeRepo(root)
78 self.butler = self._exit_stack.enter_context(Butler.from_config(butler_config, run=self.run))
79 empty_data_id = DataCoordinate.make_empty(self.butler.dimensions)
80 for name, storage_class in self._kwargs.items():
81 dataset_type = DatasetType(name, self.butler.dimensions.empty, storage_class)
82 try:
83 self.butler.registry.registerDatasetType(dataset_type)
84 except KeyError as err:
85 err.add_note(
86 "Storage class not configured in butler defaults. "
87 "A newer version of daf_butler may be needed."
88 )
89 raise
90 setattr(self, name, DatasetRef(dataset_type, empty_data_id, self.run))
91 return self
93 def __exit__(self, *args: Any) -> bool | None:
94 return self._exit_stack.__exit__(*args)
96 # Just for typing, since this class uses dynamic attributes.
97 def __getattr__(self, name: str) -> DatasetRef:
98 raise AttributeError(name)
101class RoundtripBase[T](ABC):
102 """A context manager for testing serialization.
104 Parameters
105 ----------
106 tc
107 A test case object to used for internal checks.
108 original
109 The object to serialize.
110 storage_class
111 A butler storage class name to use. If not provided (or
112 `lsst.daf.butler` cannot be imported), the roundtrip will just use
113 a direct write to a temporary file.
114 format
115 Archive/file format to use when not using a butler (ignored when
116 using a butler).
118 Notes
119 -----
120 When entered, this context manager writes the object and reads it back in
121 to the ``result`` attribute. When exited, any temporary files or
122 directories are deleted, but the ``result`` attribute is still usable.
123 In between the `inspect` and `get` methods can be used to perform other
124 tests.
126 This helper internally tests that butler provenance and metadata are saved
127 with any `.GeneralizedImage` object.
128 """
130 def __init__(
131 self,
132 tc: unittest.TestCase,
133 original: T,
134 storage_class: str | None = None,
135 ):
136 self._original = original
137 self._storage_class = storage_class
138 self._serialized: Any = None
139 self._exit_stack = ExitStack()
140 self._filename: str | None = None
141 self._tc = tc
142 self.result: Any
143 self.butler: Butler | None = None
144 self.ref: DatasetRef | None = None
145 self._test_metadata: dict[str, MetadataValue] = {
146 "roundtrip_test_1": 1,
147 "roundtrip_test_2": 2.5,
148 "roundtrip_test_3": "three",
149 "roundtrip_test_4": True,
150 "roundtrip_test_5": None,
151 }
153 def __enter__(self) -> Self:
154 self._exit_stack.__enter__()
155 if isinstance(self._original, GeneralizedImage):
156 self._original.metadata.update(self._test_metadata)
157 if HAVE_BUTLER and self._storage_class is not None:
158 self._run_with_butler()
159 else:
160 self._run_without_butler()
161 if isinstance(self._original, GeneralizedImage):
162 assert isinstance(self.result, GeneralizedImage)
163 for k in self._test_metadata:
164 self._tc.assertEqual(self.result.metadata[k], self._test_metadata[k])
165 del self._original.metadata[k]
166 del self.result.metadata[k]
167 return self
169 def __exit__(self, *args: Any) -> bool | None:
170 return self._exit_stack.__exit__(*args)
172 @property
173 def filename(self) -> str:
174 """The name of the file the object was written to."""
175 if self._filename is None:
176 assert self.butler is not None and self.ref is not None
177 self._filename = self.butler.getURI(self.ref).ospath
178 return self._filename
180 @property
181 def serialized(self) -> Any:
182 """The serialization model for this object
183 (`.serialization.ArchiveTree`).
184 """
185 if self._serialized is None:
186 # The butler code path doesn't give us a way to inspect the
187 # serialized model, so we have to save it again directly to another
188 # file (which we then discard).
189 with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) as tmp:
190 tmp.close()
191 self._serialized = fits.write(self._original, tmp.name)
192 return self._serialized
194 def get(self, component: str | None = None, storageClass: str | None = None, **kwargs: Any) -> Any:
195 """Perform a partial read.
197 Parameters
198 ----------
199 component
200 Component to read instead of the main object. This requires the
201 roundtrip to use a butler, raising `unittest.SkipTest` otherwise;
202 this generally means these tests should be nested within a
203 `~unittest.TestCase.subTest` context.
204 storageClass
205 Override storage class name to affect the type returned by
206 the get. Only used if a butler is active.
207 **kwargs
208 Keyword arguments either passed directly to `.fits.read` or used
209 as ``parameters`` for a `~lsst.daf.butler.Butler.get`.
211 Return
212 ------
213 object
214 Result of the partial read.
215 """
216 if self.butler is None:
217 if component is not None:
218 raise unittest.SkipTest("Cannot test component reads without a butler.")
219 if storageClass is not None:
220 raise unittest.SkipTest("Cannot test storage class override without a butler")
221 result = fits.read(type(self._original), self.filename, **kwargs).deserialized
222 else:
223 assert self.ref is not None, "butler and ref should be None or not together"
224 ref = self.ref
225 if component is not None:
226 ref = ref.makeComponentRef(component)
227 result = self.butler.get(ref, parameters=kwargs, storageClass=storageClass)
228 if isinstance(result, GeneralizedImage):
229 # The metadata the RoundtripFits object added for the test may or
230 # may not be present; strip it if it does so comparisons to the
231 # original are not messed up.
232 for k in self._test_metadata:
233 result.metadata.pop(k, None)
234 return result
236 def _run_with_butler(self) -> None:
237 assert self._storage_class is not None, "Should not use butler if no storage class"
238 butler_helper = self._exit_stack.enter_context(TemporaryButler(test_dataset=self._storage_class))
239 self.butler = butler_helper.butler
240 quantum_id = uuid.uuid4()
241 self.ref = self.butler.put(
242 self._original, butler_helper.test_dataset, provenance=DatasetProvenance(quantum_id=quantum_id)
243 )
244 self.result = self.butler.get(self.ref)
245 if isinstance(self._original, GeneralizedImage):
246 self._tc.assertEqual(
247 DatasetRef.from_simple(self.result.butler_dataset, universe=self.butler.dimensions), self.ref
248 )
249 self._tc.assertEqual(self.result.butler_provenance.quantum_id, quantum_id)
250 self._tc.assertTrue(self.filename.endswith(self._get_extension()))
252 def _run_without_butler(self) -> None:
253 tmp = self._exit_stack.enter_context(
254 tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True)
255 )
256 tmp.close()
257 self._filename = tmp.name
258 self._serialized = self._write(self._original, tmp.name)
259 read_result = self._read(type(self._original), tmp.name)
260 self._tc.assertIsNone(read_result.butler_info)
261 self.result = read_result.deserialized
263 @abstractmethod
264 def _get_extension(self) -> str:
265 raise NotImplementedError()
267 @abstractmethod
268 def _write(self, obj: Any, filename: str) -> ArchiveTree:
269 raise NotImplementedError()
271 @abstractmethod
272 def _read(self, obj_type: Any, filename: str) -> ReadResult:
273 raise NotImplementedError()
276class RoundtripFits[T](RoundtripBase[T]):
277 def inspect(self) -> astropy.io.fits.HDUList:
278 """Open the FITS file with Astropy."""
279 return self._exit_stack.enter_context(
280 astropy.io.fits.open(self.filename, disable_image_compression=True)
281 )
283 def _get_extension(self) -> str:
284 return ".fits"
286 def _write(self, obj: Any, filename: str) -> ArchiveTree:
287 return fits.write(obj, filename)
289 def _read(self, obj_type: Any, filename: str) -> ReadResult:
290 return fits.read(obj_type, filename)
293class RoundtripJson[T](RoundtripBase[T]):
294 def inspect(self) -> dict[str, Any]:
295 """Read the JSON file as a dictionary."""
296 with open(self.filename, "rb") as stream:
297 return from_json(stream.read())
299 def _get_extension(self) -> str:
300 return ".json"
302 def _write(self, obj: Any, filename: str) -> ArchiveTree:
303 return json.write(obj, filename)
305 def _read(self, obj_type: Any, filename: str) -> ReadResult:
306 return json.read(obj_type, filename)