Coverage for tests/test_ModelPackage.py: 24%

128 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 11:25 +0000

1# This file is part of meas_transiNet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23import torch 

24import os 

25import shutil 

26import tempfile 

27 

28from lsst.meas.transiNet.modelPackages.nnModelPackage import NNModelPackage 

29from lsst.meas.transiNet.modelPackages.storageAdapterLocal import StorageAdapterLocal 

30from lsst.meas.transiNet.modelPackages.storageAdapterNeighbor import StorageAdapterNeighbor 

31from lsst.meas.transiNet.modelPackages.storageAdapterButler import StorageAdapterButler 

32from lsst.daf.butler import Butler 

33from lsst.daf.butler.registry._exceptions import ConflictingDefinitionError 

34import lsst.utils 

35try: 

36 neighborDirectory = lsst.utils.getPackageDir("rbClassifier_data") 

37except LookupError: 

38 neighborDirectory = None 

39 

40 

41def sanity_check_dummy_model(test, model): 

42 weights = next(model.parameters()) 

43 

44 # Test shape of loaded weights. 

45 test.assertTupleEqual(weights.shape, (16, 3, 3, 3)) 

46 

47 # Test weight values. 

48 # Only test a single tensor, as the probability of randomly having 

49 # matching weights "only" in a single tensor is extremely low. 

50 torch.testing.assert_close(weights[0][0], 

51 torch.tensor([[0.14145353, -0.10257456, 0.17189537], 

52 [-0.03069756, -0.1093155, 0.15207087], 

53 [0.06509985, 0.11900973, -0.16013929]]), 

54 rtol=1e-8, atol=1e-8) 

55 

56 

57class TestModelPackageLocal(unittest.TestCase): 

58 def setUp(self): 

59 self.model_package_name = 'dummy' 

60 self.package_storage_mode = 'local' 

61 

62 def test_load(self): 

63 """Test loading of a local model package 

64 """ 

65 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

66 model = model_package.load(device='cpu') 

67 sanity_check_dummy_model(self, model) 

68 

69 def test_arch_weights_mismatch(self): 

70 """Test loading of a model package with mismatching architecture and 

71 weights. 

72 

73 Does not use PyTorch's built-in serialization to be generic and 

74 independent of the backend. 

75 """ 

76 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

77 

78 # Create a fake architecture file. 

79 arch_f = os.path.basename(model_package.adapter.model_filename) 

80 model_filename_backup = model_package.adapter.model_filename 

81 model_package.adapter.model_filename = model_package.adapter.model_filename.replace(arch_f, 

82 'fake_' + arch_f) 

83 

84 try: 

85 with open(model_package.adapter.model_filename, 'w') as f: 

86 # Write a dummy 1-layer fully connected network into the file. 

87 f.write('__all__ = ["Net"]\n') 

88 f.write('import torch\n') 

89 f.write('import torch.nn as nn\n') 

90 f.write('class Net(nn.Module):\n') 

91 f.write(' def __init__(self):\n') 

92 f.write(' super(Net, self).__init__()\n') 

93 f.write(' self.fc1 = nn.Linear(3, 16)\n') 

94 f.write(' def forward(self, x):\n') 

95 f.write(' x = self.fc1(x)\n') 

96 f.write(' return x\n') 

97 finally: 

98 # Now try to load the model. 

99 with self.assertRaises(RuntimeError): 

100 model_package.load(device='cpu') 

101 

102 # Clean up. 

103 os.remove(model_package.adapter.model_filename) 

104 model_package.adapter.model_filename = model_filename_backup 

105 

106 def test_invalid_inputs(self): 

107 """Test invalid and missing inputs 

108 (of NNModelPackage constructor, as well as the load method) 

109 """ 

110 with self.assertRaises(ValueError): 

111 NNModelPackage('dummy', 'invalid') 

112 

113 with self.assertRaises(ValueError): 

114 NNModelPackage('invalid', None) 

115 

116 with self.assertRaises(ValueError): 

117 NNModelPackage(None, 'local') 

118 

119 with self.assertRaises(ValueError): 

120 NNModelPackage(None, 'invalid') 

121 

122 with self.assertRaises(ValueError): 

123 NNModelPackage(None, None) 

124 

125 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

126 

127 with self.assertRaises(ValueError): 

128 model_package.load(device='invalid') 

129 

130 with self.assertRaises(ValueError): 

131 model_package.load(device='gpu199') 

132 

133 with self.assertRaises(ValueError): 

134 model_package.load(device=None) 

135 

136 def test_metadata(self): 

137 """Test loading of metadata 

138 """ 

139 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

140 

141 # Test whether the metadata object exists. 

142 # (it should be automatically loaded when the model package 

143 # is constructed) 

144 self.assertTrue(hasattr(model_package, 'metadata')) 

145 

146 # Test whether the metadata object is a dictionary. 

147 self.assertIsInstance(model_package.metadata, dict) 

148 

149 # Test whether the metadata object contains the mandatory keys. 

150 self.assertListEqual(list(model_package.metadata.keys()), 

151 ['version', 'description', 

152 'input_shape', 'input_scale_factor'], 

153 msg='Metadata object does not contain the mandatory keys.') 

154 

155 # Test whether the metadata-related methods return the correct values 

156 # for the dummy model package. 

157 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3)) 

158 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0)) 

159 with self.assertRaises(KeyError): 

160 model_package.get_boost_factor() # No boost factor for dummy 

161 

162 # Test whether the number of scale factor elements matches the number 

163 # of input channels. 

164 self.assertEqual(len(model_package.get_input_scale_factors()), 

165 model_package.get_model_input_shape()[2]) 

166 

167 

168@unittest.skipIf(neighborDirectory is None, "rbClassifier_data not setup") 

169class TestModelPackageNeighbor(unittest.TestCase): 

170 def setUp(self): 

171 # Create a dummy model package in the neighboring repository 

172 source_dir = os.path.join(StorageAdapterLocal.get_base_path(), 'dummy') 

173 self.temp_package_dir = os.path.join(StorageAdapterNeighbor.get_base_path(), 'dummy') 

174 

175 try: 

176 shutil.copytree(source_dir, self.temp_package_dir) 

177 except FileExistsError: 

178 raise RuntimeError('Dummy model package in neighbor mode!') 

179 

180 self.model_package_name = 'dummy' 

181 self.package_storage_mode = 'neighbor' 

182 

183 def tearDown(self): 

184 # Remove the neighbor-mode dummy model package 

185 shutil.rmtree(self.temp_package_dir) 

186 

187 def test_load(self): 

188 """Test loading of a model package of neighbor mode 

189 """ 

190 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

191 model = model_package.load(device='cpu') 

192 

193 # test to make sure the model package is loading from the 

194 # neighbor repository. 

195 # 

196 # TODO: later if we move this test to the neighbor package itself, this 

197 # check needs to be updated. 

198 self.assertTrue(model_package.adapter.checkpoint_filename.startswith( 

199 lsst.utils.getPackageDir("rbClassifier_data"))) 

200 

201 sanity_check_dummy_model(self, model) 

202 

203 def test_metadata(self): 

204 """Test loading of metadata 

205 """ 

206 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode) 

207 

208 # Test whether the metadata object exists. 

209 # (it should be automatically loaded when the model package 

210 # is constructed) 

211 self.assertTrue(hasattr(model_package, 'metadata')) 

212 

213 # Test whether the metadata object is a dictionary. 

214 self.assertIsInstance(model_package.metadata, dict) 

215 

216 # Test whether the metadata object contains the mandatory keys. 

217 self.assertListEqual(list(model_package.metadata.keys()), 

218 ['version', 'description', 

219 'input_shape', 'input_scale_factor'], 

220 msg='Metadata object does not contain the mandatory keys.') 

221 

222 # Test whether the metadata-related methods return the correct values 

223 # for the dummy model package. 

224 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3)) 

225 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0)) 

226 with self.assertRaises(KeyError): 

227 model_package.get_boost_factor() # No boost factor for dummy 

228 

229 # Test whether the number of scale factor elements matches the number 

230 # of input channels. 

231 self.assertEqual(len(model_package.get_input_scale_factors()), 

232 model_package.get_model_input_shape()[2]) 

233 

234 

235class TestModelPackageButler(unittest.TestCase): 

236 def setUp(self): 

237 self.model_package_name = 'dummy' 

238 

239 # Create a dummy butler repository (in a temporary directory). 

240 # Note that a test repo using makeTestRepo would not suffice 

241 # as we need to test the ingestion of a model package too. 

242 self.repo_root = tempfile.mkdtemp(prefix='butler_') 

243 Butler.makeRepo(root=self.repo_root) 

244 self.butler = Butler(self.repo_root, writeable=True) 

245 

246 def tearDown(self): 

247 shutil.rmtree(self.repo_root) 

248 

249 def ingest(self): 

250 # Load a local model package, to transfer/ingest to 

251 # the butler repository. 

252 local_model_package = NNModelPackage('dummy', 'local') 

253 StorageAdapterButler.ingest(local_model_package, 

254 self.butler, 

255 model_package_name=self.model_package_name) 

256 

257 def load_from_butler(self): 

258 # Load the model package from the butler repository. 

259 model_package = NNModelPackage(model_package_name=self.model_package_name, 

260 package_storage_mode='butler', 

261 butler=self.butler) 

262 return model_package 

263 

264 def test_double_ingest(self): 

265 """Test whether redundant ingestion of a model package to the butler 

266 repository fails as expected. 

267 """ 

268 self.ingest() 

269 # assert that the second one raises ConflictingDefinitionError 

270 with self.assertRaises(ConflictingDefinitionError): 

271 self.ingest() 

272 

273 def test_ingest_load(self): 

274 """Test ingesting and loading of a model package of butler mode 

275 """ 

276 self.ingest() 

277 model_package = self.load_from_butler() 

278 model = model_package.load(device='cpu') 

279 sanity_check_dummy_model(self, model)