Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 41%
59 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 03:43 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 03:43 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask")
25import numpy as np
26import pandas as pd
27from lsst.cp.pipe._lookupStaticCalibration import lookupStaticCalibration
28from lsst.geom import Box2D
29from lsst.pipe.base import connectionTypes as ct
30from lsst.skymap import BaseSkyMap
32from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
35class AssociatedSourcesTractAnalysisConnections(
36 AnalysisBaseConnections,
37 dimensions=("skymap", "tract", "instrument"),
38 defaultTemplates={
39 "outputName": "isolated_star_sources",
40 "associatedSourcesInputName": "isolated_star_sources",
41 },
42):
43 sourceCatalogs = ct.Input(
44 doc="Visit based source table to load from the butler",
45 name="sourceTable_visit",
46 storageClass="DataFrame",
47 deferLoad=True,
48 dimensions=("visit", "band"),
49 multiple=True,
50 )
52 associatedSources = ct.Input(
53 doc="Table of associated sources",
54 name="{associatedSourcesInputName}",
55 storageClass="DataFrame",
56 deferLoad=True,
57 dimensions=("instrument", "skymap", "tract"),
58 )
60 skyMap = ct.Input(
61 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
62 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
63 storageClass="SkyMap",
64 dimensions=("skymap",),
65 )
66 camera = ct.PrerequisiteInput(
67 doc="Input camera to use for focal plane geometry.",
68 name="camera",
69 storageClass="Camera",
70 dimensions=("instrument",),
71 isCalibration=True,
72 lookupFunction=lookupStaticCalibration,
73 )
76class AssociatedSourcesTractAnalysisConfig(
77 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections
78):
79 def setDefaults(self):
80 super().setDefaults()
83class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask):
84 ConfigClass = AssociatedSourcesTractAnalysisConfig
85 _DefaultName = "associatedSourcesTractAnalysis"
87 @staticmethod
88 def getBoxWcs(skymap, tract):
89 """Get box that defines tract boundaries."""
90 tractInfo = skymap.generateTract(tract)
91 wcs = tractInfo.getWcs()
92 tractBox = tractInfo.getBBox()
93 return tractBox, wcs
95 @classmethod
96 def callback(cls, inputs, dataId):
97 """Callback function to be used with reconstructor."""
98 return cls.prepareAssociatedSources(
99 inputs["skyMap"],
100 dataId["tract"],
101 inputs["sourceCatalogs"],
102 inputs["associatedSources"],
103 )
105 @classmethod
106 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources):
107 """Concatenate source catalogs and join on associated object index."""
109 # Keep only sources with associations
110 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner")
111 dataJoined.set_index("sourceId", inplace=True)
113 # Determine which sources are contained in tract
114 ra = np.radians(dataJoined["coord_ra"].values)
115 dec = np.radians(dataJoined["coord_dec"].values)
116 box, wcs = cls.getBoxWcs(skymap, tract)
117 box = Box2D(box)
118 x, y = wcs.skyToPixelArray(ra, dec)
119 boxSelection = box.contains(x, y)
121 # Keep only the sources in groups that are fully contained within the
122 # tract
123 dataJoined["boxSelection"] = boxSelection
124 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"]))
125 dataFiltered.drop(columns="boxSelection", inplace=True)
127 return dataFiltered
129 def runQuantum(self, butlerQC, inputRefs, outputRefs):
130 inputs = butlerQC.get(inputRefs)
132 # Load specified columns from source catalogs
133 names = self.collectInputNames()
134 names |= {"sourceId", "coord_ra", "coord_dec"}
135 names.remove("obj_index")
136 sourceCatalogs = []
137 for handle in inputs["sourceCatalogs"]:
138 sourceCatalogs.append(self.loadData(handle, names))
139 inputs["sourceCatalogs"] = sourceCatalogs
141 dataId = butlerQC.quantum.dataId
142 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources")
144 # TODO: make key used for object index configurable
145 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"])
147 data = self.callback(inputs, dataId)
149 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]}
150 outputs = self.run(**kwargs)
151 butlerQC.put(outputs, outputRefs)