Coverage for python / lsst / meas / transiNet / modelPackages / storageAdapterButler.py: 21%
91 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:09 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:09 +0000
1from .storageAdapterBase import StorageAdapterBase
2from lsst.meas.transiNet.modelPackages.formatters import NNModelPackagePayload
3from lsst.daf.butler import DatasetType
4from . import utils
6import torch
7import zipfile
8import io
9import yaml
11__all__ = ["StorageAdapterButler"]
14class StorageAdapterButler(StorageAdapterBase):
15 """ An adapter for interfacing with butler model packages.
17 In this mode, all components of a model package are stored in the
18 a Butler repository.
20 Parameters
21 ----------
22 model_package_name : `str`
23 The name of the ModelPackage to be loaded.
24 butler : `lsst.daf.butler.Butler`
25 The butler instance used for loading the model package.
26 This is used in the "offline" mode, where the model package
27 is not preloaded, but is fetched from the butler repository
28 manually.
29 butler_loaded_package : `io.BytesIO`
30 The package pre-loaded by the graph builder.
31 This is a data blob representing a `pretrainedModelPackage` dataset
32 directly loaded from the butler repository.
33 It is only set when we are in the "online" mode of functionality.
34 """
36 dataset_type_name = 'pretrainedModelPackage'
37 packages_parent_collection = 'pretrained_models'
39 def __init__(self, model_package_name, butler=None, butler_loaded_package=None):
40 super().__init__(model_package_name)
42 self.model_package_name = model_package_name
43 self.butler = butler
45 self.model_file = self.checkpoint_file = self.metadata_file = None
47 # butler and butler_loaded_package are mutually exclusive.
48 if butler is not None and butler_loaded_package is not None:
49 raise ValueError('butler and butler_loaded_package are mutually exclusive')
51 # Use the butler_loaded_package if it is provided.
52 if butler_loaded_package is not None:
53 self.from_payload(butler_loaded_package)
55 # If the butler is provided, we are in the "offline" mode. Let's go
56 # and fetch the model package from the butler repository.
57 if butler is not None:
58 self.fetch()
60 @classmethod
61 def from_other(cls, other, use_name=None):
62 """
63 Create a new instance of this class from another instance, which
64 can be of a different mode.
66 Parameters
67 ----------
68 other : `StorageAdapterBase`
69 The instance to create a new instance from.
70 """
72 instance = cls(model_package_name=use_name or other.model_package_name)
74 if hasattr(other, 'model_file'):
75 instance.model_file = other.model_file
76 instance.checkpoint_file = other.checkpoint_file
77 instance.metadata_file = other.metadata_file
78 else:
79 with open(other.model_filename, mode="rb") as f:
80 instance.model_file = io.BytesIO(f.read())
81 with open(other.checkpoint_filename, mode="rb") as f:
82 instance.checkpoint_file = io.BytesIO(f.read())
83 with open(other.metadata_filename, mode="rb") as f:
84 instance.metadata_file = io.BytesIO(f.read())
86 return instance
88 def from_payload(self, payload):
89 """
90 Decompress the payload into the memory and save each component
91 as an in-memory file.
93 Parameters
94 ----------
95 payload : `NNModelPackagePayload`
96 The payload to create the instance from.
98 """
99 with zipfile.ZipFile(payload.bytes, mode="r") as zf:
100 with zf.open('checkpoint') as f:
101 self.checkpoint_file = io.BytesIO(f.read())
102 with zf.open('architecture') as f:
103 self.model_file = io.BytesIO(f.read())
104 with zf.open('metadata') as f:
105 self.metadata_file = io.BytesIO(f.read())
107 def to_payload(self):
108 """
109 Compress the model package into a payload.
111 Returns
112 -------
113 payload : `NNModelPackagePayload`
114 The payload containing the compressed model package.
115 """
117 payload = NNModelPackagePayload()
119 with zipfile.ZipFile(payload.bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
120 zf.writestr('checkpoint', self.checkpoint_file.read())
121 zf.writestr('architecture', self.model_file.read())
122 zf.writestr('metadata', self.metadata_file.read())
124 return payload
126 def fetch(self):
127 """Fetch the model package from the butler repository, decompress it
128 into the memory and save each component as an in-memory file.
130 In case self.preloaded_package is not None, the fetching from the
131 butler repository is already done, which is the "normal" case.
132 """
134 # If we have already loaded the package, there's nothing left to do.
135 if self.model_file is not None:
136 return
138 # Fetching needs a butler object.
139 if self.butler is None:
140 raise ValueError('The `butler` object is required for fetching the model package')
142 # Fetch the model package from the butler repository.
143 results = self.butler.registry.queryDatasets(StorageAdapterButler.dataset_type_name,
144 collections=f'{StorageAdapterButler.packages_parent_collection}/{self.model_package_name}') # noqa: E501
145 payload = self.butler.get(list(results)[0])
146 self.from_payload(payload)
148 def load_arch(self, device):
149 """
150 Load and return the model architecture
151 (no loading of pre-trained weights).
153 Parameters
154 ----------
155 device : `torch.device`
156 Device to load the model on.
158 Returns
159 -------
160 model : `torch.nn.Module`
161 The model architecture. The exact type of this object
162 is model-specific.
164 See Also
165 --------
166 load_weights
167 """
169 module = utils.load_module_from_memory(self.model_file)
170 model = utils.import_model_from_module(module).to(device)
171 return model
173 def load_weights(self, device):
174 """
175 Load and return a checkpoint of a neural network model.
177 Parameters
178 ----------
179 device : `torch.device`
180 Device to load the pretrained weights on.
181 Only loading on CPU can be used in this case (Butler mode).
184 Returns
185 -------
186 network_data : `dict`
187 Dictionary containing a saved network state in PyTorch format,
188 composed of the trained weights, optimizer state, and other
189 useful metadata.
191 See Also
192 --------
193 load_arch
194 """
195 if device != 'cpu':
196 raise RuntimeError('storageAdapterButler only supports loading on CPU')
197 network_data = torch.load(self.checkpoint_file, map_location=device,
198 weights_only=True)
199 return network_data
201 def load_metadata(self):
202 """
203 Load and return the metadata associated with the model package.
205 Returns
206 -------
207 metadata : `dict`
208 Dictionary containing the metadata associated with the model.
209 """
211 metadata = yaml.safe_load(self.metadata_file)
212 return metadata
214 @staticmethod
215 def ingest(model_package, butler, model_package_name=None):
216 """
217 Ingest a model package to the butler repository.
219 Parameters
220 ----------
221 model_package : nnModelPackage
222 The model package to be ingested.
223 butler : `lsst.daf.butler.Butler`
224 The butler instance to use for ingesting.
225 model_package_name : `str`, optional
226 The name of the model package to be ingested.
227 """
229 # Check if the input model package is of a proper type.
230 if model_package.adapter is StorageAdapterButler:
231 raise ValueError('The input model package cannot be of the butler type')
233 # Choose the name of the model package to be ingested.
234 if model_package_name is None:
235 the_name = model_package.model_package_name
236 else:
237 the_name = model_package_name
239 # Create the destination run collection.
240 run_collection = f"{StorageAdapterButler.packages_parent_collection}/{the_name}"
241 butler.registry.registerRun(run_collection)
243 # Create the dataset type (and register it, just in case).
244 data_id = {}
245 dataset_type = DatasetType(StorageAdapterButler.dataset_type_name,
246 dimensions=[],
247 storageClass="NNModelPackagePayload",
248 universe=butler.registry.dimensions)
250 # Register the dataset type.
251 def register_dataset_type(butler, dataset_type_name, dataset_type):
252 try: # Do nothing if the dataset type is already registered
253 butler.registry.getDatasetType(dataset_type_name)
254 except KeyError:
255 butler.registry.registerDatasetType(dataset_type)
256 register_dataset_type(butler, StorageAdapterButler.dataset_type_name, dataset_type)
258 # Create an instance of StorageAdapterButler, and ingest its payload.
259 payload = StorageAdapterButler.from_other(model_package.adapter).to_payload()
260 butler.put(payload,
261 dataset_type,
262 data_id,
263 run=run_collection)