Coverage for python/lsst/meas/transiNet/modelPackages/utils.py: 24%
13 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-12 02:22 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-12 02:22 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["import_model"]
24import importlib
26import torch.nn
29def import_model(path):
30 """Import a model from the specified path and return the class object.
32 Parameters
33 ----------
34 path : `str`
35 Path to the model file.
37 Returns
38 -------
39 model : `torch.nn.Module`
40 The model class object.
42 Raises
43 ------
44 ImportError
45 Raised if a valid pytorch model cannot be found in loaded module.
46 """
47 spec = importlib.util.spec_from_file_location('model', path)
48 module = importlib.util.module_from_spec(spec)
49 spec.loader.exec_module(module)
51 if len(module.__all__) != 1:
52 raise ImportError(f"Multiple entries in {module}: cannot find model class.")
54 model = getattr(module, module.__all__[0])
55 if torch.nn.Module not in model.__bases__:
56 raise ImportError(f"Loaded class {model}, from {module}, is not a pytorch neural network module.")
58 return model()