Coverage for python/lsst/meas/transiNet/modelPackages/storageAdapterNeighbor.py: 28%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-23 03:57 -0700

1import os 

2import glob 

3 

4from .storageAdapterBase import StorageAdapterBase 

5 

6__all__ = ["StorageAdapterNeighbor"] 

7 

8 

9class StorageAdapterNeighbor(StorageAdapterBase): 

10 """ An adapter for interfacing with ModelPackages stored in the 

11 'neighbor' mode. 

12 

13 Neighbor mode means both the code and pretrained weights 

14 reside in the neighbor Git repository, namely rbClassifier_data. 

15 Each model package is assumed to be a directory under the 

16 "model_packages" folder, in the root of that repository. 

17 """ 

18 def __init__(self, model_package_name): 

19 super().__init__(model_package_name) 

20 

21 self.fetch() 

22 self.model_filename, self.checkpoint_filename, self.metadata_filename = self.get_filenames() 

23 

24 @staticmethod 

25 def get_base_path(): 

26 """ 

27 Returns the base model packages storage path for this mode. 

28 

29 Returns 

30 ------- 

31 `str` 

32 The base path to the model packages storage. 

33 

34 """ 

35 try: 

36 base_path = os.environ['RBCLASSIFIER_DATA_DIR'] 

37 except KeyError: 

38 raise RuntimeError("Cannot find the lsst-dm/rbClassifier_data package; " 

39 "is it downloaded and set up?\n" 

40 "See https://pipelines.lsst.io/v/daily/modules/lsst.ap.verify/running.html " 

41 "for details on setting up packages from GitHub.") 

42 

43 return os.path.join(base_path, 'model_packages') 

44 

45 def get_filenames(self): 

46 """ 

47 Find and return absolute paths to the architecture and checkpoint files 

48 

49 Parameters 

50 ---------- 

51 

52 Returns 

53 ------- 

54 model_filename : `str` 

55 The full path to the .py file containing the model architecture. 

56 checkpoint_filename : `str` 

57 The full path to the file containing the saved checkpoint. 

58 

59 Raises 

60 ------ 

61 FileNotFoundError 

62 If any of the underlying files cannot be accessed. This may also 

63 be raised if the model package directory is not found. 

64 """ 

65 dir_name = os.path.join(self.get_base_path(), self.model_package_name) 

66 

67 # We do not assume default file names in case of the 'neighbor' mode. 

68 # For now we rely on a hacky pattern matching approach: 

69 # There should be one and only one file named arch*.py under the dir. 

70 # There should be one and only one file named *.pth.tar under the dir. 

71 # There should be one and only one file named meta*.yaml under the dir. 

72 try: 

73 model_filenames = glob.glob(f'{dir_name}/arch*.py') 

74 checkpoint_filenames = glob.glob(f'{dir_name}/*.pth.tar') 

75 metadata_filenames = glob.glob(f'{dir_name}/meta*.yaml') 

76 except IndexError: 

77 raise FileNotFoundError("Cannot find model architecture, checkpoint or metadata file.") 

78 

79 # Check that there's only one file for each of the three categories. 

80 if len(model_filenames) != 1: 

81 raise RuntimeError(f"Found {len(model_filenames)} model files, " 

82 f"expected 1 in {dir_name}.") 

83 if len(checkpoint_filenames) != 1: 

84 raise RuntimeError(f"Found {len(checkpoint_filenames)} checkpoint files, " 

85 f"expected 1 in {dir_name}.") 

86 if len(metadata_filenames) != 1: 

87 raise RuntimeError(f"Found {len(metadata_filenames)} metadata files, " 

88 f"expected 1 in {dir_name}.") 

89 

90 return model_filenames[0], checkpoint_filenames[0], metadata_filenames[0]