Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterLocal.py: 28%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-17 03:09 -0700

1import os 

2import glob 

3 

4from .storageAdapterBase import StorageAdapterBase 

5 

6__all__ = ["StorageAdapterLocal"] 

7 

8 

9class StorageAdapterLocal(StorageAdapterBase): 

10 """An adapter for interfacing with ModelPackages stored in the 

11 'local' mode. 

12 

13 Local mode means both the code and pretrained weights reside in 

14 the same repository as that of rbTransiNetInterface. 

15 """ 

16 def __init__(self, model_package_name): 

17 super().__init__(model_package_name) 

18 

19 self.fetch() 

20 self.model_filename, self.checkpoint_filename, self.metadata_filename = self.get_filenames() 

21 

22 @staticmethod 

23 def get_base_path(): 

24 """ 

25 Return the base model packages storage path for this mode. 

26 

27 Returns 

28 ------- 

29 `str` 

30 The base path to the model packages storage. 

31 

32 """ 

33 try: 

34 base_path = os.environ['MEAS_TRANSINET_DIR'] 

35 except KeyError: 

36 raise RuntimeError("The environment variable MEAS_TRANSINET_DIR is not set.") 

37 

38 return os.path.join(base_path, 'model_packages') 

39 

40 def get_filenames(self): 

41 """ 

42 Find and return absolute paths to the architecture and checkpoint files 

43 

44 Returns 

45 ------- 

46 model_filename : `str` 

47 The full path to the .py file containing the model architecture. 

48 checkpoint_filename : `str` 

49 The full path to the file containing the saved checkpoint. 

50 

51 Raises 

52 ------ 

53 FileNotFoundError 

54 If the model package is not found. 

55 """ 

56 dir_name = os.path.join(self.get_base_path(), 

57 self.model_package_name) 

58 

59 # We do not assume default file names in case of the 'local' mode. 

60 # For now we rely on a hacky pattern matching approach: 

61 # There should be one and only one file named arch*.py under the dir. 

62 # There should be one and only one file named *.pth.tar under the dir. 

63 # There should be one and only one file named meta*.yaml under the dir. 

64 try: 

65 model_filenames = glob.glob(f'{dir_name}/arch*.py') 

66 checkpoint_filenames = glob.glob(f'{dir_name}/*.pth.tar') 

67 metadata_filenames = glob.glob(f'{dir_name}/meta*.yaml') 

68 except IndexError: 

69 raise FileNotFoundError("Cannot find model architecture, checkpoint or metadata file.") 

70 

71 # Check that there's only one file for each of the three categories. 

72 if len(model_filenames) != 1: 

73 raise RuntimeError(f"Found {len(model_filenames)} model files, " 

74 f"expected 1 in {dir_name}.") 

75 if len(checkpoint_filenames) != 1: 

76 raise RuntimeError(f"Found {len(checkpoint_filenames)} checkpoint files, " 

77 f"expected 1 in {dir_name}.") 

78 if len(metadata_filenames) != 1: 

79 raise RuntimeError(f"Found {len(metadata_filenames)} metadata files, " 

80 f"expected 1 in {dir_name}.") 

81 

82 return model_filenames[0], checkpoint_filenames[0], metadata_filenames[0]