Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import str 

2from builtins import range 

3from builtins import object 

4import os 

5import tempfile 

6import shutil 

7import numpy as np 

8import unittest 

9import sqlite3 

10import lsst.utils.tests 

11from lsst.sims.catalogs.definitions import InstanceCatalog 

12from lsst.sims.catalogs.decorators import cached, compound 

13from lsst.sims.catalogs.db import CatalogDBObject 

14 

15 

16ROOT = os.path.abspath(os.path.dirname(__file__)) 

17 

18 

19def setup_module(module): 

20 lsst.utils.tests.init() 

21 

22 

23def makeTestDB(name, size=10, **kwargs): 

24 """ 

25 Make a test database 

26 

27 @param [in] name is a string indicating the name of the database file 

28 to be created 

29 

30 @param [in] size is an int indicating the number of objects to include 

31 in the database (default=10) 

32 """ 

33 conn = sqlite3.connect(name) 

34 c = conn.cursor() 

35# try: 

36 c.execute('''CREATE TABLE testTable 

37 (id int, aa float, bb float, ra float, decl float)''') 

38 conn.commit() 

39# except Exception: 

40# raise RuntimeError("Error creating database.") 

41 

42 for i in range(size): 

43 

44 ra = np.random.sample()*360.0 

45 dec = (np.random.sample()-0.5)*180.0 

46 

47 # insert the row into the data base 

48 qstr = '''INSERT INTO testTable VALUES (%i, '%f', '%f', '%f', '%f')''' % (i, 2.0*i, 3.0*i, ra, dec) 

49 c.execute(qstr) 

50 

51 conn.commit() 

52 conn.close() 

53 

54 

55class testDBObject(CatalogDBObject): 

56 objid = 'testDBObject' 

57 tableid = 'testTable' 

58 idColKey = 'id' 

59 # Make this implausibly large? 

60 appendint = 1023 

61 database = 'colOriginsTestDatabase.db' 

62 driver = 'sqlite' 

63 raColName = 'ra' 

64 decColName = 'decl' 

65 columns = [('objid', 'id', int), 

66 ('raJ2000', 'ra*%f'%(np.pi/180.)), 

67 ('decJ2000', 'decl*%f'%(np.pi/180.)), 

68 ('aa', None), 

69 ('bb', None)] 

70 

71 

72# Below we define mixins which calculate the variables 'cc' and 'dd in different 

73# ways. The idea is to see if InstanceCatalog correctly identifies where 

74# the columns come from in those cases 

75class mixin1(object): 

76 @cached 

77 def get_cc(self): 

78 aa = self.column_by_name('aa') 

79 bb = self.column_by_name('bb') 

80 

81 return np.array(aa-bb) 

82 

83 @cached 

84 def get_dd(self): 

85 aa = self.column_by_name('aa') 

86 bb = self.column_by_name('bb') 

87 

88 return np.array(aa+bb) 

89 

90 

91class mixin2(object): 

92 @compound('cc', 'dd') 

93 def get_both(self): 

94 aa = self.column_by_name('aa') 

95 bb = self.column_by_name('bb') 

96 

97 return np.array([aa-bb, aa+bb]) 

98 

99 

100class mixin3(object): 

101 @cached 

102 def get_cc(self): 

103 aa = self.column_by_name('aa') 

104 bb = self.column_by_name('bb') 

105 

106 return np.array(aa-bb) 

107 

108 

109# Below we define catalog classes that use different combinations 

110# of the mixins above to calculate the columns 'cc' and 'dd' 

111class testCatalogDefaults(InstanceCatalog): 

112 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

113 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)] 

114 

115 

116class testCatalogMixin1(InstanceCatalog, mixin1): 

117 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

118 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)] 

119 

120 

121class testCatalogMixin2(InstanceCatalog, mixin2): 

122 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

123 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)] 

124 

125 

126class testCatalogMixin3(InstanceCatalog, mixin3): 

127 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

128 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)] 

129 

130 

131class testCatalogMixin3Mixin1(InstanceCatalog, mixin3, mixin1): 

132 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

133 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)] 

134 

135 

136class testCatalogAunspecified(InstanceCatalog, mixin3, mixin1): 

137 column_outputs = ['objid', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000'] 

138 default_columns = [('aa', -1.0, float), ('cc', 0.0, float), ('dd', 1.0, float)] 

139 

140 

141class testColumnOrigins(unittest.TestCase): 

142 

143 @classmethod 

144 def setUpClass(cls): 

145 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

146 cls.dbName = os.path.join(cls.scratch_dir, 'colOriginsTestDatabase.db') 

147 makeTestDB(cls.dbName) 

148 

149 @classmethod 

150 def tearDownClass(cls): 

151 if os.path.exists(cls.scratch_dir): 

152 shutil.rmtree(cls.scratch_dir) 

153 

154 def setUp(self): 

155 self.myDBobject = testDBObject(database=self.dbName) 

156 

157 def tearDown(self): 

158 del self.myDBobject 

159 

160 def testDefaults(self): 

161 """ 

162 Test case where the columns cc and dd come from defaults 

163 """ 

164 myCatalog = testCatalogDefaults(self.myDBobject) 

165 

166 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

167 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

168 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

169 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

170 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

171 self.assertEqual(myCatalog._column_origins['cc'], 'default column') 

172 self.assertEqual(myCatalog._column_origins['dd'], 'default column') 

173 

174 def testMixin1(self): 

175 """ 

176 Test case where the columns cc and dd come from non-compound getters 

177 """ 

178 myCatalog = testCatalogMixin1(self.myDBobject) 

179 

180 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

181 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

182 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

183 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

184 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

185 

186 # test that the last string in the column origin name refers to the correct mixin 

187 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin1') 

188 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1') 

189 

190 def testMixin2(self): 

191 """ 

192 Test case where the columns cc and dd come from a compound getter 

193 """ 

194 myCatalog = testCatalogMixin2(self.myDBobject) 

195 

196 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

197 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

198 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

199 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

200 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

201 

202 # test that the final string in the column origins name refers to the mixin 

203 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin2') 

204 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin2') 

205 

206 def testMixin3(self): 

207 """ 

208 Test case where cc comes from a mixin and dd comes from the default 

209 """ 

210 myCatalog = testCatalogMixin3(self.myDBobject) 

211 

212 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

213 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

214 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

215 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

216 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

217 

218 # test that the final string in the column origins name refers to the correct origin 

219 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3') 

220 self.assertEqual(str(myCatalog._column_origins['dd']), 'default column') 

221 

222 def testMixin3Mixin1(self): 

223 """ 

224 Test case where one mixin overwrites another for calculating cc 

225 """ 

226 myCatalog = testCatalogMixin3Mixin1(self.myDBobject) 

227 

228 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

229 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

230 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

231 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

232 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

233 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3') 

234 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1') 

235 

236 def testAunspecified(self): 

237 """ 

238 Test case where aa is not specified in the catalog (and has a default) 

239 """ 

240 myCatalog = testCatalogAunspecified(self.myDBobject) 

241 

242 self.assertEqual(myCatalog._column_origins['objid'], 'the database') 

243 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database') 

244 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database') 

245 self.assertEqual(myCatalog._column_origins['aa'], 'the database') 

246 self.assertEqual(myCatalog._column_origins['bb'], 'the database') 

247 

248 # test that the last string in the column origin name points to the correct mixin 

249 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3') 

250 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1') 

251 

252 

253class myDummyCatalogClass(InstanceCatalog): 

254 

255 default_columns = [('sillyDefault', 2.0, float)] 

256 

257 def get_cc(self): 

258 return self.column_by_name('aa')+1.0 

259 

260 @compound('dd', 'ee', 'ff') 

261 def get_compound(self): 

262 

263 return np.array([self.column_by_name('aa')+2.0, 

264 self.column_by_name('aa')+3.0, 

265 self.column_by_name('aa')+4.0]) 

266 

267 

268class myDependentColumnsClass_shouldPass(InstanceCatalog): 

269 

270 def get_dd(self): 

271 

272 if 'ee' in self._all_available_columns: 

273 delta = self.column_by_name('ee') 

274 else: 

275 delta = self.column_by_name('bb') 

276 

277 return self.column_by_name('aa') + delta 

278 

279 

280class myDependentColumnsClass_shouldFail(InstanceCatalog): 

281 

282 def get_cc(self): 

283 return self.column_by_name('aa')+1.0 

284 

285 def get_dd(self): 

286 

287 if 'ee' in self._all_available_columns: 

288 delta = self.column_by_name('ee') 

289 else: 

290 delta = self.column_by_name('bb') 

291 

292 return self.column_by_name('aa') + delta 

293 

294 def get_ee(self): 

295 return self.column_by_name('aa')+self.column_by_name('doesNotExist') 

296 

297 

298class AllAvailableColumns(unittest.TestCase): 

299 """ 

300 This will contain a unit test to verify that the InstanceCatalog class 

301 self._all_available_columns contains all of the information it should 

302 """ 

303 

304 @classmethod 

305 def setUpClass(cls): 

306 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

307 cls.dbName = os.path.join(cls.scratch_dir, 'allGettersTestDatabase.db') 

308 makeTestDB(cls.dbName) 

309 

310 @classmethod 

311 def tearDownClass(cls): 

312 if os.path.exists(cls.scratch_dir): 

313 shutil.rmtree(cls.scratch_dir) 

314 

315 def setUp(self): 

316 self.db = testDBObject(database=self.dbName) 

317 

318 def testAllGetters(self): 

319 """ 

320 test that the self._all_available_columns list contains all of the columns 

321 definedin an InstanceCatalog and its CatalogDBObject 

322 """ 

323 cat = myDummyCatalogClass(self.db, column_outputs=['aa']) 

324 self.assertIn('cc', cat._all_available_columns) 

325 self.assertIn('dd', cat._all_available_columns) 

326 self.assertIn('ee', cat._all_available_columns) 

327 self.assertIn('ff', cat._all_available_columns) 

328 self.assertIn('compound', cat._all_available_columns) 

329 self.assertIn('id', cat._all_available_columns) 

330 self.assertIn('aa', cat._all_available_columns) 

331 self.assertIn('bb', cat._all_available_columns) 

332 self.assertIn('ra', cat._all_available_columns) 

333 self.assertIn('decl', cat._all_available_columns) 

334 self.assertIn('decJ2000', cat._all_available_columns) 

335 self.assertIn('raJ2000', cat._all_available_columns) 

336 self.assertIn('objid', cat._all_available_columns) 

337 self.assertIn('sillyDefault', cat._all_available_columns) 

338 

339 def testDependentColumns(self): 

340 """ 

341 We want to be able to use self._all_available_columns to change the calculation 

342 of columns on the fly (i.e. if a column exists, then use it to calculate 

343 another column; if it does not, ignore it). This method tests whether 

344 or not that scheme will work. 

345 

346 I have written two classes of catalogs. The getter for the column 'dd' 

347 depends on the column 'doesNotExist', but only if the column 'ee' is defined. 

348 The class myDependentColumnsClass_shouldPass does not define a getter for 

349 'ee', so it does not require 'doesNotExist', so the constructor should pass. 

350 The class myDependentColumnsClass_shouldFail does have a getter for 'ee', 

351 so any catalog that requests the column 'dd' should fail to construct. 

352 """ 

353 

354 myDependentColumnsClass_shouldPass(self.db, column_outputs=['dd']) 

355 

356 # as long as we do not request the column 'dd', this should work 

357 myDependentColumnsClass_shouldFail(self.db, column_outputs=['cc']) 

358 

359 # because we are requesting the column 'dd', which depends on the fictitious column 

360 # 'doesNotExist', this should raise an exception 

361 self.assertRaises(ValueError, myDependentColumnsClass_shouldFail, self.db, column_outputs=['dd']) 

362 

363 

364class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

365 pass 

366 

367 

368if __name__ == "__main__": 368 ↛ 369line 368 didn't jump to line 369, because the condition on line 368 was never true

369 lsst.utils.tests.init() 

370 unittest.main()