Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import with_statement 

2from builtins import range 

3import unittest 

4import os 

5import numpy as np 

6import tempfile 

7import shutil 

8import lsst.utils.tests 

9 

10from lsst.sims.utils.CodeUtilities import sims_clean_up 

11from lsst.sims.catalogs.db import fileDBObject, CatalogDBObject 

12from lsst.sims.catalogs.definitions import InstanceCatalog 

13 

14ROOT = os.path.abspath(os.path.dirname(__file__)) 

15 

16 

17def setup_module(module): 

18 lsst.utils.tests.init() 

19 

20 

21class ConnectionPassingTest(unittest.TestCase): 

22 """ 

23 This will test whether we can can construct InstanceCatalogs 

24 containing multiple classes of object using only a single 

25 connection to the database. 

26 """ 

27 

28 @classmethod 

29 def write_star_txt(cls): 

30 np.random.seed(77) 

31 cls.n_stars = 20 

32 cls.star_ra = np.random.random_sample(cls.n_stars)*360.0 

33 cls.star_dec = (np.random.random_sample(cls.n_stars)-0.5)*180.0 

34 cls.star_umag = np.random.random_sample(cls.n_stars)*10.0 + 15.0 

35 cls.star_gmag = np.random.random_sample(cls.n_stars)*10.0 + 15.0 

36 

37 cls.star_txt_name = os.path.join(cls.scratch_dir, 

38 'ConnectionPassingTestStars.txt') 

39 

40 if os.path.exists(cls.star_txt_name): 

41 os.unlink(cls.star_txt_name) 

42 

43 with open(cls.star_txt_name, 'w') as output_file: 

44 output_file.write("#id raJ2000 decJ2000 umag gmag\n") 

45 for ix in range(cls.n_stars): 

46 output_file.write("%d %.4f %.4f %.4f %.4f\n" 

47 % (ix, cls.star_ra[ix], cls.star_dec[ix], 

48 cls.star_umag[ix], cls.star_gmag[ix])) 

49 

50 @classmethod 

51 def write_galaxy_txt(cls): 

52 np.random.seed(88) 

53 cls.n_galaxies = 100 

54 cls.gal_ra = np.random.random_sample(cls.n_galaxies)*360.0 

55 cls.gal_dec = (np.random.random_sample(cls.n_galaxies)-0.5)*180.0 

56 cls.gal_redshift = np.random.random_sample(cls.n_galaxies)*5.0 

57 cls.gal_umag = np.random.random_sample(cls.n_galaxies)*10.0+21.0 

58 cls.gal_gmag = np.random.random_sample(cls.n_galaxies)*10.0+21.0 

59 

60 cls.gal_txt_name = os.path.join(cls.scratch_dir, 

61 'ConnectionPassingTestGal.txt') 

62 

63 if os.path.exists(cls.gal_txt_name): 

64 os.unlink(cls.gal_txt_name) 

65 

66 with open(cls.gal_txt_name, 'w') as output_file: 

67 output_file.write("#id raJ2000 decJ2000 redshift umag gmag\n") 

68 for ix in range(cls.n_galaxies): 

69 output_file.write("%d %.4f %.4f %.4f %.4f %.4f\n" 

70 % (ix, cls.gal_ra[ix], cls.gal_dec[ix], 

71 cls.gal_redshift[ix], 

72 cls.gal_umag[ix], cls.gal_gmag[ix])) 

73 

74 @classmethod 

75 def setUpClass(cls): 

76 

77 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-") 

78 cls.write_star_txt() 

79 cls.write_galaxy_txt() 

80 

81 cls.dbName = os.path.join(cls.scratch_dir, 'ConnectionPassingTestDB.db') 

82 

83 if os.path.exists(cls.dbName): 

84 os.unlink(cls.dbName) 

85 

86 galDtype = np.dtype([('id', np.int), 

87 ('raJ2000', np.float), ('decJ2000', np.float), 

88 ('redshift', np.float), ('umag', np.float), 

89 ('gmag', np.float)]) 

90 

91 starDtype = np.dtype([('id', np.int), ('raJ2000', np.float), 

92 ('decJ2000', np.float), ('umag', np.float), 

93 ('gmag', np.float)]) 

94 

95 fileDBObject(cls.star_txt_name, 

96 database=cls.dbName, driver='sqlite', 

97 runtable='stars', idColKey='id', 

98 dtype=starDtype) 

99 

100 fileDBObject(cls.gal_txt_name, 

101 database=cls.dbName, driver='sqlite', 

102 runtable='galaxies', idColKey='id', 

103 dtype=galDtype) 

104 

105 @classmethod 

106 def tearDownClass(cls): 

107 sims_clean_up() 

108 if os.path.exists(cls.dbName): 

109 os.unlink(cls.dbName) 

110 

111 if os.path.exists(cls.star_txt_name): 

112 os.unlink(cls.star_txt_name) 

113 

114 if os.path.exists(cls.gal_txt_name): 

115 os.unlink(cls.gal_txt_name) 

116 

117 if os.path.exists(cls.scratch_dir): 

118 shutil.rmtree(cls.scratch_dir) 

119 

120 def test_passing(self): 

121 """ 

122 Test that we can produce a catalog of multiple object types 

123 drawn from different tables of the same database by passing 

124 DBConnections 

125 """ 

126 

127 class starDBObj(CatalogDBObject): 

128 database = self.dbName 

129 driver = 'sqlite' 

130 tableid = 'stars' 

131 idColKey = 'id' 

132 

133 class galDBObj(CatalogDBObject): 

134 database = self.dbName 

135 driver = 'sqlite' 

136 tableid = 'galaxies' 

137 idColKey = 'id' 

138 

139 class starCatalog(InstanceCatalog): 

140 column_outputs = ['id', 'raJ2000', 'decJ2000', 

141 'gmag', 'umag'] 

142 

143 default_formats = {'f': '%.4f'} 

144 

145 class galCatalog(InstanceCatalog): 

146 column_outputs = ['id', 'decJ2000', 'raJ2000', 

147 'gmag', 'umag', 'redshift'] 

148 

149 default_formats = {'f': '%.4f'} 

150 

151 catName = os.path.join(self.scratch_dir, 

152 'ConnectionPassingTestOutputCatalog.txt') 

153 

154 if os.path.exists(catName): 

155 os.unlink(catName) 

156 

157 stars = starDBObj() 

158 galaxies = galDBObj(connection=stars.connection) 

159 

160 self.assertEqual(stars.connection, galaxies.connection) 

161 

162 starCat = starCatalog(stars) 

163 galCat = galCatalog(galaxies) 

164 starCat.write_catalog(catName, chunk_size=5) 

165 galCat.write_catalog(catName, write_mode='a', chunk_size=5) 

166 

167 with open(catName, 'r') as input_file: 

168 lines = input_file.readlines() 

169 self.assertEqual(len(lines), self.n_stars+self.n_galaxies+2) 

170 for ix in range(self.n_stars): 

171 vals = lines[ix+1].split(',') 

172 dex = np.int(vals[0]) 

173 self.assertEqual(round(self.star_ra[dex], 4), np.float(vals[1])) 

174 self.assertEqual(round(self.star_dec[dex], 4), np.float(vals[2])) 

175 self.assertEqual(round(self.star_gmag[dex], 4), np.float(vals[3])) 

176 self.assertEqual(round(self.star_umag[dex], 4), np.float(vals[4])) 

177 

178 offset = 2 + self.n_stars 

179 for ix in range(self.n_galaxies): 

180 vals = lines[ix+offset].split(',') 

181 dex = np.int(vals[0]) 

182 self.assertEqual(round(self.gal_dec[dex], 4), np.float(vals[1])) 

183 self.assertEqual(round(self.gal_ra[dex], 4), np.float(vals[2])) 

184 self.assertEqual(round(self.gal_gmag[dex], 4), np.float(vals[3])) 

185 self.assertEqual(round(self.gal_umag[dex], 4), np.float(vals[4])) 

186 self.assertEqual(round(self.gal_redshift[dex], 4), np.float(vals[5])) 

187 

188 if os.path.exists(catName): 

189 os.unlink(catName) 

190 

191 

192class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

193 pass 

194 

195if __name__ == "__main__": 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true

196 lsst.utils.tests.init() 

197 unittest.main()