Coverage for python / lsst / analysis / tools / tasks / associatedSourcesTractAnalysis.py: 28%

104 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import astropy.time 

26import astropy.units as u 

27import lsst.pex.config as pexConfig 

28import numpy as np 

29from astropy.table import join, vstack 

30from lsst.daf.butler import DatasetProvenance 

31from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion 

32from lsst.geom import Box2D 

33from lsst.pipe.base import NoWorkFound 

34from lsst.pipe.base import connectionTypes as ct 

35from lsst.skymap import BaseSkyMap 

36 

37from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

38 

39 

40class AssociatedSourcesTractAnalysisConnections( 

41 AnalysisBaseConnections, 

42 dimensions=("skymap", "tract", "instrument"), 

43 defaultTemplates={ 

44 "outputName": "isolated_star_presources", 

45 "associatedSourcesInputName": "isolated_star_presources", 

46 "associatedSourceIdsInputName": "isolated_star_presource_associations", 

47 }, 

48): 

49 sourceCatalogs = ct.Input( 

50 doc="Visit based source table to load from the butler", 

51 name="sourceTable_visit", 

52 storageClass="ArrowAstropy", 

53 deferLoad=True, 

54 dimensions=("visit", "band"), 

55 multiple=True, 

56 ) 

57 

58 associatedSources = ct.Input( 

59 doc="Table of associated sources", 

60 name="{associatedSourcesInputName}", 

61 storageClass="ArrowAstropy", 

62 deferLoad=True, 

63 dimensions=("instrument", "skymap", "tract"), 

64 ) 

65 

66 associatedSourceIds = ct.Input( 

67 doc="Table containing unique ids for the associated sources", 

68 name="{associatedSourceIdsInputName}", 

69 storageClass="ArrowAstropy", 

70 deferLoad=True, 

71 dimensions=("instrument", "skymap", "tract"), 

72 ) 

73 

74 skyMap = ct.Input( 

75 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

76 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

77 storageClass="SkyMap", 

78 dimensions=("skymap",), 

79 ) 

80 

81 camera = ct.PrerequisiteInput( 

82 doc="Input camera to use for focal plane geometry.", 

83 name="camera", 

84 storageClass="Camera", 

85 dimensions=("instrument",), 

86 isCalibration=True, 

87 ) 

88 astrometricCorrectionCatalog = ct.Input( 

89 doc="Catalog with proper motion and parallax information.", 

90 name="isolated_star_stellar_motions", 

91 storageClass="ArrowAstropy", 

92 deferLoad=True, 

93 dimensions=("instrument", "skymap", "tract"), 

94 ) 

95 

96 visitTable = ct.Input( 

97 doc="Catalog containing visit information.", 

98 name="visitTable", 

99 storageClass="DataFrame", 

100 dimensions=("instrument",), 

101 ) 

102 

103 def __init__(self, *, config=None): 

104 super().__init__(config=config) 

105 

106 if not config.applyAstrometricCorrections: 

107 self.inputs.remove("astrometricCorrectionCatalog") 

108 self.inputs.remove("visitTable") 

109 

110 

111class AssociatedSourcesTractAnalysisConfig( 

112 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

113): 

114 applyAstrometricCorrections = pexConfig.Field( 

115 dtype=bool, 

116 default=True, 

117 doc="Apply proper motion and parallax corrections to source positions.", 

118 ) 

119 astrometricCorrectionParameters = pexConfig.DictField( 

120 keytype=str, 

121 itemtype=str, 

122 default={ 

123 "ra": "ra", 

124 "dec": "dec", 

125 "pmRA": "raPM", 

126 "pmDec": "decPM", 

127 "parallax": "parallax", 

128 "isolated_star_id": "isolated_star_id", 

129 }, 

130 doc="Column names for position and motion parameters in the astrometric correction catalogs.", 

131 ) 

132 

133 

134class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

135 ConfigClass = AssociatedSourcesTractAnalysisConfig 

136 _DefaultName = "associatedSourcesTractAnalysis" 

137 

138 @staticmethod 

139 def getBoxWcs(skymap, tract): 

140 """Get box that defines tract boundaries.""" 

141 tractInfo = skymap.generateTract(tract) 

142 wcs = tractInfo.getWcs() 

143 tractBox = tractInfo.getBBox() 

144 return tractBox, wcs 

145 

146 def callback(self, inputs, dataId): 

147 """Callback function to be used with reconstructor.""" 

148 return self.prepareAssociatedSources( 

149 inputs["skyMap"], 

150 dataId["tract"], 

151 inputs["sourceCatalogs"], 

152 inputs["associatedSources"], 

153 inputs["associatedSourceIds"], 

154 inputs["astrometricCorrectionCatalog"], 

155 inputs["visitTable"], 

156 ) 

157 

158 def prepareAssociatedSources( 

159 self, 

160 skymap, 

161 tract, 

162 sourceCatalogs, 

163 associatedSources, 

164 associatedSourceIds, 

165 astrometricCorrectionCatalog=None, 

166 visitTable=None, 

167 ): 

168 """Concatenate source catalogs and join on associated source IDs.""" 

169 

170 # Strip any provenance from tables before merging to prevent 

171 # warnings from conflicts being issued by astropy.utils.merge. 

172 for srcCat in sourceCatalogs: 

173 DatasetProvenance.strip_provenance_from_flat_dict(srcCat.meta) 

174 DatasetProvenance.strip_provenance_from_flat_dict(associatedSources.meta) 

175 DatasetProvenance.strip_provenance_from_flat_dict(associatedSourceIds.meta) 

176 

177 # associatedSource["obj_index"] refers to the corresponding index (row) 

178 # in associatedSourceIds. 

179 index = associatedSources["obj_index"] 

180 associatedSources["isolated_star_id"] = associatedSourceIds["isolated_star_id"][index] 

181 

182 # Keep only sources with associations 

183 sourceCatalogStack = vstack(sourceCatalogs, join_type="exact") 

184 dataJoined = join(sourceCatalogStack, associatedSources, keys="sourceId", join_type="inner") 

185 

186 if astrometricCorrectionCatalog is not None: 

187 self.applyAstrometricCorrections(dataJoined, astrometricCorrectionCatalog, visitTable) 

188 

189 # Determine which sources are contained in tract 

190 ra = np.radians(dataJoined["coord_ra"]) 

191 dec = np.radians(dataJoined["coord_dec"]) 

192 box, wcs = self.getBoxWcs(skymap, tract) 

193 box = Box2D(box) 

194 x, y = wcs.skyToPixelArray(ra, dec) 

195 boxSelection = box.contains(x, y) 

196 

197 # Keep only the sources in groups that are fully contained within the 

198 # tract 

199 dataFiltered = dataJoined[boxSelection] 

200 

201 return dataFiltered 

202 

203 def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalog, visitTable): 

204 """Use proper motion/parallax catalogs to shift positions to median 

205 epoch of the visits. 

206 

207 Parameters 

208 ---------- 

209 dataJoined : `astropy.table.Table` 

210 Table containing source positions, which will be modified in place. 

211 astrometricCorrectionCatalog : `astropy.table.Table` 

212 Proper motion and parallax catalog. 

213 visitTable : `pd.DataFrame` 

214 Table containing the MJDs of the visits. 

215 """ 

216 if visitTable.index.name is None: 

217 # The expected index may or may not be set, depending on whether 

218 # the table was written originally as a DataFrame or something else 

219 # Parquet-friendly. 

220 visitTable.set_index("visitId", inplace=True) 

221 

222 # Get the stellar motion catalog into the right format: 

223 for key, value in self.config.astrometricCorrectionParameters.items(): 

224 astrometricCorrectionCatalog.rename_column(value, key) 

225 astrometricCorrectionCatalog["ra"] *= u.degree 

226 astrometricCorrectionCatalog["dec"] *= u.degree 

227 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr 

228 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr 

229 astrometricCorrectionCatalog["parallax"] *= u.mas 

230 

231 dataWithPM = join( 

232 dataJoined, 

233 astrometricCorrectionCatalog, 

234 keys="isolated_star_id", 

235 join_type="left", 

236 keep_order=True, 

237 ) 

238 

239 mjds = visitTable.loc[dataWithPM["visit"]]["expMidptMJD"] 

240 times = astropy.time.Time(mjds, format="mjd", scale="tai") 

241 dataWithPM["MJD"] = times 

242 medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai") 

243 

244 raCorrection, decCorrection = calculate_apparent_motion(dataWithPM, medianMJD) 

245 

246 dataJoined["coord_ra"] = dataWithPM["coord_ra"] - raCorrection.value 

247 dataJoined["coord_dec"] = dataWithPM["coord_dec"] - decCorrection.value 

248 

249 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

250 inputs = butlerQC.get(inputRefs) 

251 

252 # Load specified columns from source catalogs 

253 names = self.collectInputNames() 

254 names |= {"sourceId", "coord_ra", "coord_dec"} 

255 for item in ["obj_index", "isolated_star_id"]: 

256 if item in names: 

257 names.remove(item) 

258 

259 sourceCatalogs = [] 

260 for handle in inputs["sourceCatalogs"]: 

261 sourceCatalogs.append(self.loadData(handle, names)) 

262 inputs["sourceCatalogs"] = sourceCatalogs 

263 

264 if self.config.applyAstrometricCorrections: 

265 astrometricCorrections = inputs["astrometricCorrectionCatalog"].get( 

266 parameters={"columns": self.config.astrometricCorrectionParameters.values()} 

267 ) 

268 inputs["astrometricCorrectionCatalog"] = astrometricCorrections 

269 else: 

270 inputs["astrometricCorrectionCatalog"] = None 

271 inputs["visitTable"] = None 

272 

273 dataId = butlerQC.quantum.dataId 

274 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

275 

276 # TODO: make key used for object index configurable 

277 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

278 inputs["associatedSourceIds"] = self.loadData(inputs["associatedSourceIds"], ["isolated_star_id"]) 

279 

280 if len(inputs["associatedSources"]) == 0: 

281 raise NoWorkFound(f"No associated sources in tract {dataId.tract.id}") 

282 

283 data = self.callback(inputs, dataId) 

284 

285 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]} 

286 outputs = self.run(**kwargs) 

287 self.putByBand(butlerQC, outputs, outputRefs)