Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32 

33from lsst.ap.association import (PackageAlertsConfig, 

34 PackageAlertsTask, 

35 make_dia_source_schema, 

36 make_dia_object_schema) 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38import lsst.afw.image as afwImage 

39import lsst.afw.image.utils as afwImageUtils 

40import lsst.daf.base as dafBase 

41from lsst.dax.apdb import Apdb, ApdbConfig 

42import lsst.geom as geom 

43import lsst.meas.base.tests 

44from lsst.utils import getPackageDir 

45import lsst.utils.tests 

46 

47 

48def _data_file_name(basename, module_name): 

49 """Return path name of a data file. 

50 

51 Parameters 

52 ---------- 

53 basename : `str` 

54 Name of the file to add to the path string. 

55 module_name : `str` 

56 Name of lsst stack package environment variable. 

57 

58 Returns 

59 ------- 

60 data_file_path : `str` 

61 Full path of the file to load from the "data" directory in a given 

62 repository. 

63 """ 

64 return os.path.join(getPackageDir(module_name), "data", basename) 

65 

66 

67def makeDiaObjects(nObjects, exposure): 

68 """Make a test set of DiaObjects. 

69 

70 Parameters 

71 ---------- 

72 nObjects : `int` 

73 Number of objects to create. 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure to create objects over. 

76 

77 Returns 

78 ------- 

79 diaObjects : `pandas.DataFrame` 

80 DiaObjects generated across the exposure. 

81 """ 

82 bbox = geom.Box2D(exposure.getBBox()) 

83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

85 

86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

87 system=dafBase.DateTime.MJD) 

88 

89 wcs = exposure.getWcs() 

90 

91 data = [] 

92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

93 coord = wcs.pixelToSky(x, y) 

94 htmIdx = 1 

95 newObject = {"ra": coord.getRa().asDegrees(), 

96 "decl": coord.getDec().asDegrees(), 

97 "radecTai": midPointTaiMJD, 

98 "diaObjectId": idx, 

99 "pixelId": htmIdx, 

100 "pmParallaxNdata": 0, 

101 "nearbyObj1": 0, 

102 "nearbyObj2": 0, 

103 "nearbyObj3": 0, 

104 "flags": 1, 

105 "nDiaSources": 5} 

106 for f in ["u", "g", "r", "i", "z", "y"]: 

107 newObject["%sPSFluxNdata" % f] = 0 

108 data.append(newObject) 

109 

110 return pd.DataFrame(data=data) 

111 

112 

113def makeDiaSources(nSources, diaObjectIds, exposure): 

114 """Make a test set of DiaSources. 

115 

116 Parameters 

117 ---------- 

118 nSources : `int` 

119 Number of sources to create. 

120 diaObjectIds : `numpy.ndarray` 

121 Integer Ids of diaobjects to "associate" with the DiaSources. 

122 exposure : `lsst.afw.image.Exposure` 

123 Exposure to create sources over. 

124 pixelator : `lsst.sphgeom.HtmPixelization` 

125 Object to compute spatial indicies from. 

126 

127 Returns 

128 ------- 

129 diaSources : `pandas.DataFrame` 

130 DiaSources generated across the exposure. 

131 """ 

132 bbox = geom.Box2D(exposure.getBBox()) 

133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

135 

136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

137 system=dafBase.DateTime.MJD) 

138 

139 wcs = exposure.getWcs() 

140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

141 

142 data = [] 

143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

144 coord = wcs.pixelToSky(x, y) 

145 htmIdx = 1 

146 objId = diaObjectIds[idx % len(diaObjectIds)] 

147 # Put together the minimum values for the alert. 

148 data.append({"ra": coord.getRa().asDegrees(), 

149 "decl": coord.getDec().asDegrees(), 

150 "x": x, 

151 "y": y, 

152 "ccdVisitId": ccdVisitId, 

153 "diaObjectId": objId, 

154 "ssObjectId": 0, 

155 "parentDiaSourceId": 0, 

156 "prv_procOrder": 0, 

157 "diaSourceId": idx, 

158 "pixelId": htmIdx, 

159 "midPointTai": midPointTaiMJD + 1.0 * idx, 

160 "filterName": exposure.getFilter().getCanonicalName(), 

161 "filterId": 0, 

162 "psNdata": 0, 

163 "trailNdata": 0, 

164 "dipNdata": 0, 

165 "flags": 1}) 

166 

167 return pd.DataFrame(data=data) 

168 

169 

170def _roundTripThroughApdb(objects, sources, dateTime): 

171 """Run object and source catalogs through the Apdb to get the correct 

172 table schemas. 

173 

174 Parameters 

175 ---------- 

176 objects : `pandas.DataFrame` 

177 Set of test DiaObjects to round trip. 

178 sources : `pandas.DataFrame` 

179 Set of test DiaSources to round trip. 

180 dateTime : `datetime.datetime` 

181 Time for the Apdb. 

182 

183 Returns 

184 ------- 

185 objects : `pandas.DataFrame` 

186 Round tripped objects. 

187 sources : `pandas.DataFrame` 

188 Round tripped sources. 

189 """ 

190 tmpFile = tempfile.NamedTemporaryFile() 

191 

192 apdbConfig = ApdbConfig() 

193 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

194 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

195 apdbConfig.dia_object_index = "baseline" 

196 apdbConfig.dia_object_columns = [] 

197 apdbConfig.schema_file = _data_file_name( 

198 "apdb-schema.yaml", "dax_apdb") 

199 apdbConfig.column_map = _data_file_name( 

200 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

201 apdbConfig.extra_schema_file = _data_file_name( 

202 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

203 

204 apdb = Apdb(config=apdbConfig, 

205 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

206 DiaSource=make_dia_source_schema())) 

207 apdb.makeSchema() 

208 

209 minId = objects["pixelId"].min() 

210 maxId = objects["pixelId"].max() 

211 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

212 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

213 dateTime, 

214 return_pandas=True).append(sources) 

215 

216 apdb.storeDiaSources(diaSources) 

217 apdb.storeDiaObjects(diaObjects, dateTime) 

218 

219 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

220 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

221 dateTime, 

222 return_pandas=True) 

223 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

224 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

225 drop=False, 

226 inplace=True) 

227 

228 return (diaObjects, diaSources) 

229 

230 

231class TestPackageAlerts(lsst.utils.tests.TestCase): 

232 

233 def setUp(self): 

234 np.random.seed(1234) 

235 self.cutoutSize = 35 

236 self.center = lsst.geom.Point2D(50.1, 49.8) 

237 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

238 lsst.geom.Extent2I(140, 160)) 

239 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

240 self.dataset.addSource(100000.0, self.center) 

241 exposure, catalog = self.dataset.realize( 

242 10.0, 

243 self.dataset.makeMinimalSchema(), 

244 randomSeed=0) 

245 self.exposure = exposure 

246 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

247 self.exposure.setDetector(detector) 

248 

249 visit = afwImage.VisitInfo( 

250 exposureId=1234, 

251 exposureTime=200., 

252 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

253 dafBase.DateTime.Timescale.TAI)) 

254 self.exposure.getInfo().setVisitInfo(visit) 

255 

256 self.filter_names = ["g"] 

257 afwImageUtils.resetFilters() 

258 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

259 self.exposure.setFilter(afwImage.Filter('g')) 

260 

261 diaObjects = makeDiaObjects(2, self.exposure) 

262 diaSourceHistory = makeDiaSources(10, 

263 diaObjects["diaObjectId"], 

264 self.exposure) 

265 self.diaObjects, diaSourceHistory = _roundTripThroughApdb( 

266 diaObjects, 

267 diaSourceHistory, 

268 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

269 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

270 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

271 diaSourceHistory["programId"] = 0 

272 

273 self.diaSources = diaSourceHistory.loc[ 

274 [(0, "g", 8), (1, "g", 9)], :] 

275 self.diaSources["bboxSize"] = self.cutoutSize 

276 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

277 (1, "g", 9)]) 

278 

279 self.cutoutWcs = wcs.WCS(naxis=2) 

280 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

281 self.cutoutWcs.wcs.crval = [ 

282 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

283 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

284 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

285 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

286 

287 def testCreateBBox(self): 

288 """Test the bbox creation 

289 """ 

290 packConfig = PackageAlertsConfig() 

291 # Just create a minimum less than the default cutout. 

292 packConfig.minCutoutSize = self.cutoutSize - 5 

293 packageAlerts = PackageAlertsTask(config=packConfig) 

294 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

295 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

296 packConfig.minCutoutSize)) 

297 # Test that the cutout size is correct. 

298 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

299 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

300 self.cutoutSize)) 

301 

302 def testCreateCcdDataCutout(self): 

303 """Test that the data is being extracted into the CCDData cutout 

304 correctly. 

305 """ 

306 packageAlerts = PackageAlertsTask() 

307 

308 ccdData = packageAlerts.createCcdDataCutout( 

309 self.exposure, 

310 self.exposure.getWcs().getSkyOrigin(), 

311 self.exposure.getPhotoCalib()) 

312 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

313 self.exposure.getMaskedImage()) 

314 

315 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

316 self.cutoutWcs.wcs.cd) 

317 self.assertFloatsAlmostEqual(ccdData.data, 

318 calibExposure.getImage().array) 

319 

320 def testMakeLocalTransformMatrix(self): 

321 """Test that the local WCS approximation is correct. 

322 """ 

323 packageAlerts = PackageAlertsTask() 

324 

325 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

326 cutout = self.exposure.getCutout(sphPoint, 

327 geom.Extent2I(self.cutoutSize, 

328 self.cutoutSize)) 

329 cd = packageAlerts.makeLocalTransformMatrix( 

330 cutout.getWcs(), self.center, sphPoint) 

331 self.assertFloatsAlmostEqual( 

332 cd, 

333 cutout.getWcs().getCdMatrix(), 

334 rtol=1e-11, 

335 atol=1e-11) 

336 

337 def testStreamCcdDataToBytes(self): 

338 """Test round tripping an CCDData cutout to bytes and back. 

339 """ 

340 packageAlerts = PackageAlertsTask() 

341 

342 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

343 cutout = self.exposure.getCutout(sphPoint, 

344 geom.Extent2I(self.cutoutSize, 

345 self.cutoutSize)) 

346 cutoutCcdData = CCDData( 

347 data=cutout.getImage().array, 

348 wcs=self.cutoutWcs, 

349 unit="adu") 

350 

351 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

352 with io.BytesIO(cutoutBytes) as bytesIO: 

353 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

354 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

355 

356 def testMakeAlertDict(self): 

357 """Test stripping data from the various data products and into a 

358 dictionary "alert". 

359 """ 

360 packageAlerts = PackageAlertsTask() 

361 alertId = 1234 

362 

363 for srcIdx, diaSource in self.diaSources.iterrows(): 

364 sphPoint = geom.SpherePoint(diaSource["ra"], 

365 diaSource["decl"], 

366 geom.degrees) 

367 cutout = self.exposure.getCutout(sphPoint, 

368 geom.Extent2I(self.cutoutSize, 

369 self.cutoutSize)) 

370 ccdCutout = packageAlerts.createCcdDataCutout( 

371 cutout, sphPoint, cutout.getPhotoCalib()) 

372 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

373 ccdCutout) 

374 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

375 alert = packageAlerts.makeAlertDict( 

376 alertId, 

377 diaSource, 

378 self.diaObjects.loc[srcIdx[0]], 

379 objSources, 

380 ccdCutout, 

381 None) 

382 self.assertEqual(len(alert), 9) 

383 

384 self.assertEqual(alert["alertId"], alertId) 

385 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

386 self.assertEqual(alert["cutoutDifference"], 

387 cutoutBytes) 

388 self.assertEqual(alert["cutoutTemplate"], 

389 None) 

390 

391 def testRun(self): 

392 """Test the run method of package alerts. 

393 """ 

394 packConfig = PackageAlertsConfig() 

395 tempdir = tempfile.mkdtemp(prefix='alerts') 

396 packConfig.alertWriteLocation = tempdir 

397 packageAlerts = PackageAlertsTask(config=packConfig) 

398 

399 packageAlerts.run(self.diaSources, 

400 self.diaObjects, 

401 self.diaSourceHistory, 

402 self.exposure, 

403 None, 

404 None) 

405 

406 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

407 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

408 writer_schema, data = \ 

409 packageAlerts.alertSchema.retrieve_alerts(f) 

410 self.assertEqual(len(data), len(self.diaSources)) 

411 for idx, alert in enumerate(data): 

412 for key, value in alert["diaSource"].items(): 

413 if isinstance(value, float): 

414 if np.isnan(self.diaSources.iloc[idx][key]): 

415 self.assertTrue(np.isnan(value)) 

416 else: 

417 self.assertAlmostEqual( 

418 1 - value / self.diaSources.iloc[idx][key], 

419 0.) 

420 else: 

421 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

422 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

423 alert["diaSource"]["decl"], 

424 geom.degrees) 

425 cutout = self.exposure.getCutout(sphPoint, 

426 geom.Extent2I(self.cutoutSize, 

427 self.cutoutSize)) 

428 ccdCutout = packageAlerts.createCcdDataCutout( 

429 cutout, sphPoint, cutout.getPhotoCalib()) 

430 self.assertEqual(alert["cutoutDifference"], 

431 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

432 

433 shutil.rmtree(tempdir) 

434 

435 

436class MemoryTester(lsst.utils.tests.MemoryTestCase): 

437 pass 

438 

439 

440def setup_module(module): 

441 lsst.utils.tests.init() 

442 

443 

444if __name__ == "__main__": 444 ↛ 445line 444 didn't jump to line 445, because the condition on line 444 was never true

445 lsst.utils.tests.init() 

446 unittest.main()