Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

34from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

35import lsst.afw.image as afwImage 

36import lsst.daf.base as dafBase 

37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

38import lsst.geom as geom 

39import lsst.meas.base.tests 

40from lsst.sphgeom import Box 

41from lsst.utils import getPackageDir 

42import lsst.utils.tests 

43 

44 

45def _data_file_name(basename, module_name): 

46 """Return path name of a data file. 

47 

48 Parameters 

49 ---------- 

50 basename : `str` 

51 Name of the file to add to the path string. 

52 module_name : `str` 

53 Name of lsst stack package environment variable. 

54 

55 Returns 

56 ------- 

57 data_file_path : `str` 

58 Full path of the file to load from the "data" directory in a given 

59 repository. 

60 """ 

61 return os.path.join(getPackageDir(module_name), "data", basename) 

62 

63 

64def makeDiaObjects(nObjects, exposure): 

65 """Make a test set of DiaObjects. 

66 

67 Parameters 

68 ---------- 

69 nObjects : `int` 

70 Number of objects to create. 

71 exposure : `lsst.afw.image.Exposure` 

72 Exposure to create objects over. 

73 

74 Returns 

75 ------- 

76 diaObjects : `pandas.DataFrame` 

77 DiaObjects generated across the exposure. 

78 """ 

79 bbox = geom.Box2D(exposure.getBBox()) 

80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

82 

83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

84 system=dafBase.DateTime.MJD) 

85 

86 wcs = exposure.getWcs() 

87 

88 data = [] 

89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

90 coord = wcs.pixelToSky(x, y) 

91 newObject = {"ra": coord.getRa().asDegrees(), 

92 "decl": coord.getDec().asDegrees(), 

93 "radecTai": midPointTaiMJD, 

94 "diaObjectId": idx + 1, 

95 "pmParallaxNdata": 0, 

96 "nearbyObj1": 0, 

97 "nearbyObj2": 0, 

98 "nearbyObj3": 0, 

99 "flags": 1, 

100 "nDiaSources": 5} 

101 for f in ["u", "g", "r", "i", "z", "y"]: 

102 newObject["%sPSFluxNdata" % f] = 0 

103 data.append(newObject) 

104 

105 return pd.DataFrame(data=data) 

106 

107 

108def makeDiaSources(nSources, diaObjectIds, exposure): 

109 """Make a test set of DiaSources. 

110 

111 Parameters 

112 ---------- 

113 nSources : `int` 

114 Number of sources to create. 

115 diaObjectIds : `numpy.ndarray` 

116 Integer Ids of diaobjects to "associate" with the DiaSources. 

117 exposure : `lsst.afw.image.Exposure` 

118 Exposure to create sources over. 

119 

120 Returns 

121 ------- 

122 diaSources : `pandas.DataFrame` 

123 DiaSources generated across the exposure. 

124 """ 

125 bbox = geom.Box2D(exposure.getBBox()) 

126 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

127 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

128 

129 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

130 system=dafBase.DateTime.MJD) 

131 

132 wcs = exposure.getWcs() 

133 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

134 

135 data = [] 

136 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

137 coord = wcs.pixelToSky(x, y) 

138 objId = diaObjectIds[idx % len(diaObjectIds)] 

139 # Put together the minimum values for the alert. 

140 data.append({"ra": coord.getRa().asDegrees(), 

141 "decl": coord.getDec().asDegrees(), 

142 "x": x, 

143 "y": y, 

144 "ccdVisitId": ccdVisitId, 

145 "diaObjectId": objId, 

146 "ssObjectId": 0, 

147 "parentDiaSourceId": 0, 

148 "prv_procOrder": 0, 

149 "diaSourceId": idx + 1, 

150 "midPointTai": midPointTaiMJD + 1.0 * idx, 

151 "filterName": exposure.getFilterLabel().bandLabel, 

152 "psNdata": 0, 

153 "trailNdata": 0, 

154 "dipNdata": 0, 

155 "flags": 1}) 

156 

157 return pd.DataFrame(data=data) 

158 

159 

160def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

161 """Make a test set of DiaSources. 

162 

163 Parameters 

164 ---------- 

165 nSources : `int` 

166 Number of sources to create. 

167 diaObjectIds : `numpy.ndarray` 

168 Integer Ids of diaobjects to "associate" with the DiaSources. 

169 exposure : `lsst.afw.image.Exposure` 

170 Exposure to create sources over. 

171 

172 Returns 

173 ------- 

174 diaSources : `pandas.DataFrame` 

175 DiaSources generated across the exposure. 

176 """ 

177 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

178 system=dafBase.DateTime.MJD) 

179 

180 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

181 

182 data = [] 

183 for idx in range(nSources): 

184 objId = diaObjectIds[idx % len(diaObjectIds)] 

185 # Put together the minimum values for the alert. 

186 data.append({"diaForcedSourceId": idx + 1, 

187 "ccdVisitId": ccdVisitId + idx, 

188 "diaObjectId": objId, 

189 "midPointTai": midPointTaiMJD + 1.0 * idx, 

190 "filterName": exposure.getFilterLabel().bandLabel, 

191 "flags": 0}) 

192 

193 return pd.DataFrame(data=data) 

194 

195 

196def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

197 """Run object and source catalogs through the Apdb to get the correct 

198 table schemas. 

199 

200 Parameters 

201 ---------- 

202 objects : `pandas.DataFrame` 

203 Set of test DiaObjects to round trip. 

204 sources : `pandas.DataFrame` 

205 Set of test DiaSources to round trip. 

206 forcedSources : `pandas.DataFrame` 

207 Set of test DiaForcedSources to round trip. 

208 dateTime : `lsst.daf.base.DateTime` 

209 Time for the Apdb. 

210 

211 Returns 

212 ------- 

213 objects : `pandas.DataFrame` 

214 Round tripped objects. 

215 sources : `pandas.DataFrame` 

216 Round tripped sources. 

217 """ 

218 tmpFile = tempfile.NamedTemporaryFile() 

219 

220 apdbConfig = ApdbSqlConfig() 

221 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

222 apdbConfig.dia_object_index = "baseline" 

223 apdbConfig.dia_object_columns = [] 

224 apdbConfig.schema_file = _data_file_name( 

225 "apdb-schema.yaml", "dax_apdb") 

226 apdbConfig.extra_schema_file = _data_file_name( 

227 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

228 

229 apdb = ApdbSql(config=apdbConfig) 

230 apdb.makeSchema() 

231 

232 wholeSky = Box.full() 

233 diaObjects = apdb.getDiaObjects(wholeSky).append(objects) 

234 diaSources = apdb.getDiaSources(wholeSky, [], dateTime).append(sources) 

235 diaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime).append(forcedSources) 

236 

237 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

238 

239 diaObjects = apdb.getDiaObjects(wholeSky) 

240 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

241 diaForcedSources = apdb.getDiaForcedSources( 

242 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

243 

244 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

245 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

246 drop=False, 

247 inplace=True) 

248 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

249 

250 return (diaObjects, diaSources, diaForcedSources) 

251 

252 

253class TestPackageAlerts(lsst.utils.tests.TestCase): 

254 

255 def setUp(self): 

256 np.random.seed(1234) 

257 self.cutoutSize = 35 

258 self.center = lsst.geom.Point2D(50.1, 49.8) 

259 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

260 lsst.geom.Extent2I(140, 160)) 

261 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

262 self.dataset.addSource(100000.0, self.center) 

263 exposure, catalog = self.dataset.realize( 

264 10.0, 

265 self.dataset.makeMinimalSchema(), 

266 randomSeed=0) 

267 self.exposure = exposure 

268 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

269 self.exposure.setDetector(detector) 

270 

271 visit = afwImage.VisitInfo( 

272 exposureId=1234, 

273 exposureTime=200., 

274 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

275 dafBase.DateTime.Timescale.TAI)) 

276 self.exposure.getInfo().setVisitInfo(visit) 

277 

278 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

279 

280 diaObjects = makeDiaObjects(2, self.exposure) 

281 diaSourceHistory = makeDiaSources(10, 

282 diaObjects["diaObjectId"], 

283 self.exposure) 

284 diaForcedSources = makeDiaForcedSources(10, 

285 diaObjects["diaObjectId"], 

286 self.exposure) 

287 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

288 diaObjects, 

289 diaSourceHistory, 

290 diaForcedSources, 

291 self.exposure.getInfo().getVisitInfo().getDate()) 

292 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

293 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

294 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

295 diaSourceHistory["programId"] = 0 

296 

297 self.diaSources = diaSourceHistory.loc[ 

298 [(1, "g", 9), (2, "g", 10)], :] 

299 self.diaSources["bboxSize"] = self.cutoutSize 

300 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

301 (2, "g", 10)]) 

302 

303 self.cutoutWcs = wcs.WCS(naxis=2) 

304 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

305 self.cutoutWcs.wcs.crval = [ 

306 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

307 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

308 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

309 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

310 

311 def testCreateExtent(self): 

312 """Test the extent creation for the cutout bbox. 

313 """ 

314 packConfig = PackageAlertsConfig() 

315 # Just create a minimum less than the default cutout. 

316 packConfig.minCutoutSize = self.cutoutSize - 5 

317 packageAlerts = PackageAlertsTask(config=packConfig) 

318 extent = packageAlerts.createDiaSourceExtent( 

319 packConfig.minCutoutSize - 5) 

320 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

321 packConfig.minCutoutSize)) 

322 # Test that the cutout size is correct. 

323 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

324 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

325 self.cutoutSize)) 

326 

327 def testCreateCcdDataCutout(self): 

328 """Test that the data is being extracted into the CCDData cutout 

329 correctly. 

330 """ 

331 packageAlerts = PackageAlertsTask() 

332 

333 diaSrcId = 1234 

334 ccdData = packageAlerts.createCcdDataCutout( 

335 self.exposure, 

336 self.exposure.getWcs().getSkyOrigin(), 

337 self.exposure.getBBox().getDimensions(), 

338 self.exposure.getPhotoCalib(), 

339 diaSrcId) 

340 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

341 self.exposure.getMaskedImage()) 

342 

343 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

344 self.cutoutWcs.wcs.cd) 

345 self.assertFloatsAlmostEqual(ccdData.data, 

346 calibExposure.getImage().array) 

347 

348 ccdData = packageAlerts.createCcdDataCutout( 

349 self.exposure, 

350 geom.SpherePoint(0, 0, geom.degrees), 

351 self.exposure.getBBox().getDimensions(), 

352 self.exposure.getPhotoCalib(), 

353 diaSrcId) 

354 self.assertTrue(ccdData is None) 

355 

356 def testMakeLocalTransformMatrix(self): 

357 """Test that the local WCS approximation is correct. 

358 """ 

359 packageAlerts = PackageAlertsTask() 

360 

361 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

362 cutout = self.exposure.getCutout(sphPoint, 

363 geom.Extent2I(self.cutoutSize, 

364 self.cutoutSize)) 

365 cd = packageAlerts.makeLocalTransformMatrix( 

366 cutout.getWcs(), self.center, sphPoint) 

367 self.assertFloatsAlmostEqual( 

368 cd, 

369 cutout.getWcs().getCdMatrix(), 

370 rtol=1e-11, 

371 atol=1e-11) 

372 

373 def testStreamCcdDataToBytes(self): 

374 """Test round tripping an CCDData cutout to bytes and back. 

375 """ 

376 packageAlerts = PackageAlertsTask() 

377 

378 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

379 cutout = self.exposure.getCutout(sphPoint, 

380 geom.Extent2I(self.cutoutSize, 

381 self.cutoutSize)) 

382 cutoutCcdData = CCDData( 

383 data=cutout.getImage().array, 

384 wcs=self.cutoutWcs, 

385 unit="adu") 

386 

387 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

388 with io.BytesIO(cutoutBytes) as bytesIO: 

389 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

390 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

391 

392 def testMakeAlertDict(self): 

393 """Test stripping data from the various data products and into a 

394 dictionary "alert". 

395 """ 

396 packageAlerts = PackageAlertsTask() 

397 alertId = 1234 

398 

399 for srcIdx, diaSource in self.diaSources.iterrows(): 

400 sphPoint = geom.SpherePoint(diaSource["ra"], 

401 diaSource["decl"], 

402 geom.degrees) 

403 cutout = self.exposure.getCutout(sphPoint, 

404 geom.Extent2I(self.cutoutSize, 

405 self.cutoutSize)) 

406 ccdCutout = packageAlerts.createCcdDataCutout( 

407 cutout, 

408 sphPoint, 

409 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

410 cutout.getPhotoCalib(), 

411 1234) 

412 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

413 ccdCutout) 

414 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

415 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

416 alert = packageAlerts.makeAlertDict( 

417 alertId, 

418 diaSource, 

419 self.diaObjects.loc[srcIdx[0]], 

420 objSources, 

421 objForcedSources, 

422 ccdCutout, 

423 ccdCutout) 

424 self.assertEqual(len(alert), 9) 

425 

426 self.assertEqual(alert["alertId"], alertId) 

427 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

428 self.assertEqual(alert["cutoutDifference"], 

429 cutoutBytes) 

430 self.assertEqual(alert["cutoutTemplate"], 

431 cutoutBytes) 

432 

433 def testRun(self): 

434 """Test the run method of package alerts. 

435 """ 

436 packConfig = PackageAlertsConfig() 

437 tempdir = tempfile.mkdtemp(prefix='alerts') 

438 packConfig.alertWriteLocation = tempdir 

439 packageAlerts = PackageAlertsTask(config=packConfig) 

440 

441 packageAlerts.run(self.diaSources, 

442 self.diaObjects, 

443 self.diaSourceHistory, 

444 self.diaForcedSources, 

445 self.exposure, 

446 self.exposure, 

447 None) 

448 

449 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

450 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

451 writer_schema, data_stream = \ 

452 packageAlerts.alertSchema.retrieve_alerts(f) 

453 data = list(data_stream) 

454 self.assertEqual(len(data), len(self.diaSources)) 

455 for idx, alert in enumerate(data): 

456 for key, value in alert["diaSource"].items(): 

457 if isinstance(value, float): 

458 if np.isnan(self.diaSources.iloc[idx][key]): 

459 self.assertTrue(np.isnan(value)) 

460 else: 

461 self.assertAlmostEqual( 

462 1 - value / self.diaSources.iloc[idx][key], 

463 0.) 

464 else: 

465 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

466 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

467 alert["diaSource"]["decl"], 

468 geom.degrees) 

469 cutout = self.exposure.getCutout(sphPoint, 

470 geom.Extent2I(self.cutoutSize, 

471 self.cutoutSize)) 

472 ccdCutout = packageAlerts.createCcdDataCutout( 

473 cutout, 

474 sphPoint, 

475 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

476 cutout.getPhotoCalib(), 

477 1234) 

478 self.assertEqual(alert["cutoutDifference"], 

479 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

480 

481 shutil.rmtree(tempdir) 

482 

483 

484class MemoryTester(lsst.utils.tests.MemoryTestCase): 

485 pass 

486 

487 

488def setup_module(module): 

489 lsst.utils.tests.init() 

490 

491 

492if __name__ == "__main__": 492 ↛ 493line 492 didn't jump to line 493, because the condition on line 492 was never true

493 lsst.utils.tests.init() 

494 unittest.main()