Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 46%

58 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-29 11:31 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import numpy as np 

26import pandas as pd 

27from lsst.geom import Box2D 

28from lsst.pipe.base import connectionTypes as ct 

29from lsst.skymap import BaseSkyMap 

30 

31from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

32 

33 

34class AssociatedSourcesTractAnalysisConnections( 

35 AnalysisBaseConnections, 

36 dimensions=("skymap", "tract", "instrument"), 

37 defaultTemplates={ 

38 "outputName": "isolated_star_sources", 

39 "associatedSourcesInputName": "isolated_star_sources", 

40 }, 

41): 

42 sourceCatalogs = ct.Input( 

43 doc="Visit based source table to load from the butler", 

44 name="sourceTable_visit", 

45 storageClass="DataFrame", 

46 deferLoad=True, 

47 dimensions=("visit", "band"), 

48 multiple=True, 

49 ) 

50 

51 associatedSources = ct.Input( 

52 doc="Table of associated sources", 

53 name="{associatedSourcesInputName}", 

54 storageClass="DataFrame", 

55 deferLoad=True, 

56 dimensions=("instrument", "skymap", "tract"), 

57 ) 

58 

59 skyMap = ct.Input( 

60 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

62 storageClass="SkyMap", 

63 dimensions=("skymap",), 

64 ) 

65 camera = ct.PrerequisiteInput( 

66 doc="Input camera to use for focal plane geometry.", 

67 name="camera", 

68 storageClass="Camera", 

69 dimensions=("instrument",), 

70 isCalibration=True, 

71 ) 

72 

73 

74class AssociatedSourcesTractAnalysisConfig( 

75 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

76): 

77 def setDefaults(self): 

78 super().setDefaults() 

79 

80 

81class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

82 ConfigClass = AssociatedSourcesTractAnalysisConfig 

83 _DefaultName = "associatedSourcesTractAnalysis" 

84 

85 @staticmethod 

86 def getBoxWcs(skymap, tract): 

87 """Get box that defines tract boundaries.""" 

88 tractInfo = skymap.generateTract(tract) 

89 wcs = tractInfo.getWcs() 

90 tractBox = tractInfo.getBBox() 

91 return tractBox, wcs 

92 

93 @classmethod 

94 def callback(cls, inputs, dataId): 

95 """Callback function to be used with reconstructor.""" 

96 return cls.prepareAssociatedSources( 

97 inputs["skyMap"], 

98 dataId["tract"], 

99 inputs["sourceCatalogs"], 

100 inputs["associatedSources"], 

101 ) 

102 

103 @classmethod 

104 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

105 """Concatenate source catalogs and join on associated object index.""" 

106 

107 # Keep only sources with associations 

108 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

109 dataJoined.set_index("sourceId", inplace=True) 

110 

111 # Determine which sources are contained in tract 

112 ra = np.radians(dataJoined["coord_ra"].values) 

113 dec = np.radians(dataJoined["coord_dec"].values) 

114 box, wcs = cls.getBoxWcs(skymap, tract) 

115 box = Box2D(box) 

116 x, y = wcs.skyToPixelArray(ra, dec) 

117 boxSelection = box.contains(x, y) 

118 

119 # Keep only the sources in groups that are fully contained within the 

120 # tract 

121 dataJoined["boxSelection"] = boxSelection 

122 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

123 dataFiltered.drop(columns="boxSelection", inplace=True) 

124 

125 return dataFiltered 

126 

127 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

128 inputs = butlerQC.get(inputRefs) 

129 

130 # Load specified columns from source catalogs 

131 names = self.collectInputNames() 

132 names |= {"sourceId", "coord_ra", "coord_dec"} 

133 names.remove("obj_index") 

134 sourceCatalogs = [] 

135 for handle in inputs["sourceCatalogs"]: 

136 sourceCatalogs.append(self.loadData(handle, names)) 

137 inputs["sourceCatalogs"] = sourceCatalogs 

138 

139 dataId = butlerQC.quantum.dataId 

140 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

141 

142 # TODO: make key used for object index configurable 

143 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

144 

145 data = self.callback(inputs, dataId) 

146 

147 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"], "camera": inputs["camera"]} 

148 outputs = self.run(**kwargs) 

149 butlerQC.put(outputs, outputRefs)