Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterBase.py: 50%
20 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 02:48 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 02:48 -0700
1from . import utils
2import torch
3import yaml
6class StorageAdapterBase(object):
7 """
8 Base class for storage adapters.
10 Parameters
11 ----------
12 model_package_name : `str`
13 The name of the model package, e.g. "my_model".
14 """
16 model_package_name = None
17 """Name of the model package (`str`).
18 """
20 def __init__(self, model_package_name):
21 self.model_package_name = model_package_name
23 def fetch(self):
24 """
25 Derived classes must implement any potentially
26 needed fetching operation in this method.
27 This is the place to implement any sort of task
28 that needs to be done before loading the model and weights.
29 Multiple calls to this method must not result in multiple
30 fetches.
31 """
32 pass
34 def load_arch(self, device):
35 """
36 Load and return the model architecture
37 (no loading of pre-trained weights).
39 Parameters
40 ----------
41 device : `torch.device`
42 Device to load the model on.
44 Returns
45 -------
46 model : `torch.nn.Module`
47 The model architecture. The exact type of this object
48 is model-specific.
50 See Also
51 --------
52 load_weights
53 """
55 model = utils.import_model(self.model_filename).to(device)
56 return model
58 def load_weights(self, device):
59 """
60 Load and return a checkpoint of a neural network model.
62 Parameters
63 ----------
64 device : `torch.device`
65 Device to load the pretrained weights on.
67 Returns
68 -------
69 network_data : `dict`
70 Dictionary containing a saved network state in PyTorch format,
71 composed of the trained weights, optimizer state, and other
72 useful metadata.
74 See Also
75 --------
76 load_arch
77 """
79 network_data = torch.load(self.checkpoint_filename, map_location=device)
80 return network_data
82 def load_metadata(self):
83 """
84 Load and return the metadata associated with the model package.
86 Returns
87 -------
88 metadata : `dict`
89 Dictionary containing the metadata associated with the model.
90 """
91 with open(self.metadata_filename, 'r') as f:
92 metadata = yaml.safe_load(f)
93 return metadata