Coverage for tests/test_transform.py: 42%
88 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-03 01:38 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-03 01:38 -0700
1#
2# LSST Data Management System
3# Copyright 2008-2015 AURA/LSST.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22"""
23Test the basic operation of measurement transformations.
25We test measurement transforms in two ways:
27First, we construct and run a simple TransformTask on the (mocked) results of
28measurement tasks. The same test is carried out against both
29SingleFrameMeasurementTask and ForcedMeasurementTask, on the basis that the
30transformation system should be agnostic as to the origin of the source
31catalog it is transforming.
32"""
33import unittest
35import lsst.utils
36import lsst.afw.table as afwTable
37import lsst.geom as geom
38import lsst.meas.base as measBase
39import lsst.utils.tests
40from lsst.pipe.tasks.transformMeasurement import TransformConfig, TransformTask
42PLUGIN_NAME = "base_TrivialMeasurement"
44# Rather than providing real WCS and calibration objects to the
45# transformation, we use this simple placeholder to keep track of the number
46# of times it is accessed.
49class Placeholder:
51 def __init__(self):
52 self.count = 0
54 def increment(self):
55 self.count += 1
58class TrivialMeasurementTransform(measBase.transforms.MeasurementTransform):
60 def __init__(self, config, name, mapper):
61 """Pass through all input fields to the output, and add a new field
62 named after the measurement with the suffix "_transform".
63 """
64 measBase.transforms.MeasurementTransform.__init__(self, config, name, mapper)
65 for key, field in mapper.getInputSchema().extract(name + "*").values():
66 mapper.addMapping(key)
67 self.key = mapper.editOutputSchema().addField(name + "_transform", type="D", doc="transformed dummy")
69 def __call__(self, inputCatalog, outputCatalog, wcs, photoCalib):
70 """Transform inputCatalog to outputCatalog.
72 We update the wcs and photoCalib placeholders to indicate that they have
73 been seen in the transformation, but do not use their values.
75 @param[in] inputCatalog SourceCatalog of measurements for transformation.
76 @param[out] outputCatalog BaseCatalog of transformed measurements.
77 @param[in] wcs Dummy WCS information; an instance of Placeholder.
78 @param[in] photoCalib Dummy calibration information; an instance of Placeholder.
79 """
80 if hasattr(wcs, "increment"):
81 wcs.increment()
82 if hasattr(photoCalib, "increment"):
83 photoCalib.increment()
84 inColumns = inputCatalog.getColumnView()
85 outColumns = outputCatalog.getColumnView()
86 outColumns[self.key] = -1.0 * inColumns[self.name]
89class TrivialMeasurementBase:
91 """Default values for a trivial measurement plugin, subclassed below"""
92 @staticmethod
93 def getExecutionOrder():
94 return 0
96 @staticmethod
97 def getTransformClass():
98 return TrivialMeasurementTransform
100 def measure(self, measRecord, exposure):
101 measRecord.set(self.key, 1.0)
104@measBase.register(PLUGIN_NAME)
105class SFTrivialMeasurement(TrivialMeasurementBase, measBase.sfm.SingleFramePlugin):
107 """Single frame version of the trivial measurement"""
109 def __init__(self, config, name, schema, metadata):
110 measBase.sfm.SingleFramePlugin.__init__(self, config, name, schema, metadata)
111 self.key = schema.addField(name, type="D", doc="dummy field")
114@measBase.register(PLUGIN_NAME)
115class ForcedTrivialMeasurement(TrivialMeasurementBase, measBase.forcedMeasurement.ForcedPlugin):
117 """Forced frame version of the trivial measurement"""
119 def __init__(self, config, name, schemaMapper, metadata):
120 measBase.forcedMeasurement.ForcedPlugin.__init__(self, config, name, schemaMapper, metadata)
121 self.key = schemaMapper.editOutputSchema().addField(name, type="D", doc="dummy field")
124class TransformTestCase(lsst.utils.tests.TestCase):
126 def _transformAndCheck(self, measConf, schema, transformTask):
127 """Check the results of applying transformTask to a SourceCatalog.
129 @param[in] measConf Measurement plugin configuration.
130 @param[in] schema Input catalog schema.
131 @param[in] transformTask Instance of TransformTask to be applied.
133 For internal use by this test case.
134 """
135 # There should now be one transformation registered per measurement plugin.
136 self.assertEqual(len(measConf.plugins.names), len(transformTask.transforms))
138 # Rather than do a real measurement, we use a dummy source catalog
139 # containing a source at an arbitrary position.
140 inCat = afwTable.SourceCatalog(schema)
141 r = inCat.addNew()
142 r.setCoord(geom.SpherePoint(0.0, 11.19, geom.degrees))
143 r[PLUGIN_NAME] = 1.0
145 wcs, photoCalib = Placeholder(), Placeholder()
146 outCat = transformTask.run(inCat, wcs, photoCalib)
148 # Check that all sources have been transformed appropriately.
149 for inSrc, outSrc in zip(inCat, outCat):
150 self.assertEqual(outSrc[PLUGIN_NAME], inSrc[PLUGIN_NAME])
151 self.assertEqual(outSrc[PLUGIN_NAME + "_transform"], inSrc[PLUGIN_NAME] * -1.0)
152 for field in transformTask.config.toDict()['copyFields']:
153 self.assertEqual(outSrc.get(field), inSrc.get(field))
155 # Check that the wcs and photoCalib objects were accessed once per transform.
156 self.assertEqual(wcs.count, len(transformTask.transforms))
157 self.assertEqual(photoCalib.count, len(transformTask.transforms))
159 def testSingleFrameMeasurementTransform(self):
160 """Test applying a transform task to the results of single frame measurement."""
161 schema = afwTable.SourceTable.makeMinimalSchema()
162 sfmConfig = measBase.SingleFrameMeasurementConfig(plugins=[PLUGIN_NAME])
163 # We don't use slots in this test
164 for key in sfmConfig.slots:
165 setattr(sfmConfig.slots, key, None)
166 sfmTask = measBase.SingleFrameMeasurementTask(schema, config=sfmConfig)
167 transformTask = TransformTask(measConfig=sfmConfig,
168 inputSchema=sfmTask.schema, outputDataset="src")
169 self._transformAndCheck(sfmConfig, sfmTask.schema, transformTask)
171 def testForcedMeasurementTransform(self):
172 """Test applying a transform task to the results of forced measurement."""
173 schema = afwTable.SourceTable.makeMinimalSchema()
174 forcedConfig = measBase.ForcedMeasurementConfig(plugins=[PLUGIN_NAME])
175 # We don't use slots in this test
176 for key in forcedConfig.slots:
177 setattr(forcedConfig.slots, key, None)
178 forcedConfig.copyColumns = {"id": "objectId", "parent": "parentObjectId"}
179 forcedTask = measBase.ForcedMeasurementTask(schema, config=forcedConfig)
180 transformConfig = TransformConfig(copyFields=("objectId", "coord_ra", "coord_dec"))
181 transformTask = TransformTask(measConfig=forcedConfig,
182 inputSchema=forcedTask.schema, outputDataset="forced_src",
183 config=transformConfig)
184 self._transformAndCheck(forcedConfig, forcedTask.schema, transformTask)
187class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
188 pass
191def setup_module(module):
192 lsst.utils.tests.init()
195if __name__ == "__main__": 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true
196 lsst.utils.tests.init()
197 unittest.main()