Coverage for tests/test_image.py: 6%
371 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import operator
24import numpy as np
25from lsst.scarlet.lite import Box, Image
26from lsst.scarlet.lite.image import MismatchedBandsError, MismatchedBoxError
27from numpy.testing import assert_almost_equal, assert_array_equal
28from utils import ScarletTestCase
31class TestImage(ScarletTestCase):
32 def test_constructors(self):
33 # Default constructor
34 data = np.arange(12).reshape(3, 4) # type: ignore
35 image = Image(data)
36 self.assertEqual(image.dtype, int)
37 self.assertTupleEqual(image.bands, ())
38 self.assertEqual(image.n_bands, 0)
39 assert_array_equal(image.shape, (3, 4))
40 self.assertEqual(image.height, 3)
41 self.assertEqual(image.width, 4)
42 assert_array_equal(image.yx0, (0, 0))
43 self.assertEqual(image.y0, 0)
44 self.assertEqual(image.x0, 0)
45 self.assertBoxEqual(image.bbox, Box((3, 4), (0, 0)))
46 assert_array_equal(image.data, data)
47 self.assertIsInstance(image.data, np.ndarray)
48 self.assertNotIsInstance(image.data, Image)
50 # Test constructor with all parameters
51 data = np.arange(24, dtype=float).reshape(2, 3, 4) # type: ignore
52 bands = ("g", "i")
53 y0, x0 = 10, 15
54 image = Image(
55 data,
56 bands=bands,
57 yx0=(y0, x0),
58 )
59 self.assertEqual(image.dtype, float)
60 assert_array_equal(image.bands, bands)
61 self.assertEqual(image.n_bands, 2)
62 assert_array_equal(image.shape, (2, 3, 4))
63 self.assertEqual(image.height, 3)
64 self.assertEqual(image.width, 4)
65 assert_array_equal(image.yx0, (10, 15))
66 self.assertEqual(image.y0, 10)
67 self.assertEqual(image.x0, 15)
68 self.assertBoxEqual(image.bbox, Box((3, 4), (10, 15)))
69 assert_array_equal(image.data, data)
70 self.assertIsInstance(image.data, np.ndarray)
71 self.assertNotIsInstance(image.data, Image)
73 # test initializing an empty image from a bounding box
74 image = Image.from_box(Box((10, 10), (13, 50)))
75 self.assertImageEqual(image, Image(np.zeros((10, 10), dtype=float), bands=(), yx0=(13, 50)))
76 bands = ("g", "r", "i")
77 image = Image.from_box(Box((10, 10), (13, 50)), bands=bands)
78 self.assertImageEqual(image, Image(np.zeros((3, 10, 10), dtype=float), bands=bands, yx0=(13, 50)))
80 with self.assertRaises(ValueError):
81 Image(np.zeros((3, 4, 5)), bands=tuple("gr"))
83 truth = "Image:\n [[[0 1 2]\n [3 4 5]]]\n bands=('g',)\n bbox=Box(shape=(2, 3), origin=(3, 2))"
84 data = np.arange(6).reshape(1, 2, 3)
85 bands = tuple("g")
86 yx0 = (3, 2)
87 image = Image(data, bands=bands, yx0=yx0)
88 self.assertEqual(str(image), truth)
90 def _binary_operation_test(
91 self,
92 lower_data: np.ndarray,
93 higher_data: np.ndarray,
94 lower_image: Image,
95 higher_image: Image,
96 op_name: str,
97 ) -> None:
98 lower = lower_image.copy()
99 higher = higher_image.copy()
100 op = getattr(operator, op_name)
102 # Test operation with constants
103 for constant in (3, 3.14, 3.14 + 3j):
104 if op_name in ("floordiv", "mod", "rshift", "lshift") and constant != 3:
105 # Cannot use floats or complex numbers for some operations,
106 # so skip them
107 continue
108 truth = op(lower_data, constant)
109 truth_image = Image(truth, bands=lower.bands)
110 result = op(lower, constant)
111 assert_array_equal(result.data, truth)
112 self.assertImageEqual(result, truth_image)
114 if op_name not in ("eq", "ne", "ge", "le", "lt", "gt") and (op_name != "pow" or constant == 3.14):
115 truth = op(constant, lower_data)
116 truth_image = Image(truth, bands=lower.bands)
117 result = getattr(lower, f"__r{op_name}__")(constant)
118 assert_array_equal(result.data, truth)
119 self.assertImageEqual(result, truth_image)
121 if op_name in ["rshift", "lshift"]:
122 # Shifting cannot be done with non-integer arrays
123 return
125 # Test lower * higher
126 truth = op(lower_data, higher_data)
127 truth_image = Image(truth, bands=higher_image.bands)
128 result = op(lower, higher)
129 assert_array_equal(result.data, truth)
130 self.assertImageEqual(result, truth_image)
132 if op_name not in ("eq", "ne", "ge", "le", "gt", "lt"):
133 result = getattr(higher, f"__r{op_name}__")(lower)
134 assert_array_equal(result.data, truth)
135 self.assertImageEqual(result, truth_image)
137 truth = op(higher_data, lower_data)
138 truth_image = Image(truth, bands=higher_image.bands)
139 iop = getattr(operator, "i" + op_name)
140 iop(higher, lower)
141 assert_array_equal(higher.data, truth)
142 self.assertImageEqual(higher, truth_image)
144 with self.assertRaises(ValueError):
145 iop(lower_image, higher_image)
147 def check_simple_arithmetic(self, data_bool, data_int, data_float, bands):
148 image_bool = Image(data_bool, bands=bands)
149 image_int = Image(data_int, bands=bands)
150 image_float = Image(data_float, bands=bands)
152 self.assertEqual(data_bool.dtype, bool)
153 self.assertEqual(data_int.dtype, int)
154 self.assertEqual(data_float.dtype, float)
156 # test casting for bool + int
157 self._binary_operation_test(
158 data_bool,
159 data_int,
160 image_bool,
161 image_int,
162 "add",
163 )
165 # Test binary operations
166 binary_operations = (
167 "add",
168 "sub",
169 "mul",
170 "truediv",
171 "floordiv",
172 "pow",
173 "mod",
174 "eq",
175 "ne",
176 "ge",
177 "le",
178 "gt",
179 "lt",
180 "rshift",
181 "lshift",
182 )
183 for op_name in binary_operations:
184 if op_name == "pow":
185 _data_int = np.abs(data_int)
186 _data_int[_data_int == 0] = 1
187 _image_int = image_int.copy()
188 _image_int.data[:] = _data_int
189 _data_float = np.abs(data_float)
190 _data_float[_data_float == 0] = 1
191 _image_float = image_float.copy()
192 _image_float.data[:] = _data_float
193 else:
194 _data_float = data_float
195 _image_float = image_float
196 _data_int = data_int
197 _image_int = image_int
198 self._binary_operation_test(
199 _data_int,
200 _data_float,
201 _image_int,
202 _image_float,
203 op_name,
204 )
206 # Test negation
207 self.assertImageEqual(-image_float, Image(-data_float, bands=bands)) # type: ignore
208 # Test unary positive operator
209 self.assertImageEqual(+image_float, image_float)
211 # Test that matrix multiplication is not supported
212 with self.assertRaises(TypeError):
213 image_int @ image_float
215 with self.assertRaises(TypeError):
216 image_int @= image_float
218 def test_simple_3d_arithmetic(self):
219 np.random.seed(1)
220 data_bool = np.random.choice((True, False), size=(2, 3, 4))
221 data_int = np.random.randint(-10, 10, (2, 3, 4))
222 data_int[data_int == 0] = 1
223 data_float = (np.random.random((2, 3, 4)) - 0.5) * 10
224 self.check_simple_arithmetic(data_bool, data_int, data_float, bands=("g", "r"))
226 def test_simple_2d_arithmetic(self):
227 np.random.seed(1)
228 data_bool = np.random.choice((True, False), size=(3, 4))
229 data_int = np.random.randint(-10, 10, (3, 4))
230 data_int[data_int == 0] = 1
231 data_float = (np.random.random((3, 4)) - 0.5) * 10
232 self.check_simple_arithmetic(data_bool, data_int, data_float, bands=None)
234 def test_3d_image_equality(self):
235 # Note: equality of the arrays is tested in other tests.
236 # This just checks that comparing non-images to images,
237 # or images with different bounding boxes or bands raises
238 # the appropriate exception.
239 np.random.seed(1)
240 bands = ("g", "r")
241 data1 = np.random.randint(-10, 10, (2, 3, 4))
242 data2 = data1.astype(float)
243 data3 = np.random.randint(-10, 10, (2, 3, 4)).astype(float)
245 image1 = Image(data1, bands=bands)
246 image2 = Image(data2, bands=bands)
247 image3 = Image(data3, bands=bands)
249 for op in (operator.eq, operator.ne):
250 with self.assertRaises(TypeError):
251 op(image1, data1)
252 with self.assertRaises(MismatchedBandsError):
253 op(image1, image2.copy_with(bands=("g", "i")))
254 with self.assertRaises(MismatchedBandsError):
255 op(image1, image3.copy_with(bands=("g", "i")))
256 with self.assertRaises(MismatchedBoxError):
257 op(image1, image2.copy_with(yx0=(30, 35)))
258 with self.assertRaises(MismatchedBoxError):
259 op(image1, image3.copy_with(yx0=(30, 35)))
261 def test_2d_image_equality(self):
262 # Note: equality of the arrays is tested in other tests.
263 # This just checks that comparing non-images to images,
264 # or images with different bounding boxes or bands raises
265 # the appropriate exception.
266 np.random.seed(1)
267 data1 = np.random.randint(-10, 10, (3, 4))
268 data2 = data1.astype(float)
269 data3 = np.random.randint(-10, 10, (3, 4)).astype(float)
271 image1 = Image(data1)
272 image2 = Image(data2)
273 image3 = Image(data3)
275 for op in (operator.eq, operator.ne):
276 with self.assertRaises(TypeError):
277 op(image1, data1)
278 with self.assertRaises(MismatchedBoxError):
279 op(image1, image2.copy_with(yx0=(30, 35)))
280 with self.assertRaises(MismatchedBoxError):
281 op(image1, image3.copy_with(yx0=(30, 35)))
283 def test_simple_boolean_arithmetic(self):
284 np.random.seed(1)
285 # Test boolean operations
286 boolean_operations = (
287 "and_",
288 "or_",
289 "xor",
290 )
291 data1 = np.random.choice((True, False), size=(2, 3, 4))
292 data2 = np.random.choice((True, False), size=(2, 3, 4))
293 _image1 = Image(data1, bands=("g", "i"))
294 _image2 = Image(data2, bands=("g", "i"))
295 for op_name in boolean_operations:
296 image1 = _image1.copy()
297 image2 = _image2.copy()
298 op = getattr(operator, op_name)
299 result = op(image1, image2)
300 data_result = op(data1, data2)
301 self.assertImageEqual(result, Image(data_result, bands=("g", "i")))
303 if op_name[-1] == "_":
304 # Trim the underscore after `or` and `and` in operator
305 op_name = op_name[:-1]
307 result = getattr(image2, f"__r{op_name}__")(image1)
308 self.assertImageEqual(result, Image(data_result, bands=("g", "i")))
310 iop = getattr(operator, "i" + op_name)
311 iop(image1, image2)
312 self.assertImageEqual(image1, Image(data_result, bands=("g", "i")))
314 # Test inversion
315 self.assertImageEqual(~_image1, Image(~data1, bands=("g", "i")))
317 def _3d_mismatched_images_test(
318 self,
319 op_name: str,
320 ):
321 np.random.seed(1)
322 op = getattr(operator, op_name)
323 grizy = ("g", "r", "i", "z", "y")
324 gir = ("g", "i", "r")
325 igy = ("i", "g", "y")
327 # Test band insert
328 if op_name == "add" or op_name == "subtract":
329 data1 = (np.random.random((5, 3, 4)) - 0.5) * 10
330 data2 = (np.random.random((3, 3, 4)) - 0.5) * 10
331 image1 = Image(data1, bands=grizy)
332 image2 = Image(data2, bands=gir)
333 result = op(image1, image2)
334 truth = np.zeros((5, 3, 4), dtype=float)
335 truth += data1
336 truth[(0, 2, 1), :, :] = op(truth[(0, 2, 1), :, :], data2)
337 assert_almost_equal(result.data, truth)
338 self.assertImageEqual(result, Image(truth, bands=grizy))
340 # Test band mixture
341 if op_name == "pow":
342 data1 = np.random.random((3, 3, 4)) + 1
343 data2 = np.random.random((3, 3, 4)) + 1
344 else:
345 data1 = (np.random.random((3, 3, 4)) - 0.5) * 10
346 data2 = (np.random.random((3, 3, 4)) - 0.5) * 10
347 image1 = Image(data1, bands=gir)
348 image2 = Image(data2, bands=igy)
349 result = op(image1, image2)
350 truth = np.zeros((4, 3, 4), dtype=float)
351 truth[(0, 1, 2), :, :] = data1
352 truth[(1, 0, 3), :, :] = op(truth[(1, 0, 3), :, :], data2)
353 assert_almost_equal(result.data, truth)
354 self.assertImageEqual(result, Image(truth, bands=("g", "i", "r", "y")))
356 # Test spatial offsets
357 if op_name == "pow":
358 data1 = np.random.random((3, 3, 4)) + 1
359 data2 = np.random.random((3, 3, 4)) + 1
360 else:
361 data1 = (np.random.random((3, 3, 4)) - 0.5) * 10
362 data2 = (np.random.random((3, 3, 4)) - 0.5) * 10
363 image1 = Image(data1, bands=gir, yx0=(10, 20))
364 image2 = Image(data2, bands=gir, yx0=(11, 17))
365 result = op(image1, image2)
367 _data1 = np.zeros((3, 4, 7), dtype=float)
368 _data2 = np.zeros((3, 4, 7), dtype=float)
369 _data1[:, :3, 3:] = data1
370 _data2[:, 1:, :4] = data2
371 with np.errstate(divide="ignore", invalid="ignore"):
372 truth = op(_data1, _data2)
373 assert_almost_equal(result.data, truth)
374 self.assertImageEqual(result, Image(truth, bands=gir, yx0=(10, 17)))
376 def _2d_mismatched_images_test(
377 self,
378 op_name: str,
379 ):
380 np.random.seed(1)
381 op = getattr(operator, op_name)
383 # Test spatial offsets
384 if op_name == "pow":
385 data1 = np.random.random((3, 4)) + 1
386 data2 = np.random.random((3, 4)) + 1
387 else:
388 data1 = (np.random.random((3, 4)) - 0.5) * 10
389 data2 = (np.random.random((3, 4)) - 0.5) * 10
390 image1 = Image(data1, yx0=(10, 20))
391 image2 = Image(data2, yx0=(11, 17))
392 result = op(image1, image2)
394 _data1 = np.zeros((4, 7), dtype=float)
395 _data2 = np.zeros((4, 7), dtype=float)
396 _data1[:3, 3:] = data1
397 _data2[1:, :4] = data2
398 with np.errstate(divide="ignore", invalid="ignore"):
399 truth = op(_data1, _data2)
400 assert_almost_equal(result.data, truth)
401 self.assertImageEqual(result, Image(truth, yx0=(10, 17)))
403 def test_mismatchd_arithmetic(self):
404 binary_operations = (
405 "add",
406 "sub",
407 "mul",
408 "truediv",
409 "floordiv",
410 "pow",
411 "mod",
412 )
414 for op_name in binary_operations:
415 self._3d_mismatched_images_test(op_name)
416 self._2d_mismatched_images_test(op_name)
418 def test_scalar_arithmetic(self):
419 data = np.arange(6).reshape(1, 2, 3)
420 bands = tuple("g")
421 yx0 = (3, 2)
422 image = Image(data, bands=bands, yx0=yx0)
423 self.assertImageEqual(2 & image, Image(2 & data, bands=bands, yx0=yx0))
424 self.assertImageEqual(2 | image, Image(2 | data, bands=bands, yx0=yx0))
425 self.assertImageEqual(2 ^ image, Image(2 ^ data, bands=bands, yx0=yx0))
427 with self.assertRaises(TypeError):
428 image << 2.0
429 with self.assertRaises(TypeError):
430 image >> 2.0
432 image2 = image.copy()
433 image2 <<= 2
434 self.assertImageEqual(image2, Image(data << 2, bands=bands, yx0=yx0))
436 image2 = image.copy()
437 image2 >>= 2
438 self.assertImageEqual(image2, Image(data >> 2, bands=bands, yx0=yx0))
440 def test_slicing(self):
441 bands = ("g", "r", "i", "z", "y")
442 yx0 = (27, 82)
443 data = (np.random.random((5, 30, 40)) - 0.5) * 10
444 image = Image(data, bands=bands, yx0=yx0)
445 image_2d = Image(data[0], yx0=yx0)
447 # test band slicing
448 sub_img = image["g"]
449 self.assertImageEqual(sub_img, Image(data[0], yx0=yx0))
451 sub_img = image[:"g"]
452 self.assertImageEqual(sub_img, Image(data[:1], bands=("g",), yx0=yx0))
454 sub_img = image["g":"r"]
455 self.assertImageEqual(sub_img, Image(data[0:2], bands=("g", "r"), yx0=yx0))
457 sub_img = image["r":"z"]
458 self.assertImageEqual(sub_img, Image(data[1:4], bands=("r", "i", "z"), yx0=yx0))
460 sub_img = image["z":]
461 self.assertImageEqual(sub_img, Image(data[-2:], bands=("z", "y"), yx0=yx0))
463 sub_img = image[("z", "i", "y")]
464 self.assertImageEqual(sub_img, Image(data[(3, 2, 4), :, :], bands=("z", "i", "y"), yx0=yx0))
466 self.assertImageEqual(image[:], image)
468 # Test bounding box slicing
469 sub_img = image[:, Box((10, 5), (37, 87))]
470 self.assertImageEqual(sub_img, Image(data[:, 10:20, 5:10], bands=bands, yx0=(37, 87)))
472 sub_img = image[Box((10, 5), (37, 87))]
473 self.assertImageEqual(sub_img, Image(data[:, 10:20, 5:10], bands=bands, yx0=(37, 87)))
475 sub_img = image_2d[Box((10, 5), (37, 87))]
476 self.assertImageEqual(sub_img, Image(data[0, 10:20, 5:10], yx0=(37, 87)))
478 with self.assertRaises(IndexError):
479 # Cannot index a single row, since it would not return an image
480 _ = image["g", 0]
482 with self.assertRaises(IndexError):
483 # Cannot index a single column, since it would not return an image
484 _ = image[:, :, 0]
486 with self.assertRaises(IndexError):
487 # Cannot use a tuple to select rows/columns
488 _ = image[("r", "i"), (1, 2)]
490 with self.assertRaises(IndexError):
491 # Cannot use a bounding box outside of the image
492 _ = image[:, Box((10, 10), (0, 0))]
494 with self.assertRaises(IndexError):
495 # Cannot use a bounding box partially outside of the image
496 _ = image[:, Box((40, 40), (20, 80))]
498 with self.assertRaises(IndexError):
499 # Too many spatial indices
500 _ = image[:, :, :, :]
502 truth = (
503 (0, 1, 2, 3, 4),
504 slice(27, 57),
505 slice(82, 122),
506 )
507 self.assertTupleEqual(image.multiband_slices, truth)
509 def test_overlap_detection(self):
510 # Test 2D image
511 image = Image(np.zeros((5, 6)), yx0=(10, 15))
512 slices = image.overlapped_slices(Box((8, 9), (7, 18)))
513 truth = ((slice(0, 5), slice(3, 6)), (slice(3, 8), slice(0, 3)))
514 self.assertTupleEqual(slices, truth)
516 # Test 3D image
517 image = Image(np.zeros((3, 10, 12)), bands=("g", "r", "i"), yx0=(13, 21))
518 slices = image.overlapped_slices(Box((8, 9), (15, 18)))
519 truth = (
520 (slice(None), slice(2, 10), slice(0, 6)),
521 (slice(None), slice(0, 8), slice(3, 9)),
522 )
523 self.assertTupleEqual(slices, truth)
525 # Test no overlap
526 slices = image.overlapped_slices(Box((8, 9), (115, 118)))
527 truth = (
528 (slice(None), slice(0, 0), slice(0, 0)),
529 (slice(None), slice(0, 0), slice(0, 0)),
530 )
531 self.assertTupleEqual(slices, truth)
533 def test_insertion(self):
534 img1 = Image.from_box(Box((20, 20)), bands=tuple("gri"))
535 img2 = Image.from_box(Box((5, 5), (11, 12)), bands=tuple("gi"))
536 img2.data[:] = np.arange(1, 3)[:, None, None]
537 img1.insert(img2)
539 truth = img1.copy()
540 truth.data[0, 11:16, 12:17] = 1
541 truth.data[2, 11:16, 12:17] = 2
542 self.assertImageEqual(img1, truth)
544 def test_matched_spectral_indices(self):
545 img1 = Image.from_box(Box((5, 5)))
546 img2 = Image.from_box(Box((5, 5)))
547 indices = img1.matched_spectral_indices(img2)
548 self.assertTupleEqual(indices, ((), ()))
550 img3 = Image.from_box(Box((5, 5)), bands=tuple("gri"))
551 with self.assertRaises(ValueError):
552 img1.matched_spectral_indices(img3)
554 with self.assertRaises(ValueError):
555 img3.matched_spectral_indices(img1)
557 def test_project(self):
558 data = np.arange(30).reshape(5, 6)
559 img = Image(data, yx0=(11, 15))
561 result = img.project(bbox=Box((20, 20), (2, 3)))
562 truth = np.zeros((20, 20))
563 truth[9:14, 12:18] = data
564 truth = Image(truth, yx0=(2, 3))
565 self.assertImageEqual(result, truth)
567 data = np.arange(60).reshape(3, 4, 5)
568 img = Image(data, bands=tuple("gri"))
569 result = img.project(tuple("gi"))
570 truth = data[(0, 2), :]
571 self.assertImageEqual(result, Image(truth, bands=tuple("gi")))
573 def test_repeat(self):
574 data = np.arange(18).reshape(3, 6)
575 image = Image(data, yx0=(15, 32))
576 result = image.repeat(tuple("grizy"))
577 truth = np.array([data, data, data, data, data])
578 truth = Image(truth, bands=tuple("grizy"), yx0=(15, 32))
579 self.assertImageEqual(result, truth)
581 with self.assertRaises(ValueError):
582 result.repeat(tuple("ubv"))