Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 46%

59 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 11:38 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import numpy as np 

26import pandas as pd 

27from lsst.cp.pipe._lookupStaticCalibration import lookupStaticCalibration 

28from lsst.geom import Box2D 

29from lsst.pipe.base import connectionTypes as ct 

30from lsst.skymap import BaseSkyMap 

31 

32from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

33 

34 

35class AssociatedSourcesTractAnalysisConnections( 

36 AnalysisBaseConnections, 

37 dimensions=("skymap", "tract", "instrument"), 

38 defaultTemplates={ 

39 "outputName": "isolated_star_sources", 

40 "associatedSourcesInputName": "isolated_star_sources", 

41 }, 

42): 

43 sourceCatalogs = ct.Input( 

44 doc="Visit based source table to load from the butler", 

45 name="sourceTable_visit", 

46 storageClass="DataFrame", 

47 deferLoad=True, 

48 dimensions=("visit", "band"), 

49 multiple=True, 

50 ) 

51 

52 associatedSources = ct.Input( 

53 doc="Table of associated sources", 

54 name="{associatedSourcesInputName}", 

55 storageClass="DataFrame", 

56 deferLoad=True, 

57 dimensions=("instrument", "skymap", "tract"), 

58 ) 

59 

60 skyMap = ct.Input( 

61 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

62 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

63 storageClass="SkyMap", 

64 dimensions=("skymap",), 

65 ) 

66 camera = ct.PrerequisiteInput( 

67 doc="Input camera to use for focal plane geometry.", 

68 name="camera", 

69 storageClass="Camera", 

70 dimensions=("instrument",), 

71 isCalibration=True, 

72 lookupFunction=lookupStaticCalibration, 

73 ) 

74 

75 

76class AssociatedSourcesTractAnalysisConfig( 

77 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

78): 

79 def setDefaults(self): 

80 super().setDefaults() 

81 

82 

83class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

84 ConfigClass = AssociatedSourcesTractAnalysisConfig 

85 _DefaultName = "associatedSourcesTractAnalysis" 

86 

87 @staticmethod 

88 def getBoxWcs(skymap, tract): 

89 """Get box that defines tract boundaries.""" 

90 tractInfo = skymap.generateTract(tract) 

91 wcs = tractInfo.getWcs() 

92 tractBox = tractInfo.getBBox() 

93 return tractBox, wcs 

94 

95 @classmethod 

96 def callback(cls, inputs, dataId): 

97 """Callback function to be used with reconstructor.""" 

98 return cls.prepareAssociatedSources( 

99 inputs["skyMap"], 

100 dataId["tract"], 

101 inputs["sourceCatalogs"], 

102 inputs["associatedSources"], 

103 ) 

104 

105 @classmethod 

106 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

107 """Concatenate source catalogs and join on associated object index.""" 

108 

109 # Keep only sources with associations 

110 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

111 dataJoined.set_index("sourceId", inplace=True) 

112 

113 # Determine which sources are contained in tract 

114 ra = np.radians(dataJoined["coord_ra"].values) 

115 dec = np.radians(dataJoined["coord_dec"].values) 

116 box, wcs = cls.getBoxWcs(skymap, tract) 

117 box = Box2D(box) 

118 x, y = wcs.skyToPixelArray(ra, dec) 

119 boxSelection = box.contains(x, y) 

120 

121 # Keep only the sources in groups that are fully contained within the 

122 # tract 

123 dataJoined["boxSelection"] = boxSelection 

124 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

125 dataFiltered.drop(columns="boxSelection", inplace=True) 

126 

127 return dataFiltered 

128 

129 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

130 inputs = butlerQC.get(inputRefs) 

131 

132 # Load specified columns from source catalogs 

133 names = self.collectInputNames() 

134 names |= {"sourceId", "coord_ra", "coord_dec"} 

135 names.remove("obj_index") 

136 sourceCatalogs = [] 

137 for handle in inputs["sourceCatalogs"]: 

138 sourceCatalogs.append(self.loadData(handle, names)) 

139 inputs["sourceCatalogs"] = sourceCatalogs 

140 

141 dataId = butlerQC.quantum.dataId 

142 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

143 

144 # TODO: make key used for object index configurable 

145 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

146 

147 data = self.callback(inputs, dataId) 

148 

149 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]} 

150 outputs = self.run(**kwargs) 

151 butlerQC.put(outputs, outputRefs)