Coverage for tests/convertReferenceCatalogTestBase.py: 14%

168 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-04 05:16 -0700

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["ConvertReferenceCatalogTestBase", "make_coord", "makeConvertConfig"] 

23 

24import logging 

25import math 

26import string 

27import tempfile 

28 

29import numpy as np 

30import astropy 

31import astropy.units as u 

32 

33import lsst.daf.butler 

34from lsst.meas.algorithms import IndexerRegistry, ConvertRefcatManager 

35from lsst.meas.algorithms import ConvertReferenceCatalogConfig 

36import lsst.utils 

37 

38 

39def make_coord(ra, dec): 

40 """Make an ICRS coord given its RA, Dec in degrees.""" 

41 return lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees) 

42 

43 

44class ConvertReferenceCatalogCustomClass(ConvertRefcatManager): 

45 """Custom class to overload `ConvertRefcatManager._setCoordinateCovariance` 

46 """ 

47 def _setCoordinateCovariance(self, record, row): 

48 """Coordinate covariance will not be used, so set to zero. 

49 """ 

50 outputParams = ['coord_ra', 'coord_dec', 'pm_ra', 'pm_dec', 'parallax'] 

51 for i in range(5): 

52 for j in range(i): 

53 record.set(self.key_map[f'{outputParams[j]}_{outputParams[i]}_Cov'], 0) 

54 

55 

56def makeConvertConfig(withMagErr=False, withRaDecErr=False, withPm=False, 

57 withParallax=False, withFullPositionInformation=False): 

58 """Make a config for ConvertReferenceCatalogTask 

59 

60 This is primarily intended to simplify tests of config validation, 

61 so fields that are not validated are not set. 

62 However, it can also be used to reduce boilerplate in other tests. 

63 """ 

64 config = ConvertReferenceCatalogConfig() 

65 config.dataset_config.ref_dataset_name = "testRefCat" 

66 config.pm_scale = 1000.0 

67 config.parallax_scale = 1e3 

68 config.ra_name = 'ra' 

69 config.dec_name = 'dec' 

70 config.mag_column_list = ['a', 'b'] 

71 

72 if withMagErr: 

73 config.mag_err_column_map = {'a': 'a_err', 'b': 'b_err'} 

74 

75 if withRaDecErr: 

76 config.ra_err_name = "ra_err" 

77 config.dec_err_name = "dec_err" 

78 config.coord_err_unit = "arcsecond" 

79 

80 if withPm: 

81 config.pm_ra_name = "pm_ra" 

82 config.pm_dec_name = "pm_dec" 

83 config.pm_ra_err_name = "pm_ra_err" 

84 config.pm_dec_err_name = "pm_dec_err" 

85 

86 if withParallax: 

87 config.parallax_name = "parallax" 

88 config.parallax_err_name = "parallax_err" 

89 

90 if withPm or withParallax: 

91 config.epoch_name = "unixtime" 

92 config.epoch_format = "unix" 

93 config.epoch_scale = "utc" 

94 

95 if withFullPositionInformation: 

96 config.full_position_information = True 

97 config.manager.retarget(ConvertReferenceCatalogCustomClass) 

98 

99 return config 

100 

101 

102class ConvertReferenceCatalogTestBase: 

103 """Base class for tests involving ConvertReferenceCatalogTask 

104 """ 

105 @classmethod 

106 def makeSkyCatalog(cls, outPath, size=1000, idStart=1, seed=123): 

107 """Make an on-sky catalog, and save it to a text file. 

108 

109 The catalog columns mimic the columns from the native Gaia catalog. 

110 

111 Parameters 

112 ---------- 

113 outPath : `str` or None 

114 The directory to write the catalog to. 

115 Specify None to not write any output. 

116 size : `int`, (optional) 

117 Number of items to add to the catalog. 

118 idStart : `int`, (optional) 

119 First id number to put in the catalog. 

120 seed : `float`, (optional) 

121 Random seed for ``np.random``. 

122 

123 Returns 

124 ------- 

125 refCatPath : `str` 

126 Path to the created on-sky catalog. 

127 refCatOtherDelimiterPath : `str` 

128 Path to the created on-sky catalog with a different delimiter. 

129 refCatData : `np.ndarray` 

130 The data contained in the on-sky catalog files. 

131 """ 

132 np.random.seed(seed) 

133 ident = np.arange(idStart, size + idStart, dtype=int) 

134 ra = np.random.random(size)*360. 

135 dec = np.degrees(np.arccos(2.*np.random.random(size) - 1.)) 

136 dec -= 90. 

137 ra_err = np.ones(size)*0.1 # arcsec 

138 dec_err = np.ones(size)*0.1 # arcsec 

139 a_mag = 16. + np.random.random(size)*4. 

140 a_mag_err = 0.01 + np.random.random(size)*0.2 

141 b_mag = 17. + np.random.random(size)*5. 

142 b_mag_err = 0.02 + np.random.random(size)*0.3 

143 is_photometric = np.random.randint(2, size=size) 

144 is_resolved = np.random.randint(2, size=size) 

145 is_variable = np.random.randint(2, size=size) 

146 extra_col1 = np.random.normal(size=size) 

147 extra_col2 = np.random.normal(1000., 100., size=size) 

148 # compute proper motion and PM error in arcseconds/year 

149 # and let the convert task scale them to radians 

150 pm_amt_arcsec = cls.properMotionAmt.asArcseconds() 

151 pm_dir_rad = cls.properMotionDir.asRadians() 

152 pm_ra = np.ones(size)*pm_amt_arcsec*math.cos(pm_dir_rad) 

153 pm_dec = np.ones(size)*pm_amt_arcsec*math.sin(pm_dir_rad) 

154 pm_ra_err = np.ones(size)*cls.properMotionErr.asArcseconds()*abs(math.cos(pm_dir_rad)) 

155 pm_dec_err = np.ones(size)*cls.properMotionErr.asArcseconds()*abs(math.sin(pm_dir_rad)) 

156 parallax = np.ones(size)*0.1 # arcseconds 

157 parallax_error = np.ones(size)*0.003 # arcseconds 

158 ra_dec_corr = 2 * np.random.random(size) - 1 

159 ra_parallax_corr = 2 * np.random.random(size) - 1 

160 ra_pmra_corr = 2 * np.random.random(size) - 1 

161 ra_pmdec_corr = 2 * np.random.random(size) - 1 

162 dec_parallax_corr = 2 * np.random.random(size) - 1 

163 dec_pmra_corr = 2 * np.random.random(size) - 1 

164 dec_pmdec_corr = 2 * np.random.random(size) - 1 

165 parallax_pmra_corr = 2 * np.random.random(size) - 1 

166 parallax_pmdec_corr = 2 * np.random.random(size) - 1 

167 pmra_pmdec_corr = 2 * np.random.random(size) - 1 

168 unixtime = np.ones(size)*cls.epoch.unix 

169 

170 def get_word(word_len): 

171 return "".join(np.random.choice([s for s in string.ascii_letters], word_len)) 

172 extra_col3 = np.array([get_word(num) for num in np.random.randint(11, size=size)]) 

173 

174 dtype = np.dtype([('id', float), ('ra', float), ('dec', float), 

175 ('ra_error', float), ('dec_error', float), ('a', float), 

176 ('a_err', float), ('b', float), ('b_err', float), ('is_phot', int), 

177 ('is_res', int), ('is_var', int), ('val1', float), ('val2', float), 

178 ('val3', '|S11'), ('pmra', float), ('pmdec', float), ('pmra_error', float), 

179 ('pmdec_error', float), ('parallax', float), ('parallax_error', float), 

180 ('ra_dec_corr', float), ('ra_parallax_corr', float), ('ra_pmra_corr', float), 

181 ('ra_pmdec_corr', float), ('dec_parallax_corr', float), ('dec_pmra_corr', float), 

182 ('dec_pmdec_corr', float), ('parallax_pmra_corr', float), 

183 ('parallax_pmdec_corr', float), ('pmra_pmdec_corr', float), ('unixtime', float)]) 

184 

185 arr = np.array(list(zip(ident, ra, dec, ra_err, dec_err, a_mag, a_mag_err, b_mag, b_mag_err, 

186 is_photometric, is_resolved, is_variable, extra_col1, extra_col2, extra_col3, 

187 pm_ra, pm_dec, pm_ra_err, pm_dec_err, parallax, parallax_error, ra_dec_corr, 

188 ra_parallax_corr, ra_pmra_corr, ra_pmdec_corr, dec_parallax_corr, 

189 dec_pmra_corr, dec_pmdec_corr, parallax_pmra_corr, parallax_pmdec_corr, 

190 pmra_pmdec_corr, unixtime)), 

191 dtype=dtype) 

192 if outPath is not None: 

193 # write the data with full precision; this is not realistic for 

194 # real catalogs, but simplifies tests based on round tripped data 

195 saveKwargs = dict( 

196 header="id,ra,dec,ra_err,dec_err," 

197 "a,a_err,b,b_err,is_phot,is_res,is_var,val1,val2,val3," 

198 "pm_ra,pm_dec,pm_ra_err,pm_dec_err,parallax,parallax_err,ra_dec_corr," 

199 "ra_parallax_corr,ra_pmra_corr,ra_pmdec_corr,dec_parallax_corr," 

200 "dec_pmra_corr,dec_pmdec_corr,parallax_pmra_corr,parallax_pmdec_corr," 

201 "pmra_pmdec_corr,unixtime", 

202 fmt=["%i", "%.15g", "%.15g", "%.15g", "%.15g", 

203 "%.15g", "%.15g", "%.15g", "%.15g", "%i", "%i", "%i", "%.15g", "%.15g", "%s", 

204 "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", 

205 "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g", "%.15g"] 

206 ) 

207 

208 np.savetxt(outPath+"/ref.txt", arr, delimiter=",", **saveKwargs) 

209 np.savetxt(outPath+"/ref_test_delim.txt", arr, delimiter="|", **saveKwargs) 

210 return outPath+"/ref.txt", outPath+"/ref_test_delim.txt", arr 

211 else: 

212 return arr 

213 

214 @classmethod 

215 def tearDownClass(cls): 

216 cls.outDir.cleanup() 

217 del cls.outPath 

218 del cls.skyCatalogFile 

219 del cls.skyCatalogFileDelim 

220 del cls.skyCatalog 

221 del cls.testRas 

222 del cls.testDecs 

223 del cls.searchRadius 

224 del cls.compCats 

225 

226 @classmethod 

227 def setUpClass(cls): 

228 cls.outDir = tempfile.TemporaryDirectory() 

229 cls.outPath = cls.outDir.name 

230 # arbitrary, but reasonable, amount of proper motion (angle/year) 

231 # and direction of proper motion 

232 cls.properMotionAmt = 3.0*lsst.geom.arcseconds 

233 cls.properMotionDir = 45*lsst.geom.degrees 

234 cls.properMotionErr = 1e-3*lsst.geom.arcseconds 

235 cls.epoch = astropy.time.Time(58206.861330339219, scale="tai", format="mjd") 

236 cls.skyCatalogFile, cls.skyCatalogFileDelim, cls.skyCatalog = cls.makeSkyCatalog(cls.outPath) 

237 cls.testRas = [210., 14.5, 93., 180., 286., 0.] 

238 cls.testDecs = [-90., -51., -30.1, 0., 27.3, 62., 90.] 

239 cls.searchRadius = 3. * lsst.geom.degrees 

240 cls.compCats = {} # dict of center coord: list of IDs of stars within cls.searchRadius of center 

241 cls.depth = 4 # gives a mean area of 20 deg^2 per pixel, roughly matching a 3 deg search radius 

242 

243 config = IndexerRegistry['HTM'].ConfigClass() 

244 # Match on disk comparison file 

245 config.depth = cls.depth 

246 cls.indexer = IndexerRegistry['HTM'](config) 

247 for ra in cls.testRas: 

248 for dec in cls.testDecs: 

249 tupl = (ra, dec) 

250 cent = make_coord(*tupl) 

251 cls.compCats[tupl] = [] 

252 for rec in cls.skyCatalog: 

253 if make_coord(rec['ra'], rec['dec']).separation(cent) < cls.searchRadius: 

254 cls.compCats[tupl].append(rec['id']) 

255 

256 cls.testRepoPath = cls.outPath+"/test_repo" 

257 

258 def setUp(self): 

259 self.repoPath = tempfile.TemporaryDirectory() # cleaned up automatically when test ends 

260 self.butler = self.makeTemporaryRepo(self.repoPath.name, self.depth) 

261 self.logger = logging.getLogger('lsst.ReferenceObjectLoader') 

262 

263 def tearDown(self): 

264 self.repoPath.cleanup() 

265 

266 @staticmethod 

267 def makeTemporaryRepo(rootPath, depth): 

268 """Create a temporary butler repository, configured to support a given 

269 htm pixel depth, to use for a single test. 

270 

271 Parameters 

272 ---------- 

273 rootPath : `str` 

274 Root path for butler. 

275 depth : `int` 

276 HTM pixel depth to be used in this test. 

277 

278 Returns 

279 ------- 

280 butler : `lsst.daf.butler.Butler` 

281 The newly created and instantiated butler. 

282 """ 

283 dimensionConfig = lsst.daf.butler.DimensionConfig() 

284 dimensionConfig['skypix']['common'] = f'htm{depth}' 

285 lsst.daf.butler.Butler.makeRepo(rootPath, dimensionConfig=dimensionConfig) 

286 return lsst.daf.butler.Butler(rootPath, writeable=True) 

287 

288 def checkAllRowsInRefcat(self, refObjLoader, skyCatalog, config): 

289 """Check that every item in ``skyCatalog`` is in the converted catalog, 

290 and check that fields are correct in it. 

291 

292 Parameters 

293 ---------- 

294 refObjLoader : `lsst.meas.algorithms.ReferenceObjectLoader` 

295 A reference object loader to use to search for rows from 

296 ``skyCatalog``. 

297 skyCatalog : `np.ndarray` 

298 The original data to compare with. 

299 config : `lsst.meas.algorithms.LoadReferenceObjectsConfig` 

300 The Config that was used to generate the refcat. 

301 """ 

302 for row in skyCatalog: 

303 center = lsst.geom.SpherePoint(row['ra'], row['dec'], lsst.geom.degrees) 

304 with self.assertLogs(self.logger.name, level="INFO") as cm: 

305 cat = refObjLoader.loadSkyCircle(center, 2*lsst.geom.arcseconds, filterName='a').refCat 

306 self.assertIn("Loading reference objects from testRefCat in region", cm.output[0]) 

307 self.assertGreater(len(cat), 0, "No objects found in loaded catalog.") 

308 msg = f"input row not found in loaded catalog:\nrow:\n{row}\n{row.dtype}\n\ncatalog:\n{cat[0]}" 

309 self.assertEqual(row['id'], cat[0]['id'], msg) 

310 # coordinates won't match perfectly due to rounding in radian/degree conversions 

311 self.assertFloatsAlmostEqual(row['ra'], cat[0]['coord_ra'].asDegrees(), 

312 rtol=1e-14, msg=msg) 

313 self.assertFloatsAlmostEqual(row['dec'], cat[0]['coord_dec'].asDegrees(), 

314 rtol=1e-14, msg=msg) 

315 if config.coord_err_unit is not None: 

316 # coordinate errors are not lsst.geom.Angle, so we have to use the 

317 # `units` field to convert them, and they are float32, so the tolerance is wider. 

318 raErr = cat[0]['coord_raErr']*u.Unit(cat.schema['coord_raErr'].asField().getUnits()) 

319 decErr = cat[0]['coord_decErr']*u.Unit(cat.schema['coord_decErr'].asField().getUnits()) 

320 self.assertFloatsAlmostEqual(row['ra_error'], raErr.to_value(config.coord_err_unit), 

321 rtol=1e-7, msg=msg) 

322 self.assertFloatsAlmostEqual(row['dec_error'], decErr.to_value(config.coord_err_unit), 

323 rtol=1e-7, msg=msg) 

324 

325 if config.parallax_name is not None: 

326 self.assertFloatsAlmostEqual(row['parallax'], cat[0]['parallax'].asArcseconds()) 

327 parallaxErr = cat[0]['parallaxErr'].asArcseconds() 

328 # larger tolerance: input data is float32 

329 self.assertFloatsAlmostEqual(row['parallax_error'], parallaxErr, rtol=3e-8)