Coverage for python / lsst / analysis / tools / tasks / associatedSourcesTractAnalysis.py: 24%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import astropy.time 

26import astropy.units as u 

27import numpy as np 

28from astropy.table import Table, hstack 

29from scipy.spatial import KDTree 

30 

31import lsst.pex.config as pexConfig 

32from lsst.daf.butler import DatasetProvenance 

33from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion 

34from lsst.pipe.base import NoWorkFound 

35from lsst.pipe.base import connectionTypes as ct 

36from lsst.skymap import BaseSkyMap 

37 

38from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

39 

40 

41class AssociatedSourcesTractAnalysisConnections( 

42 AnalysisBaseConnections, 

43 dimensions=("skymap", "tract", "instrument"), 

44 defaultTemplates={ 

45 "outputName": "isolated_star_presources", 

46 "associatedSourcesInputName": "isolated_star_presources", 

47 "associatedSourceIdsInputName": "isolated_star_presource_associations", 

48 }, 

49): 

50 sourceCatalogs = ct.Input( 

51 doc="Visit based source table to load from the butler", 

52 name="sourceTable_visit", 

53 storageClass="ArrowAstropy", 

54 deferLoad=True, 

55 dimensions=("visit", "band"), 

56 multiple=True, 

57 ) 

58 

59 associatedSources = ct.Input( 

60 doc="Table of associated sources", 

61 name="{associatedSourcesInputName}", 

62 storageClass="ArrowAstropy", 

63 deferLoad=True, 

64 dimensions=("instrument", "skymap", "tract"), 

65 ) 

66 

67 associatedSourceIds = ct.Input( 

68 doc="Table containing unique ids for the associated sources", 

69 name="{associatedSourceIdsInputName}", 

70 storageClass="ArrowAstropy", 

71 deferLoad=True, 

72 dimensions=("instrument", "skymap", "tract"), 

73 ) 

74 

75 skyMap = ct.Input( 

76 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

77 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

78 storageClass="SkyMap", 

79 dimensions=("skymap",), 

80 ) 

81 

82 camera = ct.PrerequisiteInput( 

83 doc="Input camera to use for focal plane geometry.", 

84 name="camera", 

85 storageClass="Camera", 

86 dimensions=("instrument",), 

87 isCalibration=True, 

88 ) 

89 astrometricCorrectionCatalog = ct.Input( 

90 doc="Catalog with proper motion and parallax information.", 

91 name="isolated_star_stellar_motions", 

92 storageClass="ArrowAstropy", 

93 deferLoad=True, 

94 dimensions=("instrument", "skymap", "tract"), 

95 ) 

96 

97 visitTable = ct.Input( 

98 doc="Catalog containing visit information.", 

99 name="visitTable", 

100 storageClass="DataFrame", 

101 dimensions=("instrument",), 

102 ) 

103 

104 def __init__(self, *, config=None): 

105 super().__init__(config=config) 

106 

107 if not config.applyAstrometricCorrections: 

108 self.inputs.remove("astrometricCorrectionCatalog") 

109 self.inputs.remove("visitTable") 

110 

111 

112class AssociatedSourcesTractAnalysisConfig( 

113 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

114): 

115 applyAstrometricCorrections = pexConfig.Field( 

116 dtype=bool, 

117 default=True, 

118 doc="Apply proper motion and parallax corrections to source positions.", 

119 ) 

120 astrometricCorrectionParameters = pexConfig.DictField( 

121 keytype=str, 

122 itemtype=str, 

123 default={ 

124 "ra": "ra", 

125 "dec": "dec", 

126 "pmRA": "raPM", 

127 "pmDec": "decPM", 

128 "parallax": "parallax", 

129 "isolated_star_id": "isolated_star_id", 

130 }, 

131 doc="Column names for position and motion parameters in the astrometric correction catalogs.", 

132 ) 

133 

134 

135class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

136 ConfigClass = AssociatedSourcesTractAnalysisConfig 

137 _DefaultName = "associatedSourcesTractAnalysis" 

138 

139 @staticmethod 

140 def getBoxWcs(skymap, tract): 

141 """Get box that defines tract boundaries.""" 

142 tractInfo = skymap.generateTract(tract) 

143 wcs = tractInfo.getWcs() 

144 tractBox = tractInfo.getBBox() 

145 return tractBox, wcs 

146 

147 def callback(self, inputs, dataId): 

148 """Callback function to be used with reconstructor.""" 

149 return self.prepareAssociatedSources( 

150 inputs["skyMap"], 

151 dataId["tract"], 

152 inputs["sourceCatalogs"], 

153 inputs["associatedSources"], 

154 inputs["associatedSourceIds"], 

155 inputs["astrometricCorrectionCatalog"], 

156 inputs["visitTable"], 

157 ) 

158 

159 def prepareAssociatedSources( 

160 self, 

161 skymap, 

162 tract, 

163 sourceCatalogs, 

164 associatedSources, 

165 associatedSourceIds, 

166 astrometricCorrectionCatalog=None, 

167 visitTable=None, 

168 ): 

169 """Concatenate source catalogs and join on associated source IDs.""" 

170 

171 # Strip any provenance from tables before merging to prevent 

172 # warnings from conflicts being issued by astropy.utils.merge. 

173 DatasetProvenance.strip_provenance_from_flat_dict(associatedSources.meta) 

174 DatasetProvenance.strip_provenance_from_flat_dict(associatedSourceIds.meta) 

175 

176 # associatedSource["obj_index"] refers to the corresponding index (row) 

177 # in associatedSourceIds. 

178 index = associatedSources["obj_index"] 

179 associatedSources["isolated_star_id"] = associatedSourceIds["isolated_star_id"][index] 

180 

181 trimmedSourceCatalogs = [] 

182 fullCatLen = 0 

183 # It would be preferable to use astropy's built in functions 

184 # but they are too slow so we have this wonderful masterpiece 

185 # Which is still not fast but two thirds of the time is the butler get 

186 reshapedAssocSources = associatedSources["sourceId"].reshape(len(associatedSources), 1) 

187 colsNeeded = list(self.collectInputNames()) 

188 # Only get the columns needed for the source catalogues. 

189 # The isolated_star_id and the obj_index are added later 

190 # from other tables so remove these from the list. Also 

191 # add the coord_ra and coord_dec as well because this bit 

192 # of code needs it even if it isn't requested by a 

193 # downstream atool. 

194 if "isolated_star_id" in colsNeeded: 

195 colsNeeded.remove("isolated_star_id") 

196 if "obj_index" in colsNeeded: 

197 colsNeeded.remove("obj_index") 

198 colsNeeded += ["sourceId", "coord_ra", "coord_dec"] 

199 for sourceCatalogRef in sourceCatalogs: 

200 sourceCatalog = sourceCatalogRef.get(parameters={"columns": set(colsNeeded)}) 

201 DatasetProvenance.strip_provenance_from_flat_dict(sourceCatalog.meta) 

202 reshapedSourceCat = sourceCatalog["sourceId"].reshape(len(sourceCatalog), 1) 

203 

204 tree = KDTree(reshapedSourceCat) 

205 _, inds = tree.query(reshapedAssocSources, distance_upper_bound=0.1) 

206 ids = inds < len(sourceCatalog) 

207 

208 # Keep only the sources in groups that are fully contained within 

209 # the tract by matching to the associated sources table 

210 trimmedSourceCatalogs.append(hstack([associatedSources[ids], sourceCatalog[inds[ids]]])) 

211 fullCatLen += np.sum(ids) 

212 

213 columns = trimmedSourceCatalogs[0].columns 

214 dtypes = trimmedSourceCatalogs[0].dtype 

215 zeros = np.zeros((fullCatLen, len(columns))) 

216 fullCat = Table(data=zeros, names=columns, dtype=dtypes) 

217 n = 0 

218 for trimmedSourceCatalog in trimmedSourceCatalogs: 

219 fullCat[n : n + len(trimmedSourceCatalog)] = trimmedSourceCatalog 

220 n += len(trimmedSourceCatalog) 

221 

222 if astrometricCorrectionCatalog is not None: 

223 self.applyAstrometricCorrections(fullCat, astrometricCorrectionCatalog, visitTable) 

224 

225 # Keep only finite ras and decs 

226 keep = np.isfinite(fullCat["coord_ra"]) & np.isfinite(fullCat["coord_dec"]) 

227 return fullCat[keep] 

228 

229 def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalog, visitTable): 

230 """Use proper motion/parallax catalogs to shift positions to median 

231 epoch of the visits. 

232 

233 Parameters 

234 ---------- 

235 dataJoined : `astropy.table.Table` 

236 Table containing source positions, which will be modified in place. 

237 astrometricCorrectionCatalog : `astropy.table.Table` 

238 Proper motion and parallax catalog. 

239 visitTable : `pd.DataFrame` 

240 Table containing the MJDs of the visits. 

241 """ 

242 if visitTable.index.name is None: 

243 # The expected index may or may not be set, depending on whether 

244 # the table was written originally as a DataFrame or something else 

245 # Parquet-friendly. 

246 visitTable.set_index("visitId", inplace=True) 

247 

248 # Get the stellar motion catalog into the right format: 

249 for key, value in self.config.astrometricCorrectionParameters.items(): 

250 astrometricCorrectionCatalog.rename_column(value, key) 

251 astrometricCorrectionCatalog["ra"] *= u.degree 

252 astrometricCorrectionCatalog["dec"] *= u.degree 

253 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr 

254 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr 

255 astrometricCorrectionCatalog["parallax"] *= u.mas 

256 

257 # Again using astropy join would have been great but this is four 

258 # times faster 

259 lenAstroCorrCat = len(astrometricCorrectionCatalog) 

260 tree = KDTree(astrometricCorrectionCatalog["isolated_star_id"].reshape(lenAstroCorrCat, 1)) 

261 _, inds = tree.query( 

262 dataJoined["isolated_star_id"].reshape(len(dataJoined), 1), distance_upper_bound=0.5 

263 ) 

264 ids = inds < lenAstroCorrCat 

265 

266 dataWithPM = hstack([dataJoined[ids], astrometricCorrectionCatalog[inds[ids]]]) 

267 

268 mjds = visitTable.loc[dataWithPM["visit"]]["expMidptMJD"] 

269 times = astropy.time.Time(mjds, format="mjd", scale="tai") 

270 dataWithPM["MJD"] = times 

271 medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai") 

272 

273 raCorrection, decCorrection = calculate_apparent_motion(dataWithPM, medianMJD) 

274 

275 dataJoined["coord_ra"] = dataWithPM["coord_ra"] - raCorrection.value 

276 dataJoined["coord_dec"] = dataWithPM["coord_dec"] - decCorrection.value 

277 

278 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

279 inputs = butlerQC.get(inputRefs) 

280 

281 # Load specified columns from source catalogs 

282 names = self.collectInputNames() 

283 names |= {"sourceId", "coord_ra", "coord_dec"} 

284 for item in ["obj_index", "isolated_star_id"]: 

285 if item in names: 

286 names.remove(item) 

287 

288 if self.config.applyAstrometricCorrections: 

289 astrometricCorrections = inputs["astrometricCorrectionCatalog"].get( 

290 parameters={"columns": self.config.astrometricCorrectionParameters.values()} 

291 ) 

292 inputs["astrometricCorrectionCatalog"] = astrometricCorrections 

293 else: 

294 inputs["astrometricCorrectionCatalog"] = None 

295 inputs["visitTable"] = None 

296 

297 dataId = butlerQC.quantum.dataId 

298 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

299 

300 # TODO: make key used for object index configurable 

301 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

302 inputs["associatedSourceIds"] = self.loadData(inputs["associatedSourceIds"], ["isolated_star_id"]) 

303 

304 if len(inputs["associatedSources"]) == 0: 

305 raise NoWorkFound(f"No associated sources in tract {dataId.tract.id}") 

306 

307 data = self.callback(inputs, dataId) 

308 

309 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]} 

310 outputs = self.run(**kwargs) 

311 self.putByBand(butlerQC, outputs, outputRefs)