Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 38%

56 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-25 04:56 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("AssociatedSourcesTractAnalysisConfig", "AssociatedSourcesTractAnalysisTask") 

24 

25import numpy as np 

26import pandas as pd 

27from lsst.geom import Box2D 

28from lsst.pipe.base import connectionTypes as ct 

29 

30from .base import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

31 

32 

33class AssociatedSourcesTractAnalysisConnections( 

34 AnalysisBaseConnections, 

35 dimensions=("skymap", "tract"), 

36 defaultTemplates={ 

37 "outputName": "isolated_star_sources", 

38 "associatedSourcesInputName": "isolated_star_sources", 

39 }, 

40): 

41 sourceCatalogs = ct.Input( 

42 doc="Visit based source table to load from the butler", 

43 name="sourceTable_visit", 

44 storageClass="DataFrame", 

45 deferLoad=True, 

46 dimensions=("visit", "band"), 

47 multiple=True, 

48 ) 

49 

50 associatedSources = ct.Input( 

51 doc="Table of associated sources", 

52 name="{associatedSourcesInputName}", 

53 storageClass="DataFrame", 

54 deferLoad=True, 

55 dimensions=("instrument", "skymap", "tract"), 

56 ) 

57 

58 skyMap = ct.Input( 

59 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

60 name="skyMap", 

61 storageClass="SkyMap", 

62 dimensions=("skymap",), 

63 ) 

64 

65 

66class AssociatedSourcesTractAnalysisConfig( 

67 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

68): 

69 def setDefaults(self): 

70 super().setDefaults() 

71 

72 

73class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

74 ConfigClass = AssociatedSourcesTractAnalysisConfig 

75 _DefaultName = "associatedSourcesTractAnalysisTask" 

76 

77 @staticmethod 

78 def getBoxWcs(skymap, tract): 

79 """Get box that defines tract boundaries.""" 

80 tractInfo = skymap.generateTract(tract) 

81 wcs = tractInfo.getWcs() 

82 tractBox = tractInfo.getBBox() 

83 return tractBox, wcs 

84 

85 @classmethod 

86 def callback(cls, inputs, dataId): 

87 """Callback function to be used with reconstructor.""" 

88 return cls.prepareAssociatedSources( 

89 inputs["skyMap"], 

90 dataId["tract"], 

91 inputs["sourceCatalogs"], 

92 inputs["associatedSources"], 

93 ) 

94 

95 @classmethod 

96 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

97 """Concatenate source catalogs and join on associated object index.""" 

98 

99 # Keep only sources with associations 

100 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

101 dataJoined.set_index("sourceId", inplace=True) 

102 

103 # Determine which sources are contained in tract 

104 ra = np.radians(dataJoined["coord_ra"].values) 

105 dec = np.radians(dataJoined["coord_dec"].values) 

106 box, wcs = cls.getBoxWcs(skymap, tract) 

107 box = Box2D(box) 

108 x, y = wcs.skyToPixelArray(ra, dec) 

109 boxSelection = box.contains(x, y) 

110 

111 # Keep only the sources in groups that are fully contained within the 

112 # tract 

113 dataJoined["boxSelection"] = boxSelection 

114 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

115 dataFiltered.drop(columns="boxSelection", inplace=True) 

116 

117 return dataFiltered 

118 

119 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

120 inputs = butlerQC.get(inputRefs) 

121 

122 # Load specified columns from source catalogs 

123 names = self.collectInputNames() 

124 names |= {"sourceId", "coord_ra", "coord_dec"} 

125 names.remove("obj_index") 

126 sourceCatalogs = [] 

127 for handle in inputs["sourceCatalogs"]: 

128 sourceCatalogs.append(self.loadData(handle, names)) 

129 inputs["sourceCatalogs"] = sourceCatalogs 

130 

131 dataId = butlerQC.quantum.dataId 

132 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

133 

134 # TODO: make key used for object index configurable 

135 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

136 

137 data = self.callback(inputs, dataId) 

138 

139 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"]} 

140 outputs = self.run(**kwargs) 

141 butlerQC.put(outputs, outputRefs)