Coverage for tests/test_ModelPackage.py: 23%
108 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 05:06 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 05:06 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
23import torch
24import os
25import shutil
27from lsst.meas.transiNet.modelPackages.nnModelPackage import NNModelPackage
28from lsst.meas.transiNet.modelPackages.storageAdapterLocal import StorageAdapterLocal
29from lsst.meas.transiNet.modelPackages.storageAdapterNeighbor import StorageAdapterNeighbor
31import lsst.utils
32try:
33 neighborDirectory = lsst.utils.getPackageDir("rbClassifier_data")
34except LookupError:
35 neighborDirectory = None
38class TestModelPackageLocal(unittest.TestCase):
39 def setUp(self):
40 self.model_package_name = 'dummy'
41 self.package_storage_mode = 'local'
43 def test_load(self):
44 """Test loading of a local model package
45 """
46 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
47 model = model_package.load(device='cpu')
49 weights = next(model.parameters())
51 # Test shape of loaded weights.
52 self.assertTupleEqual(weights.shape, (16, 3, 3, 3))
54 # Test weight values.
55 # Only test a single tensor, as the probability of randomly having
56 # matching weights "only" in a single tensor is extremely low.
57 torch.testing.assert_close(weights[0][0],
58 torch.tensor([[0.14145353, -0.10257456, 0.17189537],
59 [-0.03069756, -0.1093155, 0.15207087],
60 [0.06509985, 0.11900973, -0.16013929]]),
61 rtol=1e-8, atol=1e-8)
63 def test_arch_weights_mismatch(self):
64 """Test loading of a model package with mismatching architecture and
65 weights.
67 Does not use PyTorch's built-in serialization to be generic and
68 independent of the backend.
69 """
70 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
72 # Create a fake architecture file.
73 arch_f = os.path.basename(model_package.adapter.model_filename)
74 model_filename_backup = model_package.adapter.model_filename
75 model_package.adapter.model_filename = model_package.adapter.model_filename.replace(arch_f,
76 'fake_' + arch_f)
78 try:
79 with open(model_package.adapter.model_filename, 'w') as f:
80 # Write a dummy 1-layer fully connected network into the file.
81 f.write('__all__ = ["Net"]\n')
82 f.write('import torch\n')
83 f.write('import torch.nn as nn\n')
84 f.write('class Net(nn.Module):\n')
85 f.write(' def __init__(self):\n')
86 f.write(' super(Net, self).__init__()\n')
87 f.write(' self.fc1 = nn.Linear(3, 16)\n')
88 f.write(' def forward(self, x):\n')
89 f.write(' x = self.fc1(x)\n')
90 f.write(' return x\n')
91 finally:
92 # Now try to load the model.
93 with self.assertRaises(RuntimeError):
94 model_package.load(device='cpu')
96 # Clean up.
97 os.remove(model_package.adapter.model_filename)
98 model_package.adapter.model_filename = model_filename_backup
100 def test_invalid_inputs(self):
101 """Test invalid and missing inputs
102 (of NNModelPackage constructor, as well as the load method)
103 """
104 with self.assertRaises(ValueError):
105 NNModelPackage('dummy', 'invalid')
107 with self.assertRaises(ValueError):
108 NNModelPackage('invalid', None)
110 with self.assertRaises(ValueError):
111 NNModelPackage(None, 'local')
113 with self.assertRaises(ValueError):
114 NNModelPackage(None, 'invalid')
116 with self.assertRaises(ValueError):
117 NNModelPackage(None, None)
119 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
121 with self.assertRaises(ValueError):
122 model_package.load(device='invalid')
124 with self.assertRaises(ValueError):
125 model_package.load(device='gpu199')
127 with self.assertRaises(ValueError):
128 model_package.load(device=None)
130 def test_metadata(self):
131 """Test loading of metadata
132 """
133 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
135 # Test whether the metadata object exists.
136 # (it should be automatically loaded when the model package
137 # is constructed)
138 self.assertTrue(hasattr(model_package, 'metadata'))
140 # Test whether the metadata object is a dictionary.
141 self.assertIsInstance(model_package.metadata, dict)
143 # Test whether the metadata object contains the mandatory keys.
144 self.assertListEqual(list(model_package.metadata.keys()),
145 ['version', 'description',
146 'input_shape', 'input_scale_factor'],
147 msg='Metadata object does not contain the mandatory keys.')
149 # Test whether the metadata-related methods return the correct values
150 # for the dummy model package.
151 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3))
152 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0))
153 with self.assertRaises(KeyError):
154 model_package.get_boost_factor() # No boost factor for dummy
156 # Test whether the number of scale factor elements matches the number
157 # of input channels.
158 self.assertEqual(len(model_package.get_input_scale_factors()),
159 model_package.get_model_input_shape()[2])
162@unittest.skipIf(neighborDirectory is None, "rbClassifier_data not setup")
163class TestModelPackageNeighbor(unittest.TestCase):
164 def setUp(self):
165 # Create a dummy model package in the neighboring repository
166 source_dir = os.path.join(StorageAdapterLocal.get_base_path(), 'dummy')
167 self.temp_package_dir = os.path.join(StorageAdapterNeighbor.get_base_path(), 'dummy')
169 try:
170 shutil.copytree(source_dir, self.temp_package_dir)
171 except FileExistsError:
172 raise RuntimeError('Dummy model package in neighbor mode!')
174 self.model_package_name = 'dummy'
175 self.package_storage_mode = 'neighbor'
177 def tearDown(self):
178 # Remove the neighbor-mode dummy model package
179 shutil.rmtree(self.temp_package_dir)
181 def test_load(self):
182 """Test loading of a model package of neighbor mode
183 """
184 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
185 model = model_package.load(device='cpu')
187 weights = next(model.parameters())
189 # test to make sure the model package is loading from the
190 # neighbor repository.
191 #
192 # TODO: later if we move this test to the neighbor package itself, this
193 # check needs to be updated.
194 self.assertTrue(model_package.adapter.checkpoint_filename.startswith(
195 lsst.utils.getPackageDir("rbClassifier_data")))
197 # test shape of loaded weights
198 self.assertTupleEqual(weights.shape, (16, 3, 3, 3))
200 # test weight values
201 torch.testing.assert_close(weights[0][0],
202 torch.tensor([[0.14145353, -0.10257456, 0.17189537],
203 [-0.03069756, -0.1093155, 0.15207087],
204 [0.06509985, 0.11900973, -0.16013929]]),
205 rtol=1e-8, atol=1e-8)
207 def test_metadata(self):
208 """Test loading of metadata
209 """
210 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
212 # Test whether the metadata object exists.
213 # (it should be automatically loaded when the model package
214 # is constructed)
215 self.assertTrue(hasattr(model_package, 'metadata'))
217 # Test whether the metadata object is a dictionary.
218 self.assertIsInstance(model_package.metadata, dict)
220 # Test whether the metadata object contains the mandatory keys.
221 self.assertListEqual(list(model_package.metadata.keys()),
222 ['version', 'description',
223 'input_shape', 'input_scale_factor'],
224 msg='Metadata object does not contain the mandatory keys.')
226 # Test whether the metadata-related methods return the correct values
227 # for the dummy model package.
228 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3))
229 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0))
230 with self.assertRaises(KeyError):
231 model_package.get_boost_factor() # No boost factor for dummy
233 # Test whether the number of scale factor elements matches the number
234 # of input channels.
235 self.assertEqual(len(model_package.get_input_scale_factors()),
236 model_package.get_model_input_shape()[2])
239class MemoryTester(lsst.utils.tests.MemoryTestCase):
240 pass
243def setup_module(module):
244 lsst.utils.tests.init()
247if __name__ == "__main__": 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true
248 lsst.utils.tests.init()
249 unittest.main()