Coverage for tests/testDBObject.py : 9%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from builtins import next
2from builtins import range
3import os
4import sqlite3
5import sys
7import numpy as np
8import unittest
9import warnings
10import tempfile
11import shutil
12import lsst.utils.tests
13from lsst.sims.maf.db import DBObject
14from lsst.sims.utils.CodeUtilities import sims_clean_up
16ROOT = os.path.abspath(os.path.dirname(__file__))
19def setup_module(module):
20 lsst.utils.tests.init()
23class DBObjectTestCase(unittest.TestCase):
25 @classmethod
26 def setUpClass(cls):
27 """
28 Create a database with two tables of meaningless data to make sure that JOIN queries
29 can be executed using DBObject
30 """
31 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-')
32 cls.db_name = os.path.join(cls.scratch_dir, 'testDBObjectDB.db')
33 if os.path.exists(cls.db_name):
34 os.unlink(cls.db_name)
36 conn = sqlite3.connect(cls.db_name)
37 c = conn.cursor()
38 try:
39 c.execute('''CREATE TABLE intTable (id int, twice int, thrice int)''')
40 conn.commit()
41 except:
42 raise RuntimeError("Error creating database.")
44 for ii in range(100):
45 ll = 2*ii
46 jj = 2*ll
47 kk = 3*ll
48 cmd = '''INSERT INTO intTable VALUES (%s, %s, %s)''' % (ll, jj, kk)
49 c.execute(cmd)
51 conn.commit()
53 c = conn.cursor()
54 try:
55 c.execute('''CREATE TABLE doubleTable (id int, sqrt float, log float)''')
56 conn.commit()
57 except:
58 raise RuntimeError("Error creating database (double).")
59 for ii in range(200):
60 ll = ii + 1
61 nn = np.sqrt(float(ll))
62 mm = np.log(float(ll))
64 cmd = '''INSERT INTO doubleTable VALUES (%s, %s, %s)''' % (ll, nn, mm)
65 c.execute(cmd)
66 conn.commit()
68 try:
69 c.execute('''CREATE TABLE junkTable (id int, sqrt float, log float)''')
70 conn.commit()
71 except:
72 raise RuntimeError("Error creating database (double).")
73 for ii in range(200):
74 ll = ii + 1
75 nn = np.sqrt(float(ll))
76 mm = np.log(float(ll))
78 cmd = '''INSERT INTO junkTable VALUES (%s, %s, %s)''' % (ll, nn, mm)
79 c.execute(cmd)
81 conn.commit()
82 conn.close()
84 @classmethod
85 def tearDownClass(cls):
86 sims_clean_up()
87 if os.path.exists(cls.db_name):
88 os.unlink(cls.db_name)
89 if os.path.exists(cls.scratch_dir):
90 shutil.rmtree(cls.scratch_dir)
92 def setUp(self):
93 self.driver = 'sqlite'
95 def testTableNames(self):
96 """
97 Test the method that returns the names of tables in a database
98 """
99 dbobj = DBObject(driver=self.driver, database=self.db_name)
100 names = dbobj.get_table_names()
101 self.assertEqual(len(names), 3)
102 self.assertIn('doubleTable', names)
103 self.assertIn('intTable', names)
105 def testReadOnlyFilter(self):
106 """
107 Test that the filters we placed on queries made with execute_aribtrary()
108 work
109 """
110 dbobj = DBObject(driver=self.driver, database=self.db_name)
111 controlQuery = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
112 controlQuery += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
113 dbobj.execute_arbitrary(controlQuery)
115 # make sure that execute_arbitrary only accepts strings
116 query = ['a', 'list']
117 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
119 # check that our filter catches different capitalization permutations of the
120 # verboten commands
121 query = 'DROP TABLE junkTable'
122 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
123 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
124 query = 'DELETE FROM junkTable WHERE id=4'
125 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
126 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
127 query = 'UPDATE junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
128 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
129 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
130 query = 'INSERT INTO junkTable VALUES (9999, 1.0, 1.0)'
131 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
132 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
134 query = 'Drop Table junkTable'
135 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
136 query = 'Delete FROM junkTable WHERE id=4'
137 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
138 query = 'Update junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
139 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
140 query = 'Insert INTO junkTable VALUES (9999, 1.0, 1.0)'
141 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
143 query = 'dRoP TaBlE junkTable'
144 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
145 query = 'dElEtE FROM junkTable WHERE id=4'
146 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
147 query = 'uPdAtE junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
148 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
149 query = 'iNsErT INTO junkTable VALUES (9999, 1.0, 1.0)'
150 self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
152 def testColumnNames(self):
153 """
154 Test the method that returns the names of columns in a table
155 """
156 dbobj = DBObject(driver=self.driver, database=self.db_name)
157 names = dbobj.get_column_names('doubleTable')
158 self.assertEqual(len(names), 3)
159 self.assertIn('id', names)
160 self.assertIn('sqrt', names)
161 self.assertIn('log', names)
163 names = dbobj.get_column_names('intTable')
164 self.assertEqual(len(names), 3)
165 self.assertIn('id', names)
166 self.assertIn('twice', names)
167 self.assertIn('thrice', names)
169 names = dbobj.get_column_names()
170 keys = ['doubleTable', 'intTable', 'junkTable']
171 for kk in names:
172 self.assertIn(kk, keys)
174 self.assertEqual(len(names['doubleTable']), 3)
175 self.assertEqual(len(names['intTable']), 3)
176 self.assertIn('id', names['doubleTable'])
177 self.assertIn('sqrt', names['doubleTable'])
178 self.assertIn('log', names['doubleTable'])
179 self.assertIn('id', names['intTable'])
180 self.assertIn('twice', names['intTable'])
181 self.assertIn('thrice', names['intTable'])
183 def testSingleTableQuery(self):
184 """
185 Test a query on a single table (using chunk iterator)
186 """
187 dbobj = DBObject(driver=self.driver, database=self.db_name)
188 query = 'SELECT id, sqrt FROM doubleTable'
189 results = dbobj.get_chunk_iterator(query)
191 dtype = [('id', int),
192 ('sqrt', float)]
194 i = 1
195 for chunk in results:
196 for row in chunk:
197 self.assertEqual(row[0], i)
198 self.assertAlmostEqual(row[1], np.sqrt(i))
199 self.assertEqual(dtype, row.dtype)
200 i += 1
202 self.assertEqual(i, 201)
204 def testDtype(self):
205 """
206 Test that passing dtype to a query works
208 (also test q query on a single table using .execute_arbitrary() directly
209 """
210 dbobj = DBObject(driver=self.driver, database=self.db_name)
211 query = 'SELECT id, log FROM doubleTable'
212 dtype = [('id', int), ('log', float)]
213 results = dbobj.execute_arbitrary(query, dtype = dtype)
215 self.assertEqual(results.dtype, dtype)
216 for xx in results:
217 self.assertAlmostEqual(np.log(xx[0]), xx[1], 6)
219 self.assertEqual(len(results), 200)
221 results = dbobj.get_chunk_iterator(query, chunk_size=10, dtype=dtype)
222 next(results)
223 for chunk in results:
224 self.assertEqual(chunk.dtype, dtype)
226 def testJoin(self):
227 """
228 Test a join
229 """
230 dbobj = DBObject(driver=self.driver, database=self.db_name)
231 query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
232 query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
233 results = dbobj.get_chunk_iterator(query, chunk_size=10)
235 dtype = [
236 ('id', int),
237 ('id_1', int),
238 ('log', float),
239 ('thrice', int)]
241 i = 0
242 for chunk in results:
243 if i < 90:
244 self.assertEqual(len(chunk), 10)
245 for row in chunk:
246 self.assertEqual(2*(i+1), row[0])
247 self.assertEqual(row[0], row[1])
248 self.assertAlmostEqual(np.log(row[0]), row[2], 6)
249 self.assertEqual(3*row[0], row[3])
250 self.assertEqual(dtype, row.dtype)
251 i += 1
252 self.assertEqual(i, 99)
253 # make sure that we found all the matches whe should have
255 results = dbobj.execute_arbitrary(query)
256 self.assertEqual(dtype, results.dtype)
257 i = 0
258 for row in results:
259 self.assertEqual(2*(i+1), row[0])
260 self.assertEqual(row[0], row[1])
261 self.assertAlmostEqual(np.log(row[0]), row[2], 6)
262 self.assertEqual(3*row[0], row[3])
263 i += 1
264 self.assertEqual(i, 99)
265 # make sure we found all the matches we should have
267 def testMinMax(self):
268 """
269 Test queries on SQL functions by using the MIN and MAX functions
270 """
271 dbobj = DBObject(driver=self.driver, database=self.db_name)
272 query = 'SELECT MAX(thrice), MIN(thrice) FROM intTable'
273 results = dbobj.execute_arbitrary(query)
274 self.assertEqual(results[0][0], 594)
275 self.assertEqual(results[0][1], 0)
277 dtype = [('MAXthrice', int), ('MINthrice', int)]
278 self.assertEqual(results.dtype, dtype)
280 def testPassingConnection(self):
281 """
282 Repeat the test from testJoin, but with a DBObject whose connection was passed
283 directly from another DBObject, to make sure that passing a connection works
284 """
285 dbobj_base = DBObject(driver=self.driver, database=self.db_name)
286 dbobj = DBObject(connection=dbobj_base.connection)
287 query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
288 query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
289 results = dbobj.get_chunk_iterator(query, chunk_size=10)
291 dtype = [
292 ('id', int),
293 ('id_1', int),
294 ('log', float),
295 ('thrice', int)]
297 i = 0
298 for chunk in results:
299 if i < 90:
300 self.assertEqual(len(chunk), 10)
301 for row in chunk:
302 self.assertEqual(2*(i+1), row[0])
303 self.assertEqual(row[0], row[1])
304 self.assertAlmostEqual(np.log(row[0]), row[2], 6)
305 self.assertEqual(3*row[0], row[3])
306 self.assertEqual(dtype, row.dtype)
307 i += 1
308 self.assertEqual(i, 99)
309 # make sure that we found all the matches whe should have
311 results = dbobj.execute_arbitrary(query)
312 self.assertEqual(dtype, results.dtype)
313 i = 0
314 for row in results:
315 self.assertEqual(2*(i+1), row[0])
316 self.assertEqual(row[0], row[1])
317 self.assertAlmostEqual(np.log(row[0]), row[2], 6)
318 self.assertEqual(3*row[0], row[3])
319 i += 1
320 self.assertEqual(i, 99)
321 # make sure we found all the matches we should have
323 def testValidationErrors(self):
324 """ Test that appropriate errors and warnings are thrown when connecting
325 """
327 # missing database
328 self.assertRaises(AttributeError, DBObject, driver=self.driver)
329 # missing driver
330 self.assertRaises(AttributeError, DBObject, database=self.db_name)
331 # missing host
332 self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql')
333 # missing port
334 self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql', host='localhost')
336 def testDetectDtype(self):
337 """
338 Test that DBObject.execute_arbitrary can correctly detect the dtypes
339 of the rows it is returning
340 """
341 db_name = os.path.join(self.scratch_dir, 'testDBObject_dtype_DB.db')
342 if os.path.exists(db_name):
343 os.unlink(db_name)
345 conn = sqlite3.connect(db_name)
346 c = conn.cursor()
347 try:
348 c.execute('''CREATE TABLE testTable (id int, val real, sentence int)''')
349 conn.commit()
350 except:
351 raise RuntimeError("Error creating database.")
353 for ii in range(10):
354 cmd = '''INSERT INTO testTable VALUES (%d, %.5f, %s)''' % (ii, 5.234*ii, "'this, has; punctuation'")
355 c.execute(cmd)
357 conn.commit()
358 conn.close()
360 db = DBObject(database=db_name, driver='sqlite')
361 query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0'
362 results = db.execute_arbitrary(query)
364 np.testing.assert_array_equal(results['id'], np.arange(0,9,2,dtype=int))
365 np.testing.assert_array_almost_equal(results['val'], 5.234*np.arange(0,9,2), decimal=5)
366 for sentence in results['sentence']:
367 self.assertEqual(sentence, 'this, has; punctuation')
369 self.assertEqual(str(results.dtype['id']), 'int64')
370 self.assertEqual(str(results.dtype['val']), 'float64')
371 if sys.version_info.major == 2:
372 self.assertEqual(str(results.dtype['sentence']), '|S22')
373 else:
374 self.assertEqual(str(results.dtype['sentence']), '<U22')
375 self.assertEqual(len(results.dtype), 3)
377 # now test that it works when getting a ChunkIterator
378 chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3)
379 ct = 0
380 for chunk in chunk_iter:
382 self.assertEqual(str(chunk.dtype['id']), 'int64')
383 self.assertEqual(str(chunk.dtype['val']), 'float64')
384 if sys.version_info.major == 2:
385 self.assertEqual(str(results.dtype['sentence']), '|S22')
386 else:
387 self.assertEqual(str(results.dtype['sentence']), '<U22')
388 self.assertEqual(len(chunk.dtype), 3)
390 for line in chunk:
391 ct += 1
392 self.assertEqual(line['sentence'], 'this, has; punctuation')
393 self.assertAlmostEqual(line['val'], line['id']*5.234, 5)
394 self.assertEqual(line['id']%2, 0)
396 self.assertEqual(ct, 5)
398 # test that doing a different query does not spoil dtype detection
399 query = 'SELECT id, sentence FROM testTable WHERE id%2 = 0'
400 results = db.execute_arbitrary(query)
401 self.assertGreater(len(results), 0)
402 self.assertEqual(len(results.dtype.names), 2)
403 self.assertEqual(str(results.dtype['id']), 'int64')
404 if sys.version_info.major == 2:
405 self.assertEqual(str(results.dtype['sentence']), '|S22')
406 else:
407 self.assertEqual(str(results.dtype['sentence']), '<U22')
409 query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0'
410 chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3)
411 ct = 0
412 for chunk in chunk_iter:
414 self.assertEqual(str(chunk.dtype['id']), 'int64')
415 self.assertEqual(str(chunk.dtype['val']), 'float64')
416 if sys.version_info.major == 2:
417 self.assertEqual(str(results.dtype['sentence']), '|S22')
418 else:
419 self.assertEqual(str(results.dtype['sentence']), '<U22')
420 self.assertEqual(len(chunk.dtype), 3)
422 for line in chunk:
423 ct += 1
424 self.assertEqual(line['sentence'], 'this, has; punctuation')
425 self.assertAlmostEqual(line['val'], line['id']*5.234, 5)
426 self.assertEqual(line['id']%2, 0)
428 self.assertEqual(ct, 5)
430 if os.path.exists(db_name):
431 os.unlink(db_name)
433class MemoryTestClass(lsst.utils.tests.MemoryTestCase):
434 pass
436if __name__ == "__main__": 436 ↛ 437line 436 didn't jump to line 437, because the condition on line 436 was never true
437 lsst.utils.tests.init()
438 unittest.main()