Coverage for python / lsst / meas / transiNet / modelPackages / storageAdapterButler.py: 21%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:27 +0000

1from .storageAdapterBase import StorageAdapterBase 

2from lsst.meas.transiNet.modelPackages.formatters import NNModelPackagePayload 

3from lsst.daf.butler import DatasetType 

4from . import utils 

5 

6import torch 

7import zipfile 

8import io 

9import yaml 

10 

11__all__ = ["StorageAdapterButler"] 

12 

13 

14class StorageAdapterButler(StorageAdapterBase): 

15 """ An adapter for interfacing with butler model packages. 

16 

17 In this mode, all components of a model package are stored in the 

18 a Butler repository. 

19 

20 Parameters 

21 ---------- 

22 model_package_name : `str` 

23 The name of the ModelPackage to be loaded. 

24 butler : `lsst.daf.butler.Butler` 

25 The butler instance used for loading the model package. 

26 This is used in the "offline" mode, where the model package 

27 is not preloaded, but is fetched from the butler repository 

28 manually. 

29 butler_loaded_package : `io.BytesIO` 

30 The package pre-loaded by the graph builder. 

31 This is a data blob representing a `pretrainedModelPackage` dataset 

32 directly loaded from the butler repository. 

33 It is only set when we are in the "online" mode of functionality. 

34 """ 

35 

36 dataset_type_name = 'pretrainedModelPackage' 

37 packages_parent_collection = 'pretrained_models' 

38 

39 def __init__(self, model_package_name, butler=None, butler_loaded_package=None): 

40 super().__init__(model_package_name) 

41 

42 self.model_package_name = model_package_name 

43 self.butler = butler 

44 

45 self.model_file = self.checkpoint_file = self.metadata_file = None 

46 

47 # butler and butler_loaded_package are mutually exclusive. 

48 if butler is not None and butler_loaded_package is not None: 

49 raise ValueError('butler and butler_loaded_package are mutually exclusive') 

50 

51 # Use the butler_loaded_package if it is provided. 

52 if butler_loaded_package is not None: 

53 self.from_payload(butler_loaded_package) 

54 

55 # If the butler is provided, we are in the "offline" mode. Let's go 

56 # and fetch the model package from the butler repository. 

57 if butler is not None: 

58 self.fetch() 

59 

60 @classmethod 

61 def from_other(cls, other, use_name=None): 

62 """ 

63 Create a new instance of this class from another instance, which 

64 can be of a different mode. 

65 

66 Parameters 

67 ---------- 

68 other : `StorageAdapterBase` 

69 The instance to create a new instance from. 

70 """ 

71 

72 instance = cls(model_package_name=use_name or other.model_package_name) 

73 

74 if hasattr(other, 'model_file'): 

75 instance.model_file = other.model_file 

76 instance.checkpoint_file = other.checkpoint_file 

77 instance.metadata_file = other.metadata_file 

78 else: 

79 with open(other.model_filename, mode="rb") as f: 

80 instance.model_file = io.BytesIO(f.read()) 

81 with open(other.checkpoint_filename, mode="rb") as f: 

82 instance.checkpoint_file = io.BytesIO(f.read()) 

83 with open(other.metadata_filename, mode="rb") as f: 

84 instance.metadata_file = io.BytesIO(f.read()) 

85 

86 return instance 

87 

88 def from_payload(self, payload): 

89 """ 

90 Decompress the payload into the memory and save each component 

91 as an in-memory file. 

92 

93 Parameters 

94 ---------- 

95 payload : `NNModelPackagePayload` 

96 The payload to create the instance from. 

97 

98 """ 

99 with zipfile.ZipFile(payload.bytes, mode="r") as zf: 

100 with zf.open('checkpoint') as f: 

101 self.checkpoint_file = io.BytesIO(f.read()) 

102 with zf.open('architecture') as f: 

103 self.model_file = io.BytesIO(f.read()) 

104 with zf.open('metadata') as f: 

105 self.metadata_file = io.BytesIO(f.read()) 

106 

107 def to_payload(self): 

108 """ 

109 Compress the model package into a payload. 

110 

111 Returns 

112 ------- 

113 payload : `NNModelPackagePayload` 

114 The payload containing the compressed model package. 

115 """ 

116 

117 payload = NNModelPackagePayload() 

118 

119 with zipfile.ZipFile(payload.bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: 

120 zf.writestr('checkpoint', self.checkpoint_file.read()) 

121 zf.writestr('architecture', self.model_file.read()) 

122 zf.writestr('metadata', self.metadata_file.read()) 

123 

124 return payload 

125 

126 def fetch(self): 

127 """Fetch the model package from the butler repository, decompress it 

128 into the memory and save each component as an in-memory file. 

129 

130 In case self.preloaded_package is not None, the fetching from the 

131 butler repository is already done, which is the "normal" case. 

132 """ 

133 

134 # If we have already loaded the package, there's nothing left to do. 

135 if self.model_file is not None: 

136 return 

137 

138 # Fetching needs a butler object. 

139 if self.butler is None: 

140 raise ValueError('The `butler` object is required for fetching the model package') 

141 

142 # Fetch the model package from the butler repository. 

143 results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name, 

144 collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}') # noqa: E501 

145 payload = self.butler.get(list(results)[0]) 

146 self.from_payload(payload) 

147 

148 def load_arch(self, device): 

149 """ 

150 Load and return the model architecture 

151 (no loading of pre-trained weights). 

152 

153 Parameters 

154 ---------- 

155 device : `torch.device` 

156 Device to load the model on. 

157 

158 Returns 

159 ------- 

160 model : `torch.nn.Module` 

161 The model architecture. The exact type of this object 

162 is model-specific. 

163 

164 See Also 

165 -------- 

166 load_weights 

167 """ 

168 

169 module = utils.load_module_from_memory(self.model_file) 

170 model = utils.import_model_from_module(module).to(device) 

171 return model 

172 

173 def load_weights(self, device): 

174 """ 

175 Load and return a checkpoint of a neural network model. 

176 

177 Parameters 

178 ---------- 

179 device : `torch.device` 

180 Device to load the pretrained weights on. 

181 Only loading on CPU can be used in this case (Butler mode). 

182 

183 

184 Returns 

185 ------- 

186 network_data : `dict` 

187 Dictionary containing a saved network state in PyTorch format, 

188 composed of the trained weights, optimizer state, and other 

189 useful metadata. 

190 

191 See Also 

192 -------- 

193 load_arch 

194 """ 

195 if device != 'cpu': 

196 raise RuntimeError('storageAdapterButler only supports loading on CPU') 

197 network_data = torch.load(self.checkpoint_file, map_location=device, 

198 weights_only=True) 

199 return network_data 

200 

201 def load_metadata(self): 

202 """ 

203 Load and return the metadata associated with the model package. 

204 

205 Returns 

206 ------- 

207 metadata : `dict` 

208 Dictionary containing the metadata associated with the model. 

209 """ 

210 

211 metadata = yaml.safe_load(self.metadata_file) 

212 return metadata 

213 

214 @staticmethod 

215 def ingest(model_package, butler, model_package_name=None): 

216 """ 

217 Ingest a model package to the butler repository. 

218 

219 Parameters 

220 ---------- 

221 model_package : nnModelPackage 

222 The model package to be ingested. 

223 butler : `lsst.daf.butler.Butler` 

224 The butler instance to use for ingesting. 

225 model_package_name : `str`, optional 

226 The name of the model package to be ingested. 

227 """ 

228 

229 # Check if the input model package is of a proper type. 

230 if model_package.adapter is StorageAdapterButler: 

231 raise ValueError('The input model package cannot be of the butler type') 

232 

233 # Choose the name of the model package to be ingested. 

234 if model_package_name is None: 

235 the_name = model_package.model_package_name 

236 else: 

237 the_name = model_package_name 

238 

239 # Create the destination run collection. 

240 run_collection = f"{StorageAdapterButler.packages_parent_collection}/{the_name}" 

241 butler.registry.registerRun(run_collection) 

242 

243 # Create the dataset type (and register it, just in case). 

244 data_id = {} 

245 dataset_type = DatasetType(StorageAdapterButler.dataset_type_name, 

246 dimensions=[], 

247 storageClass="NNModelPackagePayload", 

248 universe=butler.registry.dimensions) 

249 

250 # Register the dataset type. 

251 def register_dataset_type(butler, dataset_type_name, dataset_type): 

252 try: # Do nothing if the dataset type is already registered 

253 butler.registry.getDatasetType(dataset_type_name) 

254 except KeyError: 

255 butler.registry.registerDatasetType(dataset_type) 

256 register_dataset_type(butler, StorageAdapterButler.dataset_type_name, dataset_type) 

257 

258 # Create an instance of StorageAdapterButler, and ingest its payload. 

259 payload = StorageAdapterButler.from_other(model_package.adapter).to_payload() 

260 butler.put(payload, 

261 dataset_type, 

262 data_id, 

263 run=run_collection)