Coverage for python / lsst / images / tests / _roundtrip.py: 28%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:34 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("RoundtripFits", "RoundtripJson", "TemporaryButler") 

15 

16import tempfile 

17import unittest 

18import uuid 

19from abc import ABC, abstractmethod 

20from contextlib import ExitStack 

21from typing import Any, Self, TypeVar 

22 

23import astropy.io.fits 

24from pydantic_core import from_json 

25 

26try: 

27 from lsst.daf.butler import Butler, DataCoordinate, DatasetProvenance, DatasetRef, DatasetType 

28 

29 HAVE_BUTLER = True 

30except ImportError: 

31 HAVE_BUTLER = False 

32 

33from .. import fits, json 

34from .._generalized_image import GeneralizedImage 

35from ..serialization import ArchiveTree, MetadataValue, ReadResult 

36 

37# We need an old-style TypeVar for Sphinx. 

38T = TypeVar("T") 

39 

40 

41class TemporaryButler: 

42 """Make a temporary butler repository. 

43 

44 Parameters 

45 ---------- 

46 run 

47 Name of a `~lsst.daf.butler.CollectionType.RUN` collection to 

48 register and use as the default run for the returned butler. 

49 **kwargs 

50 A mapping from a dataset type name to its storage class. For each 

51 entry, a dataset type will be registered with empty dimensions, and a 

52 `~lsst.daf.butler.DatasetRef` will be created and added as an 

53 attribute of this class. 

54 

55 Raises 

56 ------ 

57 unittest.SkipTest 

58 Raised when the context manager is entered if `lsst.daf.butler` could 

59 not be imported. This is typically handled by using this context 

60 manager within a `unittest.TestCase.subTest` context, which will skip 

61 just the butler-required tests in that context while allowing the rest 

62 of the test to continue. 

63 """ 

64 

65 def __init__(self, run: str = "test_run", **kwargs: str): 

66 self.run = run 

67 self._kwargs = kwargs 

68 self._exit_stack = ExitStack() 

69 

70 def __enter__(self) -> TemporaryButler: 

71 if not HAVE_BUTLER: 

72 raise unittest.SkipTest("lsst.daf.butler could not be imported.") 

73 self._exit_stack.__enter__() 

74 root = self._exit_stack.enter_context( 

75 tempfile.TemporaryDirectory(ignore_cleanup_errors=True, delete=True) 

76 ) 

77 butler_config = Butler.makeRepo(root) 

78 self.butler = self._exit_stack.enter_context(Butler.from_config(butler_config, run=self.run)) 

79 empty_data_id = DataCoordinate.make_empty(self.butler.dimensions) 

80 for name, storage_class in self._kwargs.items(): 

81 dataset_type = DatasetType(name, self.butler.dimensions.empty, storage_class) 

82 try: 

83 self.butler.registry.registerDatasetType(dataset_type) 

84 except KeyError as err: 

85 err.add_note( 

86 "Storage class not configured in butler defaults. " 

87 "A newer version of daf_butler may be needed." 

88 ) 

89 raise 

90 setattr(self, name, DatasetRef(dataset_type, empty_data_id, self.run)) 

91 return self 

92 

93 def __exit__(self, *args: Any) -> bool | None: 

94 return self._exit_stack.__exit__(*args) 

95 

96 # Just for typing, since this class uses dynamic attributes. 

97 def __getattr__(self, name: str) -> DatasetRef: 

98 raise AttributeError(name) 

99 

100 

101class RoundtripBase[T](ABC): 

102 """A context manager for testing serialization. 

103 

104 Parameters 

105 ---------- 

106 tc 

107 A test case object to used for internal checks. 

108 original 

109 The object to serialize. 

110 storage_class 

111 A butler storage class name to use. If not provided (or 

112 `lsst.daf.butler` cannot be imported), the roundtrip will just use 

113 a direct write to a temporary file. 

114 format 

115 Archive/file format to use when not using a butler (ignored when 

116 using a butler). 

117 

118 Notes 

119 ----- 

120 When entered, this context manager writes the object and reads it back in 

121 to the ``result`` attribute. When exited, any temporary files or 

122 directories are deleted, but the ``result`` attribute is still usable. 

123 In between the `inspect` and `get` methods can be used to perform other 

124 tests. 

125 

126 This helper internally tests that butler provenance and metadata are saved 

127 with any `.GeneralizedImage` object. 

128 """ 

129 

130 def __init__( 

131 self, 

132 tc: unittest.TestCase, 

133 original: T, 

134 storage_class: str | None = None, 

135 ): 

136 self._original = original 

137 self._storage_class = storage_class 

138 self._serialized: Any = None 

139 self._exit_stack = ExitStack() 

140 self._filename: str | None = None 

141 self._tc = tc 

142 self.result: Any 

143 self.butler: Butler | None = None 

144 self.ref: DatasetRef | None = None 

145 self._test_metadata: dict[str, MetadataValue] = { 

146 "roundtrip_test_1": 1, 

147 "roundtrip_test_2": 2.5, 

148 "roundtrip_test_3": "three", 

149 "roundtrip_test_4": True, 

150 "roundtrip_test_5": None, 

151 } 

152 

153 def __enter__(self) -> Self: 

154 self._exit_stack.__enter__() 

155 if isinstance(self._original, GeneralizedImage): 

156 self._original.metadata.update(self._test_metadata) 

157 if HAVE_BUTLER and self._storage_class is not None: 

158 self._run_with_butler() 

159 else: 

160 self._run_without_butler() 

161 if isinstance(self._original, GeneralizedImage): 

162 assert isinstance(self.result, GeneralizedImage) 

163 for k in self._test_metadata: 

164 self._tc.assertEqual(self.result.metadata[k], self._test_metadata[k]) 

165 del self._original.metadata[k] 

166 del self.result.metadata[k] 

167 return self 

168 

169 def __exit__(self, *args: Any) -> bool | None: 

170 return self._exit_stack.__exit__(*args) 

171 

172 @property 

173 def filename(self) -> str: 

174 """The name of the file the object was written to.""" 

175 if self._filename is None: 

176 assert self.butler is not None and self.ref is not None 

177 self._filename = self.butler.getURI(self.ref).ospath 

178 return self._filename 

179 

180 @property 

181 def serialized(self) -> Any: 

182 """The serialization model for this object 

183 (`.serialization.ArchiveTree`). 

184 """ 

185 if self._serialized is None: 

186 # The butler code path doesn't give us a way to inspect the 

187 # serialized model, so we have to save it again directly to another 

188 # file (which we then discard). 

189 with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) as tmp: 

190 tmp.close() 

191 self._serialized = fits.write(self._original, tmp.name) 

192 return self._serialized 

193 

194 def get(self, component: str | None = None, storageClass: str | None = None, **kwargs: Any) -> Any: 

195 """Perform a partial read. 

196 

197 Parameters 

198 ---------- 

199 component 

200 Component to read instead of the main object. This requires the 

201 roundtrip to use a butler, raising `unittest.SkipTest` otherwise; 

202 this generally means these tests should be nested within a 

203 `~unittest.TestCase.subTest` context. 

204 storageClass 

205 Override storage class name to affect the type returned by 

206 the get. Only used if a butler is active. 

207 **kwargs 

208 Keyword arguments either passed directly to `.fits.read` or used 

209 as ``parameters`` for a `~lsst.daf.butler.Butler.get`. 

210 

211 Return 

212 ------ 

213 object 

214 Result of the partial read. 

215 """ 

216 if self.butler is None: 

217 if component is not None: 

218 raise unittest.SkipTest("Cannot test component reads without a butler.") 

219 if storageClass is not None: 

220 raise unittest.SkipTest("Cannot test storage class override without a butler") 

221 result = fits.read(type(self._original), self.filename, **kwargs).deserialized 

222 else: 

223 assert self.ref is not None, "butler and ref should be None or not together" 

224 ref = self.ref 

225 if component is not None: 

226 ref = ref.makeComponentRef(component) 

227 result = self.butler.get(ref, parameters=kwargs, storageClass=storageClass) 

228 if isinstance(result, GeneralizedImage): 

229 # The metadata the RoundtripFits object added for the test may or 

230 # may not be present; strip it if it does so comparisons to the 

231 # original are not messed up. 

232 for k in self._test_metadata: 

233 result.metadata.pop(k, None) 

234 return result 

235 

236 def _run_with_butler(self) -> None: 

237 assert self._storage_class is not None, "Should not use butler if no storage class" 

238 butler_helper = self._exit_stack.enter_context(TemporaryButler(test_dataset=self._storage_class)) 

239 self.butler = butler_helper.butler 

240 quantum_id = uuid.uuid4() 

241 self.ref = self.butler.put( 

242 self._original, butler_helper.test_dataset, provenance=DatasetProvenance(quantum_id=quantum_id) 

243 ) 

244 self.result = self.butler.get(self.ref) 

245 if isinstance(self._original, GeneralizedImage): 

246 self._tc.assertEqual( 

247 DatasetRef.from_simple(self.result.butler_dataset, universe=self.butler.dimensions), self.ref 

248 ) 

249 self._tc.assertEqual(self.result.butler_provenance.quantum_id, quantum_id) 

250 self._tc.assertTrue(self.filename.endswith(self._get_extension())) 

251 

252 def _run_without_butler(self) -> None: 

253 tmp = self._exit_stack.enter_context( 

254 tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) 

255 ) 

256 tmp.close() 

257 self._filename = tmp.name 

258 self._serialized = self._write(self._original, tmp.name) 

259 read_result = self._read(type(self._original), tmp.name) 

260 self._tc.assertIsNone(read_result.butler_info) 

261 self.result = read_result.deserialized 

262 

263 @abstractmethod 

264 def _get_extension(self) -> str: 

265 raise NotImplementedError() 

266 

267 @abstractmethod 

268 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

269 raise NotImplementedError() 

270 

271 @abstractmethod 

272 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

273 raise NotImplementedError() 

274 

275 

276class RoundtripFits[T](RoundtripBase[T]): 

277 def inspect(self) -> astropy.io.fits.HDUList: 

278 """Open the FITS file with Astropy.""" 

279 return self._exit_stack.enter_context( 

280 astropy.io.fits.open(self.filename, disable_image_compression=True) 

281 ) 

282 

283 def _get_extension(self) -> str: 

284 return ".fits" 

285 

286 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

287 return fits.write(obj, filename) 

288 

289 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

290 return fits.read(obj_type, filename) 

291 

292 

293class RoundtripJson[T](RoundtripBase[T]): 

294 def inspect(self) -> dict[str, Any]: 

295 """Read the JSON file as a dictionary.""" 

296 with open(self.filename, "rb") as stream: 

297 return from_json(stream.read()) 

298 

299 def _get_extension(self) -> str: 

300 return ".json" 

301 

302 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

303 return json.write(obj, filename) 

304 

305 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

306 return json.read(obj_type, filename)