Coverage for tests/test_iteration.py: 12%
38 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-21 09:53 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-21 09:53 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
11#
13import itertools
14import unittest
16from lsst.utils.iteration import chunk_iterable, ensure_iterable, isplit
19class IterationTestCase(unittest.TestCase):
20 """Tests for `iterable` helper."""
22 def testIterable(self):
23 test_data = (
24 # Boolean indicates to test that we know it's
25 # meant to be returned unchanged.
26 (0, False),
27 ("hello", False),
28 ({1: 2, 3: 4}, False),
29 ([0, 1, 2], True),
30 (["hello", "world"], True),
31 ({"a", "b", "c"}, True),
32 )
34 for data, is_iterable in test_data:
35 iterator = ensure_iterable(data)
36 if not is_iterable:
37 # Turn into a list so we can check against the
38 # expected result.
39 data = [data]
40 for original, from_iterable in itertools.zip_longest(data, iterator):
41 self.assertEqual(original, from_iterable)
43 def testChunking(self):
44 """Chunk iterables."""
45 simple = list(range(101))
46 out = []
47 n_chunks = 0
48 for chunk in chunk_iterable(simple, chunk_size=10):
49 out.extend(chunk)
50 n_chunks += 1
51 self.assertEqual(out, simple)
52 self.assertEqual(n_chunks, 11)
54 test_dict = {k: 1 for k in range(101)}
55 n_chunks = 0
56 for chunk in chunk_iterable(test_dict, chunk_size=45):
57 # Subtract 1 for each key in chunk
58 for k in chunk:
59 test_dict[k] -= 1
60 n_chunks += 1
61 # Should have matched every key
62 self.assertEqual(sum(test_dict.values()), 0)
63 self.assertEqual(n_chunks, 3)
65 def testIsplit(self):
66 # Test compatibility with str.split
67 seps = ("\n", " ", "d")
68 input_str = "ab\ncd ef\n"
70 for sep in seps:
71 for input in (input_str, input_str.encode()):
72 test_sep = sep.encode() if isinstance(input, bytes) else sep
73 isp = list(isplit(input, sep=test_sep))
74 ssp = input.split(test_sep)
75 self.assertEqual(isp, ssp)
78if __name__ == "__main__":
79 unittest.main()