from torch import nn
import torch
class CnnNet(nn.Module):
def __init__(self, in_ch=1):
super(CnnNet, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.output = nn.AdaptiveAvgPool2d((1,1))
self.mlp = nn.Sequential(nn.Linear(32,100),
nn.Dropout(0.5),
nn.Linear(100,10)
)
def forward(self,x):
result = self.model(x)
result = self.output(result)
result = result.reshape(-1,32)
result = self.mlp(result)
return result
if __name__ == '__main__':
x = torch.randn(1,1,28,28)
net = CnnNet(in_ch=1)
print(net(x).shape)