Coverage for tests/test_iteration.py: 12%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 10:50 +0000

1# This file is part of utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import itertools 

23import unittest 

24 

25from lsst.utils.iteration import chunk_iterable, ensure_iterable, isplit 

26 

27 

28class IterationTestCase(unittest.TestCase): 

29 """Tests for `iterable` helper.""" 

30 

31 def testIterable(self): 

32 test_data = ( 

33 # Boolean indicates to test that we know it's 

34 # meant to be returned unchanged. 

35 (0, False), 

36 ("hello", False), 

37 ({1: 2, 3: 4}, False), 

38 ([0, 1, 2], True), 

39 (["hello", "world"], True), 

40 ({"a", "b", "c"}, True), 

41 ) 

42 

43 for data, is_iterable in test_data: 

44 iterator = ensure_iterable(data) 

45 if not is_iterable: 

46 # Turn into a list so we can check against the 

47 # expected result. 

48 data = [data] 

49 for original, from_iterable in itertools.zip_longest(data, iterator): 

50 self.assertEqual(original, from_iterable) 

51 

52 def testChunking(self): 

53 """Chunk iterables.""" 

54 simple = list(range(101)) 

55 out = [] 

56 n_chunks = 0 

57 for chunk in chunk_iterable(simple, chunk_size=10): 

58 out.extend(chunk) 

59 n_chunks += 1 

60 self.assertEqual(out, simple) 

61 self.assertEqual(n_chunks, 11) 

62 

63 test_dict = {k: 1 for k in range(101)} 

64 n_chunks = 0 

65 for chunk in chunk_iterable(test_dict, chunk_size=45): 

66 # Subtract 1 for each key in chunk 

67 for k in chunk: 

68 test_dict[k] -= 1 

69 n_chunks += 1 

70 # Should have matched every key 

71 self.assertEqual(sum(test_dict.values()), 0) 

72 self.assertEqual(n_chunks, 3) 

73 

74 def testIsplit(self): 

75 # Test compatibility with str.split 

76 seps = ("\n", " ", "d") 

77 input_str = "ab\ncd ef\n" 

78 

79 for sep in seps: 

80 for input in (input_str, input_str.encode()): 

81 test_sep = sep.encode() if isinstance(input, bytes) else sep 

82 isp = list(isplit(input, sep=test_sep)) 

83 ssp = input.split(test_sep) 

84 self.assertEqual(isp, ssp) 

85 

86 

87if __name__ == "__main__": 

88 unittest.main()