# -*- coding: utf-8 -*-
"""
@Time : 2021/7/23 上午8:00
@Auth : 陈伟峰
@File :2.py
@phone: 15882085601
@IDE :PyCharm
@Motto:ABC(Always Be Coding)
"""
from torch import nn
import torch
import torch.nn.functional as F
class VggBase(nn.Module):
def __init__(self,conf=None,num_classes=1000):
super(VggBase,self).__init__()
self.in_channels =3
self.conf = conf
assert len(self.conf)==5,"the length of config file is 5!"
assert self.conf is not None,"no config files"
if self.conf:
self.conv_64c = self.__make_layer(64,self.conf[0])
self.conv_128c = self.__make_layer(128,self.conf[1])
self.conv_256c = self.__make_layer(256,self.conf[2])
self.conv_512ac = self.__make_layer(512,self.conf[3])
self.conv_512bc = self.__make_layer(512,self.conf[4])
self.output_layer = nn.Sequential(
nn.Linear(512*7*7,4096),
nn.ReLU(),
nn.Linear(4096,4096),
nn.ReLU(),
nn.Linear(4096,num_classes),
nn.Softmax(dim=-1)
)
def __make_layer(self,channels,num,action=nn.ReLU,BN=True):
if action is not None:
self.action = action
if BN is not None:
self.bn = nn.BatchNorm2d
layers = []
for i in range(num):
layers.append(nn.Conv2d(in_channels=self.in_channels,out_channels=channels,kernel_size=(3,3),stride=(1,1),padding=(1,1)))
layers.append(self.bn(channels))
layers.append(self.action())
self.in_channels = channels
return nn.Sequential(*layers)
def forward(self,x):
result = self.conv_64c(x)
result = F.max_pool2d(result,(2,2))
result = self.conv_128c(result)
result = F.max_pool2d(result, (2, 2))
result = self.conv_256c(result)
result = F.max_pool2d(result,(2,2))
result = self.conv_512ac(result)
result = F.max_pool2d(result, (2, 2))
result = self.conv_512bc(result)
result = F.max_pool2d(result, (2, 2))
result = result.reshape(-1,512*7*7)
result = self.output_layer(result)
return result
class vgg11(VggBase):
def __init__(self):
super(vgg11, self).__init__(conf=[1,1,2,2,2])
def forward(self,x):
return super(vgg11, self).forward(x)
class vgg13(VggBase):
def __init__(self):
super(vgg13, self).__init__(conf=[2, 2, 2, 2, 2])
def forward(self, x):
return super(vgg13, self).forward(x)
class vgg16(VggBase):
def __init__(self):
super(vgg16, self).__init__(conf=[2, 2, 3, 3, 3])
def forward(self, x):
return super(vgg16, self).forward(x)
class vgg19(VggBase):
def __init__(self):
super(vgg19, self).__init__(conf=[2, 2, 4, 4, 4])
def forward(self, x):
return super(vgg19, self).forward(x)
if __name__ == '__main__':
net = vgg11()
x = torch.randn(1,3,224,224)
print(net(x).shape)
net = vgg13()
x = torch.randn(1, 3, 224, 224)
print(net(x).shape)
net = vgg16()
x = torch.randn(1,3,224,224)
print(net(x).shape)
net = vgg19()
x = torch.randn(1,3,224,224)
print(net(x).shape)