Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.afw.image.utils as afwImageUtils 

40import lsst.daf.base as dafBase 

41from lsst.dax.apdb import Apdb, ApdbConfig 

42import lsst.geom as geom 

43import lsst.meas.base.tests 

44from lsst.utils import getPackageDir 

45import lsst.utils.tests 

46 

47 

48def _data_file_name(basename, module_name): 

49 """Return path name of a data file. 

50 

51 Parameters 

52 ---------- 

53 basename : `str` 

54 Name of the file to add to the path string. 

55 module_name : `str` 

56 Name of lsst stack package environment variable. 

57 

58 Returns 

59 ------- 

60 data_file_path : `str` 

61 Full path of the file to load from the "data" directory in a given 

62 repository. 

63 """ 

64 return os.path.join(getPackageDir(module_name), "data", basename) 

65 

66 

67def makeDiaObjects(nObjects, exposure): 

68 """Make a test set of DiaObjects. 

69 

70 Parameters 

71 ---------- 

72 nObjects : `int` 

73 Number of objects to create. 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure to create objects over. 

76 

77 Returns 

78 ------- 

79 diaObjects : `pandas.DataFrame` 

80 DiaObjects generated across the exposure. 

81 """ 

82 bbox = geom.Box2D(exposure.getBBox()) 

83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

85 

86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

87 system=dafBase.DateTime.MJD) 

88 

89 wcs = exposure.getWcs() 

90 

91 data = [] 

92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

93 coord = wcs.pixelToSky(x, y) 

94 htmIdx = 1 

95 newObject = {"ra": coord.getRa().asDegrees(), 

96 "decl": coord.getDec().asDegrees(), 

97 "radecTai": midPointTaiMJD, 

98 "diaObjectId": idx, 

99 "pixelId": htmIdx, 

100 "pmParallaxNdata": 0, 

101 "nearbyObj1": 0, 

102 "nearbyObj2": 0, 

103 "nearbyObj3": 0, 

104 "flags": 1, 

105 "nDiaSources": 5} 

106 for f in ["u", "g", "r", "i", "z", "y"]: 

107 newObject["%sPSFluxNdata" % f] = 0 

108 data.append(newObject) 

109 

110 return pd.DataFrame(data=data) 

111 

112 

113def makeDiaSources(nSources, diaObjectIds, exposure): 

114 """Make a test set of DiaSources. 

115 

116 Parameters 

117 ---------- 

118 nSources : `int` 

119 Number of sources to create. 

120 diaObjectIds : `numpy.ndarray` 

121 Integer Ids of diaobjects to "associate" with the DiaSources. 

122 exposure : `lsst.afw.image.Exposure` 

123 Exposure to create sources over. 

124 pixelator : `lsst.sphgeom.HtmPixelization` 

125 Object to compute spatial indicies from. 

126 

127 Returns 

128 ------- 

129 diaSources : `pandas.DataFrame` 

130 DiaSources generated across the exposure. 

131 """ 

132 bbox = geom.Box2D(exposure.getBBox()) 

133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

135 

136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

137 system=dafBase.DateTime.MJD) 

138 

139 wcs = exposure.getWcs() 

140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

141 

142 data = [] 

143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

144 coord = wcs.pixelToSky(x, y) 

145 htmIdx = 1 

146 objId = diaObjectIds[idx % len(diaObjectIds)] 

147 # Put together the minimum values for the alert. 

148 data.append({"ra": coord.getRa().asDegrees(), 

149 "decl": coord.getDec().asDegrees(), 

150 "x": x, 

151 "y": y, 

152 "ccdVisitId": ccdVisitId, 

153 "diaObjectId": objId, 

154 "ssObjectId": 0, 

155 "parentDiaSourceId": 0, 

156 "prv_procOrder": 0, 

157 "diaSourceId": idx, 

158 "pixelId": htmIdx, 

159 "midPointTai": midPointTaiMJD + 1.0 * idx, 

160 # TODO DM-27170: fix this [0] workaround which gets a 

161 # single character representation of the band. 

162 "filterName": exposure.getFilter().getCanonicalName()[0], 

163 "filterId": 0, 

164 "psNdata": 0, 

165 "trailNdata": 0, 

166 "dipNdata": 0, 

167 "flags": 1}) 

168 

169 return pd.DataFrame(data=data) 

170 

171 

172def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

173 """Make a test set of DiaSources. 

174 

175 Parameters 

176 ---------- 

177 nSources : `int` 

178 Number of sources to create. 

179 diaObjectIds : `numpy.ndarray` 

180 Integer Ids of diaobjects to "associate" with the DiaSources. 

181 exposure : `lsst.afw.image.Exposure` 

182 Exposure to create sources over. 

183 

184 Returns 

185 ------- 

186 diaSources : `pandas.DataFrame` 

187 DiaSources generated across the exposure. 

188 """ 

189 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

190 system=dafBase.DateTime.MJD) 

191 

192 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

193 

194 data = [] 

195 for idx in range(nSources): 

196 objId = diaObjectIds[idx % len(diaObjectIds)] 

197 # Put together the minimum values for the alert. 

198 data.append({"diaForcedSourceId": idx, 

199 "ccdVisitId": ccdVisitId + idx, 

200 "diaObjectId": objId, 

201 "midPointTai": midPointTaiMJD + 1.0 * idx, 

202 # TODO DM-27170: fix this [0] workaround which gets a 

203 # single character representation of the band. 

204 "filterName": exposure.getFilter().getCanonicalName()[0], 

205 "flags": 0}) 

206 

207 return pd.DataFrame(data=data) 

208 

209 

210def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

211 """Run object and source catalogs through the Apdb to get the correct 

212 table schemas. 

213 

214 Parameters 

215 ---------- 

216 objects : `pandas.DataFrame` 

217 Set of test DiaObjects to round trip. 

218 sources : `pandas.DataFrame` 

219 Set of test DiaSources to round trip. 

220 forcedSources : `pandas.DataFrame` 

221 Set of test DiaForcedSources to round trip. 

222 dateTime : `datetime.datetime` 

223 Time for the Apdb. 

224 

225 Returns 

226 ------- 

227 objects : `pandas.DataFrame` 

228 Round tripped objects. 

229 sources : `pandas.DataFrame` 

230 Round tripped sources. 

231 """ 

232 tmpFile = tempfile.NamedTemporaryFile() 

233 

234 apdbConfig = ApdbConfig() 

235 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

236 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

237 apdbConfig.dia_object_index = "baseline" 

238 apdbConfig.dia_object_columns = [] 

239 apdbConfig.schema_file = _data_file_name( 

240 "apdb-schema.yaml", "dax_apdb") 

241 apdbConfig.column_map = _data_file_name( 

242 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

243 apdbConfig.extra_schema_file = _data_file_name( 

244 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

245 

246 apdb = Apdb(config=apdbConfig, 

247 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

248 DiaSource=make_dia_source_schema())) 

249 apdb.makeSchema() 

250 

251 minId = objects["pixelId"].min() 

252 maxId = objects["pixelId"].max() 

253 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

254 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

255 dateTime, 

256 return_pandas=True).append(sources) 

257 diaForcedSources = apdb.getDiaForcedSources( 

258 np.unique(objects["diaObjectId"]), 

259 dateTime, 

260 return_pandas=True).append(forcedSources) 

261 

262 apdb.storeDiaSources(diaSources) 

263 apdb.storeDiaForcedSources(diaForcedSources) 

264 apdb.storeDiaObjects(diaObjects, dateTime) 

265 

266 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

267 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

268 dateTime, 

269 return_pandas=True) 

270 diaForcedSources = apdb.getDiaForcedSources( 

271 np.unique(diaObjects["diaObjectId"]), 

272 dateTime, 

273 return_pandas=True) 

274 

275 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

276 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

277 drop=False, 

278 inplace=True) 

279 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

280 

281 return (diaObjects, diaSources, diaForcedSources) 

282 

283 

284class TestPackageAlerts(lsst.utils.tests.TestCase): 

285 

286 def setUp(self): 

287 np.random.seed(1234) 

288 self.cutoutSize = 35 

289 self.center = lsst.geom.Point2D(50.1, 49.8) 

290 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

291 lsst.geom.Extent2I(140, 160)) 

292 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

293 self.dataset.addSource(100000.0, self.center) 

294 exposure, catalog = self.dataset.realize( 

295 10.0, 

296 self.dataset.makeMinimalSchema(), 

297 randomSeed=0) 

298 self.exposure = exposure 

299 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

300 self.exposure.setDetector(detector) 

301 

302 visit = afwImage.VisitInfo( 

303 exposureId=1234, 

304 exposureTime=200., 

305 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

306 dafBase.DateTime.Timescale.TAI)) 

307 self.exposure.getInfo().setVisitInfo(visit) 

308 

309 self.filter_names = ["g"] 

310 afwImageUtils.resetFilters() 

311 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

312 self.exposure.setFilter(afwImage.Filter('g')) 

313 

314 diaObjects = makeDiaObjects(2, self.exposure) 

315 diaSourceHistory = makeDiaSources(10, 

316 diaObjects["diaObjectId"], 

317 self.exposure) 

318 diaForcedSources = makeDiaForcedSources(10, 

319 diaObjects["diaObjectId"], 

320 self.exposure) 

321 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

322 diaObjects, 

323 diaSourceHistory, 

324 diaForcedSources, 

325 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

326 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

327 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

328 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

329 diaSourceHistory["programId"] = 0 

330 

331 self.diaSources = diaSourceHistory.loc[ 

332 [(0, "g", 8), (1, "g", 9)], :] 

333 self.diaSources["bboxSize"] = self.cutoutSize 

334 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

335 (1, "g", 9)]) 

336 

337 self.cutoutWcs = wcs.WCS(naxis=2) 

338 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

339 self.cutoutWcs.wcs.crval = [ 

340 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

341 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

342 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

343 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

344 

345 def testCreateBBox(self): 

346 """Test the bbox creation 

347 """ 

348 packConfig = PackageAlertsConfig() 

349 # Just create a minimum less than the default cutout. 

350 packConfig.minCutoutSize = self.cutoutSize - 5 

351 packageAlerts = PackageAlertsTask(config=packConfig) 

352 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

353 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

354 packConfig.minCutoutSize)) 

355 # Test that the cutout size is correct. 

356 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

357 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

358 self.cutoutSize)) 

359 

360 def testCreateCcdDataCutout(self): 

361 """Test that the data is being extracted into the CCDData cutout 

362 correctly. 

363 """ 

364 packageAlerts = PackageAlertsTask() 

365 

366 ccdData = packageAlerts.createCcdDataCutout( 

367 self.exposure, 

368 self.exposure.getWcs().getSkyOrigin(), 

369 self.exposure.getPhotoCalib()) 

370 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

371 self.exposure.getMaskedImage()) 

372 

373 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

374 self.cutoutWcs.wcs.cd) 

375 self.assertFloatsAlmostEqual(ccdData.data, 

376 calibExposure.getImage().array) 

377 

378 def testMakeLocalTransformMatrix(self): 

379 """Test that the local WCS approximation is correct. 

380 """ 

381 packageAlerts = PackageAlertsTask() 

382 

383 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

384 cutout = self.exposure.getCutout(sphPoint, 

385 geom.Extent2I(self.cutoutSize, 

386 self.cutoutSize)) 

387 cd = packageAlerts.makeLocalTransformMatrix( 

388 cutout.getWcs(), self.center, sphPoint) 

389 self.assertFloatsAlmostEqual( 

390 cd, 

391 cutout.getWcs().getCdMatrix(), 

392 rtol=1e-11, 

393 atol=1e-11) 

394 

395 def testStreamCcdDataToBytes(self): 

396 """Test round tripping an CCDData cutout to bytes and back. 

397 """ 

398 packageAlerts = PackageAlertsTask() 

399 

400 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

401 cutout = self.exposure.getCutout(sphPoint, 

402 geom.Extent2I(self.cutoutSize, 

403 self.cutoutSize)) 

404 cutoutCcdData = CCDData( 

405 data=cutout.getImage().array, 

406 wcs=self.cutoutWcs, 

407 unit="adu") 

408 

409 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

410 with io.BytesIO(cutoutBytes) as bytesIO: 

411 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

412 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

413 

414 def testMakeAlertDict(self): 

415 """Test stripping data from the various data products and into a 

416 dictionary "alert". 

417 """ 

418 packageAlerts = PackageAlertsTask() 

419 alertId = 1234 

420 

421 for srcIdx, diaSource in self.diaSources.iterrows(): 

422 sphPoint = geom.SpherePoint(diaSource["ra"], 

423 diaSource["decl"], 

424 geom.degrees) 

425 cutout = self.exposure.getCutout(sphPoint, 

426 geom.Extent2I(self.cutoutSize, 

427 self.cutoutSize)) 

428 ccdCutout = packageAlerts.createCcdDataCutout( 

429 cutout, sphPoint, cutout.getPhotoCalib()) 

430 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

431 ccdCutout) 

432 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

433 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

434 alert = packageAlerts.makeAlertDict( 

435 alertId, 

436 diaSource, 

437 self.diaObjects.loc[srcIdx[0]], 

438 objSources, 

439 objForcedSources, 

440 ccdCutout, 

441 ccdCutout) 

442 self.assertEqual(len(alert), 9) 

443 

444 self.assertEqual(alert["alertId"], alertId) 

445 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

446 self.assertEqual(alert["cutoutDifference"], 

447 cutoutBytes) 

448 self.assertEqual(alert["cutoutTemplate"], 

449 cutoutBytes) 

450 

451 def testRun(self): 

452 """Test the run method of package alerts. 

453 """ 

454 packConfig = PackageAlertsConfig() 

455 tempdir = tempfile.mkdtemp(prefix='alerts') 

456 packConfig.alertWriteLocation = tempdir 

457 packageAlerts = PackageAlertsTask(config=packConfig) 

458 

459 packageAlerts.run(self.diaSources, 

460 self.diaObjects, 

461 self.diaSourceHistory, 

462 self.diaForcedSources, 

463 self.exposure, 

464 self.exposure, 

465 None) 

466 

467 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

468 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

469 writer_schema, data_stream = \ 

470 packageAlerts.alertSchema.retrieve_alerts(f) 

471 data = list(data_stream) 

472 self.assertEqual(len(data), len(self.diaSources)) 

473 for idx, alert in enumerate(data): 

474 for key, value in alert["diaSource"].items(): 

475 if isinstance(value, float): 

476 if np.isnan(self.diaSources.iloc[idx][key]): 

477 self.assertTrue(np.isnan(value)) 

478 else: 

479 self.assertAlmostEqual( 

480 1 - value / self.diaSources.iloc[idx][key], 

481 0.) 

482 else: 

483 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

484 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

485 alert["diaSource"]["decl"], 

486 geom.degrees) 

487 cutout = self.exposure.getCutout(sphPoint, 

488 geom.Extent2I(self.cutoutSize, 

489 self.cutoutSize)) 

490 ccdCutout = packageAlerts.createCcdDataCutout( 

491 cutout, sphPoint, cutout.getPhotoCalib()) 

492 self.assertEqual(alert["cutoutDifference"], 

493 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

494 

495 shutil.rmtree(tempdir) 

496 

497 

498class MemoryTester(lsst.utils.tests.MemoryTestCase): 

499 pass 

500 

501 

502def setup_module(module): 

503 lsst.utils.tests.init() 

504 

505 

506if __name__ == "__main__": 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true

507 lsst.utils.tests.init() 

508 unittest.main()