Coverage for tests/test_packageAlerts.py: 17%

183 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-28 04:57 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

34from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

35import lsst.afw.image as afwImage 

36import lsst.daf.base as dafBase 

37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

38import lsst.geom as geom 

39import lsst.meas.base.tests 

40from lsst.sphgeom import Box 

41import lsst.utils.tests 

42 

43 

44def makeDiaObjects(nObjects, exposure): 

45 """Make a test set of DiaObjects. 

46 

47 Parameters 

48 ---------- 

49 nObjects : `int` 

50 Number of objects to create. 

51 exposure : `lsst.afw.image.Exposure` 

52 Exposure to create objects over. 

53 

54 Returns 

55 ------- 

56 diaObjects : `pandas.DataFrame` 

57 DiaObjects generated across the exposure. 

58 """ 

59 bbox = geom.Box2D(exposure.getBBox()) 

60 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

61 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

62 

63 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

64 system=dafBase.DateTime.MJD) 

65 

66 wcs = exposure.getWcs() 

67 

68 data = [] 

69 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

70 coord = wcs.pixelToSky(x, y) 

71 newObject = {"ra": coord.getRa().asDegrees(), 

72 "decl": coord.getDec().asDegrees(), 

73 "radecTai": midPointTaiMJD, 

74 "diaObjectId": idx + 1, 

75 "pmParallaxNdata": 0, 

76 "nearbyObj1": 0, 

77 "nearbyObj2": 0, 

78 "nearbyObj3": 0, 

79 "flags": 1, 

80 "nDiaSources": 5} 

81 for f in ["u", "g", "r", "i", "z", "y"]: 

82 newObject["%sPSFluxNdata" % f] = 0 

83 data.append(newObject) 

84 

85 return pd.DataFrame(data=data) 

86 

87 

88def makeDiaSources(nSources, diaObjectIds, exposure): 

89 """Make a test set of DiaSources. 

90 

91 Parameters 

92 ---------- 

93 nSources : `int` 

94 Number of sources to create. 

95 diaObjectIds : `numpy.ndarray` 

96 Integer Ids of diaobjects to "associate" with the DiaSources. 

97 exposure : `lsst.afw.image.Exposure` 

98 Exposure to create sources over. 

99 

100 Returns 

101 ------- 

102 diaSources : `pandas.DataFrame` 

103 DiaSources generated across the exposure. 

104 """ 

105 bbox = geom.Box2D(exposure.getBBox()) 

106 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

107 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

108 

109 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

110 system=dafBase.DateTime.MJD) 

111 

112 wcs = exposure.getWcs() 

113 ccdVisitId = exposure.info.id 

114 

115 data = [] 

116 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

117 coord = wcs.pixelToSky(x, y) 

118 objId = diaObjectIds[idx % len(diaObjectIds)] 

119 # Put together the minimum values for the alert. 

120 data.append({"ra": coord.getRa().asDegrees(), 

121 "decl": coord.getDec().asDegrees(), 

122 "x": x, 

123 "y": y, 

124 "ccdVisitId": ccdVisitId, 

125 "diaObjectId": objId, 

126 "ssObjectId": 0, 

127 "parentDiaSourceId": 0, 

128 "prv_procOrder": 0, 

129 "diaSourceId": idx + 1, 

130 "midPointTai": midPointTaiMJD + 1.0 * idx, 

131 "filterName": exposure.getFilter().bandLabel, 

132 "psNdata": 0, 

133 "trailNdata": 0, 

134 "dipNdata": 0, 

135 "flags": 1}) 

136 

137 return pd.DataFrame(data=data) 

138 

139 

140def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

141 """Make a test set of DiaSources. 

142 

143 Parameters 

144 ---------- 

145 nSources : `int` 

146 Number of sources to create. 

147 diaObjectIds : `numpy.ndarray` 

148 Integer Ids of diaobjects to "associate" with the DiaSources. 

149 exposure : `lsst.afw.image.Exposure` 

150 Exposure to create sources over. 

151 

152 Returns 

153 ------- 

154 diaSources : `pandas.DataFrame` 

155 DiaSources generated across the exposure. 

156 """ 

157 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

158 system=dafBase.DateTime.MJD) 

159 

160 ccdVisitId = exposure.info.id 

161 

162 data = [] 

163 for idx in range(nSources): 

164 objId = diaObjectIds[idx % len(diaObjectIds)] 

165 # Put together the minimum values for the alert. 

166 data.append({"diaForcedSourceId": idx + 1, 

167 "ccdVisitId": ccdVisitId + idx, 

168 "diaObjectId": objId, 

169 "midPointTai": midPointTaiMJD + 1.0 * idx, 

170 "filterName": exposure.getFilter().bandLabel, 

171 "flags": 0}) 

172 

173 return pd.DataFrame(data=data) 

174 

175 

176def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

177 """Run object and source catalogs through the Apdb to get the correct 

178 table schemas. 

179 

180 Parameters 

181 ---------- 

182 objects : `pandas.DataFrame` 

183 Set of test DiaObjects to round trip. 

184 sources : `pandas.DataFrame` 

185 Set of test DiaSources to round trip. 

186 forcedSources : `pandas.DataFrame` 

187 Set of test DiaForcedSources to round trip. 

188 dateTime : `lsst.daf.base.DateTime` 

189 Time for the Apdb. 

190 

191 Returns 

192 ------- 

193 objects : `pandas.DataFrame` 

194 Round tripped objects. 

195 sources : `pandas.DataFrame` 

196 Round tripped sources. 

197 """ 

198 tmpFile = tempfile.NamedTemporaryFile() 

199 

200 apdbConfig = ApdbSqlConfig() 

201 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

202 apdbConfig.dia_object_index = "baseline" 

203 apdbConfig.dia_object_columns = [] 

204 

205 apdb = ApdbSql(config=apdbConfig) 

206 apdb.makeSchema() 

207 

208 wholeSky = Box.full() 

209 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects]) 

210 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources]) 

211 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources]) 

212 

213 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

214 

215 diaObjects = apdb.getDiaObjects(wholeSky) 

216 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

217 diaForcedSources = apdb.getDiaForcedSources( 

218 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

219 

220 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

221 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

222 drop=False, 

223 inplace=True) 

224 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

225 

226 return (diaObjects, diaSources, diaForcedSources) 

227 

228 

229class TestPackageAlerts(lsst.utils.tests.TestCase): 

230 

231 def setUp(self): 

232 np.random.seed(1234) 

233 self.cutoutSize = 35 

234 self.center = lsst.geom.Point2D(50.1, 49.8) 

235 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

236 lsst.geom.Extent2I(140, 160)) 

237 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

238 self.dataset.addSource(100000.0, self.center) 

239 exposure, catalog = self.dataset.realize( 

240 10.0, 

241 self.dataset.makeMinimalSchema(), 

242 randomSeed=0) 

243 self.exposure = exposure 

244 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

245 self.exposure.setDetector(detector) 

246 

247 visit = afwImage.VisitInfo( 

248 exposureId=1234, 

249 exposureTime=200., 

250 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

251 dafBase.DateTime.Timescale.TAI)) 

252 self.exposure.info.id = 1234 

253 self.exposure.getInfo().setVisitInfo(visit) 

254 

255 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

256 

257 diaObjects = makeDiaObjects(2, self.exposure) 

258 diaSourceHistory = makeDiaSources(10, 

259 diaObjects["diaObjectId"], 

260 self.exposure) 

261 diaForcedSources = makeDiaForcedSources(10, 

262 diaObjects["diaObjectId"], 

263 self.exposure) 

264 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

265 diaObjects, 

266 diaSourceHistory, 

267 diaForcedSources, 

268 self.exposure.getInfo().getVisitInfo().getDate()) 

269 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

270 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

271 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

272 diaSourceHistory["programId"] = 0 

273 

274 self.diaSources = diaSourceHistory.loc[ 

275 [(1, "g", 9), (2, "g", 10)], :] 

276 self.diaSources["bboxSize"] = self.cutoutSize 

277 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

278 (2, "g", 10)]) 

279 

280 self.cutoutWcs = wcs.WCS(naxis=2) 

281 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

282 self.cutoutWcs.wcs.crval = [ 

283 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

284 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

285 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

286 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

287 

288 def testCreateExtent(self): 

289 """Test the extent creation for the cutout bbox. 

290 """ 

291 packConfig = PackageAlertsConfig() 

292 # Just create a minimum less than the default cutout. 

293 packConfig.minCutoutSize = self.cutoutSize - 5 

294 packageAlerts = PackageAlertsTask(config=packConfig) 

295 extent = packageAlerts.createDiaSourceExtent( 

296 packConfig.minCutoutSize - 5) 

297 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

298 packConfig.minCutoutSize)) 

299 # Test that the cutout size is correct. 

300 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

301 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

302 self.cutoutSize)) 

303 

304 def testCreateCcdDataCutout(self): 

305 """Test that the data is being extracted into the CCDData cutout 

306 correctly. 

307 """ 

308 packageAlerts = PackageAlertsTask() 

309 

310 diaSrcId = 1234 

311 ccdData = packageAlerts.createCcdDataCutout( 

312 self.exposure, 

313 self.exposure.getWcs().getSkyOrigin(), 

314 self.exposure.getBBox().getDimensions(), 

315 self.exposure.getPhotoCalib(), 

316 diaSrcId) 

317 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

318 self.exposure.getMaskedImage()) 

319 

320 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

321 self.cutoutWcs.wcs.cd) 

322 self.assertFloatsAlmostEqual(ccdData.data, 

323 calibExposure.getImage().array) 

324 

325 ccdData = packageAlerts.createCcdDataCutout( 

326 self.exposure, 

327 geom.SpherePoint(0, 0, geom.degrees), 

328 self.exposure.getBBox().getDimensions(), 

329 self.exposure.getPhotoCalib(), 

330 diaSrcId) 

331 self.assertTrue(ccdData is None) 

332 

333 def testMakeLocalTransformMatrix(self): 

334 """Test that the local WCS approximation is correct. 

335 """ 

336 packageAlerts = PackageAlertsTask() 

337 

338 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

339 cutout = self.exposure.getCutout(sphPoint, 

340 geom.Extent2I(self.cutoutSize, 

341 self.cutoutSize)) 

342 cd = packageAlerts.makeLocalTransformMatrix( 

343 cutout.getWcs(), self.center, sphPoint) 

344 self.assertFloatsAlmostEqual( 

345 cd, 

346 cutout.getWcs().getCdMatrix(), 

347 rtol=1e-11, 

348 atol=1e-11) 

349 

350 def testStreamCcdDataToBytes(self): 

351 """Test round tripping an CCDData cutout to bytes and back. 

352 """ 

353 packageAlerts = PackageAlertsTask() 

354 

355 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

356 cutout = self.exposure.getCutout(sphPoint, 

357 geom.Extent2I(self.cutoutSize, 

358 self.cutoutSize)) 

359 cutoutCcdData = CCDData( 

360 data=cutout.getImage().array, 

361 wcs=self.cutoutWcs, 

362 unit="adu") 

363 

364 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

365 with io.BytesIO(cutoutBytes) as bytesIO: 

366 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

367 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

368 

369 def testMakeAlertDict(self): 

370 """Test stripping data from the various data products and into a 

371 dictionary "alert". 

372 """ 

373 packageAlerts = PackageAlertsTask() 

374 alertId = 1234 

375 

376 for srcIdx, diaSource in self.diaSources.iterrows(): 

377 sphPoint = geom.SpherePoint(diaSource["ra"], 

378 diaSource["decl"], 

379 geom.degrees) 

380 cutout = self.exposure.getCutout(sphPoint, 

381 geom.Extent2I(self.cutoutSize, 

382 self.cutoutSize)) 

383 ccdCutout = packageAlerts.createCcdDataCutout( 

384 cutout, 

385 sphPoint, 

386 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

387 cutout.getPhotoCalib(), 

388 1234) 

389 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

390 ccdCutout) 

391 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

392 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

393 alert = packageAlerts.makeAlertDict( 

394 alertId, 

395 diaSource, 

396 self.diaObjects.loc[srcIdx[0]], 

397 objSources, 

398 objForcedSources, 

399 ccdCutout, 

400 ccdCutout) 

401 self.assertEqual(len(alert), 9) 

402 

403 self.assertEqual(alert["alertId"], alertId) 

404 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

405 self.assertEqual(alert["cutoutDifference"], 

406 cutoutBytes) 

407 self.assertEqual(alert["cutoutTemplate"], 

408 cutoutBytes) 

409 

410 def testRun(self): 

411 """Test the run method of package alerts. 

412 """ 

413 packConfig = PackageAlertsConfig() 

414 tempdir = tempfile.mkdtemp(prefix='alerts') 

415 packConfig.alertWriteLocation = tempdir 

416 packageAlerts = PackageAlertsTask(config=packConfig) 

417 

418 packageAlerts.run(self.diaSources, 

419 self.diaObjects, 

420 self.diaSourceHistory, 

421 self.diaForcedSources, 

422 self.exposure, 

423 self.exposure, 

424 None) 

425 

426 ccdVisitId = self.exposure.info.id 

427 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

428 writer_schema, data_stream = \ 

429 packageAlerts.alertSchema.retrieve_alerts(f) 

430 data = list(data_stream) 

431 self.assertEqual(len(data), len(self.diaSources)) 

432 for idx, alert in enumerate(data): 

433 for key, value in alert["diaSource"].items(): 

434 if isinstance(value, float): 

435 if np.isnan(self.diaSources.iloc[idx][key]): 

436 self.assertTrue(np.isnan(value)) 

437 else: 

438 self.assertAlmostEqual( 

439 1 - value / self.diaSources.iloc[idx][key], 

440 0.) 

441 else: 

442 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

443 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

444 alert["diaSource"]["decl"], 

445 geom.degrees) 

446 cutout = self.exposure.getCutout(sphPoint, 

447 geom.Extent2I(self.cutoutSize, 

448 self.cutoutSize)) 

449 ccdCutout = packageAlerts.createCcdDataCutout( 

450 cutout, 

451 sphPoint, 

452 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

453 cutout.getPhotoCalib(), 

454 1234) 

455 self.assertEqual(alert["cutoutDifference"], 

456 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

457 

458 shutil.rmtree(tempdir) 

459 

460 

461class MemoryTester(lsst.utils.tests.MemoryTestCase): 

462 pass 

463 

464 

465def setup_module(module): 

466 lsst.utils.tests.init() 

467 

468 

469if __name__ == "__main__": 469 ↛ 470line 469 didn't jump to line 470, because the condition on line 469 was never true

470 lsst.utils.tests.init() 

471 unittest.main()