Coverage for tests / test_visit_image.py: 14%
228 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:01 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:01 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14import os
15import unittest
16import warnings
17from typing import Any
19import astropy.io.fits
20import astropy.units as u
21import astropy.wcs
22import numpy as np
23from astro_metadata_translator import ObservationInfo
25from lsst.images import (
26 Box,
27 DetectorFrame,
28 Image,
29 MaskPlane,
30 MaskSchema,
31 ObservationSummaryStats,
32 ProjectionAstropyView,
33 TractFrame,
34 VisitImage,
35 get_legacy_visit_image_mask_planes,
36)
37from lsst.images.fits import ExtensionKey, FitsOpaqueMetadata
38from lsst.images.psfs import GaussianPointSpreadFunction, PointSpreadFunction
39from lsst.images.tests import (
40 DP2_VISIT_DETECTOR_DATA_ID,
41 RoundtripFits,
42 TemporaryButler,
43 assert_masked_images_equal,
44 assert_projections_equal,
45 compare_visit_image_to_legacy,
46 make_random_projection,
47)
49DATA_DIR = os.environ.get("TESTDATA_IMAGES_DIR", None)
52class VisitImageTestCase(unittest.TestCase):
53 """Basic Tests for VisitImage."""
55 @classmethod
56 def setUpClass(cls) -> None:
57 cls.rng = np.random.default_rng(500)
58 det_frame = DetectorFrame(instrument="Inst", visit=1234, detector=1, bbox=Box.factory[1:4096, 1:4096])
59 cls.projection = make_random_projection(cls.rng, det_frame, Box.factory[1:4096, 1:4096])
60 cls.mask_schema = MaskSchema([MaskPlane("M1", "D1")])
61 cls.obs_info = ObservationInfo(instrument="LSSTCam", detector_num=4)
62 cls.summary_stats = ObservationSummaryStats(psfSigma=2.5, zeroPoint=31.4)
63 cls.gaussian_psf = GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13])
65 opaque = FitsOpaqueMetadata()
66 hdr = astropy.io.fits.Header()
67 with warnings.catch_warnings():
68 # Silence warnings about long keys becoming HIERARCH.
69 warnings.simplefilter("ignore", category=astropy.io.fits.verify.VerifyWarning)
70 hdr.update({"PLATFORM": "lsstcam", "LSST BUTLER ID": "123456789"})
71 opaque.extract_legacy_primary_header(hdr)
73 cls.image = Image(42, shape=(1024, 1024), unit=u.nJy)
74 cls.variance = Image(5.0, shape=(1024, 1024), unit=u.nJy * u.nJy)
75 # API signature suggests projection and obs_info can be None but they
76 # are required.
77 cls.visit_image = VisitImage(
78 cls.image,
79 variance=cls.variance,
80 psf=GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13]),
81 mask_schema=cls.mask_schema,
82 projection=cls.projection,
83 obs_info=cls.obs_info,
84 summary_stats=cls.summary_stats,
85 )
86 cls.visit_image._opaque_metadata = opaque
87 cls.simplest_visit_image = VisitImage(
88 cls.image,
89 psf=GaussianPointSpreadFunction(2.5, stamp_size=33, bounds=Box.factory[-10:10, -12:13]),
90 mask_schema=cls.mask_schema,
91 projection=cls.projection,
92 obs_info=cls.obs_info,
93 )
95 def test_basics(self) -> None:
96 """Test basic constructor patterns."""
97 # Test default fill of variance.
98 visit = self.simplest_visit_image
99 self.assertEqual(visit.variance.array[0, 0], 1.0)
100 self.assertIs(visit[...], visit)
101 self.assertEqual(str(visit), "VisitImage(Image([y=0:1024, x=0:1024], int64), ['M1'])")
102 self.assertEqual(
103 repr(visit),
104 "VisitImage(Image(..., bbox=Box(y=Interval(start=0, stop=1024), x=Interval(start=0, stop=1024)),"
105 " dtype=dtype('int64')), mask_schema=MaskSchema([MaskPlane(name='M1', description='D1')],"
106 " dtype=dtype('uint8')))",
107 )
109 astropy_wcs = visit.astropy_wcs
110 self.assertIsInstance(astropy_wcs, ProjectionAstropyView)
111 approx_wcs = visit.fits_wcs
112 self.assertIsInstance(approx_wcs, astropy.wcs.WCS)
114 # Check that it is a deep copy.
115 copy = visit.copy()
116 copy.image.array[0, 0] = 30.0
117 self.assertEqual(visit.image.array[0, 0], 42.0)
118 self.assertEqual(copy.image.array[0, 0], 30.0)
119 # Check that summary stats survives a slice and a copy.
120 self.assertEqual(copy.summary_stats, visit.summary_stats)
121 self.assertEqual(visit[Box.factory[0:5, 0:5]].summary_stats, visit.summary_stats)
123 with self.assertRaises(TypeError):
124 # Requires a PSF.
125 VisitImage(
126 self.image,
127 mask_schema=self.mask_schema,
128 projection=self.projection,
129 obs_info=self.obs_info,
130 )
132 with self.assertRaises(TypeError):
133 # Requires ObservationInfo.
134 VisitImage(
135 self.image,
136 psf=self.gaussian_psf,
137 mask_schema=self.mask_schema,
138 projection=self.projection,
139 )
141 with self.assertRaises(TypeError):
142 # Requires a projection.
143 VisitImage(
144 self.image,
145 psf=self.gaussian_psf,
146 mask_schema=self.mask_schema,
147 obs_info=self.obs_info,
148 )
150 with self.assertRaises(TypeError):
151 # Requires some form of mask.
152 VisitImage(
153 self.image,
154 psf=self.gaussian_psf,
155 projection=self.projection,
156 obs_info=self.obs_info,
157 )
159 with self.assertRaises(TypeError):
160 VisitImage(
161 Image(42, shape=(5, 5)),
162 psf=self.gaussian_psf,
163 mask_schema=self.mask_schema,
164 projection=self.projection,
165 obs_info=self.obs_info,
166 )
168 # Requires a DetectorFrame.
169 tract_frame = TractFrame(skymap="Skymap", tract=1, bbox=Box.factory[1:10, 1:10])
170 tract_proj = make_random_projection(self.rng, tract_frame, Box.factory[1:4096, 1:4096])
171 with self.assertRaises(TypeError):
172 VisitImage(
173 self.image,
174 projection=tract_proj,
175 psf=self.gaussian_psf,
176 mask_schema=self.mask_schema,
177 obs_info=self.obs_info,
178 )
180 # Variance unit mismatch.
181 with self.assertRaises(ValueError):
182 VisitImage(
183 self.image,
184 variance=self.image,
185 psf=self.gaussian_psf,
186 mask_schema=self.mask_schema,
187 projection=self.projection,
188 obs_info=self.obs_info,
189 )
191 def test_obs_info(self) -> None:
192 """Check that ObservationInfo has been constructured."""
193 visit = self.visit_image
194 self.assertIsNotNone(visit.obs_info)
195 self.maxDiff = None
196 assert visit.obs_info is not None # for mypy.
197 self.assertEqual(visit.obs_info.instrument, "LSSTCam")
199 def test_read_write(self) -> None:
200 """Test that a visit can round trip through a FITS file."""
201 with RoundtripFits(self, self.visit_image, "VisitImage") as roundtrip:
202 # Check that we're still using the right compression, and that we
203 # wrote WCSs.
204 fits = roundtrip.inspect()
205 self.assertEqual(fits[1].header["ZCMPTYPE"], "GZIP_2")
206 self.assertEqual(fits[1].header["CTYPE1"], "RA---TAN")
207 self.assertEqual(fits[2].header["ZCMPTYPE"], "GZIP_2")
208 self.assertEqual(fits[2].header["CTYPE1"], "RA---TAN")
209 self.assertEqual(fits[3].header["ZCMPTYPE"], "GZIP_2")
210 self.assertEqual(fits[3].header["CTYPE1"], "RA---TAN")
211 # Check a subimage read.
212 subbox = Box.factory[8:13, 9:30]
213 subimage = roundtrip.get(bbox=subbox)
214 assert_masked_images_equal(self, subimage, self.visit_image[subbox], expect_view=False)
215 with self.subTest():
216 self.assertEqual(roundtrip.get("bbox"), self.visit_image.bbox)
217 with self.subTest():
218 obs_info = roundtrip.get("obs_info")
219 self.assertIsInstance(obs_info, ObservationInfo)
220 self.assertEqual(obs_info, self.visit_image.obs_info)
222 assert_masked_images_equal(self, roundtrip.result, self.visit_image, expect_view=False)
223 # Check that the round-tripped headers are the same (up to card order).
224 self.assertEqual(len(roundtrip.result._opaque_metadata.headers[ExtensionKey()]), 1)
225 self.assertEqual(
226 dict(self.visit_image._opaque_metadata.headers[ExtensionKey()]),
227 dict(roundtrip.result._opaque_metadata.headers[ExtensionKey()]),
228 )
229 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("IMAGE")])
230 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("MASK")])
231 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("VARIANCE")])
232 self.assertEqual(roundtrip.result.obs_info, self.visit_image.obs_info)
233 self.assertIsNotNone(roundtrip.result.summary_stats)
234 self.assertEqual(
235 roundtrip.result.summary_stats.psfSigma,
236 self.visit_image.summary_stats.psfSigma,
237 )
238 self.assertEqual(
239 roundtrip.result.summary_stats.zeroPoint,
240 self.visit_image.summary_stats.zeroPoint,
241 )
244@unittest.skipUnless(DATA_DIR is not None, "TESTDATA_IMAGES_DIR is not in the environment.")
245class VisitImageLegacyTestCase(unittest.TestCase):
246 """Tests for the VisitImage class and the basics of the archive system.
248 Requires legacy code.
249 """
251 @classmethod
252 def setUpClass(cls) -> None:
253 assert DATA_DIR is not None, "Guaranteed by decorator."
254 cls.filename = os.path.join(DATA_DIR, "dp2", "legacy", "visit_image.fits")
255 cls.plane_map = plane_map = get_legacy_visit_image_mask_planes()
256 cls.visit_image = VisitImage.read_legacy(
257 cls.filename, preserve_quantization=True, plane_map=plane_map
258 )
259 cls.legacy_exposure: Any = None
260 try:
261 from lsst.afw.image import ExposureFitsReader
263 cls.legacy_exposure = ExposureFitsReader(cls.filename).read()
264 except ImportError:
265 pass
267 def test_legacy_errors(self) -> None:
268 """Legacy read failure modes."""
269 with self.assertRaises(ValueError):
270 VisitImage.from_legacy(self.legacy_exposure, instrument="HSC")
271 with self.assertRaises(ValueError):
272 VisitImage.from_legacy(self.legacy_exposure, visit=123456)
273 with self.assertRaises(ValueError):
274 VisitImage.from_legacy(self.legacy_exposure, unit=u.mJy)
275 visit = VisitImage.from_legacy(
276 self.legacy_exposure, instrument="LSSTCam", unit=u.nJy, visit=2025052000177
277 )
278 self.assertEqual(visit.unit, u.nJy)
280 with self.assertRaises(ValueError):
281 VisitImage.read_legacy(self.filename, instrument="HSC")
282 with self.assertRaises(ValueError):
283 VisitImage.read_legacy(self.filename, visit=123456)
285 def test_component_reads(self) -> None:
286 """Test reads of components from legacy file."""
287 visit = VisitImage.read_legacy(self.filename)
288 proj = VisitImage.read_legacy(self.filename, component="projection")
289 assert_projections_equal(self, proj, visit.projection, expect_identity=False)
290 image = VisitImage.read_legacy(self.filename, component="image")
291 self.assertEqual(image, visit.image)
292 variance = VisitImage.read_legacy(self.filename, component="variance")
293 self.assertEqual(variance, visit.variance)
294 mask = VisitImage.read_legacy(self.filename, component="mask")
295 self.assertEqual(mask, visit.mask)
296 psf = VisitImage.read_legacy(self.filename, component="psf")
297 self.assertIsInstance(psf, PointSpreadFunction)
298 obs_info = VisitImage.read_legacy(self.filename, component="obs_info")
299 self.assertIsInstance(obs_info, ObservationInfo)
300 self.assertEqual(obs_info.instrument, "LSSTCam")
301 self.assertEqual(obs_info.detector_num, 85, obs_info)
302 self.assertEqual(obs_info.detector_unique_name, "R21_S11", obs_info)
303 self.assertEqual(obs_info.physical_filter, "r_57", obs_info)
304 summary_stats = VisitImage.read_legacy(self.filename, component="summary_stats")
305 self.assertIsInstance(summary_stats, ObservationSummaryStats)
306 self.assertEqual(summary_stats.nPsfStar, 93)
308 def test_obs_info(self) -> None:
309 """Check that ObservationInfo has been constructed."""
310 legacy = VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map)
311 self.assertIsNotNone(legacy.obs_info)
312 self.maxDiff = None
313 self.assertEqual(legacy.obs_info, self.visit_image.obs_info)
314 assert legacy.obs_info is not None # for mypy.
315 self.assertEqual(legacy.obs_info.instrument, "LSSTCam")
316 self.assertEqual(legacy.obs_info.detector_num, 85, legacy.obs_info)
317 self.assertEqual(legacy.obs_info.detector_unique_name, "R21_S11", legacy.obs_info)
318 self.assertEqual(legacy.obs_info.physical_filter, "r_57", legacy.obs_info)
320 def test_read_legacy_headers(self) -> None:
321 """Test that headers were correctly stripped and interpreted in
322 `VisitImage.read_legacy`.
323 """
324 # Check that we read the units from BUNIT.
325 self.assertEqual(self.visit_image.unit, astropy.units.nJy)
326 # Check that the primary header has the keys we want, and none of the
327 # keys we don't want.
328 header = self.visit_image._opaque_metadata.headers[ExtensionKey()]
329 self.assertIn("EXPTIME", header)
330 self.assertEqual(header["PLATFORM"], "lsstcam")
331 self.assertNotIn("LSST BUTLER ID", header)
332 self.assertNotIn("AR HDU", header)
333 self.assertNotIn("A_ORDER", header)
334 # Check that the extension HDUs do not have any custom headers.
335 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("IMAGE")])
336 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("MASK")])
337 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("VARIANCE")])
339 def test_from_legacy_headers(self) -> None:
340 """Test that from_legacy handles headers properly."""
341 legacy = VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map)
342 header = legacy._opaque_metadata.headers[ExtensionKey()]
343 self.assertIn("EXPTIME", header)
344 self.assertEqual(header["PLATFORM"], "lsstcam")
345 self.assertNotIn("LSST BUTLER ID", header)
346 self.assertNotIn("AR HDU", header)
347 self.assertNotIn("A_ORDER", header)
348 # Check that the extension HDUs do not have any custom headers.
349 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("IMAGE")])
350 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("MASK")])
351 self.assertFalse(self.visit_image._opaque_metadata.headers[ExtensionKey("VARIANCE")])
353 def test_rewrite(self) -> None:
354 """Test that we can rewrite the visit image and preserve both
355 lossy-compressed pixel values and components exactly.
356 """
357 with RoundtripFits(self, self.visit_image, "VisitImage") as roundtrip:
358 # Check that we're still using the right compression, and that we
359 # wrote WCSs.
360 fits = roundtrip.inspect()
361 self.assertEqual(fits[1].header["ZCMPTYPE"], "RICE_1")
362 self.assertEqual(fits[1].header["CTYPE1"], "RA---TAN-SIP")
363 self.assertEqual(fits[2].header["ZCMPTYPE"], "GZIP_2")
364 self.assertEqual(fits[2].header["CTYPE1"], "RA---TAN-SIP")
365 self.assertEqual(fits[3].header["ZCMPTYPE"], "RICE_1")
366 self.assertEqual(fits[3].header["CTYPE1"], "RA---TAN-SIP")
367 # Check a subimage read.
368 subbox = Box.factory[8:13, 9:30]
369 subimage = roundtrip.get(bbox=subbox)
370 assert_masked_images_equal(self, subimage, self.visit_image[subbox], expect_view=False)
371 alternates: dict[str, Any] = {}
372 with self.subTest():
373 self.assertEqual(roundtrip.get("bbox"), self.visit_image.bbox)
374 alternates = {
375 k: roundtrip.get(k)
376 for k in ["projection", "image", "mask", "variance", "psf", "obs_info", "summary_stats"]
377 }
378 # Try to do a butler get of a component with storage class
379 # override.
380 with self.subTest():
381 if self.legacy_exposure is not None:
382 import lsst.afw.image
384 # We have VisitInfo available.
385 visit_info = roundtrip.get("obs_info", storageClass="VisitInfo")
386 self.assertIsInstance(visit_info, lsst.afw.image.VisitInfo)
387 self.assertEqual(visit_info.getInstrumentLabel(), "LSSTCam")
388 else:
389 raise unittest.SkipTest("Can not test VisitInfo conversion without afw")
391 assert_masked_images_equal(self, roundtrip.result, self.visit_image, expect_view=False)
392 # Check that the round-tripped headers are the same (up to card order).
393 self.assertEqual(
394 dict(self.visit_image._opaque_metadata.headers[ExtensionKey()]),
395 dict(roundtrip.result._opaque_metadata.headers[ExtensionKey()]),
396 )
397 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("IMAGE")])
398 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("MASK")])
399 self.assertFalse(roundtrip.result._opaque_metadata.headers[ExtensionKey("VARIANCE")])
400 self.assertEqual(roundtrip.result._opaque_metadata.headers[ExtensionKey()]["PLATFORM"], "lsstcam")
401 with self.subTest():
402 if self.legacy_exposure is None:
403 raise unittest.SkipTest("'lsst.afw.image' could not be imported.") from None
404 compare_visit_image_to_legacy(
405 self,
406 roundtrip.result,
407 self.legacy_exposure,
408 expect_view=False,
409 plane_map=self.plane_map,
410 **DP2_VISIT_DETECTOR_DATA_ID,
411 alternates=alternates,
412 )
413 # Check converting from the legacy object in-memory.
414 compare_visit_image_to_legacy(
415 self,
416 VisitImage.from_legacy(self.legacy_exposure, plane_map=self.plane_map),
417 self.legacy_exposure,
418 expect_view=True,
419 plane_map=self.plane_map,
420 **DP2_VISIT_DETECTOR_DATA_ID,
421 )
423 def test_butler_converters(self) -> None:
424 """Test that we can read a VisitImage and its components from a butler
425 dataset written as an `lsst.afw.image.Exposure`.
426 """
427 if self.legacy_exposure is None:
428 raise unittest.SkipTest("lsst.afw.image.afw could not be imported.")
429 with TemporaryButler(legacy="ExposureF") as helper:
430 from lsst.daf.butler import FileDataset
432 helper.butler.ingest(FileDataset(path=self.filename, refs=[helper.legacy]), transfer="symlink")
433 visit_image_ref = helper.legacy.overrideStorageClass("VisitImage")
434 visit_image = helper.butler.get(visit_image_ref)
435 bbox = helper.butler.get(visit_image_ref.makeComponentRef("bbox"))
436 self.assertEqual(bbox, visit_image.bbox)
437 alternates = {
438 k: helper.butler.get(visit_image_ref.makeComponentRef(k))
439 # TODO: including "projection" or "obs_info" here fails because
440 # there's code in daf_butler that expects any component to be
441 # valid for the *internal* storage class, not the requested
442 # one, and that's difficult to fix because it's tied up with
443 # the data ID standardization logic.
444 for k in ["image", "mask", "variance", "psf"]
445 }
446 compare_visit_image_to_legacy(
447 self,
448 visit_image,
449 self.legacy_exposure,
450 expect_view=False,
451 plane_map=self.plane_map,
452 alternates=alternates,
453 **DP2_VISIT_DETECTOR_DATA_ID,
454 )
457if __name__ == "__main__":
458 unittest.main()