from turtle import forward
import torch
from torch import nn
class ChannelAttention(nn.Module):
def __init__(self, in_plances, ratio=16) -> None:
super().__init__()
self.max_pool = nn.AdaptiveMaxPool2d((1,1))
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
'''
(1) in_plances / ratio, 其结果为小数,导致模型报错;
(2) in_plances // ratio, 向下取整;
(3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能;
'''
self.mlp = nn.Sequential(
nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.max_pool(x)
x1 = self.mlp(x1)
x2 = self.avg_pool(x)
x2 = self.mlp(x2)
out = x1 + x2
out = self.sigmoid(out)
return out
class SpatialAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
'''
注意,此处并不是简单的最大值和均值池化操作,而是cross channel的;
'''
avg_pool = torch.mean(x, dim=1, keepdim=True)
max_pool, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_pool, max_pool], dim=1)
out = self.conv2d(out)
out = self.sigmoid(out)
return out
if __name__ == '__main__':
from torchinfo import summary
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Channel Attention')
layer = ChannelAttention(32).to(device)
summary(layer, (1, 32, 224, 224))
print('Spatial Attention')
layer = SpatialAttention().to(device)
summary(layer, (1, 32, 224, 224))
print('done!')