Coverage for python/lsst/meas/transiNet/rbTransiNetInterface.py: 43%

40 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-05 11:21 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["RBTransiNetInterface", "CutoutInputs"] 

23 

24import numpy as np 

25import dataclasses 

26import torch 

27 

28from .modelPackages.nnModelPackage import NNModelPackage 

29 

30 

31@dataclasses.dataclass(frozen=True, kw_only=True) 

32class CutoutInputs: 

33 """Science/template/difference cutouts of a single object plus other 

34 metadata. 

35 """ 

36 science: np.ndarray 

37 template: np.ndarray 

38 difference: np.ndarray 

39 

40 label: bool = None 

41 """Known truth of whether this is a real or bogus object.""" 

42 

43 

44class RBTransiNetInterface: 

45 """ 

46 The interface between the LSST AP pipeline and a trained pytorch-based 

47 RBTransiNet neural network model. 

48 

49 Parameters 

50 ---------- 

51 model_package_name : `str` 

52 Name of the model package to load. 

53 package_storage_mode : {'local', 'neighbor'} 

54 Storage mode of the model package 

55 device : `str` 

56 Device to load and run the neural network on, e.g. 'cpu' or 'cuda:0' 

57 """ 

58 

59 def __init__(self, model_package_name, package_storage_mode, device='cpu'): 

60 self.model_package_name = model_package_name 

61 self.package_storage_mode = package_storage_mode 

62 self.device = device 

63 self.init_model() 

64 

65 def init_model(self): 

66 """Create and initialize an NN model 

67 """ 

68 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

69 self.model = model_package.load(self.device) 

70 

71 # Put the model in evaluation mode instead of training model. 

72 self.model.eval() 

73 

74 def prepare_input(self, inputs): 

75 """ 

76 Convert inputs from numpy arrays, etc. to a torch.tensor blob. 

77 

78 Parameters 

79 ---------- 

80 inputs : `list` [`CutoutInputs`] 

81 Inputs to be scored. 

82 

83 Returns 

84 ------- 

85 blob 

86 Prepared torch tensor blob to run the model on. 

87 labels 

88 Truth labels, concatenated into a single list. 

89 """ 

90 cutoutsList = [] 

91 labelsList = [] 

92 for inp in inputs: 

93 # Convert each cutout to a torch tensor 

94 template = torch.from_numpy(inp.template) 

95 science = torch.from_numpy(inp.science) 

96 difference = torch.from_numpy(inp.difference) 

97 

98 # Stack the components to create a single blob 

99 singleBlob = torch.stack((template, science, difference), dim=0) 

100 

101 # And append them to the temporary list 

102 cutoutsList.append(singleBlob) 

103 

104 labelsList.append(inp.label) 

105 

106 torchBlob = torch.stack(cutoutsList) 

107 return torchBlob, labelsList 

108 

109 def infer(self, inputs): 

110 """Return the score of this cutout. 

111 

112 Parameters 

113 ---------- 

114 inputs : `list` [`CutoutInputs`] 

115 Inputs to be scored. 

116 

117 Returns 

118 ------- 

119 scores : `numpy.array` 

120 Float scores for each element of ``inputs``. 

121 """ 

122 blob, labels = self.prepare_input(inputs) 

123 result = self.model(blob) 

124 scores = torch.sigmoid(result) 

125 npyScores = scores.detach().numpy().ravel() 

126 

127 return npyScores