Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import next 

2from builtins import range 

3import os 

4import sqlite3 

5import sys 

6 

7import numpy as np 

8import unittest 

9import warnings 

10import tempfile 

11import shutil 

12import lsst.utils.tests 

13from lsst.sims.catalogs.db import DBObject 

14from lsst.sims.utils.CodeUtilities import sims_clean_up 

15 

16ROOT = os.path.abspath(os.path.dirname(__file__)) 

17 

18 

19def setup_module(module): 

20 lsst.utils.tests.init() 

21 

22 

23class DBObjectTestCase(unittest.TestCase): 

24 

25 @classmethod 

26 def setUpClass(cls): 

27 """ 

28 Create a database with two tables of meaningless data to make sure that JOIN queries 

29 can be executed using DBObject 

30 """ 

31 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

32 cls.db_name = os.path.join(cls.scratch_dir, 'testDBObjectDB.db') 

33 if os.path.exists(cls.db_name): 

34 os.unlink(cls.db_name) 

35 

36 conn = sqlite3.connect(cls.db_name) 

37 c = conn.cursor() 

38 try: 

39 c.execute('''CREATE TABLE intTable (id int, twice int, thrice int)''') 

40 conn.commit() 

41 except: 

42 raise RuntimeError("Error creating database.") 

43 

44 for ii in range(100): 

45 ll = 2*ii 

46 jj = 2*ll 

47 kk = 3*ll 

48 cmd = '''INSERT INTO intTable VALUES (%s, %s, %s)''' % (ll, jj, kk) 

49 c.execute(cmd) 

50 

51 conn.commit() 

52 

53 c = conn.cursor() 

54 try: 

55 c.execute('''CREATE TABLE doubleTable (id int, sqrt float, log float)''') 

56 conn.commit() 

57 except: 

58 raise RuntimeError("Error creating database (double).") 

59 for ii in range(200): 

60 ll = ii + 1 

61 nn = np.sqrt(float(ll)) 

62 mm = np.log(float(ll)) 

63 

64 cmd = '''INSERT INTO doubleTable VALUES (%s, %s, %s)''' % (ll, nn, mm) 

65 c.execute(cmd) 

66 conn.commit() 

67 

68 try: 

69 c.execute('''CREATE TABLE junkTable (id int, sqrt float, log float)''') 

70 conn.commit() 

71 except: 

72 raise RuntimeError("Error creating database (double).") 

73 for ii in range(200): 

74 ll = ii + 1 

75 nn = np.sqrt(float(ll)) 

76 mm = np.log(float(ll)) 

77 

78 cmd = '''INSERT INTO junkTable VALUES (%s, %s, %s)''' % (ll, nn, mm) 

79 c.execute(cmd) 

80 

81 conn.commit() 

82 conn.close() 

83 

84 @classmethod 

85 def tearDownClass(cls): 

86 sims_clean_up() 

87 if os.path.exists(cls.db_name): 

88 os.unlink(cls.db_name) 

89 if os.path.exists(cls.scratch_dir): 

90 shutil.rmtree(cls.scratch_dir) 

91 

92 def setUp(self): 

93 self.driver = 'sqlite' 

94 

95 def testTableNames(self): 

96 """ 

97 Test the method that returns the names of tables in a database 

98 """ 

99 dbobj = DBObject(driver=self.driver, database=self.db_name) 

100 names = dbobj.get_table_names() 

101 self.assertEqual(len(names), 3) 

102 self.assertIn('doubleTable', names) 

103 self.assertIn('intTable', names) 

104 

105 def testReadOnlyFilter(self): 

106 """ 

107 Test that the filters we placed on queries made with execute_aribtrary() 

108 work 

109 """ 

110 dbobj = DBObject(driver=self.driver, database=self.db_name) 

111 controlQuery = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' 

112 controlQuery += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' 

113 dbobj.execute_arbitrary(controlQuery) 

114 

115 # make sure that execute_arbitrary only accepts strings 

116 query = ['a', 'list'] 

117 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

118 

119 # check that our filter catches different capitalization permutations of the 

120 # verboten commands 

121 query = 'DROP TABLE junkTable' 

122 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

123 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) 

124 query = 'DELETE FROM junkTable WHERE id=4' 

125 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

126 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) 

127 query = 'UPDATE junkTable SET sqrt=0.0, log=0.0 WHERE id=4' 

128 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

129 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) 

130 query = 'INSERT INTO junkTable VALUES (9999, 1.0, 1.0)' 

131 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

132 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) 

133 

134 query = 'Drop Table junkTable' 

135 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

136 query = 'Delete FROM junkTable WHERE id=4' 

137 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

138 query = 'Update junkTable SET sqrt=0.0, log=0.0 WHERE id=4' 

139 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

140 query = 'Insert INTO junkTable VALUES (9999, 1.0, 1.0)' 

141 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

142 

143 query = 'dRoP TaBlE junkTable' 

144 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

145 query = 'dElEtE FROM junkTable WHERE id=4' 

146 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

147 query = 'uPdAtE junkTable SET sqrt=0.0, log=0.0 WHERE id=4' 

148 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

149 query = 'iNsErT INTO junkTable VALUES (9999, 1.0, 1.0)' 

150 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) 

151 

152 def testColumnNames(self): 

153 """ 

154 Test the method that returns the names of columns in a table 

155 """ 

156 dbobj = DBObject(driver=self.driver, database=self.db_name) 

157 names = dbobj.get_column_names('doubleTable') 

158 self.assertEqual(len(names), 3) 

159 self.assertIn('id', names) 

160 self.assertIn('sqrt', names) 

161 self.assertIn('log', names) 

162 

163 names = dbobj.get_column_names('intTable') 

164 self.assertEqual(len(names), 3) 

165 self.assertIn('id', names) 

166 self.assertIn('twice', names) 

167 self.assertIn('thrice', names) 

168 

169 names = dbobj.get_column_names() 

170 keys = ['doubleTable', 'intTable', 'junkTable'] 

171 for kk in names: 

172 self.assertIn(kk, keys) 

173 

174 self.assertEqual(len(names['doubleTable']), 3) 

175 self.assertEqual(len(names['intTable']), 3) 

176 self.assertIn('id', names['doubleTable']) 

177 self.assertIn('sqrt', names['doubleTable']) 

178 self.assertIn('log', names['doubleTable']) 

179 self.assertIn('id', names['intTable']) 

180 self.assertIn('twice', names['intTable']) 

181 self.assertIn('thrice', names['intTable']) 

182 

183 def testSingleTableQuery(self): 

184 """ 

185 Test a query on a single table (using chunk iterator) 

186 """ 

187 dbobj = DBObject(driver=self.driver, database=self.db_name) 

188 query = 'SELECT id, sqrt FROM doubleTable' 

189 results = dbobj.get_chunk_iterator(query) 

190 

191 dtype = [('id', int), 

192 ('sqrt', float)] 

193 

194 i = 1 

195 for chunk in results: 

196 for row in chunk: 

197 self.assertEqual(row[0], i) 

198 self.assertAlmostEqual(row[1], np.sqrt(i)) 

199 self.assertEqual(dtype, row.dtype) 

200 i += 1 

201 

202 self.assertEqual(i, 201) 

203 

204 def testDtype(self): 

205 """ 

206 Test that passing dtype to a query works 

207 

208 (also test q query on a single table using .execute_arbitrary() directly 

209 """ 

210 dbobj = DBObject(driver=self.driver, database=self.db_name) 

211 query = 'SELECT id, log FROM doubleTable' 

212 dtype = [('id', int), ('log', float)] 

213 results = dbobj.execute_arbitrary(query, dtype = dtype) 

214 

215 self.assertEqual(results.dtype, dtype) 

216 for xx in results: 

217 self.assertAlmostEqual(np.log(xx[0]), xx[1], 6) 

218 

219 self.assertEqual(len(results), 200) 

220 

221 results = dbobj.get_chunk_iterator(query, chunk_size=10, dtype=dtype) 

222 next(results) 

223 for chunk in results: 

224 self.assertEqual(chunk.dtype, dtype) 

225 

226 def testJoin(self): 

227 """ 

228 Test a join 

229 """ 

230 dbobj = DBObject(driver=self.driver, database=self.db_name) 

231 query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' 

232 query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' 

233 results = dbobj.get_chunk_iterator(query, chunk_size=10) 

234 

235 dtype = [ 

236 ('id', int), 

237 ('id_1', int), 

238 ('log', float), 

239 ('thrice', int)] 

240 

241 i = 0 

242 for chunk in results: 

243 if i < 90: 

244 self.assertEqual(len(chunk), 10) 

245 for row in chunk: 

246 self.assertEqual(2*(i+1), row[0]) 

247 self.assertEqual(row[0], row[1]) 

248 self.assertAlmostEqual(np.log(row[0]), row[2], 6) 

249 self.assertEqual(3*row[0], row[3]) 

250 self.assertEqual(dtype, row.dtype) 

251 i += 1 

252 self.assertEqual(i, 99) 

253 # make sure that we found all the matches whe should have 

254 

255 results = dbobj.execute_arbitrary(query) 

256 self.assertEqual(dtype, results.dtype) 

257 i = 0 

258 for row in results: 

259 self.assertEqual(2*(i+1), row[0]) 

260 self.assertEqual(row[0], row[1]) 

261 self.assertAlmostEqual(np.log(row[0]), row[2], 6) 

262 self.assertEqual(3*row[0], row[3]) 

263 i += 1 

264 self.assertEqual(i, 99) 

265 # make sure we found all the matches we should have 

266 

267 def testMinMax(self): 

268 """ 

269 Test queries on SQL functions by using the MIN and MAX functions 

270 """ 

271 dbobj = DBObject(driver=self.driver, database=self.db_name) 

272 query = 'SELECT MAX(thrice), MIN(thrice) FROM intTable' 

273 results = dbobj.execute_arbitrary(query) 

274 self.assertEqual(results[0][0], 594) 

275 self.assertEqual(results[0][1], 0) 

276 

277 dtype = [('MAXthrice', int), ('MINthrice', int)] 

278 self.assertEqual(results.dtype, dtype) 

279 

280 def testPassingConnection(self): 

281 """ 

282 Repeat the test from testJoin, but with a DBObject whose connection was passed 

283 directly from another DBObject, to make sure that passing a connection works 

284 """ 

285 dbobj_base = DBObject(driver=self.driver, database=self.db_name) 

286 dbobj = DBObject(connection=dbobj_base.connection) 

287 query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' 

288 query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' 

289 results = dbobj.get_chunk_iterator(query, chunk_size=10) 

290 

291 dtype = [ 

292 ('id', int), 

293 ('id_1', int), 

294 ('log', float), 

295 ('thrice', int)] 

296 

297 i = 0 

298 for chunk in results: 

299 if i < 90: 

300 self.assertEqual(len(chunk), 10) 

301 for row in chunk: 

302 self.assertEqual(2*(i+1), row[0]) 

303 self.assertEqual(row[0], row[1]) 

304 self.assertAlmostEqual(np.log(row[0]), row[2], 6) 

305 self.assertEqual(3*row[0], row[3]) 

306 self.assertEqual(dtype, row.dtype) 

307 i += 1 

308 self.assertEqual(i, 99) 

309 # make sure that we found all the matches whe should have 

310 

311 results = dbobj.execute_arbitrary(query) 

312 self.assertEqual(dtype, results.dtype) 

313 i = 0 

314 for row in results: 

315 self.assertEqual(2*(i+1), row[0]) 

316 self.assertEqual(row[0], row[1]) 

317 self.assertAlmostEqual(np.log(row[0]), row[2], 6) 

318 self.assertEqual(3*row[0], row[3]) 

319 i += 1 

320 self.assertEqual(i, 99) 

321 # make sure we found all the matches we should have 

322 

323 def testValidationErrors(self): 

324 """ Test that appropriate errors and warnings are thrown when connecting 

325 """ 

326 

327 with warnings.catch_warnings(record=True) as w: 

328 warnings.simplefilter("always") 

329 DBObject('sqlite:///' + self.db_name) 

330 assert len(w) == 1 

331 

332 # missing database 

333 self.assertRaises(AttributeError, DBObject, driver=self.driver) 

334 # missing driver 

335 self.assertRaises(AttributeError, DBObject, database=self.db_name) 

336 # missing host 

337 self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql') 

338 # missing port 

339 self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql', host='localhost') 

340 

341 def testDetectDtype(self): 

342 """ 

343 Test that DBObject.execute_arbitrary can correctly detect the dtypes 

344 of the rows it is returning 

345 """ 

346 db_name = os.path.join(self.scratch_dir, 'testDBObject_dtype_DB.db') 

347 if os.path.exists(db_name): 

348 os.unlink(db_name) 

349 

350 conn = sqlite3.connect(db_name) 

351 c = conn.cursor() 

352 try: 

353 c.execute('''CREATE TABLE testTable (id int, val real, sentence int)''') 

354 conn.commit() 

355 except: 

356 raise RuntimeError("Error creating database.") 

357 

358 for ii in range(10): 

359 cmd = '''INSERT INTO testTable VALUES (%d, %.5f, %s)''' % (ii, 5.234*ii, "'this, has; punctuation'") 

360 c.execute(cmd) 

361 

362 conn.commit() 

363 conn.close() 

364 

365 db = DBObject(database=db_name, driver='sqlite') 

366 query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0' 

367 results = db.execute_arbitrary(query) 

368 

369 np.testing.assert_array_equal(results['id'], np.arange(0,9,2,dtype=int)) 

370 np.testing.assert_array_almost_equal(results['val'], 5.234*np.arange(0,9,2), decimal=5) 

371 for sentence in results['sentence']: 

372 self.assertEqual(sentence, 'this, has; punctuation') 

373 

374 self.assertEqual(str(results.dtype['id']), 'int64') 

375 self.assertEqual(str(results.dtype['val']), 'float64') 

376 if sys.version_info.major == 2: 

377 self.assertEqual(str(results.dtype['sentence']), '|S22') 

378 else: 

379 self.assertEqual(str(results.dtype['sentence']), '<U22') 

380 self.assertEqual(len(results.dtype), 3) 

381 

382 # now test that it works when getting a ChunkIterator 

383 chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3) 

384 ct = 0 

385 for chunk in chunk_iter: 

386 

387 self.assertEqual(str(chunk.dtype['id']), 'int64') 

388 self.assertEqual(str(chunk.dtype['val']), 'float64') 

389 if sys.version_info.major == 2: 

390 self.assertEqual(str(results.dtype['sentence']), '|S22') 

391 else: 

392 self.assertEqual(str(results.dtype['sentence']), '<U22') 

393 self.assertEqual(len(chunk.dtype), 3) 

394 

395 for line in chunk: 

396 ct += 1 

397 self.assertEqual(line['sentence'], 'this, has; punctuation') 

398 self.assertAlmostEqual(line['val'], line['id']*5.234, 5) 

399 self.assertEqual(line['id']%2, 0) 

400 

401 self.assertEqual(ct, 5) 

402 

403 # test that doing a different query does not spoil dtype detection 

404 query = 'SELECT id, sentence FROM testTable WHERE id%2 = 0' 

405 results = db.execute_arbitrary(query) 

406 self.assertGreater(len(results), 0) 

407 self.assertEqual(len(results.dtype.names), 2) 

408 self.assertEqual(str(results.dtype['id']), 'int64') 

409 if sys.version_info.major == 2: 

410 self.assertEqual(str(results.dtype['sentence']), '|S22') 

411 else: 

412 self.assertEqual(str(results.dtype['sentence']), '<U22') 

413 

414 query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0' 

415 chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3) 

416 ct = 0 

417 for chunk in chunk_iter: 

418 

419 self.assertEqual(str(chunk.dtype['id']), 'int64') 

420 self.assertEqual(str(chunk.dtype['val']), 'float64') 

421 if sys.version_info.major == 2: 

422 self.assertEqual(str(results.dtype['sentence']), '|S22') 

423 else: 

424 self.assertEqual(str(results.dtype['sentence']), '<U22') 

425 self.assertEqual(len(chunk.dtype), 3) 

426 

427 for line in chunk: 

428 ct += 1 

429 self.assertEqual(line['sentence'], 'this, has; punctuation') 

430 self.assertAlmostEqual(line['val'], line['id']*5.234, 5) 

431 self.assertEqual(line['id']%2, 0) 

432 

433 self.assertEqual(ct, 5) 

434 

435 if os.path.exists(db_name): 

436 os.unlink(db_name) 

437 

438class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

439 pass 

440 

441if __name__ == "__main__": 441 ↛ 442line 441 didn't jump to line 442, because the condition on line 441 was never true

442 lsst.utils.tests.init() 

443 unittest.main()