Coverage for tests/test_interval.py : 12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20#
21from __future__ import annotations
23import unittest
24import itertools
25from typing import ClassVar
27import numpy as np
29import lsst.utils.tests
30from lsst.pex.exceptions import InvalidParameterError
31from lsst.geom import IntervalI, IntervalD
34class IntervalTestData:
35 """Test helper that constructs and organizes intervals to be tested.
37 Parameters
38 ----------
39 IntervalClass : `type`
40 Type object that specifies the interval class to be tested (either
41 `IntervalI` or `IntervalD`).
42 points : `list`
43 List of scalar values to use as endpoints, sorted from smallest to
44 largest.
45 n : `int`, optional
46 If not `None`, the number of intervals to retain in each category
47 (a random subset).
48 """
50 def __init__(self, IntervalClass, points, n=None):
51 self.concrete = []
52 self.infinite = []
53 self.empty = [IntervalClass()]
54 for i, lhs in enumerate(points):
55 for j, rhs in enumerate(points):
56 if np.isfinite(lhs) and np.isfinite(rhs):
57 if i < j:
58 self.concrete.append(IntervalClass(min=lhs, max=rhs))
59 elif i > j:
60 self.empty.append(IntervalClass(min=lhs, max=rhs))
61 else:
62 assert i == j
63 self.concrete.append(IntervalClass(min=lhs, max=rhs))
64 else:
65 if i < j:
66 self.infinite.append(IntervalClass(min=lhs, max=rhs))
67 elif i > j:
68 self.empty.append(IntervalClass(min=lhs, max=rhs))
69 if n is not None:
70 self.concrete = self.subset(self.concrete, n)
71 self.infinite = self.subset(self.infinite, n)
72 self.empty = self.subset(self.empty, n)
74 @staticmethod
75 def subset(seq, n):
76 """Return `n` random elements from the given sequence.
77 """
78 if len(seq) > n:
79 return [seq[i] for i in np.random.choice(len(seq), n)]
80 return seq
82 @property
83 def nonempty(self):
84 """Iterate over all test intervals that are not empty.
85 """
86 return itertools.chain(self.concrete, self.infinite)
88 @property
89 def finite(self):
90 """Iterate over all test intervals that have finite size.
91 """
92 return itertools.chain(self.concrete, self.empty)
94 @property
95 def all(self):
96 """Iterate over all test intervals.
97 """
98 return itertools.chain(self.concrete, self.infinite, self.empty)
101class IntervalTests:
102 """Base class that provides common tests for IntervalI and IntervalD.
103 """
105 IntervalClass: ClassVar[type]
106 """Interval class object to be tested.
108 Subclasses must define this as a class attribute.
109 """
111 intervals: IntervalTestData
112 """Example intervals to be tested, categorized.
114 Subclasses must define this, typically as an instance attribute during
115 `setUp`.
116 """
118 points: list
119 """List of points to use when testing interval operations that take
120 scalars of the appropriate type.
122 These points must be finite.
124 Subclasses must define this, typically as an instance attribute during
125 `setUp`.
126 """
128 nonfinitePoints: list
129 """Additional points representing non-finite values.
131 Subclasses must define this, typically as an instance attribute during
132 `setUp`. An empty list should be used for intervals whose bound type does
133 not support non-finite values.
134 """
136 def assertAllTrue(self, iterable):
137 seq = list(iterable)
138 self.assertEqual(seq, [True]*len(seq))
140 def assertAllFalse(self, iterable):
141 seq = list(iterable)
142 self.assertEqual(seq, [False]*len(seq))
144 def testEmpty(self):
145 self.assertTrue(self.IntervalClass().isEmpty())
146 self.assertAllFalse(s.isEmpty() for s in self.intervals.nonempty)
147 for interval in self.intervals.empty:
148 with self.subTest(interval=interval):
149 self.checkEmptyIntervalInvariants(interval)
151 def testConstructors(self):
152 for i in self.intervals.finite:
153 with self.subTest(i=i):
154 self.assertEqual(i, self.IntervalClass(min=i.min, max=i.max))
155 self.assertEqual(i, self.IntervalClass(min=i.min, size=i.size))
156 self.assertEqual(i, self.IntervalClass(max=i.max, size=i.size))
157 for i in self.intervals.infinite:
158 with self.subTest(i=i):
159 self.assertEqual(i, self.IntervalClass(min=i.min, max=i.max))
160 with self.assertRaises(InvalidParameterError):
161 self.IntervalClass(min=i.min, size=i.size)
162 with self.assertRaises(InvalidParameterError):
163 self.IntervalClass(max=i.max, size=i.size)
165 def testFromSpannedPoints(self):
166 for n1, p1 in enumerate(self.points):
167 for n2, p2 in enumerate(self.points):
168 with self.subTest(n1=n1, p1=p1, n2=n2, p2=p2):
169 seq = list(self.points[n1:n2+1])
170 # p1 is the overall min and p2 is the overall max because
171 # self.points is sorted.
172 i = self.IntervalClass(min=p1, max=p2)
173 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq))
174 np.random.shuffle(seq)
175 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq))
176 seq.reverse()
177 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq))
179 def testContains(self):
180 for lhs in self.intervals.nonempty:
181 for rhs in self.intervals.nonempty:
182 with self.subTest(lhs=lhs, rhs=rhs):
183 self.assertEqual(lhs.contains(rhs), lhs.min <= rhs.min and lhs.max >= rhs.max)
184 for rhs in self.intervals.empty:
185 with self.subTest(lhs=lhs, rhs=rhs):
186 self.assertTrue(lhs.contains(rhs))
187 for rhs in self.points:
188 with self.subTest(lhs=lhs, rhs=rhs):
189 self.assertEqual(lhs.contains(rhs), lhs.min <= rhs and lhs.max >= rhs)
190 array = np.array(self.points)
191 np.testing.assert_array_equal(lhs.contains(array),
192 np.logical_and(lhs.min <= array, lhs.max >= array))
193 for lhs in self.intervals.empty:
194 for rhs in self.intervals.nonempty:
195 with self.subTest(lhs=lhs, rhs=rhs):
196 self.assertFalse(lhs.contains(rhs))
197 for rhs in self.intervals.empty:
198 with self.subTest(lhs=lhs, rhs=rhs):
199 self.assertTrue(lhs.contains(rhs))
200 for rhs in self.points:
201 with self.subTest(lhs=lhs, rhs=rhs):
202 self.assertFalse(lhs.contains(rhs))
204 def testOverlaps(self):
205 for lhs in self.intervals.nonempty:
206 for rhs in self.intervals.nonempty:
207 with self.subTest(lhs=lhs, rhs=rhs):
208 self.assertEqual(lhs.overlaps(rhs),
209 lhs.contains(rhs.min) or lhs.contains(rhs.max)
210 or rhs.contains(lhs.min) or rhs.contains(lhs.max))
211 for rhs in self.intervals.empty:
212 with self.subTest(lhs=lhs, rhs=rhs):
213 self.assertFalse(lhs.overlaps(rhs))
214 for lhs in self.intervals.empty:
215 for rhs in self.intervals.all:
216 with self.subTest(lhs=lhs, rhs=rhs):
217 self.assertFalse(lhs.overlaps(rhs))
219 def testEquality(self):
220 for lhs in self.intervals.all:
221 for rhs in self.intervals.all:
222 with self.subTest(lhs=lhs, rhs=rhs):
223 shouldBeEqual = lhs.contains(rhs) and rhs.contains(lhs)
224 self.assertIs(lhs == rhs, shouldBeEqual)
225 self.assertIs(lhs != rhs, not shouldBeEqual)
227 def testClippedTo(self):
228 for lhs in self.intervals.all:
229 for rhs in self.intervals.all:
230 with self.subTest(lhs=lhs, rhs=rhs):
231 clipped = lhs.clippedTo(rhs)
232 self.assertTrue(lhs.contains(clipped))
233 self.assertTrue(rhs.contains(clipped))
234 self.assertIs(
235 clipped.isEmpty(),
236 lhs.isEmpty() or rhs.isEmpty() or not lhs.overlaps(rhs)
237 )
238 self.assertIs(clipped == rhs, lhs.contains(rhs))
239 self.assertIs(clipped == lhs, rhs.contains(lhs))
241 def testShiftedBy(self):
242 for original in self.intervals.nonempty:
243 for offset in self.points:
244 with self.subTest(original=original, offset=offset):
245 shifted = original.shiftedBy(offset)
246 self.assertEqual(original.size, shifted.size)
247 self.assertEqual(original.min + offset, shifted.min)
248 self.assertEqual(original.max + offset, shifted.max)
249 for original in self.intervals.empty:
250 for offset in self.points:
251 with self.subTest(original=original, offset=offset):
252 self.checkEmptyIntervalInvariants(original.shiftedBy(offset))
253 for original in self.intervals.all:
254 for offset in self.nonfinitePoints:
255 with self.subTest(original=original, offset=offset):
256 with self.assertRaises(InvalidParameterError):
257 original.shiftedBy(offset)
259 def testExpandedTo(self):
260 for lhs in self.intervals.all:
261 for rhs in self.intervals.all:
262 with self.subTest(lhs=lhs, rhs=rhs):
263 expanded = lhs.expandedTo(rhs)
264 self.assertTrue(expanded.contains(lhs))
265 self.assertTrue(expanded.contains(rhs))
266 self.assertIs(
267 expanded.isEmpty(),
268 lhs.isEmpty() and rhs.isEmpty()
269 )
270 self.assertIs(expanded == rhs, rhs.contains(lhs))
271 self.assertIs(expanded == lhs, lhs.contains(rhs))
272 for rhs in self.points:
273 with self.subTest(lhs=lhs, rhs=rhs):
274 self.assertEqual(lhs.expandedTo(rhs),
275 lhs.expandedTo(self.IntervalClass(min=rhs, max=rhs)))
276 for rhs in self.nonfinitePoints:
277 with self.subTest(lhs=lhs, rhs=rhs):
278 with self.assertRaises(InvalidParameterError):
279 lhs.expandedTo(rhs)
281 def testDilatedBy(self):
282 for original in self.intervals.nonempty:
283 for buffer in self.points:
284 with self.subTest(original=original, buffer=buffer):
285 dilated = original.dilatedBy(buffer)
286 if not dilated.isEmpty():
287 self.assertEqual(original.min - buffer, dilated.min)
288 self.assertEqual(original.max + buffer, dilated.max)
289 for original in self.intervals.empty:
290 for buffer in self.points:
291 with self.subTest(original=original, buffer=buffer):
292 self.checkEmptyIntervalInvariants(original.dilatedBy(buffer))
293 for original in self.intervals.all:
294 for buffer in self.nonfinitePoints:
295 with self.subTest(original=original, buffer=buffer):
296 with self.assertRaises(InvalidParameterError):
297 original.dilatedBy(buffer)
299 def testErodedBy(self):
300 for original in self.intervals.all:
301 for buffer in self.points:
302 with self.subTest(original=original, buffer=buffer):
303 self.assertEqual(original.erodedBy(buffer), original.dilatedBy(-buffer))
304 for buffer in self.nonfinitePoints:
305 with self.subTest(original=original, buffer=buffer):
306 with self.assertRaises(InvalidParameterError):
307 original.erodedBy(buffer)
309 def testReflectedAbout(self):
310 for original in self.intervals.nonempty:
311 for point in self.points:
312 reflected = original.reflectedAbout(point)
313 with self.subTest(original=original, point=point, reflected=reflected):
314 self.assertEqual(point - original.min, -(point - reflected.max))
315 self.assertEqual(point - original.max, -(point - reflected.min))
316 for original in self.intervals.empty:
317 for point in self.points:
318 with self.subTest(original=original, point=point):
319 self.checkEmptyIntervalInvariants(original.reflectedAbout(point))
320 for original in self.intervals.all:
321 for point in self.nonfinitePoints:
322 with self.subTest(original=original, point=point):
323 with self.assertRaises(InvalidParameterError):
324 original.reflectedAbout(point)
327class IntervalDTestCase(unittest.TestCase, IntervalTests):
328 IntervalClass = IntervalD
330 def setUp(self):
331 inf = float("inf")
332 self.points = [-1.5, 5.0, 6.75, 8.625]
333 self.intervals = IntervalTestData(self.IntervalClass, [-inf] + self.points + [inf], n=3)
334 self.nonfinitePoints = [np.nan, -np.inf, np.inf]
336 def checkEmptyIntervalInvariants(self, interval):
337 self.assertTrue(interval.isEmpty())
338 self.assertEqual(interval.size, 0.0)
339 self.assertTrue(np.isnan(interval.min))
340 self.assertTrue(np.isnan(interval.max))
342 def testBadConstruction(self):
343 with self.assertRaises(InvalidParameterError):
344 IntervalD(min=np.inf, max=np.inf)
345 with self.assertRaises(InvalidParameterError):
346 IntervalD(min=-np.inf, max=-np.inf)
347 with self.assertRaises(InvalidParameterError):
348 IntervalD(min=np.inf, size=2.0)
349 with self.assertRaises(InvalidParameterError):
350 IntervalD(max=-np.inf, size=2.0)
352 def testCenter(self):
353 for interval in self.intervals.concrete:
354 self.assertEqual(interval.center, 0.5*(interval.min + interval.max))
355 for i in self.intervals.finite:
356 self.assertEqual(i, self.IntervalClass(center=i.center, size=i.size))
358 def testInfinite(self):
359 for interval in self.intervals.finite:
360 with self.subTest(interval=interval):
361 self.assertTrue(interval.isFinite())
362 self.assertTrue(np.isfinite(interval.size))
363 for interval in self.intervals.infinite:
364 with self.subTest(interval=interval):
365 self.assertFalse(interval.isEmpty())
366 self.assertFalse(interval.isFinite())
367 self.assertEqual(interval.size, np.inf)
370class IntervalITestCase(unittest.TestCase, IntervalTests):
371 IntervalClass = IntervalI
373 def setUp(self):
374 self.points = [-2, 4, 7, 11]
375 self.intervals = IntervalTestData(self.IntervalClass, self.points, n=3)
376 self.nonfinitePoints = []
378 def checkEmptyIntervalInvariants(self, interval):
379 self.assertTrue(interval.isEmpty())
380 self.assertEqual(interval.size, 0.0)
381 self.assertLess(interval.max, interval.min)
382 # Actual values of min and max are unspecified; while implementation
383 # tries to make them consistent, nothing (not even tests) should depend
384 # on that.
386 def testExtensions(self):
387 s = list(range(10))
388 i = IntervalI(min=3, max=8)
389 self.assertEqual(s[i.slice()], list(i.range()))
390 self.assertEqual(len(i.range()), i.size)
391 np.testing.assert_array_equal(np.array(list(i.range()), dtype=np.int32), i.arange())
392 np.testing.assert_array_equal(np.array(list(i.range()), dtype=np.int64), i.arange(dtype=np.int64))
394 def testConversions(self):
395 cases = [
396 (IntervalD(min=0.5, max=0.5), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=1)),
397 (IntervalD(min=0.5, max=0.5), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()),
398 (IntervalD(min=0.3, max=0.8), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()),
399 (IntervalD(min=0.3, max=1.8), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI(min=1, max=1)),
400 (IntervalD(min=0.3, max=1.3), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()),
401 (IntervalD(min=0.0, max=0.0), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=0)),
402 (IntervalD(min=0.0, max=0.1), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=0)),
403 (IntervalD(min=0.9, max=1.0), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=1, max=1)),
404 (IntervalD(min=0.0, max=1.0), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()),
405 (IntervalD(min=-0.1, max=1.1), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=1)),
406 (IntervalD(min=-0.1, max=1.1), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()),
407 ]
408 for intervalD, edgeHandling, intervalI in cases:
409 with self.subTest(intervalD=intervalD, edgeHandling=edgeHandling, intervalI=intervalI):
410 self.assertFalse(intervalD.isEmpty())
411 self.assertEqual(IntervalI(intervalD, edgeHandling), intervalI)
412 if intervalI.isEmpty():
413 self.checkEmptyIntervalInvariants(IntervalI(intervalD, edgeHandling))
415 def testOverflow(self):
416 # Small enough to fit in int32 without any problem at all.
417 small = 1 << 16
418 # Fits in int32, barely.
419 medium = (1 << 31) - 1
420 # Definitely doesn't fit in int32.
421 large = 1 << 33
422 # Pass in a too-large value for either min or max.
423 # Just check for exception because pybind11 is actually what catches
424 # this overflow case.
425 with self.assertRaises(Exception):
426 IntervalI(min=-small, max=large)
427 with self.assertRaises(Exception):
428 IntervalI(min=-large, max=small)
429 # Pass two values that are individually okay, but together overflow
430 # the size, min, or max.
431 with self.assertRaises(OverflowError):
432 IntervalI(min=-medium, max=medium)
433 with self.assertRaises(OverflowError):
434 IntervalI(min=2, size=medium)
435 with self.assertRaises(OverflowError):
436 IntervalI(max=-3, size=medium)
437 # Make valid intervals overflow by dilating, shifting, or expanding.
438 base = IntervalI(min=-medium - 1, size=small)
439 with self.assertRaises(OverflowError):
440 base.dilatedBy(1)
441 with self.assertRaises(OverflowError):
442 base.shiftedBy(-1)
443 base = IntervalI(max=medium, size=small)
444 with self.assertRaises(OverflowError):
445 base.dilatedBy(1)
446 with self.assertRaises(OverflowError):
447 base.shiftedBy(1)
448 base = IntervalI(min=-small, size=medium)
449 with self.assertRaises(OverflowError):
450 base.dilatedBy(1)
451 with self.assertRaises(OverflowError):
452 base.expandedTo(-small - 1)
453 with self.assertRaises(OverflowError):
454 base.expandedTo(IntervalI(max=small, size=medium))
457class MemoryTester(lsst.utils.tests.MemoryTestCase):
458 pass
461def setup_module(module):
462 lsst.utils.tests.init()
465if __name__ == "__main__": 465 ↛ 466line 465 didn't jump to line 466, because the condition on line 465 was never true
466 lsst.utils.tests.init()
467 unittest.main()