Coverage for python/lsst/meas/transiNet/modelPackages/nnModelPackage.py: 27%
26 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-01 09:59 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-01 09:59 +0000
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["NNModelPackage"]
24from .storageAdapterFactory import StorageAdapterFactory
26import torch
29class NNModelPackage:
30 """
31 An interface to abstract physical storage of network architecture &
32 pretrained models out of clients' code.
34 It handles all necessary required tasks, including fetching,
35 decompression, etc. per need and creates a "Model Package"
36 ready to use: a model architecture loaded with specific pretrained
37 weights.
38 """
40 def __init__(self, model_package_name, package_storage_mode):
41 # Validate passed arguments.
42 if package_storage_mode not in StorageAdapterFactory.storageAdapterClasses.keys():
43 raise ValueError("Unsupported storage mode: %s" % package_storage_mode)
44 if None in (model_package_name, package_storage_mode):
45 raise ValueError("None is not a valid argument")
47 self.model_package_name = model_package_name
48 self.package_storage_mode = package_storage_mode
50 self.adapter = StorageAdapterFactory.create(self.model_package_name, self.package_storage_mode)
51 self.metadata = self.adapter.load_metadata()
53 def load(self, device):
54 """Load model architecture and pretrained weights.
55 This method handles all different modes of storages.
58 Parameters
59 ----------
60 device : `str`
61 Device to create the model on, e.g. 'cpu' or 'cuda:0'.
63 Returns
64 -------
65 model : `torch.nn.Module`
66 The neural network model, loaded with pretrained weights.
67 Its type should be a subclass of nn.Module, defined by
68 the architecture module.
69 """
71 # Check if the specified device is valid.
72 if device not in ['cpu'] + ['cuda:%d' % i for i in range(torch.cuda.device_count())]:
73 raise ValueError("Invalid device: %s" % device)
75 # Load various components based on the storage mode
76 model = self.adapter.load_arch(device)
77 network_data = self.adapter.load_weights(device)
79 # Load pretrained weights into model
80 model.load_state_dict(network_data['state_dict'], strict=True)
82 return model
84 def get_model_input_shape(self):
85 """ Return the input shape of the model.
87 Returns
88 -------
89 input_shape : `tuple`
90 The input shape of the model -- (height, width), ignores
91 the other dimensions.
93 Raises
94 ------
95 KeyError
96 If the input shape is not found in the metadata.
97 """
98 return tuple(self.metadata['input_shape'])
100 def get_input_scale_factors(self):
101 """
102 Return the scale factors to be applied to the input data.
104 Returns
105 -------
106 scale_factors : `tuple`
107 The scale factors to be applied to the input data.
109 Raises
110 ------
111 KeyError
112 If the scale factors are not found in the metadata.
113 """
114 return tuple(self.metadata['input_scale_factor'])
116 def get_boost_factor(self):
117 """
118 Return the boost factor to be applied to the output data.
120 If the boost factor is not found in the metadata, return None.
121 It is the responsibility of the client to know whether this type
122 of model requires a boost factor or not.
124 Returns
125 -------
126 boost_factor : `float`
127 The boost factor to be applied to the output data.
129 Raises
130 ------
131 KeyError
132 If the boost factor is not found in the metadata.
133 """
134 return self.metadata['boost_factor']