Coverage for tests/test_interval.py: 11%

316 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-15 02:34 -0700

1# 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20# 

21from __future__ import annotations 

22 

23import unittest 

24import itertools 

25from typing import ClassVar 

26 

27import numpy as np 

28 

29import lsst.utils.tests 

30from lsst.pex.exceptions import InvalidParameterError 

31from lsst.geom import IntervalI, IntervalD 

32 

33 

34class IntervalTestData: 

35 """Test helper that constructs and organizes intervals to be tested. 

36 

37 Parameters 

38 ---------- 

39 IntervalClass : `type` 

40 Type object that specifies the interval class to be tested (either 

41 `IntervalI` or `IntervalD`). 

42 points : `list` 

43 List of scalar values to use as endpoints, sorted from smallest to 

44 largest. 

45 n : `int`, optional 

46 If not `None`, the number of intervals to retain in each category 

47 (a random subset). 

48 """ 

49 

50 def __init__(self, IntervalClass, points, n=None): 

51 self.concrete = [] 

52 self.infinite = [] 

53 self.empty = [IntervalClass()] 

54 for i, lhs in enumerate(points): 

55 for j, rhs in enumerate(points): 

56 if np.isfinite(lhs) and np.isfinite(rhs): 

57 if i < j: 

58 self.concrete.append(IntervalClass(min=lhs, max=rhs)) 

59 elif i > j: 

60 self.empty.append(IntervalClass(min=lhs, max=rhs)) 

61 else: 

62 assert i == j 

63 self.concrete.append(IntervalClass(min=lhs, max=rhs)) 

64 else: 

65 if i < j: 

66 self.infinite.append(IntervalClass(min=lhs, max=rhs)) 

67 elif i > j: 

68 self.empty.append(IntervalClass(min=lhs, max=rhs)) 

69 if n is not None: 

70 self.concrete = self.subset(self.concrete, n) 

71 self.infinite = self.subset(self.infinite, n) 

72 self.empty = self.subset(self.empty, n) 

73 

74 @staticmethod 

75 def subset(seq, n): 

76 """Return `n` random elements from the given sequence. 

77 """ 

78 if len(seq) > n: 

79 return [seq[i] for i in np.random.choice(len(seq), n)] 

80 return seq 

81 

82 @property 

83 def nonempty(self): 

84 """Iterate over all test intervals that are not empty. 

85 """ 

86 return itertools.chain(self.concrete, self.infinite) 

87 

88 @property 

89 def finite(self): 

90 """Iterate over all test intervals that have finite size. 

91 """ 

92 return itertools.chain(self.concrete, self.empty) 

93 

94 @property 

95 def all(self): 

96 """Iterate over all test intervals. 

97 """ 

98 return itertools.chain(self.concrete, self.infinite, self.empty) 

99 

100 

101class IntervalTests: 

102 """Base class that provides common tests for IntervalI and IntervalD. 

103 """ 

104 

105 IntervalClass: ClassVar[type] 

106 """Interval class object to be tested. 

107 

108 Subclasses must define this as a class attribute. 

109 """ 

110 

111 intervals: IntervalTestData 

112 """Example intervals to be tested, categorized. 

113 

114 Subclasses must define this, typically as an instance attribute during 

115 `setUp`. 

116 """ 

117 

118 points: list 

119 """List of points to use when testing interval operations that take 

120 scalars of the appropriate type. 

121 

122 These points must be finite. 

123 

124 Subclasses must define this, typically as an instance attribute during 

125 `setUp`. 

126 """ 

127 

128 nonfinitePoints: list 

129 """Additional points representing non-finite values. 

130 

131 Subclasses must define this, typically as an instance attribute during 

132 `setUp`. An empty list should be used for intervals whose bound type does 

133 not support non-finite values. 

134 """ 

135 

136 def assertAllTrue(self, iterable): 

137 seq = list(iterable) 

138 self.assertEqual(seq, [True]*len(seq)) 

139 

140 def assertAllFalse(self, iterable): 

141 seq = list(iterable) 

142 self.assertEqual(seq, [False]*len(seq)) 

143 

144 def testEmpty(self): 

145 self.assertTrue(self.IntervalClass().isEmpty()) 

146 self.assertAllFalse(s.isEmpty() for s in self.intervals.nonempty) 

147 for interval in self.intervals.empty: 

148 with self.subTest(interval=interval): 

149 self.checkEmptyIntervalInvariants(interval) 

150 

151 def testConstructors(self): 

152 for i in self.intervals.finite: 

153 with self.subTest(i=i): 

154 self.assertEqual(i, self.IntervalClass(min=i.min, max=i.max)) 

155 self.assertEqual(i, self.IntervalClass(min=i.min, size=i.size)) 

156 self.assertEqual(i, self.IntervalClass(max=i.max, size=i.size)) 

157 for i in self.intervals.infinite: 

158 with self.subTest(i=i): 

159 self.assertEqual(i, self.IntervalClass(min=i.min, max=i.max)) 

160 with self.assertRaises(InvalidParameterError): 

161 self.IntervalClass(min=i.min, size=i.size) 

162 with self.assertRaises(InvalidParameterError): 

163 self.IntervalClass(max=i.max, size=i.size) 

164 

165 def testFromSpannedPoints(self): 

166 for n1, p1 in enumerate(self.points): 

167 for n2, p2 in enumerate(self.points): 

168 with self.subTest(n1=n1, p1=p1, n2=n2, p2=p2): 

169 seq = list(self.points[n1:n2+1]) 

170 # p1 is the overall min and p2 is the overall max because 

171 # self.points is sorted. 

172 i = self.IntervalClass(min=p1, max=p2) 

173 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq)) 

174 np.random.shuffle(seq) 

175 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq)) 

176 seq.reverse() 

177 self.assertEqual(i, self.IntervalClass.fromSpannedPoints(seq)) 

178 

179 def testContains(self): 

180 for lhs in self.intervals.nonempty: 

181 for rhs in self.intervals.nonempty: 

182 with self.subTest(lhs=lhs, rhs=rhs): 

183 self.assertEqual(lhs.contains(rhs), lhs.min <= rhs.min and lhs.max >= rhs.max) 

184 for rhs in self.intervals.empty: 

185 with self.subTest(lhs=lhs, rhs=rhs): 

186 self.assertTrue(lhs.contains(rhs)) 

187 for rhs in self.points: 

188 with self.subTest(lhs=lhs, rhs=rhs): 

189 self.assertEqual(lhs.contains(rhs), lhs.min <= rhs and lhs.max >= rhs) 

190 array = np.array(self.points) 

191 np.testing.assert_array_equal(lhs.contains(array), 

192 np.logical_and(lhs.min <= array, lhs.max >= array)) 

193 for lhs in self.intervals.empty: 

194 for rhs in self.intervals.nonempty: 

195 with self.subTest(lhs=lhs, rhs=rhs): 

196 self.assertFalse(lhs.contains(rhs)) 

197 for rhs in self.intervals.empty: 

198 with self.subTest(lhs=lhs, rhs=rhs): 

199 self.assertTrue(lhs.contains(rhs)) 

200 for rhs in self.points: 

201 with self.subTest(lhs=lhs, rhs=rhs): 

202 self.assertFalse(lhs.contains(rhs)) 

203 

204 def testOverlaps(self): 

205 for lhs in self.intervals.nonempty: 

206 for rhs in self.intervals.nonempty: 

207 with self.subTest(lhs=lhs, rhs=rhs): 

208 self.assertEqual(lhs.overlaps(rhs), 

209 lhs.contains(rhs.min) or lhs.contains(rhs.max) 

210 or rhs.contains(lhs.min) or rhs.contains(lhs.max)) 

211 for rhs in self.intervals.empty: 

212 with self.subTest(lhs=lhs, rhs=rhs): 

213 self.assertFalse(lhs.overlaps(rhs)) 

214 for lhs in self.intervals.empty: 

215 for rhs in self.intervals.all: 

216 with self.subTest(lhs=lhs, rhs=rhs): 

217 self.assertFalse(lhs.overlaps(rhs)) 

218 

219 def testEquality(self): 

220 for lhs in self.intervals.all: 

221 for rhs in self.intervals.all: 

222 with self.subTest(lhs=lhs, rhs=rhs): 

223 shouldBeEqual = lhs.contains(rhs) and rhs.contains(lhs) 

224 self.assertIs(lhs == rhs, shouldBeEqual) 

225 self.assertIs(lhs != rhs, not shouldBeEqual) 

226 

227 def testClippedTo(self): 

228 for lhs in self.intervals.all: 

229 for rhs in self.intervals.all: 

230 with self.subTest(lhs=lhs, rhs=rhs): 

231 clipped = lhs.clippedTo(rhs) 

232 self.assertTrue(lhs.contains(clipped)) 

233 self.assertTrue(rhs.contains(clipped)) 

234 self.assertIs( 

235 clipped.isEmpty(), 

236 lhs.isEmpty() or rhs.isEmpty() or not lhs.overlaps(rhs) 

237 ) 

238 self.assertIs(clipped == rhs, lhs.contains(rhs)) 

239 self.assertIs(clipped == lhs, rhs.contains(lhs)) 

240 

241 def testShiftedBy(self): 

242 for original in self.intervals.nonempty: 

243 for offset in self.points: 

244 with self.subTest(original=original, offset=offset): 

245 shifted = original.shiftedBy(offset) 

246 self.assertEqual(original.size, shifted.size) 

247 self.assertEqual(original.min + offset, shifted.min) 

248 self.assertEqual(original.max + offset, shifted.max) 

249 for original in self.intervals.empty: 

250 for offset in self.points: 

251 with self.subTest(original=original, offset=offset): 

252 self.checkEmptyIntervalInvariants(original.shiftedBy(offset)) 

253 for original in self.intervals.all: 

254 for offset in self.nonfinitePoints: 

255 with self.subTest(original=original, offset=offset): 

256 with self.assertRaises(InvalidParameterError): 

257 original.shiftedBy(offset) 

258 

259 def testExpandedTo(self): 

260 for lhs in self.intervals.all: 

261 for rhs in self.intervals.all: 

262 with self.subTest(lhs=lhs, rhs=rhs): 

263 expanded = lhs.expandedTo(rhs) 

264 self.assertTrue(expanded.contains(lhs)) 

265 self.assertTrue(expanded.contains(rhs)) 

266 self.assertIs( 

267 expanded.isEmpty(), 

268 lhs.isEmpty() and rhs.isEmpty() 

269 ) 

270 self.assertIs(expanded == rhs, rhs.contains(lhs)) 

271 self.assertIs(expanded == lhs, lhs.contains(rhs)) 

272 for rhs in self.points: 

273 with self.subTest(lhs=lhs, rhs=rhs): 

274 self.assertEqual(lhs.expandedTo(rhs), 

275 lhs.expandedTo(self.IntervalClass(min=rhs, max=rhs))) 

276 for rhs in self.nonfinitePoints: 

277 with self.subTest(lhs=lhs, rhs=rhs): 

278 with self.assertRaises(InvalidParameterError): 

279 lhs.expandedTo(rhs) 

280 

281 def testDilatedBy(self): 

282 for original in self.intervals.nonempty: 

283 for buffer in self.points: 

284 with self.subTest(original=original, buffer=buffer): 

285 dilated = original.dilatedBy(buffer) 

286 if not dilated.isEmpty(): 

287 self.assertEqual(original.min - buffer, dilated.min) 

288 self.assertEqual(original.max + buffer, dilated.max) 

289 for original in self.intervals.empty: 

290 for buffer in self.points: 

291 with self.subTest(original=original, buffer=buffer): 

292 self.checkEmptyIntervalInvariants(original.dilatedBy(buffer)) 

293 for original in self.intervals.all: 

294 for buffer in self.nonfinitePoints: 

295 with self.subTest(original=original, buffer=buffer): 

296 with self.assertRaises(InvalidParameterError): 

297 original.dilatedBy(buffer) 

298 

299 def testErodedBy(self): 

300 for original in self.intervals.all: 

301 for buffer in self.points: 

302 with self.subTest(original=original, buffer=buffer): 

303 self.assertEqual(original.erodedBy(buffer), original.dilatedBy(-buffer)) 

304 for buffer in self.nonfinitePoints: 

305 with self.subTest(original=original, buffer=buffer): 

306 with self.assertRaises(InvalidParameterError): 

307 original.erodedBy(buffer) 

308 

309 def testReflectedAbout(self): 

310 for original in self.intervals.nonempty: 

311 for point in self.points: 

312 reflected = original.reflectedAbout(point) 

313 with self.subTest(original=original, point=point, reflected=reflected): 

314 self.assertEqual(point - original.min, -(point - reflected.max)) 

315 self.assertEqual(point - original.max, -(point - reflected.min)) 

316 for original in self.intervals.empty: 

317 for point in self.points: 

318 with self.subTest(original=original, point=point): 

319 self.checkEmptyIntervalInvariants(original.reflectedAbout(point)) 

320 for original in self.intervals.all: 

321 for point in self.nonfinitePoints: 

322 with self.subTest(original=original, point=point): 

323 with self.assertRaises(InvalidParameterError): 

324 original.reflectedAbout(point) 

325 

326 

327class IntervalDTestCase(unittest.TestCase, IntervalTests): 

328 IntervalClass = IntervalD 

329 

330 def setUp(self): 

331 inf = float("inf") 

332 self.points = [-1.5, 5.0, 6.75, 8.625] 

333 self.intervals = IntervalTestData(self.IntervalClass, [-inf] + self.points + [inf], n=3) 

334 self.nonfinitePoints = [np.nan, -np.inf, np.inf] 

335 

336 def checkEmptyIntervalInvariants(self, interval): 

337 self.assertTrue(interval.isEmpty()) 

338 self.assertEqual(interval.size, 0.0) 

339 self.assertTrue(np.isnan(interval.min)) 

340 self.assertTrue(np.isnan(interval.max)) 

341 

342 def testBadConstruction(self): 

343 with self.assertRaises(InvalidParameterError): 

344 IntervalD(min=np.inf, max=np.inf) 

345 with self.assertRaises(InvalidParameterError): 

346 IntervalD(min=-np.inf, max=-np.inf) 

347 with self.assertRaises(InvalidParameterError): 

348 IntervalD(min=np.inf, size=2.0) 

349 with self.assertRaises(InvalidParameterError): 

350 IntervalD(max=-np.inf, size=2.0) 

351 

352 def testCenter(self): 

353 for interval in self.intervals.concrete: 

354 self.assertEqual(interval.center, 0.5*(interval.min + interval.max)) 

355 for i in self.intervals.finite: 

356 self.assertEqual(i, self.IntervalClass(center=i.center, size=i.size)) 

357 

358 def testInfinite(self): 

359 for interval in self.intervals.finite: 

360 with self.subTest(interval=interval): 

361 self.assertTrue(interval.isFinite()) 

362 self.assertTrue(np.isfinite(interval.size)) 

363 for interval in self.intervals.infinite: 

364 with self.subTest(interval=interval): 

365 self.assertFalse(interval.isEmpty()) 

366 self.assertFalse(interval.isFinite()) 

367 self.assertEqual(interval.size, np.inf) 

368 

369 

370class IntervalITestCase(unittest.TestCase, IntervalTests): 

371 IntervalClass = IntervalI 

372 

373 def setUp(self): 

374 self.points = [-2, 4, 7, 11] 

375 self.intervals = IntervalTestData(self.IntervalClass, self.points, n=3) 

376 self.nonfinitePoints = [] 

377 

378 def checkEmptyIntervalInvariants(self, interval): 

379 self.assertTrue(interval.isEmpty()) 

380 self.assertEqual(interval.size, 0.0) 

381 self.assertLess(interval.max, interval.min) 

382 # Actual values of min and max are unspecified; while implementation 

383 # tries to make them consistent, nothing (not even tests) should depend 

384 # on that. 

385 

386 def testExtensions(self): 

387 s = list(range(10)) 

388 i = IntervalI(min=3, max=8) 

389 self.assertEqual(s[i.slice()], list(i.range())) 

390 self.assertEqual(len(i.range()), i.size) 

391 np.testing.assert_array_equal(np.array(list(i.range()), dtype=np.int32), i.arange()) 

392 np.testing.assert_array_equal(np.array(list(i.range()), dtype=np.int64), i.arange(dtype=np.int64)) 

393 

394 def testConversions(self): 

395 cases = [ 

396 (IntervalD(min=0.5, max=0.5), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=1)), 

397 (IntervalD(min=0.5, max=0.5), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()), 

398 (IntervalD(min=0.3, max=0.8), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()), 

399 (IntervalD(min=0.3, max=1.8), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI(min=1, max=1)), 

400 (IntervalD(min=0.3, max=1.3), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()), 

401 (IntervalD(min=0.0, max=0.0), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=0)), 

402 (IntervalD(min=0.0, max=0.1), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=0)), 

403 (IntervalD(min=0.9, max=1.0), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=1, max=1)), 

404 (IntervalD(min=0.0, max=1.0), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()), 

405 (IntervalD(min=-0.1, max=1.1), IntervalI.EdgeHandlingEnum.EXPAND, IntervalI(min=0, max=1)), 

406 (IntervalD(min=-0.1, max=1.1), IntervalI.EdgeHandlingEnum.SHRINK, IntervalI()), 

407 ] 

408 for intervalD, edgeHandling, intervalI in cases: 

409 with self.subTest(intervalD=intervalD, edgeHandling=edgeHandling, intervalI=intervalI): 

410 self.assertFalse(intervalD.isEmpty()) 

411 self.assertEqual(IntervalI(intervalD, edgeHandling), intervalI) 

412 if intervalI.isEmpty(): 

413 self.checkEmptyIntervalInvariants(IntervalI(intervalD, edgeHandling)) 

414 

415 def testOverflow(self): 

416 # Small enough to fit in int32 without any problem at all. 

417 small = 1 << 16 

418 # Fits in int32, barely. 

419 medium = (1 << 31) - 1 

420 # Definitely doesn't fit in int32. 

421 large = 1 << 33 

422 # Pass in a too-large value for either min or max. 

423 # Just check for exception because pybind11 is actually what catches 

424 # this overflow case. 

425 with self.assertRaises(Exception): 

426 IntervalI(min=-small, max=large) 

427 with self.assertRaises(Exception): 

428 IntervalI(min=-large, max=small) 

429 # Pass two values that are individually okay, but together overflow 

430 # the size, min, or max. 

431 with self.assertRaises(OverflowError): 

432 IntervalI(min=-medium, max=medium) 

433 with self.assertRaises(OverflowError): 

434 IntervalI(min=2, size=medium) 

435 with self.assertRaises(OverflowError): 

436 IntervalI(max=-3, size=medium) 

437 # Make valid intervals overflow by dilating, shifting, or expanding. 

438 base = IntervalI(min=-medium - 1, size=small) 

439 with self.assertRaises(OverflowError): 

440 base.dilatedBy(1) 

441 with self.assertRaises(OverflowError): 

442 base.shiftedBy(-1) 

443 base = IntervalI(max=medium, size=small) 

444 with self.assertRaises(OverflowError): 

445 base.dilatedBy(1) 

446 with self.assertRaises(OverflowError): 

447 base.shiftedBy(1) 

448 base = IntervalI(min=-small, size=medium) 

449 with self.assertRaises(OverflowError): 

450 base.dilatedBy(1) 

451 with self.assertRaises(OverflowError): 

452 base.expandedTo(-small - 1) 

453 with self.assertRaises(OverflowError): 

454 base.expandedTo(IntervalI(max=small, size=medium)) 

455 

456 

457class MemoryTester(lsst.utils.tests.MemoryTestCase): 

458 pass 

459 

460 

461def setup_module(module): 

462 lsst.utils.tests.init() 

463 

464 

465if __name__ == "__main__": 465 ↛ 466line 465 didn't jump to line 466, because the condition on line 465 was never true

466 lsst.utils.tests.init() 

467 unittest.main()