Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterLocal.py: 28%
31 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 12:25 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 12:25 +0000
1import os
2import glob
4from .storageAdapterBase import StorageAdapterBase
6__all__ = ["StorageAdapterLocal"]
9class StorageAdapterLocal(StorageAdapterBase):
10 """An adapter for interfacing with ModelPackages stored in the
11 'local' mode.
13 Local mode means both the code and pretrained weights reside in
14 the same repository as that of rbTransiNetInterface.
15 """
16 def __init__(self, model_package_name):
17 super().__init__(model_package_name)
19 self.fetch()
20 self.model_filename, self.checkpoint_filename, self.metadata_filename = self.get_filenames()
22 @staticmethod
23 def get_base_path():
24 """
25 Return the base model packages storage path for this mode.
27 Returns
28 -------
29 `str`
30 The base path to the model packages storage.
32 """
33 try:
34 base_path = os.environ['MEAS_TRANSINET_DIR']
35 except KeyError:
36 raise RuntimeError("The environment variable MEAS_TRANSINET_DIR is not set.")
38 return os.path.join(base_path, 'model_packages')
40 def get_filenames(self):
41 """
42 Find and return absolute paths to the architecture and checkpoint files
44 Returns
45 -------
46 model_filename : `str`
47 The full path to the .py file containing the model architecture.
48 checkpoint_filename : `str`
49 The full path to the file containing the saved checkpoint.
51 Raises
52 ------
53 FileNotFoundError
54 If the model package is not found.
55 """
56 dir_name = os.path.join(self.get_base_path(),
57 self.model_package_name)
59 # We do not assume default file names in case of the 'local' mode.
60 # For now we rely on a hacky pattern matching approach:
61 # There should be one and only one file named arch*.py under the dir.
62 # There should be one and only one file named *.pth.tar under the dir.
63 # There should be one and only one file named meta*.yaml under the dir.
64 try:
65 model_filenames = glob.glob(f'{dir_name}/arch*.py')
66 checkpoint_filenames = glob.glob(f'{dir_name}/*.pth.tar')
67 metadata_filenames = glob.glob(f'{dir_name}/meta*.yaml')
68 except IndexError:
69 raise FileNotFoundError("Cannot find model architecture, checkpoint or metadata file.")
71 # Check that there's only one file for each of the three categories.
72 if len(model_filenames) != 1:
73 raise RuntimeError(f"Found {len(model_filenames)} model files, "
74 f"expected 1 in {dir_name}.")
75 if len(checkpoint_filenames) != 1:
76 raise RuntimeError(f"Found {len(checkpoint_filenames)} checkpoint files, "
77 f"expected 1 in {dir_name}.")
78 if len(metadata_filenames) != 1:
79 raise RuntimeError(f"Found {len(metadata_filenames)} metadata files, "
80 f"expected 1 in {dir_name}.")
82 return model_filenames[0], checkpoint_filenames[0], metadata_filenames[0]