Coverage for python/lsst/meas/transiNet/rbTransiNetInterface.py: 33%

49 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 10:58 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["RBTransiNetInterface", "CutoutInputs"] 

23 

24import numpy as np 

25import dataclasses 

26import torch 

27 

28from .modelPackages.nnModelPackage import NNModelPackage 

29 

30 

31@dataclasses.dataclass(frozen=True, kw_only=True) 

32class CutoutInputs: 

33 """Science/template/difference cutouts of a single object plus other 

34 metadata. 

35 """ 

36 science: np.ndarray 

37 template: np.ndarray 

38 difference: np.ndarray 

39 

40 label: bool = None 

41 """Known truth of whether this is a real or bogus object.""" 

42 

43 

44class RBTransiNetInterface: 

45 """ 

46 The interface between the LSST AP pipeline and a trained pytorch-based 

47 RBTransiNet neural network model. 

48 

49 Parameters 

50 ---------- 

51 model_package_name : `str` 

52 Name of the model package to load. 

53 package_storage_mode : {'local', 'neighbor'} 

54 Storage mode of the model package 

55 device : `str` 

56 Device to load and run the neural network on, e.g. 'cpu' or 'cuda:0' 

57 """ 

58 

59 def __init__(self, model_package_name, package_storage_mode, device='cpu'): 

60 self.model_package_name = model_package_name 

61 self.package_storage_mode = package_storage_mode 

62 self.device = device 

63 self.init_model() 

64 

65 def init_model(self): 

66 """Create and initialize an NN model 

67 """ 

68 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

69 self.model = model_package.load(self.device) 

70 

71 # Put the model in evaluation mode instead of training model. 

72 self.model.eval() 

73 

74 def input_to_batches(self, inputs, batchSize): 

75 """Convert a list of inputs to a generator of batches. 

76 

77 Parameters 

78 ---------- 

79 inputs : `list` [`CutoutInputs`] 

80 Inputs to be scored. 

81 

82 Returns 

83 ------- 

84 batches : `generator` 

85 Generator of batches of inputs. 

86 """ 

87 for i in range(0, len(inputs), batchSize): 

88 yield inputs[i:i + batchSize] 

89 

90 def prepare_input(self, inputs): 

91 """ 

92 Convert inputs from numpy arrays, etc. to a torch.tensor blob. 

93 

94 Parameters 

95 ---------- 

96 inputs : `list` [`CutoutInputs`] 

97 Inputs to be scored. 

98 

99 Returns 

100 ------- 

101 blob 

102 Prepared torch tensor blob to run the model on. 

103 labels 

104 Truth labels, concatenated into a single list. 

105 """ 

106 cutoutsList = [] 

107 labelsList = [] 

108 for inp in inputs: 

109 # Convert each cutout to a torch tensor 

110 template = torch.from_numpy(inp.template) 

111 science = torch.from_numpy(inp.science) 

112 difference = torch.from_numpy(inp.difference) 

113 

114 # Stack the components to create a single blob 

115 singleBlob = torch.stack((template, science, difference), dim=0) 

116 

117 # And append them to the temporary list 

118 cutoutsList.append(singleBlob) 

119 

120 labelsList.append(inp.label) 

121 

122 blob = torch.stack(cutoutsList) 

123 return blob, labelsList 

124 

125 def infer(self, inputs): 

126 """Return the score of this cutout. 

127 

128 Parameters 

129 ---------- 

130 inputs : `list` [`CutoutInputs`] 

131 Inputs to be scored. 

132 

133 Returns 

134 ------- 

135 scores : `numpy.array` 

136 Float scores for each element of ``inputs``. 

137 """ 

138 

139 # Convert the inputs to batches. 

140 # TODO: The batch size is set to 64 for now. Later when 

141 # deploying parallel instances of the task, memory limits 

142 # should be taken into account, if necessary. 

143 batches = self.input_to_batches(inputs, batchSize=64) 

144 

145 # Loop over the batches 

146 for i, batch in enumerate(batches): 

147 torchBlob, labelsList = self.prepare_input(batch) 

148 

149 # Run the model 

150 with torch.no_grad(): 

151 output_ = self.model(torchBlob) 

152 output = torch.sigmoid(output_) 

153 

154 # And append the results to the list 

155 if i == 0: 

156 scores = output 

157 else: 

158 scores = torch.cat((scores, output.cpu()), dim=0) 

159 

160 npyScores = scores.detach().numpy().ravel() 

161 return npyScores