Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.afw.image.utils as afwImageUtils 

40import lsst.daf.base as dafBase 

41from lsst.dax.apdb import Apdb, ApdbConfig 

42import lsst.geom as geom 

43import lsst.meas.base.tests 

44from lsst.utils import getPackageDir 

45import lsst.utils.tests 

46 

47 

48def _data_file_name(basename, module_name): 

49 """Return path name of a data file. 

50 

51 Parameters 

52 ---------- 

53 basename : `str` 

54 Name of the file to add to the path string. 

55 module_name : `str` 

56 Name of lsst stack package environment variable. 

57 

58 Returns 

59 ------- 

60 data_file_path : `str` 

61 Full path of the file to load from the "data" directory in a given 

62 repository. 

63 """ 

64 return os.path.join(getPackageDir(module_name), "data", basename) 

65 

66 

67def makeDiaObjects(nObjects, exposure): 

68 """Make a test set of DiaObjects. 

69 

70 Parameters 

71 ---------- 

72 nObjects : `int` 

73 Number of objects to create. 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure to create objects over. 

76 

77 Returns 

78 ------- 

79 diaObjects : `pandas.DataFrame` 

80 DiaObjects generated across the exposure. 

81 """ 

82 bbox = geom.Box2D(exposure.getBBox()) 

83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

85 

86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

87 system=dafBase.DateTime.MJD) 

88 

89 wcs = exposure.getWcs() 

90 

91 data = [] 

92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

93 coord = wcs.pixelToSky(x, y) 

94 htmIdx = 1 

95 newObject = {"ra": coord.getRa().asDegrees(), 

96 "decl": coord.getDec().asDegrees(), 

97 "radecTai": midPointTaiMJD, 

98 "diaObjectId": idx, 

99 "pixelId": htmIdx, 

100 "pmParallaxNdata": 0, 

101 "nearbyObj1": 0, 

102 "nearbyObj2": 0, 

103 "nearbyObj3": 0, 

104 "flags": 1, 

105 "nDiaSources": 5} 

106 for f in ["u", "g", "r", "i", "z", "y"]: 

107 newObject["%sPSFluxNdata" % f] = 0 

108 data.append(newObject) 

109 

110 return pd.DataFrame(data=data) 

111 

112 

113def makeDiaSources(nSources, diaObjectIds, exposure): 

114 """Make a test set of DiaSources. 

115 

116 Parameters 

117 ---------- 

118 nSources : `int` 

119 Number of sources to create. 

120 diaObjectIds : `numpy.ndarray` 

121 Integer Ids of diaobjects to "associate" with the DiaSources. 

122 exposure : `lsst.afw.image.Exposure` 

123 Exposure to create sources over. 

124 pixelator : `lsst.sphgeom.HtmPixelization` 

125 Object to compute spatial indicies from. 

126 

127 Returns 

128 ------- 

129 diaSources : `pandas.DataFrame` 

130 DiaSources generated across the exposure. 

131 """ 

132 bbox = geom.Box2D(exposure.getBBox()) 

133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

135 

136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

137 system=dafBase.DateTime.MJD) 

138 

139 wcs = exposure.getWcs() 

140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

141 

142 data = [] 

143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

144 coord = wcs.pixelToSky(x, y) 

145 htmIdx = 1 

146 objId = diaObjectIds[idx % len(diaObjectIds)] 

147 # Put together the minimum values for the alert. 

148 data.append({"ra": coord.getRa().asDegrees(), 

149 "decl": coord.getDec().asDegrees(), 

150 "x": x, 

151 "y": y, 

152 "ccdVisitId": ccdVisitId, 

153 "diaObjectId": objId, 

154 "ssObjectId": 0, 

155 "parentDiaSourceId": 0, 

156 "prv_procOrder": 0, 

157 "diaSourceId": idx, 

158 "pixelId": htmIdx, 

159 "midPointTai": midPointTaiMJD + 1.0 * idx, 

160 # TODO DM-21333: Remove [0] (first character only) workaround 

161 "filterName": exposure.getFilter().getCanonicalName()[0], 

162 "filterId": 0, 

163 "psNdata": 0, 

164 "trailNdata": 0, 

165 "dipNdata": 0, 

166 "flags": 1}) 

167 

168 return pd.DataFrame(data=data) 

169 

170 

171def makeDiaForcedSources(nSources, diaObjectIds, exposure): 

172 """Make a test set of DiaSources. 

173 

174 Parameters 

175 ---------- 

176 nSources : `int` 

177 Number of sources to create. 

178 diaObjectIds : `numpy.ndarray` 

179 Integer Ids of diaobjects to "associate" with the DiaSources. 

180 exposure : `lsst.afw.image.Exposure` 

181 Exposure to create sources over. 

182 

183 Returns 

184 ------- 

185 diaSources : `pandas.DataFrame` 

186 DiaSources generated across the exposure. 

187 """ 

188 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

189 system=dafBase.DateTime.MJD) 

190 

191 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

192 

193 data = [] 

194 for idx in range(nSources): 

195 objId = diaObjectIds[idx % len(diaObjectIds)] 

196 # Put together the minimum values for the alert. 

197 data.append({"diaForcedSourceId": idx, 

198 "ccdVisitId": ccdVisitId + idx, 

199 "diaObjectId": objId, 

200 "midPointTai": midPointTaiMJD + 1.0 * idx, 

201 # TODO DM-21333: Remove [0] (first character only) workaround 

202 "filterName": exposure.getFilter().getCanonicalName()[0], 

203 "flags": 0}) 

204 

205 return pd.DataFrame(data=data) 

206 

207 

208def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

209 """Run object and source catalogs through the Apdb to get the correct 

210 table schemas. 

211 

212 Parameters 

213 ---------- 

214 objects : `pandas.DataFrame` 

215 Set of test DiaObjects to round trip. 

216 sources : `pandas.DataFrame` 

217 Set of test DiaSources to round trip. 

218 forcedSources : `pandas.DataFrame` 

219 Set of test DiaForcedSources to round trip. 

220 dateTime : `datetime.datetime` 

221 Time for the Apdb. 

222 

223 Returns 

224 ------- 

225 objects : `pandas.DataFrame` 

226 Round tripped objects. 

227 sources : `pandas.DataFrame` 

228 Round tripped sources. 

229 """ 

230 tmpFile = tempfile.NamedTemporaryFile() 

231 

232 apdbConfig = ApdbConfig() 

233 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

234 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

235 apdbConfig.dia_object_index = "baseline" 

236 apdbConfig.dia_object_columns = [] 

237 apdbConfig.schema_file = _data_file_name( 

238 "apdb-schema.yaml", "dax_apdb") 

239 apdbConfig.column_map = _data_file_name( 

240 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

241 apdbConfig.extra_schema_file = _data_file_name( 

242 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

243 

244 apdb = Apdb(config=apdbConfig, 

245 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

246 DiaSource=make_dia_source_schema())) 

247 apdb.makeSchema() 

248 

249 minId = objects["pixelId"].min() 

250 maxId = objects["pixelId"].max() 

251 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

252 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

253 dateTime, 

254 return_pandas=True).append(sources) 

255 diaForcedSources = apdb.getDiaForcedSources( 

256 np.unique(objects["diaObjectId"]), 

257 dateTime, 

258 return_pandas=True).append(forcedSources) 

259 

260 apdb.storeDiaSources(diaSources) 

261 apdb.storeDiaForcedSources(diaForcedSources) 

262 apdb.storeDiaObjects(diaObjects, dateTime) 

263 

264 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

265 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

266 dateTime, 

267 return_pandas=True) 

268 diaForcedSources = apdb.getDiaForcedSources( 

269 np.unique(diaObjects["diaObjectId"]), 

270 dateTime, 

271 return_pandas=True) 

272 

273 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

274 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

275 drop=False, 

276 inplace=True) 

277 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

278 

279 return (diaObjects, diaSources, diaForcedSources) 

280 

281 

282class TestPackageAlerts(lsst.utils.tests.TestCase): 

283 

284 def setUp(self): 

285 np.random.seed(1234) 

286 self.cutoutSize = 35 

287 self.center = lsst.geom.Point2D(50.1, 49.8) 

288 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

289 lsst.geom.Extent2I(140, 160)) 

290 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

291 self.dataset.addSource(100000.0, self.center) 

292 exposure, catalog = self.dataset.realize( 

293 10.0, 

294 self.dataset.makeMinimalSchema(), 

295 randomSeed=0) 

296 self.exposure = exposure 

297 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

298 self.exposure.setDetector(detector) 

299 

300 visit = afwImage.VisitInfo( 

301 exposureId=1234, 

302 exposureTime=200., 

303 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

304 dafBase.DateTime.Timescale.TAI)) 

305 self.exposure.getInfo().setVisitInfo(visit) 

306 

307 self.filter_names = ["g"] 

308 afwImageUtils.resetFilters() 

309 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

310 self.exposure.setFilter(afwImage.Filter('g')) 

311 

312 diaObjects = makeDiaObjects(2, self.exposure) 

313 diaSourceHistory = makeDiaSources(10, 

314 diaObjects["diaObjectId"], 

315 self.exposure) 

316 diaForcedSources = makeDiaForcedSources(10, 

317 diaObjects["diaObjectId"], 

318 self.exposure) 

319 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

320 diaObjects, 

321 diaSourceHistory, 

322 diaForcedSources, 

323 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

324 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

325 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

326 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True) 

327 diaSourceHistory["programId"] = 0 

328 

329 self.diaSources = diaSourceHistory.loc[ 

330 [(0, "g", 8), (1, "g", 9)], :] 

331 self.diaSources["bboxSize"] = self.cutoutSize 

332 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

333 (1, "g", 9)]) 

334 

335 self.cutoutWcs = wcs.WCS(naxis=2) 

336 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

337 self.cutoutWcs.wcs.crval = [ 

338 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

339 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

340 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

341 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

342 

343 def testCreateBBox(self): 

344 """Test the bbox creation 

345 """ 

346 packConfig = PackageAlertsConfig() 

347 # Just create a minimum less than the default cutout. 

348 packConfig.minCutoutSize = self.cutoutSize - 5 

349 packageAlerts = PackageAlertsTask(config=packConfig) 

350 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

351 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

352 packConfig.minCutoutSize)) 

353 # Test that the cutout size is correct. 

354 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

355 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

356 self.cutoutSize)) 

357 

358 def testCreateCcdDataCutout(self): 

359 """Test that the data is being extracted into the CCDData cutout 

360 correctly. 

361 """ 

362 packageAlerts = PackageAlertsTask() 

363 

364 ccdData = packageAlerts.createCcdDataCutout( 

365 self.exposure, 

366 self.exposure.getWcs().getSkyOrigin(), 

367 self.exposure.getPhotoCalib()) 

368 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

369 self.exposure.getMaskedImage()) 

370 

371 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

372 self.cutoutWcs.wcs.cd) 

373 self.assertFloatsAlmostEqual(ccdData.data, 

374 calibExposure.getImage().array) 

375 

376 def testMakeLocalTransformMatrix(self): 

377 """Test that the local WCS approximation is correct. 

378 """ 

379 packageAlerts = PackageAlertsTask() 

380 

381 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

382 cutout = self.exposure.getCutout(sphPoint, 

383 geom.Extent2I(self.cutoutSize, 

384 self.cutoutSize)) 

385 cd = packageAlerts.makeLocalTransformMatrix( 

386 cutout.getWcs(), self.center, sphPoint) 

387 self.assertFloatsAlmostEqual( 

388 cd, 

389 cutout.getWcs().getCdMatrix(), 

390 rtol=1e-11, 

391 atol=1e-11) 

392 

393 def testStreamCcdDataToBytes(self): 

394 """Test round tripping an CCDData cutout to bytes and back. 

395 """ 

396 packageAlerts = PackageAlertsTask() 

397 

398 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

399 cutout = self.exposure.getCutout(sphPoint, 

400 geom.Extent2I(self.cutoutSize, 

401 self.cutoutSize)) 

402 cutoutCcdData = CCDData( 

403 data=cutout.getImage().array, 

404 wcs=self.cutoutWcs, 

405 unit="adu") 

406 

407 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

408 with io.BytesIO(cutoutBytes) as bytesIO: 

409 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

410 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

411 

412 def testMakeAlertDict(self): 

413 """Test stripping data from the various data products and into a 

414 dictionary "alert". 

415 """ 

416 packageAlerts = PackageAlertsTask() 

417 alertId = 1234 

418 

419 for srcIdx, diaSource in self.diaSources.iterrows(): 

420 sphPoint = geom.SpherePoint(diaSource["ra"], 

421 diaSource["decl"], 

422 geom.degrees) 

423 cutout = self.exposure.getCutout(sphPoint, 

424 geom.Extent2I(self.cutoutSize, 

425 self.cutoutSize)) 

426 ccdCutout = packageAlerts.createCcdDataCutout( 

427 cutout, sphPoint, cutout.getPhotoCalib()) 

428 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

429 ccdCutout) 

430 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

431 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

432 alert = packageAlerts.makeAlertDict( 

433 alertId, 

434 diaSource, 

435 self.diaObjects.loc[srcIdx[0]], 

436 objSources, 

437 objForcedSources, 

438 ccdCutout, 

439 ccdCutout) 

440 self.assertEqual(len(alert), 9) 

441 

442 self.assertEqual(alert["alertId"], alertId) 

443 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

444 self.assertEqual(alert["cutoutDifference"], 

445 cutoutBytes) 

446 self.assertEqual(alert["cutoutTemplate"], 

447 cutoutBytes) 

448 

449 def testRun(self): 

450 """Test the run method of package alerts. 

451 """ 

452 packConfig = PackageAlertsConfig() 

453 tempdir = tempfile.mkdtemp(prefix='alerts') 

454 packConfig.alertWriteLocation = tempdir 

455 packageAlerts = PackageAlertsTask(config=packConfig) 

456 

457 packageAlerts.run(self.diaSources, 

458 self.diaObjects, 

459 self.diaSourceHistory, 

460 self.diaForcedSources, 

461 self.exposure, 

462 self.exposure, 

463 None) 

464 

465 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

466 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

467 writer_schema, data_stream = \ 

468 packageAlerts.alertSchema.retrieve_alerts(f) 

469 data = list(data_stream) 

470 self.assertEqual(len(data), len(self.diaSources)) 

471 for idx, alert in enumerate(data): 

472 for key, value in alert["diaSource"].items(): 

473 if isinstance(value, float): 

474 if np.isnan(self.diaSources.iloc[idx][key]): 

475 self.assertTrue(np.isnan(value)) 

476 else: 

477 self.assertAlmostEqual( 

478 1 - value / self.diaSources.iloc[idx][key], 

479 0.) 

480 else: 

481 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

482 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

483 alert["diaSource"]["decl"], 

484 geom.degrees) 

485 cutout = self.exposure.getCutout(sphPoint, 

486 geom.Extent2I(self.cutoutSize, 

487 self.cutoutSize)) 

488 ccdCutout = packageAlerts.createCcdDataCutout( 

489 cutout, sphPoint, cutout.getPhotoCalib()) 

490 self.assertEqual(alert["cutoutDifference"], 

491 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

492 

493 shutil.rmtree(tempdir) 

494 

495 

496class MemoryTester(lsst.utils.tests.MemoryTestCase): 

497 pass 

498 

499 

500def setup_module(module): 

501 lsst.utils.tests.init() 

502 

503 

504if __name__ == "__main__": 504 ↛ 505line 504 didn't jump to line 505, because the condition on line 504 was never true

505 lsst.utils.tests.init() 

506 unittest.main()