Coverage for python/lsst/meas/transiNet/modelPackages/utils.py: 26%
35 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 03:18 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 03:18 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["import_model"]
24import importlib
25import importlib.abc
26import importlib.machinery
27import importlib.util
28import torch.nn
31def load_module_from_memory(file_like_object, name='model'):
32 """Load a module from the specified file-like object.
34 Parameters
35 ----------
36 file_like_object : `file-like object`
37 The file-like object containing the module code.
38 name : `str`, optional
39 Name to give to the module. Default: 'model'.
41 Returns
42 -------
43 module : `module`
44 The module object.
45 """
47 class InMemoryLoader(importlib.abc.SourceLoader):
48 def __init__(self, data):
49 self.data = data
51 def get_data(self, path):
52 # In this context, 'path' is not used as data is already in memory
53 return self.data
55 def get_filename(self, fullname):
56 # This method is required but the filename is not important here
57 return '<in-memory>'
59 content = file_like_object.getvalue()
60 loader = InMemoryLoader(content)
61 spec = importlib.util.spec_from_loader(name, loader)
62 module = importlib.util.module_from_spec(spec)
63 loader.exec_module(module)
65 return module
68def load_module_from_file(path, name='model'):
69 """Load a module from the specified path and return the module object.
71 Parameters
72 ----------
73 path : str
74 Path to the module file.
75 name : str, optional
76 Name to give to the module. Default: 'model'.
78 Returns
79 -------
80 module : module
81 The loaded module.
82 """
83 spec = importlib.util.spec_from_file_location(name, path)
84 module = importlib.util.module_from_spec(spec)
85 spec.loader.exec_module(module)
86 return module
89def import_model_from_module(module):
90 """Import a pytorch neural network architecture from the specified module.
92 Parameters
93 ----------
94 module : module
95 The module containing the neural network architecture.
97 Returns
98 -------
99 model : `torch.nn.Module`
100 The model class object.
101 """
102 if len(module.__all__) != 1:
103 raise ImportError(f"Multiple entries in {module}: cannot find model class.")
105 model = getattr(module, module.__all__[0])
106 if torch.nn.Module not in model.__bases__:
107 raise ImportError(f"Loaded class {model}, from {module}, is not a pytorch neural network module.")
109 return model()
112def import_model(path):
113 """Import a pytorch neural network architecture from the specified path.
115 Parameters
116 ----------
117 path : `str`
118 Path to the model file.
120 Returns
121 -------
122 model : `torch.nn.Module`
123 The model class object.
125 Raises
126 ------
127 ImportError
128 Raised if a valid pytorch model cannot be found in loaded module.
129 """
131 module = load_module_from_file(path)
132 return import_model_from_module(module)