Coverage for python/lsst/meas/transiNet/modelPackages/nnModelPackage.py: 24%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-27 13:19 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["NNModelPackage"] 

23 

24from .storageAdapterFactory import StorageAdapterFactory 

25 

26import torch 

27 

28 

29class NNModelPackage: 

30 """ 

31 An interface to abstract physical storage of network architecture & 

32 pretrained models out of clients' code. 

33 

34 It handles all necessary required tasks, including fetching, 

35 decompression, etc. per need and creates a "Model Package" 

36 ready to use: a model architecture loaded with specific pretrained 

37 weights. 

38 """ 

39 

40 def __init__(self, model_package_name, package_storage_mode, **kwargs): 

41 # Validate passed arguments. 

42 if package_storage_mode not in StorageAdapterFactory.storageAdapterClasses.keys(): 

43 raise ValueError("Unsupported storage mode: %s" % package_storage_mode) 

44 if None in (model_package_name, package_storage_mode): 

45 raise ValueError("None is not a valid argument") 

46 

47 self.model_package_name = model_package_name 

48 self.package_storage_mode = package_storage_mode 

49 

50 self.adapter = StorageAdapterFactory.create(self.model_package_name, 

51 self.package_storage_mode, 

52 **kwargs) 

53 

54 self.metadata = self.adapter.load_metadata() 

55 

56 def load(self, device): 

57 """Load model architecture and pretrained weights. 

58 This method handles all different modes of storages. 

59 

60 

61 Parameters 

62 ---------- 

63 device : `str` 

64 Device to create the model on, e.g. 'cpu' or 'cuda:0'. 

65 

66 Returns 

67 ------- 

68 model : `torch.nn.Module` 

69 The neural network model, loaded with pretrained weights. 

70 Its type should be a subclass of nn.Module, defined by 

71 the architecture module. 

72 """ 

73 

74 # Check if the specified device is valid. 

75 if device not in ['cpu'] + ['cuda:%d' % i for i in range(torch.cuda.device_count())]: 

76 raise ValueError("Invalid device: %s" % device) 

77 

78 # Load various components. 

79 # Note that because of the way the StorageAdapterButler works, 

80 # the model architecture and the pretrained weights are loaded 

81 # into the cpu memory, and only then moved to the target device. 

82 model = self.adapter.load_arch(device='cpu') 

83 network_data = self.adapter.load_weights(device='cpu') 

84 

85 # Load pretrained weights into model 

86 model.load_state_dict(network_data['state_dict'], strict=True) 

87 

88 # Move model to the specified device, if it is not already there. 

89 if device != 'cpu': 

90 model = model.to(device) 

91 

92 return model 

93 

94 def get_model_input_shape(self): 

95 """ Return the input shape of the model. 

96 

97 Returns 

98 ------- 

99 input_shape : `tuple` 

100 The input shape of the model -- (height, width), ignores 

101 the other dimensions. 

102 

103 Raises 

104 ------ 

105 KeyError 

106 If the input shape is not found in the metadata. 

107 """ 

108 return tuple(self.metadata['input_shape']) 

109 

110 def get_input_scale_factors(self): 

111 """ 

112 Return the scale factors to be applied to the input data. 

113 

114 Returns 

115 ------- 

116 scale_factors : `tuple` 

117 The scale factors to be applied to the input data. 

118 

119 Raises 

120 ------ 

121 KeyError 

122 If the scale factors are not found in the metadata. 

123 """ 

124 return tuple(self.metadata['input_scale_factor']) 

125 

126 def get_boost_factor(self): 

127 """ 

128 Return the boost factor to be applied to the output data. 

129 

130 If the boost factor is not found in the metadata, return None. 

131 It is the responsibility of the client to know whether this type 

132 of model requires a boost factor or not. 

133 

134 Returns 

135 ------- 

136 boost_factor : `float` 

137 The boost factor to be applied to the output data. 

138 

139 Raises 

140 ------ 

141 KeyError 

142 If the boost factor is not found in the metadata. 

143 """ 

144 return self.metadata['boost_factor']