Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterButler.py: 21%
91 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 04:38 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 04:38 -0700
1from .storageAdapterBase import StorageAdapterBase
2from lsst.meas.transiNet.modelPackages.formatters import NNModelPackagePayload
3from lsst.daf.butler import DatasetType
4from . import utils
6import torch
7import zipfile
8import io
9import yaml
11__all__ = ["StorageAdapterButler"]
14class StorageAdapterButler(StorageAdapterBase):
15 """ An adapter for interfacing with butler model packages.
17 In this mode, all components of a model package are stored in the
18 a Butler repository.
20 Parameters
21 ----------
22 model_package_name : `str`
23 The name of the ModelPackage to be loaded.
24 butler : `lsst.daf.butler.Butler`
25 The butler instance used for loading the model package.
26 This is used in the "offline" mode, where the model package
27 is not preloaded, but is fetched from the butler repository
28 manually.
29 butler_loaded_package : `io.BytesIO`
30 The package pre-loaded by the graph builder.
31 This is a data blob representing a `pretrainedModelPackage` dataset
32 directly loaded from the butler repository.
33 It is only set when we are in the "online" mode of functionality.
34 """
36 dataset_type_name = 'pretrainedModelPackage'
37 packages_parent_collection = 'pretrained_models'
39 def __init__(self, model_package_name, butler=None, butler_loaded_package=None):
40 super().__init__(model_package_name)
42 self.model_package_name = model_package_name
43 self.butler = butler
45 self.model_file = self.checkpoint_file = self.metadata_file = None
47 # butler and butler_loaded_package are mutually exclusive.
48 if butler is not None and butler_loaded_package is not None:
49 raise ValueError('butler and butler_loaded_package are mutually exclusive')
51 # Use the butler_loaded_package if it is provided.
52 if butler_loaded_package is not None:
53 self.from_payload(butler_loaded_package)
55 # If the butler is provided, we are in the "offline" mode. Let's go
56 # and fetch the model package from the butler repository.
57 if butler is not None:
58 self.fetch()
60 @classmethod
61 def from_other(cls, other, use_name=None):
62 """
63 Create a new instance of this class from another instance, which
64 can be of a different mode.
66 Parameters
67 ----------
68 other : `StorageAdapterBase`
69 The instance to create a new instance from.
70 """
72 instance = cls(model_package_name=use_name or other.model_package_name)
74 if hasattr(other, 'model_file'):
75 instance.model_file = other.model_file
76 instance.checkpoint_file = other.checkpoint_file
77 instance.metadata_file = other.metadata_file
78 else:
79 with open(other.model_filename, mode="rb") as f:
80 instance.model_file = io.BytesIO(f.read())
81 with open(other.checkpoint_filename, mode="rb") as f:
82 instance.checkpoint_file = io.BytesIO(f.read())
83 with open(other.metadata_filename, mode="rb") as f:
84 instance.metadata_file = io.BytesIO(f.read())
86 return instance
88 def from_payload(self, payload):
89 """
90 Decompress the payload into the memory and save each component
91 as an in-memory file.
93 Parameters
94 ----------
95 payload : `NNModelPackagePayload`
96 The payload to create the instance from.
98 """
99 with zipfile.ZipFile(payload.bytes, mode="r") as zf:
100 with zf.open('checkpoint') as f:
101 self.checkpoint_file = io.BytesIO(f.read())
102 with zf.open('architecture') as f:
103 self.model_file = io.BytesIO(f.read())
104 with zf.open('metadata') as f:
105 self.metadata_file = io.BytesIO(f.read())
107 def to_payload(self):
108 """
109 Compress the model package into a payload.
111 Returns
112 -------
113 payload : `NNModelPackagePayload`
114 The payload containing the compressed model package.
115 """
117 payload = NNModelPackagePayload()
119 with zipfile.ZipFile(payload.bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
120 zf.writestr('checkpoint', self.checkpoint_file.read())
121 zf.writestr('architecture', self.model_file.read())
122 zf.writestr('metadata', self.metadata_file.read())
124 return payload
126 def fetch(self):
127 """Fetch the model package from the butler repository, decompress it
128 into the memory and save each component as an in-memory file.
130 In case self.preloaded_package is not None, the fetching from the
131 butler repository is already done, which is the "normal" case.
132 """
134 # If we have already loaded the package, there's nothing left to do.
135 if self.model_file is not None:
136 return
138 # Fetching needs a butler object.
139 if self.butler is None:
140 raise ValueError('The `butler` object is required for fetching the model package')
142 # Fetch the model package from the butler repository.
143 results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name,
144 collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}') # noqa: E501
145 payload = self.butler.get(list(results)[0])
146 self.from_payload(payload)
148 def load_arch(self, device):
149 """
150 Load and return the model architecture
151 (no loading of pre-trained weights).
153 Parameters
154 ----------
155 device : `torch.device`
156 Device to load the model on.
158 Returns
159 -------
160 model : `torch.nn.Module`
161 The model architecture. The exact type of this object
162 is model-specific.
164 See Also
165 --------
166 load_weights
167 """
169 module = utils.load_module_from_memory(self.model_file)
170 model = utils.import_model_from_module(module).to(device)
171 return model
173 def load_weights(self, device):
174 """
175 Load and return a checkpoint of a neural network model.
177 Parameters
178 ----------
179 device : `torch.device`
180 Device to load the pretrained weights on.
181 Only loading on CPU can be used in this case (Butler mode).
184 Returns
185 -------
186 network_data : `dict`
187 Dictionary containing a saved network state in PyTorch format,
188 composed of the trained weights, optimizer state, and other
189 useful metadata.
191 See Also
192 --------
193 load_arch
194 """
195 if device != 'cpu':
196 raise RuntimeError('storageAdapterButler only supports loading on CPU')
197 network_data = torch.load(self.checkpoint_file, map_location=device)
198 return network_data
200 def load_metadata(self):
201 """
202 Load and return the metadata associated with the model package.
204 Returns
205 -------
206 metadata : `dict`
207 Dictionary containing the metadata associated with the model.
208 """
210 metadata = yaml.safe_load(self.metadata_file)
211 return metadata
213 @staticmethod
214 def ingest(model_package, butler, model_package_name=None):
215 """
216 Ingest a model package to the butler repository.
218 Parameters
219 ----------
220 model_package : nnModelPackage
221 The model package to be ingested.
222 butler : `lsst.daf.butler.Butler`
223 The butler instance to use for ingesting.
224 model_package_name : `str`, optional
225 The name of the model package to be ingested.
226 """
228 # Check if the input model package is of a proper type.
229 if model_package.adapter is StorageAdapterButler:
230 raise ValueError('The input model package cannot be of the butler type')
232 # Choose the name of the model package to be ingested.
233 if model_package_name is None:
234 the_name = model_package.model_package_name
235 else:
236 the_name = model_package_name
238 # Create the destination run collection.
239 run_collection = f"{StorageAdapterButler.packages_parent_collection}/{the_name}"
240 butler.registry.registerRun(run_collection)
242 # Create the dataset type (and register it, just in case).
243 data_id = {}
244 dataset_type = DatasetType(StorageAdapterButler.dataset_type_name,
245 dimensions=[],
246 storageClass="NNModelPackagePayload",
247 universe=butler.registry.dimensions)
249 # Register the dataset type.
250 def register_dataset_type(butler, dataset_type_name, dataset_type):
251 try: # Do nothing if the dataset type is already registered
252 butler.registry.getDatasetType(dataset_type_name)
253 except KeyError:
254 butler.registry.registerDatasetType(dataset_type)
255 register_dataset_type(butler, StorageAdapterButler.dataset_type_name, dataset_type)
257 # Create an instance of StorageAdapterButler, and ingest its payload.
258 payload = StorageAdapterButler.from_other(model_package.adapter).to_payload()
259 butler.put(payload,
260 dataset_type,
261 data_id,
262 run=run_collection)