Coverage for tests/test_gbdesAstrometricFit.py: 11%
286 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-20 01:57 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-20 01:57 -0800
1# This file is part of drp_tasks
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
23import os.path
24import numpy as np
25import yaml
26import astropy.units as u
27import pandas as pd
28import wcsfit
30import lsst.utils
31import lsst.afw.table as afwTable
32import lsst.afw.geom as afwgeom
33from lsst.drp.tasks import GbdesAstrometricFitConfig, GbdesAstrometricFitTask
34from lsst.daf.base import PropertyList
35from lsst.daf.butler import DimensionUniverse, DatasetType, DatasetRef, StorageClass
36from lsst.meas.algorithms import ReferenceObjectLoader
37from lsst.pipe.base import InMemoryDatasetHandle
38from lsst import sphgeom
39import lsst.geom
42class MockRefCatDataId():
44 def __init__(self, region):
45 self.region = region
47 datasetDimensions = DimensionUniverse().extract(['htm7'])
48 datasetType = DatasetType('gaia_dr2_20200414', datasetDimensions, StorageClass("SimpleCatalog"))
49 self.ref = DatasetRef(datasetType, {'htm7': "mockRefCat"})
52class TestGbdesAstrometricFit(lsst.utils.tests.TestCase):
54 @classmethod
55 def setUpClass(cls):
57 # Fraction of simulated stars in the reference catalog and science
58 # exposures
59 inReferenceFraction = 1
60 inScienceFraction = 1
62 # Make fake data
63 packageDir = lsst.utils.getPackageDir('drp_tasks')
64 cls.datadir = os.path.join(packageDir, 'tests', "data")
66 cls.fieldNumber = 0
67 cls.instrumentName = 'HSC'
68 cls.instrument = wcsfit.Instrument(cls.instrumentName)
69 cls.refEpoch = 57205.5
71 # Make test inputVisitSummary. VisitSummaryTables are taken from
72 # collection HSC/runs/RC2/w_2022_20/DM-34794
73 cls.testVisits = [1176, 17900, 17930, 17934]
74 cls.inputVisitSummary = []
75 for testVisit in cls.testVisits:
76 visSum = afwTable.ExposureCatalog.readFits(os.path.join(cls.datadir,
77 f'visitSummary_{testVisit}.fits'))
78 cls.inputVisitSummary.append(visSum)
80 cls.config = GbdesAstrometricFitConfig()
81 cls.config.systematicError = 0
82 cls.config.devicePolyOrder = 4
83 cls.config.exposurePolyOrder = 6
84 cls.task = GbdesAstrometricFitTask(config=cls.config)
86 cls.exposureInfo, cls.exposuresHelper, cls.extensionInfo = cls.task._get_exposure_info(
87 cls.inputVisitSummary, cls.instrument, refEpoch=cls.refEpoch)
89 cls.fields, cls.fieldCenter, cls.fieldRadius = cls.task._prep_sky(
90 cls.inputVisitSummary, cls.exposureInfo.medianEpoch)
92 # Bounding box of observations:
93 raMins, raMaxs = [], []
94 decMins, decMaxs = [], []
95 for visSum in cls.inputVisitSummary:
96 raMins.append(visSum['raCorners'].min())
97 raMaxs.append(visSum['raCorners'].max())
98 decMins.append(visSum['decCorners'].min())
99 decMaxs.append(visSum['decCorners'].max())
100 raMin = min(raMins)
101 raMax = max(raMaxs)
102 decMin = min(decMins)
103 decMax = max(decMaxs)
105 corners = [lsst.geom.SpherePoint(raMin, decMin, lsst.geom.degrees).getVector(),
106 lsst.geom.SpherePoint(raMax, decMin, lsst.geom.degrees).getVector(),
107 lsst.geom.SpherePoint(raMax, decMax, lsst.geom.degrees).getVector(),
108 lsst.geom.SpherePoint(raMin, decMax, lsst.geom.degrees).getVector()]
109 cls.boundingPolygon = sphgeom.ConvexPolygon(corners)
111 # Make random set of data in a bounding box determined by input visits
112 # Make wcs objects for the "true" model
113 cls.nStars = 10000
114 starIds = np.arange(cls.nStars)
115 starRAs = np.random.random(cls.nStars) * (raMax - raMin) + raMin
116 starDecs = np.random.random(cls.nStars) * (decMax - decMin) + decMin
118 # Make a reference catalog and load it into ReferenceObjectLoader
119 refDataId, deferredRefCat = cls._make_refCat(starIds, starRAs, starDecs, inReferenceFraction)
120 cls.refObjectLoader = ReferenceObjectLoader([refDataId], [deferredRefCat])
121 cls.refObjectLoader.config.requireProperMotion = False
122 cls.refObjectLoader.config.anyFilterMapsToThis = 'test_filter'
124 cls.task.refObjectLoader = cls.refObjectLoader
126 # Get True WCS for stars:
127 with open(os.path.join(cls.datadir, 'sample_wcs.yaml'), 'r') as f:
128 cls.trueModel = yaml.load(f, Loader=yaml.Loader)
130 trueWCSs = cls._make_wcs(cls.trueModel, cls.inputVisitSummary)
132 # Make source catalogs:
133 cls.inputCatalogRefs = cls._make_sourceCat(starIds, starRAs, starDecs, trueWCSs,
134 inScienceFraction)
136 @classmethod
137 def _make_refCat(cls, starIds, starRas, starDecs, inReferenceFraction):
138 """Make reference catalog from a subset of the simulated data
140 Parameters
141 ----------
142 starIds : `np.ndarray` of `int`
143 Source ids for the simulated stars
144 starRas : `np.ndarray` of `float`
145 RAs of the simulated stars
146 starDecs : `np.ndarray` of `float`
147 Decs of the simulated stars
148 inReferenceFraction : float
149 Percentage of simulated stars to include in reference catalog
151 Returns
152 -------
153 refDataId : MockRefCatDataId
154 Object that replicates the functionality of a dataId.
155 deferredRefCat : `lsst.pipe.base.InMemoryDatasetHandle`
156 Dataset handle for reference catalog.
157 """
158 nRefs = int(cls.nStars * inReferenceFraction)
159 refStarIndices = np.random.choice(cls.nStars, nRefs, replace=False)
160 # Make simpleCatalog to hold data, create datasetRef with `region`
161 # determined by bounding box used in above simulate.
162 refSchema = afwTable.SimpleTable.makeMinimalSchema()
163 idKey = refSchema.addField("sourceId", type="I")
164 fluxKey = refSchema.addField("test_filter_flux", units='nJy', type=np.float64)
165 raErrKey = refSchema.addField("coord_raErr", type=np.float64)
166 decErrKey = refSchema.addField("coord_decErr", type=np.float64)
167 pmraErrKey = refSchema.addField("pm_raErr", type=np.float64)
168 pmdecErrKey = refSchema.addField("pm_decErr", type=np.float64)
169 refCat = afwTable.SimpleCatalog(refSchema)
170 ref_md = PropertyList()
171 ref_md.set("REFCAT_FORMAT_VERSION", 1)
172 refCat.table.setMetadata(ref_md)
173 for i in refStarIndices:
174 record = refCat.addNew()
175 record.set(idKey, starIds[i])
176 record.setRa(lsst.geom.Angle(starRas[i], lsst.geom.degrees))
177 record.setDec(lsst.geom.Angle(starDecs[i], lsst.geom.degrees))
178 record.set(fluxKey, 1)
179 record.set(raErrKey, 0.00001)
180 record.set(decErrKey, 0.00001)
181 record.set(pmraErrKey, 1e-9)
182 record.set(pmdecErrKey, 1e-9)
183 refDataId = MockRefCatDataId(cls.boundingPolygon)
184 deferredRefCat = InMemoryDatasetHandle(refCat, storageClass="SourceCatalog")
186 return refDataId, deferredRefCat
188 @classmethod
189 def _make_sourceCat(cls, starIds, starRas, starDecs, trueWCSs, inScienceFraction):
190 """Make a `pd.DataFrame` catalog with the columns needed for the
191 object selector.
193 Parameters
194 ----------
195 starIds : `np.ndarray` of `int`
196 Source ids for the simulated stars
197 starRas : `np.ndarray` of `float`
198 RAs of the simulated stars
199 starDecs : `np.ndarray` of `float`
200 Decs of the simulated stars
201 trueWCSs : `list` of `lsst.afw.geom.SkyWcs`
202 WCS with which to simulate the source pixel coordinates
203 inReferenceFraction : float
204 Percentage of simulated stars to include in reference catalog
206 Returns
207 -------
208 sourceCat : `list` of `lsst.pipe.base.InMemoryDatasetHandle`
209 List of reference to source catalogs.
210 """
211 inputCatalogRefs = []
212 # Take a subset of the simulated data
213 # Use true wcs objects to put simulated data into ccds
214 bbox = lsst.geom.BoxD(lsst.geom.Point2D(cls.inputVisitSummary[0][0]['bbox_min_x'],
215 cls.inputVisitSummary[0][0]['bbox_min_y']),
216 lsst.geom.Point2D(cls.inputVisitSummary[0][0]['bbox_max_x'],
217 cls.inputVisitSummary[0][0]['bbox_max_y']))
218 bboxCorners = bbox.getCorners()
219 cls.inputCatalogRefs = []
220 for v, visit in enumerate(cls.testVisits):
221 nVisStars = int(cls.nStars * inScienceFraction)
222 visitStarIndices = np.random.choice(cls.nStars, nVisStars, replace=False)
223 visitStarIds = starIds[visitStarIndices]
224 visitStarRas = starRas[visitStarIndices]
225 visitStarDecs = starDecs[visitStarIndices]
226 sourceCats = []
227 for detector in trueWCSs[visit]:
228 detWcs = detector.getWcs()
229 detectorId = detector['id']
230 radecCorners = detWcs.pixelToSky(bboxCorners)
231 detectorFootprint = sphgeom.ConvexPolygon([rd.getVector() for rd in radecCorners])
232 detectorIndices = detectorFootprint.contains((visitStarRas*u.degree).to(u.radian),
233 (visitStarDecs*u.degree).to(u.radian))
234 nDetectorStars = detectorIndices.sum()
235 detectorArray = np.ones(nDetectorStars, dtype=bool) * detector['id']
237 ones_like = np.ones(nDetectorStars)
238 zeros_like = np.zeros(nDetectorStars, dtype=bool)
240 x, y = detWcs.skyToPixelArray(visitStarRas[detectorIndices], visitStarDecs[detectorIndices],
241 degrees=True)
243 origWcs = (cls.inputVisitSummary[v][cls.inputVisitSummary[v]['id'] == detectorId])[0].getWcs()
244 inputRa, inputDec = origWcs.pixelToSkyArray(x, y, degrees=True)
246 sourceDict = {}
247 sourceDict['detector'] = detectorArray
248 sourceDict['sourceId'] = visitStarIds[detectorIndices]
249 sourceDict['x'] = x
250 sourceDict['y'] = y
251 sourceDict['xErr'] = 1e-3 * ones_like
252 sourceDict['yErr'] = 1e-3 * ones_like
253 sourceDict['inputRA'] = inputRa
254 sourceDict['inputDec'] = inputDec
255 sourceDict['trueRA'] = visitStarRas[detectorIndices]
256 sourceDict['trueDec'] = visitStarDecs[detectorIndices]
257 for key in ['apFlux_12_0_flux', 'apFlux_12_0_instFlux', 'ixx', 'iyy']:
258 sourceDict[key] = ones_like
259 for key in ['pixelFlags_edge', 'pixelFlags_saturated', 'pixelFlags_interpolatedCenter',
260 'pixelFlags_interpolated', 'pixelFlags_crCenter', 'pixelFlags_bad',
261 'hsmPsfMoments_flag', 'apFlux_12_0_flag', 'extendedness', 'parentSourceId',
262 'deblend_nChild', 'ixy']:
263 sourceDict[key] = zeros_like
264 sourceDict['apFlux_12_0_instFluxErr'] = 1e-3 * ones_like
266 sourceCat = pd.DataFrame(sourceDict)
267 sourceCats.append(sourceCat)
269 visitSourceTable = pd.concat(sourceCats)
271 inputCatalogRef = InMemoryDatasetHandle(visitSourceTable, storageClass="DataFrame",
272 dataId={"visit": visit})
274 inputCatalogRefs.append(inputCatalogRef)
276 return inputCatalogRefs
278 @classmethod
279 def _make_wcs(cls, model, inputVisitSummaries):
280 """Make a `lsst.afw.geom.SkyWcs` from given model parameters
282 Parameters
283 ----------
284 model : `dict`
285 Dictionary with WCS model parameters
286 inputVisitSummaries : `list` of `lsst.afw.table.ExposureCatalog`
287 Visit summary catalogs
288 Returns
289 -------
290 catalogs : `dict` of `lsst.afw.table.ExposureCatalog`
291 Visit summary catalogs with WCS set to input model
292 """
294 # Pixels will need to be rescaled before going into the mappings
295 xscale = inputVisitSummaries[0][0]['bbox_max_x'] - inputVisitSummaries[0][0]['bbox_min_x']
296 yscale = inputVisitSummaries[0][0]['bbox_max_y'] - inputVisitSummaries[0][0]['bbox_min_y']
298 catalogs = {}
299 schema = lsst.afw.table.ExposureTable.makeMinimalSchema()
300 schema.addField('visit', type='L', doc='Visit number')
301 for visitSum in inputVisitSummaries:
302 visit = visitSum[0]['visit']
303 visitMapName = f'{visit}/poly'
304 visitModel = model[visitMapName]
306 catalog = lsst.afw.table.ExposureCatalog(schema)
307 catalog.resize(len(visitSum))
308 catalog['visit'] = visit
310 raDec = visitSum[0].getVisitInfo().getBoresightRaDec()
312 visitMapType = visitModel['Type']
313 visitDict = {'Type': visitMapType}
314 if visitMapType == 'Poly':
315 mapCoefficients = (visitModel['XPoly']['Coefficients']
316 + visitModel['YPoly']['Coefficients'])
317 visitDict["Coefficients"] = mapCoefficients
319 for d, detector in enumerate(visitSum):
320 detectorId = detector['id']
321 detectorMapName = f'HSC/{detectorId}/poly'
322 detectorModel = model[detectorMapName]
324 detectorMapType = detectorModel['Type']
325 mapDict = {detectorMapName: {'Type': detectorMapType},
326 visitMapName: visitDict}
327 if detectorMapType == 'Poly':
328 mapCoefficients = (detectorModel['XPoly']['Coefficients']
329 + detectorModel['YPoly']['Coefficients'])
330 mapDict[detectorMapName]['Coefficients'] = mapCoefficients
332 outWCS = cls.task._make_afw_wcs(mapDict, raDec.getRa(), raDec.getDec(),
333 doNormalizePixels=True, xScale=xscale, yScale=yscale)
334 catalog[d].setId(detectorId)
335 catalog[d].setWcs(outWCS)
337 catalog.sort()
338 catalogs[visit] = catalog
340 return catalogs
342 def test_get_exposure_info(self):
343 """Test that information for input exposures is as expected and that
344 the WCS in the class object gives approximately the same results as the
345 input `lsst.afw.geom.SkyWcs`.
346 """
348 # The total number of extensions is the number of detectors for each
349 # visit plus one for the reference catalog
350 totalExtensions = sum([len(visSum) for visSum in self.inputVisitSummary]) + 1
352 self.assertEqual(totalExtensions, len(self.extensionInfo.visit))
354 taskVisits = set(self.extensionInfo.visit)
355 self.assertEqual(taskVisits, set(self.testVisits + [-1]))
357 xx = np.linspace(0, 2000, 3)
358 yy = np.linspace(0, 4000, 6)
359 xgrid, ygrid = np.meshgrid(xx, yy)
360 for visSum in self.inputVisitSummary:
361 visit = visSum[0]['visit']
362 for detectorInfo in visSum:
363 detector = detectorInfo['id']
364 extensionIndex = np.flatnonzero((self.extensionInfo.visit == visit)
365 & (self.extensionInfo.detector == detector))[0]
366 fitWcs = self.extensionInfo.wcs[extensionIndex]
367 calexpWcs = detectorInfo.getWcs()
369 tanPlaneXY = np.array([fitWcs.toWorld(x, y) for (x, y) in zip(xgrid.ravel(),
370 ygrid.ravel())])
372 calexpra, calexpdec = calexpWcs.pixelToSkyArray(xgrid.ravel(), ygrid.ravel(), degrees=True)
374 tangentPoint = calexpWcs.pixelToSky(calexpWcs.getPixelOrigin().getX(),
375 calexpWcs.getPixelOrigin().getY())
376 cdMatrix = afwgeom.makeCdMatrix(1.0 * lsst.geom.degrees, 0 * lsst.geom.degrees, True)
377 iwcToSkyWcs = afwgeom.makeSkyWcs(lsst.geom.Point2D(0, 0), tangentPoint, cdMatrix)
378 newRAdeg, newDecdeg = iwcToSkyWcs.pixelToSkyArray(tanPlaneXY[:, 0], tanPlaneXY[:, 1],
379 degrees=True)
381 # One WCS is in SIP and the other is TPV. The pixel-to-sky
382 # conversion is not exactly the same but should be close.
383 # TODO: sip_tpv + astropy.wcs.WCS gets a better result here,
384 # particularly for detector # >= 100. See if we can improve/if
385 # improving is necessary. Check if matching in corner detectors
386 # is ok.
387 rtol = (1e-3 if (detector >= 100) else 1e-5)
388 np.testing.assert_allclose(calexpra, newRAdeg, rtol=rtol)
389 np.testing.assert_allclose(calexpdec, newDecdeg, rtol=rtol)
391 def test_refCatLoader(self):
392 """Test that we can load objects from refCat
393 """
395 tmpAssociations = wcsfit.FoFClass(self.fields, [self.instrument], self.exposuresHelper,
396 [self.fieldRadius.asDegrees()],
397 (self.task.config.matchRadius * u.arcsec).to(u.degree).value)
399 self.task._load_refcat(tmpAssociations, self.refObjectLoader, self.fieldCenter, self.fieldRadius,
400 self.extensionInfo, epoch=2015)
402 # We have only loaded one catalog, so getting the 'matches' should just
403 # return the same objects we put in, except some random objects that
404 # are too close together.
405 tmpAssociations.sortMatches(self.fieldNumber, minMatches=1)
407 nMatches = (np.array(tmpAssociations.sequence) == 0).sum()
409 self.assertLessEqual(nMatches, self.nStars)
410 self.assertGreater(nMatches, self.nStars * 0.9)
412 def test_load_catalogs_and_associate(self):
414 tmpAssociations = wcsfit.FoFClass(self.fields, [self.instrument], self.exposuresHelper,
415 [self.fieldRadius.asDegrees()],
416 (self.task.config.matchRadius * u.arcsec).to(u.degree).value)
417 self.task._load_catalogs_and_associate(tmpAssociations, self.inputCatalogRefs, self.extensionInfo)
419 tmpAssociations.sortMatches(self.fieldNumber, minMatches=2)
421 matchIds = []
422 correctMatches = []
423 for (s, e, o) in zip(tmpAssociations.sequence, tmpAssociations.extn, tmpAssociations.obj):
424 objVisitInd = self.extensionInfo.visitIndex[e]
425 objDet = self.extensionInfo.detector[e]
426 ExtnInds = self.inputCatalogRefs[objVisitInd].get()['detector'] == objDet
427 objInfo = self.inputCatalogRefs[objVisitInd].get()[ExtnInds].iloc[o]
428 if s == 0:
429 if len(matchIds) > 0:
430 correctMatches.append(len(set(matchIds)) == 1)
431 matchIds = []
433 matchIds.append(objInfo['sourceId'])
435 # A few matches may incorrectly associate sources because of the random
436 # positions
437 self.assertGreater(sum(correctMatches), len(correctMatches) * 0.95)
439 def test_make_outputs(self):
440 """Test that the run method recovers the input model parameters.
441 """
442 task = GbdesAstrometricFitTask(config=self.config)
444 outputs = task.run(self.inputCatalogRefs, self.inputVisitSummary, instrumentName=self.instrumentName,
445 refEpoch=self.refEpoch, refObjectLoader=self.refObjectLoader)
447 for v, visit in enumerate(self.testVisits):
448 visitSummary = self.inputVisitSummary[v]
449 outputWcsCatalog = outputs.outputWCSs[visit]
450 visitSources = self.inputCatalogRefs[v].get()
451 for d, detectorRow in enumerate(visitSummary):
452 detectorId = detectorRow['id']
453 fitwcs = outputWcsCatalog[d].getWcs()
454 detSources = visitSources[visitSources['detector'] == detectorId]
455 fitRA, fitDec = fitwcs.pixelToSkyArray(detSources['x'], detSources['y'], degrees=True)
456 dRA = fitRA - detSources['trueRA']
457 dDec = fitDec - detSources['trueDec']
458 # Check that input coordinates match the output coordinates
459 self.assertAlmostEqual(np.mean(dRA), 0)
460 self.assertAlmostEqual(np.std(dRA), 0)
461 self.assertAlmostEqual(np.mean(dDec), 0)
462 self.assertAlmostEqual(np.std(dDec), 0)
464 def test_run(self):
465 """Test that run method recovers the input model parameters
466 """
467 task = GbdesAstrometricFitTask(config=self.config)
469 outputs = task.run(self.inputCatalogRefs, self.inputVisitSummary, instrumentName=self.instrumentName,
470 refEpoch=self.refEpoch, refObjectLoader=self.refObjectLoader)
472 outputMaps = outputs.fitModel.mapCollection.getParamDict()
474 for v, visit in enumerate(self.testVisits):
475 visitSummary = self.inputVisitSummary[v]
476 visitMapName = f'{visit}/poly'
478 origModel = self.trueModel[visitMapName]
479 if origModel['Type'] != 'Identity':
480 fitModel = outputMaps[visitMapName]
481 origXPoly = origModel['XPoly']['Coefficients']
482 origYPoly = origModel['YPoly']['Coefficients']
483 fitXPoly = fitModel[:len(origXPoly)]
484 fitYPoly = fitModel[len(origXPoly):]
486 absDiffX = abs(fitXPoly - origXPoly)
487 absDiffY = abs(fitYPoly - origYPoly)
488 # Check that input visit model matches fit
489 np.testing.assert_array_less(absDiffX, 1e-6)
490 np.testing.assert_array_less(absDiffY, 1e-6)
491 for d, detectorRow in enumerate(visitSummary):
492 detectorId = detectorRow['id']
493 detectorMapName = f'HSC/{detectorId}/poly'
494 origModel = self.trueModel[detectorMapName]
495 if (origModel['Type'] != 'Identity') and (v == 0):
496 fitModel = outputMaps[detectorMapName]
497 origXPoly = origModel['XPoly']['Coefficients']
498 origYPoly = origModel['YPoly']['Coefficients']
499 fitXPoly = fitModel[:len(origXPoly)]
500 fitYPoly = fitModel[len(origXPoly):]
501 absDiffX = abs(fitXPoly - origXPoly)
502 absDiffY = abs(fitYPoly - origYPoly)
503 # Check that input detector model matches fit
504 np.testing.assert_array_less(absDiffX, 1e-7)
505 np.testing.assert_array_less(absDiffY, 1e-7)
508def setup_module(module):
509 lsst.utils.tests.init()
512if __name__ == "__main__": 512 ↛ 513line 512 didn't jump to line 513, because the condition on line 512 was never true
513 lsst.utils.tests.init()
514 unittest.main()