问题
方法
import torch from torch import nn conv1 = nn.Conv2d( in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1 ) conv2 = nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 ) x = torch.rand(128, 3, 224, 224) x1 = conv1(x) # [128, 32, 224, 224] x2 = conv2(x) # [128, 16, 224, 224] # 表示对dim=1维进行cat操作,其他维度均不变 out = torch.cat([x1, x2], dim=1) print(out.shape) #[128, 48, 224, 224]