Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 46%
58 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 12:29 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 12:29 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask")
25import numpy as np
26import pandas as pd
27from lsst.geom import Box2D
28from lsst.pipe.base import connectionTypes as ct
29from lsst.skymap import BaseSkyMap
31from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
34class AssociatedSourcesTractAnalysisConnections(
35 AnalysisBaseConnections,
36 dimensions=("skymap", "tract", "instrument"),
37 defaultTemplates={
38 "outputName": "isolated_star_sources",
39 "associatedSourcesInputName": "isolated_star_sources",
40 },
41):
42 sourceCatalogs = ct.Input(
43 doc="Visit based source table to load from the butler",
44 name="sourceTable_visit",
45 storageClass="DataFrame",
46 deferLoad=True,
47 dimensions=("visit", "band"),
48 multiple=True,
49 )
51 associatedSources = ct.Input(
52 doc="Table of associated sources",
53 name="{associatedSourcesInputName}",
54 storageClass="DataFrame",
55 deferLoad=True,
56 dimensions=("instrument", "skymap", "tract"),
57 )
59 skyMap = ct.Input(
60 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
62 storageClass="SkyMap",
63 dimensions=("skymap",),
64 )
65 camera = ct.PrerequisiteInput(
66 doc="Input camera to use for focal plane geometry.",
67 name="camera",
68 storageClass="Camera",
69 dimensions=("instrument",),
70 isCalibration=True,
71 )
74class AssociatedSourcesTractAnalysisConfig(
75 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections
76):
77 def setDefaults(self):
78 super().setDefaults()
81class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask):
82 ConfigClass = AssociatedSourcesTractAnalysisConfig
83 _DefaultName = "associatedSourcesTractAnalysis"
85 @staticmethod
86 def getBoxWcs(skymap, tract):
87 """Get box that defines tract boundaries."""
88 tractInfo = skymap.generateTract(tract)
89 wcs = tractInfo.getWcs()
90 tractBox = tractInfo.getBBox()
91 return tractBox, wcs
93 @classmethod
94 def callback(cls, inputs, dataId):
95 """Callback function to be used with reconstructor."""
96 return cls.prepareAssociatedSources(
97 inputs["skyMap"],
98 dataId["tract"],
99 inputs["sourceCatalogs"],
100 inputs["associatedSources"],
101 )
103 @classmethod
104 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources):
105 """Concatenate source catalogs and join on associated object index."""
107 # Keep only sources with associations
108 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner")
109 dataJoined.set_index("sourceId", inplace=True)
111 # Determine which sources are contained in tract
112 ra = np.radians(dataJoined["coord_ra"].values)
113 dec = np.radians(dataJoined["coord_dec"].values)
114 box, wcs = cls.getBoxWcs(skymap, tract)
115 box = Box2D(box)
116 x, y = wcs.skyToPixelArray(ra, dec)
117 boxSelection = box.contains(x, y)
119 # Keep only the sources in groups that are fully contained within the
120 # tract
121 dataJoined["boxSelection"] = boxSelection
122 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"]))
123 dataFiltered.drop(columns="boxSelection", inplace=True)
125 return dataFiltered
127 def runQuantum(self, butlerQC, inputRefs, outputRefs):
128 inputs = butlerQC.get(inputRefs)
130 # Load specified columns from source catalogs
131 names = self.collectInputNames()
132 names |= {"sourceId", "coord_ra", "coord_dec"}
133 names.remove("obj_index")
134 sourceCatalogs = []
135 for handle in inputs["sourceCatalogs"]:
136 sourceCatalogs.append(self.loadData(handle, names))
137 inputs["sourceCatalogs"] = sourceCatalogs
139 dataId = butlerQC.quantum.dataId
140 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources")
142 # TODO: make key used for object index configurable
143 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"])
145 data = self.callback(inputs, dataId)
147 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]}
148 outputs = self.run(**kwargs)
149 butlerQC.put(outputs, outputRefs)