Coverage for tests/test_packageAlerts.py: 20%

149 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-01 10:57 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

34from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

35import lsst.afw.image as afwImage 

36import lsst.daf.base as dafBase 

37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

38import lsst.geom as geom 

39import lsst.meas.base.tests 

40from lsst.sphgeom import Box 

41import lsst.utils.tests 

42import utils_tests 

43 

44 

45def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

46 """Run object and source catalogs through the Apdb to get the correct 

47 table schemas. 

48 

49 Parameters 

50 ---------- 

51 objects : `pandas.DataFrame` 

52 Set of test DiaObjects to round trip. 

53 sources : `pandas.DataFrame` 

54 Set of test DiaSources to round trip. 

55 forcedSources : `pandas.DataFrame` 

56 Set of test DiaForcedSources to round trip. 

57 dateTime : `lsst.daf.base.DateTime` 

58 Time for the Apdb. 

59 

60 Returns 

61 ------- 

62 objects : `pandas.DataFrame` 

63 Round tripped objects. 

64 sources : `pandas.DataFrame` 

65 Round tripped sources. 

66 """ 

67 tmpFile = tempfile.NamedTemporaryFile() 

68 

69 apdbConfig = ApdbSqlConfig() 

70 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

71 apdbConfig.dia_object_index = "baseline" 

72 apdbConfig.dia_object_columns = [] 

73 

74 apdb = ApdbSql(config=apdbConfig) 

75 apdb.makeSchema() 

76 

77 wholeSky = Box.full() 

78 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects]) 

79 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources]) 

80 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources]) 

81 

82 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

83 

84 diaObjects = apdb.getDiaObjects(wholeSky) 

85 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

86 diaForcedSources = apdb.getDiaForcedSources( 

87 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

88 

89 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

90 diaSources.set_index(["diaObjectId", "band", "diaSourceId"], 

91 drop=False, 

92 inplace=True) 

93 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

94 

95 return (diaObjects, diaSources, diaForcedSources) 

96 

97 

98class TestPackageAlerts(lsst.utils.tests.TestCase): 

99 

100 def setUp(self): 

101 np.random.seed(1234) 

102 self.cutoutSize = 35 

103 self.center = lsst.geom.Point2D(50.1, 49.8) 

104 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

105 lsst.geom.Extent2I(140, 160)) 

106 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

107 self.dataset.addSource(100000.0, self.center) 

108 exposure, catalog = self.dataset.realize( 

109 10.0, 

110 self.dataset.makeMinimalSchema(), 

111 randomSeed=0) 

112 self.exposure = exposure 

113 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

114 self.exposure.setDetector(detector) 

115 

116 visit = afwImage.VisitInfo( 

117 exposureTime=200., 

118 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

119 dafBase.DateTime.Timescale.TAI)) 

120 self.exposure.info.id = 1234 

121 self.exposure.info.setVisitInfo(visit) 

122 

123 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

124 

125 diaObjects = utils_tests.makeDiaObjects(2, self.exposure) 

126 diaSourceHistory = utils_tests.makeDiaSources(10, 

127 diaObjects["diaObjectId"], 

128 self.exposure) 

129 diaForcedSources = utils_tests.makeDiaForcedSources(10, 

130 diaObjects["diaObjectId"], 

131 self.exposure) 

132 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

133 diaObjects, 

134 diaSourceHistory, 

135 diaForcedSources, 

136 self.exposure.visitInfo.date) 

137 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

138 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

139 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

140 diaSourceHistory["programId"] = 0 

141 

142 self.diaSources = diaSourceHistory.loc[ 

143 [(1, "g", 9), (2, "g", 10)], :] 

144 self.diaSources["bboxSize"] = self.cutoutSize 

145 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

146 (2, "g", 10)]) 

147 

148 self.cutoutWcs = wcs.WCS(naxis=2) 

149 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

150 self.cutoutWcs.wcs.crval = [ 

151 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

152 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

153 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

154 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

155 

156 def testCreateExtent(self): 

157 """Test the extent creation for the cutout bbox. 

158 """ 

159 packConfig = PackageAlertsConfig() 

160 # Just create a minimum less than the default cutout. 

161 packConfig.minCutoutSize = self.cutoutSize - 5 

162 packageAlerts = PackageAlertsTask(config=packConfig) 

163 extent = packageAlerts.createDiaSourceExtent( 

164 packConfig.minCutoutSize - 5) 

165 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

166 packConfig.minCutoutSize)) 

167 # Test that the cutout size is correct. 

168 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

169 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

170 self.cutoutSize)) 

171 

172 def testCreateCcdDataCutout(self): 

173 """Test that the data is being extracted into the CCDData cutout 

174 correctly. 

175 """ 

176 packageAlerts = PackageAlertsTask() 

177 

178 diaSrcId = 1234 

179 ccdData = packageAlerts.createCcdDataCutout( 

180 self.exposure, 

181 self.exposure.getWcs().getSkyOrigin(), 

182 self.exposure.getBBox().getDimensions(), 

183 self.exposure.getPhotoCalib(), 

184 diaSrcId) 

185 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

186 self.exposure.getMaskedImage()) 

187 

188 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

189 self.cutoutWcs.wcs.cd) 

190 self.assertFloatsAlmostEqual(ccdData.data, 

191 calibExposure.getImage().array) 

192 

193 ccdData = packageAlerts.createCcdDataCutout( 

194 self.exposure, 

195 geom.SpherePoint(0, 0, geom.degrees), 

196 self.exposure.getBBox().getDimensions(), 

197 self.exposure.getPhotoCalib(), 

198 diaSrcId) 

199 self.assertTrue(ccdData is None) 

200 

201 def testMakeLocalTransformMatrix(self): 

202 """Test that the local WCS approximation is correct. 

203 """ 

204 packageAlerts = PackageAlertsTask() 

205 

206 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

207 cutout = self.exposure.getCutout(sphPoint, 

208 geom.Extent2I(self.cutoutSize, 

209 self.cutoutSize)) 

210 cd = packageAlerts.makeLocalTransformMatrix( 

211 cutout.getWcs(), self.center, sphPoint) 

212 self.assertFloatsAlmostEqual( 

213 cd, 

214 cutout.getWcs().getCdMatrix(), 

215 rtol=1e-11, 

216 atol=1e-11) 

217 

218 def testStreamCcdDataToBytes(self): 

219 """Test round tripping an CCDData cutout to bytes and back. 

220 """ 

221 packageAlerts = PackageAlertsTask() 

222 

223 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

224 cutout = self.exposure.getCutout(sphPoint, 

225 geom.Extent2I(self.cutoutSize, 

226 self.cutoutSize)) 

227 cutoutCcdData = CCDData( 

228 data=cutout.getImage().array, 

229 wcs=self.cutoutWcs, 

230 unit="adu") 

231 

232 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

233 with io.BytesIO(cutoutBytes) as bytesIO: 

234 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

235 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

236 

237 def testMakeAlertDict(self): 

238 """Test stripping data from the various data products and into a 

239 dictionary "alert". 

240 """ 

241 packageAlerts = PackageAlertsTask() 

242 alertId = 1234 

243 

244 for srcIdx, diaSource in self.diaSources.iterrows(): 

245 sphPoint = geom.SpherePoint(diaSource["ra"], 

246 diaSource["dec"], 

247 geom.degrees) 

248 cutout = self.exposure.getCutout(sphPoint, 

249 geom.Extent2I(self.cutoutSize, 

250 self.cutoutSize)) 

251 ccdCutout = packageAlerts.createCcdDataCutout( 

252 cutout, 

253 sphPoint, 

254 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

255 cutout.getPhotoCalib(), 

256 1234) 

257 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

258 ccdCutout) 

259 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

260 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

261 alert = packageAlerts.makeAlertDict( 

262 alertId, 

263 diaSource, 

264 self.diaObjects.loc[srcIdx[0]], 

265 objSources, 

266 objForcedSources, 

267 ccdCutout, 

268 ccdCutout) 

269 self.assertEqual(len(alert), 9) 

270 

271 self.assertEqual(alert["alertId"], alertId) 

272 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

273 self.assertEqual(alert["cutoutDifference"], 

274 cutoutBytes) 

275 self.assertEqual(alert["cutoutTemplate"], 

276 cutoutBytes) 

277 

278 def testRun(self): 

279 """Test the run method of package alerts. 

280 """ 

281 packConfig = PackageAlertsConfig() 

282 tempdir = tempfile.mkdtemp(prefix='alerts') 

283 packConfig.alertWriteLocation = tempdir 

284 packageAlerts = PackageAlertsTask(config=packConfig) 

285 

286 packageAlerts.run(self.diaSources, 

287 self.diaObjects, 

288 self.diaSourceHistory, 

289 self.diaForcedSources, 

290 self.exposure, 

291 self.exposure, 

292 None) 

293 

294 ccdVisitId = self.exposure.info.id 

295 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

296 writer_schema, data_stream = \ 

297 packageAlerts.alertSchema.retrieve_alerts(f) 

298 data = list(data_stream) 

299 self.assertEqual(len(data), len(self.diaSources)) 

300 for idx, alert in enumerate(data): 

301 for key, value in alert["diaSource"].items(): 

302 if isinstance(value, float): 

303 if np.isnan(self.diaSources.iloc[idx][key]): 

304 self.assertTrue(np.isnan(value)) 

305 else: 

306 self.assertAlmostEqual( 

307 1 - value / self.diaSources.iloc[idx][key], 

308 0.) 

309 else: 

310 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

311 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

312 alert["diaSource"]["dec"], 

313 geom.degrees) 

314 cutout = self.exposure.getCutout(sphPoint, 

315 geom.Extent2I(self.cutoutSize, 

316 self.cutoutSize)) 

317 ccdCutout = packageAlerts.createCcdDataCutout( 

318 cutout, 

319 sphPoint, 

320 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

321 cutout.getPhotoCalib(), 

322 1234) 

323 self.assertEqual(alert["cutoutDifference"], 

324 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

325 

326 shutil.rmtree(tempdir) 

327 

328 

329class MemoryTester(lsst.utils.tests.MemoryTestCase): 

330 pass 

331 

332 

333def setup_module(module): 

334 lsst.utils.tests.init() 

335 

336 

337if __name__ == "__main__": 337 ↛ 338line 337 didn't jump to line 338, because the condition on line 337 was never true

338 lsst.utils.tests.init() 

339 unittest.main()