Coverage for python/lsst/analysis/tools/tasks/associatedSourcesTractAnalysis.py: 43%

41 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-04 03:18 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from lsst.pipe.base import connectionTypes as ct 

24 

25from .base import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask 

26 

27# These need to be updated for this analysis context 

28# from ..analysisPlots.analysisPlots import ShapeSizeFractionalDiffScatter 

29# from ..analysisPlots.analysisPlots import Ap12PsfSkyPlot 

30# from ..analysisMetrics.analysisMetrics import ShapeSizeFractionalMetric 

31 

32 

33class AssociatedSourcesTractAnalysisConnections( 

34 AnalysisBaseConnections, 

35 dimensions=("skymap", "tract"), 

36 defaultTemplates={ 

37 "inputName": "isolated_star_sources", 

38 # "associatedSourcesInputName": "isolated_star_sources"}, 

39 }, 

40): 

41 data = ct.Input( 

42 doc="Visit based source table to load from the butler", 

43 name="sourceTable_visit", 

44 storageClass="DataFrame", 

45 # deferLoad=True, 

46 dimensions=("visit", "band"), 

47 multiple=True, 

48 ) 

49 

50 associatedSources = ct.Input( 

51 doc="Table of associated sources", 

52 # name="{associatedSourcesInputName}", 

53 name="{inputName}", 

54 storageClass="DataFrame", 

55 # deferLoad=True, 

56 dimensions=("instrument", "skymap", "tract"), 

57 ) 

58 

59 skyMap = ct.Input( 

60 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

61 name="skyMap", 

62 storageClass="SkyMap", 

63 dimensions=("skymap",), 

64 ) 

65 

66 

67class AssociatedSourcesTractAnalysisConfig( 

68 AnalysisBaseConfig, pipelineConnections=AssociatedSourcesTractAnalysisConnections 

69): 

70 def setDefaults(self): 

71 super().setDefaults() 

72 # set plots to run 

73 # update for this analysis context 

74 # self.plots.shapeSizeFractionalDiffScatter = \ 

75 # ShapeSizeFractionalDiffScatter() 

76 # self.plots.Ap12PsfSkyPlot = Ap12PsfSkyPlot() 

77 

78 # set metrics to run 

79 # update for this analysis context 

80 # self.metrics.shapeSizeFractionalMetric = ShapeSizeFractionalMetric() 

81 

82 

83class AssociatedSourcesTractAnalysisTask(AnalysisPipelineTask): 

84 ConfigClass = AssociatedSourcesTractAnalysisConfig 

85 _DefaultName = "associatedSourcesTractAnalysisTask" 

86 

87 def getBoxWcs(self, skymap, tract): 

88 tractInfo = skymap.generateTract(tract) 

89 wcs = tractInfo.getWcs() 

90 tractBox = tractInfo.getBBox() 

91 self.log.info("Running tract: %s", tract) 

92 return tractBox, wcs 

93 

94 def prepareAssociatedSources(self, skymap, tract, sourceCatalogs, associatedSources): 

95 """ 

96 This should be a standalone function rather than being associated with 

97 this class. 

98 """ 

99 

100 import lsst.geom as geom 

101 import numpy as np 

102 import pandas as pd 

103 

104 # Keep only sources with associations 

105 dataJoined = pd.concat(sourceCatalogs).merge(associatedSources, on="sourceId", how="inner") 

106 dataJoined.set_index("sourceId", inplace=True) 

107 

108 # Determine which sources are contained in tract 

109 ra = np.radians(dataJoined["coord_ra"].values) 

110 dec = np.radians(dataJoined["coord_dec"].values) 

111 box, wcs = self.getBoxWcs(skymap, tract) 

112 box = geom.Box2D(box) 

113 x, y = wcs.skyToPixelArray(ra, dec) 

114 boxSelection = box.contains(x, y) 

115 

116 # Keep only the sources in groups that are fully contained within the 

117 # tract 

118 dataJoined["boxSelection"] = boxSelection 

119 dataFiltered = dataJoined.groupby("obj_index").filter(lambda x: all(x["boxSelection"])) 

120 dataFiltered.drop(columns="boxSelection", inplace=True) 

121 

122 return dataFiltered 

123 

124 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

125 inputs = butlerQC.get(inputRefs) 

126 

127 dataFiltered = self.prepareAssociatedSources( 

128 inputs["skyMap"], 

129 inputRefs.associatedSources.dataId.byName()["tract"], 

130 inputs["data"], 

131 inputs["associatedSources"], 

132 ) 

133 

134 kwargs = {"data": dataFiltered} 

135 

136 outputs = self.run(**kwargs) 

137 butlerQC.put(outputs, outputRefs)