Coverage for python / lsst / drp / tasks / fit_stellar_motion.py: 17%
182 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:19 +0000
1# This file is part of drp_tasks.
2#
3# LSST Data Management System
4# This product includes software developed by the
5# LSST Project (http://www.lsst.org/).
6# See COPYRIGHT file at the top of the source tree.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <https://www.lsstcorp.org/LegalNotices/>.
21#
23__all__ = ["FitStellarMotionConfig", "FitStellarMotionConnections", "FitStellarMotionTask"]
25import astropy.coordinates
26import astropy.units as u
27import numpy as np
28import wcsfit
29from astropy.table import Table, hstack, join, vstack
31import lsst.afw.geom as afwGeom
32import lsst.geom
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader
36from lsst.skymap import BaseSkyMap
39class FitStellarMotionConnections(
40 pipeBase.PipelineTaskConnections,
41 dimensions=(
42 "instrument",
43 "tract",
44 "skymap",
45 ),
46):
47 visitSummaries = pipeBase.connectionTypes.Input(
48 doc=(
49 "Per-visit consolidated exposure metadata built from calexps. "
50 "These catalogs use detector id for the id and must be sorted for "
51 "fast lookups of a detector."
52 ),
53 name="preliminary_visit_summary",
54 storageClass="ExposureCatalog",
55 dimensions=("instrument", "visit"),
56 multiple=True,
57 deferLoad=True,
58 )
59 starSourceRef = pipeBase.connectionTypes.Input(
60 doc="Catalog of matched sources.",
61 name="isolated_star",
62 storageClass="ArrowAstropy",
63 dimensions=(
64 "instrument",
65 "skymap",
66 "tract",
67 ),
68 deferLoad=True,
69 )
70 starCatalogRef = pipeBase.connectionTypes.Input(
71 doc="Catalog of objects corresponding to the matched sources.",
72 name="isolated_star_association",
73 storageClass="ArrowAstropy",
74 dimensions=(
75 "instrument",
76 "skymap",
77 "tract",
78 ),
79 deferLoad=True,
80 )
81 inputSources = pipeBase.connectionTypes.Input(
82 doc="Source table in parquet format, per visit.",
83 name="recalibrated_star",
84 storageClass="ArrowAstropy",
85 dimensions=("instrument", "visit"),
86 deferLoad=True,
87 multiple=True,
88 )
89 referenceCatalog = pipeBase.connectionTypes.PrerequisiteInput(
90 doc="The astrometry reference catalog to match to loaded input catalog sources.",
91 name="the_monster_20250219",
92 storageClass="SimpleCatalog",
93 dimensions=("skypix",),
94 deferLoad=True,
95 multiple=True,
96 )
97 skymap = pipeBase.connectionTypes.Input(
98 doc="Input definition of bbox containing the associated sources.",
99 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
100 storageClass="SkyMap",
101 dimensions=("skymap",),
102 )
103 visitTable = pipeBase.connectionTypes.Input(
104 doc="Survey-wide table of visits, which will be used to get median epoch.",
105 name="preliminary_visit_table",
106 storageClass="ArrowAstropy",
107 dimensions=("instrument",),
108 deferLoad=True,
109 )
110 outputCatalog = pipeBase.connectionTypes.Output(
111 doc="Best fit position, proper motion and parallax for input objects.",
112 name="isolated_star_stellar_motions",
113 storageClass="ArrowAstropy",
114 dimensions=(
115 "instrument",
116 "skymap",
117 "tract",
118 ),
119 )
120 predictedPositions = pipeBase.connectionTypes.Output(
121 doc="Predicted position for each source at the epoch of observation.",
122 name="isolated_star_predicted_positions",
123 storageClass="ArrowAstropy",
124 dimensions=(
125 "instrument",
126 "skymap",
127 "tract",
128 ),
129 )
131 def __init__(self, *, config=None):
132 super().__init__(config=config)
134 if not self.config.includeReferenceCatalog:
135 self.inputs.remove("referenceCatalog")
136 self.inputs.remove("skymap")
137 if self.config.outputEpoch:
138 self.inputs.remove("visitTable")
141class FitStellarMotionConfig(pipeBase.PipelineTaskConfig, pipelineConnections=FitStellarMotionConnections):
142 includeReferenceCatalog = pexConfig.Field(
143 doc="Include the reference catalog in the fit.",
144 dtype=bool,
145 default=True,
146 )
147 referenceFilter = pexConfig.Field(
148 dtype=str,
149 doc="Name of filter to load from reference catalog. This is a required argument, although the values"
150 "returned are not used.",
151 default="phot_g_mean",
152 )
153 referenceMatchRadius = pexConfig.Field(
154 dtype=float,
155 doc="Maximum matching distance in arcseconds between the star catalog and the reference catalog.",
156 default=0.1,
157 )
158 outputEpoch = pexConfig.Field(
159 dtype=float,
160 doc="Epoch to which output positions will correspond. If not set, the median epoch of all visits in "
161 "visitTable will be used.",
162 default=None,
163 optional=True,
164 )
167class FitStellarMotionTask(pipeBase.PipelineTask):
168 """Fit proper motion and parallax for associated sources.
170 Input sources are assumed to be isolated point sources.
171 """
173 ConfigClass = FitStellarMotionConfig
174 _DefaultName = "fitStellarMotions"
176 def runQuantum(self, butlerQC, inputRefs, outputRefs):
177 # Override runQuantum to set up the refObjLoaders and turn input lists
178 # into dicts.
179 inputs = butlerQC.get(inputRefs)
181 inputSourceDict = {inputSource.dataId["visit"]: inputSource for inputSource in inputs["inputSources"]}
182 inputs["inputSources"] = inputSourceDict
183 visitSummaryDict = {
184 visitSummary.dataId["visit"]: visitSummary for visitSummary in inputs["visitSummaries"]
185 }
186 inputs["visitSummaries"] = visitSummaryDict
188 if self.config.includeReferenceCatalog:
189 tractId = inputs["starCatalogRef"].dataId["tract"]
190 skymap = inputs.pop("skymap")
191 tractRegion = skymap.generateTract(tractId).outer_sky_polygon
193 refConfig = LoadReferenceObjectsConfig()
194 refConfig.requireProperMotion = True
195 refObjectLoader = ReferenceObjectLoader(
196 dataIds=[ref.datasetRef.dataId for ref in inputRefs.referenceCatalog],
197 refCats=inputs.pop("referenceCatalog"),
198 config=refConfig,
199 log=self.log,
200 )
201 else:
202 refObjectLoader = None
203 tractRegion = None
205 if self.config.outputEpoch:
206 epoch = astropy.time.Time(self.config.outputEpoch, format="mjd")
207 else:
208 # Use the median epoch of all visits in the survey.
209 visitTable = inputs.pop("visitTable")
210 allVisits = visitTable.get(parameters={"columns": ["expMidptMJD"]})
211 epoch = astropy.time.Time(np.median(allVisits["expMidptMJD"]), format="mjd")
213 output = self.run(**inputs, epoch=epoch, refObjectLoader=refObjectLoader, tractRegion=tractRegion)
215 butlerQC.put(output.outputCatalog, outputRefs.outputCatalog)
216 butlerQC.put(output.predictedPositions, outputRefs.predictedPositions)
218 def run(
219 self,
220 starSourceRef,
221 inputSources,
222 starCatalogRef,
223 visitSummaries,
224 epoch,
225 refObjectLoader=None,
226 tractRegion=None,
227 ):
228 """Fit proper motion and parallax for isolated stars.
230 Parameters
231 ----------
232 starSourceRef : `DeferredDatasetHandle`
233 Handle pointing to catalog of associated sources.
234 inputSources : `dict` [`int`, `DeferredDatasetHandle`]
235 Dictionary of source catalog handles, keyed by their visit id.
236 starCatalogRef : `DeferredDatasetHandle`
237 Handle pointing to catalog of objects corresponding to associated
238 sources.
239 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`]
240 Dictionary of catalogs with per-detector summary information, keyed
241 by their visit id.
242 epoch : `float`
243 Epoch in MJD at which to fit positions of objects.
244 refObjectLoader :
245 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader`,
246 optional
247 Reference object loader
248 tractRegion : `lsst.sphgeom.Region`
249 Region containing the associated sources.
251 Returns
252 -------
253 result : `lsst.pipe.Base.Struct`
254 ``outputCatalog`` : `astropy.table.Table`
255 Catalog with postion, proper motion and parallax for all input
256 objects, with NAN for objects without enough data to fit
257 parameters.
258 ``predictedPositions`` : `astropy.table.Table`
259 Catalog with predicted positions for all input sources at their
260 epoch observation, with NAN for objects with insufficient data.
261 """
262 # Load needed columns for associated sources.
263 starSources = starSourceRef.get(parameters={"columns": ["visit", "sourceId", "obj_index"]})
264 if not starSources:
265 raise pipeBase.NoWorkFound("No isolated stars found in this region.")
267 starSources.add_index("sourceId")
269 # Load reference objects.
270 if self.config.includeReferenceCatalog:
271 refCatalog = self._load_refCat(refObjectLoader, tractRegion, epoch)
272 else:
273 refCatalog = None
275 # Load needed columns from source catalogs and get visit info.
276 visitStars, visitInfo = self._load_sources(starSources, visitSummaries, inputSources)
278 # Fit postion, proper motion and parallax for all objects.
279 outCat, predictedRADec = self._fit_objects(
280 visitStars, starCatalogRef, starSources, visitInfo, epoch, refCatalog=refCatalog
281 )
283 return pipeBase.Struct(outputCatalog=outCat, predictedPositions=predictedRADec)
285 def _load_refCat(self, refObjectLoader, region, epoch):
286 """Load reference catalog.
288 Parameters
289 ----------
290 refObjectLoader :
291 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader`
292 Reference object loader
293 tractRegion : `lsst.sphgeom.Region`
294 Region containing the associated sources.
295 epoch : `astropy.time.Time`
296 Epoch to which the reference catalog will be shifted.
298 Returns
299 -------
300 refCatalog : `astropy.table.Table`
301 Catalog of reference objects.
302 """
304 refCat = refObjectLoader.loadRegion(region, self.config.referenceFilter, epoch=epoch).refCat
305 refCat = refCat.asAstropy()
307 # In Gaia DR3, missing values are denoted by NaNs.
308 finiteInd = np.isfinite(refCat["coord_ra"]) & np.isfinite(refCat["coord_dec"])
309 refCat = refCat[finiteInd]
311 ra = (refCat["coord_ra"]).to(u.degree)
312 dec = (refCat["coord_dec"]).to(u.degree)
313 raPM = (refCat["pm_ra"]).to(u.marcsec)
314 decPM = (refCat["pm_dec"]).to(u.marcsec)
315 parallax = (refCat["parallax"]).to(u.marcsec)
317 cov = np.zeros((len(refCat), 5, 5))
318 positionParameters = ["coord_ra", "coord_dec", "pm_ra", "pm_dec", "parallax"]
319 for i, pi in enumerate(positionParameters):
320 for j, pj in enumerate(positionParameters):
321 if i == j:
322 cov[:, i, i] = ((refCat[f"{pi}Err"].value) ** 2 * u.radian**2).to(u.marcsec**2).value
323 elif i > j:
324 cov[:, i, j] = (refCat[f"{pj}_{pi}_Cov"].value * u.radian**2).to_value(u.marcsec**2)
325 else:
326 cov[:, i, j] = (refCat[f"{pi}_{pj}_Cov"].value * u.radian**2).to_value(u.marcsec**2)
327 refCatalog = Table(
328 {"ra": ra, "dec": dec, "raPM": raPM, "decPM": decPM, "parallax": parallax, "covariance": cov}
329 )
330 return refCatalog
332 def _load_sources(self, starSources, visitSummaries, inputSources):
333 """Load isolated sources and get visit information.
335 Parameters
336 ----------
337 starSources : `astropy.table.Table`
338 Catalog of associated sources.
339 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`]
340 Dictionary of catalogs with per-detector summary information keyed
341 by their visit id.
342 inputSources : `dict` [`int`, `DeferredDatasetHandle`]
343 Dictionary of source catalog handles, keyed by their visit id.
345 Returns
346 -------
347 allVisitStars : `astropy.table.Table`
348 Catalog with all needed information for associated sources.
349 visitInfo : `astropy.table.Table`
350 Catalog with observation epoch and location in ICRS coordinates.
351 """
352 visits = np.unique(starSources["visit"])
353 visits.sort()
354 observatories = []
355 mjds = []
356 allVisitStars = []
357 finalVisits = []
358 for visit in visits:
359 if (visit not in visitSummaries) or (visit not in inputSources):
360 continue
362 visitSummary = visitSummaries[visit].get()
363 finalVisits.append(visit)
364 visitInfo = visitSummary[0].visitInfo
366 # Get MJD
367 obsDate = visitInfo.getDate()
368 obsMJD = obsDate.get(obsDate.MJD)
369 mjds.append(obsMJD)
371 # Get the observatory ICRS position for use in fitting parallax
372 obsLon = visitInfo.observatory.getLongitude().asDegrees()
373 obsLat = visitInfo.observatory.getLatitude().asDegrees()
374 obsElev = visitInfo.observatory.getElevation()
375 earthLocation = astropy.coordinates.EarthLocation.from_geodetic(obsLon, obsLat, obsElev)
376 observatory_gcrs = earthLocation.get_gcrs(astropy.time.Time(obsMJD, format="mjd"))
377 observatory_icrs = observatory_gcrs.transform_to(astropy.coordinates.ICRS())
378 observatory = observatory_icrs.cartesian.xyz.to(u.AU).value
379 observatories.append(observatory)
381 # Load sources and keep isolated ones.
382 visitSources = inputSources[visit].get(
383 parameters={
384 "columns": [
385 "sourceId",
386 "ra",
387 "dec",
388 "raErr",
389 "decErr",
390 "ra_dec_Cov",
391 ]
392 }
393 )
394 visitStars = join(
395 visitSources,
396 starSources[starSources["visit"] == visit],
397 keys="sourceId",
398 join_type="inner",
399 )
400 allVisitStars.append(visitStars)
401 allVisitStars = vstack(allVisitStars)
402 visitInfo = Table({"visit": finalVisits, "observatory": observatories, "mjd": mjds})
403 visitInfo.add_index("visit")
405 return allVisitStars, visitInfo
407 def _fit_objects(self, visitStars, starCatalogRef, starSources, visitInfo, fitEpoch, refCatalog=None):
408 """Fit full 5-d position, proper motion, and parallax for associated
409 sources.
411 Parameters
412 ----------
413 visitStars : `astropy.table.Table`
414 Catalog with position information for associated sources.
415 starCatalogRef : `DeferredDatasetHandle`
416 Handle pointing to catalog of objects corresponding to associated
417 sources.
418 starSources : `astropy.table.Table`
419 Catalog of associated sources.
420 visitInfo : `astropy.table.Table`
421 Catalog with observation epoch and location in ICRS coordinates.
422 fitEpoch : `astropy.time.Time`
423 Epoch at which to fit positions of objects.
424 refCatalog : `astropy.table.Table`, optional
425 Catalog of reference objects. Used if
426 self.config.includeReferenceCatalog is true.
428 Returns
429 -------
430 outCat : `astropy.table.Table`
431 Catalog with postion, proper motion and parallax for all input
432 objects, with NAN for objects without enough data to fit
433 parameters.
434 predictedPositions : `astropy.table.Table`
435 Catalog with predicted positions for all input sources at their
436 epoch observation, with NAN for objects with insufficient data.
437 """
439 starCatalog = starCatalogRef.get(parameters={"columns": ["isolated_star_id", "ra", "dec"]})
441 if self.config.includeReferenceCatalog:
442 starCoord = astropy.coordinates.SkyCoord(
443 starCatalog["ra"] * u.degree, starCatalog["dec"] * u.degree
444 )
445 refCoord = astropy.coordinates.SkyCoord(refCatalog["ra"], refCatalog["dec"])
446 refId, refD2d, _ = starCoord.match_to_catalog_sky(refCoord)
448 identity = wcsfit.IdentityMap()
449 icrs = wcsfit.SphericalICRS()
450 refWcs = wcsfit.Wcs(identity, icrs, "Identity", np.pi / 180.0)
452 objects = np.unique(visitStars["obj_index"])
453 objects.sort()
455 # Make empty arrays to fill in, with NaN for any unfittable objects.
456 objectPositions = np.ones((len(starCatalog), 5)) * np.nan
457 objectCovariances = np.ones((len(starCatalog), 5, 5)) * np.nan
458 predictedRADec = np.ones((len(starSources), 2)) * np.nan
459 includesReference = np.zeros(len(starCatalog), dtype=bool)
460 nSources = np.zeros(len(starCatalog), dtype=int)
461 refPositions = Table(
462 np.ones((len(starCatalog), 5)) * np.nan,
463 names=("ref_ra", "ref_dec", "ref_raPM", "ref_decPM", "ref_covariance"),
464 dtype=("f8", "f8", "f8", "f8", "f8"),
465 )
466 refCovariances = np.ones((len(starCatalog), 5, 5)) * np.nan
467 for object in objects:
468 # Get all detections for this object.
469 detectionInds = visitStars["obj_index"] == object
470 detections = visitStars[detectionInds]
471 nDetections = len(detections)
472 scienceDetections = np.ones(len(detections), dtype=bool)
474 objectObservatories = visitInfo.loc[detections["visit"]]["observatory"]
475 objectMjds = visitInfo.loc[detections["visit"]]["mjd"]
477 # Move detections to be tangent plane around median position.
478 medRA = np.median(detections["ra"])
479 medDec = np.median(detections["dec"])
480 tangentPoint = lsst.geom.SpherePoint(medRA, medDec, lsst.geom.degrees)
481 cdMatrix = afwGeom.makeCdMatrix(1.0 * lsst.geom.degrees, 0 * lsst.geom.degrees, True)
482 iwcToSkyWcs = afwGeom.makeSkyWcs(lsst.geom.Point2D(0, 0), tangentPoint, cdMatrix)
483 tanX, tanY = iwcToSkyWcs.skyToPixelArray(detections["ra"], detections["dec"], degrees=True)
485 match = wcsfit.PMMatch(
486 tanX,
487 tanY,
488 detections["raErr"] ** 2,
489 detections["decErr"] ** 2,
490 detections["ra_dec_Cov"],
491 objectMjds,
492 objectObservatories,
493 medRA,
494 medDec,
495 fitEpoch.mjd,
496 )
498 if self.config.includeReferenceCatalog and (
499 refD2d[object].arcsecond < self.config.referenceMatchRadius
500 ):
501 nDetections += 1
502 refMatch = refCatalog[refId[object]]
503 match.addPMDetection(
504 refMatch["ra"],
505 refMatch["dec"],
506 refMatch["raPM"],
507 refMatch["decPM"],
508 refMatch["parallax"],
509 refMatch["covariance"],
510 refWcs,
511 )
512 scienceDetections = np.append(scienceDetections, False)
513 includesReference[object] = True
514 refPositions[object] = refMatch[["ra", "dec", "raPM", "decPM", "parallax"]]
515 refCovariances[object] = refMatch["covariance"]
517 elif nDetections < 3:
518 # If there is no associated reference object, there must be at
519 # least three detections in order to fit the 5-d solution.
520 continue
522 # Solve, get best-fit position and covariance, and prediction for
523 # the object position at the detection epochs.
524 match.solve()
525 fullPosition = match.getFit()
526 objectPositions[object] = fullPosition
527 objectCovariances[object] = match.getFitCovariance()
528 nSources[object] = nDetections
529 predictedPositions = match.predictAtDetections()
530 predictedRADec[starSources.loc_indices[detections["sourceId"]]] = predictedPositions[
531 scienceDetections
532 ]
534 outCat = Table(objectPositions, names=("ra", "dec", "raPM", "decPM", "parallax"))
535 outCat["hasReference"] = includesReference
536 outCat["covariance"] = objectCovariances
537 outCat = hstack([outCat, refPositions])
538 outCat["ref_covariance"] = refCovariances
539 outCat["isolated_star_id"] = starCatalog["isolated_star_id"]
540 outCat.meta["epoch"] = fitEpoch
542 predictedRADec = Table(predictedRADec, names=("ra", "dec"))
543 predictedRADec["sourceId"] = starSources["sourceId"]
545 return outCat, predictedRADec