Coverage for python / lsst / drp / tasks / fit_stellar_motion.py: 16%
200 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-04 17:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-04 17:41 +0000
1# This file is part of drp_tasks.
2#
3# LSST Data Management System
4# This product includes software developed by the
5# LSST Project (http://www.lsst.org/).
6# See COPYRIGHT file at the top of the source tree.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <https://www.lsstcorp.org/LegalNotices/>.
21#
23__all__ = [
24 "FitStellarMotionConfig",
25 "FitStellarMotionConnections",
26 "FitStellarMotionTask",
27 "assemble_position_covariance",
28]
30import astropy.coordinates
31import astropy.units as u
32import numpy as np
33import wcsfit
34from astropy.table import Table, hstack, join, vstack
36import lsst.afw.geom as afwGeom
37import lsst.geom
38import lsst.pex.config as pexConfig
39import lsst.pipe.base as pipeBase
40from lsst.daf.butler import DatasetProvenance
41from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader
42from lsst.skymap import BaseSkyMap
45def assemble_position_covariance(table, names=("ra", "dec", "raPM", "decPM", "parallax")):
46 """Assemble the position covariance matrices, given a table with the error
47 and covariance values.
49 Parameters
50 ----------
51 table : `astropy.table.Table`
52 Table with covariance values.
53 names : `list` [`string`]
54 List of names for the coordinates, proper motion, and parallax values.
56 Returns
57 -------
58 covariance : `np.ndarray`
59 Array with covariance in matrix format, following the ordering in
60 `names`.
61 """
62 outArray = np.zeros((len(table), len(names), len(names)))
63 for i, name1 in enumerate(names):
64 for j, name2 in enumerate(names[: i + 1]):
65 if i == j:
66 columnName = f"{name1}Err"
67 outArray[:, i, i] = table[columnName] ** 2
68 else:
69 columnName = f"{name2}_{name1}_Cov"
70 outArray[:, i, j] = table[columnName]
71 outArray[:, j, i] = table[columnName]
72 return outArray
75class FitStellarMotionConnections(
76 pipeBase.PipelineTaskConnections,
77 dimensions=(
78 "instrument",
79 "tract",
80 "skymap",
81 ),
82):
83 visitSummaries = pipeBase.connectionTypes.Input(
84 doc=(
85 "Per-visit consolidated exposure metadata built from calexps. "
86 "These catalogs use detector id for the id and must be sorted for "
87 "fast lookups of a detector."
88 ),
89 name="preliminary_visit_summary",
90 storageClass="ExposureCatalog",
91 dimensions=("instrument", "visit"),
92 multiple=True,
93 deferLoad=True,
94 )
95 starSourceRef = pipeBase.connectionTypes.Input(
96 doc="Catalog of matched sources.",
97 name="isolated_star",
98 storageClass="ArrowAstropy",
99 dimensions=(
100 "instrument",
101 "skymap",
102 "tract",
103 ),
104 deferLoad=True,
105 )
106 starCatalogRef = pipeBase.connectionTypes.Input(
107 doc="Catalog of objects corresponding to the matched sources.",
108 name="isolated_star_association",
109 storageClass="ArrowAstropy",
110 dimensions=(
111 "instrument",
112 "skymap",
113 "tract",
114 ),
115 deferLoad=True,
116 )
117 inputSources = pipeBase.connectionTypes.Input(
118 doc="Source table in parquet format, per visit.",
119 name="recalibrated_star",
120 storageClass="ArrowAstropy",
121 dimensions=("instrument", "visit"),
122 deferLoad=True,
123 multiple=True,
124 )
125 referenceCatalog = pipeBase.connectionTypes.PrerequisiteInput(
126 doc="The astrometry reference catalog to match to loaded input catalog sources.",
127 name="the_monster_20250219",
128 storageClass="SimpleCatalog",
129 dimensions=("skypix",),
130 deferLoad=True,
131 multiple=True,
132 )
133 skymap = pipeBase.connectionTypes.Input(
134 doc="Input definition of bbox containing the associated sources.",
135 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
136 storageClass="SkyMap",
137 dimensions=("skymap",),
138 )
139 visitTable = pipeBase.connectionTypes.Input(
140 doc="Survey-wide table of visits, which will be used to get median epoch.",
141 name="preliminary_visit_table",
142 storageClass="ArrowAstropy",
143 dimensions=("instrument",),
144 deferLoad=True,
145 )
146 outputCatalog = pipeBase.connectionTypes.Output(
147 doc="Best fit position, proper motion and parallax for input objects.",
148 name="isolated_star_stellar_motions",
149 storageClass="ArrowAstropy",
150 dimensions=(
151 "instrument",
152 "skymap",
153 "tract",
154 ),
155 )
156 predictedPositions = pipeBase.connectionTypes.Output(
157 doc="Predicted position for each source at the epoch of observation.",
158 name="isolated_star_predicted_positions",
159 storageClass="ArrowAstropy",
160 dimensions=(
161 "instrument",
162 "skymap",
163 "tract",
164 ),
165 )
167 def __init__(self, *, config=None):
168 super().__init__(config=config)
170 if not self.config.includeReferenceCatalog:
171 self.inputs.remove("referenceCatalog")
172 self.inputs.remove("skymap")
173 if self.config.outputEpoch:
174 self.inputs.remove("visitTable")
177class FitStellarMotionConfig(pipeBase.PipelineTaskConfig, pipelineConnections=FitStellarMotionConnections):
178 includeReferenceCatalog = pexConfig.Field(
179 doc="Include the reference catalog in the fit.",
180 dtype=bool,
181 default=True,
182 )
183 referenceFilter = pexConfig.Field(
184 dtype=str,
185 doc="Name of filter to load from reference catalog. This is a required argument, although the values"
186 "returned are not used.",
187 default="phot_g_mean",
188 )
189 referenceMatchRadius = pexConfig.Field(
190 dtype=float,
191 doc="Maximum matching distance in arcseconds between the star catalog and the reference catalog.",
192 default=0.1,
193 )
194 outputEpoch = pexConfig.Field(
195 dtype=float,
196 doc="Epoch to which output positions will correspond. If not set, the median epoch of all visits in "
197 "visitTable will be used.",
198 default=None,
199 optional=True,
200 )
201 positionNames = pexConfig.ListField(
202 dtype=str,
203 default=["ra", "dec", "raPM", "decPM", "parallax"],
204 doc="Names of position, proper motion, and parallax columns.",
205 )
208class FitStellarMotionTask(pipeBase.PipelineTask):
209 """Fit proper motion and parallax for associated sources.
211 Input sources are assumed to be isolated point sources.
212 """
214 ConfigClass = FitStellarMotionConfig
215 _DefaultName = "fitStellarMotions"
217 def runQuantum(self, butlerQC, inputRefs, outputRefs):
218 # Override runQuantum to set up the refObjLoaders and turn input lists
219 # into dicts.
220 inputs = butlerQC.get(inputRefs)
222 inputSourceDict = {inputSource.dataId["visit"]: inputSource for inputSource in inputs["inputSources"]}
223 inputs["inputSources"] = inputSourceDict
224 visitSummaryDict = {
225 visitSummary.dataId["visit"]: visitSummary for visitSummary in inputs["visitSummaries"]
226 }
227 inputs["visitSummaries"] = visitSummaryDict
229 if self.config.includeReferenceCatalog:
230 tractId = inputs["starCatalogRef"].dataId["tract"]
231 skymap = inputs.pop("skymap")
232 tractRegion = skymap.generateTract(tractId).outer_sky_polygon
234 refConfig = LoadReferenceObjectsConfig()
235 refConfig.requireProperMotion = True
236 refObjectLoader = ReferenceObjectLoader(
237 dataIds=[ref.datasetRef.dataId for ref in inputRefs.referenceCatalog],
238 refCats=inputs.pop("referenceCatalog"),
239 config=refConfig,
240 log=self.log,
241 )
242 else:
243 refObjectLoader = None
244 tractRegion = None
246 if self.config.outputEpoch:
247 epoch = astropy.time.Time(self.config.outputEpoch, format="mjd")
248 else:
249 # Use the median epoch of all visits in the survey.
250 visitTable = inputs.pop("visitTable")
251 allVisits = visitTable.get(parameters={"columns": ["expMidptMJD"]})
252 epoch = astropy.time.Time(np.median(allVisits["expMidptMJD"]), format="mjd")
254 output = self.run(**inputs, epoch=epoch, refObjectLoader=refObjectLoader, tractRegion=tractRegion)
256 butlerQC.put(output.outputCatalog, outputRefs.outputCatalog)
257 butlerQC.put(output.predictedPositions, outputRefs.predictedPositions)
259 def run(
260 self,
261 starSourceRef,
262 inputSources,
263 starCatalogRef,
264 visitSummaries,
265 epoch,
266 refObjectLoader=None,
267 tractRegion=None,
268 ):
269 """Fit proper motion and parallax for isolated stars.
271 Parameters
272 ----------
273 starSourceRef : `DeferredDatasetHandle`
274 Handle pointing to catalog of associated sources.
275 inputSources : `dict` [`int`, `DeferredDatasetHandle`]
276 Dictionary of source catalog handles, keyed by their visit id.
277 starCatalogRef : `DeferredDatasetHandle`
278 Handle pointing to catalog of objects corresponding to associated
279 sources.
280 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`]
281 Dictionary of catalogs with per-detector summary information, keyed
282 by their visit id.
283 epoch : `float`
284 Epoch in MJD at which to fit positions of objects.
285 refObjectLoader :
286 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader`,
287 optional
288 Reference object loader
289 tractRegion : `lsst.sphgeom.Region`
290 Region containing the associated sources.
292 Returns
293 -------
294 result : `lsst.pipe.Base.Struct`
295 ``outputCatalog`` : `astropy.table.Table`
296 Catalog with postion, proper motion and parallax for all input
297 objects, with NAN for objects without enough data to fit
298 parameters.
299 ``predictedPositions`` : `astropy.table.Table`
300 Catalog with predicted positions for all input sources at their
301 epoch observation, with NAN for objects with insufficient data.
302 """
303 # Load needed columns for associated sources.
304 starSources = starSourceRef.get(parameters={"columns": ["visit", "sourceId", "obj_index"]})
305 if not starSources:
306 raise pipeBase.NoWorkFound("No isolated stars found in this region.")
308 DatasetProvenance.strip_provenance_from_flat_dict(starSources.meta)
309 starSources.add_index("sourceId")
311 # Load reference objects.
312 if self.config.includeReferenceCatalog:
313 refCatalog = self._load_refCat(refObjectLoader, tractRegion, epoch)
314 else:
315 refCatalog = None
317 # Load needed columns from source catalogs and get visit info.
318 visitStars, visitInfo = self._load_sources(starSources, visitSummaries, inputSources)
320 # Fit position, proper motion and parallax for all objects.
321 outCat, predictedRADec = self._fit_objects(
322 visitStars, starCatalogRef, starSources, visitInfo, epoch, refCatalog=refCatalog
323 )
325 return pipeBase.Struct(outputCatalog=outCat, predictedPositions=predictedRADec)
327 def _load_refCat(self, refObjectLoader, region, epoch):
328 """Load reference catalog.
330 Parameters
331 ----------
332 refObjectLoader :
333 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader`
334 Reference object loader
335 tractRegion : `lsst.sphgeom.Region`
336 Region containing the associated sources.
337 epoch : `astropy.time.Time`
338 Epoch to which the reference catalog will be shifted.
340 Returns
341 -------
342 refCatalog : `astropy.table.Table`
343 Catalog of reference objects.
344 """
346 refCat = refObjectLoader.loadRegion(region, self.config.referenceFilter, epoch=epoch).refCat
347 refCat = refCat.asAstropy()
349 # In Gaia DR3, missing values are denoted by NaNs.
350 finiteInd = np.isfinite(refCat["coord_ra"]) & np.isfinite(refCat["coord_dec"])
351 refCat = refCat[finiteInd]
353 ra = (refCat["coord_ra"]).to(u.degree)
354 dec = (refCat["coord_dec"]).to(u.degree)
355 raPM = (refCat["pm_ra"]).to(u.marcsec)
356 decPM = (refCat["pm_dec"]).to(u.marcsec)
357 parallax = (refCat["parallax"]).to(u.marcsec)
359 cov = np.zeros((len(refCat), 5, 5))
360 positionParameters = ["coord_ra", "coord_dec", "pm_ra", "pm_dec", "parallax"]
361 for i, pi in enumerate(positionParameters):
362 for j, pj in enumerate(positionParameters):
363 if i == j:
364 cov[:, i, i] = ((refCat[f"{pi}Err"].value) ** 2 * u.radian**2).to(u.marcsec**2).value
365 elif i > j:
366 cov[:, i, j] = (refCat[f"{pj}_{pi}_Cov"].value * u.radian**2).to_value(u.marcsec**2)
367 else:
368 cov[:, i, j] = (refCat[f"{pi}_{pj}_Cov"].value * u.radian**2).to_value(u.marcsec**2)
369 refCatalog = Table(
370 {
371 "id": refCat["id"],
372 "ra": ra,
373 "dec": dec,
374 "raPM": raPM,
375 "decPM": decPM,
376 "parallax": parallax,
377 "covariance": cov,
378 }
379 )
380 return refCatalog
382 def _load_sources(self, starSources, visitSummaries, inputSources):
383 """Load isolated sources and get visit information.
385 Parameters
386 ----------
387 starSources : `astropy.table.Table`
388 Catalog of associated sources.
389 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`]
390 Dictionary of catalogs with per-detector summary information keyed
391 by their visit id.
392 inputSources : `dict` [`int`, `DeferredDatasetHandle`]
393 Dictionary of source catalog handles, keyed by their visit id.
395 Returns
396 -------
397 allVisitStars : `astropy.table.Table`
398 Catalog with all needed information for associated sources.
399 visitInfo : `astropy.table.Table`
400 Catalog with observation epoch and location in ICRS coordinates.
401 """
402 visits = np.unique(starSources["visit"])
403 visits.sort()
404 observatories = []
405 mjds = []
406 allVisitStars = []
407 finalVisits = []
408 for visit in visits:
409 if (visit not in visitSummaries) or (visit not in inputSources):
410 continue
412 visitSummary = visitSummaries[visit].get()
413 finalVisits.append(visit)
414 visitInfo = visitSummary[0].visitInfo
416 # Get MJD
417 obsDate = visitInfo.getDate()
418 obsMJD = obsDate.get(obsDate.MJD)
419 mjds.append(obsMJD)
421 # Get the observatory ICRS position for use in fitting parallax
422 obsLon = visitInfo.observatory.getLongitude().asDegrees()
423 obsLat = visitInfo.observatory.getLatitude().asDegrees()
424 obsElev = visitInfo.observatory.getElevation()
425 earthLocation = astropy.coordinates.EarthLocation.from_geodetic(obsLon, obsLat, obsElev)
426 observatory_gcrs = earthLocation.get_gcrs(astropy.time.Time(obsMJD, format="mjd"))
427 observatory_icrs = observatory_gcrs.transform_to(astropy.coordinates.ICRS())
428 observatory = observatory_icrs.cartesian.xyz.to(u.AU).value
429 observatories.append(observatory)
431 # Load sources and keep isolated ones.
432 visitSources = inputSources[visit].get(
433 parameters={
434 "columns": [
435 "sourceId",
436 "ra",
437 "dec",
438 "raErr",
439 "decErr",
440 "ra_dec_Cov",
441 ]
442 }
443 )
444 DatasetProvenance.strip_provenance_from_flat_dict(visitSources.meta)
445 visitStars = join(
446 visitSources,
447 starSources[starSources["visit"] == visit],
448 keys="sourceId",
449 join_type="inner",
450 )
451 allVisitStars.append(visitStars)
452 allVisitStars = vstack(allVisitStars)
453 visitInfo = Table({"visit": finalVisits, "observatory": observatories, "mjd": mjds})
454 visitInfo.add_index("visit")
456 return allVisitStars, visitInfo
458 def _fit_objects(self, visitStars, starCatalogRef, starSources, visitInfo, fitEpoch, refCatalog=None):
459 """Fit full 5-d position, proper motion, and parallax for associated
460 sources.
462 Parameters
463 ----------
464 visitStars : `astropy.table.Table`
465 Catalog with position information for associated sources.
466 starCatalogRef : `DeferredDatasetHandle`
467 Handle pointing to catalog of objects corresponding to associated
468 sources.
469 starSources : `astropy.table.Table`
470 Catalog of associated sources.
471 visitInfo : `astropy.table.Table`
472 Catalog with observation epoch and location in ICRS coordinates.
473 fitEpoch : `astropy.time.Time`
474 Epoch at which to fit positions of objects.
475 refCatalog : `astropy.table.Table`, optional
476 Catalog of reference objects. Used if
477 self.config.includeReferenceCatalog is true.
479 Returns
480 -------
481 outCat : `astropy.table.Table`
482 Catalog with postion, proper motion and parallax for all input
483 objects, with NAN for objects without enough data to fit
484 parameters.
485 predictedPositions : `astropy.table.Table`
486 Catalog with predicted positions for all input sources at their
487 epoch observation, with NAN for objects with insufficient data.
488 """
490 starCatalog = starCatalogRef.get(parameters={"columns": ["isolated_star_id", "ra", "dec"]})
492 if self.config.includeReferenceCatalog:
493 starCoord = astropy.coordinates.SkyCoord(
494 starCatalog["ra"] * u.degree, starCatalog["dec"] * u.degree
495 )
496 refCoord = astropy.coordinates.SkyCoord(refCatalog["ra"], refCatalog["dec"])
497 refId, refD2d, _ = starCoord.match_to_catalog_sky(refCoord)
499 identity = wcsfit.IdentityMap()
500 icrs = wcsfit.SphericalICRS()
501 refWcs = wcsfit.Wcs(identity, icrs, "Identity", np.pi / 180.0)
503 objects = np.unique(visitStars["obj_index"])
504 objects.sort()
506 # Make empty arrays to fill in, with NaN for any unfittable objects.
507 objectPositions = np.ones((len(starCatalog), 5)) * np.nan
508 objectCovariances = np.ones((len(starCatalog), 5, 5)) * np.nan
509 predictedRADec = np.ones((len(starSources), 2)) * np.nan
510 referenceId = np.zeros(len(starCatalog), dtype=int)
511 nSources = np.zeros(len(starCatalog), dtype=int)
512 refPositions = Table(
513 np.ones((len(starCatalog), 5)) * np.nan,
514 names=("ref_ra", "ref_dec", "ref_raPM", "ref_decPM", "ref_parallax"),
515 dtype=("f8", "f8", "f8", "f8", "f8"),
516 )
517 for object in objects:
518 # Get all detections for this object.
519 detectionInds = visitStars["obj_index"] == object
520 detections = visitStars[detectionInds]
521 nDetections = len(detections)
522 scienceDetections = np.ones(len(detections), dtype=bool)
524 objectObservatories = visitInfo.loc[detections["visit"]]["observatory"]
525 objectMjds = visitInfo.loc[detections["visit"]]["mjd"]
527 # Move detections to be tangent plane around median position.
528 medRA = np.median(detections["ra"])
529 medDec = np.median(detections["dec"])
530 tangentPoint = lsst.geom.SpherePoint(medRA, medDec, lsst.geom.degrees)
531 cdMatrix = afwGeom.makeCdMatrix(1.0 * lsst.geom.degrees, 0 * lsst.geom.degrees, True)
532 iwcToSkyWcs = afwGeom.makeSkyWcs(lsst.geom.Point2D(0, 0), tangentPoint, cdMatrix)
533 tanX, tanY = iwcToSkyWcs.skyToPixelArray(detections["ra"], detections["dec"], degrees=True)
535 match = wcsfit.PMMatch(
536 tanX,
537 tanY,
538 detections["raErr"] ** 2,
539 detections["decErr"] ** 2,
540 detections["ra_dec_Cov"],
541 objectMjds,
542 objectObservatories,
543 medRA,
544 medDec,
545 fitEpoch.mjd,
546 )
548 if self.config.includeReferenceCatalog and (
549 refD2d[object].arcsecond < self.config.referenceMatchRadius
550 ):
551 nDetections += 1
552 refMatch = refCatalog[refId[object]]
553 match.addPMDetection(
554 refMatch["ra"],
555 refMatch["dec"],
556 refMatch["raPM"],
557 refMatch["decPM"],
558 refMatch["parallax"],
559 refMatch["covariance"],
560 refWcs,
561 )
562 scienceDetections = np.append(scienceDetections, False)
563 referenceId[object] = refMatch["id"]
564 refPositions[object] = refMatch[["ra", "dec", "raPM", "decPM", "parallax"]]
566 elif nDetections < 3:
567 # If there is no associated reference object, there must be at
568 # least three detections in order to fit the 5-d solution.
569 continue
571 # Solve, get best-fit position and covariance, and prediction for
572 # the object position at the detection epochs.
573 match.solve()
574 fullPosition = match.getFit()
575 objectPositions[object] = fullPosition
576 objectCovariances[object] = match.getFitCovariance()
577 nSources[object] = nDetections
578 predictedPositions = match.predictAtDetections()
579 predictedRADec[starSources.loc_indices[detections["sourceId"]]] = predictedPositions[
580 scienceDetections
581 ]
583 outCat = Table(objectPositions, names=self.config.positionNames)
584 outCat["referenceId"] = referenceId
585 for i, name1 in enumerate(self.config.positionNames):
586 for j, name2 in enumerate(self.config.positionNames[: i + 1]):
587 if i == j:
588 columnName = f"{name1}Err"
589 outCat[columnName] = objectCovariances[:, i, i] ** 0.5
590 else:
591 columnName = f"{name2}_{name1}_Cov"
592 outCat[columnName] = objectCovariances[:, i, j]
594 outCat = hstack([outCat, refPositions])
595 outCat["isolated_star_id"] = starCatalog["isolated_star_id"]
596 outCat["epoch"] = fitEpoch.mjd
598 predictedRADec = Table(predictedRADec, names=("ra", "dec"))
599 predictedRADec["sourceId"] = starSources["sourceId"]
601 return outCat, predictedRADec