1. 基本知识
torch.unsqueeze
是 PyTorch 中的一个函数,用于在指定的维度上插入一个大小为1的维度
对于改变张量的形状(形状变换)非常有用,特别是在需要对张量的形状进行匹配以便进行后续操作时
大致用法如下:
torch.unsqueeze(input, dim)
input
:输入张量。dim
:要插入的维度索引
索引范围是[-input.dim()-1, input.dim()]
负索引将从末尾开始计算
2. Demo
示例1:在零维度上插入一个新维度
import torch
# 创建一个1D张量
x = torch.tensor([1, 2, 3])
print("原始张量:", x)
print("原始张量形状:", x.shape)
# 在0维度上插入一个新的维度
x_unsqueezed = torch.unsqueeze(x, 0)
print("插入新维度后的张量:", x_unsqueezed)
print("插入新维度后的张量形状:", x_unsqueezed.shape)
截图如下:
示例2:在第一维度上插入一个新维度
import torch
# 创建一个1D张量
x = torch.tensor([1, 2, 3])
print("原始张量:", x)
print("原始张量形状:", x.shape)
# 在1维度上插入一个新的维度
x_unsqueezed = torch.unsqueeze(x, 1)
print("插入新维度后的张量:", x_unsqueezed)
print("插入新维度后的张量形状:", x_unsqueezed.shape)
截图如下:
示例3:在二维张量中插入一个新维度
import torch
# 创建一个2D张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始张量:", x)
print("原始张量形状:", x.shape)
# 在0维度上插入一个新的维度
x_unsqueezed_0 = torch.unsqueeze(x, 0)
print("在0维度上插入新维度后的张量:", x_unsqueezed_0)
print("在0维度上插入新维度后的张量形状:", x_unsqueezed_0.shape)
# 在1维度上插入一个新的维度
x_unsqueezed_1 = torch.unsqueeze(x, 1)
print("在1维度上插入新维度后的张量:", x_unsqueezed_1)
print("在1维度上插入新维度后的张量形状:", x_unsqueezed_1.shape)
# 在2维度上插入一个新的维度
x_unsqueezed_2 = torch.unsqueeze(x, 2)
print("在2维度上插入新维度后的张量:", x_unsqueezed_2)
print("在2维度上插入新维度后的张量形状:", x_unsqueezed_2.shape)
截图如下:
如果是3维的维度插入,会出错,注意其范围所在区域:
Traceback (most recent call last):
File "F:\python_project\test\Father\child\file3.py", line 19, in <module>
x_unsqueezed_2 = torch.unsqueeze(x, 3)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
截图如下: