Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterNeighbor.py: 24%
31 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-06 10:39 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-06 10:39 +0000
1import os
2import glob
4from .storageAdapterBase import StorageAdapterBase
6__all__ = ["StorageAdapterNeighbor"]
9class StorageAdapterNeighbor(StorageAdapterBase):
10 """ An adapter for interfacing with ModelPackages stored in the
11 'neighbor' mode.
13 Neighbor mode means both the code and pretrained weights
14 reside in the neighbor Git repository, namely rbClassifier_data.
15 Each model package is assumed to be a directory under the
16 "model_packages" folder, in the root of that repository.
17 """
18 def __init__(self, model_package_name):
19 super().__init__(model_package_name)
21 self.fetch()
22 self.model_filename, self.checkpoint_filename, self.metadata_filename = self.get_filenames()
24 @staticmethod
25 def get_base_path():
26 """
27 Returns the base model packages storage path for this mode.
29 Returns
30 -------
31 `str`
32 The base path to the model packages storage.
34 """
35 try:
36 base_path = os.environ['RBCLASSIFIER_DATA_DIR']
37 except KeyError:
38 raise RuntimeError("The environment variable RBCLASSIFIER_DATA_DIR is not set.")
40 return os.path.join(base_path, 'model_packages')
42 def get_filenames(self):
43 """
44 Find and return absolute paths to the architecture and checkpoint files
46 Parameters
47 ----------
49 Returns
50 -------
51 model_filename : `str`
52 The full path to the .py file containing the model architecture.
53 checkpoint_filename : `str`
54 The full path to the file containing the saved checkpoint.
56 Raises
57 ------
58 FileNotFoundError
59 If any of the underlying files cannot be accessed. This may also
60 be raised if the model package directory is not found.
61 """
62 dir_name = os.path.join(self.get_base_path(), self.model_package_name)
64 # We do not assume default file names in case of the 'neighbor' mode.
65 # For now we rely on a hacky pattern matching approach:
66 # There should be one and only one file named arch*.py under the dir.
67 # There should be one and only one file named *.pth.tar under the dir.
68 # There should be one and only one file named meta*.yaml under the dir.
69 try:
70 model_filenames = glob.glob(f'{dir_name}/arch*.py')
71 checkpoint_filenames = glob.glob(f'{dir_name}/*.pth.tar')
72 metadata_filenames = glob.glob(f'{dir_name}/meta*.yaml')
73 except IndexError:
74 raise FileNotFoundError("Cannot find model architecture, checkpoint or metadata file.")
76 # Check that there's only one file for each of the three categories.
77 if len(model_filenames) != 1:
78 raise RuntimeError(f"Found {len(model_filenames)} model files, "
79 f"expected 1 in {dir_name}.")
80 if len(checkpoint_filenames) != 1:
81 raise RuntimeError(f"Found {len(checkpoint_filenames)} checkpoint files, "
82 f"expected 1 in {dir_name}.")
83 if len(metadata_filenames) != 1:
84 raise RuntimeError(f"Found {len(metadata_filenames)} metadata files, "
85 f"expected 1 in {dir_name}.")
87 return model_filenames[0], checkpoint_filenames[0], metadata_filenames[0]