from turtle import forward import torch from torch import nn class ChannelAttention(nn.Module): # ratio表示MLP中,中间层in_planes缩小的比例 def __init__(self, in_plances, ratio=16) -> None: super().__init__() self.max_pool = nn.AdaptiveMaxPool2d((1,1)) self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) ''' (1) in_plances / ratio, 其结果为小数,导致模型报错; (2) in_plances // ratio, 向下取整; (3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能; ''' self.mlp = nn.Sequential( # 此处没有中括号 nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False), # 此处为什么卷积不需要偏置,是为了模拟FC nn.ReLU(), nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False) # python 中/与//的区别 ) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = self.max_pool(x) x1 = self.mlp(x1) x2 = self.avg_pool(x) x2 = self.mlp(x2) # 此处直接相加,而不是拼接 # torch.cat(x1, x2) out = x1 + x2 out = self.sigmoid(out) return out class SpatialAttention(nn.Module): def __init__(self) -> None: super().__init__() self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): ''' 注意,此处并不是简单的最大值和均值池化操作,而是cross channel的; ''' avg_pool = torch.mean(x, dim=1, keepdim=True) # Bx1xHxW max_pool, _ = torch.max(x, dim=1, keepdim=True) # Bx1xHxW, 此处非常容易出错,少_ # Bx2xHxW out = torch.cat([avg_pool, max_pool], dim=1) out = self.conv2d(out) out = self.sigmoid(out) return out if __name__ == '__main__': from torchinfo import summary # import hiddenlayer as h device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Channel Attention') layer = ChannelAttention(32).to(device) summary(layer, (1, 32, 224, 224)) print('Spatial Attention') layer = SpatialAttention().to(device) summary(layer, (1, 32, 224, 224)) # graph = h.build_graph(layer, torch.zeros([1, 32, 224, 224])) # graph.theme = h.graph.THEMES['blue'].copy() # graph.save('test.png') print('done!')