Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from collections import Counter, namedtuple 

23from glob import glob 

24import os 

25import re 

26import unittest 

27import logging 

28 

29from lsst.daf.butler.core.utils import findFileResources, getFullTypeName, globToRegex, iterable, Singleton 

30from lsst.daf.butler import Formatter, Registry 

31from lsst.daf.butler import NamedKeyDict, NamedValueSet, StorageClass 

32from lsst.daf.butler.core.utils import isplit, time_this 

33 

34TESTDIR = os.path.dirname(__file__) 

35 

36 

37class IterableTestCase(unittest.TestCase): 

38 """Tests for `iterable` helper. 

39 """ 

40 

41 def testNonIterable(self): 

42 self.assertEqual(list(iterable(0)), [0, ]) 

43 

44 def testString(self): 

45 self.assertEqual(list(iterable("hello")), ["hello", ]) 

46 

47 def testIterableNoString(self): 

48 self.assertEqual(list(iterable([0, 1, 2])), [0, 1, 2]) 

49 self.assertEqual(list(iterable(["hello", "world"])), ["hello", "world"]) 

50 

51 

52class SingletonTestCase(unittest.TestCase): 

53 """Tests of the Singleton metaclass""" 

54 

55 class IsSingleton(metaclass=Singleton): 

56 def __init__(self): 

57 self.data = {} 

58 self.id = 0 

59 

60 class IsBadSingleton(IsSingleton): 

61 def __init__(self, arg): 

62 """A singleton can not accept any arguments.""" 

63 self.arg = arg 

64 

65 class IsSingletonSubclass(IsSingleton): 

66 def __init__(self): 

67 super().__init__() 

68 

69 def testSingleton(self): 

70 one = SingletonTestCase.IsSingleton() 

71 two = SingletonTestCase.IsSingleton() 

72 

73 # Now update the first one and check the second 

74 one.data["test"] = 52 

75 self.assertEqual(one.data, two.data) 

76 two.id += 1 

77 self.assertEqual(one.id, two.id) 

78 

79 three = SingletonTestCase.IsSingletonSubclass() 

80 self.assertNotEqual(one.id, three.id) 

81 

82 with self.assertRaises(TypeError): 

83 SingletonTestCase.IsBadSingleton(52) 

84 

85 

86class NamedKeyDictTest(unittest.TestCase): 

87 

88 def setUp(self): 

89 self.TestTuple = namedtuple("TestTuple", ("name", "id")) 

90 self.a = self.TestTuple(name="a", id=1) 

91 self.b = self.TestTuple(name="b", id=2) 

92 self.dictionary = {self.a: 10, self.b: 20} 

93 self.names = {self.a.name, self.b.name} 

94 

95 def check(self, nkd): 

96 self.assertEqual(len(nkd), 2) 

97 self.assertEqual(nkd.names, self.names) 

98 self.assertEqual(nkd.keys(), self.dictionary.keys()) 

99 self.assertEqual(list(nkd.values()), list(self.dictionary.values())) 

100 self.assertEqual(list(nkd.items()), list(self.dictionary.items())) 

101 self.assertEqual(list(nkd.byName().values()), list(self.dictionary.values())) 

102 self.assertEqual(list(nkd.byName().keys()), list(nkd.names)) 

103 

104 def testConstruction(self): 

105 self.check(NamedKeyDict(self.dictionary)) 

106 self.check(NamedKeyDict(iter(self.dictionary.items()))) 

107 

108 def testDuplicateNameConstruction(self): 

109 self.dictionary[self.TestTuple(name="a", id=3)] = 30 

110 with self.assertRaises(AssertionError): 

111 NamedKeyDict(self.dictionary) 

112 with self.assertRaises(AssertionError): 

113 NamedKeyDict(iter(self.dictionary.items())) 

114 

115 def testNoNameConstruction(self): 

116 self.dictionary["a"] = 30 

117 with self.assertRaises(AttributeError): 

118 NamedKeyDict(self.dictionary) 

119 with self.assertRaises(AttributeError): 

120 NamedKeyDict(iter(self.dictionary.items())) 

121 

122 def testGetItem(self): 

123 nkd = NamedKeyDict(self.dictionary) 

124 self.assertEqual(nkd["a"], 10) 

125 self.assertEqual(nkd[self.a], 10) 

126 self.assertEqual(nkd["b"], 20) 

127 self.assertEqual(nkd[self.b], 20) 

128 self.assertIn("a", nkd) 

129 self.assertIn(self.b, nkd) 

130 

131 def testSetItem(self): 

132 nkd = NamedKeyDict(self.dictionary) 

133 nkd[self.a] = 30 

134 self.assertEqual(nkd["a"], 30) 

135 nkd["b"] = 40 

136 self.assertEqual(nkd[self.b], 40) 

137 with self.assertRaises(KeyError): 

138 nkd["c"] = 50 

139 with self.assertRaises(AssertionError): 

140 nkd[self.TestTuple("a", 3)] = 60 

141 

142 def testDelItem(self): 

143 nkd = NamedKeyDict(self.dictionary) 

144 del nkd[self.a] 

145 self.assertNotIn("a", nkd) 

146 del nkd["b"] 

147 self.assertNotIn(self.b, nkd) 

148 self.assertEqual(len(nkd), 0) 

149 

150 def testIter(self): 

151 self.assertEqual(set(iter(NamedKeyDict(self.dictionary))), set(self.dictionary)) 

152 

153 def testEquality(self): 

154 nkd = NamedKeyDict(self.dictionary) 

155 self.assertEqual(nkd, self.dictionary) 

156 self.assertEqual(self.dictionary, nkd) 

157 

158 

159class NamedValueSetTest(unittest.TestCase): 

160 

161 def setUp(self): 

162 self.TestTuple = namedtuple("TestTuple", ("name", "id")) 

163 self.a = self.TestTuple(name="a", id=1) 

164 self.b = self.TestTuple(name="b", id=2) 

165 self.c = self.TestTuple(name="c", id=3) 

166 

167 def testConstruction(self): 

168 for arg in ({self.a, self.b}, (self.a, self.b)): 

169 for nvs in (NamedValueSet(arg), NamedValueSet(arg).freeze()): 

170 self.assertEqual(len(nvs), 2) 

171 self.assertEqual(nvs.names, {"a", "b"}) 

172 self.assertCountEqual(nvs, {self.a, self.b}) 

173 self.assertCountEqual(nvs.asMapping().items(), [(self.a.name, self.a), (self.b.name, self.b)]) 

174 

175 def testNoNameConstruction(self): 

176 with self.assertRaises(AttributeError): 

177 NamedValueSet([self.a, "a"]) 

178 

179 def testGetItem(self): 

180 nvs = NamedValueSet({self.a, self.b, self.c}) 

181 self.assertEqual(nvs["a"], self.a) 

182 self.assertEqual(nvs[self.a], self.a) 

183 self.assertEqual(nvs["b"], self.b) 

184 self.assertEqual(nvs[self.b], self.b) 

185 self.assertIn("a", nvs) 

186 self.assertIn(self.b, nvs) 

187 

188 def testEquality(self): 

189 s = {self.a, self.b, self.c} 

190 nvs = NamedValueSet(s) 

191 self.assertEqual(nvs, s) 

192 self.assertEqual(s, nvs) 

193 

194 def checkOperator(self, result, expected): 

195 self.assertIsInstance(result, NamedValueSet) 

196 self.assertEqual(result, expected) 

197 

198 def testOperators(self): 

199 ab = NamedValueSet({self.a, self.b}) 

200 bc = NamedValueSet({self.b, self.c}) 

201 self.checkOperator(ab & bc, {self.b}) 

202 self.checkOperator(ab | bc, {self.a, self.b, self.c}) 

203 self.checkOperator(ab ^ bc, {self.a, self.c}) 

204 self.checkOperator(ab - bc, {self.a}) 

205 

206 

207class TestButlerUtils(unittest.TestCase): 

208 """Tests of the simple utilities.""" 

209 

210 def testTypeNames(self): 

211 # Check types and also an object 

212 tests = [(Formatter, "lsst.daf.butler.core.formatter.Formatter"), 

213 (int, "int"), 

214 (StorageClass, "lsst.daf.butler.core.storageClass.StorageClass"), 

215 (StorageClass(None), "lsst.daf.butler.core.storageClass.StorageClass"), 

216 (Registry, "lsst.daf.butler.registry.Registry")] 

217 

218 for item, typeName in tests: 

219 self.assertEqual(getFullTypeName(item), typeName) 

220 

221 def testIsplit(self): 

222 # Test compatibility with str.split 

223 seps = ("\n", " ", "d") 

224 input_str = "ab\ncd ef\n" 

225 

226 for sep in seps: 

227 for input in (input_str, input_str.encode()): 

228 test_sep = sep.encode() if isinstance(input, bytes) else sep 

229 isp = list(isplit(input, sep=test_sep)) 

230 ssp = input.split(test_sep) 

231 self.assertEqual(isp, ssp) 

232 

233 

234class FindFileResourcesTestCase(unittest.TestCase): 

235 

236 def test_getSingleFile(self): 

237 """Test getting a file by its file name.""" 

238 filename = os.path.join(TESTDIR, "config/basic/butler.yaml") 

239 self.assertEqual([filename], findFileResources([filename])) 

240 

241 def test_getAllFiles(self): 

242 """Test getting all the files by not passing a regex.""" 

243 expected = Counter([p for p in glob(os.path.join(TESTDIR, "config", "**"), recursive=True) 

244 if os.path.isfile(p)]) 

245 self.assertNotEqual(len(expected), 0) # verify some files were found 

246 files = Counter(findFileResources([os.path.join(TESTDIR, "config")])) 

247 self.assertEqual(expected, files) 

248 

249 def test_getAllFilesRegex(self): 

250 """Test getting all the files with a regex-specified file ending.""" 

251 expected = Counter(glob(os.path.join(TESTDIR, "config", "**", "*.yaml"), recursive=True)) 

252 self.assertNotEqual(len(expected), 0) # verify some files were found 

253 files = Counter(findFileResources([os.path.join(TESTDIR, "config")], r"\.yaml\b")) 

254 self.assertEqual(expected, files) 

255 

256 def test_multipleInputs(self): 

257 """Test specifying more than one location to find a files.""" 

258 expected = Counter(glob(os.path.join(TESTDIR, "config", "basic", "**", "*.yaml"), recursive=True)) 

259 expected.update(glob(os.path.join(TESTDIR, "config", "templates", "**", "*.yaml"), recursive=True)) 

260 self.assertNotEqual(len(expected), 0) # verify some files were found 

261 files = Counter(findFileResources([os.path.join(TESTDIR, "config", "basic"), 

262 os.path.join(TESTDIR, "config", "templates")], 

263 r"\.yaml\b")) 

264 self.assertEqual(expected, files) 

265 

266 

267class GlobToRegexTestCase(unittest.TestCase): 

268 

269 def testStarInList(self): 

270 """Test that if a one of the items in the expression list is a star 

271 (stand-alone) then ``...`` is returned (which implies no restrictions) 

272 """ 

273 self.assertIs(globToRegex(["foo", "*", "bar"]), ...) 

274 

275 def testGlobList(self): 

276 """Test that a list of glob strings converts as expected to a regex and 

277 returns in the expected list. 

278 """ 

279 # test an absolute string 

280 patterns = globToRegex(["bar"]) 

281 self.assertEqual(len(patterns), 1) 

282 self.assertTrue(bool(re.fullmatch(patterns[0], "bar"))) 

283 self.assertIsNone(re.fullmatch(patterns[0], "boz")) 

284 

285 # test leading & trailing wildcard in multiple patterns 

286 patterns = globToRegex(["ba*", "*.fits"]) 

287 self.assertEqual(len(patterns), 2) 

288 # check the "ba*" pattern: 

289 self.assertTrue(bool(re.fullmatch(patterns[0], "bar"))) 

290 self.assertTrue(bool(re.fullmatch(patterns[0], "baz"))) 

291 self.assertIsNone(re.fullmatch(patterns[0], "boz.fits")) 

292 # check the "*.fits" pattern: 

293 self.assertTrue(bool(re.fullmatch(patterns[1], "bar.fits"))) 

294 self.assertTrue(re.fullmatch(patterns[1], "boz.fits")) 

295 self.assertIsNone(re.fullmatch(patterns[1], "boz.hdf5")) 

296 

297 

298class TimerTestCase(unittest.TestCase): 

299 

300 def testTimer(self): 

301 with self.assertLogs(level="DEBUG") as cm: 

302 with time_this(): 

303 pass 

304 self.assertEqual(cm.records[0].name, "timer") 

305 self.assertEqual(cm.records[0].levelname, "DEBUG") 

306 self.assertEqual(cm.records[0].filename, "test_utils.py") 

307 

308 with self.assertLogs(level="DEBUG") as cm: 

309 with time_this(prefix=None): 

310 pass 

311 self.assertEqual(cm.records[0].name, "root") 

312 self.assertEqual(cm.records[0].levelname, "DEBUG") 

313 self.assertIn("Took", cm.output[0]) 

314 self.assertEqual(cm.records[0].filename, "test_utils.py") 

315 

316 # Change logging level 

317 with self.assertLogs(level="INFO") as cm: 

318 with time_this(level=logging.INFO, prefix=None): 

319 pass 

320 self.assertEqual(cm.records[0].name, "root") 

321 self.assertIn("Took", cm.output[0]) 

322 self.assertIn("seconds", cm.output[0]) 

323 

324 # Use a new logger with a message. 

325 msg = "Test message %d" 

326 test_num = 42 

327 logname = "test" 

328 with self.assertLogs(level="DEBUG") as cm: 

329 with time_this(log=logging.getLogger(logname), 

330 msg=msg, args=(42,), prefix=None): 

331 pass 

332 self.assertEqual(cm.records[0].name, logname) 

333 self.assertIn("Took", cm.output[0]) 

334 self.assertIn(msg % test_num, cm.output[0]) 

335 

336 # Prefix the logger. 

337 prefix = "prefix" 

338 with self.assertLogs(level="DEBUG") as cm: 

339 with time_this(prefix=prefix): 

340 pass 

341 self.assertEqual(cm.records[0].name, prefix) 

342 self.assertIn("Took", cm.output[0]) 

343 

344 # Prefix explicit logger. 

345 with self.assertLogs(level="DEBUG") as cm: 

346 with time_this(log=logging.getLogger(logname), 

347 prefix=prefix): 

348 pass 

349 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}") 

350 

351 # Trigger a problem. 

352 with self.assertLogs(level="ERROR") as cm: 

353 with self.assertRaises(RuntimeError): 

354 with time_this(log=logging.getLogger(logname), 

355 prefix=prefix): 

356 raise RuntimeError("A problem") 

357 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}") 

358 self.assertEqual(cm.records[0].levelname, "ERROR") 

359 

360 

361if __name__ == "__main__": 361 ↛ 362line 361 didn't jump to line 362, because the condition on line 361 was never true

362 unittest.main()