Coverage for tests/test_utils.py : 23%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from collections import Counter, namedtuple
23from glob import glob
24import os
25import re
26import unittest
27import logging
29from lsst.daf.butler.core.utils import findFileResources, getFullTypeName, globToRegex, iterable, Singleton
30from lsst.daf.butler import Formatter, Registry
31from lsst.daf.butler import NamedKeyDict, NamedValueSet, StorageClass
32from lsst.daf.butler.core.utils import isplit, time_this
34TESTDIR = os.path.dirname(__file__)
37class IterableTestCase(unittest.TestCase):
38 """Tests for `iterable` helper.
39 """
41 def testNonIterable(self):
42 self.assertEqual(list(iterable(0)), [0, ])
44 def testString(self):
45 self.assertEqual(list(iterable("hello")), ["hello", ])
47 def testIterableNoString(self):
48 self.assertEqual(list(iterable([0, 1, 2])), [0, 1, 2])
49 self.assertEqual(list(iterable(["hello", "world"])), ["hello", "world"])
52class SingletonTestCase(unittest.TestCase):
53 """Tests of the Singleton metaclass"""
55 class IsSingleton(metaclass=Singleton):
56 def __init__(self):
57 self.data = {}
58 self.id = 0
60 class IsBadSingleton(IsSingleton):
61 def __init__(self, arg):
62 """A singleton can not accept any arguments."""
63 self.arg = arg
65 class IsSingletonSubclass(IsSingleton):
66 def __init__(self):
67 super().__init__()
69 def testSingleton(self):
70 one = SingletonTestCase.IsSingleton()
71 two = SingletonTestCase.IsSingleton()
73 # Now update the first one and check the second
74 one.data["test"] = 52
75 self.assertEqual(one.data, two.data)
76 two.id += 1
77 self.assertEqual(one.id, two.id)
79 three = SingletonTestCase.IsSingletonSubclass()
80 self.assertNotEqual(one.id, three.id)
82 with self.assertRaises(TypeError):
83 SingletonTestCase.IsBadSingleton(52)
86class NamedKeyDictTest(unittest.TestCase):
88 def setUp(self):
89 self.TestTuple = namedtuple("TestTuple", ("name", "id"))
90 self.a = self.TestTuple(name="a", id=1)
91 self.b = self.TestTuple(name="b", id=2)
92 self.dictionary = {self.a: 10, self.b: 20}
93 self.names = {self.a.name, self.b.name}
95 def check(self, nkd):
96 self.assertEqual(len(nkd), 2)
97 self.assertEqual(nkd.names, self.names)
98 self.assertEqual(nkd.keys(), self.dictionary.keys())
99 self.assertEqual(list(nkd.values()), list(self.dictionary.values()))
100 self.assertEqual(list(nkd.items()), list(self.dictionary.items()))
101 self.assertEqual(list(nkd.byName().values()), list(self.dictionary.values()))
102 self.assertEqual(list(nkd.byName().keys()), list(nkd.names))
104 def testConstruction(self):
105 self.check(NamedKeyDict(self.dictionary))
106 self.check(NamedKeyDict(iter(self.dictionary.items())))
108 def testDuplicateNameConstruction(self):
109 self.dictionary[self.TestTuple(name="a", id=3)] = 30
110 with self.assertRaises(AssertionError):
111 NamedKeyDict(self.dictionary)
112 with self.assertRaises(AssertionError):
113 NamedKeyDict(iter(self.dictionary.items()))
115 def testNoNameConstruction(self):
116 self.dictionary["a"] = 30
117 with self.assertRaises(AttributeError):
118 NamedKeyDict(self.dictionary)
119 with self.assertRaises(AttributeError):
120 NamedKeyDict(iter(self.dictionary.items()))
122 def testGetItem(self):
123 nkd = NamedKeyDict(self.dictionary)
124 self.assertEqual(nkd["a"], 10)
125 self.assertEqual(nkd[self.a], 10)
126 self.assertEqual(nkd["b"], 20)
127 self.assertEqual(nkd[self.b], 20)
128 self.assertIn("a", nkd)
129 self.assertIn(self.b, nkd)
131 def testSetItem(self):
132 nkd = NamedKeyDict(self.dictionary)
133 nkd[self.a] = 30
134 self.assertEqual(nkd["a"], 30)
135 nkd["b"] = 40
136 self.assertEqual(nkd[self.b], 40)
137 with self.assertRaises(KeyError):
138 nkd["c"] = 50
139 with self.assertRaises(AssertionError):
140 nkd[self.TestTuple("a", 3)] = 60
142 def testDelItem(self):
143 nkd = NamedKeyDict(self.dictionary)
144 del nkd[self.a]
145 self.assertNotIn("a", nkd)
146 del nkd["b"]
147 self.assertNotIn(self.b, nkd)
148 self.assertEqual(len(nkd), 0)
150 def testIter(self):
151 self.assertEqual(set(iter(NamedKeyDict(self.dictionary))), set(self.dictionary))
153 def testEquality(self):
154 nkd = NamedKeyDict(self.dictionary)
155 self.assertEqual(nkd, self.dictionary)
156 self.assertEqual(self.dictionary, nkd)
159class NamedValueSetTest(unittest.TestCase):
161 def setUp(self):
162 self.TestTuple = namedtuple("TestTuple", ("name", "id"))
163 self.a = self.TestTuple(name="a", id=1)
164 self.b = self.TestTuple(name="b", id=2)
165 self.c = self.TestTuple(name="c", id=3)
167 def testConstruction(self):
168 for arg in ({self.a, self.b}, (self.a, self.b)):
169 for nvs in (NamedValueSet(arg), NamedValueSet(arg).freeze()):
170 self.assertEqual(len(nvs), 2)
171 self.assertEqual(nvs.names, {"a", "b"})
172 self.assertCountEqual(nvs, {self.a, self.b})
173 self.assertCountEqual(nvs.asMapping().items(), [(self.a.name, self.a), (self.b.name, self.b)])
175 def testNoNameConstruction(self):
176 with self.assertRaises(AttributeError):
177 NamedValueSet([self.a, "a"])
179 def testGetItem(self):
180 nvs = NamedValueSet({self.a, self.b, self.c})
181 self.assertEqual(nvs["a"], self.a)
182 self.assertEqual(nvs[self.a], self.a)
183 self.assertEqual(nvs["b"], self.b)
184 self.assertEqual(nvs[self.b], self.b)
185 self.assertIn("a", nvs)
186 self.assertIn(self.b, nvs)
188 def testEquality(self):
189 s = {self.a, self.b, self.c}
190 nvs = NamedValueSet(s)
191 self.assertEqual(nvs, s)
192 self.assertEqual(s, nvs)
194 def checkOperator(self, result, expected):
195 self.assertIsInstance(result, NamedValueSet)
196 self.assertEqual(result, expected)
198 def testOperators(self):
199 ab = NamedValueSet({self.a, self.b})
200 bc = NamedValueSet({self.b, self.c})
201 self.checkOperator(ab & bc, {self.b})
202 self.checkOperator(ab | bc, {self.a, self.b, self.c})
203 self.checkOperator(ab ^ bc, {self.a, self.c})
204 self.checkOperator(ab - bc, {self.a})
207class TestButlerUtils(unittest.TestCase):
208 """Tests of the simple utilities."""
210 def testTypeNames(self):
211 # Check types and also an object
212 tests = [(Formatter, "lsst.daf.butler.core.formatter.Formatter"),
213 (int, "int"),
214 (StorageClass, "lsst.daf.butler.core.storageClass.StorageClass"),
215 (StorageClass(None), "lsst.daf.butler.core.storageClass.StorageClass"),
216 (Registry, "lsst.daf.butler.registry.Registry")]
218 for item, typeName in tests:
219 self.assertEqual(getFullTypeName(item), typeName)
221 def testIsplit(self):
222 # Test compatibility with str.split
223 seps = ("\n", " ", "d")
224 input_str = "ab\ncd ef\n"
226 for sep in seps:
227 for input in (input_str, input_str.encode()):
228 test_sep = sep.encode() if isinstance(input, bytes) else sep
229 isp = list(isplit(input, sep=test_sep))
230 ssp = input.split(test_sep)
231 self.assertEqual(isp, ssp)
234class FindFileResourcesTestCase(unittest.TestCase):
236 def test_getSingleFile(self):
237 """Test getting a file by its file name."""
238 filename = os.path.join(TESTDIR, "config/basic/butler.yaml")
239 self.assertEqual([filename], findFileResources([filename]))
241 def test_getAllFiles(self):
242 """Test getting all the files by not passing a regex."""
243 expected = Counter([p for p in glob(os.path.join(TESTDIR, "config", "**"), recursive=True)
244 if os.path.isfile(p)])
245 self.assertNotEqual(len(expected), 0) # verify some files were found
246 files = Counter(findFileResources([os.path.join(TESTDIR, "config")]))
247 self.assertEqual(expected, files)
249 def test_getAllFilesRegex(self):
250 """Test getting all the files with a regex-specified file ending."""
251 expected = Counter(glob(os.path.join(TESTDIR, "config", "**", "*.yaml"), recursive=True))
252 self.assertNotEqual(len(expected), 0) # verify some files were found
253 files = Counter(findFileResources([os.path.join(TESTDIR, "config")], r"\.yaml\b"))
254 self.assertEqual(expected, files)
256 def test_multipleInputs(self):
257 """Test specifying more than one location to find a files."""
258 expected = Counter(glob(os.path.join(TESTDIR, "config", "basic", "**", "*.yaml"), recursive=True))
259 expected.update(glob(os.path.join(TESTDIR, "config", "templates", "**", "*.yaml"), recursive=True))
260 self.assertNotEqual(len(expected), 0) # verify some files were found
261 files = Counter(findFileResources([os.path.join(TESTDIR, "config", "basic"),
262 os.path.join(TESTDIR, "config", "templates")],
263 r"\.yaml\b"))
264 self.assertEqual(expected, files)
267class GlobToRegexTestCase(unittest.TestCase):
269 def testStarInList(self):
270 """Test that if a one of the items in the expression list is a star
271 (stand-alone) then ``...`` is returned (which implies no restrictions)
272 """
273 self.assertIs(globToRegex(["foo", "*", "bar"]), ...)
275 def testGlobList(self):
276 """Test that a list of glob strings converts as expected to a regex and
277 returns in the expected list.
278 """
279 # test an absolute string
280 patterns = globToRegex(["bar"])
281 self.assertEqual(len(patterns), 1)
282 self.assertTrue(bool(re.fullmatch(patterns[0], "bar")))
283 self.assertIsNone(re.fullmatch(patterns[0], "boz"))
285 # test leading & trailing wildcard in multiple patterns
286 patterns = globToRegex(["ba*", "*.fits"])
287 self.assertEqual(len(patterns), 2)
288 # check the "ba*" pattern:
289 self.assertTrue(bool(re.fullmatch(patterns[0], "bar")))
290 self.assertTrue(bool(re.fullmatch(patterns[0], "baz")))
291 self.assertIsNone(re.fullmatch(patterns[0], "boz.fits"))
292 # check the "*.fits" pattern:
293 self.assertTrue(bool(re.fullmatch(patterns[1], "bar.fits")))
294 self.assertTrue(re.fullmatch(patterns[1], "boz.fits"))
295 self.assertIsNone(re.fullmatch(patterns[1], "boz.hdf5"))
298class TimerTestCase(unittest.TestCase):
300 def testTimer(self):
301 with self.assertLogs(level="DEBUG") as cm:
302 with time_this():
303 pass
304 self.assertEqual(cm.records[0].name, "timer")
305 self.assertEqual(cm.records[0].levelname, "DEBUG")
306 self.assertEqual(cm.records[0].filename, "test_utils.py")
308 with self.assertLogs(level="DEBUG") as cm:
309 with time_this(prefix=None):
310 pass
311 self.assertEqual(cm.records[0].name, "root")
312 self.assertEqual(cm.records[0].levelname, "DEBUG")
313 self.assertIn("Took", cm.output[0])
314 self.assertEqual(cm.records[0].filename, "test_utils.py")
316 # Change logging level
317 with self.assertLogs(level="INFO") as cm:
318 with time_this(level=logging.INFO, prefix=None):
319 pass
320 self.assertEqual(cm.records[0].name, "root")
321 self.assertIn("Took", cm.output[0])
322 self.assertIn("seconds", cm.output[0])
324 # Use a new logger with a message.
325 msg = "Test message %d"
326 test_num = 42
327 logname = "test"
328 with self.assertLogs(level="DEBUG") as cm:
329 with time_this(log=logging.getLogger(logname),
330 msg=msg, args=(42,), prefix=None):
331 pass
332 self.assertEqual(cm.records[0].name, logname)
333 self.assertIn("Took", cm.output[0])
334 self.assertIn(msg % test_num, cm.output[0])
336 # Prefix the logger.
337 prefix = "prefix"
338 with self.assertLogs(level="DEBUG") as cm:
339 with time_this(prefix=prefix):
340 pass
341 self.assertEqual(cm.records[0].name, prefix)
342 self.assertIn("Took", cm.output[0])
344 # Prefix explicit logger.
345 with self.assertLogs(level="DEBUG") as cm:
346 with time_this(log=logging.getLogger(logname),
347 prefix=prefix):
348 pass
349 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}")
351 # Trigger a problem.
352 with self.assertLogs(level="ERROR") as cm:
353 with self.assertRaises(RuntimeError):
354 with time_this(log=logging.getLogger(logname),
355 prefix=prefix):
356 raise RuntimeError("A problem")
357 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}")
358 self.assertEqual(cm.records[0].levelname, "ERROR")
361if __name__ == "__main__": 361 ↛ 362line 361 didn't jump to line 362, because the condition on line 361 was never true
362 unittest.main()