Coverage for model_packages/dummy/arch.py: 100%
26 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 02:47 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 02:47 -0700
1__all__ = ["TestModel"]
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
8class TestModel(nn.Module):
9 """A trivial pytorch model for testing the interface."""
11 def __init__(self):
12 super().__init__()
14 self.conv1 = nn.Conv2d(3, 16, 3, 1)
15 self.conv2 = nn.Conv2d(16, 16, 3, 1)
16 self.dropout1 = nn.Dropout(0.25)
17 self.dropout2 = nn.Dropout(0.5)
18 self.fc1 = nn.Linear(254016, 32)
19 self.fc2 = nn.Linear(32, 1)
21 def forward(self, x):
22 x = self.conv1(x)
23 x = F.relu(x)
24 x = self.conv2(x)
25 x = F.relu(x)
26 x = F.max_pool2d(x, 2)
27 x = self.dropout1(x)
28 x = torch.flatten(x, 1)
29 x = self.fc1(x)
30 x = F.relu(x)
31 x = self.dropout2(x)
32 x = self.fc2(x)
33 return x