Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 48%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-13 10:24 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"] 

23 

24import lsst.geom 

25import lsst.pex.config 

26import lsst.pipe.base 

27from lsst.utils.timer import timeMethod 

28import numpy as np 

29 

30from . import rbTransiNetInterface 

31 

32 

33class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections, 

34 dimensions=("instrument", "visit", "detector"), 

35 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

36 # NOTE: Do we want the "ready to difference" template, or something 

37 # earlier? This one is warped, but not PSF-matched. 

38 template = lsst.pipe.base.connectionTypes.Input( 

39 doc="Input warped template to subtract.", 

40 dimensions=("instrument", "visit", "detector"), 

41 storageClass="ExposureF", 

42 name="{fakesType}{coaddName}Diff_templateExp" 

43 ) 

44 science = lsst.pipe.base.connectionTypes.Input( 

45 doc="Input science exposure to subtract from.", 

46 dimensions=("instrument", "visit", "detector"), 

47 storageClass="ExposureF", 

48 name="{fakesType}calexp" 

49 ) 

50 difference = lsst.pipe.base.connectionTypes.Input( 

51 doc="Result of subtracting convolved template from science image.", 

52 dimensions=("instrument", "visit", "detector"), 

53 storageClass="ExposureF", 

54 name="{fakesType}{coaddName}Diff_differenceExp", 

55 ) 

56 diaSources = lsst.pipe.base.connectionTypes.Input( 

57 doc="Detected sources on the difference image.", 

58 dimensions=("instrument", "visit", "detector"), 

59 storageClass="SourceCatalog", 

60 name="{fakesType}{coaddName}Diff_diaSrc", 

61 ) 

62 

63 # Outputs 

64 classifications = lsst.pipe.base.connectionTypes.Output( 

65 doc="Catalog of real/bogus classifications for each diaSource, " 

66 "element-wise aligned with diaSources.", 

67 dimensions=("instrument", "visit", "detector"), 

68 storageClass="Catalog", 

69 name="{fakesType}{coaddName}RealBogusSources", 

70 ) 

71 

72 

73class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections): 

74 modelPackageName = lsst.pex.config.Field( 

75 dtype=str, 

76 doc=("A unique identifier of a model package. ") 

77 ) 

78 modelPackageStorageMode = lsst.pex.config.ChoiceField( 

79 dtype=str, 

80 doc=("A string that indicates _where_ and _how_ the model package is stored."), 

81 allowed={'local': 'packages stored in the meas_transiNet repository', 

82 'neighbor': 'packages stored in the rbClassifier_data repository', 

83 }, 

84 default='neighbor', 

85 ) 

86 cutoutSize = lsst.pex.config.Field( 

87 dtype=int, 

88 doc="Width/height of square cutouts to send to classifier.", 

89 default=256, 

90 ) 

91 

92 

93class RBTransiNetTask(lsst.pipe.base.PipelineTask): 

94 """Task for running TransiNet real/bogus classification on the output of 

95 the image subtraction pipeline. 

96 """ 

97 _DefaultName = "rbTransiNet" 

98 ConfigClass = RBTransiNetConfig 

99 

100 def __init__(self, **kwargs): 

101 super().__init__(**kwargs) 

102 

103 self.interface = rbTransiNetInterface.RBTransiNetInterface(self.config.modelPackageName, 

104 self.config.modelPackageStorageMode) 

105 

106 @timeMethod 

107 def run(self, template, science, difference, diaSources): 

108 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources] 

109 self.log.info("Extracted %d cutouts.", len(cutouts)) 

110 scores = self.interface.infer(cutouts) 

111 self.log.info("Scored %d cutouts.", len(scores)) 

112 schema = lsst.afw.table.Schema() 

113 schema.addField(diaSources.schema["id"].asField()) 

114 schema.addField("score", doc="real/bogus score of this source", type=float) 

115 classifications = lsst.afw.table.BaseCatalog(schema) 

116 classifications.resize(len(scores)) 

117 

118 classifications["id"] = diaSources["id"] 

119 classifications["score"] = scores 

120 

121 return lsst.pipe.base.Struct(classifications=classifications) 

122 

123 def _make_cutouts(self, template, science, difference, source): 

124 """Return cutouts of each image centered at the source location. 

125 

126 Parameters 

127 ---------- 

128 template : `lsst.afw.image.ExposureF` 

129 science : `lsst.afw.image.ExposureF` 

130 difference : `lsst.afw.image.ExposureF` 

131 Exposures to cut images out of. 

132 source : `lsst.afw.table.SourceRecord` 

133 Source to make cutouts of. 

134 

135 Returns 

136 ------- 

137 cutouts, `lsst.meas.transiNet.CutoutInputs` 

138 Cutouts of each of the input images. 

139 """ 

140 

141 # Try to create cutouts, or simply return empty cutouts if 

142 # failed (most probably out-of-border box) 

143 extent = lsst.geom.Extent2I(self.config.cutoutSize) 

144 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent) 

145 

146 if science.getBBox().contains(box): 

147 science_cutout = science.Factory(science, box).image.array 

148 template_cutout = template.Factory(template, box).image.array 

149 difference_cutout = difference.Factory(difference, box).image.array 

150 else: 

151 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32) 

152 template_cutout = np.zeros_like(science_cutout) 

153 difference_cutout = np.zeros_like(science_cutout) 

154 

155 return rbTransiNetInterface.CutoutInputs(science=science_cutout, 

156 template=template_cutout, 

157 difference=difference_cutout)