Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import next 

2from builtins import range 

3import unittest 

4import sqlite3 

5import os 

6import numpy as np 

7import tempfile 

8import shutil 

9 

10import lsst.utils.tests 

11from lsst.sims.utils.CodeUtilities import sims_clean_up 

12from lsst.sims.catalogs.db import CatalogDBObject, DBObject 

13 

14ROOT = os.path.abspath(os.path.dirname(__file__)) 

15 

16 

17def setup_module(module): 

18 lsst.utils.tests.init() 

19 

20 

21class CachingTestCase(unittest.TestCase): 

22 """ 

23 This class will contain tests to make sure that CatalogDBObject is 

24 correctly using its _connection_cache 

25 """ 

26 

27 @classmethod 

28 def setUpClass(cls): 

29 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-") 

30 cls.db_name = os.path.join(cls.scratch_dir, "connection_cache_test_db.db") 

31 if os.path.exists(cls.db_name): 

32 os.unlink(cls.db_name) 

33 

34 conn = sqlite3.connect(cls.db_name) 

35 c = conn.cursor() 

36 c.execute('''CREATE TABLE test (id int, i1 int, i2 int)''') 

37 for ii in range(5): 

38 c.execute('''INSERT INTO test VALUES (%i, %i, %i)''' % (ii, ii*ii, -ii)) 

39 conn.commit() 

40 conn.close() 

41 

42 @classmethod 

43 def tearDownClass(cls): 

44 sims_clean_up() 

45 if os.path.exists(cls.db_name): 

46 os.unlink(cls.db_name) 

47 if os.path.exists(cls.scratch_dir): 

48 shutil.rmtree(cls.scratch_dir) 

49 

50 

51 def test_catalog_db_object_cacheing(self): 

52 """ 

53 Test that opening multiple CatalogDBObjects that connect to the same 

54 database only results in one connection being opened and used. We 

55 will test this by instantiating two CatalogDBObjects and a DBObject 

56 that connect to the same database. We will then test that the two 

57 CatalogDBObjects' connections are identical, but that the DBObject has 

58 its own connection. 

59 """ 

60 

61 sims_clean_up() 

62 self.assertEqual(len(CatalogDBObject._connection_cache), 0) 

63 

64 class DbClass1(CatalogDBObject): 

65 database = self.db_name 

66 port = None 

67 host = None 

68 driver = 'sqlite' 

69 tableid = 'test' 

70 idColKey = 'id' 

71 objid = 'test_db_class_1' 

72 

73 columns = [('identification', 'id')] 

74 

75 class DbClass2(CatalogDBObject): 

76 database = self.db_name 

77 port = None 

78 host = None 

79 driver = 'sqlite' 

80 tableid = 'test' 

81 idColKey = 'id' 

82 objid = 'test_db_class_2' 

83 

84 columns = [('other', 'i1')] 

85 

86 db1 = DbClass1() 

87 db2 = DbClass2() 

88 self.assertEqual(id(db1.connection), id(db2.connection)) 

89 self.assertEqual(len(CatalogDBObject._connection_cache), 1) 

90 

91 db3 = DBObject(database=self.db_name, driver='sqlite', host=None, port=None) 

92 self.assertNotEqual(id(db1.connection), id(db3.connection)) 

93 

94 self.assertEqual(len(CatalogDBObject._connection_cache), 1) 

95 

96 # check that if we had passed db1.connection to a DBObject, 

97 # the connections would be identical 

98 db4 = DBObject(connection=db1.connection) 

99 self.assertEqual(id(db4.connection), id(db1.connection)) 

100 

101 self.assertEqual(len(CatalogDBObject._connection_cache), 1) 

102 

103 # verify that db1 and db2 are both useable 

104 results = db1.query_columns(colnames=['id', 'i1', 'i2', 'identification']) 

105 results = next(results) 

106 self.assertEqual(len(results), 5) 

107 np.testing.assert_array_equal(results['id'], list(range(5))) 

108 np.testing.assert_array_equal(results['id'], results['identification']) 

109 np.testing.assert_array_equal(results['id']**2, results['i1']) 

110 np.testing.assert_array_equal(results['id']*(-1), results['i2']) 

111 

112 results = db2.query_columns(colnames=['id', 'i1', 'i2', 'other']) 

113 results = next(results) 

114 self.assertEqual(len(results), 5) 

115 np.testing.assert_array_equal(results['id'], list(range(5))) 

116 np.testing.assert_array_equal(results['id']**2, results['i1']) 

117 np.testing.assert_array_equal(results['i1'], results['other']) 

118 np.testing.assert_array_equal(results['id']*(-1), results['i2']) 

119 

120 

121class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

122 pass 

123 

124 

125if __name__ == "__main__": 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 lsst.utils.tests.init() 

127 unittest.main()