Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import numpy as np 

24import pandas as pd 

25import shutil 

26import tempfile 

27import unittest 

28 

29from lsst.ap.association import (PackageAlertsConfig, 

30 PackageAlertsTask, 

31 make_dia_source_schema, 

32 make_dia_object_schema) 

33from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

34import lsst.afw.fits as afwFits 

35import lsst.afw.image as afwImage 

36import lsst.afw.image.utils as afwImageUtils 

37import lsst.daf.base as dafBase 

38from lsst.dax.apdb import Apdb, ApdbConfig 

39import lsst.geom as geom 

40import lsst.meas.base.tests 

41from lsst.utils import getPackageDir 

42import lsst.utils.tests 

43 

44 

45def _data_file_name(basename, module_name): 

46 """Return path name of a data file. 

47 

48 Parameters 

49 ---------- 

50 basename : `str` 

51 Name of the file to add to the path string. 

52 module_name : `str` 

53 Name of lsst stack package environment variable. 

54 

55 Returns 

56 ------- 

57 data_file_path : `str` 

58 Full path of the file to load from the "data" directory in a given 

59 repository. 

60 """ 

61 return os.path.join(getPackageDir(module_name), "data", basename) 

62 

63 

64def makeDiaObjects(nObjects, exposure): 

65 """Make a test set of DiaObjects. 

66 

67 Parameters 

68 ---------- 

69 nObjects : `int` 

70 Number of objects to create. 

71 exposure : `lsst.afw.image.Exposure` 

72 Exposure to create objects over. 

73 

74 Returns 

75 ------- 

76 diaObjects : `pandas.DataFrame` 

77 DiaObjects generated across the exposure. 

78 """ 

79 bbox = geom.Box2D(exposure.getBBox()) 

80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

82 

83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

84 system=dafBase.DateTime.MJD) 

85 

86 wcs = exposure.getWcs() 

87 

88 data = [] 

89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

90 coord = wcs.pixelToSky(x, y) 

91 htmIdx = 1 

92 newObject = {"ra": coord.getRa().asDegrees(), 

93 "decl": coord.getDec().asDegrees(), 

94 "radecTai": midPointTaiMJD, 

95 "diaObjectId": idx, 

96 "pixelId": htmIdx, 

97 "pmParallaxNdata": 0, 

98 "nearbyObj1": 0, 

99 "nearbyObj2": 0, 

100 "nearbyObj3": 0, 

101 "flags": 1, 

102 "nDiaSources": 5} 

103 for f in ["u", "g", "r", "i", "z", "y"]: 

104 newObject["%sPSFluxNdata" % f] = 0 

105 data.append(newObject) 

106 

107 return pd.DataFrame(data=data) 

108 

109 

110def makeDiaSources(nSources, diaObjectIds, exposure): 

111 """Make a test set of DiaSources. 

112 

113 Parameters 

114 ---------- 

115 nSources : `int` 

116 Number of sources to create. 

117 diaObjectIds : `numpy.ndarray` 

118 Integer Ids of diaobjects to "associate" with the DiaSources. 

119 exposure : `lsst.afw.image.Exposure` 

120 Exposure to create sources over. 

121 pixelator : `lsst.sphgeom.HtmPixelization` 

122 Object to compute spatial indicies from. 

123 

124 Returns 

125 ------- 

126 diaSources : `pandas.DataFrame` 

127 DiaSources generated across the exposure. 

128 """ 

129 bbox = geom.Box2D(exposure.getBBox()) 

130 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

131 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

132 

133 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

134 system=dafBase.DateTime.MJD) 

135 

136 wcs = exposure.getWcs() 

137 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

138 

139 data = [] 

140 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

141 coord = wcs.pixelToSky(x, y) 

142 htmIdx = 1 

143 objId = diaObjectIds[idx % len(diaObjectIds)] 

144 # Put together the minimum values for the alert. 

145 data.append({"ra": coord.getRa().asDegrees(), 

146 "decl": coord.getDec().asDegrees(), 

147 "x": x, 

148 "y": y, 

149 "ccdVisitId": ccdVisitId, 

150 "diaObjectId": objId, 

151 "ssObjectId": 0, 

152 "parentDiaSourceId": 0, 

153 "prv_procOrder": 0, 

154 "diaSourceId": idx, 

155 "pixelId": htmIdx, 

156 "midPointTai": midPointTaiMJD + 1.0 * idx, 

157 "filterName": exposure.getFilter().getCanonicalName(), 

158 "filterId": 0, 

159 "psNdata": 0, 

160 "trailNdata": 0, 

161 "dipNdata": 0, 

162 "flags": 1}) 

163 

164 return pd.DataFrame(data=data) 

165 

166 

167def _roundTripThroughApdb(objects, sources, dateTime): 

168 """Run object and source catalogs through the Apdb to get the correct 

169 table schemas. 

170 

171 Parameters 

172 ---------- 

173 objects : `pandas.DataFrame` 

174 Set of test DiaObjects to round trip. 

175 sources : `pandas.DataFrame` 

176 Set of test DiaSources to round trip. 

177 dateTime : `datetime.datetime` 

178 Time for the Apdb. 

179 

180 Returns 

181 ------- 

182 objects : `pandas.DataFrame` 

183 Round tripped objects. 

184 sources : `pandas.DataFrame` 

185 Round tripped sources. 

186 """ 

187 tmpFile = tempfile.NamedTemporaryFile() 

188 

189 apdbConfig = ApdbConfig() 

190 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

191 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

192 apdbConfig.dia_object_index = "baseline" 

193 apdbConfig.dia_object_columns = [] 

194 apdbConfig.schema_file = _data_file_name( 

195 "apdb-schema.yaml", "dax_apdb") 

196 apdbConfig.column_map = _data_file_name( 

197 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

198 apdbConfig.extra_schema_file = _data_file_name( 

199 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

200 

201 apdb = Apdb(config=apdbConfig, 

202 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

203 DiaSource=make_dia_source_schema())) 

204 apdb.makeSchema() 

205 

206 minId = objects["pixelId"].min() 

207 maxId = objects["pixelId"].max() 

208 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

209 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

210 dateTime, 

211 return_pandas=True).append(sources) 

212 

213 apdb.storeDiaSources(diaSources) 

214 apdb.storeDiaObjects(diaObjects, dateTime) 

215 

216 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

217 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

218 dateTime, 

219 return_pandas=True) 

220 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

221 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

222 drop=False, 

223 inplace=True) 

224 

225 return (diaObjects, diaSources) 

226 

227 

228class TestPackageAlerts(unittest.TestCase): 

229 

230 def setUp(self): 

231 np.random.seed(1234) 

232 self.cutoutSize = 35 

233 self.center = lsst.geom.Point2D(50.1, 49.8) 

234 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

235 lsst.geom.Extent2I(140, 160)) 

236 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

237 self.dataset.addSource(100000.0, self.center) 

238 exposure, catalog = self.dataset.realize( 

239 10.0, 

240 self.dataset.makeMinimalSchema(), 

241 randomSeed=0) 

242 self.exposure = exposure 

243 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

244 self.exposure.setDetector(detector) 

245 

246 visit = afwImage.VisitInfo( 

247 exposureId=1234, 

248 exposureTime=200., 

249 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

250 dafBase.DateTime.Timescale.TAI)) 

251 self.exposure.getInfo().setVisitInfo(visit) 

252 

253 self.filter_names = ["g"] 

254 afwImageUtils.resetFilters() 

255 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

256 self.exposure.setFilter(afwImage.Filter('g')) 

257 

258 diaObjects = makeDiaObjects(2, self.exposure) 

259 diaSourceHistory = makeDiaSources(10, 

260 diaObjects["diaObjectId"], 

261 self.exposure) 

262 self.diaObjects, diaSourceHistory = _roundTripThroughApdb( 

263 diaObjects, 

264 diaSourceHistory, 

265 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

266 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

267 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

268 diaSourceHistory["programId"] = 0 

269 diaSourceHistory["ra_decl_Cov"] = None 

270 diaSourceHistory["x_y_Cov"] = None 

271 diaSourceHistory["ps_Cov"] = None 

272 diaSourceHistory["trail_Cov"] = None 

273 diaSourceHistory["dip_Cov"] = None 

274 diaSourceHistory["i_cov"] = None 

275 

276 self.diaObjects["ra_decl_Cov"] = None 

277 self.diaObjects["pm_parallax_Cov"] = None 

278 self.diaObjects["uLcPeriodic"] = None 

279 self.diaObjects["gLcPeriodic"] = None 

280 self.diaObjects["rLcPeriodic"] = None 

281 self.diaObjects["iLcPeriodic"] = None 

282 self.diaObjects["zLcPeriodic"] = None 

283 self.diaObjects["yLcPeriodic"] = None 

284 self.diaObjects["uLcNonPeriodic"] = None 

285 self.diaObjects["gLcNonPeriodic"] = None 

286 self.diaObjects["rLcNonPeriodic"] = None 

287 self.diaObjects["iLcNonPeriodic"] = None 

288 self.diaObjects["zLcNonPeriodic"] = None 

289 self.diaObjects["yLcNonPeriodic"] = None 

290 

291 self.diaSources = diaSourceHistory.loc[ 

292 [(0, "g", 8), (1, "g", 9)], :] 

293 self.diaSources["bboxSize"] = self.cutoutSize 

294 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

295 (1, "g", 9)]) 

296 

297 def testCreateBBox(self): 

298 """Test the bbox creation 

299 """ 

300 packConfig = PackageAlertsConfig() 

301 # Just create a minimum less than the default cutout. 

302 packConfig.minCutoutSize = self.cutoutSize - 5 

303 packageAlerts = PackageAlertsTask(config=packConfig) 

304 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5) 

305 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize, 

306 packConfig.minCutoutSize)) 

307 # Test that the cutout size is correct. 

308 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize) 

309 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize, 

310 self.cutoutSize)) 

311 

312 def testMakeCutoutBytes(self): 

313 """Test round tripping an exposure/cutout to bytes and back. 

314 """ 

315 packageAlerts = PackageAlertsTask() 

316 

317 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

318 cutout = self.exposure.getCutout(sphPoint, 

319 geom.Extent2I(self.cutoutSize, 

320 self.cutoutSize)) 

321 

322 cutoutBytes = packageAlerts.makeCutoutBytes(cutout) 

323 tempMemFile = afwFits.MemFileManager(len(cutoutBytes)) 

324 tempMemFile.setData(cutoutBytes, len(cutoutBytes)) 

325 cutoutFromBytes = afwImage.ExposureF(tempMemFile) 

326 self.assertTrue( 

327 np.all(cutout.getImage().array == cutoutFromBytes.getImage().array)) 

328 

329 def testMakeAlertDict(self): 

330 """Test stripping data from the various data products and into a 

331 dictionary "alert". 

332 """ 

333 packageAlerts = PackageAlertsTask() 

334 alertId = 1234 

335 

336 for srcIdx, diaSource in self.diaSources.iterrows(): 

337 sphPoint = geom.SpherePoint(diaSource["ra"], 

338 diaSource["decl"], 

339 geom.degrees) 

340 cutout = self.exposure.getCutout(sphPoint, 

341 geom.Extent2I(self.cutoutSize, 

342 self.cutoutSize)) 

343 cutputBytes = packageAlerts.makeCutoutBytes(cutout) 

344 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

345 alert = packageAlerts.makeAlertDict( 

346 alertId, 

347 diaSource, 

348 self.diaObjects.loc[srcIdx[0]], 

349 objSources, 

350 cutout, 

351 None) 

352 self.assertEqual(len(alert), 9) 

353 

354 self.assertEqual(alert["alertId"], alertId) 

355 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

356 self.assertEqual(alert["cutoutDifference"]["stampData"], 

357 cutputBytes) 

358 self.assertEqual(alert["cutoutTemplate"], 

359 None) 

360 

361 def testRun(self): 

362 """Test the run method of package alerts. 

363 """ 

364 packConfig = PackageAlertsConfig() 

365 tempdir = tempfile.mkdtemp(prefix='alerts') 

366 packConfig.alertWriteLocation = tempdir 

367 packageAlerts = PackageAlertsTask(config=packConfig) 

368 

369 packageAlerts.run(self.diaSources, 

370 self.diaObjects, 

371 self.diaSourceHistory, 

372 self.exposure, 

373 None, 

374 None) 

375 

376 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

377 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

378 writer_schema, data = \ 

379 packageAlerts.alertSchema.retrieve_alerts(f) 

380 self.assertEqual(len(data), len(self.diaSources)) 

381 for idx, alert in enumerate(data): 

382 for key, value in alert["diaSource"].items(): 

383 if isinstance(value, float): 

384 if np.isnan(self.diaSources.iloc[idx][key]): 

385 self.assertTrue(np.isnan(value)) 

386 else: 

387 self.assertAlmostEqual( 

388 1 - value / self.diaSources.iloc[idx][key], 

389 0.) 

390 else: 

391 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

392 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

393 alert["diaSource"]["decl"], 

394 geom.degrees) 

395 cutout = self.exposure.getCutout(sphPoint, 

396 geom.Extent2I(self.cutoutSize, 

397 self.cutoutSize)) 

398 self.assertEqual(alert["cutoutDifference"]["stampData"], 

399 packageAlerts.makeCutoutBytes(cutout)) 

400 

401 shutil.rmtree(tempdir) 

402 

403 

404class MemoryTester(lsst.utils.tests.MemoryTestCase): 

405 pass 

406 

407 

408def setup_module(module): 

409 lsst.utils.tests.init() 

410 

411 

412if __name__ == "__main__": 412 ↛ 413line 412 didn't jump to line 413, because the condition on line 412 was never true

413 lsst.utils.tests.init() 

414 unittest.main()