Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 44%

46 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-15 03:15 -0700

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"] 

23 

24import lsst.geom 

25import lsst.pex.config 

26import lsst.pipe.base 

27import numpy as np 

28 

29from . import rbTransiNetInterface 

30 

31 

32class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections, 

33 dimensions=("instrument", "visit", "detector"), 

34 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

35 # NOTE: Do we want the "ready to difference" template, or something 

36 # earlier? This one is warped, but not PSF-matched. 

37 template = lsst.pipe.base.connectionTypes.Input( 

38 doc="Input warped template to subtract.", 

39 dimensions=("instrument", "visit", "detector"), 

40 storageClass="ExposureF", 

41 name="{fakesType}{coaddName}Diff_templateExp" 

42 ) 

43 science = lsst.pipe.base.connectionTypes.Input( 

44 doc="Input science exposure to subtract from.", 

45 dimensions=("instrument", "visit", "detector"), 

46 storageClass="ExposureF", 

47 name="{fakesType}calexp" 

48 ) 

49 difference = lsst.pipe.base.connectionTypes.Input( 

50 doc="Result of subtracting convolved template from science image.", 

51 dimensions=("instrument", "visit", "detector"), 

52 storageClass="ExposureF", 

53 name="{fakesType}{coaddName}Diff_differenceExp", 

54 ) 

55 diaSources = lsst.pipe.base.connectionTypes.Input( 

56 doc="Detected sources on the difference image.", 

57 dimensions=("instrument", "visit", "detector"), 

58 storageClass="SourceCatalog", 

59 name="{fakesType}{coaddName}Diff_diaSrc", 

60 ) 

61 

62 # Outputs 

63 classifications = lsst.pipe.base.connectionTypes.Output( 

64 doc="Catalog of real/bogus classifications for each diaSource, " 

65 "element-wise aligned with diaSources.", 

66 dimensions=("instrument", "visit", "detector"), 

67 storageClass="Catalog", 

68 name="{fakesType}{coaddName}RealBogusSources", 

69 ) 

70 

71 

72class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections): 

73 modelPackageName = lsst.pex.config.Field( 

74 dtype=str, 

75 doc=("A unique identifier of a model package. ") 

76 ) 

77 modelPackageStorageMode = lsst.pex.config.ChoiceField( 

78 dtype=str, 

79 doc=("A string that indicates _where_ and _how_ the model package is stored."), 

80 allowed={'local': 'packages stored in the meas_transiNet repository', 

81 'neighbor': 'packages stored in the rbClassifier_data repository', 

82 }, 

83 default='neighbor', 

84 ) 

85 cutoutSize = lsst.pex.config.Field( 

86 dtype=int, 

87 doc="Width/height of square cutouts to send to classifier.", 

88 default=256, 

89 ) 

90 

91 

92class RBTransiNetTask(lsst.pipe.base.PipelineTask): 

93 """Task for running TransiNet real/bogus classification on the output of 

94 the image subtraction pipeline. 

95 """ 

96 _DefaultName = "rbTransiNet" 

97 ConfigClass = RBTransiNetConfig 

98 

99 def __init__(self, **kwargs): 

100 super().__init__(**kwargs) 

101 

102 self.interface = rbTransiNetInterface.RBTransiNetInterface(self.config.modelPackageName, 

103 self.config.modelPackageStorageMode) 

104 

105 def run(self, template, science, difference, diaSources): 

106 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources] 

107 self.log.info("Extracted %d cutouts.", len(cutouts)) 

108 scores = self.interface.infer(cutouts) 

109 self.log.info("Scored %d cutouts.", len(scores)) 

110 schema = lsst.afw.table.Schema() 

111 schema.addField(diaSources.schema["id"].asField()) 

112 schema.addField("score", doc="real/bogus score of this source", type=float) 

113 classifications = lsst.afw.table.BaseCatalog(schema) 

114 classifications.resize(len(scores)) 

115 

116 classifications["id"] = diaSources["id"] 

117 classifications["score"] = scores 

118 

119 return lsst.pipe.base.Struct(classifications=classifications) 

120 

121 def _make_cutouts(self, template, science, difference, source): 

122 """Return cutouts of each image centered at the source location. 

123 

124 Parameters 

125 ---------- 

126 template : `lsst.afw.image.ExposureF` 

127 science : `lsst.afw.image.ExposureF` 

128 difference : `lsst.afw.image.ExposureF` 

129 Exposures to cut images out of. 

130 source : `lsst.afw.table.SourceRecord` 

131 Source to make cutouts of. 

132 

133 Returns 

134 ------- 

135 cutouts, `lsst.meas.transiNet.CutoutInputs` 

136 Cutouts of each of the input images. 

137 """ 

138 

139 # Try to create cutouts, or simply return empty cutouts if 

140 # failed (most probably out-of-border box) 

141 extent = lsst.geom.Extent2I(self.config.cutoutSize) 

142 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent) 

143 

144 if science.getBBox().contains(box): 

145 science_cutout = science.Factory(science, box).image.array 

146 template_cutout = template.Factory(template, box).image.array 

147 difference_cutout = difference.Factory(difference, box).image.array 

148 else: 

149 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32) 

150 template_cutout = np.zeros_like(science_cutout) 

151 difference_cutout = np.zeros_like(science_cutout) 

152 

153 return rbTransiNetInterface.CutoutInputs(science=science_cutout, 

154 template=template_cutout, 

155 difference=difference_cutout)