Coverage for tests/test_loadDiaCatalogs.py: 21%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import numpy as np 

24import pandas as pd 

25import tempfile 

26import unittest 

27import yaml 

28 

29from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

30import lsst.afw.geom as afwGeom 

31import lsst.afw.image as afwImage 

32from lsst.ap.association import LoadDiaCatalogsTask, LoadDiaCatalogsConfig 

33import lsst.daf.base as dafBase 

34from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

35import lsst.geom as geom 

36from lsst.utils import getPackageDir 

37import lsst.utils.tests 

38 

39 

40def _data_file_name(basename, module_name): 

41 """Return path name of a data file. 

42 

43 Parameters 

44 ---------- 

45 basename : `str` 

46 Name of the file to add to the path string. 

47 module_name : `str` 

48 Name of lsst stack package environment variable. 

49 

50 Returns 

51 ------- 

52 data_file_path : `str` 

53 Full path of the file to load from the "data" directory in a given 

54 repository. 

55 """ 

56 return os.path.join(getPackageDir(module_name), "data", basename) 

57 

58 

59def makeExposure(flipX=False, flipY=False): 

60 """Create an exposure and flip the x or y (or both) coordinates. 

61 

62 Returns bounding boxes that are right or left handed around the bounding 

63 polygon. 

64 

65 Parameters 

66 ---------- 

67 flipX : `bool` 

68 Flip the x coordinate in the WCS. 

69 flipY : `bool` 

70 Flip the y coordinate in the WCS. 

71 

72 Returns 

73 ------- 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure with a valid bounding box and wcs. 

76 """ 

77 metadata = dafBase.PropertySet() 

78 

79 metadata.set("SIMPLE", "T") 

80 metadata.set("BITPIX", -32) 

81 metadata.set("NAXIS", 2) 

82 metadata.set("NAXIS1", 1024) 

83 metadata.set("NAXIS2", 1153) 

84 metadata.set("RADECSYS", 'FK5') 

85 metadata.set("EQUINOX", 2000.) 

86 

87 metadata.setDouble("CRVAL1", 215.604025685476) 

88 metadata.setDouble("CRVAL2", 53.1595451514076) 

89 metadata.setDouble("CRPIX1", 1109.99981456774) 

90 metadata.setDouble("CRPIX2", 560.018167811613) 

91 metadata.set("CTYPE1", 'RA---SIN') 

92 metadata.set("CTYPE2", 'DEC--SIN') 

93 

94 xFlip = 1 

95 if flipX: 

96 xFlip = -1 

97 yFlip = 1 

98 if flipY: 

99 yFlip = -1 

100 metadata.setDouble("CD1_1", xFlip * 5.10808596133527E-05) 

101 metadata.setDouble("CD1_2", yFlip * 1.85579539217196E-07) 

102 metadata.setDouble("CD2_2", yFlip * -5.10281493481982E-05) 

103 metadata.setDouble("CD2_1", xFlip * -8.27440751733828E-07) 

104 

105 wcs = afwGeom.makeSkyWcs(metadata) 

106 exposure = afwImage.makeExposure( 

107 afwImage.makeMaskedImageFromArrays(np.ones((1024, 1153))), wcs) 

108 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

109 visit = afwImage.VisitInfo( 

110 exposureId=1234, 

111 exposureTime=200., 

112 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

113 dafBase.DateTime.Timescale.TAI)) 

114 exposure.info.id = 1234 

115 exposure.setDetector(detector) 

116 exposure.getInfo().setVisitInfo(visit) 

117 exposure.setFilterLabel(afwImage.FilterLabel(band='g')) 

118 

119 return exposure 

120 

121 

122def makeDiaObjects(nObjects, exposure): 

123 """Make a test set of DiaObjects. 

124 

125 Parameters 

126 ---------- 

127 nObjects : `int` 

128 Number of objects to create. 

129 exposure : `lsst.afw.image.Exposure` 

130 Exposure to create objects over. 

131 

132 Returns 

133 ------- 

134 diaObjects : `pandas.DataFrame` 

135 DiaObjects generated across the exposure. 

136 """ 

137 bbox = geom.Box2D(exposure.getBBox()) 

138 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

139 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

140 

141 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

142 system=dafBase.DateTime.MJD) 

143 

144 wcs = exposure.getWcs() 

145 

146 data = [] 

147 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

148 coord = wcs.pixelToSky(x, y) 

149 newObject = {"ra": coord.getRa().asDegrees(), 

150 "decl": coord.getDec().asDegrees(), 

151 "radecTai": midPointTaiMJD, 

152 "diaObjectId": idx, 

153 "pmParallaxNdata": 0, 

154 "nearbyObj1": 0, 

155 "nearbyObj2": 0, 

156 "nearbyObj3": 0} 

157 for f in ["u", "g", "r", "i", "z", "y"]: 

158 newObject["%sPSFluxNdata" % f] = 0 

159 data.append(newObject) 

160 

161 return pd.DataFrame(data=data) 

162 

163 

164def makeDiaSources(nSources, diaObjectIds, exposure): 

165 """Make a test set of DiaSources. 

166 

167 Parameters 

168 ---------- 

169 nSources : `int` 

170 Number of sources to create. 

171 diaObjectIds : `numpy.ndarray` 

172 Integer Ids of diaobjects to "associate" with the DiaSources. 

173 exposure : `lsst.afw.image.Exposure` 

174 Exposure to create sources over. 

175 

176 Returns 

177 ------- 

178 diaSources : `pandas.DataFrame` 

179 DiaSources generated across the exposure. 

180 """ 

181 bbox = geom.Box2D(exposure.getBBox()) 

182 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

183 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

184 rand_ids = diaObjectIds[np.random.randint(len(diaObjectIds), size=nSources)] 

185 

186 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

187 system=dafBase.DateTime.MJD) 

188 

189 wcs = exposure.getWcs() 

190 

191 data = [] 

192 for idx, (x, y, objId) in enumerate(zip(rand_x, rand_y, rand_ids)): 

193 coord = wcs.pixelToSky(x, y) 

194 data.append({"ra": coord.getRa().asDegrees(), 

195 "decl": coord.getDec().asDegrees(), 

196 "diaObjectId": objId, 

197 "diaSourceId": idx, 

198 "midPointTai": midPointTaiMJD}) 

199 

200 return pd.DataFrame(data=data) 

201 

202 

203def makeDiaForcedSources(nForcedSources, diaObjectIds, exposure): 

204 """Make a test set of DiaForcedSources. 

205 

206 Parameters 

207 ---------- 

208 nForcedSources : `int` 

209 Number of sources to create. 

210 diaObjectIds : `numpy.ndarray` 

211 Integer Ids of diaobjects to "associate" with the DiaSources. 

212 exposure : `lsst.afw.image.Exposure` 

213 Exposure to create sources over. 

214 

215 Returns 

216 ------- 

217 diaForcedSources : `pandas.DataFrame` 

218 DiaForcedSources generated across the exposure. 

219 """ 

220 rand_ids = diaObjectIds[ 

221 np.random.randint(len(diaObjectIds), size=nForcedSources)] 

222 

223 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

224 system=dafBase.DateTime.MJD) 

225 

226 data = [] 

227 for idx, objId in enumerate(rand_ids): 

228 data.append({"diaObjectId": objId, 

229 "diaForcedSourceId": idx, 

230 "ccdVisitId": idx, 

231 "midPointTai": midPointTaiMJD}) 

232 

233 return pd.DataFrame(data=data) 

234 

235 

236class TestLoadDiaCatalogs(unittest.TestCase): 

237 

238 def setUp(self): 

239 np.random.seed(1234) 

240 

241 self.db_file_fd, self.db_file = tempfile.mkstemp( 

242 dir=os.path.dirname(__file__)) 

243 

244 self.apdbConfig = ApdbSqlConfig() 

245 self.apdbConfig.db_url = "sqlite:///" + self.db_file 

246 self.apdbConfig.dia_object_index = "baseline" 

247 self.apdbConfig.dia_object_columns = [] 

248 self.apdbConfig.schema_file = _data_file_name( 

249 "apdb-schema.yaml", "dax_apdb") 

250 self.apdbConfig.extra_schema_file = _data_file_name( 

251 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

252 

253 self.apdb = ApdbSql(config=self.apdbConfig) 

254 self.apdb.makeSchema() 

255 

256 self.exposure = makeExposure(False, False) 

257 

258 self.diaObjects = makeDiaObjects(20, self.exposure) 

259 self.diaSources = makeDiaSources( 

260 100, 

261 self.diaObjects["diaObjectId"].to_numpy(), 

262 self.exposure) 

263 self.diaForcedSources = makeDiaForcedSources( 

264 200, 

265 self.diaObjects["diaObjectId"].to_numpy(), 

266 self.exposure) 

267 

268 self.dateTime = self.exposure.getInfo().getVisitInfo().getDate() 

269 self.apdb.store(self.dateTime, 

270 self.diaObjects, 

271 self.diaSources, 

272 self.diaForcedSources) 

273 

274 # These columns are not in the DPDD, yet do appear in DiaSource.yaml. 

275 # We don't need to check them against the default APDB schema. 

276 self.ignoreColumns = ["filterName", "bboxSize", "isDipole"] 

277 

278 def tearDown(self): 

279 os.close(self.db_file_fd) 

280 os.remove(self.db_file) 

281 

282 def testRun(self): 

283 """Test the full run method for the loader. 

284 """ 

285 diaLoader = LoadDiaCatalogsTask() 

286 result = diaLoader.run(self.exposure, self.apdb) 

287 

288 self.assertEqual(len(result.diaObjects), len(self.diaObjects)) 

289 self.assertEqual(len(result.diaSources), len(self.diaSources)) 

290 self.assertEqual(len(result.diaForcedSources), 

291 len(self.diaForcedSources)) 

292 

293 def testLoadDiaObjects(self): 

294 """Test that the correct number of diaObjects are loaded. 

295 """ 

296 diaLoader = LoadDiaCatalogsTask() 

297 region = diaLoader._getRegion(self.exposure) 

298 diaObjects = diaLoader.loadDiaObjects(region, 

299 self.apdb) 

300 self.assertEqual(len(diaObjects), len(self.diaObjects)) 

301 

302 def testLoadDiaForcedSources(self): 

303 """Test that the correct number of diaForcedSources are loaded. 

304 """ 

305 diaLoader = LoadDiaCatalogsTask() 

306 region = diaLoader._getRegion(self.exposure) 

307 diaForcedSources = diaLoader.loadDiaForcedSources( 

308 self.diaObjects, 

309 region, 

310 self.dateTime, 

311 self.apdb) 

312 self.assertEqual(len(diaForcedSources), len(self.diaForcedSources)) 

313 

314 def testLoadDiaSources(self): 

315 """Test that the correct number of diaSources are loaded. 

316 

317 Also check that they can be properly loaded both by location and 

318 ``diaObjectId``. 

319 """ 

320 diaConfig = LoadDiaCatalogsConfig() 

321 diaLoader = LoadDiaCatalogsTask(config=diaConfig) 

322 

323 region = diaLoader._getRegion(self.exposure) 

324 diaSources = diaLoader.loadDiaSources(self.diaObjects, 

325 region, 

326 self.dateTime, 

327 self.apdb) 

328 self.assertEqual(len(diaSources), len(self.diaSources)) 

329 

330 def test_apdbSchema(self): 

331 """Test that the default DiaSource schema from dax_apdb agrees with the 

332 column names defined here in ap_association/data/DiaSource.yaml. 

333 """ 

334 functorFile = _data_file_name("DiaSource.yaml", "ap_association") 

335 apdbSchemaFile = _data_file_name("apdb-schema.yaml", "dax_apdb") 

336 with open(apdbSchemaFile) as yaml_stream: 

337 table_list = list(yaml.safe_load_all(yaml_stream)) 

338 for table in table_list: 

339 if table['table'] == 'DiaSource': 

340 apdbSchemaColumns = [column['name'] for column in table['columns']] 

341 break 

342 with open(functorFile) as yaml_stream: 

343 diaSourceFunctor = yaml.safe_load_all(yaml_stream) 

344 for functor in diaSourceFunctor: 

345 diaSourceColumns = [column for column in list(functor['funcs'].keys()) 

346 if column not in self.ignoreColumns] 

347 self.assertLess(set(diaSourceColumns), set(apdbSchemaColumns)) 

348 

349 

350class MemoryTester(lsst.utils.tests.MemoryTestCase): 

351 pass 

352 

353 

354def setup_module(module): 

355 lsst.utils.tests.init() 

356 

357 

358if __name__ == "__main__": 358 ↛ 359line 358 didn't jump to line 359, because the condition on line 358 was never true

359 lsst.utils.tests.init() 

360 unittest.main()