Coverage for python / lsst / analysis / tools / tasks / associatedSourcesTractAnalysis.py: 24%
120 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:21 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:21 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask")
25import astropy.time
26import astropy.units as u
27import numpy as np
28from astropy.table import Table, hstack
29from scipy.spatial import KDTree
31import lsst.pex.config as pexConfig
32from lsst.daf.butler import DatasetProvenance
33from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion
34from lsst.pipe.base import NoWorkFound
35from lsst.pipe.base import connectionTypes as ct
36from lsst.skymap import BaseSkyMap
38from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
41class AssociatedSourcesTractAnalysisConnections(
42 AnalysisBaseConnections,
43 dimensions=("skymap", "tract", "instrument"),
44 defaultTemplates={
45 "outputName": "isolated_star_presources",
46 "associatedSourcesInputName": "isolated_star_presources",
47 "associatedSourceIdsInputName": "isolated_star_presource_associations",
48 },
49):
50 sourceCatalogs = ct.Input(
51 doc="Visit based source table to load from the butler",
52 name="sourceTable_visit",
53 storageClass="ArrowAstropy",
54 deferLoad=True,
55 dimensions=("visit", "band"),
56 multiple=True,
57 )
59 associatedSources = ct.Input(
60 doc="Table of associated sources",
61 name="{associatedSourcesInputName}",
62 storageClass="ArrowAstropy",
63 deferLoad=True,
64 dimensions=("instrument", "skymap", "tract"),
65 )
67 associatedSourceIds = ct.Input(
68 doc="Table containing unique ids for the associated sources",
69 name="{associatedSourceIdsInputName}",
70 storageClass="ArrowAstropy",
71 deferLoad=True,
72 dimensions=("instrument", "skymap", "tract"),
73 )
75 skyMap = ct.Input(
76 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
77 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
78 storageClass="SkyMap",
79 dimensions=("skymap",),
80 )
82 camera = ct.PrerequisiteInput(
83 doc="Input camera to use for focal plane geometry.",
84 name="camera",
85 storageClass="Camera",
86 dimensions=("instrument",),
87 isCalibration=True,
88 )
89 astrometricCorrectionCatalog = ct.Input(
90 doc="Catalog with proper motion and parallax information.",
91 name="isolated_star_stellar_motions",
92 storageClass="ArrowAstropy",
93 deferLoad=True,
94 dimensions=("instrument", "skymap", "tract"),
95 )
97 visitTable = ct.Input(
98 doc="Catalog containing visit information.",
99 name="visitTable",
100 storageClass="DataFrame",
101 dimensions=("instrument",),
102 )
104 def __init__(self, *, config=None):
105 super().__init__(config=config)
107 if not config.applyAstrometricCorrections:
108 self.inputs.remove("astrometricCorrectionCatalog")
109 self.inputs.remove("visitTable")
112class AssociatedSourcesTractAnalysisConfig(
113 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections
114):
115 applyAstrometricCorrections = pexConfig.Field(
116 dtype=bool,
117 default=True,
118 doc="Apply proper motion and parallax corrections to source positions.",
119 )
120 astrometricCorrectionParameters = pexConfig.DictField(
121 keytype=str,
122 itemtype=str,
123 default={
124 "ra": "ra",
125 "dec": "dec",
126 "pmRA": "raPM",
127 "pmDec": "decPM",
128 "parallax": "parallax",
129 "isolated_star_id": "isolated_star_id",
130 },
131 doc="Column names for position and motion parameters in the astrometric correction catalogs.",
132 )
135class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask):
136 ConfigClass = AssociatedSourcesTractAnalysisConfig
137 _DefaultName = "associatedSourcesTractAnalysis"
139 @staticmethod
140 def getBoxWcs(skymap, tract):
141 """Get box that defines tract boundaries."""
142 tractInfo = skymap.generateTract(tract)
143 wcs = tractInfo.getWcs()
144 tractBox = tractInfo.getBBox()
145 return tractBox, wcs
147 def callback(self, inputs, dataId):
148 """Callback function to be used with reconstructor."""
149 return self.prepareAssociatedSources(
150 inputs["skyMap"],
151 dataId["tract"],
152 inputs["sourceCatalogs"],
153 inputs["associatedSources"],
154 inputs["associatedSourceIds"],
155 inputs["astrometricCorrectionCatalog"],
156 inputs["visitTable"],
157 )
159 def prepareAssociatedSources(
160 self,
161 skymap,
162 tract,
163 sourceCatalogs,
164 associatedSources,
165 associatedSourceIds,
166 astrometricCorrectionCatalog=None,
167 visitTable=None,
168 ):
169 """Concatenate source catalogs and join on associated source IDs."""
171 # Strip any provenance from tables before merging to prevent
172 # warnings from conflicts being issued by astropy.utils.merge.
173 DatasetProvenance.strip_provenance_from_flat_dict(associatedSources.meta)
174 DatasetProvenance.strip_provenance_from_flat_dict(associatedSourceIds.meta)
176 # associatedSource["obj_index"] refers to the corresponding index (row)
177 # in associatedSourceIds.
178 index = associatedSources["obj_index"]
179 associatedSources["isolated_star_id"] = associatedSourceIds["isolated_star_id"][index]
181 trimmedSourceCatalogs = []
182 fullCatLen = 0
183 # It would be preferable to use astropy's built in functions
184 # but they are too slow so we have this wonderful masterpiece
185 # Which is still not fast but two thirds of the time is the butler get
186 reshapedAssocSources = associatedSources["sourceId"].reshape(len(associatedSources), 1)
187 colsNeeded = list(self.collectInputNames())
188 # Only get the columns needed for the source catalogues.
189 # The isolated_star_id and the obj_index are added later
190 # from other tables so remove these from the list. Also
191 # add the coord_ra and coord_dec as well because this bit
192 # of code needs it even if it isn't requested by a
193 # downstream atool.
194 if "isolated_star_id" in colsNeeded:
195 colsNeeded.remove("isolated_star_id")
196 if "obj_index" in colsNeeded:
197 colsNeeded.remove("obj_index")
198 colsNeeded += ["sourceId", "coord_ra", "coord_dec"]
199 for sourceCatalogRef in sourceCatalogs:
200 sourceCatalog = sourceCatalogRef.get(parameters={"columns": set(colsNeeded)})
201 DatasetProvenance.strip_provenance_from_flat_dict(sourceCatalog.meta)
202 reshapedSourceCat = sourceCatalog["sourceId"].reshape(len(sourceCatalog), 1)
204 tree = KDTree(reshapedSourceCat)
205 _, inds = tree.query(reshapedAssocSources, distance_upper_bound=0.1)
206 ids = inds < len(sourceCatalog)
208 # Keep only the sources in groups that are fully contained within
209 # the tract by matching to the associated sources table
210 trimmedSourceCatalogs.append(hstack([associatedSources[ids], sourceCatalog[inds[ids]]]))
211 fullCatLen += np.sum(ids)
213 columns = trimmedSourceCatalogs[0].columns
214 dtypes = trimmedSourceCatalogs[0].dtype
215 zeros = np.zeros((fullCatLen, len(columns)))
216 fullCat = Table(data=zeros, names=columns, dtype=dtypes)
217 n = 0
218 for trimmedSourceCatalog in trimmedSourceCatalogs:
219 fullCat[n : n + len(trimmedSourceCatalog)] = trimmedSourceCatalog
220 n += len(trimmedSourceCatalog)
222 if astrometricCorrectionCatalog is not None:
223 self.applyAstrometricCorrections(fullCat, astrometricCorrectionCatalog, visitTable)
225 # Keep only finite ras and decs
226 keep = np.isfinite(fullCat["coord_ra"]) & np.isfinite(fullCat["coord_dec"])
227 return fullCat[keep]
229 def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalog, visitTable):
230 """Use proper motion/parallax catalogs to shift positions to median
231 epoch of the visits.
233 Parameters
234 ----------
235 dataJoined : `astropy.table.Table`
236 Table containing source positions, which will be modified in place.
237 astrometricCorrectionCatalog : `astropy.table.Table`
238 Proper motion and parallax catalog.
239 visitTable : `pd.DataFrame`
240 Table containing the MJDs of the visits.
241 """
242 if visitTable.index.name is None:
243 # The expected index may or may not be set, depending on whether
244 # the table was written originally as a DataFrame or something else
245 # Parquet-friendly.
246 visitTable.set_index("visitId", inplace=True)
248 # Get the stellar motion catalog into the right format:
249 for key, value in self.config.astrometricCorrectionParameters.items():
250 astrometricCorrectionCatalog.rename_column(value, key)
251 astrometricCorrectionCatalog["ra"] *= u.degree
252 astrometricCorrectionCatalog["dec"] *= u.degree
253 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr
254 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr
255 astrometricCorrectionCatalog["parallax"] *= u.mas
257 # Again using astropy join would have been great but this is four
258 # times faster
259 lenAstroCorrCat = len(astrometricCorrectionCatalog)
260 tree = KDTree(astrometricCorrectionCatalog["isolated_star_id"].reshape(lenAstroCorrCat, 1))
261 _, inds = tree.query(
262 dataJoined["isolated_star_id"].reshape(len(dataJoined), 1), distance_upper_bound=0.5
263 )
264 ids = inds < lenAstroCorrCat
266 dataWithPM = hstack([dataJoined[ids], astrometricCorrectionCatalog[inds[ids]]])
268 mjds = visitTable.loc[dataWithPM["visit"]]["expMidptMJD"]
269 times = astropy.time.Time(mjds, format="mjd", scale="tai")
270 dataWithPM["MJD"] = times
271 medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai")
273 raCorrection, decCorrection = calculate_apparent_motion(dataWithPM, medianMJD)
275 dataJoined["coord_ra"] = dataWithPM["coord_ra"] - raCorrection.value
276 dataJoined["coord_dec"] = dataWithPM["coord_dec"] - decCorrection.value
278 def runQuantum(self, butlerQC, inputRefs, outputRefs):
279 inputs = butlerQC.get(inputRefs)
281 # Load specified columns from source catalogs
282 names = self.collectInputNames()
283 names |= {"sourceId", "coord_ra", "coord_dec"}
284 for item in ["obj_index", "isolated_star_id"]:
285 if item in names:
286 names.remove(item)
288 if self.config.applyAstrometricCorrections:
289 astrometricCorrections = inputs["astrometricCorrectionCatalog"].get(
290 parameters={"columns": self.config.astrometricCorrectionParameters.values()}
291 )
292 inputs["astrometricCorrectionCatalog"] = astrometricCorrections
293 else:
294 inputs["astrometricCorrectionCatalog"] = None
295 inputs["visitTable"] = None
297 dataId = butlerQC.quantum.dataId
298 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources")
300 # TODO: make key used for object index configurable
301 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"])
302 inputs["associatedSourceIds"] = self.loadData(inputs["associatedSourceIds"], ["isolated_star_id"])
304 if len(inputs["associatedSources"]) == 0:
305 raise NoWorkFound(f"No associated sources in tract {dataId.tract.id}")
307 data = self.callback(inputs, dataId)
309 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]}
310 outputs = self.run(**kwargs)
311 self.putByBand(butlerQC, outputs, outputRefs)