Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import numpy as np 

24import pandas as pd 

25import shutil 

26import tempfile 

27import unittest 

28 

29from lsst.ap.association import (PackageAlertsConfig, 

30 PackageAlertsTask, 

31 make_dia_source_schema, 

32 make_dia_object_schema) 

33from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

34import lsst.afw.fits as afwFits 

35import lsst.afw.image as afwImage 

36import lsst.afw.image.utils as afwImageUtils 

37import lsst.daf.base as dafBase 

38from lsst.dax.apdb import Apdb, ApdbConfig 

39import lsst.geom as geom 

40import lsst.meas.base.tests 

41from lsst.utils import getPackageDir 

42import lsst.utils.tests 

43 

44 

45def _data_file_name(basename, module_name): 

46 """Return path name of a data file. 

47 

48 Parameters 

49 ---------- 

50 basename : `str` 

51 Name of the file to add to the path string. 

52 module_name : `str` 

53 Name of lsst stack package environment variable. 

54 

55 Returns 

56 ------- 

57 data_file_path : `str` 

58 Full path of the file to load from the "data" directory in a given 

59 repository. 

60 """ 

61 return os.path.join(getPackageDir(module_name), "data", basename) 

62 

63 

64def makeDiaObjects(nObjects, exposure): 

65 """Make a test set of DiaObjects. 

66 

67 Parameters 

68 ---------- 

69 nObjects : `int` 

70 Number of objects to create. 

71 exposure : `lsst.afw.image.Exposure` 

72 Exposure to create objects over. 

73 

74 Returns 

75 ------- 

76 diaObjects : `pandas.DataFrame` 

77 DiaObjects generated across the exposure. 

78 """ 

79 bbox = geom.Box2D(exposure.getBBox()) 

80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

82 

83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

84 system=dafBase.DateTime.MJD) 

85 

86 wcs = exposure.getWcs() 

87 

88 data = [] 

89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

90 coord = wcs.pixelToSky(x, y) 

91 htmIdx = 1 

92 newObject = {"ra": coord.getRa().asDegrees(), 

93 "decl": coord.getDec().asDegrees(), 

94 "radecTai": midPointTaiMJD, 

95 "diaObjectId": idx, 

96 "pixelId": htmIdx, 

97 "pmParallaxNdata": 0, 

98 "nearbyObj1": 0, 

99 "nearbyObj2": 0, 

100 "nearbyObj3": 0, 

101 "flags": 1, 

102 "nDiaSources": 5} 

103 for f in ["u", "g", "r", "i", "z", "y"]: 

104 newObject["%sPSFluxNdata" % f] = 0 

105 data.append(newObject) 

106 

107 return pd.DataFrame(data=data) 

108 

109 

110def makeDiaSources(nSources, diaObjectIds, exposure): 

111 """Make a test set of DiaSources. 

112 

113 Parameters 

114 ---------- 

115 nSources : `int` 

116 Number of sources to create. 

117 diaObjectIds : `numpy.ndarray` 

118 Integer Ids of diaobjects to "associate" with the DiaSources. 

119 exposure : `lsst.afw.image.Exposure` 

120 Exposure to create sources over. 

121 pixelator : `lsst.sphgeom.HtmPixelization` 

122 Object to compute spatial indicies from. 

123 

124 Returns 

125 ------- 

126 diaSources : `pandas.DataFrame` 

127 DiaSources generated across the exposure. 

128 """ 

129 bbox = geom.Box2D(exposure.getBBox()) 

130 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

131 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

132 

133 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

134 system=dafBase.DateTime.MJD) 

135 

136 wcs = exposure.getWcs() 

137 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId() 

138 

139 data = [] 

140 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

141 coord = wcs.pixelToSky(x, y) 

142 htmIdx = 1 

143 objId = diaObjectIds[idx % len(diaObjectIds)] 

144 # Put together the minimum values for the alert. 

145 data.append({"ra": coord.getRa().asDegrees(), 

146 "decl": coord.getDec().asDegrees(), 

147 "x": x, 

148 "y": y, 

149 "ccdVisitId": ccdVisitId, 

150 "diaObjectId": objId, 

151 "ssObjectId": 0, 

152 "parentDiaSourceId": 0, 

153 "prv_procOrder": 0, 

154 "diaSourceId": idx, 

155 "pixelId": htmIdx, 

156 "midPointTai": midPointTaiMJD + 1.0 * idx, 

157 "filterName": exposure.getFilter().getName(), 

158 "filterId": 0, 

159 "psNdata": 0, 

160 "trailNdata": 0, 

161 "dipNdata": 0, 

162 "flags": 1}) 

163 

164 return pd.DataFrame(data=data) 

165 

166 

167def _roundTripThroughApdb(objects, sources, dateTime): 

168 """Run object and source catalogs through the Apdb to get the correct 

169 table schemas. 

170 

171 Parameters 

172 ---------- 

173 objects : `pandas.DataFrame` 

174 Set of test DiaObjects to round trip. 

175 sources : `pandas.DataFrame` 

176 Set of test DiaSources to round trip. 

177 dateTime : `datetime.datetime` 

178 Time for the Apdb. 

179 

180 Returns 

181 ------- 

182 objects : `pandas.DataFrame` 

183 Round tripped objects. 

184 sources : `pandas.DataFrame` 

185 Round tripped sources. 

186 """ 

187 tmpFile = tempfile.NamedTemporaryFile() 

188 

189 apdbConfig = ApdbConfig() 

190 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

191 apdbConfig.isolation_level = "READ_UNCOMMITTED" 

192 apdbConfig.dia_object_index = "baseline" 

193 apdbConfig.dia_object_columns = [] 

194 apdbConfig.schema_file = _data_file_name( 

195 "apdb-schema.yaml", "dax_apdb") 

196 apdbConfig.column_map = _data_file_name( 

197 "apdb-ap-pipe-afw-map.yaml", "ap_association") 

198 apdbConfig.extra_schema_file = _data_file_name( 

199 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

200 

201 apdb = Apdb(config=apdbConfig, 

202 afw_schemas=dict(DiaObject=make_dia_object_schema(), 

203 DiaSource=make_dia_source_schema())) 

204 apdb.makeSchema() 

205 

206 minId = objects["pixelId"].min() 

207 maxId = objects["pixelId"].max() 

208 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects) 

209 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]), 

210 dateTime, 

211 return_pandas=True).append(sources) 

212 

213 apdb.storeDiaSources(diaSources) 

214 apdb.storeDiaObjects(diaObjects, dateTime) 

215 

216 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True) 

217 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]), 

218 dateTime, 

219 return_pandas=True) 

220 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

221 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"], 

222 drop=False, 

223 inplace=True) 

224 

225 return (diaObjects, diaSources) 

226 

227 

228class TestPackageAlerts(unittest.TestCase): 

229 

230 def setUp(self): 

231 np.random.seed(1234) 

232 self.center = lsst.geom.Point2D(50.1, 49.8) 

233 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

234 lsst.geom.Extent2I(140, 160)) 

235 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

236 self.dataset.addSource(100000.0, self.center) 

237 exposure, catalog = self.dataset.realize( 

238 10.0, 

239 self.dataset.makeMinimalSchema(), 

240 randomSeed=0) 

241 self.exposure = exposure 

242 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

243 self.exposure.setDetector(detector) 

244 

245 visit = afwImage.VisitInfo( 

246 exposureId=1234, 

247 exposureTime=200., 

248 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

249 dafBase.DateTime.Timescale.TAI)) 

250 self.exposure.getInfo().setVisitInfo(visit) 

251 

252 self.filter_names = ["g"] 

253 afwImageUtils.resetFilters() 

254 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401") 

255 self.exposure.setFilter(afwImage.Filter('g')) 

256 

257 diaObjects = makeDiaObjects(2, self.exposure) 

258 diaSourceHistory = makeDiaSources(10, 

259 diaObjects["diaObjectId"], 

260 self.exposure) 

261 self.diaObjects, diaSourceHistory = _roundTripThroughApdb( 

262 diaObjects, 

263 diaSourceHistory, 

264 self.exposure.getInfo().getVisitInfo().getDate().toPython()) 

265 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

266 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

267 diaSourceHistory["programId"] = 0 

268 diaSourceHistory["ra_decl_Cov"] = None 

269 diaSourceHistory["x_y_Cov"] = None 

270 diaSourceHistory["ps_Cov"] = None 

271 diaSourceHistory["trail_Cov"] = None 

272 diaSourceHistory["dip_Cov"] = None 

273 diaSourceHistory["i_cov"] = None 

274 

275 self.diaObjects["ra_decl_Cov"] = None 

276 self.diaObjects["pm_parallax_Cov"] = None 

277 self.diaObjects["uLcPeriodic"] = None 

278 self.diaObjects["gLcPeriodic"] = None 

279 self.diaObjects["rLcPeriodic"] = None 

280 self.diaObjects["iLcPeriodic"] = None 

281 self.diaObjects["zLcPeriodic"] = None 

282 self.diaObjects["yLcPeriodic"] = None 

283 self.diaObjects["uLcNonPeriodic"] = None 

284 self.diaObjects["gLcNonPeriodic"] = None 

285 self.diaObjects["rLcNonPeriodic"] = None 

286 self.diaObjects["iLcNonPeriodic"] = None 

287 self.diaObjects["zLcNonPeriodic"] = None 

288 self.diaObjects["yLcNonPeriodic"] = None 

289 

290 self.diaSources = diaSourceHistory.loc[ 

291 [(0, "g", 8), (1, "g", 9)], :] 

292 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8), 

293 (1, "g", 9)]) 

294 

295 def testMakeCutoutBytes(self): 

296 """Test round tripping an exposure/cutout to bytes and back. 

297 """ 

298 packageAlerts = PackageAlertsTask() 

299 

300 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

301 cutout = self.exposure.getCutout(sphPoint, packageAlerts.cutoutBBox) 

302 

303 cutoutBytes = packageAlerts.makeCutoutBytes(cutout) 

304 tempMemFile = afwFits.MemFileManager(len(cutoutBytes)) 

305 tempMemFile.setData(cutoutBytes, len(cutoutBytes)) 

306 cutoutFromBytes = afwImage.ExposureF(tempMemFile) 

307 self.assertTrue( 

308 np.all(cutout.getImage().array == cutoutFromBytes.getImage().array)) 

309 

310 def testMakeAlertDict(self): 

311 """Test stripping data from the various data products and into a 

312 dictionary "alert". 

313 """ 

314 packageAlerts = PackageAlertsTask() 

315 alertId = 1234 

316 

317 for srcIdx, diaSource in self.diaSources.iterrows(): 

318 sphPoint = geom.SpherePoint(diaSource["ra"], 

319 diaSource["decl"], 

320 geom.degrees) 

321 cutout = self.exposure.getCutout(sphPoint, 

322 packageAlerts.cutoutBBox) 

323 cutputBytes = packageAlerts.makeCutoutBytes(cutout) 

324 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

325 alert = packageAlerts.makeAlertDict( 

326 alertId, 

327 diaSource, 

328 self.diaObjects.loc[srcIdx[0]], 

329 objSources, 

330 cutout, 

331 None) 

332 self.assertEqual(len(alert), 9) 

333 

334 self.assertEqual(alert["alertId"], alertId) 

335 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

336 self.assertEqual(alert["cutoutDifference"]["stampData"], 

337 cutputBytes) 

338 self.assertEqual(alert["cutoutTemplate"], 

339 None) 

340 

341 def testRun(self): 

342 """Test the run method of package alerts. 

343 """ 

344 packConfig = PackageAlertsConfig() 

345 tempdir = tempfile.mkdtemp(prefix='alerts') 

346 packConfig.alertWriteLocation = tempdir 

347 packageAlerts = PackageAlertsTask(config=packConfig) 

348 

349 packageAlerts.run(self.diaSources, 

350 self.diaObjects, 

351 self.diaSourceHistory, 

352 self.exposure, 

353 None, 

354 None) 

355 

356 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId() 

357 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

358 writer_schema, data = \ 

359 packageAlerts.alertSchema.retrieve_alerts(f) 

360 self.assertEqual(len(data), len(self.diaSources)) 

361 for idx, alert in enumerate(data): 

362 for key, value in alert["diaSource"].items(): 

363 if isinstance(value, float): 

364 if np.isnan(self.diaSources.iloc[idx][key]): 

365 self.assertTrue(np.isnan(value)) 

366 else: 

367 self.assertAlmostEqual( 

368 1 - value / self.diaSources.iloc[idx][key], 

369 0.) 

370 else: 

371 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

372 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

373 alert["diaSource"]["decl"], 

374 geom.degrees) 

375 cutout = self.exposure.getCutout(sphPoint, 

376 packageAlerts.cutoutBBox) 

377 self.assertEqual(alert["cutoutDifference"]["stampData"], 

378 packageAlerts.makeCutoutBytes(cutout)) 

379 

380 shutil.rmtree(tempdir) 

381 

382 

383class MemoryTester(lsst.utils.tests.MemoryTestCase): 

384 pass 

385 

386 

387def setup_module(module): 

388 lsst.utils.tests.init() 

389 

390 

391if __name__ == "__main__": 391 ↛ 392line 391 didn't jump to line 392, because the condition on line 391 was never true

392 lsst.utils.tests.init() 

393 unittest.main()