Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.afw.image.utils as afwImageUtils 

40import lsst.daf.base as dafBase 

41from lsst.dax.apdb import Apdb, ApdbConfig 

42import lsst.geom as geom 

43import lsst.meas.base.tests 

44from lsst.utils import getPackageDir 

45import lsst.utils.tests 

46 

47 

48def _data_file_name(basename, module_name): 

49 """Return path name of a data file. 

50 

51 Parameters 

52 ---------- 

53 basename : `str` 

54 Name of the file to add to the path string. 

55 module_name : `str` 

56 Name of lsst stack package environment variable. 

57 

58 Returns 

59 ------- 

60 data_file_path : `str` 

61 Full path of the file to load from the "data" directory in a given 

62 repository. 

63 """ 

64 return os.path.join(getPackageDir(module_name), "data", basename) 

65 

66 

67def makeDiaObjects(nObjects, exposure): 

68 """Make a test set of DiaObjects. 

69 

70 Parameters 

71 ---------- 

72 nObjects : `int` 

73 Number of objects to create. 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure to create objects over. 

76 

77 Returns 

78 ------- 

79 diaObjects : `pandas.DataFrame` 

80 DiaObjects generated across the exposure. 

81 """ 

82 bbox = geom.Box2D(exposure.getBBox()) 

83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

85 

86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

87 system=dafBase.DateTime.MJD) 

88 

89 wcs = exposure.getWcs() 

90 

91 data = [] 

92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

93 coord = wcs.pixelToSky(x, y) 

94 htmIdx = 1 

95 newObject = {"ra": coord.getRa().asDegrees(), 

96 "decl": coord.getDec().asDegrees(), 

97 "radecTai": midPointTaiMJD, 

98 "diaObjectId": idx, 

99 "pixelId": htmIdx, 

100 "pmParallaxNdata": 0, 

101 "nearbyObj1": 0, 

102 "nearbyObj2": 0, 

103 "nearbyObj3": 0, 

104 "flags": 1, 

105 "nDiaSources": 5} 

106 for f in ["u", "g", "r", "i", "z", "y"]: 

107 newObject["%sPSFluxNdata" % f] = 0 

108 data.append(newObject) 

109 

110 return pd.DataFrame(data=data) 

111 

112 

113def makeDiaSources(nSources, diaObjectIds, exposure): 

114 """Make a test set of DiaSources. 

115 

116 Parameters 

117 ---------- 

118 nSources : `int` 

119 Number of sources to create. 

120 diaObjectIds : `numpy.ndarray` 

121 Integer Ids of diaobjects to "associate" with the DiaSources. 

122 exposure : `lsst.afw.image.Exposure` 

123 Exposure to create sources over. 

124 pixelator : `lsst.sphgeom.HtmPixelization` 

125 Object to compute spatial indicies from. 

126 

127 Returns 

128 ------- 

129 diaSources : `pandas.DataFrame` 

130 DiaSources generated across the exposure. 

131 """ 

132 bbox = geom.Box2D(exposure.getBBox()) 

133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

135 

136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

137 system=dafBase.DateTime.MJD) 

138 

139 wcs = exposure.getWcs() 

140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

141 

142 data = [] 

143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

144 coord = wcs.pixelToSky(x, y) 

145 htmIdx = 1 

146 objId = diaObjectIds[idx % len(diaObjectIds)] 

147 # Put together the minimum values for the alert. 

148 data.append({"ra": coord.getRa().asDegrees(), 

149 "decl": coord.getDec().asDegrees(), 

150 "x": x, 

151 "y": y, 

152 "ccdVisitId": ccdVisitId, 

153 "diaObjectId": objId, 

154 "ssObjectId": 0, 

155 "parentDiaSourceId": 0, 

156 "prv_procOrder": 0, 

157 "diaSourceId": idx, 

158 "pixelId": htmIdx, 

159 "midPointTai": midPointTaiMJD + 1.0 * idx, 

160 "filterName": exposure.getFilter().getCanonicalName(), 

161 "filterId": 0, 

162 "psNdata": 0, 

163 "trailNdata": 0, 

164 "dipNdata": 0, 

165 "flags": 1}) 

166 

167 return pd.DataFrame(data=data) 

168 

169 

170def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

171 """Make a test set of DiaSources. 

172 

173 Parameters 

174 ---------- 

175 nSources : `int` 

176 Number of sources to create. 

177 diaObjectIds : `numpy.ndarray` 

178 Integer Ids of diaobjects to "associate" with the DiaSources. 

179 exposure : `lsst.afw.image.Exposure` 

180 Exposure to create sources over. 

181 

182 Returns 

183 ------- 

184 diaSources : `pandas.DataFrame` 

185 DiaSources generated across the exposure. 

186 """ 

187 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

188 system=dafBase.DateTime.MJD) 

189 

190 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

191 

192 data = [] 

193 for idx in range(nSources): 

194 objId = diaObjectIds[idx % len(diaObjectIds)] 

195 # Put together the minimum values for the alert. 

196 data.append({"diaForcedSourceId": idx, 

197 "ccdVisitId": ccdVisitId + idx, 

198 "diaObjectId": objId, 

199 "midPointTai": midPointTaiMJD + 1.0 * idx, 

200 "filterName": exposure.getFilter().getCanonicalName(), 

201 "flags": 0}) 

202 

203 return pd.DataFrame(data=data) 

204 

205 

206def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

207 """Run object and source catalogs through the Apdb to get the correct 

208 table schemas. 

209 

210 Parameters 

211 ---------- 

212 objects : `pandas.DataFrame` 

213 Set of test DiaObjects to round trip. 

214 sources : `pandas.DataFrame` 

215 Set of test DiaSources to round trip. 

216 forcedSources : `pandas.DataFrame` 

217 Set of test DiaForcedSources to round trip. 

218 dateTime : `datetime.datetime` 

219 Time for the Apdb. 

220 

221 Returns 

222 ------- 

223 objects : `pandas.DataFrame` 

224 Round tripped objects. 

225 sources : `pandas.DataFrame` 

226 Round tripped sources. 

227 """ 

228 tmpFile = tempfile.NamedTemporaryFile() 

229 

230 apdbConfig = ApdbConfig() 

231 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

232 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

233 apdbConfig.dia_object_index = "baseline" 

234 apdbConfig.dia_object_columns = [] 

235 apdbConfig.schema_file = _data_file_name( 

236 "apdb-schema.yaml", "dax_apdb") 

237 apdbConfig.column_map = _data_file_name( 

238 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

239 apdbConfig.extra_schema_file = _data_file_name( 

240 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

241 

242 apdb = Apdb(config=apdbConfig, 

243 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

244 DiaSource=make_dia_source_schema())) 

245 apdb.makeSchema() 

246 

247 minId = objects["pixelId"].min() 

248 maxId = objects["pixelId"].max() 

249 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

250 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

251 dateTime, 

252 return_pandas=True).append(sources) 

253 diaForcedSources = apdb.getDiaForcedSources( 

254 np.unique(objects["diaObjectId"]), 

255 dateTime, 

256 return_pandas=True).append(forcedSources) 

257 

258 apdb.storeDiaSources(diaSources) 

259 apdb.storeDiaForcedSources(diaForcedSources) 

260 apdb.storeDiaObjects(diaObjects, dateTime) 

261 

262 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

263 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

264 dateTime, 

265 return_pandas=True) 

266 diaForcedSources = apdb.getDiaForcedSources( 

267 np.unique(diaObjects["diaObjectId"]), 

268 dateTime, 

269 return_pandas=True) 

270 

271 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

272 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

273 drop=False, 

274 inplace=True) 

275 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

276 

277 return (diaObjects, diaSources, diaForcedSources) 

278 

279 

280class TestPackageAlerts(lsst.utils.tests.TestCase): 

281 

282 def setUp(self): 

283 np.random.seed(1234) 

284 self.cutoutSize = 35 

285 self.center = lsst.geom.Point2D(50.1, 49.8) 

286 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

287 lsst.geom.Extent2I(140, 160)) 

288 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

289 self.dataset.addSource(100000.0, self.center) 

290 exposure, catalog = self.dataset.realize( 

291 10.0, 

292 self.dataset.makeMinimalSchema(), 

293 randomSeed=0) 

294 self.exposure = exposure 

295 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

296 self.exposure.setDetector(detector) 

297 

298 visit = afwImage.VisitInfo( 

299 exposureId=1234, 

300 exposureTime=200., 

301 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

302 dafBase.DateTime.Timescale.TAI)) 

303 self.exposure.getInfo().setVisitInfo(visit) 

304 

305 self.filter_names = ["g"] 

306 afwImageUtils.resetFilters() 

307 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

308 self.exposure.setFilter(afwImage.Filter('g')) 

309 

310 diaObjects = makeDiaObjects(2, self.exposure) 

311 diaSourceHistory = makeDiaSources(10, 

312 diaObjects["diaObjectId"], 

313 self.exposure) 

314 diaForcedSources = makeDiaForcedSources(10, 

315 diaObjects["diaObjectId"], 

316 self.exposure) 

317 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

318 diaObjects, 

319 diaSourceHistory, 

320 diaForcedSources, 

321 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

322 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

323 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

324 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

325 diaSourceHistory["programId"] = 0 

326 

327 self.diaSources = diaSourceHistory.loc[ 

328 [(0, "g", 8), (1, "g", 9)], :] 

329 self.diaSources["bboxSize"] = self.cutoutSize 

330 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

331 (1, "g", 9)]) 

332 

333 self.cutoutWcs = wcs.WCS(naxis=2) 

334 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

335 self.cutoutWcs.wcs.crval = [ 

336 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

337 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

338 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

339 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

340 

341 def testCreateBBox(self): 

342 """Test the bbox creation 

343 """ 

344 packConfig = PackageAlertsConfig() 

345 # Just create a minimum less than the default cutout. 

346 packConfig.minCutoutSize = self.cutoutSize - 5 

347 packageAlerts = PackageAlertsTask(config=packConfig) 

348 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

349 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

350 packConfig.minCutoutSize)) 

351 # Test that the cutout size is correct. 

352 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

353 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

354 self.cutoutSize)) 

355 

356 def testCreateCcdDataCutout(self): 

357 """Test that the data is being extracted into the CCDData cutout 

358 correctly. 

359 """ 

360 packageAlerts = PackageAlertsTask() 

361 

362 ccdData = packageAlerts.createCcdDataCutout( 

363 self.exposure, 

364 self.exposure.getWcs().getSkyOrigin(), 

365 self.exposure.getPhotoCalib()) 

366 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

367 self.exposure.getMaskedImage()) 

368 

369 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

370 self.cutoutWcs.wcs.cd) 

371 self.assertFloatsAlmostEqual(ccdData.data, 

372 calibExposure.getImage().array) 

373 

374 def testMakeLocalTransformMatrix(self): 

375 """Test that the local WCS approximation is correct. 

376 """ 

377 packageAlerts = PackageAlertsTask() 

378 

379 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

380 cutout = self.exposure.getCutout(sphPoint, 

381 geom.Extent2I(self.cutoutSize, 

382 self.cutoutSize)) 

383 cd = packageAlerts.makeLocalTransformMatrix( 

384 cutout.getWcs(), self.center, sphPoint) 

385 self.assertFloatsAlmostEqual( 

386 cd, 

387 cutout.getWcs().getCdMatrix(), 

388 rtol=1e-11, 

389 atol=1e-11) 

390 

391 def testStreamCcdDataToBytes(self): 

392 """Test round tripping an CCDData cutout to bytes and back. 

393 """ 

394 packageAlerts = PackageAlertsTask() 

395 

396 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

397 cutout = self.exposure.getCutout(sphPoint, 

398 geom.Extent2I(self.cutoutSize, 

399 self.cutoutSize)) 

400 cutoutCcdData = CCDData( 

401 data=cutout.getImage().array, 

402 wcs=self.cutoutWcs, 

403 unit="adu") 

404 

405 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

406 with io.BytesIO(cutoutBytes) as bytesIO: 

407 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

408 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

409 

410 def testMakeAlertDict(self): 

411 """Test stripping data from the various data products and into a 

412 dictionary "alert". 

413 """ 

414 packageAlerts = PackageAlertsTask() 

415 alertId = 1234 

416 

417 for srcIdx, diaSource in self.diaSources.iterrows(): 

418 sphPoint = geom.SpherePoint(diaSource["ra"], 

419 diaSource["decl"], 

420 geom.degrees) 

421 cutout = self.exposure.getCutout(sphPoint, 

422 geom.Extent2I(self.cutoutSize, 

423 self.cutoutSize)) 

424 ccdCutout = packageAlerts.createCcdDataCutout( 

425 cutout, sphPoint, cutout.getPhotoCalib()) 

426 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

427 ccdCutout) 

428 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

429 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

430 alert = packageAlerts.makeAlertDict( 

431 alertId, 

432 diaSource, 

433 self.diaObjects.loc[srcIdx[0]], 

434 objSources, 

435 objForcedSources, 

436 ccdCutout, 

437 None) 

438 self.assertEqual(len(alert), 9) 

439 

440 self.assertEqual(alert["alertId"], alertId) 

441 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

442 self.assertEqual(alert["cutoutDifference"], 

443 cutoutBytes) 

444 self.assertEqual(alert["cutoutTemplate"], 

445 None) 

446 

447 def testRun(self): 

448 """Test the run method of package alerts. 

449 """ 

450 packConfig = PackageAlertsConfig() 

451 tempdir = tempfile.mkdtemp(prefix='alerts') 

452 packConfig.alertWriteLocation = tempdir 

453 packageAlerts = PackageAlertsTask(config=packConfig) 

454 

455 packageAlerts.run(self.diaSources, 

456 self.diaObjects, 

457 self.diaSourceHistory, 

458 self.diaForcedSources, 

459 self.exposure, 

460 None, 

461 None) 

462 

463 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

464 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

465 writer_schema, data = \ 

466 packageAlerts.alertSchema.retrieve_alerts(f) 

467 self.assertEqual(len(data), len(self.diaSources)) 

468 for idx, alert in enumerate(data): 

469 for key, value in alert["diaSource"].items(): 

470 if isinstance(value, float): 

471 if np.isnan(self.diaSources.iloc[idx][key]): 

472 self.assertTrue(np.isnan(value)) 

473 else: 

474 self.assertAlmostEqual( 

475 1 - value / self.diaSources.iloc[idx][key], 

476 0.) 

477 else: 

478 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

479 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

480 alert["diaSource"]["decl"], 

481 geom.degrees) 

482 cutout = self.exposure.getCutout(sphPoint, 

483 geom.Extent2I(self.cutoutSize, 

484 self.cutoutSize)) 

485 ccdCutout = packageAlerts.createCcdDataCutout( 

486 cutout, sphPoint, cutout.getPhotoCalib()) 

487 self.assertEqual(alert["cutoutDifference"], 

488 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

489 

490 shutil.rmtree(tempdir) 

491 

492 

493class MemoryTester(lsst.utils.tests.MemoryTestCase): 

494 pass 

495 

496 

497def setup_module(module): 

498 lsst.utils.tests.init() 

499 

500 

501if __name__ == "__main__": 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true

502 lsst.utils.tests.init() 

503 unittest.main()