Coverage for tests/test_Config.py: 19%

333 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-23 09:46 +0000

1# This file is part of pex_config. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import io 

29import itertools 

30import os 

31import pickle 

32import re 

33import unittest 

34 

35try: 

36 import yaml 

37except ImportError: 

38 yaml = None 

39 

40import lsst.pex.config as pexConfig 

41 

42# Some tests depend on daf_base. 

43# Skip them if it is not found. 

44try: 

45 import lsst.daf.base as dafBase 

46except ImportError: 

47 dafBase = None 

48 

49GLOBAL_REGISTRY = {} 

50 

51 

52class Simple(pexConfig.Config): 

53 i = pexConfig.Field("integer test", int, optional=True) 

54 f = pexConfig.Field("float test", float, default=3.0) 

55 b = pexConfig.Field("boolean test", bool, default=False, optional=False) 

56 c = pexConfig.ChoiceField( 

57 "choice test", str, default="Hello", allowed={"Hello": "First choice", "World": "second choice"} 

58 ) 

59 r = pexConfig.RangeField("Range test", float, default=3.0, optional=False, min=3.0, inclusiveMin=True) 

60 ll = pexConfig.ListField( 60 ↛ exitline 60 didn't jump to the function exit

61 "list test", int, default=[1, 2, 3], maxLength=5, itemCheck=lambda x: x is not None and x > 0 

62 ) 

63 d = pexConfig.DictField( 63 ↛ exitline 63 didn't jump to the function exit

64 "dict test", str, str, default={"key": "value"}, itemCheck=lambda x: x.startswith("v") 

65 ) 

66 n = pexConfig.Field("nan test", float, default=float("NAN")) 

67 

68 

69GLOBAL_REGISTRY["AAA"] = Simple 

70 

71 

72class InnerConfig(pexConfig.Config): 

73 f = pexConfig.Field("Inner.f", float, default=0.0, check=lambda x: x >= 0, optional=False) 73 ↛ exitline 73 didn't run the lambda on line 73

74 

75 

76GLOBAL_REGISTRY["BBB"] = InnerConfig 

77 

78 

79class OuterConfig(InnerConfig, pexConfig.Config): 

80 i = pexConfig.ConfigField("Outer.i", InnerConfig) 

81 

82 def __init__(self): 

83 pexConfig.Config.__init__(self) 

84 self.i.f = 5.0 

85 

86 def validate(self): 

87 pexConfig.Config.validate(self) 

88 if self.i.f < 5: 

89 raise ValueError("validation failed, outer.i.f must be greater than 5") 

90 

91 

92class Complex(pexConfig.Config): 

93 c = pexConfig.ConfigField("an inner config", InnerConfig) 

94 r = pexConfig.ConfigChoiceField( 

95 "a registry field", typemap=GLOBAL_REGISTRY, default="AAA", optional=False 

96 ) 

97 p = pexConfig.ConfigChoiceField("another registry", typemap=GLOBAL_REGISTRY, default="BBB", optional=True) 

98 

99 

100class Deprecation(pexConfig.Config): 

101 old = pexConfig.Field("Something.", int, default=10, deprecated="not used!") 

102 

103 

104class ConfigTest(unittest.TestCase): 

105 def setUp(self): 

106 self.simple = Simple() 

107 self.inner = InnerConfig() 

108 self.outer = OuterConfig() 

109 self.comp = Complex() 

110 self.deprecation = Deprecation() 

111 

112 def tearDown(self): 

113 del self.simple 

114 del self.inner 

115 del self.outer 

116 del self.comp 

117 

118 def testInit(self): 

119 self.assertIsNone(self.simple.i) 

120 self.assertEqual(self.simple.f, 3.0) 

121 self.assertFalse(self.simple.b) 

122 self.assertEqual(self.simple.c, "Hello") 

123 self.assertEqual(list(self.simple.ll), [1, 2, 3]) 

124 self.assertEqual(self.simple.d["key"], "value") 

125 self.assertEqual(self.inner.f, 0.0) 

126 self.assertEqual(self.deprecation.old, 10) 

127 

128 self.assertEqual(self.deprecation._fields["old"].doc, "Something. Deprecated: not used!") 

129 

130 self.assertEqual(self.outer.i.f, 5.0) 

131 self.assertEqual(self.outer.f, 0.0) 

132 

133 self.assertEqual(self.comp.c.f, 0.0) 

134 self.assertEqual(self.comp.r.name, "AAA") 

135 self.assertEqual(self.comp.r.active.f, 3.0) 

136 self.assertEqual(self.comp.r["BBB"].f, 0.0) 

137 

138 def testDeprecationWarning(self): 

139 """Test that a deprecated field emits a warning when it is set.""" 

140 with self.assertWarns(FutureWarning) as w: 

141 self.deprecation.old = 5 

142 self.assertEqual(self.deprecation.old, 5) 

143 

144 self.assertIn(self.deprecation._fields["old"].deprecated, str(w.warnings[-1].message)) 

145 

146 def testDeprecationOutput(self): 

147 """Test that a deprecated field is not written out unless it is set.""" 

148 stream = io.StringIO() 

149 self.deprecation.saveToStream(stream) 

150 self.assertNotIn("config.old", stream.getvalue()) 

151 with self.assertWarns(FutureWarning): 

152 self.deprecation.old = 5 

153 stream = io.StringIO() 

154 self.deprecation.saveToStream(stream) 

155 self.assertIn("config.old=5\n", stream.getvalue()) 

156 

157 def testValidate(self): 

158 self.simple.validate() 

159 

160 self.inner.validate() 

161 self.assertRaises(ValueError, setattr, self.outer.i, "f", -5) 

162 self.outer.i.f = 10.0 

163 self.outer.validate() 

164 

165 try: 

166 self.simple.d["failKey"] = "failValue" 

167 except pexConfig.FieldValidationError: 

168 pass 

169 except Exception: 

170 raise "Validation error Expected" 

171 self.simple.validate() 

172 

173 self.outer.i = InnerConfig 

174 self.assertRaises(ValueError, self.outer.validate) 

175 self.outer.i = InnerConfig() 

176 self.assertRaises(ValueError, self.outer.validate) 

177 

178 self.comp.validate() 

179 self.comp.r = None 

180 self.assertRaises(ValueError, self.comp.validate) 

181 self.comp.r = "BBB" 

182 self.comp.validate() 

183 

184 def testRangeFieldConstructor(self): 

185 """Test RangeField constructor's checking of min, max""" 

186 val = 3 

187 self.assertRaises(ValueError, pexConfig.RangeField, "", int, default=val, min=val, max=val - 1) 

188 self.assertRaises(ValueError, pexConfig.RangeField, "", float, default=val, min=val, max=val - 1e-15) 

189 for inclusiveMin, inclusiveMax in itertools.product((False, True), (False, True)): 

190 if inclusiveMin and inclusiveMax: 

191 # should not raise 

192 class Cfg1(pexConfig.Config): 

193 r1 = pexConfig.RangeField( 

194 doc="", 

195 dtype=int, 

196 default=val, 

197 min=val, 

198 max=val, 

199 inclusiveMin=inclusiveMin, 

200 inclusiveMax=inclusiveMax, 

201 ) 

202 r2 = pexConfig.RangeField( 

203 doc="", 

204 dtype=float, 

205 default=val, 

206 min=val, 

207 max=val, 

208 inclusiveMin=inclusiveMin, 

209 inclusiveMax=inclusiveMax, 

210 ) 

211 

212 Cfg1() 

213 else: 

214 # raise while constructing the RangeField (hence cannot make 

215 # it part of a Config) 

216 self.assertRaises( 

217 ValueError, 

218 pexConfig.RangeField, 

219 doc="", 

220 dtype=int, 

221 default=val, 

222 min=val, 

223 max=val, 

224 inclusiveMin=inclusiveMin, 

225 inclusiveMax=inclusiveMax, 

226 ) 

227 self.assertRaises( 

228 ValueError, 

229 pexConfig.RangeField, 

230 doc="", 

231 dtype=float, 

232 default=val, 

233 min=val, 

234 max=val, 

235 inclusiveMin=inclusiveMin, 

236 inclusiveMax=inclusiveMax, 

237 ) 

238 

239 def testRangeFieldDefault(self): 

240 """Test RangeField's checking of the default value""" 

241 minVal = 3 

242 maxVal = 4 

243 for val, inclusiveMin, inclusiveMax, shouldRaise in ( 

244 (minVal, False, True, True), 

245 (minVal, True, True, False), 

246 (maxVal, True, False, True), 

247 (maxVal, True, True, False), 

248 ): 

249 

250 class Cfg1(pexConfig.Config): 

251 r = pexConfig.RangeField( 

252 doc="", 

253 dtype=int, 

254 default=val, 

255 min=minVal, 

256 max=maxVal, 

257 inclusiveMin=inclusiveMin, 

258 inclusiveMax=inclusiveMax, 

259 ) 

260 

261 class Cfg2(pexConfig.Config): 

262 r2 = pexConfig.RangeField( 

263 doc="", 

264 dtype=float, 

265 default=val, 

266 min=minVal, 

267 max=maxVal, 

268 inclusiveMin=inclusiveMin, 

269 inclusiveMax=inclusiveMax, 

270 ) 

271 

272 if shouldRaise: 

273 self.assertRaises(pexConfig.FieldValidationError, Cfg1) 

274 self.assertRaises(pexConfig.FieldValidationError, Cfg2) 

275 else: 

276 Cfg1() 

277 Cfg2() 

278 

279 def testSave(self): 

280 self.comp.r = "BBB" 

281 self.comp.p = "AAA" 

282 self.comp.c.f = 5.0 

283 self.comp.save("roundtrip.test") 

284 

285 roundTrip = Complex() 

286 roundTrip.load("roundtrip.test") 

287 os.remove("roundtrip.test") 

288 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

289 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

290 del roundTrip 

291 

292 # test saving to an open file 

293 with open("roundtrip.test", "w") as outfile: 

294 self.comp.saveToStream(outfile) 

295 roundTrip = Complex() 

296 with open("roundtrip.test", "r") as infile: 

297 roundTrip.loadFromStream(infile) 

298 os.remove("roundtrip.test") 

299 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

300 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

301 del roundTrip 

302 

303 # test saving to a string. 

304 saved_string = self.comp.saveToString() 

305 roundTrip = Complex() 

306 roundTrip.loadFromString(saved_string) 

307 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

308 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

309 del roundTrip 

310 

311 # Test an override of the default variable name. 

312 with open("roundtrip.test", "w") as outfile: 

313 self.comp.saveToStream(outfile, root="root") 

314 roundTrip = Complex() 

315 with self.assertRaises(NameError): 

316 roundTrip.load("roundtrip.test") 

317 roundTrip.load("roundtrip.test", root="root") 

318 os.remove("roundtrip.test") 

319 self.assertEqual(self.comp.c.f, roundTrip.c.f) 

320 self.assertEqual(self.comp.r.name, roundTrip.r.name) 

321 

322 def testDuplicateRegistryNames(self): 

323 self.comp.r["AAA"].f = 5.0 

324 self.assertEqual(self.comp.p["AAA"].f, 3.0) 

325 

326 def testInheritance(self): 

327 class AAA(pexConfig.Config): 

328 a = pexConfig.Field("AAA.a", int, default=4) 

329 

330 class BBB(AAA): 

331 b = pexConfig.Field("BBB.b", int, default=3) 

332 

333 class CCC(BBB): 

334 c = pexConfig.Field("CCC.c", int, default=2) 

335 

336 # test multi-level inheritance 

337 c = CCC() 

338 self.assertIn("a", c.toDict()) 

339 self.assertEqual(c._fields["a"].dtype, int) 

340 self.assertEqual(c.a, 4) 

341 

342 # test conflicting multiple inheritance 

343 class DDD(pexConfig.Config): 

344 a = pexConfig.Field("DDD.a", float, default=0.0) 

345 

346 class EEE(DDD, AAA): 

347 pass 

348 

349 e = EEE() 

350 self.assertEqual(e._fields["a"].dtype, float) 

351 self.assertIn("a", e.toDict()) 

352 self.assertEqual(e.a, 0.0) 

353 

354 class FFF(AAA, DDD): 

355 pass 

356 

357 f = FFF() 

358 self.assertEqual(f._fields["a"].dtype, int) 

359 self.assertIn("a", f.toDict()) 

360 self.assertEqual(f.a, 4) 

361 

362 # test inheritance from non Config objects 

363 class GGG: 

364 a = pexConfig.Field("AAA.a", float, default=10.0) 

365 

366 class HHH(GGG, AAA): 

367 pass 

368 

369 h = HHH() 

370 self.assertEqual(h._fields["a"].dtype, float) 

371 self.assertIn("a", h.toDict()) 

372 self.assertEqual(h.a, 10.0) 

373 

374 # test partial Field redefinition 

375 

376 class III(AAA): 

377 pass 

378 

379 III.a.default = 5 

380 

381 self.assertEqual(III.a.default, 5) 

382 self.assertEqual(AAA.a.default, 4) 

383 

384 @unittest.skipIf(dafBase is None, "lsst.daf.base is required") 

385 def testConvertPropertySet(self): 

386 ps = pexConfig.makePropertySet(self.simple) 

387 self.assertFalse(ps.exists("i")) 

388 self.assertEqual(ps.getScalar("f"), self.simple.f) 

389 self.assertEqual(ps.getScalar("b"), self.simple.b) 

390 self.assertEqual(ps.getScalar("c"), self.simple.c) 

391 self.assertEqual(list(ps.getArray("ll")), list(self.simple.ll)) 

392 

393 ps = pexConfig.makePropertySet(self.comp) 

394 self.assertEqual(ps.getScalar("c.f"), self.comp.c.f) 

395 

396 def testFreeze(self): 

397 self.comp.freeze() 

398 

399 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp.c, "f", 10.0) 

400 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp, "r", "AAA") 

401 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp, "p", "AAA") 

402 self.assertRaises(pexConfig.FieldValidationError, setattr, self.comp.p["AAA"], "f", 5.0) 

403 

404 def checkImportRoundTrip(self, importStatement, searchString, shouldBeThere): 

405 self.comp.c.f = 5.0 

406 

407 # Generate a Config through loading 

408 stream = io.StringIO() 

409 stream.write(str(importStatement)) 

410 self.comp.saveToStream(stream) 

411 roundtrip = Complex() 

412 roundtrip.loadFromStream(stream.getvalue()) 

413 self.assertEqual(self.comp.c.f, roundtrip.c.f) 

414 

415 # Check the save stream 

416 stream = io.StringIO() 

417 roundtrip.saveToStream(stream) 

418 self.assertEqual(self.comp.c.f, roundtrip.c.f) 

419 streamStr = stream.getvalue() 

420 if shouldBeThere: 

421 self.assertTrue(re.search(searchString, streamStr)) 

422 else: 

423 self.assertFalse(re.search(searchString, streamStr)) 

424 

425 def testImports(self): 

426 # A module not used by anything else, but which exists 

427 importing = "import lsst.pex.config._doNotImportMe\n" 

428 self.checkImportRoundTrip(importing, importing, True) 

429 

430 def testBadImports(self): 

431 dummy = "somethingThatDoesntExist" 

432 importing = ( 

433 """ 

434try: 

435 import %s 

436except ImportError: 

437 pass 

438""" 

439 % dummy 

440 ) 

441 self.checkImportRoundTrip(importing, dummy, False) 

442 

443 def testPickle(self): 

444 self.simple.f = 5 

445 simple = pickle.loads(pickle.dumps(self.simple)) 

446 self.assertIsInstance(simple, Simple) 

447 self.assertEqual(self.simple.f, simple.f) 

448 

449 self.comp.c.f = 5 

450 comp = pickle.loads(pickle.dumps(self.comp)) 

451 self.assertIsInstance(comp, Complex) 

452 self.assertEqual(self.comp.c.f, comp.c.f) 

453 

454 @unittest.skipIf(yaml is None, "Test requires pyyaml") 

455 def testYaml(self): 

456 self.simple.f = 5 

457 simple = yaml.safe_load(yaml.dump(self.simple)) 

458 self.assertIsInstance(simple, Simple) 

459 self.assertEqual(self.simple.f, simple.f) 

460 

461 self.comp.c.f = 5 

462 # Use a different loader to check that it also works 

463 comp = yaml.load(yaml.dump(self.comp), Loader=yaml.FullLoader) 

464 self.assertIsInstance(comp, Complex) 

465 self.assertEqual(self.comp.c.f, comp.c.f) 

466 

467 def testCompare(self): 

468 comp2 = Complex() 

469 inner2 = InnerConfig() 

470 simple2 = Simple() 

471 self.assertTrue(self.comp.compare(comp2)) 

472 self.assertTrue(comp2.compare(self.comp)) 

473 self.assertTrue(self.comp.c.compare(inner2)) 

474 self.assertTrue(self.simple.compare(simple2)) 

475 self.assertTrue(simple2.compare(self.simple)) 

476 self.assertEqual(self.simple, simple2) 

477 self.assertEqual(simple2, self.simple) 

478 outList = [] 

479 

480 def outFunc(msg): 

481 outList.append(msg) 

482 

483 simple2.b = True 

484 simple2.ll.append(4) 

485 simple2.d["foo"] = "var" 

486 self.assertFalse(self.simple.compare(simple2, shortcut=True, output=outFunc)) 

487 self.assertEqual(len(outList), 1) 

488 del outList[:] 

489 self.assertFalse(self.simple.compare(simple2, shortcut=False, output=outFunc)) 

490 output = "\n".join(outList) 

491 self.assertIn("Inequality in b", output) 

492 self.assertIn("Inequality in size for ll", output) 

493 self.assertIn("Inequality in keys for d", output) 

494 del outList[:] 

495 self.simple.d["foo"] = "vast" 

496 self.simple.ll.append(5) 

497 self.simple.b = True 

498 self.simple.f += 1e8 

499 self.assertFalse(self.simple.compare(simple2, shortcut=False, output=outFunc)) 

500 output = "\n".join(outList) 

501 self.assertIn("Inequality in f", output) 

502 self.assertIn("Inequality in ll[3]", output) 

503 self.assertIn("Inequality in d['foo']", output) 

504 del outList[:] 

505 comp2.r["BBB"].f = 1.0 # changing the non-selected item shouldn't break equality 

506 self.assertTrue(self.comp.compare(comp2)) 

507 comp2.r["AAA"].i = 56 # changing the selected item should break equality 

508 comp2.c.f = 1.0 

509 self.assertFalse(self.comp.compare(comp2, shortcut=False, output=outFunc)) 

510 output = "\n".join(outList) 

511 self.assertIn("Inequality in c.f", output) 

512 self.assertIn("Inequality in r['AAA']", output) 

513 self.assertNotIn("Inequality in r['BBB']", output) 

514 

515 # Before DM-16561, this incorrectly returned `True`. 

516 self.assertFalse(self.inner.compare(self.outer)) 

517 # Before DM-16561, this raised. 

518 self.assertFalse(self.outer.compare(self.inner)) 

519 

520 def testLoadError(self): 

521 """Check that loading allows errors in the file being loaded to 

522 propagate. 

523 """ 

524 self.assertRaises(SyntaxError, self.simple.loadFromStream, "bork bork bork") 

525 self.assertRaises(NameError, self.simple.loadFromStream, "config.f = bork") 

526 

527 def testNames(self): 

528 """Check that the names() method returns valid keys. 

529 

530 Also check that we have the right number of keys, and as they are 

531 all known to be valid we know that we got them all. 

532 """ 

533 

534 names = self.simple.names() 

535 self.assertEqual(len(names), 8) 

536 for name in names: 

537 self.assertTrue(hasattr(self.simple, name)) 

538 

539 def testIteration(self): 

540 self.assertIn("ll", self.simple) 

541 self.assertIn("ll", self.simple.keys()) 

542 self.assertIn("Hello", self.simple.values()) 

543 self.assertEqual(len(self.simple.values()), 8) 

544 

545 for k, v, (k1, v1) in zip(self.simple.keys(), self.simple.values(), self.simple.items()): 

546 self.assertEqual(k, k1) 

547 if k == "n": 

548 self.assertNotEqual(v, v1) 

549 else: 

550 self.assertEqual(v, v1) 

551 

552 

553if __name__ == "__main__": 553 ↛ 554line 553 didn't jump to line 554, because the condition on line 553 was never true

554 unittest.main()