Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 40%
58 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 03:18 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 03:18 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask")
25import numpy as np
26import pandas as pd
27from lsst.cp.pipe._lookupStaticCalibration import lookupStaticCalibration
28from lsst.geom import Box2D
29from lsst.pipe.base import connectionTypes as ct
31from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
34class AssociatedSourcesTractAnalysisConnections(
35 AnalysisBaseConnections,
36 dimensions=("skymap", "tract", "instrument"),
37 defaultTemplates={
38 "outputName": "isolated_star_sources",
39 "associatedSourcesInputName": "isolated_star_sources",
40 },
41):
42 sourceCatalogs = ct.Input(
43 doc="Visit based source table to load from the butler",
44 name="sourceTable_visit",
45 storageClass="DataFrame",
46 deferLoad=True,
47 dimensions=("visit", "band"),
48 multiple=True,
49 )
51 associatedSources = ct.Input(
52 doc="Table of associated sources",
53 name="{associatedSourcesInputName}",
54 storageClass="DataFrame",
55 deferLoad=True,
56 dimensions=("instrument", "skymap", "tract"),
57 )
59 skyMap = ct.Input(
60 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
61 name="skyMap",
62 storageClass="SkyMap",
63 dimensions=("skymap",),
64 )
65 camera = ct.PrerequisiteInput(
66 doc="Input camera to use for focal plane geometry.",
67 name="camera",
68 storageClass="Camera",
69 dimensions=("instrument",),
70 isCalibration=True,
71 lookupFunction=lookupStaticCalibration,
72 )
75class AssociatedSourcesTractAnalysisConfig(
76 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections
77):
78 def setDefaults(self):
79 super().setDefaults()
82class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask):
83 ConfigClass = AssociatedSourcesTractAnalysisConfig
84 _DefaultName = "associatedSourcesTractAnalysis"
86 @staticmethod
87 def getBoxWcs(skymap, tract):
88 """Get box that defines tract boundaries."""
89 tractInfo = skymap.generateTract(tract)
90 wcs = tractInfo.getWcs()
91 tractBox = tractInfo.getBBox()
92 return tractBox, wcs
94 @classmethod
95 def callback(cls, inputs, dataId):
96 """Callback function to be used with reconstructor."""
97 return cls.prepareAssociatedSources(
98 inputs["skyMap"],
99 dataId["tract"],
100 inputs["sourceCatalogs"],
101 inputs["associatedSources"],
102 )
104 @classmethod
105 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources):
106 """Concatenate source catalogs and join on associated object index."""
108 # Keep only sources with associations
109 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner")
110 dataJoined.set_index("sourceId", inplace=True)
112 # Determine which sources are contained in tract
113 ra = np.radians(dataJoined["coord_ra"].values)
114 dec = np.radians(dataJoined["coord_dec"].values)
115 box, wcs = cls.getBoxWcs(skymap, tract)
116 box = Box2D(box)
117 x, y = wcs.skyToPixelArray(ra, dec)
118 boxSelection = box.contains(x, y)
120 # Keep only the sources in groups that are fully contained within the
121 # tract
122 dataJoined["boxSelection"] = boxSelection
123 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"]))
124 dataFiltered.drop(columns="boxSelection", inplace=True)
126 return dataFiltered
128 def runQuantum(self, butlerQC, inputRefs, outputRefs):
129 inputs = butlerQC.get(inputRefs)
131 # Load specified columns from source catalogs
132 names = self.collectInputNames()
133 names |= {"sourceId", "coord_ra", "coord_dec"}
134 names.remove("obj_index")
135 sourceCatalogs = []
136 for handle in inputs["sourceCatalogs"]:
137 sourceCatalogs.append(self.loadData(handle, names))
138 inputs["sourceCatalogs"] = sourceCatalogs
140 dataId = butlerQC.quantum.dataId
141 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources")
143 # TODO: make key used for object index configurable
144 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"])
146 data = self.callback(inputs, dataId)
148 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]}
149 outputs = self.run(**kwargs)
150 butlerQC.put(outputs, outputRefs)