Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 42%

60 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 11:34 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"] 

23 

24import lsst.geom 

25import lsst.pex.config 

26import lsst.pipe.base 

27from lsst.utils.timer import timeMethod 

28import numpy as np 

29 

30from . import rbTransiNetInterface 

31from lsst.meas.transiNet.modelPackages.storageAdapterButler import StorageAdapterButler 

32 

33 

34class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections, 

35 dimensions=("instrument", "visit", "detector"), 

36 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

37 # NOTE: Do we want the "ready to difference" template, or something 

38 # earlier? This one is warped, but not PSF-matched. 

39 template = lsst.pipe.base.connectionTypes.Input( 

40 doc="Input warped template to subtract.", 

41 dimensions=("instrument", "visit", "detector"), 

42 storageClass="ExposureF", 

43 name="{fakesType}{coaddName}Diff_templateExp" 

44 ) 

45 science = lsst.pipe.base.connectionTypes.Input( 

46 doc="Input science exposure to subtract from.", 

47 dimensions=("instrument", "visit", "detector"), 

48 storageClass="ExposureF", 

49 name="{fakesType}calexp" 

50 ) 

51 difference = lsst.pipe.base.connectionTypes.Input( 

52 doc="Result of subtracting convolved template from science image.", 

53 dimensions=("instrument", "visit", "detector"), 

54 storageClass="ExposureF", 

55 name="{fakesType}{coaddName}Diff_differenceExp", 

56 ) 

57 diaSources = lsst.pipe.base.connectionTypes.Input( 

58 doc="Detected sources on the difference image.", 

59 dimensions=("instrument", "visit", "detector"), 

60 storageClass="SourceCatalog", 

61 name="{fakesType}{coaddName}Diff_diaSrc", 

62 ) 

63 pretrainedModel = lsst.pipe.base.connectionTypes.PrerequisiteInput( 

64 doc="Pretrained neural network model (-package) for the RBClassifier.", 

65 dimensions=(), 

66 storageClass="NNModelPackagePayload", 

67 name=StorageAdapterButler.dataset_type_name, 

68 ) 

69 

70 # Outputs 

71 classifications = lsst.pipe.base.connectionTypes.Output( 

72 doc="Catalog of real/bogus classifications for each diaSource, " 

73 "element-wise aligned with diaSources.", 

74 dimensions=("instrument", "visit", "detector"), 

75 storageClass="Catalog", 

76 name="{fakesType}{coaddName}RealBogusSources", 

77 ) 

78 

79 def __init__(self, *, config=None): 

80 super().__init__(config=config) 

81 

82 if self.config.modelPackageStorageMode != "butler": 

83 del self.pretrainedModel 

84 

85 

86class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections): 

87 modelPackageName = lsst.pex.config.Field( 

88 optional=True, 

89 dtype=str, 

90 doc=("A unique identifier of a model package. ") 

91 ) 

92 modelPackageStorageMode = lsst.pex.config.ChoiceField( 

93 dtype=str, 

94 doc=("A string that indicates _where_ and _how_ the model package is stored."), 

95 allowed={'local': 'packages stored in the meas_transiNet repository', 

96 'neighbor': 'packages stored in the rbClassifier_data repository', 

97 'butler': 'packages stored in the butler repository', 

98 }, 

99 default='neighbor', 

100 ) 

101 cutoutSize = lsst.pex.config.Field( 

102 dtype=int, 

103 doc="Width/height of square cutouts to send to classifier.", 

104 default=256, 

105 ) 

106 

107 def validate(self): 

108 # if we are in the butler mode, the user should not set 

109 # a modelPackageName as a config field. 

110 if self.modelPackageStorageMode == "butler": 

111 if self.modelPackageName is not None: 

112 raise ValueError("In a _real_ run of a pipeline when the " 

113 "modelPackageStorageMode is 'butler', " 

114 "the modelPackageName cannot be specified " 

115 "as a config field. Pass it as a collection" 

116 "name in the command-line instead.") 

117 

118 

119class RBTransiNetTask(lsst.pipe.base.PipelineTask): 

120 """Task for running TransiNet real/bogus classification on the output of 

121 the image subtraction pipeline. 

122 """ 

123 _DefaultName = "rbTransiNet" 

124 ConfigClass = RBTransiNetConfig 

125 

126 def __init__(self, **kwargs): 

127 super().__init__(**kwargs) 

128 

129 self.butler_loaded_package = None 

130 

131 @timeMethod 

132 def run(self, template, science, difference, diaSources, pretrainedModel=None): 

133 

134 # Create the TransiNet interface object. 

135 # Note: assuming each quanta creates one instance of this task, this is 

136 # a proper place for doing this since loading of the model is run only 

137 # once. However, if in the future we come up with a design in which one 

138 # task instance is used for multiple quanta, this will need to be moved 

139 # somewhere else -- e.g. to the __init__ method, or even to runQuantum. 

140 self.butler_loaded_package = pretrainedModel # This will be used by the interface 

141 self.interface = rbTransiNetInterface.RBTransiNetInterface(self) 

142 

143 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources] 

144 self.log.info("Extracted %d cutouts.", len(cutouts)) 

145 scores = self.interface.infer(cutouts) 

146 self.log.info("Scored %d cutouts.", len(scores)) 

147 schema = lsst.afw.table.Schema() 

148 schema.addField(diaSources.schema["id"].asField()) 

149 schema.addField("score", doc="real/bogus score of this source", type=float) 

150 classifications = lsst.afw.table.BaseCatalog(schema) 

151 classifications.resize(len(scores)) 

152 

153 classifications["id"] = diaSources["id"] 

154 classifications["score"] = scores 

155 

156 return lsst.pipe.base.Struct(classifications=classifications) 

157 

158 def _make_cutouts(self, template, science, difference, source): 

159 """Return cutouts of each image centered at the source location. 

160 

161 Parameters 

162 ---------- 

163 template : `lsst.afw.image.ExposureF` 

164 science : `lsst.afw.image.ExposureF` 

165 difference : `lsst.afw.image.ExposureF` 

166 Exposures to cut images out of. 

167 source : `lsst.afw.table.SourceRecord` 

168 Source to make cutouts of. 

169 

170 Returns 

171 ------- 

172 cutouts, `lsst.meas.transiNet.CutoutInputs` 

173 Cutouts of each of the input images. 

174 """ 

175 

176 # Try to create cutouts, or simply return empty cutouts if 

177 # failed (most probably out-of-border box) 

178 extent = lsst.geom.Extent2I(self.config.cutoutSize) 

179 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent) 

180 

181 if science.getBBox().contains(box): 

182 science_cutout = science.Factory(science, box).image.array 

183 template_cutout = template.Factory(template, box).image.array 

184 difference_cutout = difference.Factory(difference, box).image.array 

185 else: 

186 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32) 

187 template_cutout = np.zeros_like(science_cutout) 

188 difference_cutout = np.zeros_like(science_cutout) 

189 

190 return rbTransiNetInterface.CutoutInputs(science=science_cutout, 

191 template=template_cutout, 

192 difference=difference_cutout)