Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.daf.base as dafBase 

40from lsst.dax.apdb import Apdb, ApdbConfig 

41import lsst.geom as geom 

42import lsst.meas.base.tests 

43from lsst.utils import getPackageDir 

44import lsst.utils.tests 

45 

46 

47def _data_file_name(basename, module_name): 

48 """Return path name of a data file. 

49 

50 Parameters 

51 ---------- 

52 basename : `str` 

53 Name of the file to add to the path string. 

54 module_name : `str` 

55 Name of lsst stack package environment variable. 

56 

57 Returns 

58 ------- 

59 data_file_path : `str` 

60 Full path of the file to load from the "data" directory in a given 

61 repository. 

62 """ 

63 return os.path.join(getPackageDir(module_name), "data", basename) 

64 

65 

66def makeDiaObjects(nObjects, exposure): 

67 """Make a test set of DiaObjects. 

68 

69 Parameters 

70 ---------- 

71 nObjects : `int` 

72 Number of objects to create. 

73 exposure : `lsst.afw.image.Exposure` 

74 Exposure to create objects over. 

75 

76 Returns 

77 ------- 

78 diaObjects : `pandas.DataFrame` 

79 DiaObjects generated across the exposure. 

80 """ 

81 bbox = geom.Box2D(exposure.getBBox()) 

82 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

83 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

84 

85 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

86 system=dafBase.DateTime.MJD) 

87 

88 wcs = exposure.getWcs() 

89 

90 data = [] 

91 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

92 coord = wcs.pixelToSky(x, y) 

93 htmIdx = 1 

94 newObject = {"ra": coord.getRa().asDegrees(), 

95 "decl": coord.getDec().asDegrees(), 

96 "radecTai": midPointTaiMJD, 

97 "diaObjectId": idx, 

98 "pixelId": htmIdx, 

99 "pmParallaxNdata": 0, 

100 "nearbyObj1": 0, 

101 "nearbyObj2": 0, 

102 "nearbyObj3": 0, 

103 "flags": 1, 

104 "nDiaSources": 5} 

105 for f in ["u", "g", "r", "i", "z", "y"]: 

106 newObject["%sPSFluxNdata" % f] = 0 

107 data.append(newObject) 

108 

109 return pd.DataFrame(data=data) 

110 

111 

112def makeDiaSources(nSources, diaObjectIds, exposure): 

113 """Make a test set of DiaSources. 

114 

115 Parameters 

116 ---------- 

117 nSources : `int` 

118 Number of sources to create. 

119 diaObjectIds : `numpy.ndarray` 

120 Integer Ids of diaobjects to "associate" with the DiaSources. 

121 exposure : `lsst.afw.image.Exposure` 

122 Exposure to create sources over. 

123 pixelator : `lsst.sphgeom.HtmPixelization` 

124 Object to compute spatial indicies from. 

125 

126 Returns 

127 ------- 

128 diaSources : `pandas.DataFrame` 

129 DiaSources generated across the exposure. 

130 """ 

131 bbox = geom.Box2D(exposure.getBBox()) 

132 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

133 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

134 

135 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

136 system=dafBase.DateTime.MJD) 

137 

138 wcs = exposure.getWcs() 

139 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

140 

141 data = [] 

142 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

143 coord = wcs.pixelToSky(x, y) 

144 htmIdx = 1 

145 objId = diaObjectIds[idx % len(diaObjectIds)] 

146 # Put together the minimum values for the alert. 

147 data.append({"ra": coord.getRa().asDegrees(), 

148 "decl": coord.getDec().asDegrees(), 

149 "x": x, 

150 "y": y, 

151 "ccdVisitId": ccdVisitId, 

152 "diaObjectId": objId, 

153 "ssObjectId": 0, 

154 "parentDiaSourceId": 0, 

155 "prv_procOrder": 0, 

156 "diaSourceId": idx, 

157 "pixelId": htmIdx, 

158 "midPointTai": midPointTaiMJD + 1.0 * idx, 

159 "filterName": exposure.getFilterLabel().bandLabel, 

160 "psNdata": 0, 

161 "trailNdata": 0, 

162 "dipNdata": 0, 

163 "flags": 1}) 

164 

165 return pd.DataFrame(data=data) 

166 

167 

168def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

169 """Make a test set of DiaSources. 

170 

171 Parameters 

172 ---------- 

173 nSources : `int` 

174 Number of sources to create. 

175 diaObjectIds : `numpy.ndarray` 

176 Integer Ids of diaobjects to "associate" with the DiaSources. 

177 exposure : `lsst.afw.image.Exposure` 

178 Exposure to create sources over. 

179 

180 Returns 

181 ------- 

182 diaSources : `pandas.DataFrame` 

183 DiaSources generated across the exposure. 

184 """ 

185 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

186 system=dafBase.DateTime.MJD) 

187 

188 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

189 

190 data = [] 

191 for idx in range(nSources): 

192 objId = diaObjectIds[idx % len(diaObjectIds)] 

193 # Put together the minimum values for the alert. 

194 data.append({"diaForcedSourceId": idx, 

195 "ccdVisitId": ccdVisitId + idx, 

196 "diaObjectId": objId, 

197 "midPointTai": midPointTaiMJD + 1.0 * idx, 

198 "filterName": exposure.getFilterLabel().bandLabel, 

199 "flags": 0}) 

200 

201 return pd.DataFrame(data=data) 

202 

203 

204def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

205 """Run object and source catalogs through the Apdb to get the correct 

206 table schemas. 

207 

208 Parameters 

209 ---------- 

210 objects : `pandas.DataFrame` 

211 Set of test DiaObjects to round trip. 

212 sources : `pandas.DataFrame` 

213 Set of test DiaSources to round trip. 

214 forcedSources : `pandas.DataFrame` 

215 Set of test DiaForcedSources to round trip. 

216 dateTime : `datetime.datetime` 

217 Time for the Apdb. 

218 

219 Returns 

220 ------- 

221 objects : `pandas.DataFrame` 

222 Round tripped objects. 

223 sources : `pandas.DataFrame` 

224 Round tripped sources. 

225 """ 

226 tmpFile = tempfile.NamedTemporaryFile() 

227 

228 apdbConfig = ApdbConfig() 

229 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

230 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

231 apdbConfig.dia_object_index = "baseline" 

232 apdbConfig.dia_object_columns = [] 

233 apdbConfig.schema_file = _data_file_name( 

234 "apdb-schema.yaml", "dax_apdb") 

235 apdbConfig.column_map = _data_file_name( 

236 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

237 apdbConfig.extra_schema_file = _data_file_name( 

238 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

239 

240 apdb = Apdb(config=apdbConfig, 

241 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

242 DiaSource=make_dia_source_schema())) 

243 apdb.makeSchema() 

244 

245 minId = objects["pixelId"].min() 

246 maxId = objects["pixelId"].max() 

247 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

248 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

249 dateTime, 

250 return_pandas=True).append(sources) 

251 diaForcedSources = apdb.getDiaForcedSources( 

252 np.unique(objects["diaObjectId"]), 

253 dateTime, 

254 return_pandas=True).append(forcedSources) 

255 

256 apdb.storeDiaSources(diaSources) 

257 apdb.storeDiaForcedSources(diaForcedSources) 

258 apdb.storeDiaObjects(diaObjects, dateTime) 

259 

260 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

261 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

262 dateTime, 

263 return_pandas=True) 

264 diaForcedSources = apdb.getDiaForcedSources( 

265 np.unique(diaObjects["diaObjectId"]), 

266 dateTime, 

267 return_pandas=True) 

268 

269 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

270 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

271 drop=False, 

272 inplace=True) 

273 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

274 

275 return (diaObjects, diaSources, diaForcedSources) 

276 

277 

278class TestPackageAlerts(lsst.utils.tests.TestCase): 

279 

280 def setUp(self): 

281 np.random.seed(1234) 

282 self.cutoutSize = 35 

283 self.center = lsst.geom.Point2D(50.1, 49.8) 

284 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

285 lsst.geom.Extent2I(140, 160)) 

286 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

287 self.dataset.addSource(100000.0, self.center) 

288 exposure, catalog = self.dataset.realize( 

289 10.0, 

290 self.dataset.makeMinimalSchema(), 

291 randomSeed=0) 

292 self.exposure = exposure 

293 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

294 self.exposure.setDetector(detector) 

295 

296 visit = afwImage.VisitInfo( 

297 exposureId=1234, 

298 exposureTime=200., 

299 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

300 dafBase.DateTime.Timescale.TAI)) 

301 self.exposure.getInfo().setVisitInfo(visit) 

302 

303 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401")) 

304 

305 diaObjects = makeDiaObjects(2, self.exposure) 

306 diaSourceHistory = makeDiaSources(10, 

307 diaObjects["diaObjectId"], 

308 self.exposure) 

309 diaForcedSources = makeDiaForcedSources(10, 

310 diaObjects["diaObjectId"], 

311 self.exposure) 

312 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

313 diaObjects, 

314 diaSourceHistory, 

315 diaForcedSources, 

316 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

317 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

318 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

319 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

320 diaSourceHistory["programId"] = 0 

321 

322 self.diaSources = diaSourceHistory.loc[ 

323 [(0, "g", 8), (1, "g", 9)], :] 

324 self.diaSources["bboxSize"] = self.cutoutSize 

325 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

326 (1, "g", 9)]) 

327 

328 self.cutoutWcs = wcs.WCS(naxis=2) 

329 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

330 self.cutoutWcs.wcs.crval = [ 

331 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

332 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

333 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

334 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

335 

336 def testCreateBBox(self): 

337 """Test the bbox creation 

338 """ 

339 packConfig = PackageAlertsConfig() 

340 # Just create a minimum less than the default cutout. 

341 packConfig.minCutoutSize = self.cutoutSize - 5 

342 packageAlerts = PackageAlertsTask(config=packConfig) 

343 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

344 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

345 packConfig.minCutoutSize)) 

346 # Test that the cutout size is correct. 

347 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

348 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

349 self.cutoutSize)) 

350 

351 def testCreateCcdDataCutout(self): 

352 """Test that the data is being extracted into the CCDData cutout 

353 correctly. 

354 """ 

355 packageAlerts = PackageAlertsTask() 

356 

357 ccdData = packageAlerts.createCcdDataCutout( 

358 self.exposure, 

359 self.exposure.getWcs().getSkyOrigin(), 

360 self.exposure.getPhotoCalib()) 

361 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

362 self.exposure.getMaskedImage()) 

363 

364 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

365 self.cutoutWcs.wcs.cd) 

366 self.assertFloatsAlmostEqual(ccdData.data, 

367 calibExposure.getImage().array) 

368 

369 def testMakeLocalTransformMatrix(self): 

370 """Test that the local WCS approximation is correct. 

371 """ 

372 packageAlerts = PackageAlertsTask() 

373 

374 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

375 cutout = self.exposure.getCutout(sphPoint, 

376 geom.Extent2I(self.cutoutSize, 

377 self.cutoutSize)) 

378 cd = packageAlerts.makeLocalTransformMatrix( 

379 cutout.getWcs(), self.center, sphPoint) 

380 self.assertFloatsAlmostEqual( 

381 cd, 

382 cutout.getWcs().getCdMatrix(), 

383 rtol=1e-11, 

384 atol=1e-11) 

385 

386 def testStreamCcdDataToBytes(self): 

387 """Test round tripping an CCDData cutout to bytes and back. 

388 """ 

389 packageAlerts = PackageAlertsTask() 

390 

391 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

392 cutout = self.exposure.getCutout(sphPoint, 

393 geom.Extent2I(self.cutoutSize, 

394 self.cutoutSize)) 

395 cutoutCcdData = CCDData( 

396 data=cutout.getImage().array, 

397 wcs=self.cutoutWcs, 

398 unit="adu") 

399 

400 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

401 with io.BytesIO(cutoutBytes) as bytesIO: 

402 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

403 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

404 

405 def testMakeAlertDict(self): 

406 """Test stripping data from the various data products and into a 

407 dictionary "alert". 

408 """ 

409 packageAlerts = PackageAlertsTask() 

410 alertId = 1234 

411 

412 for srcIdx, diaSource in self.diaSources.iterrows(): 

413 sphPoint = geom.SpherePoint(diaSource["ra"], 

414 diaSource["decl"], 

415 geom.degrees) 

416 cutout = self.exposure.getCutout(sphPoint, 

417 geom.Extent2I(self.cutoutSize, 

418 self.cutoutSize)) 

419 ccdCutout = packageAlerts.createCcdDataCutout( 

420 cutout, sphPoint, cutout.getPhotoCalib()) 

421 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

422 ccdCutout) 

423 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

424 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

425 alert = packageAlerts.makeAlertDict( 

426 alertId, 

427 diaSource, 

428 self.diaObjects.loc[srcIdx[0]], 

429 objSources, 

430 objForcedSources, 

431 ccdCutout, 

432 ccdCutout) 

433 self.assertEqual(len(alert), 9) 

434 

435 self.assertEqual(alert["alertId"], alertId) 

436 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

437 self.assertEqual(alert["cutoutDifference"], 

438 cutoutBytes) 

439 self.assertEqual(alert["cutoutTemplate"], 

440 cutoutBytes) 

441 

442 def testRun(self): 

443 """Test the run method of package alerts. 

444 """ 

445 packConfig = PackageAlertsConfig() 

446 tempdir = tempfile.mkdtemp(prefix='alerts') 

447 packConfig.alertWriteLocation = tempdir 

448 packageAlerts = PackageAlertsTask(config=packConfig) 

449 

450 packageAlerts.run(self.diaSources, 

451 self.diaObjects, 

452 self.diaSourceHistory, 

453 self.diaForcedSources, 

454 self.exposure, 

455 self.exposure, 

456 None) 

457 

458 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

459 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

460 writer_schema, data_stream = \ 

461 packageAlerts.alertSchema.retrieve_alerts(f) 

462 data = list(data_stream) 

463 self.assertEqual(len(data), len(self.diaSources)) 

464 for idx, alert in enumerate(data): 

465 for key, value in alert["diaSource"].items(): 

466 if isinstance(value, float): 

467 if np.isnan(self.diaSources.iloc[idx][key]): 

468 self.assertTrue(np.isnan(value)) 

469 else: 

470 self.assertAlmostEqual( 

471 1 - value / self.diaSources.iloc[idx][key], 

472 0.) 

473 else: 

474 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

475 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

476 alert["diaSource"]["decl"], 

477 geom.degrees) 

478 cutout = self.exposure.getCutout(sphPoint, 

479 geom.Extent2I(self.cutoutSize, 

480 self.cutoutSize)) 

481 ccdCutout = packageAlerts.createCcdDataCutout( 

482 cutout, sphPoint, cutout.getPhotoCalib()) 

483 self.assertEqual(alert["cutoutDifference"], 

484 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

485 

486 shutil.rmtree(tempdir) 

487 

488 

489class MemoryTester(lsst.utils.tests.MemoryTestCase): 

490 pass 

491 

492 

493def setup_module(module): 

494 lsst.utils.tests.init() 

495 

496 

497if __name__ == "__main__": 497 ↛ 498line 497 didn't jump to line 498, because the condition on line 497 was never true

498 lsst.utils.tests.init() 

499 unittest.main()