Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 41%

58 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-07 04:48 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23import numpy as np 

24import pandas as pd 

25from lsst.daf.butler import DataCoordinate 

26from lsst.geom import Box2D 

27from lsst.pipe.base import connectionTypes as ct 

28 

29from .base import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

30 

31 

32class AssociatedSourcesTractAnalysisConnections( 

33 AnalysisBaseConnections, 

34 dimensions=("skymap", "tract"), 

35 defaultTemplates={ 

36 "outputName": "isolated_star_sources", 

37 "associatedSourcesInputName": "isolated_star_sources", 

38 }, 

39): 

40 sourceCatalogs = ct.Input( 

41 doc="Visit based source table to load from the butler", 

42 name="sourceTable_visit", 

43 storageClass="DataFrame", 

44 deferLoad=True, 

45 dimensions=("visit", "band"), 

46 multiple=True, 

47 ) 

48 

49 associatedSources = ct.Input( 

50 doc="Table of associated sources", 

51 name="{associatedSourcesInputName}", 

52 storageClass="DataFrame", 

53 deferLoad=True, 

54 dimensions=("instrument", "skymap", "tract"), 

55 ) 

56 

57 skyMap = ct.Input( 

58 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

59 name="skyMap", 

60 storageClass="SkyMap", 

61 dimensions=("skymap",), 

62 ) 

63 

64 

65class AssociatedSourcesTractAnalysisConfig( 

66 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

67): 

68 def setDefaults(self): 

69 super().setDefaults() 

70 

71 

72class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

73 ConfigClass = AssociatedSourcesTractAnalysisConfig 

74 _DefaultName = "associatedSourcesTractAnalysisTask" 

75 

76 @staticmethod 

77 def getBoxWcs(skymap, tract): 

78 """Get box that defines tract boundaries.""" 

79 tractInfo = skymap.generateTract(tract) 

80 wcs = tractInfo.getWcs() 

81 tractBox = tractInfo.getBBox() 

82 return tractBox, wcs 

83 

84 @classmethod 

85 def callback(cls, inputs, dataId): 

86 """Callback function to be used with reconstructor.""" 

87 return cls.prepareAssociatedSources( 

88 inputs["skyMap"], 

89 dataId["tract"], 

90 inputs["sourceCatalogs"], 

91 inputs["associatedSources"], 

92 ) 

93 

94 @classmethod 

95 def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSources): 

96 """Concatenate source catalogs and join on associated object index.""" 

97 

98 # Keep only sources with associations 

99 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

100 dataJoined.set_index("sourceId", inplace=True) 

101 

102 # Determine which sources are contained in tract 

103 ra = np.radians(dataJoined["coord_ra"].values) 

104 dec = np.radians(dataJoined["coord_dec"].values) 

105 box, wcs = cls.getBoxWcs(skymap, tract) 

106 box = Box2D(box) 

107 x, y = wcs.skyToPixelArray(ra, dec) 

108 boxSelection = box.contains(x, y) 

109 

110 # Keep only the sources in groups that are fully contained within the 

111 # tract 

112 dataJoined["boxSelection"] = boxSelection 

113 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

114 dataFiltered.drop(columns="boxSelection", inplace=True) 

115 

116 return dataFiltered 

117 

118 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

119 inputs = butlerQC.get(inputRefs) 

120 

121 # Load specified columns from source catalogs 

122 names = self.collectInputNames() 

123 names |= {"sourceId", "coord_ra", "coord_dec"} 

124 names.remove("obj_index") 

125 sourceCatalogs = [] 

126 for handle in inputs["sourceCatalogs"]: 

127 sourceCatalogs.append(self.loadData(handle, names)) 

128 inputs["sourceCatalogs"] = sourceCatalogs 

129 

130 dataId = butlerQC.quantum.dataId 

131 if dataId is not None: 

132 dataId = DataCoordinate.standardize(dataId, universe=butlerQC.registry.dimensions) 

133 

134 plotInfo = self.parsePlotInfo(inputs, dataId, connectionName="associatedSources") 

135 

136 # TODO: make key used for object index configurable 

137 inputs["associatedSources"] = self.loadData(inputs["associatedSources"], ["obj_index", "sourceId"]) 

138 

139 data = self.callback(inputs, dataId) 

140 

141 kwargs = {"data": data, "plotInfo": plotInfo, "skymap": inputs["skyMap"]} 

142 outputs = self.run(**kwargs) 

143 butlerQC.put(outputs, outputRefs)