Coverage for tests/test_utils.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from collections import Counter, namedtuple
23from glob import glob
24import os
25import re
26import unittest
27import logging
29from lsst.daf.butler.core.utils import findFileResources, getFullTypeName, globToRegex, iterable, Singleton
30from lsst.daf.butler import Formatter, Registry
31from lsst.daf.butler import NamedKeyDict, NamedValueSet, StorageClass
32from lsst.daf.butler.core.utils import isplit, time_this, chunk_iterable
34TESTDIR = os.path.dirname(__file__)
37class IterableTestCase(unittest.TestCase):
38 """Tests for `iterable` helper.
39 """
41 def testNonIterable(self):
42 self.assertEqual(list(iterable(0)), [0, ])
44 def testString(self):
45 self.assertEqual(list(iterable("hello")), ["hello", ])
47 def testIterableNoString(self):
48 self.assertEqual(list(iterable([0, 1, 2])), [0, 1, 2])
49 self.assertEqual(list(iterable(["hello", "world"])), ["hello", "world"])
51 def testChunking(self):
52 """Chunk iterables."""
53 simple = list(range(101))
54 out = []
55 n_chunks = 0
56 for chunk in chunk_iterable(simple, chunk_size=10):
57 out.extend(chunk)
58 n_chunks += 1
59 self.assertEqual(out, simple)
60 self.assertEqual(n_chunks, 11)
62 test_dict = {k: 1 for k in range(101)}
63 n_chunks = 0
64 for chunk in chunk_iterable(test_dict, chunk_size=45):
65 # Subtract 1 for each key in chunk
66 for k in chunk:
67 test_dict[k] -= 1
68 n_chunks += 1
69 # Should have matched every key
70 self.assertEqual(sum(test_dict.values()), 0)
71 self.assertEqual(n_chunks, 3)
74class SingletonTestCase(unittest.TestCase):
75 """Tests of the Singleton metaclass"""
77 class IsSingleton(metaclass=Singleton):
78 def __init__(self):
79 self.data = {}
80 self.id = 0
82 class IsBadSingleton(IsSingleton):
83 def __init__(self, arg):
84 """A singleton can not accept any arguments."""
85 self.arg = arg
87 class IsSingletonSubclass(IsSingleton):
88 def __init__(self):
89 super().__init__()
91 def testSingleton(self):
92 one = SingletonTestCase.IsSingleton()
93 two = SingletonTestCase.IsSingleton()
95 # Now update the first one and check the second
96 one.data["test"] = 52
97 self.assertEqual(one.data, two.data)
98 two.id += 1
99 self.assertEqual(one.id, two.id)
101 three = SingletonTestCase.IsSingletonSubclass()
102 self.assertNotEqual(one.id, three.id)
104 with self.assertRaises(TypeError):
105 SingletonTestCase.IsBadSingleton(52)
108class NamedKeyDictTest(unittest.TestCase):
110 def setUp(self):
111 self.TestTuple = namedtuple("TestTuple", ("name", "id"))
112 self.a = self.TestTuple(name="a", id=1)
113 self.b = self.TestTuple(name="b", id=2)
114 self.dictionary = {self.a: 10, self.b: 20}
115 self.names = {self.a.name, self.b.name}
117 def check(self, nkd):
118 self.assertEqual(len(nkd), 2)
119 self.assertEqual(nkd.names, self.names)
120 self.assertEqual(nkd.keys(), self.dictionary.keys())
121 self.assertEqual(list(nkd.values()), list(self.dictionary.values()))
122 self.assertEqual(list(nkd.items()), list(self.dictionary.items()))
123 self.assertEqual(list(nkd.byName().values()), list(self.dictionary.values()))
124 self.assertEqual(list(nkd.byName().keys()), list(nkd.names))
126 def testConstruction(self):
127 self.check(NamedKeyDict(self.dictionary))
128 self.check(NamedKeyDict(iter(self.dictionary.items())))
130 def testDuplicateNameConstruction(self):
131 self.dictionary[self.TestTuple(name="a", id=3)] = 30
132 with self.assertRaises(AssertionError):
133 NamedKeyDict(self.dictionary)
134 with self.assertRaises(AssertionError):
135 NamedKeyDict(iter(self.dictionary.items()))
137 def testNoNameConstruction(self):
138 self.dictionary["a"] = 30
139 with self.assertRaises(AttributeError):
140 NamedKeyDict(self.dictionary)
141 with self.assertRaises(AttributeError):
142 NamedKeyDict(iter(self.dictionary.items()))
144 def testGetItem(self):
145 nkd = NamedKeyDict(self.dictionary)
146 self.assertEqual(nkd["a"], 10)
147 self.assertEqual(nkd[self.a], 10)
148 self.assertEqual(nkd["b"], 20)
149 self.assertEqual(nkd[self.b], 20)
150 self.assertIn("a", nkd)
151 self.assertIn(self.b, nkd)
153 def testSetItem(self):
154 nkd = NamedKeyDict(self.dictionary)
155 nkd[self.a] = 30
156 self.assertEqual(nkd["a"], 30)
157 nkd["b"] = 40
158 self.assertEqual(nkd[self.b], 40)
159 with self.assertRaises(KeyError):
160 nkd["c"] = 50
161 with self.assertRaises(AssertionError):
162 nkd[self.TestTuple("a", 3)] = 60
164 def testDelItem(self):
165 nkd = NamedKeyDict(self.dictionary)
166 del nkd[self.a]
167 self.assertNotIn("a", nkd)
168 del nkd["b"]
169 self.assertNotIn(self.b, nkd)
170 self.assertEqual(len(nkd), 0)
172 def testIter(self):
173 self.assertEqual(set(iter(NamedKeyDict(self.dictionary))), set(self.dictionary))
175 def testEquality(self):
176 nkd = NamedKeyDict(self.dictionary)
177 self.assertEqual(nkd, self.dictionary)
178 self.assertEqual(self.dictionary, nkd)
181class NamedValueSetTest(unittest.TestCase):
183 def setUp(self):
184 self.TestTuple = namedtuple("TestTuple", ("name", "id"))
185 self.a = self.TestTuple(name="a", id=1)
186 self.b = self.TestTuple(name="b", id=2)
187 self.c = self.TestTuple(name="c", id=3)
189 def testConstruction(self):
190 for arg in ({self.a, self.b}, (self.a, self.b)):
191 for nvs in (NamedValueSet(arg), NamedValueSet(arg).freeze()):
192 self.assertEqual(len(nvs), 2)
193 self.assertEqual(nvs.names, {"a", "b"})
194 self.assertCountEqual(nvs, {self.a, self.b})
195 self.assertCountEqual(nvs.asMapping().items(), [(self.a.name, self.a), (self.b.name, self.b)])
197 def testNoNameConstruction(self):
198 with self.assertRaises(AttributeError):
199 NamedValueSet([self.a, "a"])
201 def testGetItem(self):
202 nvs = NamedValueSet({self.a, self.b, self.c})
203 self.assertEqual(nvs["a"], self.a)
204 self.assertEqual(nvs[self.a], self.a)
205 self.assertEqual(nvs["b"], self.b)
206 self.assertEqual(nvs[self.b], self.b)
207 self.assertIn("a", nvs)
208 self.assertIn(self.b, nvs)
210 def testEquality(self):
211 s = {self.a, self.b, self.c}
212 nvs = NamedValueSet(s)
213 self.assertEqual(nvs, s)
214 self.assertEqual(s, nvs)
216 def checkOperator(self, result, expected):
217 self.assertIsInstance(result, NamedValueSet)
218 self.assertEqual(result, expected)
220 def testOperators(self):
221 ab = NamedValueSet({self.a, self.b})
222 bc = NamedValueSet({self.b, self.c})
223 self.checkOperator(ab & bc, {self.b})
224 self.checkOperator(ab | bc, {self.a, self.b, self.c})
225 self.checkOperator(ab ^ bc, {self.a, self.c})
226 self.checkOperator(ab - bc, {self.a})
229class TestButlerUtils(unittest.TestCase):
230 """Tests of the simple utilities."""
232 def testTypeNames(self):
233 # Check types and also an object
234 tests = [(Formatter, "lsst.daf.butler.core.formatter.Formatter"),
235 (int, "int"),
236 (StorageClass, "lsst.daf.butler.core.storageClass.StorageClass"),
237 (StorageClass(None), "lsst.daf.butler.core.storageClass.StorageClass"),
238 (Registry, "lsst.daf.butler.registry.Registry")]
240 for item, typeName in tests:
241 self.assertEqual(getFullTypeName(item), typeName)
243 def testIsplit(self):
244 # Test compatibility with str.split
245 seps = ("\n", " ", "d")
246 input_str = "ab\ncd ef\n"
248 for sep in seps:
249 for input in (input_str, input_str.encode()):
250 test_sep = sep.encode() if isinstance(input, bytes) else sep
251 isp = list(isplit(input, sep=test_sep))
252 ssp = input.split(test_sep)
253 self.assertEqual(isp, ssp)
256class FindFileResourcesTestCase(unittest.TestCase):
258 def test_getSingleFile(self):
259 """Test getting a file by its file name."""
260 filename = os.path.join(TESTDIR, "config/basic/butler.yaml")
261 self.assertEqual([filename], findFileResources([filename]))
263 def test_getAllFiles(self):
264 """Test getting all the files by not passing a regex."""
265 expected = Counter([p for p in glob(os.path.join(TESTDIR, "config", "**"), recursive=True)
266 if os.path.isfile(p)])
267 self.assertNotEqual(len(expected), 0) # verify some files were found
268 files = Counter(findFileResources([os.path.join(TESTDIR, "config")]))
269 self.assertEqual(expected, files)
271 def test_getAllFilesRegex(self):
272 """Test getting all the files with a regex-specified file ending."""
273 expected = Counter(glob(os.path.join(TESTDIR, "config", "**", "*.yaml"), recursive=True))
274 self.assertNotEqual(len(expected), 0) # verify some files were found
275 files = Counter(findFileResources([os.path.join(TESTDIR, "config")], r"\.yaml\b"))
276 self.assertEqual(expected, files)
278 def test_multipleInputs(self):
279 """Test specifying more than one location to find a files."""
280 expected = Counter(glob(os.path.join(TESTDIR, "config", "basic", "**", "*.yaml"), recursive=True))
281 expected.update(glob(os.path.join(TESTDIR, "config", "templates", "**", "*.yaml"), recursive=True))
282 self.assertNotEqual(len(expected), 0) # verify some files were found
283 files = Counter(findFileResources([os.path.join(TESTDIR, "config", "basic"),
284 os.path.join(TESTDIR, "config", "templates")],
285 r"\.yaml\b"))
286 self.assertEqual(expected, files)
289class GlobToRegexTestCase(unittest.TestCase):
291 def testStarInList(self):
292 """Test that if a one of the items in the expression list is a star
293 (stand-alone) then ``...`` is returned (which implies no restrictions)
294 """
295 self.assertIs(globToRegex(["foo", "*", "bar"]), ...)
297 def testGlobList(self):
298 """Test that a list of glob strings converts as expected to a regex and
299 returns in the expected list.
300 """
301 # test an absolute string
302 patterns = globToRegex(["bar"])
303 self.assertEqual(len(patterns), 1)
304 self.assertTrue(bool(re.fullmatch(patterns[0], "bar")))
305 self.assertIsNone(re.fullmatch(patterns[0], "boz"))
307 # test leading & trailing wildcard in multiple patterns
308 patterns = globToRegex(["ba*", "*.fits"])
309 self.assertEqual(len(patterns), 2)
310 # check the "ba*" pattern:
311 self.assertTrue(bool(re.fullmatch(patterns[0], "bar")))
312 self.assertTrue(bool(re.fullmatch(patterns[0], "baz")))
313 self.assertIsNone(re.fullmatch(patterns[0], "boz.fits"))
314 # check the "*.fits" pattern:
315 self.assertTrue(bool(re.fullmatch(patterns[1], "bar.fits")))
316 self.assertTrue(re.fullmatch(patterns[1], "boz.fits"))
317 self.assertIsNone(re.fullmatch(patterns[1], "boz.hdf5"))
320class TimerTestCase(unittest.TestCase):
322 def testTimer(self):
323 with self.assertLogs(level="DEBUG") as cm:
324 with time_this():
325 pass
326 self.assertEqual(cm.records[0].name, "timer")
327 self.assertEqual(cm.records[0].levelname, "DEBUG")
328 self.assertEqual(cm.records[0].filename, "test_utils.py")
330 with self.assertLogs(level="DEBUG") as cm:
331 with time_this(prefix=None):
332 pass
333 self.assertEqual(cm.records[0].name, "root")
334 self.assertEqual(cm.records[0].levelname, "DEBUG")
335 self.assertIn("Took", cm.output[0])
336 self.assertEqual(cm.records[0].filename, "test_utils.py")
338 # Change logging level
339 with self.assertLogs(level="INFO") as cm:
340 with time_this(level=logging.INFO, prefix=None):
341 pass
342 self.assertEqual(cm.records[0].name, "root")
343 self.assertIn("Took", cm.output[0])
344 self.assertIn("seconds", cm.output[0])
346 # Use a new logger with a message.
347 msg = "Test message %d"
348 test_num = 42
349 logname = "test"
350 with self.assertLogs(level="DEBUG") as cm:
351 with time_this(log=logging.getLogger(logname),
352 msg=msg, args=(42,), prefix=None):
353 pass
354 self.assertEqual(cm.records[0].name, logname)
355 self.assertIn("Took", cm.output[0])
356 self.assertIn(msg % test_num, cm.output[0])
358 # Prefix the logger.
359 prefix = "prefix"
360 with self.assertLogs(level="DEBUG") as cm:
361 with time_this(prefix=prefix):
362 pass
363 self.assertEqual(cm.records[0].name, prefix)
364 self.assertIn("Took", cm.output[0])
366 # Prefix explicit logger.
367 with self.assertLogs(level="DEBUG") as cm:
368 with time_this(log=logging.getLogger(logname),
369 prefix=prefix):
370 pass
371 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}")
373 # Trigger a problem.
374 with self.assertLogs(level="ERROR") as cm:
375 with self.assertRaises(RuntimeError):
376 with time_this(log=logging.getLogger(logname),
377 prefix=prefix):
378 raise RuntimeError("A problem")
379 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}")
380 self.assertEqual(cm.records[0].levelname, "ERROR")
383if __name__ == "__main__": 383 ↛ 384line 383 didn't jump to line 384, because the condition on line 383 was never true
384 unittest.main()