Coverage for tests/test_transform.py: 42%

88 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-24 02:36 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2015 AURA/LSST. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22""" 

23Test the basic operation of measurement transformations. 

24 

25We test measurement transforms in two ways: 

26 

27First, we construct and run a simple TransformTask on the (mocked) results of 

28measurement tasks. The same test is carried out against both 

29SingleFrameMeasurementTask and ForcedMeasurementTask, on the basis that the 

30transformation system should be agnostic as to the origin of the source 

31catalog it is transforming. 

32""" 

33import unittest 

34 

35import lsst.utils 

36import lsst.afw.table as afwTable 

37import lsst.geom as geom 

38import lsst.meas.base as measBase 

39import lsst.utils.tests 

40from lsst.pipe.tasks.transformMeasurement import TransformConfig, TransformTask 

41 

42PLUGIN_NAME = "base_TrivialMeasurement" 

43 

44# Rather than providing real WCS and calibration objects to the 

45# transformation, we use this simple placeholder to keep track of the number 

46# of times it is accessed. 

47 

48 

49class Placeholder: 

50 

51 def __init__(self): 

52 self.count = 0 

53 

54 def increment(self): 

55 self.count += 1 

56 

57 

58class TrivialMeasurementTransform(measBase.transforms.MeasurementTransform): 

59 

60 def __init__(self, config, name, mapper): 

61 """Pass through all input fields to the output, and add a new field 

62 named after the measurement with the suffix "_transform". 

63 """ 

64 measBase.transforms.MeasurementTransform.__init__(self, config, name, mapper) 

65 for key, field in mapper.getInputSchema().extract(name + "*").values(): 

66 mapper.addMapping(key) 

67 self.key = mapper.editOutputSchema().addField(name + "_transform", type="D", doc="transformed dummy") 

68 

69 def __call__(self, inputCatalog, outputCatalog, wcs, photoCalib): 

70 """Transform inputCatalog to outputCatalog. 

71 

72 We update the wcs and photoCalib placeholders to indicate that they have 

73 been seen in the transformation, but do not use their values. 

74 

75 @param[in] inputCatalog SourceCatalog of measurements for transformation. 

76 @param[out] outputCatalog BaseCatalog of transformed measurements. 

77 @param[in] wcs Dummy WCS information; an instance of Placeholder. 

78 @param[in] photoCalib Dummy calibration information; an instance of Placeholder. 

79 """ 

80 if hasattr(wcs, "increment"): 

81 wcs.increment() 

82 if hasattr(photoCalib, "increment"): 

83 photoCalib.increment() 

84 inColumns = inputCatalog.getColumnView() 

85 outColumns = outputCatalog.getColumnView() 

86 outColumns[self.key] = -1.0 * inColumns[self.name] 

87 

88 

89class TrivialMeasurementBase: 

90 

91 """Default values for a trivial measurement plugin, subclassed below""" 

92 @staticmethod 

93 def getExecutionOrder(): 

94 return 0 

95 

96 @staticmethod 

97 def getTransformClass(): 

98 return TrivialMeasurementTransform 

99 

100 def measure(self, measRecord, exposure): 

101 measRecord.set(self.key, 1.0) 

102 

103 

104@measBase.register(PLUGIN_NAME) 

105class SFTrivialMeasurement(TrivialMeasurementBase, measBase.sfm.SingleFramePlugin): 

106 

107 """Single frame version of the trivial measurement""" 

108 

109 def __init__(self, config, name, schema, metadata): 

110 measBase.sfm.SingleFramePlugin.__init__(self, config, name, schema, metadata) 

111 self.key = schema.addField(name, type="D", doc="dummy field") 

112 

113 

114@measBase.register(PLUGIN_NAME) 

115class ForcedTrivialMeasurement(TrivialMeasurementBase, measBase.forcedMeasurement.ForcedPlugin): 

116 

117 """Forced frame version of the trivial measurement""" 

118 

119 def __init__(self, config, name, schemaMapper, metadata): 

120 measBase.forcedMeasurement.ForcedPlugin.__init__(self, config, name, schemaMapper, metadata) 

121 self.key = schemaMapper.editOutputSchema().addField(name, type="D", doc="dummy field") 

122 

123 

124class TransformTestCase(lsst.utils.tests.TestCase): 

125 

126 def _transformAndCheck(self, measConf, schema, transformTask): 

127 """Check the results of applying transformTask to a SourceCatalog. 

128 

129 @param[in] measConf Measurement plugin configuration. 

130 @param[in] schema Input catalog schema. 

131 @param[in] transformTask Instance of TransformTask to be applied. 

132 

133 For internal use by this test case. 

134 """ 

135 # There should now be one transformation registered per measurement plugin. 

136 self.assertEqual(len(measConf.plugins.names), len(transformTask.transforms)) 

137 

138 # Rather than do a real measurement, we use a dummy source catalog 

139 # containing a source at an arbitrary position. 

140 inCat = afwTable.SourceCatalog(schema) 

141 r = inCat.addNew() 

142 r.setCoord(geom.SpherePoint(0.0, 11.19, geom.degrees)) 

143 r[PLUGIN_NAME] = 1.0 

144 

145 wcs, photoCalib = Placeholder(), Placeholder() 

146 outCat = transformTask.run(inCat, wcs, photoCalib) 

147 

148 # Check that all sources have been transformed appropriately. 

149 for inSrc, outSrc in zip(inCat, outCat): 

150 self.assertEqual(outSrc[PLUGIN_NAME], inSrc[PLUGIN_NAME]) 

151 self.assertEqual(outSrc[PLUGIN_NAME + "_transform"], inSrc[PLUGIN_NAME] * -1.0) 

152 for field in transformTask.config.toDict()['copyFields']: 

153 self.assertEqual(outSrc.get(field), inSrc.get(field)) 

154 

155 # Check that the wcs and photoCalib objects were accessed once per transform. 

156 self.assertEqual(wcs.count, len(transformTask.transforms)) 

157 self.assertEqual(photoCalib.count, len(transformTask.transforms)) 

158 

159 def testSingleFrameMeasurementTransform(self): 

160 """Test applying a transform task to the results of single frame measurement.""" 

161 schema = afwTable.SourceTable.makeMinimalSchema() 

162 sfmConfig = measBase.SingleFrameMeasurementConfig(plugins=[PLUGIN_NAME]) 

163 # We don't use slots in this test 

164 for key in sfmConfig.slots: 

165 setattr(sfmConfig.slots, key, None) 

166 sfmTask = measBase.SingleFrameMeasurementTask(schema, config=sfmConfig) 

167 transformTask = TransformTask(measConfig=sfmConfig, 

168 inputSchema=sfmTask.schema, outputDataset="src") 

169 self._transformAndCheck(sfmConfig, sfmTask.schema, transformTask) 

170 

171 def testForcedMeasurementTransform(self): 

172 """Test applying a transform task to the results of forced measurement.""" 

173 schema = afwTable.SourceTable.makeMinimalSchema() 

174 forcedConfig = measBase.ForcedMeasurementConfig(plugins=[PLUGIN_NAME]) 

175 # We don't use slots in this test 

176 for key in forcedConfig.slots: 

177 setattr(forcedConfig.slots, key, None) 

178 forcedConfig.copyColumns = {"id": "objectId", "parent": "parentObjectId"} 

179 forcedTask = measBase.ForcedMeasurementTask(schema, config=forcedConfig) 

180 transformConfig = TransformConfig(copyFields=("objectId", "coord_ra", "coord_dec")) 

181 transformTask = TransformTask(measConfig=forcedConfig, 

182 inputSchema=forcedTask.schema, outputDataset="forced_src", 

183 config=transformConfig) 

184 self._transformAndCheck(forcedConfig, forcedTask.schema, transformTask) 

185 

186 

187class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

188 pass 

189 

190 

191def setup_module(module): 

192 lsst.utils.tests.init() 

193 

194 

195if __name__ == "__main__": 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true

196 lsst.utils.tests.init() 

197 unittest.main()