Coverage for tests/test_ModelPackage.py: 24%
128 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:46 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:46 +0000
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
23import torch
24import os
25import shutil
26import tempfile
28from lsst.meas.transiNet.modelPackages.nnModelPackage import NNModelPackage
29from lsst.meas.transiNet.modelPackages.storageAdapterLocal import StorageAdapterLocal
30from lsst.meas.transiNet.modelPackages.storageAdapterNeighbor import StorageAdapterNeighbor
31from lsst.meas.transiNet.modelPackages.storageAdapterButler import StorageAdapterButler
32from lsst.daf.butler import Butler
33from lsst.daf.butler.registry._exceptions import ConflictingDefinitionError
34import lsst.utils
35try:
36 neighborDirectory = lsst.utils.getPackageDir("rbClassifier_data")
37except LookupError:
38 neighborDirectory = None
41def sanity_check_dummy_model(test, model):
42 weights = next(model.parameters())
44 # Test shape of loaded weights.
45 test.assertTupleEqual(weights.shape, (16, 3, 3, 3))
47 # Test weight values.
48 # Only test a single tensor, as the probability of randomly having
49 # matching weights "only" in a single tensor is extremely low.
50 torch.testing.assert_close(weights[0][0],
51 torch.tensor([[0.14145353, -0.10257456, 0.17189537],
52 [-0.03069756, -0.1093155, 0.15207087],
53 [0.06509985, 0.11900973, -0.16013929]]),
54 rtol=1e-8, atol=1e-8)
57class TestModelPackageLocal(unittest.TestCase):
58 def setUp(self):
59 self.model_package_name = 'dummy'
60 self.package_storage_mode = 'local'
62 def test_load(self):
63 """Test loading of a local model package
64 """
65 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
66 model = model_package.load(device='cpu')
67 sanity_check_dummy_model(self, model)
69 def test_arch_weights_mismatch(self):
70 """Test loading of a model package with mismatching architecture and
71 weights.
73 Does not use PyTorch's built-in serialization to be generic and
74 independent of the backend.
75 """
76 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
78 # Create a fake architecture file.
79 arch_f = os.path.basename(model_package.adapter.model_filename)
80 model_filename_backup = model_package.adapter.model_filename
81 model_package.adapter.model_filename = model_package.adapter.model_filename.replace(arch_f,
82 'fake_' + arch_f)
84 try:
85 with open(model_package.adapter.model_filename, 'w') as f:
86 # Write a dummy 1-layer fully connected network into the file.
87 f.write('__all__ = ["Net"]\n')
88 f.write('import torch\n')
89 f.write('import torch.nn as nn\n')
90 f.write('class Net(nn.Module):\n')
91 f.write(' def __init__(self):\n')
92 f.write(' super(Net, self).__init__()\n')
93 f.write(' self.fc1 = nn.Linear(3, 16)\n')
94 f.write(' def forward(self, x):\n')
95 f.write(' x = self.fc1(x)\n')
96 f.write(' return x\n')
97 finally:
98 # Now try to load the model.
99 with self.assertRaises(RuntimeError):
100 model_package.load(device='cpu')
102 # Clean up.
103 os.remove(model_package.adapter.model_filename)
104 model_package.adapter.model_filename = model_filename_backup
106 def test_invalid_inputs(self):
107 """Test invalid and missing inputs
108 (of NNModelPackage constructor, as well as the load method)
109 """
110 with self.assertRaises(ValueError):
111 NNModelPackage('dummy', 'invalid')
113 with self.assertRaises(ValueError):
114 NNModelPackage('invalid', None)
116 with self.assertRaises(ValueError):
117 NNModelPackage(None, 'local')
119 with self.assertRaises(ValueError):
120 NNModelPackage(None, 'invalid')
122 with self.assertRaises(ValueError):
123 NNModelPackage(None, None)
125 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
127 with self.assertRaises(ValueError):
128 model_package.load(device='invalid')
130 with self.assertRaises(ValueError):
131 model_package.load(device='gpu199')
133 with self.assertRaises(ValueError):
134 model_package.load(device=None)
136 def test_metadata(self):
137 """Test loading of metadata
138 """
139 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
141 # Test whether the metadata object exists.
142 # (it should be automatically loaded when the model package
143 # is constructed)
144 self.assertTrue(hasattr(model_package, 'metadata'))
146 # Test whether the metadata object is a dictionary.
147 self.assertIsInstance(model_package.metadata, dict)
149 # Test whether the metadata object contains the mandatory keys.
150 self.assertListEqual(list(model_package.metadata.keys()),
151 ['version', 'description',
152 'input_shape', 'input_scale_factor'],
153 msg='Metadata object does not contain the mandatory keys.')
155 # Test whether the metadata-related methods return the correct values
156 # for the dummy model package.
157 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3))
158 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0))
159 with self.assertRaises(KeyError):
160 model_package.get_boost_factor() # No boost factor for dummy
162 # Test whether the number of scale factor elements matches the number
163 # of input channels.
164 self.assertEqual(len(model_package.get_input_scale_factors()),
165 model_package.get_model_input_shape()[2])
168@unittest.skipIf(neighborDirectory is None, "rbClassifier_data not setup")
169class TestModelPackageNeighbor(unittest.TestCase):
170 def setUp(self):
171 # Create a dummy model package in the neighboring repository
172 source_dir = os.path.join(StorageAdapterLocal.get_base_path(), 'dummy')
173 self.temp_package_dir = os.path.join(StorageAdapterNeighbor.get_base_path(), 'dummy')
175 try:
176 shutil.copytree(source_dir, self.temp_package_dir)
177 except FileExistsError:
178 raise RuntimeError('Dummy model package in neighbor mode!')
180 self.model_package_name = 'dummy'
181 self.package_storage_mode = 'neighbor'
183 def tearDown(self):
184 # Remove the neighbor-mode dummy model package
185 shutil.rmtree(self.temp_package_dir)
187 def test_load(self):
188 """Test loading of a model package of neighbor mode
189 """
190 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
191 model = model_package.load(device='cpu')
193 # test to make sure the model package is loading from the
194 # neighbor repository.
195 #
196 # TODO: later if we move this test to the neighbor package itself, this
197 # check needs to be updated.
198 self.assertTrue(model_package.adapter.checkpoint_filename.startswith(
199 lsst.utils.getPackageDir("rbClassifier_data")))
201 sanity_check_dummy_model(self, model)
203 def test_metadata(self):
204 """Test loading of metadata
205 """
206 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
208 # Test whether the metadata object exists.
209 # (it should be automatically loaded when the model package
210 # is constructed)
211 self.assertTrue(hasattr(model_package, 'metadata'))
213 # Test whether the metadata object is a dictionary.
214 self.assertIsInstance(model_package.metadata, dict)
216 # Test whether the metadata object contains the mandatory keys.
217 self.assertListEqual(list(model_package.metadata.keys()),
218 ['version', 'description',
219 'input_shape', 'input_scale_factor'],
220 msg='Metadata object does not contain the mandatory keys.')
222 # Test whether the metadata-related methods return the correct values
223 # for the dummy model package.
224 self.assertEqual(model_package.get_model_input_shape(), (256, 256, 3))
225 self.assertEqual(model_package.get_input_scale_factors(), (1.0, 0.0033333333333333335, 1.0))
226 with self.assertRaises(KeyError):
227 model_package.get_boost_factor() # No boost factor for dummy
229 # Test whether the number of scale factor elements matches the number
230 # of input channels.
231 self.assertEqual(len(model_package.get_input_scale_factors()),
232 model_package.get_model_input_shape()[2])
235class TestModelPackageButler(unittest.TestCase):
236 def setUp(self):
237 self.model_package_name = 'dummy'
239 # Create a dummy butler repository (in a temporary directory).
240 # Note that a test repo using makeTestRepo would not suffice
241 # as we need to test the ingestion of a model package too.
242 self.repo_root = tempfile.mkdtemp(prefix='butler_')
243 Butler.makeRepo(root=self.repo_root)
244 self.butler = Butler(self.repo_root, writeable=True)
246 def tearDown(self):
247 shutil.rmtree(self.repo_root)
249 def ingest(self):
250 # Load a local model package, to transfer/ingest to
251 # the butler repository.
252 local_model_package = NNModelPackage('dummy', 'local')
253 StorageAdapterButler.ingest(local_model_package,
254 self.butler,
255 model_package_name=self.model_package_name)
257 def load_from_butler(self):
258 # Load the model package from the butler repository.
259 model_package = NNModelPackage(model_package_name=self.model_package_name,
260 package_storage_mode='butler',
261 butler=self.butler)
262 return model_package
264 def test_double_ingest(self):
265 """Test whether redundant ingestion of a model package to the butler
266 repository fails as expected.
267 """
268 self.ingest()
269 # assert that the second one raises ConflictingDefinitionError
270 with self.assertRaises(ConflictingDefinitionError):
271 self.ingest()
273 def test_ingest_load(self):
274 """Test ingesting and loading of a model package of butler mode
275 """
276 self.ingest()
277 model_package = self.load_from_butler()
278 model = model_package.load(device='cpu')
279 sanity_check_dummy_model(self, model)