Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterBase.py: 50%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-20 03:07 -0700

1from . import utils 

2import torch 

3import yaml 

4 

5 

6class StorageAdapterBase(object): 

7 """ 

8 Base class for storage adapters. 

9 

10 Parameters 

11 ---------- 

12 model_package_name : `str` 

13 The name of the model package, e.g. "my_model". 

14 """ 

15 

16 model_package_name = None 

17 """Name of the model package (`str`). 

18 """ 

19 

20 def __init__(self, model_package_name): 

21 self.model_package_name = model_package_name 

22 

23 def fetch(self): 

24 """ 

25 Derived classes must implement any potentially 

26 needed fetching operation in this method. 

27 This is the place to implement any sort of task 

28 that needs to be done before loading the model and weights. 

29 Multiple calls to this method must not result in multiple 

30 fetches. 

31 """ 

32 pass 

33 

34 def load_arch(self, device): 

35 """ 

36 Load and return the model architecture 

37 (no loading of pre-trained weights). 

38 

39 Parameters 

40 ---------- 

41 device : `torch.device` 

42 Device to load the model on. 

43 

44 Returns 

45 ------- 

46 model : `torch.nn.Module` 

47 The model architecture. The exact type of this object 

48 is model-specific. 

49 

50 See Also 

51 -------- 

52 load_weights 

53 """ 

54 

55 model = utils.import_model(self.model_filename).to(device) 

56 return model 

57 

58 def load_weights(self, device): 

59 """ 

60 Load and return a checkpoint of a neural network model. 

61 

62 Parameters 

63 ---------- 

64 device : `torch.device` 

65 Device to load the pretrained weights on. 

66 

67 Returns 

68 ------- 

69 network_data : `dict` 

70 Dictionary containing a saved network state in PyTorch format, 

71 composed of the trained weights, optimizer state, and other 

72 useful metadata. 

73 

74 See Also 

75 -------- 

76 load_arch 

77 """ 

78 

79 network_data = torch.load(self.checkpoint_filename, map_location=device) 

80 return network_data 

81 

82 def load_metadata(self): 

83 """ 

84 Load and return the metadata associated with the model package. 

85 

86 Returns 

87 ------- 

88 metadata : `dict` 

89 Dictionary containing the metadata associated with the model. 

90 """ 

91 with open(self.metadata_filename, 'r') as f: 

92 metadata = yaml.safe_load(f) 

93 return metadata