Coverage for python/lsst/meas/transiNet/modelPackages/nnModelPackage.py: 27%

26 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-17 03:12 -0700

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["NNModelPackage"] 

23 

24from .storageAdapterFactory import StorageAdapterFactory 

25 

26import torch 

27 

28 

29class NNModelPackage: 

30 """ 

31 An interface to abstract physical storage of network architecture & 

32 pretrained models out of clients' code. 

33 

34 It handles all necessary required tasks, including fetching, 

35 decompression, etc. per need and creates a "Model Package" 

36 ready to use: a model architecture loaded with specific pretrained 

37 weights. 

38 """ 

39 

40 def __init__(self, model_package_name, package_storage_mode): 

41 # Validate passed arguments. 

42 if package_storage_mode not in StorageAdapterFactory.storageAdapterClasses.keys(): 

43 raise ValueError("Unsupported storage mode: %s" % package_storage_mode) 

44 if None in (model_package_name, package_storage_mode): 

45 raise ValueError("None is not a valid argument") 

46 

47 self.model_package_name = model_package_name 

48 self.package_storage_mode = package_storage_mode 

49 

50 self.adapter = StorageAdapterFactory.create(self.model_package_name, self.package_storage_mode) 

51 self.metadata = self.adapter.load_metadata() 

52 

53 def load(self, device): 

54 """Load model architecture and pretrained weights. 

55 This method handles all different modes of storages. 

56 

57 

58 Parameters 

59 ---------- 

60 device : `str` 

61 Device to create the model on, e.g. 'cpu' or 'cuda:0'. 

62 

63 Returns 

64 ------- 

65 model : `torch.nn.Module` 

66 The neural network model, loaded with pretrained weights. 

67 Its type should be a subclass of nn.Module, defined by 

68 the architecture module. 

69 """ 

70 

71 # Check if the specified device is valid. 

72 if device not in ['cpu'] + ['cuda:%d' % i for i in range(torch.cuda.device_count())]: 

73 raise ValueError("Invalid device: %s" % device) 

74 

75 # Load various components based on the storage mode 

76 model = self.adapter.load_arch(device) 

77 network_data = self.adapter.load_weights(device) 

78 

79 # Load pretrained weights into model 

80 model.load_state_dict(network_data['state_dict'], strict=True) 

81 

82 return model 

83 

84 def get_model_input_shape(self): 

85 """ Return the input shape of the model. 

86 

87 Returns 

88 ------- 

89 input_shape : `tuple` 

90 The input shape of the model -- (height, width), ignores 

91 the other dimensions. 

92 

93 Raises 

94 ------ 

95 KeyError 

96 If the input shape is not found in the metadata. 

97 """ 

98 return tuple(self.metadata['input_shape']) 

99 

100 def get_input_scale_factors(self): 

101 """ 

102 Return the scale factors to be applied to the input data. 

103 

104 Returns 

105 ------- 

106 scale_factors : `tuple` 

107 The scale factors to be applied to the input data. 

108 

109 Raises 

110 ------ 

111 KeyError 

112 If the scale factors are not found in the metadata. 

113 """ 

114 return tuple(self.metadata['input_scale_factor']) 

115 

116 def get_boost_factor(self): 

117 """ 

118 Return the boost factor to be applied to the output data. 

119 

120 If the boost factor is not found in the metadata, return None. 

121 It is the responsibility of the client to know whether this type 

122 of model requires a boost factor or not. 

123 

124 Returns 

125 ------- 

126 boost_factor : `float` 

127 The boost factor to be applied to the output data. 

128 

129 Raises 

130 ------ 

131 KeyError 

132 If the boost factor is not found in the metadata. 

133 """ 

134 return self.metadata['boost_factor']