Coverage for python / lsst / analysis / tools / tasks / sourceObjectTableAnalysis.py: 27%
189 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:45 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:45 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "SourceObjectTableAnalysisConfig",
25 "SourceObjectTableAnalysisTask",
26 "ObjectEpochTableConfig",
27 "ObjectEpochTableTask",
28)
30import astropy.time
31import astropy.units as u
32import numpy as np
33import pandas as pd
34from astropy.table import Table, join, vstack
35from smatch import Matcher
37import lsst.pex.config as pexConfig
38import lsst.pipe.base as pipeBase
39from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion
40from lsst.pipe.base import AlgorithmError
41from lsst.pipe.base import connectionTypes as ct
43from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
46class IndexMismatchError(AlgorithmError):
47 """Raised if the indices in input associatedSources do not match the input
48 data."""
50 def __init__(self) -> None:
51 super().__init__(
52 "Not all sourceIds in the associated sources catalog are available in the input data."
53 )
55 @property
56 def metadata(self) -> dict:
57 """There is no metadata associated with this error."""
58 return {}
61class NoMatchError(AlgorithmError):
62 """Raised if there are no matches between the source and reference
63 catalogs. This can happen if areas of the source or reference image were
64 not processed successfully."""
66 def __init__(self, targetCatalogSize, refCatalogSize) -> None:
67 self._metadata = {"targetCatalogSize": targetCatalogSize, "refCatalogSize": refCatalogSize}
68 super().__init__("No matches were made between the source and reference catalogs.")
70 @property
71 def metadata(self) -> dict:
72 for key, value in self._metadata.items():
73 if not isinstance(value, int | float | str):
74 raise TypeError(f"{key} is of type {type(value)}, but only (int, float, str) are allowed.")
75 return self._metadata
78class ObjectEpochTableConnections(
79 pipeBase.PipelineTaskConnections,
80 dimensions=("tract", "skymap"),
81):
82 objectCat = ct.Input(
83 doc="Catalog of positions in each patch.",
84 name="objectTable",
85 storageClass="ArrowAstropy",
86 dimensions=["skymap", "tract", "patch"],
87 multiple=True,
88 deferLoad=True,
89 deferGraphConstraint=True,
90 )
92 epochMap = ct.Input(
93 doc="Healsparse map of mean epoch of objectCat in each band.",
94 name="deepCoadd_epoch_map_mean",
95 storageClass="HealSparseMap",
96 dimensions=("skymap", "tract", "band"),
97 multiple=True,
98 deferLoad=True,
99 )
101 objectEpochs = ct.Output(
102 doc="Catalog of epochs for objectCat objects.",
103 name="object_epoch",
104 storageClass="ArrowAstropy",
105 dimensions=["skymap", "tract", "patch"],
106 multiple=True,
107 )
110class ObjectEpochTableConfig(pipeBase.PipelineTaskConfig, pipelineConnections=ObjectEpochTableConnections):
111 bands = pexConfig.ListField(
112 doc=("Bands in objectCat to be combined with `objectCat_selectors` to build objectCat column names."),
113 dtype=str,
114 default=["u", "g", "r", "i", "z", "y"],
115 )
118class ObjectEpochTableTask(pipeBase.PipelineTask):
119 """Collect mean epochs for the observations that went into each object.
121 TODO: DM-46202, Remove this task once the object epochs are available
122 elsewhere.
123 """
125 ConfigClass = ObjectEpochTableConfig
126 _DefaultName = "objectEpochTable"
128 def getEpochs(self, cat, epochMapDict):
129 """Get mean epoch of the visits corresponding to object position.
131 Parameters
132 ----------
133 cat : `astropy.table.Table`
134 Catalog containing object positions.
135 epochMapDict: `dict` [`DeferredDatasetHandle`]
136 Dictionary of handles for healsparse maps containing the mean epoch
137 for positions in the reference catalog.
139 Returns
140 -------
141 epochDf = `astropy.table.Table`
142 Catalog with mean epoch of visits at each object position.
143 """
144 allEpochs = {}
145 for band in self.config.bands:
146 epochs = np.ones(len(cat)) * np.nan
147 validPositions = np.isfinite(cat[f"{band}_ra"]) & np.isfinite(cat[f"{band}_dec"])
148 if validPositions.any():
149 bandEpochs = epochMapDict[band].get_values_pos(
150 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions]
151 )
152 epochsValid = epochMapDict[band].get_values_pos(
153 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions], valid_mask=True
154 )
155 bandEpochs[~epochsValid] = np.nan
156 epochs[validPositions] = bandEpochs
157 allEpochs[f"{band}_epoch"] = epochs
158 allEpochs["objectId"] = cat["objectId"]
160 epochTable = Table(allEpochs)
161 return epochTable
163 def runQuantum(self, butlerQC, inputRefs, outputRefs):
164 inputs = butlerQC.get(inputRefs)
166 columns = [f"{band}_{coord}" for band in self.config.bands for coord in ["ra", "dec"]]
167 columns.append("objectId")
169 inputs["epochMap"] = {ref.dataId["band"]: ref.get() for ref in inputs["epochMap"]}
171 outputEpochRefs = {outputRef.dataId["patch"]: outputRef for outputRef in outputRefs.objectEpochs}
172 for objectCatRef in inputs["objectCat"]:
173 patch = objectCatRef.dataId["patch"]
174 objectCat = objectCatRef.get(parameters={"columns": columns})
175 epochs = self.getEpochs(objectCat, inputs["epochMap"])
176 butlerQC.put(epochs, outputEpochRefs[patch])
179class SourceObjectTableAnalysisConnections(
180 AnalysisBaseConnections,
181 dimensions=("visit",),
182 defaultTemplates={
183 "inputName": "sourceTable_visit",
184 "inputCoaddName": "deep",
185 "associatedSourcesInputName": "isolated_star_presources",
186 "associatedSourceIdsInputName": "isolated_star_presource_associations",
187 "outputName": "sourceObjectTable",
188 },
189):
190 data = ct.Input(
191 doc="Visit based source table to load from the butler",
192 name="sourceTable_visit",
193 storageClass="ArrowAstropy",
194 dimensions=("visit",),
195 deferLoad=True,
196 )
198 associatedSources = ct.Input(
199 doc="Table of associated sources",
200 name="{associatedSourcesInputName}",
201 storageClass="ArrowAstropy",
202 multiple=True,
203 deferLoad=True,
204 dimensions=("instrument", "skymap", "tract"),
205 deferGraphConstraint=True,
206 )
208 associatedSourceIds = ct.Input(
209 doc="Table containing unique ids for the associated sources",
210 name="{associatedSourceIdsInputName}",
211 storageClass="ArrowAstropy",
212 deferLoad=True,
213 multiple=True,
214 dimensions=("instrument", "skymap", "tract"),
215 deferGraphConstraint=True,
216 )
218 refCat = ct.Input(
219 doc="Catalog of positions to use as reference.",
220 name="objectTable",
221 storageClass="DataFrame",
222 dimensions=["skymap", "tract", "patch"],
223 multiple=True,
224 deferLoad=True,
225 deferGraphConstraint=True,
226 )
227 astrometricCorrectionCatalog = ct.Input(
228 doc="Catalog containing proper motions and parallaxes.",
229 name="isolated_star_stellar_motions",
230 storageClass="ArrowAstropy",
231 dimensions=("instrument", "skymap", "tract"),
232 multiple=True,
233 deferLoad=True,
234 )
235 refCatEpochs = ct.Input(
236 doc="Catalog of epochs for refCat objects.",
237 name="object_epoch",
238 storageClass="ArrowAstropy",
239 dimensions=["skymap", "tract", "patch"],
240 multiple=True,
241 deferLoad=True,
242 )
243 visitTable = ct.Input(
244 doc="Catalog containing visit information.",
245 name="visitTable",
246 storageClass="DataFrame",
247 dimensions=("instrument",),
248 )
250 def __init__(self, *, config=None):
251 super().__init__(config=config)
253 if not config.applyAstrometricCorrections:
254 self.inputs.remove("astrometricCorrectionCatalog")
255 self.inputs.remove("refCatEpochs")
256 self.inputs.remove("visitTable")
259class SourceObjectTableAnalysisConfig(
260 AnalysisBaseConfig, pipelineConnections=SourceObjectTableAnalysisConnections
261):
262 ra_column = pexConfig.Field(
263 doc="Name of column in refCat to use for right ascension.",
264 dtype=str,
265 default="r_ra",
266 )
267 dec_column = pexConfig.Field(
268 doc="Name of column in refCat to use for declination.",
269 dtype=str,
270 default="r_dec",
271 )
272 epoch_column = pexConfig.Field(
273 doc=(
274 "Name of column in refCat corresponding to the epoch to which "
275 "sources will be shifted. Should correspond to the positions in "
276 "`ra_column` and `dec_column`."
277 ),
278 dtype=str,
279 default="r_epoch",
280 )
281 refCat_bands = pexConfig.ListField(
282 doc=("Bands in refCat to be combined with `refCat_selectors` to build refCat column names."),
283 dtype=str,
284 default=["u", "g", "r", "i", "z", "y"],
285 )
286 refCat_selectors = pexConfig.ListField(
287 doc=(
288 "Remove objects for which these flags are true. These strings are combined with `refCat_bands`"
289 " to build the full refCat column names"
290 ),
291 dtype=str,
292 default=["pixelFlags_saturated", "pixelFlags_saturatedCenter"],
293 )
294 refCatMatchingRadius = pexConfig.Field(
295 dtype=float,
296 default=1.0,
297 doc=(
298 "Radius in mas with which to match the mean positions of the sources with the positions in the"
299 " reference catalog."
300 ),
301 )
302 applyAstrometricCorrections = pexConfig.Field(
303 dtype=bool,
304 default=True,
305 doc="Apply proper motions and parallaxes to source positions.",
306 )
307 correctionsMatchingRadius = pexConfig.Field(
308 dtype=float,
309 default=0.2,
310 doc=(
311 "Radius in mas with which to match the mean positions of the sources with the positions in the"
312 " astrometricCorrectionCatalog."
313 ),
314 )
315 astrometricCorrectionParameters = pexConfig.DictField(
316 keytype=str,
317 itemtype=str,
318 default={
319 "ra": "ra",
320 "dec": "dec",
321 "pmRA": "raPM",
322 "pmDec": "decPM",
323 "parallax": "parallax",
324 "isolated_star_id": "isolated_star_id",
325 },
326 doc="Column names for position and motion parameters in the astrometric correction catalogs.",
327 )
328 raiseIfNoMatches = pexConfig.Field(
329 dtype=bool,
330 default=True,
331 doc="Raise NoMatchesFound error if there are no matches between the source and object catalogs.",
332 )
334 def setDefaults(self):
335 super().setDefaults()
336 from ..atools import TargetRefCatDeltaColorMetrics
338 self.atools.astromColorDiffMetrics = TargetRefCatDeltaColorMetrics
341class SourceObjectTableAnalysisTask(AnalysisPipelineTask):
342 ConfigClass = SourceObjectTableAnalysisConfig
343 _DefaultName = "sourceObjectTableAnalysis"
345 def callback(self, inputs, dataId):
346 """Callback function to be used with reconstructor."""
347 return self.prepareAssociatedSources(
348 dataId["visit"],
349 inputs["data"],
350 inputs["associatedSources"],
351 inputs["associatedSourceIds"],
352 inputs["refCat"],
353 inputs["visitTable"],
354 inputs["astrometricCorrectionCatalog"],
355 )
357 def applyAstrometricCorrections(
358 self, isolatedSources, astrometricCorrectionCatalog, visitTable, visit, refEpochs
359 ):
360 """Shift source positions to match the epoch of the reference catalog
361 objects.
363 Parameters
364 ----------
365 isolatedSources : `astropy.table.Table`
366 Catalog of sources which will be modified in place with the
367 astrometric corrections.
368 astrometricCorrectionCatalog : `astropy.table.Table`
369 Catalog with proper motion and parallax information.
370 visitTable : `pd.DataFrame`
371 Catalog containing the epoch for the visit corresponding to the
372 isolatedSources.
373 visit : `int`
374 Identifier of the isolatedSources' visit.
375 """
376 if visitTable.index.name is None:
377 # The expected index may or may not be set, depending on whether
378 # the table was written originally as a DataFrame or something else
379 # Parquet-friendly.
380 visitTable.set_index("visitId", inplace=True)
381 sourceMjd = visitTable.loc[visit]["expMidptMJD"]
383 # Get target date from reference catalog
384 targetEpochs = refEpochs.to_numpy()
385 # There may not be a valid reference epoch on the edge of a given
386 # region. Do not make an astrometric correction for any sources on the
387 # edge.
388 targetEpochs[~np.isfinite(targetEpochs)] = sourceMjd
390 # Get the stellar motion catalog into the right format:
391 for key, value in self.config.astrometricCorrectionParameters.items():
392 astrometricCorrectionCatalog.rename_column(value, key)
393 astrometricCorrectionCatalog["ra"] *= u.degree
394 astrometricCorrectionCatalog["dec"] *= u.degree
395 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr
396 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr
397 astrometricCorrectionCatalog["parallax"] *= u.mas
399 joinedData = join(
400 isolatedSources[["isolated_star_id"]],
401 astrometricCorrectionCatalog,
402 keys="isolated_star_id",
403 join_type="left",
404 keep_order=True,
405 metadata_conflicts="silent",
406 )
407 joinedData["MJD"] = astropy.time.Time(sourceMjd, format="mjd", scale="tai")
409 raCorrection, decCorrection = calculate_apparent_motion(
410 joinedData, astropy.time.Time(targetEpochs, format="mjd", scale="tai")
411 )
413 isolatedSources["coord_ra"] -= raCorrection.value
414 isolatedSources["coord_dec"] -= decCorrection.value
416 def prepareAssociatedSources(
417 self,
418 visit,
419 data,
420 associatedSourceRefs,
421 associatedSourceIdRefs,
422 refCats,
423 visitTable,
424 astrometricCorrectionCatalog,
425 ):
426 """Match isolated sources with reference objects and shift the sources
427 to the object epochs if `self.config.applyAstrometricCorrections` is
428 True.
430 Parameters
431 ----------
432 visit : `int`
433 Identifier of the visit corresponding to the data.
434 data : `astropy.table.Table`
435 Catalog of sources to be associated.
436 associatedSourceRefs : `list` [`DeferredDatasetHandle`]
437 Handle for the catalogs of isolated sources. There will be multiple
438 if the visit overlaps with multiple tracts.
439 refCats : `list` [`pd.DataFrame`]
440 Catalog of objects with which the sources will be compared.
441 visitTable : `pd.DataFrame`
442 Catalog containing the epoch for the visit corresponding to the
443 isolatedSources.
444 astrometricCorrectionCatalog : `astropy.table.Table`
445 Catalog with proper motion and parallax information.
446 """
447 isolatedSources = []
448 associatedSourceIds = {
449 ref.dataId["tract"]: ref.get(parameters={"columns": ["isolated_star_id"]})
450 for ref in associatedSourceIdRefs
451 }
452 for associatedSourceRef in associatedSourceRefs:
453 tract = associatedSourceRef.dataId["tract"]
454 associatedSources = associatedSourceRef.get(
455 parameters={"columns": ["visit", "sourceId", "obj_index"]}
456 )
457 index = associatedSources["obj_index"]
458 associatedSources["isolated_star_id"] = associatedSourceIds[tract]["isolated_star_id"][index]
460 visit_sources = associatedSources[associatedSources["visit"] == visit]
461 try:
462 visitData = data.loc[visit_sources["sourceId"]]
463 visitData["isolated_star_id"] = visit_sources["isolated_star_id"]
464 isolatedSources.append(visitData)
465 except KeyError:
466 raise IndexMismatchError()
467 isolatedSources = vstack(isolatedSources)
469 if len(isolatedSources) == 0:
470 raise pipeBase.NoWorkFound(f"No isolated sources found for visit {visit}")
472 with Matcher(np.asarray(isolatedSources["coord_ra"]), np.asarray(isolatedSources["coord_dec"])) as m:
473 idx, isolatedMatchIndices, refMatchIndices, dists = m.query_radius(
474 np.asarray(refCats[self.config.ra_column]),
475 np.asarray(refCats[self.config.dec_column]),
476 self.config.refCatMatchingRadius / 3600.0,
477 return_indices=True,
478 )
480 matchIS = isolatedSources[isolatedMatchIndices]
482 if len(matchIS) == 0 and self.config.raiseIfNoMatches:
483 raise NoMatchError(len(isolatedSources), len(refCats))
485 # Apply proper motions and parallaxes to visit sources.
486 if self.config.applyAstrometricCorrections:
487 refCatEpochs = refCats[self.config.epoch_column].iloc[refMatchIndices]
488 self.applyAstrometricCorrections(
489 matchIS, astrometricCorrectionCatalog, visitTable, visit, refCatEpochs
490 )
492 matchRef = refCats.iloc[refMatchIndices]
493 matchIS = matchIS.to_pandas()
495 allCat = pd.concat([matchRef.reset_index(), matchIS.reset_index()], axis=1)
496 return allCat
498 def runQuantum(self, butlerQC, inputRefs, outputRefs):
499 inputs = butlerQC.get(inputRefs)
501 dataId = butlerQC.quantum.dataId
502 plotInfo = self.parsePlotInfo(inputs, dataId)
504 # Get isolated sources:
505 visit = inputs["data"].dataId["visit"]
506 band = inputs["data"].dataId["band"]
507 names = self.collectInputNames()
508 names -= {self.config.ra_column, self.config.dec_column}
509 names.add("sourceId")
510 data = inputs["data"].get(parameters={"columns": names})
511 data.add_index("sourceId")
512 inputs["data"] = data
514 if self.config.applyAstrometricCorrections:
515 refCatEpochs = {
516 epochTable.dataId["patch"]: epochTable.get() for epochTable in inputs["refCatEpochs"]
517 }
518 # Get objects:
519 allRefCats = []
520 refCatSelectors = [
521 f"{refCatBand}_{selector}"
522 for refCatBand in self.config.refCat_bands
523 for selector in self.config.refCat_selectors
524 ]
526 for refCatRef in inputs["refCat"]:
527 refCat = refCatRef.get(
528 parameters={
529 "columns": ["detect_isPrimary", self.config.ra_column, self.config.dec_column, "objectId"]
530 + refCatSelectors
531 }
532 )
533 refCat.set_index("objectId")
534 if self.config.applyAstrometricCorrections:
535 refCat = pd.merge(refCat, refCatEpochs[refCatRef.dataId["patch"]].to_pandas(), on="objectId")
536 goodInds = (
537 refCat["detect_isPrimary"]
538 & np.isfinite(refCat[self.config.ra_column])
539 & np.isfinite(refCat[self.config.dec_column])
540 )
541 goodInds &= ~refCat[refCatSelectors].any(axis=1)
542 allRefCats.append(refCat[goodInds])
544 refCat = pd.concat(allRefCats)
545 inputs["refCat"] = refCat
546 if len(refCat) == 0:
547 raise pipeBase.NoWorkFound(f"No reference catalog objects found to associate with visit {visit}")
549 if self.config.applyAstrometricCorrections:
550 pmCats = []
551 for astrometricCorrectionCatalogRef in inputs["astrometricCorrectionCatalog"]:
552 pmCat = astrometricCorrectionCatalogRef.get(
553 parameters={"columns": self.config.astrometricCorrectionParameters.values()}
554 )
555 pmCats.append(pmCat)
556 inputs["astrometricCorrectionCatalog"] = vstack(pmCats, metadata_conflicts="silent")
557 else:
558 inputs["astrometricCorrectionCatalog"] = None
559 inputs["visitTable"] = None
561 try:
562 allCat = self.callback(inputs, dataId)
563 except pipeBase.AlgorithmError as e:
564 error = pipeBase.AnnotatedPartialOutputsError.annotate(e, self, log=self.log)
565 raise error from e
567 outputs = self.run(data=allCat, bands=band, plotInfo=plotInfo)
568 butlerQC.put(outputs, outputRefs)