Coverage for python / lsst / analysis / tools / tasks / sourceObjectTableAnalysis.py: 27%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:26 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "SourceObjectTableAnalysisConfig", 

25 "SourceObjectTableAnalysisTask", 

26 "ObjectEpochTableConfig", 

27 "ObjectEpochTableTask", 

28) 

29 

30import astropy.time 

31import astropy.units as u 

32import numpy as np 

33import pandas as pd 

34from astropy.table import Table, join, vstack 

35from smatch import Matcher 

36 

37import lsst.pex.config as pexConfig 

38import lsst.pipe.base as pipeBase 

39from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion 

40from lsst.pipe.base import AlgorithmError 

41from lsst.pipe.base import connectionTypes as ct 

42 

43from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

44 

45 

46class IndexMismatchError(AlgorithmError): 

47 """Raised if the indices in input associatedSources do not match the input 

48 data.""" 

49 

50 def __init__(self) -> None: 

51 super().__init__( 

52 "Not all sourceIds in the associated sources catalog are available in the input data." 

53 ) 

54 

55 @property 

56 def metadata(self) -> dict: 

57 """There is no metadata associated with this error.""" 

58 return {} 

59 

60 

61class NoMatchError(AlgorithmError): 

62 """Raised if there are no matches between the source and reference 

63 catalogs. This can happen if areas of the source or reference image were 

64 not processed successfully.""" 

65 

66 def __init__(self, targetCatalogSize, refCatalogSize) -> None: 

67 self._metadata = {"targetCatalogSize": targetCatalogSize, "refCatalogSize": refCatalogSize} 

68 super().__init__("No matches were made between the source and reference catalogs.") 

69 

70 @property 

71 def metadata(self) -> dict: 

72 for key, value in self._metadata.items(): 

73 if not isinstance(value, int | float | str): 

74 raise TypeError(f"{key} is of type {type(value)}, but only (int, float, str) are allowed.") 

75 return self._metadata 

76 

77 

78class ObjectEpochTableConnections( 

79 pipeBase.PipelineTaskConnections, 

80 dimensions=("tract", "skymap"), 

81): 

82 objectCat = ct.Input( 

83 doc="Catalog of positions in each patch.", 

84 name="objectTable", 

85 storageClass="ArrowAstropy", 

86 dimensions=["skymap", "tract", "patch"], 

87 multiple=True, 

88 deferLoad=True, 

89 deferGraphConstraint=True, 

90 ) 

91 

92 epochMap = ct.Input( 

93 doc="Healsparse map of mean epoch of objectCat in each band.", 

94 name="deepCoadd_epoch_map_mean", 

95 storageClass="HealSparseMap", 

96 dimensions=("skymap", "tract", "band"), 

97 multiple=True, 

98 deferLoad=True, 

99 ) 

100 

101 objectEpochs = ct.Output( 

102 doc="Catalog of epochs for objectCat objects.", 

103 name="object_epoch", 

104 storageClass="ArrowAstropy", 

105 dimensions=["skymap", "tract", "patch"], 

106 multiple=True, 

107 ) 

108 

109 

110class ObjectEpochTableConfig(pipeBase.PipelineTaskConfig, pipelineConnections=ObjectEpochTableConnections): 

111 bands = pexConfig.ListField( 

112 doc=("Bands in objectCat to be combined with `objectCat_selectors` to build objectCat column names."), 

113 dtype=str, 

114 default=["u", "g", "r", "i", "z", "y"], 

115 ) 

116 

117 

118class ObjectEpochTableTask(pipeBase.PipelineTask): 

119 """Collect mean epochs for the observations that went into each object. 

120 

121 TODO: DM-46202, Remove this task once the object epochs are available 

122 elsewhere. 

123 """ 

124 

125 ConfigClass = ObjectEpochTableConfig 

126 _DefaultName = "objectEpochTable" 

127 

128 def getEpochs(self, cat, epochMapDict): 

129 """Get mean epoch of the visits corresponding to object position. 

130 

131 Parameters 

132 ---------- 

133 cat : `astropy.table.Table` 

134 Catalog containing object positions. 

135 epochMapDict: `dict` [`DeferredDatasetHandle`] 

136 Dictionary of handles for healsparse maps containing the mean epoch 

137 for positions in the reference catalog. 

138 

139 Returns 

140 ------- 

141 epochDf = `astropy.table.Table` 

142 Catalog with mean epoch of visits at each object position. 

143 """ 

144 allEpochs = {} 

145 for band in self.config.bands: 

146 epochs = np.ones(len(cat)) * np.nan 

147 validPositions = np.isfinite(cat[f"{band}_ra"]) & np.isfinite(cat[f"{band}_dec"]) 

148 if validPositions.any(): 

149 bandEpochs = epochMapDict[band].get_values_pos( 

150 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions] 

151 ) 

152 epochsValid = epochMapDict[band].get_values_pos( 

153 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions], valid_mask=True 

154 ) 

155 bandEpochs[~epochsValid] = np.nan 

156 epochs[validPositions] = bandEpochs 

157 allEpochs[f"{band}_epoch"] = epochs 

158 allEpochs["objectId"] = cat["objectId"] 

159 

160 epochTable = Table(allEpochs) 

161 return epochTable 

162 

163 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

164 inputs = butlerQC.get(inputRefs) 

165 

166 columns = [f"{band}_{coord}" for band in self.config.bands for coord in ["ra", "dec"]] 

167 columns.append("objectId") 

168 

169 inputs["epochMap"] = {ref.dataId["band"]: ref.get() for ref in inputs["epochMap"]} 

170 

171 outputEpochRefs = {outputRef.dataId["patch"]: outputRef for outputRef in outputRefs.objectEpochs} 

172 for objectCatRef in inputs["objectCat"]: 

173 patch = objectCatRef.dataId["patch"] 

174 objectCat = objectCatRef.get(parameters={"columns": columns}) 

175 epochs = self.getEpochs(objectCat, inputs["epochMap"]) 

176 butlerQC.put(epochs, outputEpochRefs[patch]) 

177 

178 

179class SourceObjectTableAnalysisConnections( 

180 AnalysisBaseConnections, 

181 dimensions=("visit",), 

182 defaultTemplates={ 

183 "inputName": "sourceTable_visit", 

184 "inputCoaddName": "deep", 

185 "associatedSourcesInputName": "isolated_star_presources", 

186 "associatedSourceIdsInputName": "isolated_star_presource_associations", 

187 "outputName": "sourceObjectTable", 

188 }, 

189): 

190 data = ct.Input( 

191 doc="Visit based source table to load from the butler", 

192 name="sourceTable_visit", 

193 storageClass="ArrowAstropy", 

194 dimensions=("visit",), 

195 deferLoad=True, 

196 ) 

197 

198 associatedSources = ct.Input( 

199 doc="Table of associated sources", 

200 name="{associatedSourcesInputName}", 

201 storageClass="ArrowAstropy", 

202 multiple=True, 

203 deferLoad=True, 

204 dimensions=("instrument", "skymap", "tract"), 

205 deferGraphConstraint=True, 

206 ) 

207 

208 associatedSourceIds = ct.Input( 

209 doc="Table containing unique ids for the associated sources", 

210 name="{associatedSourceIdsInputName}", 

211 storageClass="ArrowAstropy", 

212 deferLoad=True, 

213 multiple=True, 

214 dimensions=("instrument", "skymap", "tract"), 

215 deferGraphConstraint=True, 

216 ) 

217 

218 refCat = ct.Input( 

219 doc="Catalog of positions to use as reference.", 

220 name="objectTable", 

221 storageClass="DataFrame", 

222 dimensions=["skymap", "tract", "patch"], 

223 multiple=True, 

224 deferLoad=True, 

225 deferGraphConstraint=True, 

226 ) 

227 astrometricCorrectionCatalog = ct.Input( 

228 doc="Catalog containing proper motions and parallaxes.", 

229 name="isolated_star_stellar_motions", 

230 storageClass="ArrowAstropy", 

231 dimensions=("instrument", "skymap", "tract"), 

232 multiple=True, 

233 deferLoad=True, 

234 ) 

235 refCatEpochs = ct.Input( 

236 doc="Catalog of epochs for refCat objects.", 

237 name="object_epoch", 

238 storageClass="ArrowAstropy", 

239 dimensions=["skymap", "tract", "patch"], 

240 multiple=True, 

241 deferLoad=True, 

242 ) 

243 visitTable = ct.Input( 

244 doc="Catalog containing visit information.", 

245 name="visitTable", 

246 storageClass="DataFrame", 

247 dimensions=("instrument",), 

248 ) 

249 

250 def __init__(self, *, config=None): 

251 super().__init__(config=config) 

252 

253 if not config.applyAstrometricCorrections: 

254 self.inputs.remove("astrometricCorrectionCatalog") 

255 self.inputs.remove("refCatEpochs") 

256 self.inputs.remove("visitTable") 

257 

258 

259class SourceObjectTableAnalysisConfig( 

260 AnalysisBaseConfig, pipelineConnections=SourceObjectTableAnalysisConnections 

261): 

262 ra_column = pexConfig.Field( 

263 doc="Name of column in refCat to use for right ascension.", 

264 dtype=str, 

265 default="r_ra", 

266 ) 

267 dec_column = pexConfig.Field( 

268 doc="Name of column in refCat to use for declination.", 

269 dtype=str, 

270 default="r_dec", 

271 ) 

272 epoch_column = pexConfig.Field( 

273 doc=( 

274 "Name of column in refCat corresponding to the epoch to which " 

275 "sources will be shifted. Should correspond to the positions in " 

276 "`ra_column` and `dec_column`." 

277 ), 

278 dtype=str, 

279 default="r_epoch", 

280 ) 

281 refCat_bands = pexConfig.ListField( 

282 doc=("Bands in refCat to be combined with `refCat_selectors` to build refCat column names."), 

283 dtype=str, 

284 default=["u", "g", "r", "i", "z", "y"], 

285 ) 

286 refCat_selectors = pexConfig.ListField( 

287 doc=( 

288 "Remove objects for which these flags are true. These strings are combined with `refCat_bands`" 

289 " to build the full refCat column names" 

290 ), 

291 dtype=str, 

292 default=["pixelFlags_saturated", "pixelFlags_saturatedCenter"], 

293 ) 

294 refCatMatchingRadius = pexConfig.Field( 

295 dtype=float, 

296 default=1.0, 

297 doc=( 

298 "Radius in mas with which to match the mean positions of the sources with the positions in the" 

299 " reference catalog." 

300 ), 

301 ) 

302 applyAstrometricCorrections = pexConfig.Field( 

303 dtype=bool, 

304 default=True, 

305 doc="Apply proper motions and parallaxes to source positions.", 

306 ) 

307 correctionsMatchingRadius = pexConfig.Field( 

308 dtype=float, 

309 default=0.2, 

310 doc=( 

311 "Radius in mas with which to match the mean positions of the sources with the positions in the" 

312 " astrometricCorrectionCatalog." 

313 ), 

314 ) 

315 astrometricCorrectionParameters = pexConfig.DictField( 

316 keytype=str, 

317 itemtype=str, 

318 default={ 

319 "ra": "ra", 

320 "dec": "dec", 

321 "pmRA": "raPM", 

322 "pmDec": "decPM", 

323 "parallax": "parallax", 

324 "isolated_star_id": "isolated_star_id", 

325 }, 

326 doc="Column names for position and motion parameters in the astrometric correction catalogs.", 

327 ) 

328 raiseIfNoMatches = pexConfig.Field( 

329 dtype=bool, 

330 default=True, 

331 doc="Raise NoMatchesFound error if there are no matches between the source and object catalogs.", 

332 ) 

333 

334 def setDefaults(self): 

335 super().setDefaults() 

336 from ..atools import TargetRefCatDeltaColorMetrics 

337 

338 self.atools.astromColorDiffMetrics = TargetRefCatDeltaColorMetrics 

339 

340 

341class SourceObjectTableAnalysisTask(AnalysisPipelineTask): 

342 ConfigClass = SourceObjectTableAnalysisConfig 

343 _DefaultName = "sourceObjectTableAnalysis" 

344 

345 def callback(self, inputs, dataId): 

346 """Callback function to be used with reconstructor.""" 

347 return self.prepareAssociatedSources( 

348 dataId["visit"], 

349 inputs["data"], 

350 inputs["associatedSources"], 

351 inputs["associatedSourceIds"], 

352 inputs["refCat"], 

353 inputs["visitTable"], 

354 inputs["astrometricCorrectionCatalog"], 

355 ) 

356 

357 def applyAstrometricCorrections( 

358 self, isolatedSources, astrometricCorrectionCatalog, visitTable, visit, refEpochs 

359 ): 

360 """Shift source positions to match the epoch of the reference catalog 

361 objects. 

362 

363 Parameters 

364 ---------- 

365 isolatedSources : `astropy.table.Table` 

366 Catalog of sources which will be modified in place with the 

367 astrometric corrections. 

368 astrometricCorrectionCatalog : `astropy.table.Table` 

369 Catalog with proper motion and parallax information. 

370 visitTable : `pd.DataFrame` 

371 Catalog containing the epoch for the visit corresponding to the 

372 isolatedSources. 

373 visit : `int` 

374 Identifier of the isolatedSources' visit. 

375 """ 

376 if visitTable.index.name is None: 

377 # The expected index may or may not be set, depending on whether 

378 # the table was written originally as a DataFrame or something else 

379 # Parquet-friendly. 

380 visitTable.set_index("visitId", inplace=True) 

381 sourceMjd = visitTable.loc[visit]["expMidptMJD"] 

382 

383 # Get target date from reference catalog 

384 targetEpochs = refEpochs.to_numpy() 

385 # There may not be a valid reference epoch on the edge of a given 

386 # region. Do not make an astrometric correction for any sources on the 

387 # edge. 

388 targetEpochs[~np.isfinite(targetEpochs)] = sourceMjd 

389 

390 # Get the stellar motion catalog into the right format: 

391 for key, value in self.config.astrometricCorrectionParameters.items(): 

392 astrometricCorrectionCatalog.rename_column(value, key) 

393 astrometricCorrectionCatalog["ra"] *= u.degree 

394 astrometricCorrectionCatalog["dec"] *= u.degree 

395 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr 

396 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr 

397 astrometricCorrectionCatalog["parallax"] *= u.mas 

398 

399 joinedData = join( 

400 isolatedSources[["isolated_star_id"]], 

401 astrometricCorrectionCatalog, 

402 keys="isolated_star_id", 

403 join_type="left", 

404 keep_order=True, 

405 metadata_conflicts="silent", 

406 ) 

407 joinedData["MJD"] = astropy.time.Time(sourceMjd, format="mjd", scale="tai") 

408 

409 raCorrection, decCorrection = calculate_apparent_motion( 

410 joinedData, astropy.time.Time(targetEpochs, format="mjd", scale="tai") 

411 ) 

412 

413 isolatedSources["coord_ra"] -= raCorrection.value 

414 isolatedSources["coord_dec"] -= decCorrection.value 

415 

416 def prepareAssociatedSources( 

417 self, 

418 visit, 

419 data, 

420 associatedSourceRefs, 

421 associatedSourceIdRefs, 

422 refCats, 

423 visitTable, 

424 astrometricCorrectionCatalog, 

425 ): 

426 """Match isolated sources with reference objects and shift the sources 

427 to the object epochs if `self.config.applyAstrometricCorrections` is 

428 True. 

429 

430 Parameters 

431 ---------- 

432 visit : `int` 

433 Identifier of the visit corresponding to the data. 

434 data : `astropy.table.Table` 

435 Catalog of sources to be associated. 

436 associatedSourceRefs : `list` [`DeferredDatasetHandle`] 

437 Handle for the catalogs of isolated sources. There will be multiple 

438 if the visit overlaps with multiple tracts. 

439 refCats : `list` [`pd.DataFrame`] 

440 Catalog of objects with which the sources will be compared. 

441 visitTable : `pd.DataFrame` 

442 Catalog containing the epoch for the visit corresponding to the 

443 isolatedSources. 

444 astrometricCorrectionCatalog : `astropy.table.Table` 

445 Catalog with proper motion and parallax information. 

446 """ 

447 isolatedSources = [] 

448 associatedSourceIds = { 

449 ref.dataId["tract"]: ref.get(parameters={"columns": ["isolated_star_id"]}) 

450 for ref in associatedSourceIdRefs 

451 } 

452 for associatedSourceRef in associatedSourceRefs: 

453 tract = associatedSourceRef.dataId["tract"] 

454 associatedSources = associatedSourceRef.get( 

455 parameters={"columns": ["visit", "sourceId", "obj_index"]} 

456 ) 

457 index = associatedSources["obj_index"] 

458 associatedSources["isolated_star_id"] = associatedSourceIds[tract]["isolated_star_id"][index] 

459 

460 visit_sources = associatedSources[associatedSources["visit"] == visit] 

461 try: 

462 visitData = data.loc[visit_sources["sourceId"]] 

463 visitData["isolated_star_id"] = visit_sources["isolated_star_id"] 

464 isolatedSources.append(visitData) 

465 except KeyError: 

466 raise IndexMismatchError() 

467 isolatedSources = vstack(isolatedSources) 

468 

469 if len(isolatedSources) == 0: 

470 raise pipeBase.NoWorkFound(f"No isolated sources found for visit {visit}") 

471 

472 with Matcher(np.asarray(isolatedSources["coord_ra"]), np.asarray(isolatedSources["coord_dec"])) as m: 

473 idx, isolatedMatchIndices, refMatchIndices, dists = m.query_radius( 

474 np.asarray(refCats[self.config.ra_column]), 

475 np.asarray(refCats[self.config.dec_column]), 

476 self.config.refCatMatchingRadius / 3600.0, 

477 return_indices=True, 

478 ) 

479 

480 matchIS = isolatedSources[isolatedMatchIndices] 

481 

482 if len(matchIS) == 0 and self.config.raiseIfNoMatches: 

483 raise NoMatchError(len(isolatedSources), len(refCats)) 

484 

485 # Apply proper motions and parallaxes to visit sources. 

486 if self.config.applyAstrometricCorrections: 

487 refCatEpochs = refCats[self.config.epoch_column].iloc[refMatchIndices] 

488 self.applyAstrometricCorrections( 

489 matchIS, astrometricCorrectionCatalog, visitTable, visit, refCatEpochs 

490 ) 

491 

492 matchRef = refCats.iloc[refMatchIndices] 

493 matchIS = matchIS.to_pandas() 

494 

495 allCat = pd.concat([matchRef.reset_index(), matchIS.reset_index()], axis=1) 

496 return allCat 

497 

498 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

499 inputs = butlerQC.get(inputRefs) 

500 

501 dataId = butlerQC.quantum.dataId 

502 plotInfo = self.parsePlotInfo(inputs, dataId) 

503 

504 # Get isolated sources: 

505 visit = inputs["data"].dataId["visit"] 

506 band = inputs["data"].dataId["band"] 

507 names = self.collectInputNames() 

508 names -= {self.config.ra_column, self.config.dec_column} 

509 names.add("sourceId") 

510 data = inputs["data"].get(parameters={"columns": names}) 

511 data.add_index("sourceId") 

512 inputs["data"] = data 

513 

514 if self.config.applyAstrometricCorrections: 

515 refCatEpochs = { 

516 epochTable.dataId["patch"]: epochTable.get() for epochTable in inputs["refCatEpochs"] 

517 } 

518 # Get objects: 

519 allRefCats = [] 

520 refCatSelectors = [ 

521 f"{refCatBand}_{selector}" 

522 for refCatBand in self.config.refCat_bands 

523 for selector in self.config.refCat_selectors 

524 ] 

525 

526 for refCatRef in inputs["refCat"]: 

527 refCat = refCatRef.get( 

528 parameters={ 

529 "columns": ["detect_isPrimary", self.config.ra_column, self.config.dec_column, "objectId"] 

530 + refCatSelectors 

531 } 

532 ) 

533 refCat.set_index("objectId") 

534 if self.config.applyAstrometricCorrections: 

535 refCat = pd.merge(refCat, refCatEpochs[refCatRef.dataId["patch"]].to_pandas(), on="objectId") 

536 goodInds = ( 

537 refCat["detect_isPrimary"] 

538 & np.isfinite(refCat[self.config.ra_column]) 

539 & np.isfinite(refCat[self.config.dec_column]) 

540 ) 

541 goodInds &= ~refCat[refCatSelectors].any(axis=1) 

542 allRefCats.append(refCat[goodInds]) 

543 

544 refCat = pd.concat(allRefCats) 

545 inputs["refCat"] = refCat 

546 if len(refCat) == 0: 

547 raise pipeBase.NoWorkFound(f"No reference catalog objects found to associate with visit {visit}") 

548 

549 if self.config.applyAstrometricCorrections: 

550 pmCats = [] 

551 for astrometricCorrectionCatalogRef in inputs["astrometricCorrectionCatalog"]: 

552 pmCat = astrometricCorrectionCatalogRef.get( 

553 parameters={"columns": self.config.astrometricCorrectionParameters.values()} 

554 ) 

555 pmCats.append(pmCat) 

556 inputs["astrometricCorrectionCatalog"] = vstack(pmCats, metadata_conflicts="silent") 

557 else: 

558 inputs["astrometricCorrectionCatalog"] = None 

559 inputs["visitTable"] = None 

560 

561 try: 

562 allCat = self.callback(inputs, dataId) 

563 except pipeBase.AlgorithmError as e: 

564 error = pipeBase.AnnotatedPartialOutputsError.annotate(e, self, log=self.log) 

565 raise error from e 

566 

567 outputs = self.run(data=allCat, bands=band, plotInfo=plotInfo) 

568 butlerQC.put(outputs, outputRefs)