Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 40%

58 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-14 03:17 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import numpy as np 

26import pandas as pd 

27from lsst.cp.pipe._lookupStaticCalibration import lookupStaticCalibration 

28from lsst.geom import Box2D 

29from lsst.pipe.base import connectionTypes as ct 

30 

31from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

32 

33 

34class AssociatedSourcesTractAnalysisConnections( 

35 AnalysisBaseConnections, 

36 dimensions=("skymap", "tract", "instrument"), 

37 defaultTemplates={ 

38 "outputName": "isolated_star_sources", 

39 "associatedSourcesInputName": "isolated_star_sources", 

40 }, 

41): 

42 sourceCatalogs = ct.Input( 

43 doc="Visit based source table to load from the butler", 

44 name="sourceTable_visit", 

45 storageClass="DataFrame", 

46 deferLoad=True, 

47 dimensions=("visit", "band"), 

48 multiple=True, 

49 ) 

50 

51 associatedSources = ct.Input( 

52 doc="Table of associated sources", 

53 name="{associatedSourcesInputName}", 

54 storageClass="DataFrame", 

55 deferLoad=True, 

56 dimensions=("instrument", "skymap", "tract"), 

57 ) 

58 

59 skyMap = ct.Input( 

60 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

61 name="skyMap", 

62 storageClass="SkyMap", 

63 dimensions=("skymap",), 

64 ) 

65 camera = ct.PrerequisiteInput( 

66 doc="Input camera to use for focal plane geometry.", 

67 name="camera", 

68 storageClass="Camera", 

69 dimensions=("instrument",), 

70 isCalibration=True, 

71 lookupFunction=lookupStaticCalibration, 

72 ) 

73 

74 

75class AssociatedSourcesTractAnalysisConfig( 

76 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

77): 

78 def setDefaults(self): 

79 super().setDefaults() 

80 

81 

82class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

83 ConfigClass = AssociatedSourcesTractAnalysisConfig 

84 _DefaultName = "associatedSourcesTractAnalysis" 

85 

86 @staticmethod 

87 def getBoxWcs(skymap, tract): 

88 """Get box that defines tract boundaries.""" 

89 tractInfo = skymap.generateTract(tract) 

90 wcs = tractInfo.getWcs() 

91 tractBox = tractInfo.getBBox() 

92 return tractBox, wcs 

93 

94 @classmethod 

95 def callback(cls, inputs, dataId): 

96 """Callback function to be used with reconstructor.""" 

97 return cls.prepareAssociatedSources( 

98 inputs["skyMap"], 

99 dataId["tract"], 

100 inputs["sourceCatalogs"], 

101 inputs["associatedSources"], 

102 ) 

103 

104 @classmethod 

105 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

106 """Concatenate source catalogs and join on associated object index.""" 

107 

108 # Keep only sources with associations 

109 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

110 dataJoined.set_index("sourceId", inplace=True) 

111 

112 # Determine which sources are contained in tract 

113 ra = np.radians(dataJoined["coord_ra"].values) 

114 dec = np.radians(dataJoined["coord_dec"].values) 

115 box, wcs = cls.getBoxWcs(skymap, tract) 

116 box = Box2D(box) 

117 x, y = wcs.skyToPixelArray(ra, dec) 

118 boxSelection = box.contains(x, y) 

119 

120 # Keep only the sources in groups that are fully contained within the 

121 # tract 

122 dataJoined["boxSelection"] = boxSelection 

123 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

124 dataFiltered.drop(columns="boxSelection", inplace=True) 

125 

126 return dataFiltered 

127 

128 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

129 inputs = butlerQC.get(inputRefs) 

130 

131 # Load specified columns from source catalogs 

132 names = self.collectInputNames() 

133 names |= {"sourceId", "coord_ra", "coord_dec"} 

134 names.remove("obj_index") 

135 sourceCatalogs = [] 

136 for handle in inputs["sourceCatalogs"]: 

137 sourceCatalogs.append(self.loadData(handle, names)) 

138 inputs["sourceCatalogs"] = sourceCatalogs 

139 

140 dataId = butlerQC.quantum.dataId 

141 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

142 

143 # TODO: make key used for object index configurable 

144 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

145 

146 data = self.callback(inputs, dataId) 

147 

148 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]} 

149 outputs = self.run(**kwargs) 

150 butlerQC.put(outputs, outputRefs)