Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import with_statement 

2from builtins import range 

3from builtins import super 

4import unittest 

5import sqlite3 

6import os 

7import numpy as np 

8import tempfile 

9import shutil 

10 

11import lsst.utils.tests 

12from lsst.sims.utils.CodeUtilities import sims_clean_up 

13from lsst.sims.catalogs.definitions import parallelCatalogWriter 

14from lsst.sims.catalogs.definitions import InstanceCatalog 

15from lsst.sims.catalogs.decorators import compound, cached 

16from lsst.sims.catalogs.db import CatalogDBObject 

17 

18 

19ROOT = os.path.abspath(os.path.dirname(__file__)) 

20 

21 

22def setup_module(module): 

23 lsst.utils.tests.init() 

24 

25 

26class DbClass(CatalogDBObject): 

27 tableid = 'test' 

28 

29 host = None 

30 port = None 

31 driver = 'sqlite' 

32 objid = 'parallel_writer_test_db' 

33 idColKey = 'id' 

34 

35 

36class ParallelCatClass1(InstanceCatalog): 

37 column_outputs = ['id', 'test1', 'ii'] 

38 cannot_be_null = ['valid1'] 

39 

40 @compound('test1', 'valid1') 

41 def get_values(self): 

42 ii = self.column_by_name('ii') 

43 return np.array([self.column_by_name('id')**2, 

44 np.where(ii%2 == 1, ii, None)]) 

45 

46 

47class ParallelCatClass2(InstanceCatalog): 

48 column_outputs = ['id', 'test2', 'ii'] 

49 cannot_be_null = ['valid2'] 

50 

51 @compound('test2', 'valid2') 

52 def get_values(self): 

53 ii = self.column_by_name('id') 

54 return np.array([self.column_by_name('id')**3, 

55 np.where(ii%2 == 1, ii, None)]) 

56 

57 

58class ParallelCatClass3(InstanceCatalog): 

59 column_outputs = ['id', 'test3', 'ii'] 

60 cannot_be_null = ['valid3'] 

61 

62 @cached 

63 def get_test3(self): 

64 return self.column_by_name('id')**4 

65 

66 @cached 

67 def get_valid3(self): 

68 ii = self.column_by_name('id') 

69 return np.where(ii%5 == 0, ii, None) 

70 

71 

72class ControlCatalog(InstanceCatalog): 

73 column_outputs = ['id', 'ii'] 

74 

75 

76class ParallelWriterTestCase(unittest.TestCase): 

77 

78 @classmethod 

79 def setUpClass(cls): 

80 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="ParallelWriterTestCase") 

81 

82 cls.db_name = os.path.join(cls.scratch_dir, 'parallel_test_db.db') 

83 if os.path.exists(cls.db_name): 

84 os.unlink(cls.db_name) 

85 

86 rng = np.random.RandomState(88) 

87 conn = sqlite3.connect(cls.db_name) 

88 c = conn.cursor() 

89 c.execute('''CREATE TABLE test (id int, ii int)''') 

90 for ii in range(100): 

91 c.execute('''INSERT INTO test VALUES(%i, %i)''' % (ii, rng.randint(0, 101))) 

92 

93 conn.commit() 

94 conn.close() 

95 

96 @classmethod 

97 def tearDownClass(cls): 

98 sims_clean_up() 

99 if os.path.exists(cls.db_name): 

100 os.unlink(cls.db_name) 

101 if os.path.exists(cls.scratch_dir): 

102 shutil.rmtree(cls.scratch_dir) 

103 

104 def test_parallel_writing(self): 

105 """ 

106 Test that parallelCatalogWriter gets the right columns in it 

107 """ 

108 db_name = os.path.join(self.scratch_dir, 'parallel_test_db.db') 

109 db = DbClass(database=db_name) 

110 

111 class_dict = {os.path.join(self.scratch_dir, 'par_test1.txt'): ParallelCatClass1(db), 

112 os.path.join(self.scratch_dir, 'par_test2.txt'): ParallelCatClass2(db), 

113 os.path.join(self.scratch_dir, 'par_test3.txt'): ParallelCatClass3(db)} 

114 

115 for file_name in class_dict: 

116 if os.path.exists(file_name): 

117 os.unlink(file_name) 

118 

119 parallelCatalogWriter(class_dict) 

120 

121 dtype = np.dtype([('id', int), ('test', int), ('ii', int)]) 

122 data1 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test1.txt'), dtype=dtype, delimiter=',') 

123 data2 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test2.txt'), dtype=dtype, delimiter=',') 

124 data3 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test3.txt'), dtype=dtype, delimiter=',') 

125 

126 # verify that the contents of the catalogs fit with the constraints in cannot_be_null 

127 self.assertEqual(len(np.where(data1['ii']%2 == 0)[0]), 0) 

128 self.assertEqual(len(np.where(data2['id']%2 == 0)[0]), 0) 

129 self.assertEqual(len(np.where(data3['id']%5 != 0)[0]), 0) 

130 

131 # verify that the added value columns came out to the correct value 

132 np.testing.assert_array_equal(data1['id']**2, data1['test']) 

133 np.testing.assert_array_equal(data2['id']**3, data2['test']) 

134 np.testing.assert_array_equal(data3['id']**4, data3['test']) 

135 

136 # now verify that all of the rows that were excluded from our catalogs 

137 # really should have been excluded 

138 

139 control_cat = ControlCatalog(db) 

140 iterator = control_cat.iter_catalog() 

141 ct = 0 

142 ct_in_1 = 0 

143 ct_in_2 = 0 

144 ct_in_3 = 0 

145 for control_data in iterator: 

146 ct += 1 

147 

148 if control_data[1] % 2 == 0: 

149 self.assertNotIn(control_data[0], data1['id']) 

150 else: 

151 ct_in_1 += 1 

152 self.assertIn(control_data[0], data1['id']) 

153 dex = np.where(data1['id'] == control_data[0])[0][0] 

154 self.assertEqual(control_data[1], data1['ii'][dex]) 

155 

156 if control_data[0] % 2 == 0: 

157 self.assertNotIn(control_data[0], data2['id']) 

158 else: 

159 ct_in_2 += 1 

160 self.assertIn(control_data[0], data2['id']) 

161 dex = np.where(data2['id'] == control_data[0])[0][0] 

162 self.assertEqual(control_data[1], data2['ii'][dex]) 

163 

164 if control_data[0] % 5 != 0: 

165 self.assertNotIn(control_data[0], data3['id']) 

166 else: 

167 ct_in_3 += 1 

168 self.assertIn(control_data[0], data3['id']) 

169 dex = np.where(data3['id'] == control_data[0])[0][0] 

170 self.assertEqual(control_data[1], data3['ii'][dex]) 

171 

172 self.assertEqual(ct_in_1, len(data1['id'])) 

173 self.assertEqual(ct_in_2, len(data2['id'])) 

174 self.assertEqual(ct_in_3, len(data3['id'])) 

175 self.assertEqual(ct, 100) 

176 

177 for file_name in class_dict: 

178 if os.path.exists(file_name): 

179 os.unlink(file_name) 

180 

181 def test_parallel_writing_chunk_size(self): 

182 """ 

183 Test that parallelCatalogWriter gets the right columns in it 

184 when chunk_size is not None (this is a repeat of test_parallel_writing) 

185 """ 

186 db_name = os.path.join(self.scratch_dir, 'parallel_test_db.db') 

187 db = DbClass(database=db_name) 

188 

189 class_dict = {os.path.join(self.scratch_dir, 'par_test1.txt'): ParallelCatClass1(db), 

190 os.path.join(self.scratch_dir, 'par_test2.txt'): ParallelCatClass2(db), 

191 os.path.join(self.scratch_dir, 'par_test3.txt'): ParallelCatClass3(db)} 

192 

193 for file_name in class_dict: 

194 if os.path.exists(file_name): 

195 os.unlink(file_name) 

196 

197 parallelCatalogWriter(class_dict, chunk_size=7) 

198 

199 dtype = np.dtype([('id', int), ('test', int), ('ii', int)]) 

200 data1 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test1.txt'), dtype=dtype, delimiter=',') 

201 data2 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test2.txt'), dtype=dtype, delimiter=',') 

202 data3 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test3.txt'), dtype=dtype, delimiter=',') 

203 

204 # verify that the contents of the catalogs fit with the constraints in cannot_be_null 

205 self.assertEqual(len(np.where(data1['ii']%2 == 0)[0]), 0) 

206 self.assertEqual(len(np.where(data2['id']%2 == 0)[0]), 0) 

207 self.assertEqual(len(np.where(data3['id']%5 != 0)[0]), 0) 

208 

209 # verify that the added value columns came out to the correct value 

210 np.testing.assert_array_equal(data1['id']**2, data1['test']) 

211 np.testing.assert_array_equal(data2['id']**3, data2['test']) 

212 np.testing.assert_array_equal(data3['id']**4, data3['test']) 

213 

214 # now verify that all of the rows that were excluded from our catalogs 

215 # really should have been excluded 

216 

217 control_cat = ControlCatalog(db) 

218 iterator = control_cat.iter_catalog() 

219 ct = 0 

220 ct_in_1 = 0 

221 ct_in_2 = 0 

222 ct_in_3 = 0 

223 for control_data in iterator: 

224 ct += 1 

225 

226 if control_data[1] % 2 == 0: 

227 self.assertNotIn(control_data[0], data1['id']) 

228 else: 

229 ct_in_1 += 1 

230 self.assertIn(control_data[0], data1['id']) 

231 dex = np.where(data1['id'] == control_data[0])[0][0] 

232 self.assertEqual(control_data[1], data1['ii'][dex]) 

233 

234 if control_data[0] % 2 == 0: 

235 self.assertNotIn(control_data[0], data2['id']) 

236 else: 

237 ct_in_2 += 1 

238 self.assertIn(control_data[0], data2['id']) 

239 dex = np.where(data2['id'] == control_data[0])[0][0] 

240 self.assertEqual(control_data[1], data2['ii'][dex]) 

241 

242 if control_data[0] % 5 != 0: 

243 self.assertNotIn(control_data[0], data3['id']) 

244 else: 

245 ct_in_3 += 1 

246 self.assertIn(control_data[0], data3['id']) 

247 dex = np.where(data3['id'] == control_data[0])[0][0] 

248 self.assertEqual(control_data[1], data3['ii'][dex]) 

249 

250 self.assertEqual(ct_in_1, len(data1['id'])) 

251 self.assertEqual(ct_in_2, len(data2['id'])) 

252 self.assertEqual(ct_in_3, len(data3['id'])) 

253 self.assertEqual(ct, 100) 

254 

255 for file_name in class_dict: 

256 if os.path.exists(file_name): 

257 os.unlink(file_name) 

258 

259 

260class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

261 pass 

262 

263 

264if __name__ == "__main__": 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true

265 setup_module(None) 

266 unittest.main()