Coverage for python / lsst / analysis / tools / tasks / associatedSourcesTractAnalysis.py: 28%
104 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask")
25import astropy.time
26import astropy.units as u
27import lsst.pex.config as pexConfig
28import numpy as np
29from astropy.table import join, vstack
30from lsst.daf.butler import DatasetProvenance
31from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion
32from lsst.geom import Box2D
33from lsst.pipe.base import NoWorkFound
34from lsst.pipe.base import connectionTypes as ct
35from lsst.skymap import BaseSkyMap
37from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
40class AssociatedSourcesTractAnalysisConnections(
41 AnalysisBaseConnections,
42 dimensions=("skymap", "tract", "instrument"),
43 defaultTemplates={
44 "outputName": "isolated_star_presources",
45 "associatedSourcesInputName": "isolated_star_presources",
46 "associatedSourceIdsInputName": "isolated_star_presource_associations",
47 },
48):
49 sourceCatalogs = ct.Input(
50 doc="Visit based source table to load from the butler",
51 name="sourceTable_visit",
52 storageClass="ArrowAstropy",
53 deferLoad=True,
54 dimensions=("visit", "band"),
55 multiple=True,
56 )
58 associatedSources = ct.Input(
59 doc="Table of associated sources",
60 name="{associatedSourcesInputName}",
61 storageClass="ArrowAstropy",
62 deferLoad=True,
63 dimensions=("instrument", "skymap", "tract"),
64 )
66 associatedSourceIds = ct.Input(
67 doc="Table containing unique ids for the associated sources",
68 name="{associatedSourceIdsInputName}",
69 storageClass="ArrowAstropy",
70 deferLoad=True,
71 dimensions=("instrument", "skymap", "tract"),
72 )
74 skyMap = ct.Input(
75 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
76 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
77 storageClass="SkyMap",
78 dimensions=("skymap",),
79 )
81 camera = ct.PrerequisiteInput(
82 doc="Input camera to use for focal plane geometry.",
83 name="camera",
84 storageClass="Camera",
85 dimensions=("instrument",),
86 isCalibration=True,
87 )
88 astrometricCorrectionCatalog = ct.Input(
89 doc="Catalog with proper motion and parallax information.",
90 name="isolated_star_stellar_motions",
91 storageClass="ArrowAstropy",
92 deferLoad=True,
93 dimensions=("instrument", "skymap", "tract"),
94 )
96 visitTable = ct.Input(
97 doc="Catalog containing visit information.",
98 name="visitTable",
99 storageClass="DataFrame",
100 dimensions=("instrument",),
101 )
103 def __init__(self, *, config=None):
104 super().__init__(config=config)
106 if not config.applyAstrometricCorrections:
107 self.inputs.remove("astrometricCorrectionCatalog")
108 self.inputs.remove("visitTable")
111class AssociatedSourcesTractAnalysisConfig(
112 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections
113):
114 applyAstrometricCorrections = pexConfig.Field(
115 dtype=bool,
116 default=True,
117 doc="Apply proper motion and parallax corrections to source positions.",
118 )
119 astrometricCorrectionParameters = pexConfig.DictField(
120 keytype=str,
121 itemtype=str,
122 default={
123 "ra": "ra",
124 "dec": "dec",
125 "pmRA": "raPM",
126 "pmDec": "decPM",
127 "parallax": "parallax",
128 "isolated_star_id": "isolated_star_id",
129 },
130 doc="Column names for position and motion parameters in the astrometric correction catalogs.",
131 )
134class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask):
135 ConfigClass = AssociatedSourcesTractAnalysisConfig
136 _DefaultName = "associatedSourcesTractAnalysis"
138 @staticmethod
139 def getBoxWcs(skymap, tract):
140 """Get box that defines tract boundaries."""
141 tractInfo = skymap.generateTract(tract)
142 wcs = tractInfo.getWcs()
143 tractBox = tractInfo.getBBox()
144 return tractBox, wcs
146 def callback(self, inputs, dataId):
147 """Callback function to be used with reconstructor."""
148 return self.prepareAssociatedSources(
149 inputs["skyMap"],
150 dataId["tract"],
151 inputs["sourceCatalogs"],
152 inputs["associatedSources"],
153 inputs["associatedSourceIds"],
154 inputs["astrometricCorrectionCatalog"],
155 inputs["visitTable"],
156 )
158 def prepareAssociatedSources(
159 self,
160 skymap,
161 tract,
162 sourceCatalogs,
163 associatedSources,
164 associatedSourceIds,
165 astrometricCorrectionCatalog=None,
166 visitTable=None,
167 ):
168 """Concatenate source catalogs and join on associated source IDs."""
170 # Strip any provenance from tables before merging to prevent
171 # warnings from conflicts being issued by astropy.utils.merge.
172 for srcCat in sourceCatalogs:
173 DatasetProvenance.strip_provenance_from_flat_dict(srcCat.meta)
174 DatasetProvenance.strip_provenance_from_flat_dict(associatedSources.meta)
175 DatasetProvenance.strip_provenance_from_flat_dict(associatedSourceIds.meta)
177 # associatedSource["obj_index"] refers to the corresponding index (row)
178 # in associatedSourceIds.
179 index = associatedSources["obj_index"]
180 associatedSources["isolated_star_id"] = associatedSourceIds["isolated_star_id"][index]
182 # Keep only sources with associations
183 sourceCatalogStack = vstack(sourceCatalogs, join_type="exact")
184 dataJoined = join(sourceCatalogStack, associatedSources, keys="sourceId", join_type="inner")
186 if astrometricCorrectionCatalog is not None:
187 self.applyAstrometricCorrections(dataJoined, astrometricCorrectionCatalog, visitTable)
189 # Determine which sources are contained in tract
190 ra = np.radians(dataJoined["coord_ra"])
191 dec = np.radians(dataJoined["coord_dec"])
192 box, wcs = self.getBoxWcs(skymap, tract)
193 box = Box2D(box)
194 x, y = wcs.skyToPixelArray(ra, dec)
195 boxSelection = box.contains(x, y)
197 # Keep only the sources in groups that are fully contained within the
198 # tract
199 dataFiltered = dataJoined[boxSelection]
201 return dataFiltered
203 def applyAstrometricCorrections(self, dataJoined, astrometricCorrectionCatalog, visitTable):
204 """Use proper motion/parallax catalogs to shift positions to median
205 epoch of the visits.
207 Parameters
208 ----------
209 dataJoined : `astropy.table.Table`
210 Table containing source positions, which will be modified in place.
211 astrometricCorrectionCatalog : `astropy.table.Table`
212 Proper motion and parallax catalog.
213 visitTable : `pd.DataFrame`
214 Table containing the MJDs of the visits.
215 """
216 if visitTable.index.name is None:
217 # The expected index may or may not be set, depending on whether
218 # the table was written originally as a DataFrame or something else
219 # Parquet-friendly.
220 visitTable.set_index("visitId", inplace=True)
222 # Get the stellar motion catalog into the right format:
223 for key, value in self.config.astrometricCorrectionParameters.items():
224 astrometricCorrectionCatalog.rename_column(value, key)
225 astrometricCorrectionCatalog["ra"] *= u.degree
226 astrometricCorrectionCatalog["dec"] *= u.degree
227 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr
228 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr
229 astrometricCorrectionCatalog["parallax"] *= u.mas
231 dataWithPM = join(
232 dataJoined,
233 astrometricCorrectionCatalog,
234 keys="isolated_star_id",
235 join_type="left",
236 keep_order=True,
237 )
239 mjds = visitTable.loc[dataWithPM["visit"]]["expMidptMJD"]
240 times = astropy.time.Time(mjds, format="mjd", scale="tai")
241 dataWithPM["MJD"] = times
242 medianMJD = astropy.time.Time(np.median(mjds), format="mjd", scale="tai")
244 raCorrection, decCorrection = calculate_apparent_motion(dataWithPM, medianMJD)
246 dataJoined["coord_ra"] = dataWithPM["coord_ra"] - raCorrection.value
247 dataJoined["coord_dec"] = dataWithPM["coord_dec"] - decCorrection.value
249 def runQuantum(self, butlerQC, inputRefs, outputRefs):
250 inputs = butlerQC.get(inputRefs)
252 # Load specified columns from source catalogs
253 names = self.collectInputNames()
254 names |= {"sourceId", "coord_ra", "coord_dec"}
255 for item in ["obj_index", "isolated_star_id"]:
256 if item in names:
257 names.remove(item)
259 sourceCatalogs = []
260 for handle in inputs["sourceCatalogs"]:
261 sourceCatalogs.append(self.loadData(handle, names))
262 inputs["sourceCatalogs"] = sourceCatalogs
264 if self.config.applyAstrometricCorrections:
265 astrometricCorrections = inputs["astrometricCorrectionCatalog"].get(
266 parameters={"columns": self.config.astrometricCorrectionParameters.values()}
267 )
268 inputs["astrometricCorrectionCatalog"] = astrometricCorrections
269 else:
270 inputs["astrometricCorrectionCatalog"] = None
271 inputs["visitTable"] = None
273 dataId = butlerQC.quantum.dataId
274 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources")
276 # TODO: make key used for object index configurable
277 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"])
278 inputs["associatedSourceIds"] = self.loadData(inputs["associatedSourceIds"], ["isolated_star_id"])
280 if len(inputs["associatedSources"]) == 0:
281 raise NoWorkFound(f"No associated sources in tract {dataId.tract.id}")
283 data = self.callback(inputs, dataId)
285 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]}
286 outputs = self.run(**kwargs)
287 self.putByBand(butlerQC, outputs, outputRefs)