Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from collections import Counter, namedtuple 

23from glob import glob 

24import os 

25import re 

26import unittest 

27import logging 

28 

29from lsst.daf.butler.core.utils import findFileResources, getFullTypeName, globToRegex, iterable, Singleton 

30from lsst.daf.butler import Formatter, Registry 

31from lsst.daf.butler import NamedKeyDict, NamedValueSet, StorageClass 

32from lsst.daf.butler.core.utils import isplit, time_this, chunk_iterable 

33 

34TESTDIR = os.path.dirname(__file__) 

35 

36 

37class IterableTestCase(unittest.TestCase): 

38 """Tests for `iterable` helper. 

39 """ 

40 

41 def testNonIterable(self): 

42 self.assertEqual(list(iterable(0)), [0, ]) 

43 

44 def testString(self): 

45 self.assertEqual(list(iterable("hello")), ["hello", ]) 

46 

47 def testIterableNoString(self): 

48 self.assertEqual(list(iterable([0, 1, 2])), [0, 1, 2]) 

49 self.assertEqual(list(iterable(["hello", "world"])), ["hello", "world"]) 

50 

51 def testChunking(self): 

52 """Chunk iterables.""" 

53 simple = list(range(101)) 

54 out = [] 

55 n_chunks = 0 

56 for chunk in chunk_iterable(simple, chunk_size=10): 

57 out.extend(chunk) 

58 n_chunks += 1 

59 self.assertEqual(out, simple) 

60 self.assertEqual(n_chunks, 11) 

61 

62 test_dict = {k: 1 for k in range(101)} 

63 n_chunks = 0 

64 for chunk in chunk_iterable(test_dict, chunk_size=45): 

65 # Subtract 1 for each key in chunk 

66 for k in chunk: 

67 test_dict[k] -= 1 

68 n_chunks += 1 

69 # Should have matched every key 

70 self.assertEqual(sum(test_dict.values()), 0) 

71 self.assertEqual(n_chunks, 3) 

72 

73 

74class SingletonTestCase(unittest.TestCase): 

75 """Tests of the Singleton metaclass""" 

76 

77 class IsSingleton(metaclass=Singleton): 

78 def __init__(self): 

79 self.data = {} 

80 self.id = 0 

81 

82 class IsBadSingleton(IsSingleton): 

83 def __init__(self, arg): 

84 """A singleton can not accept any arguments.""" 

85 self.arg = arg 

86 

87 class IsSingletonSubclass(IsSingleton): 

88 def __init__(self): 

89 super().__init__() 

90 

91 def testSingleton(self): 

92 one = SingletonTestCase.IsSingleton() 

93 two = SingletonTestCase.IsSingleton() 

94 

95 # Now update the first one and check the second 

96 one.data["test"] = 52 

97 self.assertEqual(one.data, two.data) 

98 two.id += 1 

99 self.assertEqual(one.id, two.id) 

100 

101 three = SingletonTestCase.IsSingletonSubclass() 

102 self.assertNotEqual(one.id, three.id) 

103 

104 with self.assertRaises(TypeError): 

105 SingletonTestCase.IsBadSingleton(52) 

106 

107 

108class NamedKeyDictTest(unittest.TestCase): 

109 

110 def setUp(self): 

111 self.TestTuple = namedtuple("TestTuple", ("name", "id")) 

112 self.a = self.TestTuple(name="a", id=1) 

113 self.b = self.TestTuple(name="b", id=2) 

114 self.dictionary = {self.a: 10, self.b: 20} 

115 self.names = {self.a.name, self.b.name} 

116 

117 def check(self, nkd): 

118 self.assertEqual(len(nkd), 2) 

119 self.assertEqual(nkd.names, self.names) 

120 self.assertEqual(nkd.keys(), self.dictionary.keys()) 

121 self.assertEqual(list(nkd.values()), list(self.dictionary.values())) 

122 self.assertEqual(list(nkd.items()), list(self.dictionary.items())) 

123 self.assertEqual(list(nkd.byName().values()), list(self.dictionary.values())) 

124 self.assertEqual(list(nkd.byName().keys()), list(nkd.names)) 

125 

126 def testConstruction(self): 

127 self.check(NamedKeyDict(self.dictionary)) 

128 self.check(NamedKeyDict(iter(self.dictionary.items()))) 

129 

130 def testDuplicateNameConstruction(self): 

131 self.dictionary[self.TestTuple(name="a", id=3)] = 30 

132 with self.assertRaises(AssertionError): 

133 NamedKeyDict(self.dictionary) 

134 with self.assertRaises(AssertionError): 

135 NamedKeyDict(iter(self.dictionary.items())) 

136 

137 def testNoNameConstruction(self): 

138 self.dictionary["a"] = 30 

139 with self.assertRaises(AttributeError): 

140 NamedKeyDict(self.dictionary) 

141 with self.assertRaises(AttributeError): 

142 NamedKeyDict(iter(self.dictionary.items())) 

143 

144 def testGetItem(self): 

145 nkd = NamedKeyDict(self.dictionary) 

146 self.assertEqual(nkd["a"], 10) 

147 self.assertEqual(nkd[self.a], 10) 

148 self.assertEqual(nkd["b"], 20) 

149 self.assertEqual(nkd[self.b], 20) 

150 self.assertIn("a", nkd) 

151 self.assertIn(self.b, nkd) 

152 

153 def testSetItem(self): 

154 nkd = NamedKeyDict(self.dictionary) 

155 nkd[self.a] = 30 

156 self.assertEqual(nkd["a"], 30) 

157 nkd["b"] = 40 

158 self.assertEqual(nkd[self.b], 40) 

159 with self.assertRaises(KeyError): 

160 nkd["c"] = 50 

161 with self.assertRaises(AssertionError): 

162 nkd[self.TestTuple("a", 3)] = 60 

163 

164 def testDelItem(self): 

165 nkd = NamedKeyDict(self.dictionary) 

166 del nkd[self.a] 

167 self.assertNotIn("a", nkd) 

168 del nkd["b"] 

169 self.assertNotIn(self.b, nkd) 

170 self.assertEqual(len(nkd), 0) 

171 

172 def testIter(self): 

173 self.assertEqual(set(iter(NamedKeyDict(self.dictionary))), set(self.dictionary)) 

174 

175 def testEquality(self): 

176 nkd = NamedKeyDict(self.dictionary) 

177 self.assertEqual(nkd, self.dictionary) 

178 self.assertEqual(self.dictionary, nkd) 

179 

180 

181class NamedValueSetTest(unittest.TestCase): 

182 

183 def setUp(self): 

184 self.TestTuple = namedtuple("TestTuple", ("name", "id")) 

185 self.a = self.TestTuple(name="a", id=1) 

186 self.b = self.TestTuple(name="b", id=2) 

187 self.c = self.TestTuple(name="c", id=3) 

188 

189 def testConstruction(self): 

190 for arg in ({self.a, self.b}, (self.a, self.b)): 

191 for nvs in (NamedValueSet(arg), NamedValueSet(arg).freeze()): 

192 self.assertEqual(len(nvs), 2) 

193 self.assertEqual(nvs.names, {"a", "b"}) 

194 self.assertCountEqual(nvs, {self.a, self.b}) 

195 self.assertCountEqual(nvs.asMapping().items(), [(self.a.name, self.a), (self.b.name, self.b)]) 

196 

197 def testNoNameConstruction(self): 

198 with self.assertRaises(AttributeError): 

199 NamedValueSet([self.a, "a"]) 

200 

201 def testGetItem(self): 

202 nvs = NamedValueSet({self.a, self.b, self.c}) 

203 self.assertEqual(nvs["a"], self.a) 

204 self.assertEqual(nvs[self.a], self.a) 

205 self.assertEqual(nvs["b"], self.b) 

206 self.assertEqual(nvs[self.b], self.b) 

207 self.assertIn("a", nvs) 

208 self.assertIn(self.b, nvs) 

209 

210 def testEquality(self): 

211 s = {self.a, self.b, self.c} 

212 nvs = NamedValueSet(s) 

213 self.assertEqual(nvs, s) 

214 self.assertEqual(s, nvs) 

215 

216 def checkOperator(self, result, expected): 

217 self.assertIsInstance(result, NamedValueSet) 

218 self.assertEqual(result, expected) 

219 

220 def testOperators(self): 

221 ab = NamedValueSet({self.a, self.b}) 

222 bc = NamedValueSet({self.b, self.c}) 

223 self.checkOperator(ab & bc, {self.b}) 

224 self.checkOperator(ab | bc, {self.a, self.b, self.c}) 

225 self.checkOperator(ab ^ bc, {self.a, self.c}) 

226 self.checkOperator(ab - bc, {self.a}) 

227 

228 

229class TestButlerUtils(unittest.TestCase): 

230 """Tests of the simple utilities.""" 

231 

232 def testTypeNames(self): 

233 # Check types and also an object 

234 tests = [(Formatter, "lsst.daf.butler.core.formatter.Formatter"), 

235 (int, "int"), 

236 (StorageClass, "lsst.daf.butler.core.storageClass.StorageClass"), 

237 (StorageClass(None), "lsst.daf.butler.core.storageClass.StorageClass"), 

238 (Registry, "lsst.daf.butler.registry.Registry")] 

239 

240 for item, typeName in tests: 

241 self.assertEqual(getFullTypeName(item), typeName) 

242 

243 def testIsplit(self): 

244 # Test compatibility with str.split 

245 seps = ("\n", " ", "d") 

246 input_str = "ab\ncd ef\n" 

247 

248 for sep in seps: 

249 for input in (input_str, input_str.encode()): 

250 test_sep = sep.encode() if isinstance(input, bytes) else sep 

251 isp = list(isplit(input, sep=test_sep)) 

252 ssp = input.split(test_sep) 

253 self.assertEqual(isp, ssp) 

254 

255 

256class FindFileResourcesTestCase(unittest.TestCase): 

257 

258 def test_getSingleFile(self): 

259 """Test getting a file by its file name.""" 

260 filename = os.path.join(TESTDIR, "config/basic/butler.yaml") 

261 self.assertEqual([filename], findFileResources([filename])) 

262 

263 def test_getAllFiles(self): 

264 """Test getting all the files by not passing a regex.""" 

265 expected = Counter([p for p in glob(os.path.join(TESTDIR, "config", "**"), recursive=True) 

266 if os.path.isfile(p)]) 

267 self.assertNotEqual(len(expected), 0) # verify some files were found 

268 files = Counter(findFileResources([os.path.join(TESTDIR, "config")])) 

269 self.assertEqual(expected, files) 

270 

271 def test_getAllFilesRegex(self): 

272 """Test getting all the files with a regex-specified file ending.""" 

273 expected = Counter(glob(os.path.join(TESTDIR, "config", "**", "*.yaml"), recursive=True)) 

274 self.assertNotEqual(len(expected), 0) # verify some files were found 

275 files = Counter(findFileResources([os.path.join(TESTDIR, "config")], r"\.yaml\b")) 

276 self.assertEqual(expected, files) 

277 

278 def test_multipleInputs(self): 

279 """Test specifying more than one location to find a files.""" 

280 expected = Counter(glob(os.path.join(TESTDIR, "config", "basic", "**", "*.yaml"), recursive=True)) 

281 expected.update(glob(os.path.join(TESTDIR, "config", "templates", "**", "*.yaml"), recursive=True)) 

282 self.assertNotEqual(len(expected), 0) # verify some files were found 

283 files = Counter(findFileResources([os.path.join(TESTDIR, "config", "basic"), 

284 os.path.join(TESTDIR, "config", "templates")], 

285 r"\.yaml\b")) 

286 self.assertEqual(expected, files) 

287 

288 

289class GlobToRegexTestCase(unittest.TestCase): 

290 

291 def testStarInList(self): 

292 """Test that if a one of the items in the expression list is a star 

293 (stand-alone) then ``...`` is returned (which implies no restrictions) 

294 """ 

295 self.assertIs(globToRegex(["foo", "*", "bar"]), ...) 

296 

297 def testGlobList(self): 

298 """Test that a list of glob strings converts as expected to a regex and 

299 returns in the expected list. 

300 """ 

301 # test an absolute string 

302 patterns = globToRegex(["bar"]) 

303 self.assertEqual(len(patterns), 1) 

304 self.assertTrue(bool(re.fullmatch(patterns[0], "bar"))) 

305 self.assertIsNone(re.fullmatch(patterns[0], "boz")) 

306 

307 # test leading & trailing wildcard in multiple patterns 

308 patterns = globToRegex(["ba*", "*.fits"]) 

309 self.assertEqual(len(patterns), 2) 

310 # check the "ba*" pattern: 

311 self.assertTrue(bool(re.fullmatch(patterns[0], "bar"))) 

312 self.assertTrue(bool(re.fullmatch(patterns[0], "baz"))) 

313 self.assertIsNone(re.fullmatch(patterns[0], "boz.fits")) 

314 # check the "*.fits" pattern: 

315 self.assertTrue(bool(re.fullmatch(patterns[1], "bar.fits"))) 

316 self.assertTrue(re.fullmatch(patterns[1], "boz.fits")) 

317 self.assertIsNone(re.fullmatch(patterns[1], "boz.hdf5")) 

318 

319 

320class TimerTestCase(unittest.TestCase): 

321 

322 def testTimer(self): 

323 with self.assertLogs(level="DEBUG") as cm: 

324 with time_this(): 

325 pass 

326 self.assertEqual(cm.records[0].name, "timer") 

327 self.assertEqual(cm.records[0].levelname, "DEBUG") 

328 self.assertEqual(cm.records[0].filename, "test_utils.py") 

329 

330 with self.assertLogs(level="DEBUG") as cm: 

331 with time_this(prefix=None): 

332 pass 

333 self.assertEqual(cm.records[0].name, "root") 

334 self.assertEqual(cm.records[0].levelname, "DEBUG") 

335 self.assertIn("Took", cm.output[0]) 

336 self.assertEqual(cm.records[0].filename, "test_utils.py") 

337 

338 # Change logging level 

339 with self.assertLogs(level="INFO") as cm: 

340 with time_this(level=logging.INFO, prefix=None): 

341 pass 

342 self.assertEqual(cm.records[0].name, "root") 

343 self.assertIn("Took", cm.output[0]) 

344 self.assertIn("seconds", cm.output[0]) 

345 

346 # Use a new logger with a message. 

347 msg = "Test message %d" 

348 test_num = 42 

349 logname = "test" 

350 with self.assertLogs(level="DEBUG") as cm: 

351 with time_this(log=logging.getLogger(logname), 

352 msg=msg, args=(42,), prefix=None): 

353 pass 

354 self.assertEqual(cm.records[0].name, logname) 

355 self.assertIn("Took", cm.output[0]) 

356 self.assertIn(msg % test_num, cm.output[0]) 

357 

358 # Prefix the logger. 

359 prefix = "prefix" 

360 with self.assertLogs(level="DEBUG") as cm: 

361 with time_this(prefix=prefix): 

362 pass 

363 self.assertEqual(cm.records[0].name, prefix) 

364 self.assertIn("Took", cm.output[0]) 

365 

366 # Prefix explicit logger. 

367 with self.assertLogs(level="DEBUG") as cm: 

368 with time_this(log=logging.getLogger(logname), 

369 prefix=prefix): 

370 pass 

371 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}") 

372 

373 # Trigger a problem. 

374 with self.assertLogs(level="ERROR") as cm: 

375 with self.assertRaises(RuntimeError): 

376 with time_this(log=logging.getLogger(logname), 

377 prefix=prefix): 

378 raise RuntimeError("A problem") 

379 self.assertEqual(cm.records[0].name, f"{prefix}.{logname}") 

380 self.assertEqual(cm.records[0].levelname, "ERROR") 

381 

382 

383if __name__ == "__main__": 383 ↛ 384line 383 didn't jump to line 384, because the condition on line 383 was never true

384 unittest.main()