Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import numpy as np 

24import pandas as pd 

25import tempfile 

26import unittest 

27import yaml 

28 

29from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

30import lsst.afw.geom as afwGeom 

31import lsst.afw.image as afwImage 

32from lsst.ap.association import LoadDiaCatalogsTask, LoadDiaCatalogsConfig 

33import lsst.daf.base as dafBase 

34from lsst.dax.apdb import ApdbSql, ApdbSqlConfig 

35import lsst.geom as geom 

36from lsst.utils import getPackageDir 

37import lsst.utils.tests 

38 

39 

40def _data_file_name(basename, module_name): 

41 """Return path name of a data file. 

42 

43 Parameters 

44 ---------- 

45 basename : `str` 

46 Name of the file to add to the path string. 

47 module_name : `str` 

48 Name of lsst stack package environment variable. 

49 

50 Returns 

51 ------- 

52 data_file_path : `str` 

53 Full path of the file to load from the "data" directory in a given 

54 repository. 

55 """ 

56 return os.path.join(getPackageDir(module_name), "data", basename) 

57 

58 

59def makeExposure(flipX=False, flipY=False): 

60 """Create an exposure and flip the x or y (or both) coordinates. 

61 

62 Returns bounding boxes that are right or left handed around the bounding 

63 polygon. 

64 

65 Parameters 

66 ---------- 

67 flipX : `bool` 

68 Flip the x coordinate in the WCS. 

69 flipY : `bool` 

70 Flip the y coordinate in the WCS. 

71 

72 Returns 

73 ------- 

74 exposure : `lsst.afw.image.Exposure` 

75 Exposure with a valid bounding box and wcs. 

76 """ 

77 metadata = dafBase.PropertySet() 

78 

79 metadata.set("SIMPLE", "T") 

80 metadata.set("BITPIX", -32) 

81 metadata.set("NAXIS", 2) 

82 metadata.set("NAXIS1", 1024) 

83 metadata.set("NAXIS2", 1153) 

84 metadata.set("RADECSYS", 'FK5') 

85 metadata.set("EQUINOX", 2000.) 

86 

87 metadata.setDouble("CRVAL1", 215.604025685476) 

88 metadata.setDouble("CRVAL2", 53.1595451514076) 

89 metadata.setDouble("CRPIX1", 1109.99981456774) 

90 metadata.setDouble("CRPIX2", 560.018167811613) 

91 metadata.set("CTYPE1", 'RA---SIN') 

92 metadata.set("CTYPE2", 'DEC--SIN') 

93 

94 xFlip = 1 

95 if flipX: 

96 xFlip = -1 

97 yFlip = 1 

98 if flipY: 

99 yFlip = -1 

100 metadata.setDouble("CD1_1", xFlip * 5.10808596133527E-05) 

101 metadata.setDouble("CD1_2", yFlip * 1.85579539217196E-07) 

102 metadata.setDouble("CD2_2", yFlip * -5.10281493481982E-05) 

103 metadata.setDouble("CD2_1", xFlip * -8.27440751733828E-07) 

104 

105 wcs = afwGeom.makeSkyWcs(metadata) 

106 exposure = afwImage.makeExposure( 

107 afwImage.makeMaskedImageFromArrays(np.ones((1024, 1153))), wcs) 

108 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

109 visit = afwImage.VisitInfo( 

110 exposureId=1234, 

111 exposureTime=200., 

112 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

113 dafBase.DateTime.Timescale.TAI)) 

114 exposure.setDetector(detector) 

115 exposure.getInfo().setVisitInfo(visit) 

116 exposure.setFilterLabel(afwImage.FilterLabel(band='g')) 

117 

118 return exposure 

119 

120 

121def makeDiaObjects(nObjects, exposure): 

122 """Make a test set of DiaObjects. 

123 

124 Parameters 

125 ---------- 

126 nObjects : `int` 

127 Number of objects to create. 

128 exposure : `lsst.afw.image.Exposure` 

129 Exposure to create objects over. 

130 

131 Returns 

132 ------- 

133 diaObjects : `pandas.DataFrame` 

134 DiaObjects generated across the exposure. 

135 """ 

136 bbox = geom.Box2D(exposure.getBBox()) 

137 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects) 

138 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects) 

139 

140 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

141 system=dafBase.DateTime.MJD) 

142 

143 wcs = exposure.getWcs() 

144 

145 data = [] 

146 for idx, (x, y) in enumerate(zip(rand_x, rand_y)): 

147 coord = wcs.pixelToSky(x, y) 

148 newObject = {"ra": coord.getRa().asDegrees(), 

149 "decl": coord.getDec().asDegrees(), 

150 "radecTai": midPointTaiMJD, 

151 "diaObjectId": idx, 

152 "pmParallaxNdata": 0, 

153 "nearbyObj1": 0, 

154 "nearbyObj2": 0, 

155 "nearbyObj3": 0} 

156 for f in ["u", "g", "r", "i", "z", "y"]: 

157 newObject["%sPSFluxNdata" % f] = 0 

158 data.append(newObject) 

159 

160 return pd.DataFrame(data=data) 

161 

162 

163def makeDiaSources(nSources, diaObjectIds, exposure): 

164 """Make a test set of DiaSources. 

165 

166 Parameters 

167 ---------- 

168 nSources : `int` 

169 Number of sources to create. 

170 diaObjectIds : `numpy.ndarray` 

171 Integer Ids of diaobjects to "associate" with the DiaSources. 

172 exposure : `lsst.afw.image.Exposure` 

173 Exposure to create sources over. 

174 

175 Returns 

176 ------- 

177 diaSources : `pandas.DataFrame` 

178 DiaSources generated across the exposure. 

179 """ 

180 bbox = geom.Box2D(exposure.getBBox()) 

181 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources) 

182 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources) 

183 rand_ids = diaObjectIds[np.random.randint(len(diaObjectIds), size=nSources)] 

184 

185 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

186 system=dafBase.DateTime.MJD) 

187 

188 wcs = exposure.getWcs() 

189 

190 data = [] 

191 for idx, (x, y, objId) in enumerate(zip(rand_x, rand_y, rand_ids)): 

192 coord = wcs.pixelToSky(x, y) 

193 data.append({"ra": coord.getRa().asDegrees(), 

194 "decl": coord.getDec().asDegrees(), 

195 "diaObjectId": objId, 

196 "diaSourceId": idx, 

197 "midPointTai": midPointTaiMJD}) 

198 

199 return pd.DataFrame(data=data) 

200 

201 

202def makeDiaForcedSources(nForcedSources, diaObjectIds, exposure): 

203 """Make a test set of DiaForcedSources. 

204 

205 Parameters 

206 ---------- 

207 nForcedSources : `int` 

208 Number of sources to create. 

209 diaObjectIds : `numpy.ndarray` 

210 Integer Ids of diaobjects to "associate" with the DiaSources. 

211 exposure : `lsst.afw.image.Exposure` 

212 Exposure to create sources over. 

213 

214 Returns 

215 ------- 

216 diaForcedSources : `pandas.DataFrame` 

217 DiaForcedSources generated across the exposure. 

218 """ 

219 rand_ids = diaObjectIds[ 

220 np.random.randint(len(diaObjectIds), size=nForcedSources)] 

221 

222 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get( 

223 system=dafBase.DateTime.MJD) 

224 

225 data = [] 

226 for idx, objId in enumerate(rand_ids): 

227 data.append({"diaObjectId": objId, 

228 "diaForcedSourceId": idx, 

229 "ccdVisitId": idx, 

230 "midPointTai": midPointTaiMJD}) 

231 

232 return pd.DataFrame(data=data) 

233 

234 

235class TestLoadDiaCatalogs(unittest.TestCase): 

236 

237 def setUp(self): 

238 np.random.seed(1234) 

239 

240 self.db_file_fd, self.db_file = tempfile.mkstemp( 

241 dir=os.path.dirname(__file__)) 

242 

243 self.apdbConfig = ApdbSqlConfig() 

244 self.apdbConfig.db_url = "sqlite:///" + self.db_file 

245 self.apdbConfig.dia_object_index = "baseline" 

246 self.apdbConfig.dia_object_columns = [] 

247 self.apdbConfig.schema_file = _data_file_name( 

248 "apdb-schema.yaml", "dax_apdb") 

249 self.apdbConfig.extra_schema_file = _data_file_name( 

250 "apdb-ap-pipe-schema-extra.yaml", "ap_association") 

251 

252 self.apdb = ApdbSql(config=self.apdbConfig) 

253 self.apdb.makeSchema() 

254 

255 self.exposure = makeExposure(False, False) 

256 

257 self.diaObjects = makeDiaObjects(20, self.exposure) 

258 self.diaSources = makeDiaSources( 

259 100, 

260 self.diaObjects["diaObjectId"].to_numpy(), 

261 self.exposure) 

262 self.diaForcedSources = makeDiaForcedSources( 

263 200, 

264 self.diaObjects["diaObjectId"].to_numpy(), 

265 self.exposure) 

266 

267 self.dateTime = self.exposure.getInfo().getVisitInfo().getDate() 

268 self.apdb.store(self.dateTime, 

269 self.diaObjects, 

270 self.diaSources, 

271 self.diaForcedSources) 

272 

273 # These columns are not in the DPDD, yet do appear in DiaSource.yaml. 

274 # We don't need to check them against the default APDB schema. 

275 self.ignoreColumns = ["filterName", "bboxSize", "isDipole"] 

276 

277 def tearDown(self): 

278 os.close(self.db_file_fd) 

279 os.remove(self.db_file) 

280 

281 def testRun(self): 

282 """Test the full run method for the loader. 

283 """ 

284 diaLoader = LoadDiaCatalogsTask() 

285 result = diaLoader.run(self.exposure, self.apdb) 

286 

287 self.assertEqual(len(result.diaObjects), len(self.diaObjects)) 

288 self.assertEqual(len(result.diaSources), len(self.diaSources)) 

289 self.assertEqual(len(result.diaForcedSources), 

290 len(self.diaForcedSources)) 

291 

292 def testLoadDiaObjects(self): 

293 """Test that the correct number of diaObjects are loaded. 

294 """ 

295 diaLoader = LoadDiaCatalogsTask() 

296 region = diaLoader._getRegion(self.exposure) 

297 diaObjects = diaLoader.loadDiaObjects(region, 

298 self.apdb) 

299 self.assertEqual(len(diaObjects), len(self.diaObjects)) 

300 

301 def testLoadDiaForcedSources(self): 

302 """Test that the correct number of diaForcedSources are loaded. 

303 """ 

304 diaLoader = LoadDiaCatalogsTask() 

305 region = diaLoader._getRegion(self.exposure) 

306 diaForcedSources = diaLoader.loadDiaForcedSources( 

307 self.diaObjects, 

308 region, 

309 self.dateTime, 

310 self.apdb) 

311 self.assertEqual(len(diaForcedSources), len(self.diaForcedSources)) 

312 

313 def testLoadDiaSources(self): 

314 """Test that the correct number of diaSources are loaded. 

315 

316 Also check that they can be properly loaded both by location and 

317 ``diaObjectId``. 

318 """ 

319 diaConfig = LoadDiaCatalogsConfig() 

320 diaLoader = LoadDiaCatalogsTask(config=diaConfig) 

321 

322 region = diaLoader._getRegion(self.exposure) 

323 diaSources = diaLoader.loadDiaSources(self.diaObjects, 

324 region, 

325 self.dateTime, 

326 self.apdb) 

327 self.assertEqual(len(diaSources), len(self.diaSources)) 

328 

329 def test_apdbSchema(self): 

330 """Test that the default DiaSource schema from dax_apdb agrees with the 

331 column names defined here in ap_association/data/DiaSource.yaml. 

332 """ 

333 functorFile = _data_file_name("DiaSource.yaml", "ap_association") 

334 apdbSchemaFile = _data_file_name("apdb-schema.yaml", "dax_apdb") 

335 with open(apdbSchemaFile) as yaml_stream: 

336 table_list = list(yaml.safe_load_all(yaml_stream)) 

337 for table in table_list: 

338 if table['table'] == 'DiaSource': 

339 apdbSchemaColumns = [column['name'] for column in table['columns']] 

340 break 

341 with open(functorFile) as yaml_stream: 

342 diaSourceFunctor = yaml.safe_load_all(yaml_stream) 

343 for functor in diaSourceFunctor: 

344 diaSourceColumns = [column for column in list(functor['funcs'].keys()) 

345 if column not in self.ignoreColumns] 

346 self.assertLess(set(diaSourceColumns), set(apdbSchemaColumns)) 

347 

348 

349class MemoryTester(lsst.utils.tests.MemoryTestCase): 

350 pass 

351 

352 

353def setup_module(module): 

354 lsst.utils.tests.init() 

355 

356 

357if __name__ == "__main__": 357 ↛ 358line 357 didn't jump to line 358, because the condition on line 357 was never true

358 lsst.utils.tests.init() 

359 unittest.main()