Coverage for tests / test_visit_image.py: 14%

219 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 09:16 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14import os 

15import unittest 

16import warnings 

17from typing import Any 

18 

19import astropy.io.fits 

20import astropy.units as u 

21import astropy.wcs 

22import numpy as np 

23from astro_metadata_translator import ObservationInfo 

24 

25from lsst.images import ( 

26 Box, 

27 DetectorFrame, 

28 Image, 

29 MaskPlane, 

30 MaskSchema, 

31 ProjectionAstropyView, 

32 TractFrame, 

33 VisitImage, 

34 get_legacy_visit_image_mask_planes, 

35) 

36from lsst.images.fits import ExtensionKey, FitsOpaqueMetadata 

37from lsst.images.psfs import GaussianPointSpreadFunction, PointSpreadFunction 

38from lsst.images.tests import ( 

39 DP2_VISIT_DETECTOR_DATA_ID, 

40 RoundtripFits, 

41 TemporaryButler, 

42 assert_masked_images_equal, 

43 assert_projections_equal, 

44 compare_visit_image_to_legacy, 

45 make_random_projection, 

46) 

47 

48DATA_DIR = os.environ.get("TESTDATA_IMAGES_DIR", None) 

49 

50 

51class VisitImageTestCase(unittest.TestCase): 

52 """Basic Tests for VisitImage.""" 

53 

54 @classmethod 

55 def setUpClass(cls) -> None: 

56 cls.rng = np.random.default_rng(500) 

57 det_frame = DetectorFrame(instrument="Inst", visit=1234, detector=1, bbox=Box.factory[1:4096, 1:4096]) 

58 cls.projection = make_random_projection(cls.rng, det_frame, Box.factory[1:4096, 1:4096]) 

59 cls.mask_schema = MaskSchema([MaskPlane("M1", "D1")]) 

60 cls.obs_info = ObservationInfo(instrument="LSSTCam", detector_num=4) 

61 cls.gaussian_psf = GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13]) 

62 

63 opaque = FitsOpaqueMetadata() 

64 hdr = astropy.io.fits.Header() 

65 with warnings.catch_warnings(): 

66 # Silence warnings about long keys becoming HIERARCH. 

67 warnings.simplefilter("ignore", category=astropy.io.fits.verify.VerifyWarning) 

68 hdr.update({"PLATFORM": "lsstcam", "LSST BUTLER ID": "123456789"}) 

69 opaque.extract_legacy_primary_header(hdr) 

70 

71 cls.image = Image(42, shape=(1024, 1024), unit=u.nJy) 

72 cls.variance = Image(5.0, shape=(1024, 1024), unit=u.nJy * u.nJy) 

73 # API signature suggests projection and obs_info can be None but they 

74 # are required. 

75 cls.visit_image = VisitImage( 

76 cls.image, 

77 variance=cls.variance, 

78 psf=GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13]), 

79 mask_schema=cls.mask_schema, 

80 projection=cls.projection, 

81 obs_info=cls.obs_info, 

82 ) 

83 cls.visit_image._opaque_metadata = opaque 

84 cls.simplest_visit_image = VisitImage( 

85 cls.image, 

86 psf=GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13]), 

87 mask_schema=cls.mask_schema, 

88 projection=cls.projection, 

89 obs_info=cls.obs_info, 

90 ) 

91 

92 def test_basics(self) -> None: 

93 """Test basic constructor patterns.""" 

94 # Test default fill of variance. 

95 visit = self.simplest_visit_image 

96 self.assertEqual(visit.variance.array[0, 0], 1.0) 

97 self.assertIs(visit[...], visit) 

98 self.assertEqual(str(visit), "VisitImage(Image([y=0:1024, x=0:1024], int64), ['M1'])") 

99 self.assertEqual( 

100 repr(visit), 

101 "VisitImage(Image(..., bbox=Box(y=Interval(start=0, stop=1024), x=Interval(start=0, stop=1024))," 

102 " dtype=dtype('int64')), mask_schema=MaskSchema([MaskPlane(name='M1', description='D1')]," 

103 " dtype=dtype('uint8')))", 

104 ) 

105 

106 astropy_wcs = visit.astropy_wcs 

107 self.assertIsInstance(astropy_wcs, ProjectionAstropyView) 

108 approx_wcs = visit.fits_wcs 

109 self.assertIsInstance(approx_wcs, astropy.wcs.WCS) 

110 

111 # Check that it is a deep copy. 

112 copy = visit.copy() 

113 copy.image.array[0, 0] = 30.0 

114 self.assertEqual(visit.image.array[0, 0], 42.0) 

115 self.assertEqual(copy.image.array[0, 0], 30.0) 

116 

117 with self.assertRaises(TypeError): 

118 # Requires a PSF. 

119 VisitImage( 

120 self.image, 

121 mask_schema=self.mask_schema, 

122 projection=self.projection, 

123 obs_info=self.obs_info, 

124 ) 

125 

126 with self.assertRaises(TypeError): 

127 # Requires ObservationInfo. 

128 VisitImage( 

129 self.image, 

130 psf=self.gaussian_psf, 

131 mask_schema=self.mask_schema, 

132 projection=self.projection, 

133 ) 

134 

135 with self.assertRaises(TypeError): 

136 # Requires a projection. 

137 VisitImage( 

138 self.image, 

139 psf=self.gaussian_psf, 

140 mask_schema=self.mask_schema, 

141 obs_info=self.obs_info, 

142 ) 

143 

144 with self.assertRaises(TypeError): 

145 # Requires some form of mask. 

146 VisitImage( 

147 self.image, 

148 psf=self.gaussian_psf, 

149 projection=self.projection, 

150 obs_info=self.obs_info, 

151 ) 

152 

153 with self.assertRaises(TypeError): 

154 VisitImage( 

155 Image(42, shape=(5, 5)), 

156 psf=self.gaussian_psf, 

157 mask_schema=self.mask_schema, 

158 projection=self.projection, 

159 obs_info=self.obs_info, 

160 ) 

161 

162 # Requires a DetectorFrame. 

163 tract_frame = TractFrame(skymap="Skymap", tract=1, bbox=Box.factory[1:10, 1:10]) 

164 tract_proj = make_random_projection(self.rng, tract_frame, Box.factory[1:4096, 1:4096]) 

165 with self.assertRaises(TypeError): 

166 VisitImage( 

167 self.image, 

168 projection=tract_proj, 

169 psf=self.gaussian_psf, 

170 mask_schema=self.mask_schema, 

171 obs_info=self.obs_info, 

172 ) 

173 

174 # Variance unit mismatch. 

175 with self.assertRaises(ValueError): 

176 VisitImage( 

177 self.image, 

178 variance=self.image, 

179 psf=self.gaussian_psf, 

180 mask_schema=self.mask_schema, 

181 projection=self.projection, 

182 obs_info=self.obs_info, 

183 ) 

184 

185 def test_obs_info(self) -> None: 

186 """Check that ObservationInfo has been constructured.""" 

187 visit = self.visit_image 

188 self.assertIsNotNone(visit.obs_info) 

189 self.maxDiff = None 

190 assert visit.obs_info is not None # for mypy. 

191 self.assertEqual(visit.obs_info.instrument, "LSSTCam") 

192 

193 def test_read_write(self) -> None: 

194 """Test that a visit can round trip through a FITS file.""" 

195 with RoundtripFits(self, self.visit_image, "VisitImage") as roundtrip: 

196 # Check that we're still using the right compression, and that we 

197 # wrote WCSs. 

198 fits = roundtrip.inspect() 

199 self.assertEqual(fits[1].header["ZCMPTYPE"], "GZIP_2") 

200 self.assertEqual(fits[1].header["CTYPE1"], "RA---TAN") 

201 self.assertEqual(fits[2].header["ZCMPTYPE"], "GZIP_2") 

202 self.assertEqual(fits[2].header["CTYPE1"], "RA---TAN") 

203 self.assertEqual(fits[3].header["ZCMPTYPE"], "GZIP_2") 

204 self.assertEqual(fits[3].header["CTYPE1"], "RA---TAN") 

205 # Check a subimage read. 

206 subbox = Box.factory[8:13, 9:30] 

207 subimage = roundtrip.get(bbox=subbox) 

208 assert_masked_images_equal(self, subimage, self.visit_image[subbox], expect_view=False) 

209 with self.subTest(): 

210 self.assertEqual(roundtrip.get("bbox"), self.visit_image.bbox) 

211 with self.subTest(): 

212 obs_info = roundtrip.get("obs_info") 

213 self.assertIsInstance(obs_info, ObservationInfo) 

214 self.assertEqual(obs_info, self.visit_image.obs_info) 

215 

216 assert_masked_images_equal(self, roundtrip.result, self.visit_image, expect_view=False) 

217 # Check that the round-tripped headers are the same (up to card order). 

218 self.assertEqual(len(roundtrip.result._opaque_metadata.headers[ExtensionKey()]), 1) 

219 self.assertEqual( 

220 dict(self.visit_image._opaque_metadata.headers[ExtensionKey()]), 

221 dict(roundtrip.result._opaque_metadata.headers[ExtensionKey()]), 

222 ) 

223 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("IMAGE")]) 

224 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("MASK")]) 

225 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("VARIANCE")]) 

226 self.assertEqual(roundtrip.result.obs_info, self.visit_image.obs_info) 

227 

228 

229@unittest.skipUnless(DATA_DIR is not None, "TESTDATA_IMAGES_DIR is not in the environment.") 

230class VisitImageLegacyTestCase(unittest.TestCase): 

231 """Tests for the VisitImage class and the basics of the archive system. 

232 

233 Requires legacy code. 

234 """ 

235 

236 @classmethod 

237 def setUpClass(cls) -> None: 

238 assert DATA_DIR is not None, "Guaranteed by decorator." 

239 cls.filename = os.path.join(DATA_DIR, "dp2", "legacy", "visit_image.fits") 

240 cls.plane_map = plane_map = get_legacy_visit_image_mask_planes() 

241 cls.visit_image = VisitImage.read_legacy( 

242 cls.filename, preserve_quantization=True, plane_map=plane_map 

243 ) 

244 cls.legacy_exposure: Any = None 

245 try: 

246 from lsst.afw.image import ExposureFitsReader 

247 

248 cls.legacy_exposure = ExposureFitsReader(cls.filename).read() 

249 except ImportError: 

250 pass 

251 

252 def test_legacy_errors(self) -> None: 

253 """Legacy read failure modes.""" 

254 with self.assertRaises(ValueError): 

255 VisitImage.from_legacy(self.legacy_exposure, instrument="HSC") 

256 with self.assertRaises(ValueError): 

257 VisitImage.from_legacy(self.legacy_exposure, visit=123456) 

258 with self.assertRaises(ValueError): 

259 VisitImage.from_legacy(self.legacy_exposure, unit=u.mJy) 

260 visit = VisitImage.from_legacy( 

261 self.legacy_exposure, instrument="LSSTCam", unit=u.nJy, visit=2025052000177 

262 ) 

263 self.assertEqual(visit.unit, u.nJy) 

264 

265 with self.assertRaises(ValueError): 

266 VisitImage.read_legacy(self.filename, instrument="HSC") 

267 with self.assertRaises(ValueError): 

268 VisitImage.read_legacy(self.filename, visit=123456) 

269 

270 def test_component_reads(self) -> None: 

271 """Test reads of components from legacy file.""" 

272 visit = VisitImage.read_legacy(self.filename) 

273 proj = VisitImage.read_legacy(self.filename, component="projection") 

274 assert_projections_equal(self, proj, visit.projection, expect_identity=False) 

275 image = VisitImage.read_legacy(self.filename, component="image") 

276 self.assertEqual(image, visit.image) 

277 variance = VisitImage.read_legacy(self.filename, component="variance") 

278 self.assertEqual(variance, visit.variance) 

279 mask = VisitImage.read_legacy(self.filename, component="mask") 

280 self.assertEqual(mask, visit.mask) 

281 psf = VisitImage.read_legacy(self.filename, component="psf") 

282 self.assertIsInstance(psf, PointSpreadFunction) 

283 obs_info = VisitImage.read_legacy(self.filename, component="obs_info") 

284 self.assertIsInstance(obs_info, ObservationInfo) 

285 self.assertEqual(obs_info.instrument, "LSSTCam") 

286 self.assertEqual(obs_info.detector_num, 85, obs_info) 

287 self.assertEqual(obs_info.detector_unique_name, "R21_S11", obs_info) 

288 self.assertEqual(obs_info.physical_filter, "r_57", obs_info) 

289 

290 def test_obs_info(self) -> None: 

291 """Check that ObservationInfo has been constructed.""" 

292 legacy = VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map) 

293 self.assertIsNotNone(legacy.obs_info) 

294 self.maxDiff = None 

295 self.assertEqual(legacy.obs_info, self.visit_image.obs_info) 

296 assert legacy.obs_info is not None # for mypy. 

297 self.assertEqual(legacy.obs_info.instrument, "LSSTCam") 

298 self.assertEqual(legacy.obs_info.detector_num, 85, legacy.obs_info) 

299 self.assertEqual(legacy.obs_info.detector_unique_name, "R21_S11", legacy.obs_info) 

300 self.assertEqual(legacy.obs_info.physical_filter, "r_57", legacy.obs_info) 

301 

302 def test_read_legacy_headers(self) -> None: 

303 """Test that headers were correctly stripped and interpreted in 

304 `VisitImage.read_legacy`. 

305 """ 

306 # Check that we read the units from BUNIT. 

307 self.assertEqual(self.visit_image.unit, astropy.units.nJy) 

308 # Check that the primary header has the keys we want, and none of the 

309 # keys we don't want. 

310 header = self.visit_image._opaque_metadata.headers[ExtensionKey()] 

311 self.assertIn("EXPTIME", header) 

312 self.assertEqual(header["PLATFORM"], "lsstcam") 

313 self.assertNotIn("LSST BUTLER ID", header) 

314 self.assertNotIn("AR HDU", header) 

315 self.assertNotIn("A_ORDER", header) 

316 # Check that the extension HDUs do not have any custom headers. 

317 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("IMAGE")]) 

318 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("MASK")]) 

319 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("VARIANCE")]) 

320 

321 def test_from_legacy_headers(self) -> None: 

322 """Test that from_legacy handles headers properly.""" 

323 legacy = VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map) 

324 header = legacy._opaque_metadata.headers[ExtensionKey()] 

325 self.assertIn("EXPTIME", header) 

326 self.assertEqual(header["PLATFORM"], "lsstcam") 

327 self.assertNotIn("LSST BUTLER ID", header) 

328 self.assertNotIn("AR HDU", header) 

329 self.assertNotIn("A_ORDER", header) 

330 # Check that the extension HDUs do not have any custom headers. 

331 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("IMAGE")]) 

332 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("MASK")]) 

333 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("VARIANCE")]) 

334 

335 def test_rewrite(self) -> None: 

336 """Test that we can rewrite the visit image and preserve both 

337 lossy-compressed pixel values and components exactly. 

338 """ 

339 with RoundtripFits(self, self.visit_image, "VisitImage") as roundtrip: 

340 # Check that we're still using the right compression, and that we 

341 # wrote WCSs. 

342 fits = roundtrip.inspect() 

343 self.assertEqual(fits[1].header["ZCMPTYPE"], "RICE_1") 

344 self.assertEqual(fits[1].header["CTYPE1"], "RA---TAN-SIP") 

345 self.assertEqual(fits[2].header["ZCMPTYPE"], "GZIP_2") 

346 self.assertEqual(fits[2].header["CTYPE1"], "RA---TAN-SIP") 

347 self.assertEqual(fits[3].header["ZCMPTYPE"], "RICE_1") 

348 self.assertEqual(fits[3].header["CTYPE1"], "RA---TAN-SIP") 

349 # Check a subimage read. 

350 subbox = Box.factory[8:13, 9:30] 

351 subimage = roundtrip.get(bbox=subbox) 

352 assert_masked_images_equal(self, subimage, self.visit_image[subbox], expect_view=False) 

353 alternates: dict[str, Any] = {} 

354 with self.subTest(): 

355 self.assertEqual(roundtrip.get("bbox"), self.visit_image.bbox) 

356 alternates = { 

357 k: roundtrip.get(k) 

358 for k in ["projection", "image", "mask", "variance", "psf", "obs_info"] 

359 } 

360 # Try to do a butler get of a component with storage class 

361 # override. 

362 with self.subTest(): 

363 if self.legacy_exposure is not None: 

364 import lsst.afw.image 

365 

366 # We have VisitInfo available. 

367 visit_info = roundtrip.get("obs_info", storageClass="VisitInfo") 

368 self.assertIsInstance(visit_info, lsst.afw.image.VisitInfo) 

369 self.assertEqual(visit_info.getInstrumentLabel(), "LSSTCam") 

370 else: 

371 raise unittest.SkipTest("Can not test VisitInfo conversion without afw") 

372 

373 assert_masked_images_equal(self, roundtrip.result, self.visit_image, expect_view=False) 

374 # Check that the round-tripped headers are the same (up to card order). 

375 self.assertEqual( 

376 dict(self.visit_image._opaque_metadata.headers[ExtensionKey()]), 

377 dict(roundtrip.result._opaque_metadata.headers[ExtensionKey()]), 

378 ) 

379 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("IMAGE")]) 

380 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("MASK")]) 

381 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("VARIANCE")]) 

382 self.assertEqual(roundtrip.result._opaque_metadata.headers[ExtensionKey()]["PLATFORM"], "lsstcam") 

383 with self.subTest(): 

384 if self.legacy_exposure is None: 

385 raise unittest.SkipTest("'lsst.afw.image' could not be imported.") from None 

386 compare_visit_image_to_legacy( 

387 self, 

388 roundtrip.result, 

389 self.legacy_exposure, 

390 expect_view=False, 

391 plane_map=self.plane_map, 

392 **DP2_VISIT_DETECTOR_DATA_ID, 

393 alternates=alternates, 

394 ) 

395 # Check converting from the legacy object in-memory. 

396 compare_visit_image_to_legacy( 

397 self, 

398 VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map), 

399 self.legacy_exposure, 

400 expect_view=True, 

401 plane_map=self.plane_map, 

402 **DP2_VISIT_DETECTOR_DATA_ID, 

403 ) 

404 

405 def test_butler_converters(self) -> None: 

406 """Test that we can read a VisitImage and its components from a butler 

407 dataset written as an `lsst.afw.image.Exposure`. 

408 """ 

409 if self.legacy_exposure is None: 

410 raise unittest.SkipTest("lsst.afw.image.afw could not be imported.") 

411 with TemporaryButler(legacy="ExposureF") as helper: 

412 from lsst.daf.butler import FileDataset 

413 

414 helper.butler.ingest(FileDataset(path=self.filename, refs=[helper.legacy]), transfer="symlink") 

415 visit_image_ref = helper.legacy.overrideStorageClass("VisitImage") 

416 visit_image = helper.butler.get(visit_image_ref) 

417 bbox = helper.butler.get(visit_image_ref.makeComponentRef("bbox")) 

418 self.assertEqual(bbox, visit_image.bbox) 

419 alternates = { 

420 k: helper.butler.get(visit_image_ref.makeComponentRef(k)) 

421 # TODO: including "projection" or "obs_info" here fails because 

422 # there's code in daf_butler that expects any component to be 

423 # valid for the *internal* storage class, not the requested 

424 # one, and that's difficult to fix because it's tied up with 

425 # the data ID standardization logic. 

426 for k in ["image", "mask", "variance", "psf"] 

427 } 

428 compare_visit_image_to_legacy( 

429 self, 

430 visit_image, 

431 self.legacy_exposure, 

432 expect_view=False, 

433 plane_map=self.plane_map, 

434 alternates=alternates, 

435 **DP2_VISIT_DETECTOR_DATA_ID, 

436 ) 

437 

438 

439if __name__ == "__main__": 

440 unittest.main()