import torch
# torch 中 stack 和 cat 异同
t1 = torch.ones(2, 3, 4)
t2 = torch.ones(2, 3, 4)
# stack
t3 = torch.stack((t1, t2), dim=0)
print(t3.size())
t3 = torch.stack((t1, t2), dim=1)
print(t3.size())
t3 = torch.stack((t1, t2), dim=2)
print(t3.size())
# cat
t3 = torch.cat((t1, t2), dim=0)
print(t3.size())
t3 = torch.cat((t1, t2), dim=1)
print(t3.size())
t3 = torch.cat((t1, t2), dim=2)
print(t3.size())
"""
torch.Size([2, 2, 3, 4])
torch.Size([2, 2, 3, 4])
torch.Size([2, 3, 2, 4])
torch.Size([4, 3, 4])
torch.Size([2, 6, 4])
torch.Size([2, 3, 8])
"""
import torch
# vstack 和 hstack 异同
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.stack((a, b), 0)
print(c, c.size())
a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[4], [5], [6]])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), 0)
print(c, c.size())
"""
tensor([[1, 2, 3],
[4, 5, 6]]) torch.Size([2, 3])
tensor([[1],
[2],
[3],
[4],
[5],
[6]]) torch.Size([6, 1])
"""
a = torch.tensor((1, 2, 3)) # 只有一个轴存在
b = torch.tensor((22, 33, 44))
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=0)
print(c, c.size())
a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[22], [33], [44]])
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=1)
print(c, c.size())
# split 和 chunk 异同
import torch
x = torch.rand(4, 8, 6)
y = torch.split(x, 2, dim=0) # 按照2这个维度去分,每大块包含2个小块
for i in y:
print(i.size())
y = torch.split(x, 4, dim=1) # 按照4这个维度去分,每大块包含2个小块
for i in y:
print(i.size())
y = torch.split(x, 3, dim=2) # 按照3这个维度去分,每大块包含2个小块
for i in y:
print(i.size())
"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 4, 6])
torch.Size([4, 4, 6])
torch.Size([4, 8, 3])
torch.Size([4, 8, 3])
"""
# chunk 目标是要切分成几个
x = torch.rand(4, 8, 6)
y = torch.chunk(x, 2, dim=0) # 按照要切分成2个去分,每大块包含2个小块
for i in y:
print(i.size())
y = torch.chunk(x, 4, dim=1) # 按照要切分成4个去分,每大块包含2个小块
for i in y:
print(i.size())
y = torch.chunk(x, 3, dim=2) # 按照要切分成3个去分,每大块包含2个小块
for i in y:
print(i.size())
"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
"""
0条评论