Coverage for tests/test_Config.py: 19%

321 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-02 11:08 +0000

1# This file is part of pex_config. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import io 

29import itertools 

30import os 

31import pickle 

32import re 

33import unittest 

34 

35try: 

36 import yaml 

37except ImportError: 

38 yaml = None 

39 

40import lsst.pex.config as pexConfig 

41 

42# Some tests depend on daf_base. 

43# Skip them if it is not found. 

44try: 

45 import lsst.daf.base as dafBase 

46except ImportError: 

47 dafBase = None 

48 

49GLOBAL_REGISTRY = {} 

50 

51 

52class Simple(pexConfig.Config): 

53 i = pexConfig.Field("integer test", int, optional=True) 

54 f = pexConfig.Field("float test", float, default=3.0) 

55 b = pexConfig.Field("boolean test", bool, default=False, optional=False) 

56 c = pexConfig.ChoiceField( 

57 "choice test", str, default="Hello", allowed={"Hello": "First choice", "World": "second choice"} 

58 ) 

59 r = pexConfig.RangeField("Range test", float, default=3.0, optional=False, min=3.0, inclusiveMin=True) 

60 ll = pexConfig.ListField( 60 ↛ exitline 60 didn't jump to the function exit

61 "list test", int, default=[1, 2, 3], maxLength=5, itemCheck=lambda x: x is not None and x > 0 

62 ) 

63 d = pexConfig.DictField( 63 ↛ exitline 63 didn't jump to the function exit

64 "dict test", str, str, default={"key": "value"}, itemCheck=lambda x: x.startswith("v") 

65 ) 

66 n = pexConfig.Field("nan test", float, default=float("NAN")) 

67 

68 

69GLOBAL_REGISTRY["AAA"] = Simple 

70 

71 

72class InnerConfig(pexConfig.Config): 

73 f = pexConfig.Field("Inner.f", float, default=0.0, check=lambda x: x >= 0, optional=False) 73 ↛ exitline 73 didn't run the lambda on line 73

74 

75 

76GLOBAL_REGISTRY["BBB"] = InnerConfig 

77 

78 

79class OuterConfig(InnerConfig, pexConfig.Config): 

80 i = pexConfig.ConfigField("Outer.i", InnerConfig) 

81 

82 def __init__(self): 

83 pexConfig.Config.__init__(self) 

84 self.i.f = 5.0 

85 

86 def validate(self): 

87 pexConfig.Config.validate(self) 

88 if self.i.f < 5: 

89 raise ValueError("validation failed, outer.i.f must be greater than 5") 

90 

91 

92class Complex(pexConfig.Config): 

93 c = pexConfig.ConfigField("an inner config", InnerConfig) 

94 r = pexConfig.ConfigChoiceField( 

95 "a registry field", typemap=GLOBAL_REGISTRY, default="AAA", optional=False 

96 ) 

97 p = pexConfig.ConfigChoiceField("another registry", typemap=GLOBAL_REGISTRY, default="BBB", optional=True) 

98 

99 

100class Deprecation(pexConfig.Config): 

101 old = pexConfig.Field("Something.", int, default=10, deprecated="not used!") 

102 

103 

104class ConfigTest(unittest.TestCase): 

105 def setUp(self): 

106 self.simple = Simple() 

107 self.inner = InnerConfig() 

108 self.outer = OuterConfig() 

109 self.comp = Complex() 

110 self.deprecation = Deprecation() 

111 

112 def tearDown(self): 

113 del self.simple 

114 del self.inner 

115 del self.outer 

116 del self.comp 

117 

118 def testInit(self): 

119 self.assertIsNone(self.simple.i) 

120 self.assertEqual(self.simple.f, 3.0) 

121 self.assertFalse(self.simple.b) 

122 self.assertEqual(self.simple.c, "Hello") 

123 self.assertEqual(list(self.simple.ll), [1, 2, 3]) 

124 self.assertEqual(self.simple.d["key"], "value") 

125 self.assertEqual(self.inner.f, 0.0) 

126 self.assertEqual(self.deprecation.old, 10) 

127 

128 self.assertEqual(self.deprecation._fields["old"].doc, "Something. Deprecated: not used!") 

129 

130 self.assertEqual(self.outer.i.f, 5.0) 

131 self.assertEqual(self.outer.f, 0.0) 

132 

133 self.assertEqual(self.comp.c.f, 0.0) 

134 self.assertEqual(self.comp.r.name, "AAA") 

135 self.assertEqual(self.comp.r.active.f, 3.0) 

136 self.assertEqual(self.comp.r["BBB"].f, 0.0) 

137 

138 def testDeprecationWarning(self): 

139 """Test that a deprecated field emits a warning when it is set.""" 

140 with self.assertWarns(FutureWarning) as w: 

141 self.deprecation.old = 5 

142 self.assertEqual(self.deprecation.old, 5) 

143 

144 self.assertIn(self.deprecation._fields["old"].deprecated, str(w.warnings[-1].message)) 

145 

146 def testDeprecationOutput(self): 

147 """Test that a deprecated field is not written out unless it is set.""" 

148 stream = io.StringIO() 

149 self.deprecation.saveToStream(stream) 

150 self.assertNotIn("config.old", stream.getvalue()) 

151 with self.assertWarns(FutureWarning): 

152 self.deprecation.old = 5 

153 stream = io.StringIO() 

154 self.deprecation.saveToStream(stream) 

155 self.assertIn("config.old=5\n", stream.getvalue()) 

156 

157 def testValidate(self): 

158 self.simple.validate() 

159 

160 self.inner.validate() 

161 self.assertRaises(ValueError, setattr, self.outer.i, "f", -5) 

162 self.outer.i.f = 10.0 

163 self.outer.validate() 

164 

165 try: 

166 self.simple.d["failKey"] = "failValue" 

167 except pexConfig.FieldValidationError: 

168 pass 

169 except Exception: 

170 raise "Validation error Expected" 

171 self.simple.validate() 

172 

173 self.outer.i = InnerConfig 

174 self.assertRaises(ValueError, self.outer.validate) 

175 self.outer.i = InnerConfig() 

176 self.assertRaises(ValueError, self.outer.validate) 

177 

178 self.comp.validate() 

179 self.comp.r = None 

180 self.assertRaises(ValueError, self.comp.validate) 

181 self.comp.r = "BBB" 

182 self.comp.validate() 

183 

184 def testRangeFieldConstructor(self): 

185 """Test RangeField constructor's checking of min, max""" 

186 val = 3 

187 self.assertRaises(ValueError, pexConfig.RangeField, "", int, default=val, min=val, max=val - 1) 

188 self.assertRaises(ValueError, pexConfig.RangeField, "", float, default=val, min=val, max=val - 1e-15) 

189 for inclusiveMin, inclusiveMax in itertools.product((False, True), (False, True)): 

190 if inclusiveMin and inclusiveMax: 

191 # should not raise 

192 class Cfg1(pexConfig.Config): 

193 r1 = pexConfig.RangeField( 

194 doc="", 

195 dtype=int, 

196 default=val, 

197 min=val, 

198 max=val, 

199 inclusiveMin=inclusiveMin, 

200 inclusiveMax=inclusiveMax, 

201 ) 

202 r2 = pexConfig.RangeField( 

203 doc="", 

204 dtype=float, 

205 default=val, 

206 min=val, 

207 max=val, 

208 inclusiveMin=inclusiveMin, 

209 inclusiveMax=inclusiveMax, 

210 ) 

211 

212 Cfg1() 

213 else: 

214 # raise while constructing the RangeField (hence cannot make 

215 # it part of a Config) 

216 self.assertRaises( 

217 ValueError, 

218 pexConfig.RangeField, 

219 doc="", 

220 dtype=int, 

221 default=val, 

222 min=val, 

223 max=val, 

224 inclusiveMin=inclusiveMin, 

225 inclusiveMax=inclusiveMax, 

226 ) 

227 self.assertRaises( 

228 ValueError, 

229 pexConfig.RangeField, 

230 doc="", 

231 dtype=float, 

232 default=val, 

233 min=val, 

234 max=val, 

235 inclusiveMin=inclusiveMin, 

236 inclusiveMax=inclusiveMax, 

237 ) 

238 

239 def testRangeFieldDefault(self): 

240 """Test RangeField's checking of the default value""" 

241 minVal = 3 

242 maxVal = 4 

243 for val, inclusiveMin, inclusiveMax, shouldRaise in ( 

244 (minVal, False, True, True), 

245 (minVal, True, True, False), 

246 (maxVal, True, False, True), 

247 (maxVal, True, True, False), 

248 ): 

249 

250 class Cfg1(pexConfig.Config): 

251 r = pexConfig.RangeField( 

252 doc="", 

253 dtype=int, 

254 default=val, 

255 min=minVal, 

256 max=maxVal, 

257 inclusiveMin=inclusiveMin, 

258 inclusiveMax=inclusiveMax, 

259 ) 

260 

261 class Cfg2(pexConfig.Config): 

262 r2 = pexConfig.RangeField( 

263 doc="", 

264 dtype=float, 

265 default=val, 

266 min=minVal, 

267 max=maxVal, 

268 inclusiveMin=inclusiveMin, 

269 inclusiveMax=inclusiveMax, 

270 ) 

271 

272 if shouldRaise: 

273 self.assertRaises(pexConfig.FieldValidationError, Cfg1) 

274 self.assertRaises(pexConfig.FieldValidationError, Cfg2) 

275 else: 

276 Cfg1() 

277 Cfg2() 

278 

279 def testSave(self): 

280 self.comp.r = "BBB" 

281 self.comp.p = "AAA" 

282 self.comp.c.f = 5.0 

283 self.comp.save("roundtrip.test") 

284 

285 roundTrip = Complex() 

286 roundTrip.load("roundtrip.test") 

287 os.remove("roundtrip.test") 

288 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

289 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

290 del roundTrip 

291 

292 # test saving to an open file 

293 with open("roundtrip.test", "w") as outfile: 

294 self.comp.saveToStream(outfile) 

295 roundTrip = Complex() 

296 with open("roundtrip.test", "r") as infile: 

297 roundTrip.loadFromStream(infile) 

298 os.remove("roundtrip.test") 

299 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

300 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

301 del roundTrip 

302 

303 # test saving to a string. 

304 saved_string = self.comp.saveToString() 

305 roundTrip = Complex() 

306 roundTrip.loadFromString(saved_string) 

307 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

308 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

309 del roundTrip 

310 

311 # test backwards compatibility feature of allowing "root" instead of 

312 # "config" 

313 with open("roundtrip.test", "w") as outfile: 

314 self.comp.saveToStream(outfile, root="root") 

315 roundTrip = Complex() 

316 roundTrip.load("roundtrip.test") 

317 os.remove("roundtrip.test") 

318 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

319 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

320 

321 def testDuplicateRegistryNames(self): 

322 self.comp.r["AAA"].f = 5.0 

323 self.assertEqual(self.comp.p["AAA"].f, 3.0) 

324 

325 def testInheritance(self): 

326 class AAA(pexConfig.Config): 

327 a = pexConfig.Field("AAA.a", int, default=4) 

328 

329 class BBB(AAA): 

330 b = pexConfig.Field("BBB.b", int, default=3) 

331 

332 class CCC(BBB): 

333 c = pexConfig.Field("CCC.c", int, default=2) 

334 

335 # test multi-level inheritance 

336 c = CCC() 

337 self.assertIn("a", c.toDict()) 

338 self.assertEqual(c._fields["a"].dtype, int) 

339 self.assertEqual(c.a, 4) 

340 

341 # test conflicting multiple inheritance 

342 class DDD(pexConfig.Config): 

343 a = pexConfig.Field("DDD.a", float, default=0.0) 

344 

345 class EEE(DDD, AAA): 

346 pass 

347 

348 e = EEE() 

349 self.assertEqual(e._fields["a"].dtype, float) 

350 self.assertIn("a", e.toDict()) 

351 self.assertEqual(e.a, 0.0) 

352 

353 class FFF(AAA, DDD): 

354 pass 

355 

356 f = FFF() 

357 self.assertEqual(f._fields["a"].dtype, int) 

358 self.assertIn("a", f.toDict()) 

359 self.assertEqual(f.a, 4) 

360 

361 # test inheritance from non Config objects 

362 class GGG: 

363 a = pexConfig.Field("AAA.a", float, default=10.0) 

364 

365 class HHH(GGG, AAA): 

366 pass 

367 

368 h = HHH() 

369 self.assertEqual(h._fields["a"].dtype, float) 

370 self.assertIn("a", h.toDict()) 

371 self.assertEqual(h.a, 10.0) 

372 

373 # test partial Field redefinition 

374 

375 class III(AAA): 

376 pass 

377 

378 III.a.default = 5 

379 

380 self.assertEqual(III.a.default, 5) 

381 self.assertEqual(AAA.a.default, 4) 

382 

383 @unittest.skipIf(dafBase is None, "lsst.daf.base is required") 

384 def testConvertPropertySet(self): 

385 ps = pexConfig.makePropertySet(self.simple) 

386 self.assertFalse(ps.exists("i")) 

387 self.assertEqual(ps.getScalar("f"), self.simple.f) 

388 self.assertEqual(ps.getScalar("b"), self.simple.b) 

389 self.assertEqual(ps.getScalar("c"), self.simple.c) 

390 self.assertEqual(list(ps.getArray("ll")), list(self.simple.ll)) 

391 

392 ps = pexConfig.makePropertySet(self.comp) 

393 self.assertEqual(ps.getScalar("c.f"), self.comp.c.f) 

394 

395 def testFreeze(self): 

396 self.comp.freeze() 

397 

398 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp.c, "f", 10.0) 

399 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp, "r", "AAA") 

400 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp, "p", "AAA") 

401 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp.p["AAA"], "f", 5.0) 

402 

403 def checkImportRoundTrip(self, importStatement, searchString, shouldBeThere): 

404 self.comp.c.f = 5.0 

405 

406 # Generate a Config through loading 

407 stream = io.StringIO() 

408 stream.write(str(importStatement)) 

409 self.comp.saveToStream(stream) 

410 roundtrip = Complex() 

411 roundtrip.loadFromStream(stream.getvalue()) 

412 self.assertEqual(self.comp.c.f, roundtrip.c.f) 

413 

414 # Check the save stream 

415 stream = io.StringIO() 

416 roundtrip.saveToStream(stream) 

417 self.assertEqual(self.comp.c.f, roundtrip.c.f) 

418 streamStr = stream.getvalue() 

419 if shouldBeThere: 

420 self.assertTrue(re.search(searchString, streamStr)) 

421 else: 

422 self.assertFalse(re.search(searchString, streamStr)) 

423 

424 def testImports(self): 

425 # A module not used by anything else, but which exists 

426 importing = "import lsst.pex.config._doNotImportMe\n" 

427 self.checkImportRoundTrip(importing, importing, True) 

428 

429 def testBadImports(self): 

430 dummy = "somethingThatDoesntExist" 

431 importing = ( 

432 """ 

433try: 

434 import %s 

435except ImportError: 

436 pass 

437""" 

438 % dummy 

439 ) 

440 self.checkImportRoundTrip(importing, dummy, False) 

441 

442 def testPickle(self): 

443 self.simple.f = 5 

444 simple = pickle.loads(pickle.dumps(self.simple)) 

445 self.assertIsInstance(simple, Simple) 

446 self.assertEqual(self.simple.f, simple.f) 

447 

448 self.comp.c.f = 5 

449 comp = pickle.loads(pickle.dumps(self.comp)) 

450 self.assertIsInstance(comp, Complex) 

451 self.assertEqual(self.comp.c.f, comp.c.f) 

452 

453 @unittest.skipIf(yaml is None, "Test requires pyyaml") 

454 def testYaml(self): 

455 self.simple.f = 5 

456 simple = yaml.safe_load(yaml.dump(self.simple)) 

457 self.assertIsInstance(simple, Simple) 

458 self.assertEqual(self.simple.f, simple.f) 

459 

460 self.comp.c.f = 5 

461 # Use a different loader to check that it also works 

462 comp = yaml.load(yaml.dump(self.comp), Loader=yaml.FullLoader) 

463 self.assertIsInstance(comp, Complex) 

464 self.assertEqual(self.comp.c.f, comp.c.f) 

465 

466 def testCompare(self): 

467 comp2 = Complex() 

468 inner2 = InnerConfig() 

469 simple2 = Simple() 

470 self.assertTrue(self.comp.compare(comp2)) 

471 self.assertTrue(comp2.compare(self.comp)) 

472 self.assertTrue(self.comp.c.compare(inner2)) 

473 self.assertTrue(self.simple.compare(simple2)) 

474 self.assertTrue(simple2.compare(self.simple)) 

475 self.assertEqual(self.simple, simple2) 

476 self.assertEqual(simple2, self.simple) 

477 outList = [] 

478 

479 def outFunc(msg): 

480 outList.append(msg) 

481 

482 simple2.b = True 

483 simple2.ll.append(4) 

484 simple2.d["foo"] = "var" 

485 self.assertFalse(self.simple.compare(simple2, shortcut=True, output=outFunc)) 

486 self.assertEqual(len(outList), 1) 

487 del outList[:] 

488 self.assertFalse(self.simple.compare(simple2, shortcut=False, output=outFunc)) 

489 output = "\n".join(outList) 

490 self.assertIn("Inequality in b", output) 

491 self.assertIn("Inequality in size for ll", output) 

492 self.assertIn("Inequality in keys for d", output) 

493 del outList[:] 

494 self.simple.d["foo"] = "vast" 

495 self.simple.ll.append(5) 

496 self.simple.b = True 

497 self.simple.f += 1e8 

498 self.assertFalse(self.simple.compare(simple2, shortcut=False, output=outFunc)) 

499 output = "\n".join(outList) 

500 self.assertIn("Inequality in f", output) 

501 self.assertIn("Inequality in ll[3]", output) 

502 self.assertIn("Inequality in d['foo']", output) 

503 del outList[:] 

504 comp2.r["BBB"].f = 1.0 # changing the non-selected item shouldn't break equality 

505 self.assertTrue(self.comp.compare(comp2)) 

506 comp2.r["AAA"].i = 56 # changing the selected item should break equality 

507 comp2.c.f = 1.0 

508 self.assertFalse(self.comp.compare(comp2, shortcut=False, output=outFunc)) 

509 output = "\n".join(outList) 

510 self.assertIn("Inequality in c.f", output) 

511 self.assertIn("Inequality in r['AAA']", output) 

512 self.assertNotIn("Inequality in r['BBB']", output) 

513 

514 # Before DM-16561, this incorrectly returned `True`. 

515 self.assertFalse(self.inner.compare(self.outer)) 

516 # Before DM-16561, this raised. 

517 self.assertFalse(self.outer.compare(self.inner)) 

518 

519 def testLoadError(self): 

520 """Check that loading allows errors in the file being loaded to 

521 propagate. 

522 """ 

523 self.assertRaises(SyntaxError, self.simple.loadFromStream, "bork bork bork") 

524 self.assertRaises(NameError, self.simple.loadFromStream, "config.f = bork") 

525 

526 def testNames(self): 

527 """Check that the names() method returns valid keys. 

528 

529 Also check that we have the right number of keys, and as they are 

530 all known to be valid we know that we got them all. 

531 """ 

532 

533 names = self.simple.names() 

534 self.assertEqual(len(names), 8) 

535 for name in names: 

536 self.assertTrue(hasattr(self.simple, name)) 

537 

538 

539if __name__ == "__main__": 539 ↛ 540line 539 didn't jump to line 540, because the condition on line 539 was never true

540 unittest.main()