Coverage for python / lsst / drp / tasks / fit_stellar_motion.py: 17%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:19 +0000

1# This file is part of drp_tasks. 

2# 

3# LSST Data Management System 

4# This product includes software developed by the 

5# LSST Project (http://www.lsst.org/). 

6# See COPYRIGHT file at the top of the source tree. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <https://www.lsstcorp.org/LegalNotices/>. 

21# 

22 

23__all__ = ["FitStellarMotionConfig", "FitStellarMotionConnections", "FitStellarMotionTask"] 

24 

25import astropy.coordinates 

26import astropy.units as u 

27import numpy as np 

28import wcsfit 

29from astropy.table import Table, hstack, join, vstack 

30 

31import lsst.afw.geom as afwGeom 

32import lsst.geom 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader 

36from lsst.skymap import BaseSkyMap 

37 

38 

39class FitStellarMotionConnections( 

40 pipeBase.PipelineTaskConnections, 

41 dimensions=( 

42 "instrument", 

43 "tract", 

44 "skymap", 

45 ), 

46): 

47 visitSummaries = pipeBase.connectionTypes.Input( 

48 doc=( 

49 "Per-visit consolidated exposure metadata built from calexps. " 

50 "These catalogs use detector id for the id and must be sorted for " 

51 "fast lookups of a detector." 

52 ), 

53 name="preliminary_visit_summary", 

54 storageClass="ExposureCatalog", 

55 dimensions=("instrument", "visit"), 

56 multiple=True, 

57 deferLoad=True, 

58 ) 

59 starSourceRef = pipeBase.connectionTypes.Input( 

60 doc="Catalog of matched sources.", 

61 name="isolated_star", 

62 storageClass="ArrowAstropy", 

63 dimensions=( 

64 "instrument", 

65 "skymap", 

66 "tract", 

67 ), 

68 deferLoad=True, 

69 ) 

70 starCatalogRef = pipeBase.connectionTypes.Input( 

71 doc="Catalog of objects corresponding to the matched sources.", 

72 name="isolated_star_association", 

73 storageClass="ArrowAstropy", 

74 dimensions=( 

75 "instrument", 

76 "skymap", 

77 "tract", 

78 ), 

79 deferLoad=True, 

80 ) 

81 inputSources = pipeBase.connectionTypes.Input( 

82 doc="Source table in parquet format, per visit.", 

83 name="recalibrated_star", 

84 storageClass="ArrowAstropy", 

85 dimensions=("instrument", "visit"), 

86 deferLoad=True, 

87 multiple=True, 

88 ) 

89 referenceCatalog = pipeBase.connectionTypes.PrerequisiteInput( 

90 doc="The astrometry reference catalog to match to loaded input catalog sources.", 

91 name="the_monster_20250219", 

92 storageClass="SimpleCatalog", 

93 dimensions=("skypix",), 

94 deferLoad=True, 

95 multiple=True, 

96 ) 

97 skymap = pipeBase.connectionTypes.Input( 

98 doc="Input definition of bbox containing the associated sources.", 

99 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

100 storageClass="SkyMap", 

101 dimensions=("skymap",), 

102 ) 

103 visitTable = pipeBase.connectionTypes.Input( 

104 doc="Survey-wide table of visits, which will be used to get median epoch.", 

105 name="preliminary_visit_table", 

106 storageClass="ArrowAstropy", 

107 dimensions=("instrument",), 

108 deferLoad=True, 

109 ) 

110 outputCatalog = pipeBase.connectionTypes.Output( 

111 doc="Best fit position, proper motion and parallax for input objects.", 

112 name="isolated_star_stellar_motions", 

113 storageClass="ArrowAstropy", 

114 dimensions=( 

115 "instrument", 

116 "skymap", 

117 "tract", 

118 ), 

119 ) 

120 predictedPositions = pipeBase.connectionTypes.Output( 

121 doc="Predicted position for each source at the epoch of observation.", 

122 name="isolated_star_predicted_positions", 

123 storageClass="ArrowAstropy", 

124 dimensions=( 

125 "instrument", 

126 "skymap", 

127 "tract", 

128 ), 

129 ) 

130 

131 def __init__(self, *, config=None): 

132 super().__init__(config=config) 

133 

134 if not self.config.includeReferenceCatalog: 

135 self.inputs.remove("referenceCatalog") 

136 self.inputs.remove("skymap") 

137 if self.config.outputEpoch: 

138 self.inputs.remove("visitTable") 

139 

140 

141class FitStellarMotionConfig(pipeBase.PipelineTaskConfig, pipelineConnections=FitStellarMotionConnections): 

142 includeReferenceCatalog = pexConfig.Field( 

143 doc="Include the reference catalog in the fit.", 

144 dtype=bool, 

145 default=True, 

146 ) 

147 referenceFilter = pexConfig.Field( 

148 dtype=str, 

149 doc="Name of filter to load from reference catalog. This is a required argument, although the values" 

150 "returned are not used.", 

151 default="phot_g_mean", 

152 ) 

153 referenceMatchRadius = pexConfig.Field( 

154 dtype=float, 

155 doc="Maximum matching distance in arcseconds between the star catalog and the reference catalog.", 

156 default=0.1, 

157 ) 

158 outputEpoch = pexConfig.Field( 

159 dtype=float, 

160 doc="Epoch to which output positions will correspond. If not set, the median epoch of all visits in " 

161 "visitTable will be used.", 

162 default=None, 

163 optional=True, 

164 ) 

165 

166 

167class FitStellarMotionTask(pipeBase.PipelineTask): 

168 """Fit proper motion and parallax for associated sources. 

169 

170 Input sources are assumed to be isolated point sources. 

171 """ 

172 

173 ConfigClass = FitStellarMotionConfig 

174 _DefaultName = "fitStellarMotions" 

175 

176 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

177 # Override runQuantum to set up the refObjLoaders and turn input lists 

178 # into dicts. 

179 inputs = butlerQC.get(inputRefs) 

180 

181 inputSourceDict = {inputSource.dataId["visit"]: inputSource for inputSource in inputs["inputSources"]} 

182 inputs["inputSources"] = inputSourceDict 

183 visitSummaryDict = { 

184 visitSummary.dataId["visit"]: visitSummary for visitSummary in inputs["visitSummaries"] 

185 } 

186 inputs["visitSummaries"] = visitSummaryDict 

187 

188 if self.config.includeReferenceCatalog: 

189 tractId = inputs["starCatalogRef"].dataId["tract"] 

190 skymap = inputs.pop("skymap") 

191 tractRegion = skymap.generateTract(tractId).outer_sky_polygon 

192 

193 refConfig = LoadReferenceObjectsConfig() 

194 refConfig.requireProperMotion = True 

195 refObjectLoader = ReferenceObjectLoader( 

196 dataIds=[ref.datasetRef.dataId for ref in inputRefs.referenceCatalog], 

197 refCats=inputs.pop("referenceCatalog"), 

198 config=refConfig, 

199 log=self.log, 

200 ) 

201 else: 

202 refObjectLoader = None 

203 tractRegion = None 

204 

205 if self.config.outputEpoch: 

206 epoch = astropy.time.Time(self.config.outputEpoch, format="mjd") 

207 else: 

208 # Use the median epoch of all visits in the survey. 

209 visitTable = inputs.pop("visitTable") 

210 allVisits = visitTable.get(parameters={"columns": ["expMidptMJD"]}) 

211 epoch = astropy.time.Time(np.median(allVisits["expMidptMJD"]), format="mjd") 

212 

213 output = self.run(**inputs, epoch=epoch, refObjectLoader=refObjectLoader, tractRegion=tractRegion) 

214 

215 butlerQC.put(output.outputCatalog, outputRefs.outputCatalog) 

216 butlerQC.put(output.predictedPositions, outputRefs.predictedPositions) 

217 

218 def run( 

219 self, 

220 starSourceRef, 

221 inputSources, 

222 starCatalogRef, 

223 visitSummaries, 

224 epoch, 

225 refObjectLoader=None, 

226 tractRegion=None, 

227 ): 

228 """Fit proper motion and parallax for isolated stars. 

229 

230 Parameters 

231 ---------- 

232 starSourceRef : `DeferredDatasetHandle` 

233 Handle pointing to catalog of associated sources. 

234 inputSources : `dict` [`int`, `DeferredDatasetHandle`] 

235 Dictionary of source catalog handles, keyed by their visit id. 

236 starCatalogRef : `DeferredDatasetHandle` 

237 Handle pointing to catalog of objects corresponding to associated 

238 sources. 

239 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`] 

240 Dictionary of catalogs with per-detector summary information, keyed 

241 by their visit id. 

242 epoch : `float` 

243 Epoch in MJD at which to fit positions of objects. 

244 refObjectLoader : 

245 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader`, 

246 optional 

247 Reference object loader 

248 tractRegion : `lsst.sphgeom.Region` 

249 Region containing the associated sources. 

250 

251 Returns 

252 ------- 

253 result : `lsst.pipe.Base.Struct` 

254 ``outputCatalog`` : `astropy.table.Table` 

255 Catalog with postion, proper motion and parallax for all input 

256 objects, with NAN for objects without enough data to fit 

257 parameters. 

258 ``predictedPositions`` : `astropy.table.Table` 

259 Catalog with predicted positions for all input sources at their 

260 epoch observation, with NAN for objects with insufficient data. 

261 """ 

262 # Load needed columns for associated sources. 

263 starSources = starSourceRef.get(parameters={"columns": ["visit", "sourceId", "obj_index"]}) 

264 if not starSources: 

265 raise pipeBase.NoWorkFound("No isolated stars found in this region.") 

266 

267 starSources.add_index("sourceId") 

268 

269 # Load reference objects. 

270 if self.config.includeReferenceCatalog: 

271 refCatalog = self._load_refCat(refObjectLoader, tractRegion, epoch) 

272 else: 

273 refCatalog = None 

274 

275 # Load needed columns from source catalogs and get visit info. 

276 visitStars, visitInfo = self._load_sources(starSources, visitSummaries, inputSources) 

277 

278 # Fit postion, proper motion and parallax for all objects. 

279 outCat, predictedRADec = self._fit_objects( 

280 visitStars, starCatalogRef, starSources, visitInfo, epoch, refCatalog=refCatalog 

281 ) 

282 

283 return pipeBase.Struct(outputCatalog=outCat, predictedPositions=predictedRADec) 

284 

285 def _load_refCat(self, refObjectLoader, region, epoch): 

286 """Load reference catalog. 

287 

288 Parameters 

289 ---------- 

290 refObjectLoader : 

291 `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader` 

292 Reference object loader 

293 tractRegion : `lsst.sphgeom.Region` 

294 Region containing the associated sources. 

295 epoch : `astropy.time.Time` 

296 Epoch to which the reference catalog will be shifted. 

297 

298 Returns 

299 ------- 

300 refCatalog : `astropy.table.Table` 

301 Catalog of reference objects. 

302 """ 

303 

304 refCat = refObjectLoader.loadRegion(region, self.config.referenceFilter, epoch=epoch).refCat 

305 refCat = refCat.asAstropy() 

306 

307 # In Gaia DR3, missing values are denoted by NaNs. 

308 finiteInd = np.isfinite(refCat["coord_ra"]) & np.isfinite(refCat["coord_dec"]) 

309 refCat = refCat[finiteInd] 

310 

311 ra = (refCat["coord_ra"]).to(u.degree) 

312 dec = (refCat["coord_dec"]).to(u.degree) 

313 raPM = (refCat["pm_ra"]).to(u.marcsec) 

314 decPM = (refCat["pm_dec"]).to(u.marcsec) 

315 parallax = (refCat["parallax"]).to(u.marcsec) 

316 

317 cov = np.zeros((len(refCat), 5, 5)) 

318 positionParameters = ["coord_ra", "coord_dec", "pm_ra", "pm_dec", "parallax"] 

319 for i, pi in enumerate(positionParameters): 

320 for j, pj in enumerate(positionParameters): 

321 if i == j: 

322 cov[:, i, i] = ((refCat[f"{pi}Err"].value) ** 2 * u.radian**2).to(u.marcsec**2).value 

323 elif i > j: 

324 cov[:, i, j] = (refCat[f"{pj}_{pi}_Cov"].value * u.radian**2).to_value(u.marcsec**2) 

325 else: 

326 cov[:, i, j] = (refCat[f"{pi}_{pj}_Cov"].value * u.radian**2).to_value(u.marcsec**2) 

327 refCatalog = Table( 

328 {"ra": ra, "dec": dec, "raPM": raPM, "decPM": decPM, "parallax": parallax, "covariance": cov} 

329 ) 

330 return refCatalog 

331 

332 def _load_sources(self, starSources, visitSummaries, inputSources): 

333 """Load isolated sources and get visit information. 

334 

335 Parameters 

336 ---------- 

337 starSources : `astropy.table.Table` 

338 Catalog of associated sources. 

339 visitSummaries : `dict` [`int`, `lsst.afw.table.ExposureCatalog`] 

340 Dictionary of catalogs with per-detector summary information keyed 

341 by their visit id. 

342 inputSources : `dict` [`int`, `DeferredDatasetHandle`] 

343 Dictionary of source catalog handles, keyed by their visit id. 

344 

345 Returns 

346 ------- 

347 allVisitStars : `astropy.table.Table` 

348 Catalog with all needed information for associated sources. 

349 visitInfo : `astropy.table.Table` 

350 Catalog with observation epoch and location in ICRS coordinates. 

351 """ 

352 visits = np.unique(starSources["visit"]) 

353 visits.sort() 

354 observatories = [] 

355 mjds = [] 

356 allVisitStars = [] 

357 finalVisits = [] 

358 for visit in visits: 

359 if (visit not in visitSummaries) or (visit not in inputSources): 

360 continue 

361 

362 visitSummary = visitSummaries[visit].get() 

363 finalVisits.append(visit) 

364 visitInfo = visitSummary[0].visitInfo 

365 

366 # Get MJD 

367 obsDate = visitInfo.getDate() 

368 obsMJD = obsDate.get(obsDate.MJD) 

369 mjds.append(obsMJD) 

370 

371 # Get the observatory ICRS position for use in fitting parallax 

372 obsLon = visitInfo.observatory.getLongitude().asDegrees() 

373 obsLat = visitInfo.observatory.getLatitude().asDegrees() 

374 obsElev = visitInfo.observatory.getElevation() 

375 earthLocation = astropy.coordinates.EarthLocation.from_geodetic(obsLon, obsLat, obsElev) 

376 observatory_gcrs = earthLocation.get_gcrs(astropy.time.Time(obsMJD, format="mjd")) 

377 observatory_icrs = observatory_gcrs.transform_to(astropy.coordinates.ICRS()) 

378 observatory = observatory_icrs.cartesian.xyz.to(u.AU).value 

379 observatories.append(observatory) 

380 

381 # Load sources and keep isolated ones. 

382 visitSources = inputSources[visit].get( 

383 parameters={ 

384 "columns": [ 

385 "sourceId", 

386 "ra", 

387 "dec", 

388 "raErr", 

389 "decErr", 

390 "ra_dec_Cov", 

391 ] 

392 } 

393 ) 

394 visitStars = join( 

395 visitSources, 

396 starSources[starSources["visit"] == visit], 

397 keys="sourceId", 

398 join_type="inner", 

399 ) 

400 allVisitStars.append(visitStars) 

401 allVisitStars = vstack(allVisitStars) 

402 visitInfo = Table({"visit": finalVisits, "observatory": observatories, "mjd": mjds}) 

403 visitInfo.add_index("visit") 

404 

405 return allVisitStars, visitInfo 

406 

407 def _fit_objects(self, visitStars, starCatalogRef, starSources, visitInfo, fitEpoch, refCatalog=None): 

408 """Fit full 5-d position, proper motion, and parallax for associated 

409 sources. 

410 

411 Parameters 

412 ---------- 

413 visitStars : `astropy.table.Table` 

414 Catalog with position information for associated sources. 

415 starCatalogRef : `DeferredDatasetHandle` 

416 Handle pointing to catalog of objects corresponding to associated 

417 sources. 

418 starSources : `astropy.table.Table` 

419 Catalog of associated sources. 

420 visitInfo : `astropy.table.Table` 

421 Catalog with observation epoch and location in ICRS coordinates. 

422 fitEpoch : `astropy.time.Time` 

423 Epoch at which to fit positions of objects. 

424 refCatalog : `astropy.table.Table`, optional 

425 Catalog of reference objects. Used if 

426 self.config.includeReferenceCatalog is true. 

427 

428 Returns 

429 ------- 

430 outCat : `astropy.table.Table` 

431 Catalog with postion, proper motion and parallax for all input 

432 objects, with NAN for objects without enough data to fit 

433 parameters. 

434 predictedPositions : `astropy.table.Table` 

435 Catalog with predicted positions for all input sources at their 

436 epoch observation, with NAN for objects with insufficient data. 

437 """ 

438 

439 starCatalog = starCatalogRef.get(parameters={"columns": ["isolated_star_id", "ra", "dec"]}) 

440 

441 if self.config.includeReferenceCatalog: 

442 starCoord = astropy.coordinates.SkyCoord( 

443 starCatalog["ra"] * u.degree, starCatalog["dec"] * u.degree 

444 ) 

445 refCoord = astropy.coordinates.SkyCoord(refCatalog["ra"], refCatalog["dec"]) 

446 refId, refD2d, _ = starCoord.match_to_catalog_sky(refCoord) 

447 

448 identity = wcsfit.IdentityMap() 

449 icrs = wcsfit.SphericalICRS() 

450 refWcs = wcsfit.Wcs(identity, icrs, "Identity", np.pi / 180.0) 

451 

452 objects = np.unique(visitStars["obj_index"]) 

453 objects.sort() 

454 

455 # Make empty arrays to fill in, with NaN for any unfittable objects. 

456 objectPositions = np.ones((len(starCatalog), 5)) * np.nan 

457 objectCovariances = np.ones((len(starCatalog), 5, 5)) * np.nan 

458 predictedRADec = np.ones((len(starSources), 2)) * np.nan 

459 includesReference = np.zeros(len(starCatalog), dtype=bool) 

460 nSources = np.zeros(len(starCatalog), dtype=int) 

461 refPositions = Table( 

462 np.ones((len(starCatalog), 5)) * np.nan, 

463 names=("ref_ra", "ref_dec", "ref_raPM", "ref_decPM", "ref_covariance"), 

464 dtype=("f8", "f8", "f8", "f8", "f8"), 

465 ) 

466 refCovariances = np.ones((len(starCatalog), 5, 5)) * np.nan 

467 for object in objects: 

468 # Get all detections for this object. 

469 detectionInds = visitStars["obj_index"] == object 

470 detections = visitStars[detectionInds] 

471 nDetections = len(detections) 

472 scienceDetections = np.ones(len(detections), dtype=bool) 

473 

474 objectObservatories = visitInfo.loc[detections["visit"]]["observatory"] 

475 objectMjds = visitInfo.loc[detections["visit"]]["mjd"] 

476 

477 # Move detections to be tangent plane around median position. 

478 medRA = np.median(detections["ra"]) 

479 medDec = np.median(detections["dec"]) 

480 tangentPoint = lsst.geom.SpherePoint(medRA, medDec, lsst.geom.degrees) 

481 cdMatrix = afwGeom.makeCdMatrix(1.0 * lsst.geom.degrees, 0 * lsst.geom.degrees, True) 

482 iwcToSkyWcs = afwGeom.makeSkyWcs(lsst.geom.Point2D(0, 0), tangentPoint, cdMatrix) 

483 tanX, tanY = iwcToSkyWcs.skyToPixelArray(detections["ra"], detections["dec"], degrees=True) 

484 

485 match = wcsfit.PMMatch( 

486 tanX, 

487 tanY, 

488 detections["raErr"] ** 2, 

489 detections["decErr"] ** 2, 

490 detections["ra_dec_Cov"], 

491 objectMjds, 

492 objectObservatories, 

493 medRA, 

494 medDec, 

495 fitEpoch.mjd, 

496 ) 

497 

498 if self.config.includeReferenceCatalog and ( 

499 refD2d[object].arcsecond < self.config.referenceMatchRadius 

500 ): 

501 nDetections += 1 

502 refMatch = refCatalog[refId[object]] 

503 match.addPMDetection( 

504 refMatch["ra"], 

505 refMatch["dec"], 

506 refMatch["raPM"], 

507 refMatch["decPM"], 

508 refMatch["parallax"], 

509 refMatch["covariance"], 

510 refWcs, 

511 ) 

512 scienceDetections = np.append(scienceDetections, False) 

513 includesReference[object] = True 

514 refPositions[object] = refMatch[["ra", "dec", "raPM", "decPM", "parallax"]] 

515 refCovariances[object] = refMatch["covariance"] 

516 

517 elif nDetections < 3: 

518 # If there is no associated reference object, there must be at 

519 # least three detections in order to fit the 5-d solution. 

520 continue 

521 

522 # Solve, get best-fit position and covariance, and prediction for 

523 # the object position at the detection epochs. 

524 match.solve() 

525 fullPosition = match.getFit() 

526 objectPositions[object] = fullPosition 

527 objectCovariances[object] = match.getFitCovariance() 

528 nSources[object] = nDetections 

529 predictedPositions = match.predictAtDetections() 

530 predictedRADec[starSources.loc_indices[detections["sourceId"]]] = predictedPositions[ 

531 scienceDetections 

532 ] 

533 

534 outCat = Table(objectPositions, names=("ra", "dec", "raPM", "decPM", "parallax")) 

535 outCat["hasReference"] = includesReference 

536 outCat["covariance"] = objectCovariances 

537 outCat = hstack([outCat, refPositions]) 

538 outCat["ref_covariance"] = refCovariances 

539 outCat["isolated_star_id"] = starCatalog["isolated_star_id"] 

540 outCat.meta["epoch"] = fitEpoch 

541 

542 predictedRADec = Table(predictedRADec, names=("ra", "dec")) 

543 predictedRADec["sourceId"] = starSources["sourceId"] 

544 

545 return outCat, predictedRADec