Coverage for tests/test_packageAlerts.py: 18%

186 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-19 02:58 -0800

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

34from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

35import lsst.afw.image as afwImage 

36import lsst.daf.base as dafBase 

37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

38import lsst.geom as geom 

39import lsst.meas.base.tests 

40from lsst.sphgeom import Box 

41from lsst.utils import getPackageDir 

42import lsst.utils.tests 

43 

44 

45def _data_file_name(basename, module_name): 

46 """Return path name of a data file. 

47 

48 Parameters 

49 ---------- 

50 basename : `str` 

51 Name of the file to add to the path string. 

52 module_name : `str` 

53 Name of lsst stack package environment variable. 

54 

55 Returns 

56 ------- 

57 data_file_path : `str` 

58 Full path of the file to load from the "data" directory in a given 

59 repository. 

60 """ 

61 return os.path.join(getPackageDir(module_name), "data", basename) 

62 

63 

64def makeDiaObjects(nObjects, exposure): 

65 """Make a test set of DiaObjects. 

66 

67 Parameters 

68 ---------- 

69 nObjects : `int` 

70 Number of objects to create. 

71 exposure : `lsst.afw.image.Exposure` 

72 Exposure to create objects over. 

73 

74 Returns 

75 ------- 

76 diaObjects : `pandas.DataFrame` 

77 DiaObjects generated across the exposure. 

78 """ 

79 bbox = geom.Box2D(exposure.getBBox()) 

80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

82 

83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

84 system=dafBase.DateTime.MJD) 

85 

86 wcs = exposure.getWcs() 

87 

88 data = [] 

89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

90 coord = wcs.pixelToSky(x, y) 

91 newObject = {"ra": coord.getRa().asDegrees(), 

92 "decl": coord.getDec().asDegrees(), 

93 "radecTai": midPointTaiMJD, 

94 "diaObjectId": idx + 1, 

95 "pmParallaxNdata": 0, 

96 "nearbyObj1": 0, 

97 "nearbyObj2": 0, 

98 "nearbyObj3": 0, 

99 "flags": 1, 

100 "nDiaSources": 5} 

101 for f in ["u", "g", "r", "i", "z", "y"]: 

102 newObject["%sPSFluxNdata" % f] = 0 

103 data.append(newObject) 

104 

105 return pd.DataFrame(data=data) 

106 

107 

108def makeDiaSources(nSources, diaObjectIds, exposure): 

109 """Make a test set of DiaSources. 

110 

111 Parameters 

112 ---------- 

113 nSources : `int` 

114 Number of sources to create. 

115 diaObjectIds : `numpy.ndarray` 

116 Integer Ids of diaobjects to "associate" with the DiaSources. 

117 exposure : `lsst.afw.image.Exposure` 

118 Exposure to create sources over. 

119 

120 Returns 

121 ------- 

122 diaSources : `pandas.DataFrame` 

123 DiaSources generated across the exposure. 

124 """ 

125 bbox = geom.Box2D(exposure.getBBox()) 

126 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

127 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

128 

129 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

130 system=dafBase.DateTime.MJD) 

131 

132 wcs = exposure.getWcs() 

133 ccdVisitId = exposure.info.id 

134 

135 data = [] 

136 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

137 coord = wcs.pixelToSky(x, y) 

138 objId = diaObjectIds[idx % len(diaObjectIds)] 

139 # Put together the minimum values for the alert. 

140 data.append({"ra": coord.getRa().asDegrees(), 

141 "decl": coord.getDec().asDegrees(), 

142 "x": x, 

143 "y": y, 

144 "ccdVisitId": ccdVisitId, 

145 "diaObjectId": objId, 

146 "ssObjectId": 0, 

147 "parentDiaSourceId": 0, 

148 "prv_procOrder": 0, 

149 "diaSourceId": idx + 1, 

150 "midPointTai": midPointTaiMJD + 1.0 * idx, 

151 "filterName": exposure.getFilter().bandLabel, 

152 "psNdata": 0, 

153 "trailNdata": 0, 

154 "dipNdata": 0, 

155 "flags": 1}) 

156 

157 return pd.DataFrame(data=data) 

158 

159 

160def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

161 """Make a test set of DiaSources. 

162 

163 Parameters 

164 ---------- 

165 nSources : `int` 

166 Number of sources to create. 

167 diaObjectIds : `numpy.ndarray` 

168 Integer Ids of diaobjects to "associate" with the DiaSources. 

169 exposure : `lsst.afw.image.Exposure` 

170 Exposure to create sources over. 

171 

172 Returns 

173 ------- 

174 diaSources : `pandas.DataFrame` 

175 DiaSources generated across the exposure. 

176 """ 

177 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

178 system=dafBase.DateTime.MJD) 

179 

180 ccdVisitId = exposure.info.id 

181 

182 data = [] 

183 for idx in range(nSources): 

184 objId = diaObjectIds[idx % len(diaObjectIds)] 

185 # Put together the minimum values for the alert. 

186 data.append({"diaForcedSourceId": idx + 1, 

187 "ccdVisitId": ccdVisitId + idx, 

188 "diaObjectId": objId, 

189 "midPointTai": midPointTaiMJD + 1.0 * idx, 

190 "filterName": exposure.getFilter().bandLabel, 

191 "flags": 0}) 

192 

193 return pd.DataFrame(data=data) 

194 

195 

196def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

197 """Run object and source catalogs through the Apdb to get the correct 

198 table schemas. 

199 

200 Parameters 

201 ---------- 

202 objects : `pandas.DataFrame` 

203 Set of test DiaObjects to round trip. 

204 sources : `pandas.DataFrame` 

205 Set of test DiaSources to round trip. 

206 forcedSources : `pandas.DataFrame` 

207 Set of test DiaForcedSources to round trip. 

208 dateTime : `lsst.daf.base.DateTime` 

209 Time for the Apdb. 

210 

211 Returns 

212 ------- 

213 objects : `pandas.DataFrame` 

214 Round tripped objects. 

215 sources : `pandas.DataFrame` 

216 Round tripped sources. 

217 """ 

218 tmpFile = tempfile.NamedTemporaryFile() 

219 

220 apdbConfig = ApdbSqlConfig() 

221 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

222 apdbConfig.dia_object_index = "baseline" 

223 apdbConfig.dia_object_columns = [] 

224 

225 apdb = ApdbSql(config=apdbConfig) 

226 apdb.makeSchema() 

227 

228 wholeSky = Box.full() 

229 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects]) 

230 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources]) 

231 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources]) 

232 

233 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

234 

235 diaObjects = apdb.getDiaObjects(wholeSky) 

236 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

237 diaForcedSources = apdb.getDiaForcedSources( 

238 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

239 

240 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

241 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

242 drop=False, 

243 inplace=True) 

244 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

245 

246 return (diaObjects, diaSources, diaForcedSources) 

247 

248 

249class TestPackageAlerts(lsst.utils.tests.TestCase): 

250 

251 def setUp(self): 

252 np.random.seed(1234) 

253 self.cutoutSize = 35 

254 self.center = lsst.geom.Point2D(50.1, 49.8) 

255 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

256 lsst.geom.Extent2I(140, 160)) 

257 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

258 self.dataset.addSource(100000.0, self.center) 

259 exposure, catalog = self.dataset.realize( 

260 10.0, 

261 self.dataset.makeMinimalSchema(), 

262 randomSeed=0) 

263 self.exposure = exposure 

264 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

265 self.exposure.setDetector(detector) 

266 

267 visit = afwImage.VisitInfo( 

268 exposureId=1234, 

269 exposureTime=200., 

270 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

271 dafBase.DateTime.Timescale.TAI)) 

272 self.exposure.info.id = 1234 

273 self.exposure.getInfo().setVisitInfo(visit) 

274 

275 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

276 

277 diaObjects = makeDiaObjects(2, self.exposure) 

278 diaSourceHistory = makeDiaSources(10, 

279 diaObjects["diaObjectId"], 

280 self.exposure) 

281 diaForcedSources = makeDiaForcedSources(10, 

282 diaObjects["diaObjectId"], 

283 self.exposure) 

284 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

285 diaObjects, 

286 diaSourceHistory, 

287 diaForcedSources, 

288 self.exposure.getInfo().getVisitInfo().getDate()) 

289 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

290 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

291 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

292 diaSourceHistory["programId"] = 0 

293 

294 self.diaSources = diaSourceHistory.loc[ 

295 [(1, "g", 9), (2, "g", 10)], :] 

296 self.diaSources["bboxSize"] = self.cutoutSize 

297 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

298 (2, "g", 10)]) 

299 

300 self.cutoutWcs = wcs.WCS(naxis=2) 

301 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

302 self.cutoutWcs.wcs.crval = [ 

303 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

304 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

305 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

306 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

307 

308 def testCreateExtent(self): 

309 """Test the extent creation for the cutout bbox. 

310 """ 

311 packConfig = PackageAlertsConfig() 

312 # Just create a minimum less than the default cutout. 

313 packConfig.minCutoutSize = self.cutoutSize - 5 

314 packageAlerts = PackageAlertsTask(config=packConfig) 

315 extent = packageAlerts.createDiaSourceExtent( 

316 packConfig.minCutoutSize - 5) 

317 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

318 packConfig.minCutoutSize)) 

319 # Test that the cutout size is correct. 

320 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

321 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

322 self.cutoutSize)) 

323 

324 def testCreateCcdDataCutout(self): 

325 """Test that the data is being extracted into the CCDData cutout 

326 correctly. 

327 """ 

328 packageAlerts = PackageAlertsTask() 

329 

330 diaSrcId = 1234 

331 ccdData = packageAlerts.createCcdDataCutout( 

332 self.exposure, 

333 self.exposure.getWcs().getSkyOrigin(), 

334 self.exposure.getBBox().getDimensions(), 

335 self.exposure.getPhotoCalib(), 

336 diaSrcId) 

337 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

338 self.exposure.getMaskedImage()) 

339 

340 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

341 self.cutoutWcs.wcs.cd) 

342 self.assertFloatsAlmostEqual(ccdData.data, 

343 calibExposure.getImage().array) 

344 

345 ccdData = packageAlerts.createCcdDataCutout( 

346 self.exposure, 

347 geom.SpherePoint(0, 0, geom.degrees), 

348 self.exposure.getBBox().getDimensions(), 

349 self.exposure.getPhotoCalib(), 

350 diaSrcId) 

351 self.assertTrue(ccdData is None) 

352 

353 def testMakeLocalTransformMatrix(self): 

354 """Test that the local WCS approximation is correct. 

355 """ 

356 packageAlerts = PackageAlertsTask() 

357 

358 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

359 cutout = self.exposure.getCutout(sphPoint, 

360 geom.Extent2I(self.cutoutSize, 

361 self.cutoutSize)) 

362 cd = packageAlerts.makeLocalTransformMatrix( 

363 cutout.getWcs(), self.center, sphPoint) 

364 self.assertFloatsAlmostEqual( 

365 cd, 

366 cutout.getWcs().getCdMatrix(), 

367 rtol=1e-11, 

368 atol=1e-11) 

369 

370 def testStreamCcdDataToBytes(self): 

371 """Test round tripping an CCDData cutout to bytes and back. 

372 """ 

373 packageAlerts = PackageAlertsTask() 

374 

375 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

376 cutout = self.exposure.getCutout(sphPoint, 

377 geom.Extent2I(self.cutoutSize, 

378 self.cutoutSize)) 

379 cutoutCcdData = CCDData( 

380 data=cutout.getImage().array, 

381 wcs=self.cutoutWcs, 

382 unit="adu") 

383 

384 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

385 with io.BytesIO(cutoutBytes) as bytesIO: 

386 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

387 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

388 

389 def testMakeAlertDict(self): 

390 """Test stripping data from the various data products and into a 

391 dictionary "alert". 

392 """ 

393 packageAlerts = PackageAlertsTask() 

394 alertId = 1234 

395 

396 for srcIdx, diaSource in self.diaSources.iterrows(): 

397 sphPoint = geom.SpherePoint(diaSource["ra"], 

398 diaSource["decl"], 

399 geom.degrees) 

400 cutout = self.exposure.getCutout(sphPoint, 

401 geom.Extent2I(self.cutoutSize, 

402 self.cutoutSize)) 

403 ccdCutout = packageAlerts.createCcdDataCutout( 

404 cutout, 

405 sphPoint, 

406 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

407 cutout.getPhotoCalib(), 

408 1234) 

409 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

410 ccdCutout) 

411 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

412 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

413 alert = packageAlerts.makeAlertDict( 

414 alertId, 

415 diaSource, 

416 self.diaObjects.loc[srcIdx[0]], 

417 objSources, 

418 objForcedSources, 

419 ccdCutout, 

420 ccdCutout) 

421 self.assertEqual(len(alert), 9) 

422 

423 self.assertEqual(alert["alertId"], alertId) 

424 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

425 self.assertEqual(alert["cutoutDifference"], 

426 cutoutBytes) 

427 self.assertEqual(alert["cutoutTemplate"], 

428 cutoutBytes) 

429 

430 def testRun(self): 

431 """Test the run method of package alerts. 

432 """ 

433 packConfig = PackageAlertsConfig() 

434 tempdir = tempfile.mkdtemp(prefix='alerts') 

435 packConfig.alertWriteLocation = tempdir 

436 packageAlerts = PackageAlertsTask(config=packConfig) 

437 

438 packageAlerts.run(self.diaSources, 

439 self.diaObjects, 

440 self.diaSourceHistory, 

441 self.diaForcedSources, 

442 self.exposure, 

443 self.exposure, 

444 None) 

445 

446 ccdVisitId = self.exposure.info.id 

447 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

448 writer_schema, data_stream = \ 

449 packageAlerts.alertSchema.retrieve_alerts(f) 

450 data = list(data_stream) 

451 self.assertEqual(len(data), len(self.diaSources)) 

452 for idx, alert in enumerate(data): 

453 for key, value in alert["diaSource"].items(): 

454 if isinstance(value, float): 

455 if np.isnan(self.diaSources.iloc[idx][key]): 

456 self.assertTrue(np.isnan(value)) 

457 else: 

458 self.assertAlmostEqual( 

459 1 - value / self.diaSources.iloc[idx][key], 

460 0.) 

461 else: 

462 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

463 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

464 alert["diaSource"]["decl"], 

465 geom.degrees) 

466 cutout = self.exposure.getCutout(sphPoint, 

467 geom.Extent2I(self.cutoutSize, 

468 self.cutoutSize)) 

469 ccdCutout = packageAlerts.createCcdDataCutout( 

470 cutout, 

471 sphPoint, 

472 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

473 cutout.getPhotoCalib(), 

474 1234) 

475 self.assertEqual(alert["cutoutDifference"], 

476 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

477 

478 shutil.rmtree(tempdir) 

479 

480 

481class MemoryTester(lsst.utils.tests.MemoryTestCase): 

482 pass 

483 

484 

485def setup_module(module): 

486 lsst.utils.tests.init() 

487 

488 

489if __name__ == "__main__": 489 ↛ 490line 489 didn't jump to line 490, because the condition on line 489 was never true

490 lsst.utils.tests.init() 

491 unittest.main()