Coverage for tests/test_ModelPackage.py: 23%

108 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 02:22 -0700

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23import torch 

24import os 

25import shutil 

26 

27from lsst.meas.transiNet.modelPackages.nnModelPackage import NNModelPackage 

28from lsst.meas.transiNet.modelPackages.storageAdapterLocal import StorageAdapterLocal 

29from lsst.meas.transiNet.modelPackages.storageAdapterNeighbor import StorageAdapterNeighbor 

30 

31import lsst.utils 

32try: 

33 neighborDirectory = lsst.utils.getPackageDir("rbClassifier_data") 

34except LookupError: 

35 neighborDirectory = None 

36 

37 

38class TestModelPackageLocal(unittest.TestCase): 

39 def setUp(self): 

40 self.model_package_name = 'dummy' 

41 self.package_storage_mode = 'local' 

42 

43 def test_load(self): 

44 """Test loading of a local model package 

45 """ 

46 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

47 model = model_package.load(device='cpu') 

48 

49 weights = next(model.parameters()) 

50 

51 # Test shape of loaded weights. 

52 self.assertTupleEqual(weights.shape, (16, 3, 3, 3)) 

53 

54 # Test weight values. 

55 # Only test a single tensor, as the probability of randomly having 

56 # matching weights "only" in a single tensor is extremely low. 

57 torch.testing.assert_close(weights[0][0], 

58 torch.tensor([[0.14145353, -0.10257456, 0.17189537], 

59 [-0.03069756, -0.1093155, 0.15207087], 

60 [0.06509985, 0.11900973, -0.16013929]]), 

61 rtol=1e-8, atol=1e-8) 

62 

63 def test_arch_weights_mismatch(self): 

64 """Test loading of a model package with mismatching architecture and 

65 weights. 

66 

67 Does not use PyTorch's built-in serialization to be generic and 

68 independent of the backend. 

69 """ 

70 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

71 

72 # Create a fake architecture file. 

73 arch_f = os.path.basename(model_package.adapter.model_filename) 

74 model_filename_backup = model_package.adapter.model_filename 

75 model_package.adapter.model_filename = model_package.adapter.model_filename.replace(arch_f, 

76 'fake_' + arch_f) 

77 

78 try: 

79 with open(model_package.adapter.model_filename, 'w') as f: 

80 # Write a dummy 1-layer fully connected network into the file. 

81 f.write('__all__ = ["Net"]\n') 

82 f.write('import torch\n') 

83 f.write('import torch.nn as nn\n') 

84 f.write('class Net(nn.Module):\n') 

85 f.write(' def __init__(self):\n') 

86 f.write(' super(Net, self).__init__()\n') 

87 f.write(' self.fc1 = nn.Linear(3, 16)\n') 

88 f.write(' def forward(self, x):\n') 

89 f.write(' x = self.fc1(x)\n') 

90 f.write(' return x\n') 

91 finally: 

92 # Now try to load the model. 

93 with self.assertRaises(RuntimeError): 

94 model_package.load(device='cpu') 

95 

96 # Clean up. 

97 os.remove(model_package.adapter.model_filename) 

98 model_package.adapter.model_filename = model_filename_backup 

99 

100 def test_invalid_inputs(self): 

101 """Test invalid and missing inputs 

102 (of NNModelPackage constructor, as well as the load method) 

103 """ 

104 with self.assertRaises(ValueError): 

105 NNModelPackage('dummy', 'invalid') 

106 

107 with self.assertRaises(ValueError): 

108 NNModelPackage('invalid', None) 

109 

110 with self.assertRaises(ValueError): 

111 NNModelPackage(None, 'local') 

112 

113 with self.assertRaises(ValueError): 

114 NNModelPackage(None, 'invalid') 

115 

116 with self.assertRaises(ValueError): 

117 NNModelPackage(None, None) 

118 

119 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

120 

121 with self.assertRaises(ValueError): 

122 model_package.load(device='invalid') 

123 

124 with self.assertRaises(ValueError): 

125 model_package.load(device='gpu199') 

126 

127 with self.assertRaises(ValueError): 

128 model_package.load(device=None) 

129 

130 def test_metadata(self): 

131 """Test loading of metadata 

132 """ 

133 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

134 

135 # Test whether the metadata object exists. 

136 # (it should be automatically loaded when the model package 

137 # is constructed) 

138 self.assertTrue(hasattr(model_package, 'metadata')) 

139 

140 # Test whether the metadata object is a dictionary. 

141 self.assertIsInstance(model_package.metadata, dict) 

142 

143 # Test whether the metadata object contains the mandatory keys. 

144 self.assertListEqual(list(model_package.metadata.keys()), 

145 ['version', 'description', 

146 'input_shape', 'input_scale_factor'], 

147 msg='Metadata object does not contain the mandatory keys.') 

148 

149 # Test whether the metadata-related methods return the correct values 

150 # for the dummy model package. 

151 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3)) 

152 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0)) 

153 with self.assertRaises(KeyError): 

154 model_package.get_boost_factor() # No boost factor for dummy 

155 

156 # Test whether the number of scale factor elements matches the number 

157 # of input channels. 

158 self.assertEqual(len(model_package.get_input_scale_factors()), 

159 model_package.get_model_input_shape()[2]) 

160 

161 

162@unittest.skipIf(neighborDirectory is None, "rbClassifier_data not setup") 

163class TestModelPackageNeighbor(unittest.TestCase): 

164 def setUp(self): 

165 # Create a dummy model package in the neighboring repository 

166 source_dir = os.path.join(StorageAdapterLocal.get_base_path(), 'dummy') 

167 self.temp_package_dir = os.path.join(StorageAdapterNeighbor.get_base_path(), 'dummy') 

168 

169 try: 

170 shutil.copytree(source_dir, self.temp_package_dir) 

171 except FileExistsError: 

172 raise RuntimeError('Dummy model package in neighbor mode!') 

173 

174 self.model_package_name = 'dummy' 

175 self.package_storage_mode = 'neighbor' 

176 

177 def tearDown(self): 

178 # Remove the neighbor-mode dummy model package 

179 shutil.rmtree(self.temp_package_dir) 

180 

181 def test_load(self): 

182 """Test loading of a model package of neighbor mode 

183 """ 

184 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

185 model = model_package.load(device='cpu') 

186 

187 weights = next(model.parameters()) 

188 

189 # test to make sure the model package is loading from the 

190 # neighbor repository. 

191 # 

192 # TODO: later if we move this test to the neighbor package itself, this 

193 # check needs to be updated. 

194 self.assertTrue(model_package.adapter.checkpoint_filename.startswith( 

195 lsst.utils.getPackageDir("rbClassifier_data"))) 

196 

197 # test shape of loaded weights 

198 self.assertTupleEqual(weights.shape, (16, 3, 3, 3)) 

199 

200 # test weight values 

201 torch.testing.assert_close(weights[0][0], 

202 torch.tensor([[0.14145353, -0.10257456, 0.17189537], 

203 [-0.03069756, -0.1093155, 0.15207087], 

204 [0.06509985, 0.11900973, -0.16013929]]), 

205 rtol=1e-8, atol=1e-8) 

206 

207 def test_metadata(self): 

208 """Test loading of metadata 

209 """ 

210 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

211 

212 # Test whether the metadata object exists. 

213 # (it should be automatically loaded when the model package 

214 # is constructed) 

215 self.assertTrue(hasattr(model_package, 'metadata')) 

216 

217 # Test whether the metadata object is a dictionary. 

218 self.assertIsInstance(model_package.metadata, dict) 

219 

220 # Test whether the metadata object contains the mandatory keys. 

221 self.assertListEqual(list(model_package.metadata.keys()), 

222 ['version', 'description', 

223 'input_shape', 'input_scale_factor'], 

224 msg='Metadata object does not contain the mandatory keys.') 

225 

226 # Test whether the metadata-related methods return the correct values 

227 # for the dummy model package. 

228 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3)) 

229 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0)) 

230 with self.assertRaises(KeyError): 

231 model_package.get_boost_factor() # No boost factor for dummy 

232 

233 # Test whether the number of scale factor elements matches the number 

234 # of input channels. 

235 self.assertEqual(len(model_package.get_input_scale_factors()), 

236 model_package.get_model_input_shape()[2]) 

237 

238 

239class MemoryTester(lsst.utils.tests.MemoryTestCase): 

240 pass 

241 

242 

243def setup_module(module): 

244 lsst.utils.tests.init() 

245 

246 

247if __name__ == "__main__": 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true

248 lsst.utils.tests.init() 

249 unittest.main()