Coverage for model_packages/dummy/arch.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-19 05:06 -0700

1__all__ = ["TestModel"] 

2 

3import torch 

4import torch.nn as nn 

5import torch.nn.functional as F 

6 

7 

8class TestModel(nn.Module): 

9 """A trivial pytorch model for testing the interface.""" 

10 

11 def __init__(self): 

12 super().__init__() 

13 

14 self.conv1 = nn.Conv2d(3, 16, 3, 1) 

15 self.conv2 = nn.Conv2d(16, 16, 3, 1) 

16 self.dropout1 = nn.Dropout(0.25) 

17 self.dropout2 = nn.Dropout(0.5) 

18 self.fc1 = nn.Linear(254016, 32) 

19 self.fc2 = nn.Linear(32, 1) 

20 

21 def forward(self, x): 

22 x = self.conv1(x) 

23 x = F.relu(x) 

24 x = self.conv2(x) 

25 x = F.relu(x) 

26 x = F.max_pool2d(x, 2) 

27 x = self.dropout1(x) 

28 x = torch.flatten(x, 1) 

29 x = self.fc1(x) 

30 x = F.relu(x) 

31 x = self.dropout2(x) 

32 x = self.fc2(x) 

33 return x