Coverage for tests/test_packageAlerts.py: 17%

193 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-12 10:45 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.daf.base as dafBase 

40from lsst.dax.apdb import Apdb, ApdbConfig 

41import lsst.geom as geom 

42import lsst.meas.base.tests 

43from lsst.utils import getPackageDir 

44import lsst.utils.tests 

45 

46 

47def _data_file_name(basename, module_name): 

48 """Return path name of a data file. 

49 

50 Parameters 

51 ---------- 

52 basename : `str` 

53 Name of the file to add to the path string. 

54 module_name : `str` 

55 Name of lsst stack package environment variable. 

56 

57 Returns 

58 ------- 

59 data_file_path : `str` 

60 Full path of the file to load from the "data" directory in a given 

61 repository. 

62 """ 

63 return os.path.join(getPackageDir(module_name), "data", basename) 

64 

65 

66def makeDiaObjects(nObjects, exposure): 

67 """Make a test set of DiaObjects. 

68 

69 Parameters 

70 ---------- 

71 nObjects : `int` 

72 Number of objects to create. 

73 exposure : `lsst.afw.image.Exposure` 

74 Exposure to create objects over. 

75 

76 Returns 

77 ------- 

78 diaObjects : `pandas.DataFrame` 

79 DiaObjects generated across the exposure. 

80 """ 

81 bbox = geom.Box2D(exposure.getBBox()) 

82 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

83 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

84 

85 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

86 system=dafBase.DateTime.MJD) 

87 

88 wcs = exposure.getWcs() 

89 

90 data = [] 

91 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

92 coord = wcs.pixelToSky(x, y) 

93 htmIdx = 1 

94 newObject = {"ra": coord.getRa().asDegrees(), 

95 "decl": coord.getDec().asDegrees(), 

96 "radecTai": midPointTaiMJD, 

97 "diaObjectId": idx, 

98 "pixelId": htmIdx, 

99 "pmParallaxNdata": 0, 

100 "nearbyObj1": 0, 

101 "nearbyObj2": 0, 

102 "nearbyObj3": 0, 

103 "flags": 1, 

104 "nDiaSources": 5} 

105 for f in ["u", "g", "r", "i", "z", "y"]: 

106 newObject["%sPSFluxNdata" % f] = 0 

107 data.append(newObject) 

108 

109 return pd.DataFrame(data=data) 

110 

111 

112def makeDiaSources(nSources, diaObjectIds, exposure): 

113 """Make a test set of DiaSources. 

114 

115 Parameters 

116 ---------- 

117 nSources : `int` 

118 Number of sources to create. 

119 diaObjectIds : `numpy.ndarray` 

120 Integer Ids of diaobjects to "associate" with the DiaSources. 

121 exposure : `lsst.afw.image.Exposure` 

122 Exposure to create sources over. 

123 pixelator : `lsst.sphgeom.HtmPixelization` 

124 Object to compute spatial indicies from. 

125 

126 Returns 

127 ------- 

128 diaSources : `pandas.DataFrame` 

129 DiaSources generated across the exposure. 

130 """ 

131 bbox = geom.Box2D(exposure.getBBox()) 

132 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

133 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

134 

135 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

136 system=dafBase.DateTime.MJD) 

137 

138 wcs = exposure.getWcs() 

139 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

140 

141 data = [] 

142 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

143 coord = wcs.pixelToSky(x, y) 

144 htmIdx = 1 

145 objId = diaObjectIds[idx % len(diaObjectIds)] 

146 # Put together the minimum values for the alert. 

147 data.append({"ra": coord.getRa().asDegrees(), 

148 "decl": coord.getDec().asDegrees(), 

149 "x": x, 

150 "y": y, 

151 "ccdVisitId": ccdVisitId, 

152 "diaObjectId": objId, 

153 "ssObjectId": 0, 

154 "parentDiaSourceId": 0, 

155 "prv_procOrder": 0, 

156 "diaSourceId": idx, 

157 "pixelId": htmIdx, 

158 "midPointTai": midPointTaiMJD + 1.0 * idx, 

159 "filterName": exposure.getFilterLabel().bandLabel, 

160 "psNdata": 0, 

161 "trailNdata": 0, 

162 "dipNdata": 0, 

163 "flags": 1}) 

164 

165 return pd.DataFrame(data=data) 

166 

167 

168def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

169 """Make a test set of DiaSources. 

170 

171 Parameters 

172 ---------- 

173 nSources : `int` 

174 Number of sources to create. 

175 diaObjectIds : `numpy.ndarray` 

176 Integer Ids of diaobjects to "associate" with the DiaSources. 

177 exposure : `lsst.afw.image.Exposure` 

178 Exposure to create sources over. 

179 

180 Returns 

181 ------- 

182 diaSources : `pandas.DataFrame` 

183 DiaSources generated across the exposure. 

184 """ 

185 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

186 system=dafBase.DateTime.MJD) 

187 

188 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

189 

190 data = [] 

191 for idx in range(nSources): 

192 objId = diaObjectIds[idx % len(diaObjectIds)] 

193 # Put together the minimum values for the alert. 

194 data.append({"diaForcedSourceId": idx, 

195 "ccdVisitId": ccdVisitId + idx, 

196 "diaObjectId": objId, 

197 "midPointTai": midPointTaiMJD + 1.0 * idx, 

198 "filterName": exposure.getFilterLabel().bandLabel, 

199 "flags": 0}) 

200 

201 return pd.DataFrame(data=data) 

202 

203 

204def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

205 """Run object and source catalogs through the Apdb to get the correct 

206 table schemas. 

207 

208 Parameters 

209 ---------- 

210 objects : `pandas.DataFrame` 

211 Set of test DiaObjects to round trip. 

212 sources : `pandas.DataFrame` 

213 Set of test DiaSources to round trip. 

214 forcedSources : `pandas.DataFrame` 

215 Set of test DiaForcedSources to round trip. 

216 dateTime : `datetime.datetime` 

217 Time for the Apdb. 

218 

219 Returns 

220 ------- 

221 objects : `pandas.DataFrame` 

222 Round tripped objects. 

223 sources : `pandas.DataFrame` 

224 Round tripped sources. 

225 """ 

226 tmpFile = tempfile.NamedTemporaryFile() 

227 

228 apdbConfig = ApdbConfig() 

229 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

230 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

231 apdbConfig.dia_object_index = "baseline" 

232 apdbConfig.dia_object_columns = [] 

233 apdbConfig.schema_file = _data_file_name( 

234 "apdb-schema.yaml", "dax_apdb") 

235 apdbConfig.column_map = _data_file_name( 

236 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

237 apdbConfig.extra_schema_file = _data_file_name( 

238 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

239 

240 apdb = Apdb(config=apdbConfig, 

241 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

242 DiaSource=make_dia_source_schema())) 

243 apdb.makeSchema() 

244 

245 minId = objects["pixelId"].min() 

246 maxId = objects["pixelId"].max() 

247 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

248 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

249 dateTime, 

250 return_pandas=True).append(sources) 

251 diaForcedSources = apdb.getDiaForcedSources( 

252 np.unique(objects["diaObjectId"]), 

253 dateTime, 

254 return_pandas=True).append(forcedSources) 

255 

256 apdb.storeDiaSources(diaSources) 

257 apdb.storeDiaForcedSources(diaForcedSources) 

258 apdb.storeDiaObjects(diaObjects, dateTime) 

259 

260 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

261 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

262 dateTime, 

263 return_pandas=True) 

264 diaForcedSources = apdb.getDiaForcedSources( 

265 np.unique(diaObjects["diaObjectId"]), 

266 dateTime, 

267 return_pandas=True) 

268 

269 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

270 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

271 drop=False, 

272 inplace=True) 

273 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

274 

275 return (diaObjects, diaSources, diaForcedSources) 

276 

277 

278class TestPackageAlerts(lsst.utils.tests.TestCase): 

279 

280 def setUp(self): 

281 np.random.seed(1234) 

282 self.cutoutSize = 35 

283 self.center = lsst.geom.Point2D(50.1, 49.8) 

284 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

285 lsst.geom.Extent2I(140, 160)) 

286 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

287 self.dataset.addSource(100000.0, self.center) 

288 exposure, catalog = self.dataset.realize( 

289 10.0, 

290 self.dataset.makeMinimalSchema(), 

291 randomSeed=0) 

292 self.exposure = exposure 

293 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

294 self.exposure.setDetector(detector) 

295 

296 visit = afwImage.VisitInfo( 

297 exposureId=1234, 

298 exposureTime=200., 

299 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

300 dafBase.DateTime.Timescale.TAI)) 

301 self.exposure.getInfo().setVisitInfo(visit) 

302 

303 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

304 

305 diaObjects = makeDiaObjects(2, self.exposure) 

306 diaSourceHistory = makeDiaSources(10, 

307 diaObjects["diaObjectId"], 

308 self.exposure) 

309 diaForcedSources = makeDiaForcedSources(10, 

310 diaObjects["diaObjectId"], 

311 self.exposure) 

312 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

313 diaObjects, 

314 diaSourceHistory, 

315 diaForcedSources, 

316 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

317 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

318 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

319 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

320 diaSourceHistory["programId"] = 0 

321 

322 self.diaSources = diaSourceHistory.loc[ 

323 [(0, "g", 8), (1, "g", 9)], :] 

324 self.diaSources["bboxSize"] = self.cutoutSize 

325 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

326 (1, "g", 9)]) 

327 

328 self.cutoutWcs = wcs.WCS(naxis=2) 

329 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

330 self.cutoutWcs.wcs.crval = [ 

331 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

332 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

333 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

334 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

335 

336 def testCreateExtent(self): 

337 """Test the extent creation for the cutout bbox. 

338 """ 

339 packConfig = PackageAlertsConfig() 

340 # Just create a minimum less than the default cutout. 

341 packConfig.minCutoutSize = self.cutoutSize - 5 

342 packageAlerts = PackageAlertsTask(config=packConfig) 

343 extent = packageAlerts.createDiaSourceExtent( 

344 packConfig.minCutoutSize - 5) 

345 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

346 packConfig.minCutoutSize)) 

347 # Test that the cutout size is correct. 

348 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

349 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

350 self.cutoutSize)) 

351 

352 def testCreateCcdDataCutout(self): 

353 """Test that the data is being extracted into the CCDData cutout 

354 correctly. 

355 """ 

356 packageAlerts = PackageAlertsTask() 

357 

358 diaSrcId = 1234 

359 ccdData = packageAlerts.createCcdDataCutout( 

360 self.exposure, 

361 self.exposure.getWcs().getSkyOrigin(), 

362 self.exposure.getBBox().getDimensions(), 

363 self.exposure.getPhotoCalib(), 

364 diaSrcId) 

365 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

366 self.exposure.getMaskedImage()) 

367 

368 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

369 self.cutoutWcs.wcs.cd) 

370 self.assertFloatsAlmostEqual(ccdData.data, 

371 calibExposure.getImage().array) 

372 

373 ccdData = packageAlerts.createCcdDataCutout( 

374 self.exposure, 

375 geom.SpherePoint(0, 0, geom.degrees), 

376 self.exposure.getBBox().getDimensions(), 

377 self.exposure.getPhotoCalib(), 

378 diaSrcId) 

379 self.assertTrue(ccdData is None) 

380 

381 def testMakeLocalTransformMatrix(self): 

382 """Test that the local WCS approximation is correct. 

383 """ 

384 packageAlerts = PackageAlertsTask() 

385 

386 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

387 cutout = self.exposure.getCutout(sphPoint, 

388 geom.Extent2I(self.cutoutSize, 

389 self.cutoutSize)) 

390 cd = packageAlerts.makeLocalTransformMatrix( 

391 cutout.getWcs(), self.center, sphPoint) 

392 self.assertFloatsAlmostEqual( 

393 cd, 

394 cutout.getWcs().getCdMatrix(), 

395 rtol=1e-11, 

396 atol=1e-11) 

397 

398 def testStreamCcdDataToBytes(self): 

399 """Test round tripping an CCDData cutout to bytes and back. 

400 """ 

401 packageAlerts = PackageAlertsTask() 

402 

403 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

404 cutout = self.exposure.getCutout(sphPoint, 

405 geom.Extent2I(self.cutoutSize, 

406 self.cutoutSize)) 

407 cutoutCcdData = CCDData( 

408 data=cutout.getImage().array, 

409 wcs=self.cutoutWcs, 

410 unit="adu") 

411 

412 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

413 with io.BytesIO(cutoutBytes) as bytesIO: 

414 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

415 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

416 

417 def testMakeAlertDict(self): 

418 """Test stripping data from the various data products and into a 

419 dictionary "alert". 

420 """ 

421 packageAlerts = PackageAlertsTask() 

422 alertId = 1234 

423 

424 for srcIdx, diaSource in self.diaSources.iterrows(): 

425 sphPoint = geom.SpherePoint(diaSource["ra"], 

426 diaSource["decl"], 

427 geom.degrees) 

428 cutout = self.exposure.getCutout(sphPoint, 

429 geom.Extent2I(self.cutoutSize, 

430 self.cutoutSize)) 

431 ccdCutout = packageAlerts.createCcdDataCutout( 

432 cutout, 

433 sphPoint, 

434 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

435 cutout.getPhotoCalib(), 

436 1234) 

437 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

438 ccdCutout) 

439 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

440 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

441 alert = packageAlerts.makeAlertDict( 

442 alertId, 

443 diaSource, 

444 self.diaObjects.loc[srcIdx[0]], 

445 objSources, 

446 objForcedSources, 

447 ccdCutout, 

448 ccdCutout) 

449 self.assertEqual(len(alert), 9) 

450 

451 self.assertEqual(alert["alertId"], alertId) 

452 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

453 self.assertEqual(alert["cutoutDifference"], 

454 cutoutBytes) 

455 self.assertEqual(alert["cutoutTemplate"], 

456 cutoutBytes) 

457 

458 def testRun(self): 

459 """Test the run method of package alerts. 

460 """ 

461 packConfig = PackageAlertsConfig() 

462 tempdir = tempfile.mkdtemp(prefix='alerts') 

463 packConfig.alertWriteLocation = tempdir 

464 packageAlerts = PackageAlertsTask(config=packConfig) 

465 

466 packageAlerts.run(self.diaSources, 

467 self.diaObjects, 

468 self.diaSourceHistory, 

469 self.diaForcedSources, 

470 self.exposure, 

471 self.exposure, 

472 None) 

473 

474 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

475 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

476 writer_schema, data_stream = \ 

477 packageAlerts.alertSchema.retrieve_alerts(f) 

478 data = list(data_stream) 

479 self.assertEqual(len(data), len(self.diaSources)) 

480 for idx, alert in enumerate(data): 

481 for key, value in alert["diaSource"].items(): 

482 if isinstance(value, float): 

483 if np.isnan(self.diaSources.iloc[idx][key]): 

484 self.assertTrue(np.isnan(value)) 

485 else: 

486 self.assertAlmostEqual( 

487 1 - value / self.diaSources.iloc[idx][key], 

488 0.) 

489 else: 

490 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

491 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

492 alert["diaSource"]["decl"], 

493 geom.degrees) 

494 cutout = self.exposure.getCutout(sphPoint, 

495 geom.Extent2I(self.cutoutSize, 

496 self.cutoutSize)) 

497 ccdCutout = packageAlerts.createCcdDataCutout( 

498 cutout, 

499 sphPoint, 

500 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

501 cutout.getPhotoCalib(), 

502 1234) 

503 self.assertEqual(alert["cutoutDifference"], 

504 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

505 

506 shutil.rmtree(tempdir) 

507 

508 

509class MemoryTester(lsst.utils.tests.MemoryTestCase): 

510 pass 

511 

512 

513def setup_module(module): 

514 lsst.utils.tests.init() 

515 

516 

517if __name__ == "__main__": 517 ↛ 518line 517 didn't jump to line 518, because the condition on line 517 was never true

518 lsst.utils.tests.init() 

519 unittest.main()