Coverage for tests/test_diaPipe.py: 21%

128 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-01 17:01 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23import warnings 

24 

25import numpy as np 

26import pandas as pd 

27 

28import lsst.afw.image as afwImage 

29import lsst.afw.table as afwTable 

30from lsst.pipe.base.testUtils import assertValidOutput 

31from utils_tests import makeExposure, makeDiaObjects 

32import lsst.utils.tests 

33import lsst.utils.timer 

34from unittest.mock import patch, Mock, MagicMock, DEFAULT 

35 

36from lsst.ap.association import DiaPipelineTask 

37 

38 

39def _makeMockDataFrame(): 

40 """Create a new mock of a DataFrame. 

41 

42 Returns 

43 ------- 

44 mock : `unittest.mock.Mock` 

45 A mock guaranteed to accept all operations used by `pandas.DataFrame`. 

46 """ 

47 with warnings.catch_warnings(): 

48 # spec triggers deprecation warnings on DataFrame, but will 

49 # automatically adapt to any removals. 

50 warnings.simplefilter("ignore", category=DeprecationWarning) 

51 return MagicMock(spec=pd.DataFrame()) 

52 

53 

54class TestDiaPipelineTask(unittest.TestCase): 

55 

56 @classmethod 

57 def _makeDefaultConfig(cls, 

58 doPackageAlerts=False, 

59 doSolarSystemAssociation=False): 

60 config = DiaPipelineTask.ConfigClass() 

61 config.apdb.db_url = "sqlite://" 

62 config.doPackageAlerts = doPackageAlerts 

63 config.doSolarSystemAssociation = doSolarSystemAssociation 

64 return config 

65 

66 def setUp(self): 

67 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs 

68 srcSchema = afwTable.SourceTable.makeMinimalSchema() 

69 srcSchema.addField("base_PixelFlags_flag", type="Flag") 

70 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag") 

71 self.srcSchema = afwTable.SourceCatalog(srcSchema) 

72 

73 def tearDown(self): 

74 pass 

75 

76 def testRun(self): 

77 """Test running while creating and packaging alerts. 

78 """ 

79 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True) 

80 

81 def testRunWithSolarSystemAssociation(self): 

82 """Test running while creating and packaging alerts. 

83 """ 

84 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True) 

85 

86 def testRunWithAlerts(self): 

87 """Test running while creating and packaging alerts. 

88 """ 

89 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False) 

90 

91 def testRunWithoutAlertsOrSolarSystem(self): 

92 """Test running without creating and packaging alerts. 

93 """ 

94 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False) 

95 

96 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False): 

97 """Test the normal workflow of each ap_pipe step. 

98 """ 

99 config = self._makeDefaultConfig( 

100 doPackageAlerts=doPackageAlerts, 

101 doSolarSystemAssociation=doSolarSystemAssociation) 

102 task = DiaPipelineTask(config=config) 

103 # Set DataFrame index testing to always return False. Mocks return 

104 # true for this check otherwise. 

105 task.testDataFrameIndex = lambda x: False 

106 diffIm = Mock(spec=afwImage.ExposureF) 

107 exposure = Mock(spec=afwImage.ExposureF) 

108 template = Mock(spec=afwImage.ExposureF) 

109 diaSrc = _makeMockDataFrame() 

110 ssObjects = _makeMockDataFrame() 

111 ccdExposureIdBits = 32 

112 

113 # Each of these subtasks should be called once during diaPipe 

114 # execution. We use mocks here to check they are being executed 

115 # appropriately. 

116 subtasksToMock = [ 

117 "diaCatalogLoader", 

118 "diaCalculation", 

119 "diaForcedSource", 

120 ] 

121 if doPackageAlerts: 

122 subtasksToMock.append("alertPackager") 

123 else: 

124 self.assertFalse(hasattr(task, "alertPackager")) 

125 

126 if not doSolarSystemAssociation: 

127 self.assertFalse(hasattr(task, "solarSystemAssociator")) 

128 

129 def concatMock(_data, **_kwargs): 

130 return _makeMockDataFrame() 

131 

132 # Mock out the run() methods of these two Tasks to ensure they 

133 # return data in the correct form. 

134 @lsst.utils.timer.timeMethod 

135 def solarSystemAssociator_run(self, unAssocDiaSources, solarSystemObjectTable, diffIm): 

136 return lsst.pipe.base.Struct(nTotalSsObjects=42, 

137 nAssociatedSsObjects=30, 

138 ssoAssocDiaSources=_makeMockDataFrame(), 

139 unAssocDiaSources=_makeMockDataFrame()) 

140 

141 @lsst.utils.timer.timeMethod 

142 def associator_run(self, table, diaObjects, exposure_time=None): 

143 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, 

144 matchedDiaSources=_makeMockDataFrame(), 

145 unAssocDiaSources=_makeMockDataFrame(), 

146 longTrailedSources=None) 

147 

148 # apdb isn't a subtask, but still needs to be mocked out for correct 

149 # execution in the test environment. 

150 with patch.multiple( 

151 task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]} 

152 ): 

153 with patch('lsst.ap.association.diaPipe.pd.concat', new=concatMock), \ 

154 patch('lsst.ap.association.association.AssociationTask.run', new=associator_run), \ 

155 patch('lsst.ap.association.ssoAssociation.SolarSystemAssociationTask.run', 

156 new=solarSystemAssociator_run): 

157 

158 result = task.run(diaSrc, 

159 ssObjects, 

160 diffIm, 

161 exposure, 

162 template, 

163 ccdExposureIdBits, 

164 "g") 

165 for subtaskName in subtasksToMock: 

166 getattr(task, subtaskName).run.assert_called_once() 

167 assertValidOutput(task, result) 

168 self.assertEqual(result.apdbMarker.db_url, "sqlite://") 

169 meta = task.getFullMetadata() 

170 # Check that the expected metadata has been set. 

171 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2) 

172 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3) 

173 # and that associators ran once or not at all. 

174 self.assertEqual(len(meta.getArray("diaPipe:associator.associator_runEndUtc")), 1) 

175 if doSolarSystemAssociation: 

176 self.assertEqual(len(meta.getArray("diaPipe:solarSystemAssociator." 

177 "solarSystemAssociator_runEndUtc")), 1) 

178 else: 

179 self.assertNotIn("diaPipe:solarSystemAssociator", meta) 

180 

181 def test_createDiaObjects(self): 

182 """Test that creating new DiaObjects works as expected. 

183 """ 

184 nSources = 5 

185 diaSources = pd.DataFrame(data=[ 

186 {"ra": 0.04*idx, "dec": 0.04*idx, 

187 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0, 

188 "ssObjectId": 0} 

189 for idx in range(nSources)]) 

190 

191 config = self._makeDefaultConfig(doPackageAlerts=False) 

192 task = DiaPipelineTask(config=config) 

193 result = task.createNewDiaObjects(diaSources) 

194 self.assertEqual(nSources, len(result.newDiaObjects)) 

195 self.assertTrue(np.all(np.equal( 

196 result.diaSources["diaObjectId"].to_numpy(), 

197 result.diaSources["diaSourceId"].to_numpy()))) 

198 self.assertTrue(np.all(np.equal( 

199 result.newDiaObjects["diaObjectId"].to_numpy(), 

200 result.diaSources["diaSourceId"].to_numpy()))) 

201 

202 def test_purgeDiaObjects(self): 

203 """Remove diaOjects that are outside an image's bounding box. 

204 """ 

205 

206 config = self._makeDefaultConfig(doPackageAlerts=False) 

207 task = DiaPipelineTask(config=config) 

208 exposure = makeExposure(False, False) 

209 nObj0 = 20 

210 

211 # Create diaObjects 

212 diaObjects = makeDiaObjects(nObj0, exposure) 

213 # Shrink the bounding box so that some of the diaObjects will be outside 

214 bbox = exposure.getBBox() 

215 size = np.minimum(bbox.getHeight(), bbox.getWidth()) 

216 bbox.grow(-size//4) 

217 exposureCut = exposure[bbox] 

218 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth()) 

219 buffer = 10 

220 bbox.grow(buffer) 

221 

222 def check_diaObjects(bbox, wcs, diaObjects): 

223 raVals = diaObjects.ra.to_numpy() 

224 decVals = diaObjects.dec.to_numpy() 

225 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True) 

226 selector = bbox.contains(xVals, yVals) 

227 return selector 

228 

229 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects) 

230 nIn0 = np.count_nonzero(selector0) 

231 nOut0 = np.count_nonzero(~selector0) 

232 self.assertEqual(nObj0, nIn0 + nOut0) 

233 

234 diaObjects1 = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects, 

235 buffer=buffer) 

236 # Verify that the bounding box was not changed 

237 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth()) 

238 self.assertEqual(sizeCut, sizeCheck) 

239 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1) 

240 nIn1 = np.count_nonzero(selector1) 

241 nOut1 = np.count_nonzero(~selector1) 

242 nObj1 = len(diaObjects1) 

243 self.assertEqual(nObj1, nIn0) 

244 # Verify that not all diaObjects were removed 

245 self.assertGreater(nObj1, 0) 

246 # Check that some diaObjects were removed 

247 self.assertLess(nObj1, nObj0) 

248 # Verify that no objects outside the bounding box remain 

249 self.assertEqual(nOut1, 0) 

250 # Verify that no objects inside the bounding box were removed 

251 self.assertEqual(nIn1, nIn0) 

252 

253 

254class MemoryTester(lsst.utils.tests.MemoryTestCase): 

255 pass 

256 

257 

258def setup_module(module): 

259 lsst.utils.tests.init() 

260 

261 

262if __name__ == "__main__": 262 ↛ 263line 262 didn't jump to line 263, because the condition on line 262 was never true

263 lsst.utils.tests.init() 

264 unittest.main()