Coverage for python/lsst/meas/transiNet/modelPackages/nnModelPackage.py: 24%
28 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 11:28 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 11:28 +0000
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["NNModelPackage"]
24from .storageAdapterFactory import StorageAdapterFactory
26import torch
29class NNModelPackage:
30 """
31 An interface to abstract physical storage of network architecture &
32 pretrained models out of clients' code.
34 It handles all necessary required tasks, including fetching,
35 decompression, etc. per need and creates a "Model Package"
36 ready to use: a model architecture loaded with specific pretrained
37 weights.
38 """
40 def __init__(self, model_package_name, package_storage_mode, **kwargs):
41 # Validate passed arguments.
42 if package_storage_mode not in StorageAdapterFactory.storageAdapterClasses.keys():
43 raise ValueError("Unsupported storage mode: %s" % package_storage_mode)
44 if None in (model_package_name, package_storage_mode):
45 raise ValueError("None is not a valid argument")
47 self.model_package_name = model_package_name
48 self.package_storage_mode = package_storage_mode
50 self.adapter = StorageAdapterFactory.create(self.model_package_name,
51 self.package_storage_mode,
52 **kwargs)
54 self.metadata = self.adapter.load_metadata()
56 def load(self, device):
57 """Load model architecture and pretrained weights.
58 This method handles all different modes of storages.
61 Parameters
62 ----------
63 device : `str`
64 Device to create the model on, e.g. 'cpu' or 'cuda:0'.
66 Returns
67 -------
68 model : `torch.nn.Module`
69 The neural network model, loaded with pretrained weights.
70 Its type should be a subclass of nn.Module, defined by
71 the architecture module.
72 """
74 # Check if the specified device is valid.
75 if device not in ['cpu'] + ['cuda:%d' % i for i in range(torch.cuda.device_count())]:
76 raise ValueError("Invalid device: %s" % device)
78 # Load various components.
79 # Note that because of the way the StorageAdapterButler works,
80 # the model architecture and the pretrained weights are loaded
81 # into the cpu memory, and only then moved to the target device.
82 model = self.adapter.load_arch(device='cpu')
83 network_data = self.adapter.load_weights(device='cpu')
85 # Load pretrained weights into model
86 model.load_state_dict(network_data['state_dict'], strict=True)
88 # Move model to the specified device, if it is not already there.
89 if device != 'cpu':
90 model = model.to(device)
92 return model
94 def get_model_input_shape(self):
95 """ Return the input shape of the model.
97 Returns
98 -------
99 input_shape : `tuple`
100 The input shape of the model -- (height, width), ignores
101 the other dimensions.
103 Raises
104 ------
105 KeyError
106 If the input shape is not found in the metadata.
107 """
108 return tuple(self.metadata['input_shape'])
110 def get_input_scale_factors(self):
111 """
112 Return the scale factors to be applied to the input data.
114 Returns
115 -------
116 scale_factors : `tuple`
117 The scale factors to be applied to the input data.
119 Raises
120 ------
121 KeyError
122 If the scale factors are not found in the metadata.
123 """
124 return tuple(self.metadata['input_scale_factor'])
126 def get_boost_factor(self):
127 """
128 Return the boost factor to be applied to the output data.
130 If the boost factor is not found in the metadata, return None.
131 It is the responsibility of the client to know whether this type
132 of model requires a boost factor or not.
134 Returns
135 -------
136 boost_factor : `float`
137 The boost factor to be applied to the output data.
139 Raises
140 ------
141 KeyError
142 If the boost factor is not found in the metadata.
143 """
144 return self.metadata['boost_factor']