Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterNeighbor.py: 28%
31 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 03:52 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 03:52 -0700
1import os
2import glob
4from .storageAdapterBase import StorageAdapterBase
6__all__ = ["StorageAdapterNeighbor"]
9class StorageAdapterNeighbor(StorageAdapterBase):
10 """ An adapter for interfacing with ModelPackages stored in the
11 'neighbor' mode.
13 Neighbor mode means both the code and pretrained weights
14 reside in the neighbor Git repository, namely rbClassifier_data.
15 Each model package is assumed to be a directory under the
16 "model_packages" folder, in the root of that repository.
17 """
18 def __init__(self, model_package_name):
19 super().__init__(model_package_name)
21 self.fetch()
22 self.model_filename, self.checkpoint_filename, self.metadata_filename = self.get_filenames()
24 @staticmethod
25 def get_base_path():
26 """
27 Returns the base model packages storage path for this mode.
29 Returns
30 -------
31 `str`
32 The base path to the model packages storage.
34 """
35 try:
36 base_path = os.environ['RBCLASSIFIER_DATA_DIR']
37 except KeyError:
38 raise RuntimeError("Cannot find the lsst-dm/rbClassifier_data package; "
39 "is it downloaded and set up?\n"
40 "See https://pipelines.lsst.io/v/daily/modules/lsst.ap.verify/running.html "
41 "for details on setting up packages from GitHub.")
43 return os.path.join(base_path, 'model_packages')
45 def get_filenames(self):
46 """
47 Find and return absolute paths to the architecture and checkpoint files
49 Parameters
50 ----------
52 Returns
53 -------
54 model_filename : `str`
55 The full path to the .py file containing the model architecture.
56 checkpoint_filename : `str`
57 The full path to the file containing the saved checkpoint.
59 Raises
60 ------
61 FileNotFoundError
62 If any of the underlying files cannot be accessed. This may also
63 be raised if the model package directory is not found.
64 """
65 dir_name = os.path.join(self.get_base_path(), self.model_package_name)
67 # We do not assume default file names in case of the 'neighbor' mode.
68 # For now we rely on a hacky pattern matching approach:
69 # There should be one and only one file named arch*.py under the dir.
70 # There should be one and only one file named *.pth.tar under the dir.
71 # There should be one and only one file named meta*.yaml under the dir.
72 try:
73 model_filenames = glob.glob(f'{dir_name}/arch*.py')
74 checkpoint_filenames = glob.glob(f'{dir_name}/*.pth.tar')
75 metadata_filenames = glob.glob(f'{dir_name}/meta*.yaml')
76 except IndexError:
77 raise FileNotFoundError("Cannot find model architecture, checkpoint or metadata file.")
79 # Check that there's only one file for each of the three categories.
80 if len(model_filenames) != 1:
81 raise RuntimeError(f"Found {len(model_filenames)} model files, "
82 f"expected 1 in {dir_name}.")
83 if len(checkpoint_filenames) != 1:
84 raise RuntimeError(f"Found {len(checkpoint_filenames)} checkpoint files, "
85 f"expected 1 in {dir_name}.")
86 if len(metadata_filenames) != 1:
87 raise RuntimeError(f"Found {len(metadata_filenames)} metadata files, "
88 f"expected 1 in {dir_name}.")
90 return model_filenames[0], checkpoint_filenames[0], metadata_filenames[0]