Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 37%

55 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-21 02:55 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23import numpy as np 

24import pandas as pd 

25from lsst.geom import Box2D 

26from lsst.pipe.base import connectionTypes as ct 

27 

28from .base import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

29 

30 

31class AssociatedSourcesTractAnalysisConnections( 

32 AnalysisBaseConnections, 

33 dimensions=("skymap", "tract"), 

34 defaultTemplates={ 

35 "outputName": "isolated_star_sources", 

36 "associatedSourcesInputName": "isolated_star_sources", 

37 }, 

38): 

39 sourceCatalogs = ct.Input( 

40 doc="Visit based source table to load from the butler", 

41 name="sourceTable_visit", 

42 storageClass="DataFrame", 

43 deferLoad=True, 

44 dimensions=("visit", "band"), 

45 multiple=True, 

46 ) 

47 

48 associatedSources = ct.Input( 

49 doc="Table of associated sources", 

50 name="{associatedSourcesInputName}", 

51 storageClass="DataFrame", 

52 deferLoad=True, 

53 dimensions=("instrument", "skymap", "tract"), 

54 ) 

55 

56 skyMap = ct.Input( 

57 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

58 name="skyMap", 

59 storageClass="SkyMap", 

60 dimensions=("skymap",), 

61 ) 

62 

63 

64class AssociatedSourcesTractAnalysisConfig( 

65 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

66): 

67 def setDefaults(self): 

68 super().setDefaults() 

69 

70 

71class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

72 ConfigClass = AssociatedSourcesTractAnalysisConfig 

73 _DefaultName = "associatedSourcesTractAnalysisTask" 

74 

75 @staticmethod 

76 def getBoxWcs(skymap, tract): 

77 """Get box that defines tract boundaries.""" 

78 tractInfo = skymap.generateTract(tract) 

79 wcs = tractInfo.getWcs() 

80 tractBox = tractInfo.getBBox() 

81 return tractBox, wcs 

82 

83 @classmethod 

84 def callback(cls, inputs, dataId): 

85 """Callback function to be used with reconstructor.""" 

86 return cls.prepareAssociatedSources( 

87 inputs["skyMap"], 

88 dataId["tract"], 

89 inputs["sourceCatalogs"], 

90 inputs["associatedSources"], 

91 ) 

92 

93 @classmethod 

94 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

95 """Concatenate source catalogs and join on associated object index.""" 

96 

97 # Keep only sources with associations 

98 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

99 dataJoined.set_index("sourceId", inplace=True) 

100 

101 # Determine which sources are contained in tract 

102 ra = np.radians(dataJoined["coord_ra"].values) 

103 dec = np.radians(dataJoined["coord_dec"].values) 

104 box, wcs = cls.getBoxWcs(skymap, tract) 

105 box = Box2D(box) 

106 x, y = wcs.skyToPixelArray(ra, dec) 

107 boxSelection = box.contains(x, y) 

108 

109 # Keep only the sources in groups that are fully contained within the 

110 # tract 

111 dataJoined["boxSelection"] = boxSelection 

112 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

113 dataFiltered.drop(columns="boxSelection", inplace=True) 

114 

115 return dataFiltered 

116 

117 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

118 inputs = butlerQC.get(inputRefs) 

119 

120 # Load specified columns from source catalogs 

121 names = self.collectInputNames() 

122 names |= {"sourceId", "coord_ra", "coord_dec"} 

123 names.remove("obj_index") 

124 sourceCatalogs = [] 

125 for handle in inputs["sourceCatalogs"]: 

126 sourceCatalogs.append(self.loadData(handle, names)) 

127 inputs["sourceCatalogs"] = sourceCatalogs 

128 

129 dataId = butlerQC.quantum.dataId 

130 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

131 

132 # TODO: make key used for object index configurable 

133 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

134 

135 data = self.callback(inputs, dataId) 

136 

137 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"]} 

138 outputs = self.run(**kwargs) 

139 butlerQC.put(outputs, outputRefs)