Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterButler.py: 21%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-28 03:26 -0700

1from .storageAdapterBase import StorageAdapterBase 

2from lsst.meas.transiNet.modelPackages.formatters import NNModelPackagePayload 

3from lsst.daf.butler import DatasetType 

4from . import utils 

5 

6import torch 

7import zipfile 

8import io 

9import yaml 

10 

11__all__ = ["StorageAdapterButler"] 

12 

13 

14class StorageAdapterButler(StorageAdapterBase): 

15 """ An adapter for interfacing with butler model packages. 

16 

17 In this mode, all components of a model package are stored in the 

18 a Butler repository. 

19 

20 Parameters 

21 ---------- 

22 model_package_name : `str` 

23 The name of the ModelPackage to be loaded. 

24 butler : `lsst.daf.butler.Butler` 

25 The butler instance used for loading the model package. 

26 This is used in the "offline" mode, where the model package 

27 is not preloaded, but is fetched from the butler repository 

28 manually. 

29 butler_loaded_package : `io.BytesIO` 

30 The package pre-loaded by the graph builder. 

31 This is a data blob representing a `pretrainedModelPackage` dataset 

32 directly loaded from the butler repository. 

33 It is only set when we are in the "online" mode of functionality. 

34 """ 

35 

36 dataset_type_name = 'pretrainedModelPackage' 

37 packages_parent_collection = 'pretrained_models' 

38 

39 def __init__(self, model_package_name, butler=None, butler_loaded_package=None): 

40 super().__init__(model_package_name) 

41 

42 self.model_package_name = model_package_name 

43 self.butler = butler 

44 

45 self.model_file = self.checkpoint_file = self.metadata_file = None 

46 

47 # butler and butler_loaded_package are mutually exclusive. 

48 if butler is not None and butler_loaded_package is not None: 

49 raise ValueError('butler and butler_loaded_package are mutually exclusive') 

50 

51 # Use the butler_loaded_package if it is provided. 

52 if butler_loaded_package is not None: 

53 self.from_payload(butler_loaded_package) 

54 

55 # If the butler is provided, we are in the "offline" mode. Let's go 

56 # and fetch the model package from the butler repository. 

57 if butler is not None: 

58 self.fetch() 

59 

60 @classmethod 

61 def from_other(cls, other, use_name=None): 

62 """ 

63 Create a new instance of this class from another instance, which 

64 can be of a different mode. 

65 

66 Parameters 

67 ---------- 

68 other : `StorageAdapterBase` 

69 The instance to create a new instance from. 

70 """ 

71 

72 instance = cls(model_package_name=use_name or other.model_package_name) 

73 

74 if hasattr(other, 'model_file'): 

75 instance.model_file = other.model_file 

76 instance.checkpoint_file = other.checkpoint_file 

77 instance.metadata_file = other.metadata_file 

78 else: 

79 with open(other.model_filename, mode="rb") as f: 

80 instance.model_file = io.BytesIO(f.read()) 

81 with open(other.checkpoint_filename, mode="rb") as f: 

82 instance.checkpoint_file = io.BytesIO(f.read()) 

83 with open(other.metadata_filename, mode="rb") as f: 

84 instance.metadata_file = io.BytesIO(f.read()) 

85 

86 return instance 

87 

88 def from_payload(self, payload): 

89 """ 

90 Decompress the payload into the memory and save each component 

91 as an in-memory file. 

92 

93 Parameters 

94 ---------- 

95 payload : `NNModelPackagePayload` 

96 The payload to create the instance from. 

97 

98 """ 

99 with zipfile.ZipFile(payload.bytes, mode="r") as zf: 

100 with zf.open('checkpoint') as f: 

101 self.checkpoint_file = io.BytesIO(f.read()) 

102 with zf.open('architecture') as f: 

103 self.model_file = io.BytesIO(f.read()) 

104 with zf.open('metadata') as f: 

105 self.metadata_file = io.BytesIO(f.read()) 

106 

107 def to_payload(self): 

108 """ 

109 Compress the model package into a payload. 

110 

111 Returns 

112 ------- 

113 payload : `NNModelPackagePayload` 

114 The payload containing the compressed model package. 

115 """ 

116 

117 payload = NNModelPackagePayload() 

118 

119 with zipfile.ZipFile(payload.bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: 

120 zf.writestr('checkpoint', self.checkpoint_file.read()) 

121 zf.writestr('architecture', self.model_file.read()) 

122 zf.writestr('metadata', self.metadata_file.read()) 

123 

124 return payload 

125 

126 def fetch(self): 

127 """Fetch the model package from the butler repository, decompress it 

128 into the memory and save each component as an in-memory file. 

129 

130 In case self.preloaded_package is not None, the fetching from the 

131 butler repository is already done, which is the "normal" case. 

132 """ 

133 

134 # If we have already loaded the package, there's nothing left to do. 

135 if self.model_file is not None: 

136 return 

137 

138 # Fetching needs a butler object. 

139 if self.butler is None: 

140 raise ValueError('The `butler` object is required for fetching the model package') 

141 

142 # Fetch the model package from the butler repository. 

143 results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name, 

144 collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}') # noqa: E501 

145 payload = self.butler.get(list(results)[0]) 

146 self.from_payload(payload) 

147 

148 def load_arch(self, device): 

149 """ 

150 Load and return the model architecture 

151 (no loading of pre-trained weights). 

152 

153 Parameters 

154 ---------- 

155 device : `torch.device` 

156 Device to load the model on. 

157 

158 Returns 

159 ------- 

160 model : `torch.nn.Module` 

161 The model architecture. The exact type of this object 

162 is model-specific. 

163 

164 See Also 

165 -------- 

166 load_weights 

167 """ 

168 

169 module = utils.load_module_from_memory(self.model_file) 

170 model = utils.import_model_from_module(module).to(device) 

171 return model 

172 

173 def load_weights(self, device): 

174 """ 

175 Load and return a checkpoint of a neural network model. 

176 

177 Parameters 

178 ---------- 

179 device : `torch.device` 

180 Device to load the pretrained weights on. 

181 Only loading on CPU can be used in this case (Butler mode). 

182 

183 

184 Returns 

185 ------- 

186 network_data : `dict` 

187 Dictionary containing a saved network state in PyTorch format, 

188 composed of the trained weights, optimizer state, and other 

189 useful metadata. 

190 

191 See Also 

192 -------- 

193 load_arch 

194 """ 

195 if device != 'cpu': 

196 raise RuntimeError('storageAdapterButler only supports loading on CPU') 

197 network_data = torch.load(self.checkpoint_file, map_location=device) 

198 return network_data 

199 

200 def load_metadata(self): 

201 """ 

202 Load and return the metadata associated with the model package. 

203 

204 Returns 

205 ------- 

206 metadata : `dict` 

207 Dictionary containing the metadata associated with the model. 

208 """ 

209 

210 metadata = yaml.safe_load(self.metadata_file) 

211 return metadata 

212 

213 @staticmethod 

214 def ingest(model_package, butler, model_package_name=None): 

215 """ 

216 Ingest a model package to the butler repository. 

217 

218 Parameters 

219 ---------- 

220 model_package : nnModelPackage 

221 The model package to be ingested. 

222 butler : `lsst.daf.butler.Butler` 

223 The butler instance to use for ingesting. 

224 model_package_name : `str`, optional 

225 The name of the model package to be ingested. 

226 """ 

227 

228 # Check if the input model package is of a proper type. 

229 if model_package.adapter is StorageAdapterButler: 

230 raise ValueError('The input model package cannot be of the butler type') 

231 

232 # Choose the name of the model package to be ingested. 

233 if model_package_name is None: 

234 the_name = model_package.model_package_name 

235 else: 

236 the_name = model_package_name 

237 

238 # Create the destination run collection. 

239 run_collection = f"{StorageAdapterButler.packages_parent_collection}/{the_name}" 

240 butler.registry.registerRun(run_collection) 

241 

242 # Create the dataset type (and register it, just in case). 

243 data_id = {} 

244 dataset_type = DatasetType(StorageAdapterButler.dataset_type_name, 

245 dimensions=[], 

246 storageClass="NNModelPackagePayload", 

247 universe=butler.registry.dimensions) 

248 

249 # Register the dataset type. 

250 def register_dataset_type(butler, dataset_type_name, dataset_type): 

251 try: # Do nothing if the dataset type is already registered 

252 butler.registry.getDatasetType(dataset_type_name) 

253 except KeyError: 

254 butler.registry.registerDatasetType(dataset_type) 

255 register_dataset_type(butler, StorageAdapterButler.dataset_type_name, dataset_type) 

256 

257 # Create an instance of StorageAdapterButler, and ingest its payload. 

258 payload = StorageAdapterButler.from_other(model_package.adapter).to_payload() 

259 butler.put(payload, 

260 dataset_type, 

261 data_id, 

262 run=run_collection)