Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import with_statement 

2from builtins import zip 

3from builtins import str 

4import sys 

5str_type = str 

6if sys.version_info.major == 2: 6 ↛ 7line 6 didn't jump to line 7, because the condition on line 6 was never true

7 from past.builtins import str as past_str 

8 str_type = past_str 

9from builtins import range 

10import os 

11import numpy as np 

12import sqlite3 

13import unittest 

14import tempfile 

15import shutil 

16import lsst.utils.tests 

17from lsst.sims.utils.CodeUtilities import sims_clean_up 

18from lsst.sims.utils import ObservationMetaData 

19from lsst.sims.catalogs.db import CatalogDBObject 

20from lsst.sims.catalogs.utils import myTestStars, makeStarTestDB 

21from lsst.sims.catalogs.definitions import InstanceCatalog 

22from lsst.sims.utils import Site 

23 

24ROOT = os.path.abspath(os.path.dirname(__file__)) 

25 

26 

27def setup_module(module): 

28 lsst.utils.tests.init() 

29 

30 

31def createCannotBeNullTestDB(filename=None, add_nans=True, dir=None): 

32 """ 

33 Create a database to test the 'cannot_be_null' functionality in InstanceCatalog 

34 

35 This method will return the contents of the database as a recarray for baseline comparison 

36 in the unit tests. 

37 If the filename is not specified, it will be written in to directory "dir" if not 

38 none, else it will be written to the current directory 

39 """ 

40 

41 if filename is None: 

42 dbName = 'cannotBeNullTest.db' 

43 if dir is not None: 

44 dbName = os.path.join(dir, dbName) 

45 else: 

46 dbName = filename 

47 

48 rng = np.random.RandomState(32) 

49 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64), 

50 ('n4', (str_type, 40)), ('n5', (str_type, 40))]) 

51 output = None 

52 

53 if os.path.exists(dbName): 

54 os.unlink(dbName) 

55 

56 conn = sqlite3.connect(dbName) 

57 c = conn.cursor() 

58 try: 

59 c.execute('''CREATE TABLE testTable (id int, n1 float, n2 float, n3 float, n4 text, n5 text)''') 

60 conn.commit() 

61 except: 

62 raise RuntimeError("Error creating database.") 

63 

64 for ii in range(100): 

65 

66 values = rng.random_sample(3) 

67 for i in range(len(values)): 

68 draw = rng.random_sample(1) 

69 if draw[0] < 0.5 and add_nans: 

70 values[i] = None 

71 

72 draw = rng.random_sample(1) 

73 if draw[0] < 0.5: 

74 w1 = 'None' 

75 else: 

76 w1 = 'word' 

77 

78 draw = rng.random_sample(1) 

79 if draw[0] < 0.5: 

80 w2 = str('None') 

81 else: 

82 w2 = str('word') 

83 

84 if output is None: 

85 output = np.array([(ii, values[0], values[1], values[2], w1, w2)], dtype=dtype) 

86 else: 

87 size = output.size 

88 output = np.resize(output, size+1) 

89 output[size] = (ii, values[0], values[1], values[2], w1, w2) 

90 

91 if np.isnan(values[0]) and add_nans: 

92 v0 = 'NULL' 

93 else: 

94 v0 = str(values[0]) 

95 

96 if np.isnan(values[1]) and add_nans: 

97 v1 = 'NULL' 

98 else: 

99 v1 = str(values[1]) 

100 

101 if np.isnan(values[2]) and add_nans: 

102 v2 = 'NULL' 

103 else: 

104 v2 = str(values[2]) 

105 

106 cmd = '''INSERT INTO testTable VALUES (%s, %s, %s, %s, '%s', '%s')''' % (ii, v0, v1, v2, w1, w2) 

107 c.execute(cmd) 

108 

109 conn.commit() 

110 conn.close() 

111 return output 

112 

113 

114class myCannotBeNullDBObject(CatalogDBObject): 

115 driver = 'sqlite' 

116 database = 'cannotBeNullTest.db' 

117 tableid = 'testTable' 

118 objid = 'cannotBeNull' 

119 idColKey = 'id' 

120 columns = [('n5', 'n5', str, 40)] 

121 

122 

123class floatCannotBeNullCatalog(InstanceCatalog): 

124 """ 

125 This catalog class will not write rows with a null value in the n2 column 

126 """ 

127 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5'] 

128 cannot_be_null = ['n2'] 

129 

130 

131class strCannotBeNullCatalog(InstanceCatalog): 

132 """ 

133 This catalog class will not write rows with a null value in the n2 column 

134 """ 

135 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5'] 

136 cannot_be_null = ['n4'] 

137 

138 

139class unicodeCannotBeNullCatalog(InstanceCatalog): 

140 """ 

141 This catalog class will not write rows with a null value in the n2 column 

142 """ 

143 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5'] 

144 cannot_be_null = ['n5'] 

145 

146 

147class severalCannotBeNullCatalog(InstanceCatalog): 

148 """ 

149 This catalog class will not write rows with null values in the n2 or n4 columns 

150 """ 

151 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5'] 

152 cannot_be_null = ['n2', 'n4'] 

153 

154 

155class CanBeNullCatalog(InstanceCatalog): 

156 """ 

157 This catalog class will write all rows to the catalog 

158 """ 

159 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5'] 

160 catalog_type = 'canBeNull' 

161 

162 

163class testStellarCatalogClass(InstanceCatalog): 

164 column_outputs = ['raJ2000', 'decJ2000'] 

165 default_formats = {'f': '%le'} 

166 

167 

168class cartoonValueCatalog(InstanceCatalog): 

169 column_outputs = ['n1', 'n2'] 

170 default_formats = {'f': '%le'} 

171 

172 def get_difference(self): 

173 x = self.column_by_name('n1') 

174 y = self.column_by_name('n3') 

175 return x-y 

176 

177 

178class InstanceCatalogMetaDataTest(unittest.TestCase): 

179 """ 

180 This class will test how Instance catalog handles the metadata 

181 class variables (pointingRA, pointingDec, etc.) 

182 """ 

183 

184 @classmethod 

185 def setUpClass(cls): 

186 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-") 

187 cls.database = os.path.join(cls.scratch_dir, 'testInstanceCatalogDatabase.db') 

188 makeStarTestDB(filename=cls.database) 

189 

190 @classmethod 

191 def tearDownClass(cls): 

192 sims_clean_up() 

193 if os.path.exists(cls.scratch_dir): 

194 shutil.rmtree(cls.scratch_dir) 

195 

196 def setUp(self): 

197 self.myDB = myTestStars(driver='sqlite', database=self.database) 

198 

199 def tearDown(self): 

200 del self.myDB 

201 

202 def testObsMetaDataAssignment(self): 

203 """ 

204 Test that you get an error when you pass something that is not 

205 ObservationMetaData as obs_metadata 

206 """ 

207 

208 xx = 5.0 

209 self.assertRaises(ValueError, testStellarCatalogClass, self.myDB, obs_metadata=xx) 

210 

211 def testColumnArg(self): 

212 """ 

213 A unit test to make sure that the code allowing you to add 

214 new column_outputs to an InstanceCatalog using its constructor 

215 works properly. 

216 """ 

217 mjd = 5120.0 

218 RA = 1.5 

219 Dec = -1.1 

220 rotSkyPos = -10.0 

221 

222 testSite = Site(longitude=2.0, latitude=-1.0, height=4.0, 

223 temperature=100.0, pressure=500.0, humidity=0.1, 

224 lapseRate=0.1) 

225 

226 testObsMD = ObservationMetaData(site=testSite, 

227 mjd=mjd, pointingRA=RA, 

228 pointingDec=Dec, 

229 rotSkyPos=rotSkyPos, 

230 bandpassName = 'z') 

231 

232 # make sure the correct column names are returned 

233 # according to class definition 

234 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD) 

235 columnsShouldBe = ['raJ2000', 'decJ2000'] 

236 for col in testCat.iter_column_names(): 

237 if col in columnsShouldBe: 

238 columnsShouldBe.remove(col) 

239 else: 

240 raise RuntimeError('column %s returned; should not be there' % col) 

241 

242 self.assertEqual(len(columnsShouldBe), 0) 

243 

244 # make sure that new column names can be added 

245 newColumns = ['properMotionRa', 'properMotionDec'] 

246 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD, column_outputs=newColumns) 

247 columnsShouldBe = ['raJ2000', 'decJ2000', 'properMotionRa', 'properMotionDec'] 

248 for col in testCat.iter_column_names(): 

249 if col in columnsShouldBe: 

250 columnsShouldBe.remove(col) 

251 else: 

252 raise RuntimeError('column %s returned; should not be there' % col) 

253 

254 self.assertEqual(len(columnsShouldBe), 0) 

255 

256 # make sure that, if we include a duplicate column in newColumns, 

257 # the column is not duplicated 

258 newColumns = ['properMotionRa', 'properMotionDec', 'raJ2000'] 

259 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD, column_outputs=newColumns) 

260 columnsShouldBe = ['raJ2000', 'decJ2000', 'properMotionRa', 'properMotionDec'] 

261 

262 for col in columnsShouldBe: 

263 self.assertIn(col, testCat._actually_calculated_columns) 

264 

265 generatedColumns = [] 

266 for col in testCat.iter_column_names(): 

267 generatedColumns.append(col) 

268 if col in columnsShouldBe: 

269 columnsShouldBe.remove(col) 

270 else: 

271 raise RuntimeError('column %s returned; should not be there' % col) 

272 

273 self.assertEqual(len(columnsShouldBe), 0) 

274 self.assertEqual(len(generatedColumns), 4) 

275 

276 cat_name = os.path.join(self.scratch_dir, 'testArgCatalog.txt') 

277 testCat.write_catalog(cat_name) 

278 with open(cat_name, 'r') as inCat: 

279 lines = inCat.readlines() 

280 header = lines[0] 

281 header = header.strip('#') 

282 header = header.strip('\n') 

283 header = header.split(', ') 

284 self.assertIn('raJ2000', header) 

285 self.assertIn('decJ2000', header) 

286 self.assertIn('properMotionRa', header) 

287 self.assertIn('properMotionDec', header) 

288 if os.path.exists(cat_name): 

289 os.unlink(cat_name) 

290 

291 def testArgValues(self): 

292 """ 

293 Test that columns added using the contructor ags return the correct value 

294 """ 

295 with lsst.utils.tests.getTempFilePath(".db") as dbName: 

296 baselineData = createCannotBeNullTestDB(filename=dbName, add_nans=False) 

297 db = myCannotBeNullDBObject(driver='sqlite', database=dbName) 

298 dtype = np.dtype([('n1', float), ('n2', float), ('n3', float), ('difference', float)]) 

299 cat = cartoonValueCatalog(db, column_outputs = ['n3', 'difference']) 

300 

301 columns = ['n1', 'n2', 'n3', 'difference'] 

302 for col in columns: 

303 self.assertIn(col, cat._actually_calculated_columns) 

304 

305 cat_name = os.path.join(self.scratch_dir, 'cartoonValCat.txt') 

306 cat.write_catalog(cat_name) 

307 testData = np.genfromtxt(cat_name, dtype=dtype, delimiter=',') 

308 for testLine, controlLine in zip(testData, baselineData): 

309 self.assertAlmostEqual(testLine[0], controlLine['n1'], 6) 

310 self.assertAlmostEqual(testLine[1], controlLine['n2'], 6) 

311 self.assertAlmostEqual(testLine[2], controlLine['n3'], 6) 

312 self.assertAlmostEqual(testLine[3], controlLine['n1']-controlLine['n3'], 6) 

313 

314 if os.path.exists(cat_name): 

315 os.unlink(cat_name) 

316 

317 def testAllCalculatedColumns(self): 

318 """ 

319 Unit test to make sure that _actually_calculated_columns contains all of the dependent columns 

320 """ 

321 class otherCartoonValueCatalog(InstanceCatalog): 

322 column_outputs = ['n1', 'n2', 'difference'] 

323 

324 def get_difference(self): 

325 n1 = self.column_by_name('n1') 

326 n3 = self.column_by_name('n3') 

327 return n1-n3 

328 

329 with lsst.utils.tests.getTempFilePath(".db") as dbName: 

330 createCannotBeNullTestDB(filename=dbName, add_nans=False) 

331 db = myCannotBeNullDBObject(driver='sqlite', database=dbName) 

332 cat = otherCartoonValueCatalog(db) 

333 columns = ['n1', 'n2', 'n3', 'difference'] 

334 for col in columns: 

335 self.assertIn(col, cat._actually_calculated_columns) 

336 

337 

338class InstanceCatalogCannotBeNullTest(unittest.TestCase): 

339 

340 def setUp(self): 

341 self.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

342 # Force the class to understand where the DB is meant to be 

343 myCannotBeNullDBObject.database = os.path.join(self.scratch_dir, 

344 'cannotBeNullTest.db') 

345 self.baselineOutput = createCannotBeNullTestDB(dir=self.scratch_dir) 

346 

347 def tearDown(self): 

348 sims_clean_up() 

349 del self.baselineOutput 

350 if os.path.exists('cannotBeNullTest.db'): 

351 os.unlink('cannotBeNullTest.db') 

352 if os.path.exists(self.scratch_dir): 

353 shutil.rmtree(self.scratch_dir) 

354 

355 def testCannotBeNull(self): 

356 """ 

357 Test to make sure that the code for filtering out rows with null values 

358 in key rows works. 

359 """ 

360 

361 # each of these classes flags a different column with a different datatype as cannot_be_null 

362 availableCatalogs = [floatCannotBeNullCatalog, strCannotBeNullCatalog, unicodeCannotBeNullCatalog, 

363 severalCannotBeNullCatalog] 

364 dbobj = CatalogDBObject.from_objid('cannotBeNull') 

365 

366 ct_n2 = 0 # number of rows in floatCannotBeNullCatalog 

367 ct_n4 = 0 # number of rows in strCannotBeNullCatalog 

368 ct_n2_n4 = 0 # number of rows in severalCannotBeNullCatalog 

369 

370 for catClass in availableCatalogs: 

371 cat = catClass(dbobj) 

372 fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile.txt') 

373 cat.write_catalog(fileName) 

374 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64), 

375 ('n4', (str_type, 40)), ('n5', (str_type, 40))]) 

376 testData = np.genfromtxt(fileName, dtype=dtype, delimiter=',') 

377 

378 ct_good = 0 # a counter to keep track of the rows read in from the catalog 

379 ct_total = len(self.baselineOutput) 

380 

381 for i in range(len(self.baselineOutput)): 

382 

383 # self.baselineOutput contains all of the rows from the dbobj 

384 # first, we must assess whether or not the row we are currently 

385 # testing would, in fact, pass the cannot_be_null test 

386 validLine = True 

387 for col_name in cat.cannot_be_null: 

388 if (isinstance(self.baselineOutput[col_name][i], str) or 

389 isinstance(self.baselineOutput[col_name][i], str_type)): 

390 

391 if self.baselineOutput[col_name][i].strip().lower() == 'none': 

392 validLine = False 

393 else: 

394 if np.isnan(self.baselineOutput[col_name][i]): 

395 validLine = False 

396 

397 if validLine: 

398 if catClass is floatCannotBeNullCatalog: 

399 ct_n2 += 1 

400 elif catClass is strCannotBeNullCatalog: 

401 ct_n4 += 1 

402 elif catClass is severalCannotBeNullCatalog: 

403 ct_n2_n4 += 1 

404 

405 # if the row in self.baslineOutput should be in the catalog, we now check 

406 # that baseline and testData agree on column values (there are some gymnastics 

407 # here because you cannot do an == on NaN's 

408 for (k, xx) in enumerate(self.baselineOutput[i]): 

409 if k < 4: 

410 if not np.isnan(xx): 

411 msg = ('k: %d -- %s %s -- %s' % 

412 (k, str(xx), str(testData[ct_good][k]), cat.cannot_be_null)) 

413 self.assertAlmostEqual(xx, testData[ct_good][k], 3, msg=msg) 

414 else: 

415 np.testing.assert_equal(testData[ct_good][k], np.NaN) 

416 else: 

417 msg = ('%s (%s) is not %s (%s)' % 

418 (xx, type(xx), testData[ct_good][k], type(testData[ct_good][k]))) 

419 self.assertEqual(xx.strip(), testData[ct_good][k].strip(), msg=msg) 

420 ct_good += 1 

421 

422 self.assertEqual(ct_good, len(testData)) # make sure that we tested all of the testData rows 

423 msg = '%d >= %d' % (ct_good, ct_total) 

424 self.assertLess(ct_good, ct_total, msg=msg) # make sure that some rows did not make 

425 # it into the catalog 

426 

427 # make sure that severalCannotBeNullCatalog weeded out rows that were individually in 

428 # floatCannotBeNullCatalog or strCannotBeNullCatalog 

429 self.assertGreater(ct_n2, ct_n2_n4) 

430 self.assertGreater(ct_n4, ct_n2_n4) 

431 

432 if os.path.exists(fileName): 

433 os.unlink(fileName) 

434 

435 def testCannotBeNull_pre_screen(self): 

436 """ 

437 Check that writing a catalog with self._pre_screen = True produces 

438 the same results as writing one with self._pre_screen = False. 

439 """ 

440 

441 # each of these classes flags a different column with a different datatype as cannot_be_null 

442 availableCatalogs = [floatCannotBeNullCatalog, strCannotBeNullCatalog, unicodeCannotBeNullCatalog, 

443 severalCannotBeNullCatalog] 

444 dbobj = CatalogDBObject.from_objid('cannotBeNull') 

445 

446 for catClass in availableCatalogs: 

447 cat = catClass(dbobj) 

448 cat._pre_screen = True 

449 control_cat = catClass(dbobj) 

450 fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile_prescreen.txt') 

451 control_fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile_prescreen_control.txt') 

452 cat.write_catalog(fileName) 

453 control_cat.write_catalog(control_fileName) 

454 

455 with open(fileName, 'r') as test_file: 

456 test_lines = test_file.readlines() 

457 with open(control_fileName, 'r') as control_file: 

458 control_lines = control_file.readlines() 

459 for line in control_lines: 

460 self.assertIn(line, test_lines) 

461 for line in test_lines: 

462 self.assertIn(line, control_lines) 

463 

464 if os.path.exists(fileName): 

465 os.unlink(fileName) 

466 if os.path.exists(control_fileName): 

467 os.unlink(control_fileName) 

468 

469 def testCanBeNull(self): 

470 """ 

471 Test to make sure that we can still write all rows to catalogs, 

472 even those with null values in key columns 

473 """ 

474 dbobj = CatalogDBObject.from_objid('cannotBeNull') 

475 cat = dbobj.getCatalog('canBeNull') 

476 fileName = os.path.join(self.scratch_dir, 'canBeNullTestFile.txt') 

477 cat.write_catalog(fileName) 

478 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64), 

479 ('n4', (str_type, 40)), ('n5', (str_type, 40))]) 

480 testData = np.genfromtxt(fileName, dtype=dtype, delimiter=',') 

481 

482 for i in range(len(self.baselineOutput)): 

483 # make sure that all of the rows in self.baselineOutput are represented in 

484 # testData 

485 for (k, xx) in enumerate(self.baselineOutput[i]): 

486 if k < 4: 

487 if not np.isnan(xx): 

488 self.assertAlmostEqual(xx, testData[i][k], 3) 

489 else: 

490 np.testing.assert_equal(testData[i][k], np.NaN) 

491 else: 

492 msg = '%s is not %s' % (xx, testData[i][k]) 

493 self.assertEqual(xx.strip(), testData[i][k].strip(), msg=msg) 

494 

495 self.assertEqual(i, 99) 

496 self.assertEqual(len(testData), len(self.baselineOutput)) 

497 

498 if os.path.exists(fileName): 

499 os.unlink(fileName) 

500 

501 

502class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

503 pass 

504 

505if __name__ == "__main__": 505 ↛ 506line 505 didn't jump to line 506, because the condition on line 505 was never true

506 lsst.utils.tests.init() 

507 unittest.main()