Coverage for python / lsst / analysis / tools / tasks / sourceObjectTableAnalysis.py: 27%
189 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "SourceObjectTableAnalysisConfig",
25 "SourceObjectTableAnalysisTask",
26 "ObjectEpochTableConfig",
27 "ObjectEpochTableTask",
28)
30import astropy.time
31import astropy.units as u
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import numpy as np
35import pandas as pd
36from astropy.table import Table, join, vstack
37from lsst.drp.tasks.gbdesAstrometricFit import calculate_apparent_motion
38from lsst.pipe.base import AlgorithmError
39from lsst.pipe.base import connectionTypes as ct
40from smatch import Matcher
42from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
45class IndexMismatchError(AlgorithmError):
46 """Raised if the indices in input associatedSources do not match the input
47 data."""
49 def __init__(self) -> None:
50 super().__init__(
51 "Not all sourceIds in the associated sources catalog are available in the input data."
52 )
54 @property
55 def metadata(self) -> dict:
56 """There is no metadata associated with this error."""
57 return {}
60class NoMatchError(AlgorithmError):
61 """Raised if there are no matches between the source and reference
62 catalogs. This can happen if areas of the source or reference image were
63 not processed successfully."""
65 def __init__(self, targetCatalogSize, refCatalogSize) -> None:
66 self._metadata = {"targetCatalogSize": targetCatalogSize, "refCatalogSize": refCatalogSize}
67 super().__init__("No matches were made between the source and reference catalogs.")
69 @property
70 def metadata(self) -> dict:
71 for key, value in self._metadata.items():
72 if not isinstance(value, int | float | str):
73 raise TypeError(f"{key} is of type {type(value)}, but only (int, float, str) are allowed.")
74 return self._metadata
77class ObjectEpochTableConnections(
78 pipeBase.PipelineTaskConnections,
79 dimensions=("tract", "skymap"),
80):
81 objectCat = ct.Input(
82 doc="Catalog of positions in each patch.",
83 name="objectTable",
84 storageClass="ArrowAstropy",
85 dimensions=["skymap", "tract", "patch"],
86 multiple=True,
87 deferLoad=True,
88 deferGraphConstraint=True,
89 )
91 epochMap = ct.Input(
92 doc="Healsparse map of mean epoch of objectCat in each band.",
93 name="deepCoadd_epoch_map_mean",
94 storageClass="HealSparseMap",
95 dimensions=("skymap", "tract", "band"),
96 multiple=True,
97 deferLoad=True,
98 )
100 objectEpochs = ct.Output(
101 doc="Catalog of epochs for objectCat objects.",
102 name="object_epoch",
103 storageClass="ArrowAstropy",
104 dimensions=["skymap", "tract", "patch"],
105 multiple=True,
106 )
109class ObjectEpochTableConfig(pipeBase.PipelineTaskConfig, pipelineConnections=ObjectEpochTableConnections):
110 bands = pexConfig.ListField(
111 doc=("Bands in objectCat to be combined with `objectCat_selectors` to build objectCat column names."),
112 dtype=str,
113 default=["u", "g", "r", "i", "z", "y"],
114 )
117class ObjectEpochTableTask(pipeBase.PipelineTask):
118 """Collect mean epochs for the observations that went into each object.
120 TODO: DM-46202, Remove this task once the object epochs are available
121 elsewhere.
122 """
124 ConfigClass = ObjectEpochTableConfig
125 _DefaultName = "objectEpochTable"
127 def getEpochs(self, cat, epochMapDict):
128 """Get mean epoch of the visits corresponding to object position.
130 Parameters
131 ----------
132 cat : `astropy.table.Table`
133 Catalog containing object positions.
134 epochMapDict: `dict` [`DeferredDatasetHandle`]
135 Dictionary of handles for healsparse maps containing the mean epoch
136 for positions in the reference catalog.
138 Returns
139 -------
140 epochDf = `astropy.table.Table`
141 Catalog with mean epoch of visits at each object position.
142 """
143 allEpochs = {}
144 for band in self.config.bands:
145 epochs = np.ones(len(cat)) * np.nan
146 validPositions = np.isfinite(cat[f"{band}_ra"]) & np.isfinite(cat[f"{band}_dec"])
147 if validPositions.any():
148 bandEpochs = epochMapDict[band].get_values_pos(
149 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions]
150 )
151 epochsValid = epochMapDict[band].get_values_pos(
152 cat[f"{band}_ra"][validPositions], cat[f"{band}_dec"][validPositions], valid_mask=True
153 )
154 bandEpochs[~epochsValid] = np.nan
155 epochs[validPositions] = bandEpochs
156 allEpochs[f"{band}_epoch"] = epochs
157 allEpochs["objectId"] = cat["objectId"]
159 epochTable = Table(allEpochs)
160 return epochTable
162 def runQuantum(self, butlerQC, inputRefs, outputRefs):
163 inputs = butlerQC.get(inputRefs)
165 columns = [f"{band}_{coord}" for band in self.config.bands for coord in ["ra", "dec"]]
166 columns.append("objectId")
168 inputs["epochMap"] = {ref.dataId["band"]: ref.get() for ref in inputs["epochMap"]}
170 outputEpochRefs = {outputRef.dataId["patch"]: outputRef for outputRef in outputRefs.objectEpochs}
171 for objectCatRef in inputs["objectCat"]:
172 patch = objectCatRef.dataId["patch"]
173 objectCat = objectCatRef.get(parameters={"columns": columns})
174 epochs = self.getEpochs(objectCat, inputs["epochMap"])
175 butlerQC.put(epochs, outputEpochRefs[patch])
178class SourceObjectTableAnalysisConnections(
179 AnalysisBaseConnections,
180 dimensions=("visit",),
181 defaultTemplates={
182 "inputName": "sourceTable_visit",
183 "inputCoaddName": "deep",
184 "associatedSourcesInputName": "isolated_star_presources",
185 "associatedSourceIdsInputName": "isolated_star_presource_associations",
186 "outputName": "sourceObjectTable",
187 },
188):
189 data = ct.Input(
190 doc="Visit based source table to load from the butler",
191 name="sourceTable_visit",
192 storageClass="ArrowAstropy",
193 dimensions=("visit",),
194 deferLoad=True,
195 )
197 associatedSources = ct.Input(
198 doc="Table of associated sources",
199 name="{associatedSourcesInputName}",
200 storageClass="ArrowAstropy",
201 multiple=True,
202 deferLoad=True,
203 dimensions=("instrument", "skymap", "tract"),
204 deferGraphConstraint=True,
205 )
207 associatedSourceIds = ct.Input(
208 doc="Table containing unique ids for the associated sources",
209 name="{associatedSourceIdsInputName}",
210 storageClass="ArrowAstropy",
211 deferLoad=True,
212 multiple=True,
213 dimensions=("instrument", "skymap", "tract"),
214 deferGraphConstraint=True,
215 )
217 refCat = ct.Input(
218 doc="Catalog of positions to use as reference.",
219 name="objectTable",
220 storageClass="DataFrame",
221 dimensions=["skymap", "tract", "patch"],
222 multiple=True,
223 deferLoad=True,
224 deferGraphConstraint=True,
225 )
226 astrometricCorrectionCatalog = ct.Input(
227 doc="Catalog containing proper motions and parallaxes.",
228 name="isolated_star_stellar_motions",
229 storageClass="ArrowAstropy",
230 dimensions=("instrument", "skymap", "tract"),
231 multiple=True,
232 deferLoad=True,
233 )
234 refCatEpochs = ct.Input(
235 doc="Catalog of epochs for refCat objects.",
236 name="object_epoch",
237 storageClass="ArrowAstropy",
238 dimensions=["skymap", "tract", "patch"],
239 multiple=True,
240 deferLoad=True,
241 )
242 visitTable = ct.Input(
243 doc="Catalog containing visit information.",
244 name="visitTable",
245 storageClass="DataFrame",
246 dimensions=("instrument",),
247 )
249 def __init__(self, *, config=None):
250 super().__init__(config=config)
252 if not config.applyAstrometricCorrections:
253 self.inputs.remove("astrometricCorrectionCatalog")
254 self.inputs.remove("refCatEpochs")
255 self.inputs.remove("visitTable")
258class SourceObjectTableAnalysisConfig(
259 AnalysisBaseConfig, pipelineConnections=SourceObjectTableAnalysisConnections
260):
261 ra_column = pexConfig.Field(
262 doc="Name of column in refCat to use for right ascension.",
263 dtype=str,
264 default="r_ra",
265 )
266 dec_column = pexConfig.Field(
267 doc="Name of column in refCat to use for declination.",
268 dtype=str,
269 default="r_dec",
270 )
271 epoch_column = pexConfig.Field(
272 doc=(
273 "Name of column in refCat corresponding to the epoch to which "
274 "sources will be shifted. Should correspond to the positions in "
275 "`ra_column` and `dec_column`."
276 ),
277 dtype=str,
278 default="r_epoch",
279 )
280 refCat_bands = pexConfig.ListField(
281 doc=("Bands in refCat to be combined with `refCat_selectors` to build refCat column names."),
282 dtype=str,
283 default=["u", "g", "r", "i", "z", "y"],
284 )
285 refCat_selectors = pexConfig.ListField(
286 doc=(
287 "Remove objects for which these flags are true. These strings are combined with `refCat_bands`"
288 " to build the full refCat column names"
289 ),
290 dtype=str,
291 default=["pixelFlags_saturated", "pixelFlags_saturatedCenter"],
292 )
293 refCatMatchingRadius = pexConfig.Field(
294 dtype=float,
295 default=1.0,
296 doc=(
297 "Radius in mas with which to match the mean positions of the sources with the positions in the"
298 " reference catalog."
299 ),
300 )
301 applyAstrometricCorrections = pexConfig.Field(
302 dtype=bool,
303 default=True,
304 doc="Apply proper motions and parallaxes to source positions.",
305 )
306 correctionsMatchingRadius = pexConfig.Field(
307 dtype=float,
308 default=0.2,
309 doc=(
310 "Radius in mas with which to match the mean positions of the sources with the positions in the"
311 " astrometricCorrectionCatalog."
312 ),
313 )
314 astrometricCorrectionParameters = pexConfig.DictField(
315 keytype=str,
316 itemtype=str,
317 default={
318 "ra": "ra",
319 "dec": "dec",
320 "pmRA": "raPM",
321 "pmDec": "decPM",
322 "parallax": "parallax",
323 "isolated_star_id": "isolated_star_id",
324 },
325 doc="Column names for position and motion parameters in the astrometric correction catalogs.",
326 )
327 raiseIfNoMatches = pexConfig.Field(
328 dtype=bool,
329 default=True,
330 doc="Raise NoMatchesFound error if there are no matches between the source and object catalogs.",
331 )
333 def setDefaults(self):
334 super().setDefaults()
335 from ..atools import TargetRefCatDeltaColorMetrics
337 self.atools.astromColorDiffMetrics = TargetRefCatDeltaColorMetrics
340class SourceObjectTableAnalysisTask(AnalysisPipelineTask):
341 ConfigClass = SourceObjectTableAnalysisConfig
342 _DefaultName = "sourceObjectTableAnalysis"
344 def callback(self, inputs, dataId):
345 """Callback function to be used with reconstructor."""
346 return self.prepareAssociatedSources(
347 dataId["visit"],
348 inputs["data"],
349 inputs["associatedSources"],
350 inputs["associatedSourceIds"],
351 inputs["refCat"],
352 inputs["visitTable"],
353 inputs["astrometricCorrectionCatalog"],
354 )
356 def applyAstrometricCorrections(
357 self, isolatedSources, astrometricCorrectionCatalog, visitTable, visit, refEpochs
358 ):
359 """Shift source positions to match the epoch of the reference catalog
360 objects.
362 Parameters
363 ----------
364 isolatedSources : `astropy.table.Table`
365 Catalog of sources which will be modified in place with the
366 astrometric corrections.
367 astrometricCorrectionCatalog : `astropy.table.Table`
368 Catalog with proper motion and parallax information.
369 visitTable : `pd.DataFrame`
370 Catalog containing the epoch for the visit corresponding to the
371 isolatedSources.
372 visit : `int`
373 Identifier of the isolatedSources' visit.
374 """
375 if visitTable.index.name is None:
376 # The expected index may or may not be set, depending on whether
377 # the table was written originally as a DataFrame or something else
378 # Parquet-friendly.
379 visitTable.set_index("visitId", inplace=True)
380 sourceMjd = visitTable.loc[visit]["expMidptMJD"]
382 # Get target date from reference catalog
383 targetEpochs = refEpochs.to_numpy()
384 # There may not be a valid reference epoch on the edge of a given
385 # region. Do not make an astrometric correction for any sources on the
386 # edge.
387 targetEpochs[~np.isfinite(targetEpochs)] = sourceMjd
389 # Get the stellar motion catalog into the right format:
390 for key, value in self.config.astrometricCorrectionParameters.items():
391 astrometricCorrectionCatalog.rename_column(value, key)
392 astrometricCorrectionCatalog["ra"] *= u.degree
393 astrometricCorrectionCatalog["dec"] *= u.degree
394 astrometricCorrectionCatalog["pmRA"] *= u.mas / u.yr
395 astrometricCorrectionCatalog["pmDec"] *= u.mas / u.yr
396 astrometricCorrectionCatalog["parallax"] *= u.mas
398 joinedData = join(
399 isolatedSources[["isolated_star_id"]],
400 astrometricCorrectionCatalog,
401 keys="isolated_star_id",
402 join_type="left",
403 keep_order=True,
404 metadata_conflicts="silent",
405 )
406 joinedData["MJD"] = astropy.time.Time(sourceMjd, format="mjd", scale="tai")
408 raCorrection, decCorrection = calculate_apparent_motion(
409 joinedData, astropy.time.Time(targetEpochs, format="mjd", scale="tai")
410 )
412 isolatedSources["coord_ra"] -= raCorrection.value
413 isolatedSources["coord_dec"] -= decCorrection.value
415 def prepareAssociatedSources(
416 self,
417 visit,
418 data,
419 associatedSourceRefs,
420 associatedSourceIdRefs,
421 refCats,
422 visitTable,
423 astrometricCorrectionCatalog,
424 ):
425 """Match isolated sources with reference objects and shift the sources
426 to the object epochs if `self.config.applyAstrometricCorrections` is
427 True.
429 Parameters
430 ----------
431 visit : `int`
432 Identifier of the visit corresponding to the data.
433 data : `astropy.table.Table`
434 Catalog of sources to be associated.
435 associatedSourceRefs : `list` [`DeferredDatasetHandle`]
436 Handle for the catalogs of isolated sources. There will be multiple
437 if the visit overlaps with multiple tracts.
438 refCats : `list` [`pd.DataFrame`]
439 Catalog of objects with which the sources will be compared.
440 visitTable : `pd.DataFrame`
441 Catalog containing the epoch for the visit corresponding to the
442 isolatedSources.
443 astrometricCorrectionCatalog : `astropy.table.Table`
444 Catalog with proper motion and parallax information.
445 """
446 isolatedSources = []
447 associatedSourceIds = {
448 ref.dataId["tract"]: ref.get(parameters={"columns": ["isolated_star_id"]})
449 for ref in associatedSourceIdRefs
450 }
451 for associatedSourceRef in associatedSourceRefs:
452 tract = associatedSourceRef.dataId["tract"]
453 associatedSources = associatedSourceRef.get(
454 parameters={"columns": ["visit", "sourceId", "obj_index"]}
455 )
456 index = associatedSources["obj_index"]
457 associatedSources["isolated_star_id"] = associatedSourceIds[tract]["isolated_star_id"][index]
459 visit_sources = associatedSources[associatedSources["visit"] == visit]
460 try:
461 visitData = data.loc[visit_sources["sourceId"]]
462 visitData["isolated_star_id"] = visit_sources["isolated_star_id"]
463 isolatedSources.append(visitData)
464 except KeyError:
465 raise IndexMismatchError()
466 isolatedSources = vstack(isolatedSources)
468 if len(isolatedSources) == 0:
469 raise pipeBase.NoWorkFound(f"No isolated sources found for visit {visit}")
471 with Matcher(np.asarray(isolatedSources["coord_ra"]), np.asarray(isolatedSources["coord_dec"])) as m:
472 idx, isolatedMatchIndices, refMatchIndices, dists = m.query_radius(
473 np.asarray(refCats[self.config.ra_column]),
474 np.asarray(refCats[self.config.dec_column]),
475 self.config.refCatMatchingRadius / 3600.0,
476 return_indices=True,
477 )
479 matchIS = isolatedSources[isolatedMatchIndices]
481 if len(matchIS) == 0 and self.config.raiseIfNoMatches:
482 raise NoMatchError(len(isolatedSources), len(refCats))
484 # Apply proper motions and parallaxes to visit sources.
485 if self.config.applyAstrometricCorrections:
486 refCatEpochs = refCats[self.config.epoch_column].iloc[refMatchIndices]
487 self.applyAstrometricCorrections(
488 matchIS, astrometricCorrectionCatalog, visitTable, visit, refCatEpochs
489 )
491 matchRef = refCats.iloc[refMatchIndices]
492 matchIS = matchIS.to_pandas()
494 allCat = pd.concat([matchRef.reset_index(), matchIS.reset_index()], axis=1)
495 return allCat
497 def runQuantum(self, butlerQC, inputRefs, outputRefs):
498 inputs = butlerQC.get(inputRefs)
500 dataId = butlerQC.quantum.dataId
501 plotInfo = self.parsePlotInfo(inputs, dataId)
503 # Get isolated sources:
504 visit = inputs["data"].dataId["visit"]
505 band = inputs["data"].dataId["band"]
506 names = self.collectInputNames()
507 names -= {self.config.ra_column, self.config.dec_column}
508 names.add("sourceId")
509 data = inputs["data"].get(parameters={"columns": names})
510 data.add_index("sourceId")
511 inputs["data"] = data
513 if self.config.applyAstrometricCorrections:
514 refCatEpochs = {
515 epochTable.dataId["patch"]: epochTable.get() for epochTable in inputs["refCatEpochs"]
516 }
517 # Get objects:
518 allRefCats = []
519 refCatSelectors = [
520 f"{refCatBand}_{selector}"
521 for refCatBand in self.config.refCat_bands
522 for selector in self.config.refCat_selectors
523 ]
525 for refCatRef in inputs["refCat"]:
526 refCat = refCatRef.get(
527 parameters={
528 "columns": ["detect_isPrimary", self.config.ra_column, self.config.dec_column, "objectId"]
529 + refCatSelectors
530 }
531 )
532 refCat.set_index("objectId")
533 if self.config.applyAstrometricCorrections:
534 refCat = pd.merge(refCat, refCatEpochs[refCatRef.dataId["patch"]].to_pandas(), on="objectId")
535 goodInds = (
536 refCat["detect_isPrimary"]
537 & np.isfinite(refCat[self.config.ra_column])
538 & np.isfinite(refCat[self.config.dec_column])
539 )
540 goodInds &= ~refCat[refCatSelectors].any(axis=1)
541 allRefCats.append(refCat[goodInds])
543 refCat = pd.concat(allRefCats)
544 inputs["refCat"] = refCat
545 if len(refCat) == 0:
546 raise pipeBase.NoWorkFound(f"No reference catalog objects found to associate with visit {visit}")
548 if self.config.applyAstrometricCorrections:
549 pmCats = []
550 for astrometricCorrectionCatalogRef in inputs["astrometricCorrectionCatalog"]:
551 pmCat = astrometricCorrectionCatalogRef.get(
552 parameters={"columns": self.config.astrometricCorrectionParameters.values()}
553 )
554 pmCats.append(pmCat)
555 inputs["astrometricCorrectionCatalog"] = vstack(pmCats, metadata_conflicts="silent")
556 else:
557 inputs["astrometricCorrectionCatalog"] = None
558 inputs["visitTable"] = None
560 try:
561 allCat = self.callback(inputs, dataId)
562 except pipeBase.AlgorithmError as e:
563 error = pipeBase.AnnotatedPartialOutputsError.annotate(e, self, log=self.log)
564 raise error from e
566 outputs = self.run(data=allCat, bands=band, plotInfo=plotInfo)
567 butlerQC.put(outputs, outputRefs)