searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

torch中stack和cat、hstack和vstack、split和chunk 分析

2024-11-04 09:32:44
1
0
import torch

# torch 中 stack 和 cat 异同

t1 = torch.ones(2, 3, 4)
t2 = torch.ones(2, 3, 4)

# stack
t3 = torch.stack((t1, t2), dim=0)
print(t3.size())

t3 = torch.stack((t1, t2), dim=1)
print(t3.size())

t3 = torch.stack((t1, t2), dim=2)
print(t3.size())

# cat
t3 = torch.cat((t1, t2), dim=0)
print(t3.size())

t3 = torch.cat((t1, t2), dim=1)
print(t3.size())

t3 = torch.cat((t1, t2), dim=2)
print(t3.size())

"""
torch.Size([2, 2, 3, 4])
torch.Size([2, 2, 3, 4])
torch.Size([2, 3, 2, 4])
torch.Size([4, 3, 4])
torch.Size([2, 6, 4])
torch.Size([2, 3, 8])
"""

import torch


# vstack 和 hstack 异同
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.stack((a, b), 0)
print(c, c.size())

a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[4], [5], [6]])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), 0)
print(c, c.size())

"""
tensor([[1, 2, 3],
[4, 5, 6]]) torch.Size([2, 3])
tensor([[1],
[2],
[3],
[4],
[5],
[6]]) torch.Size([6, 1])
"""

a = torch.tensor((1, 2, 3)) # 只有一个轴存在
b = torch.tensor((22, 33, 44))
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=0)
print(c, c.size())

a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[22], [33], [44]])
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=1)
print(c, c.size())

# split 和 chunk 异同
import torch

x = torch.rand(4, 8, 6)
y = torch.split(x, 2, dim=0) # 按照2这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.split(x, 4, dim=1) # 按照4这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.split(x, 3, dim=2) # 按照3这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 4, 6])
torch.Size([4, 4, 6])
torch.Size([4, 8, 3])
torch.Size([4, 8, 3])
"""

# chunk 目标是要切分成几个
x = torch.rand(4, 8, 6)
y = torch.chunk(x, 2, dim=0) # 按照要切分成2个去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.chunk(x, 4, dim=1) # 按照要切分成4个去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.chunk(x, 3, dim=2) # 按照要切分成3个去分,每大块包含2个小块
for i in y:
print(i.size())

"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
"""
0条评论
作者已关闭评论
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
原创

torch中stack和cat、hstack和vstack、split和chunk 分析

2024-11-04 09:32:44
1
0
import torch

# torch 中 stack 和 cat 异同

t1 = torch.ones(2, 3, 4)
t2 = torch.ones(2, 3, 4)

# stack
t3 = torch.stack((t1, t2), dim=0)
print(t3.size())

t3 = torch.stack((t1, t2), dim=1)
print(t3.size())

t3 = torch.stack((t1, t2), dim=2)
print(t3.size())

# cat
t3 = torch.cat((t1, t2), dim=0)
print(t3.size())

t3 = torch.cat((t1, t2), dim=1)
print(t3.size())

t3 = torch.cat((t1, t2), dim=2)
print(t3.size())

"""
torch.Size([2, 2, 3, 4])
torch.Size([2, 2, 3, 4])
torch.Size([2, 3, 2, 4])
torch.Size([4, 3, 4])
torch.Size([2, 6, 4])
torch.Size([2, 3, 8])
"""

import torch


# vstack 和 hstack 异同
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.stack((a, b), 0)
print(c, c.size())

a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[4], [5], [6]])
c = torch.vstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), 0)
print(c, c.size())

"""
tensor([[1, 2, 3],
[4, 5, 6]]) torch.Size([2, 3])
tensor([[1],
[2],
[3],
[4],
[5],
[6]]) torch.Size([6, 1])
"""

a = torch.tensor((1, 2, 3)) # 只有一个轴存在
b = torch.tensor((22, 33, 44))
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=0)
print(c, c.size())

a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[22], [33], [44]])
c = torch.hstack((a, b))
print(c, c.size())
# 等价于
c = torch.cat((a, b), dim=1)
print(c, c.size())

# split 和 chunk 异同
import torch

x = torch.rand(4, 8, 6)
y = torch.split(x, 2, dim=0) # 按照2这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.split(x, 4, dim=1) # 按照4这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.split(x, 3, dim=2) # 按照3这个维度去分,每大块包含2个小块
for i in y:
print(i.size())

"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 4, 6])
torch.Size([4, 4, 6])
torch.Size([4, 8, 3])
torch.Size([4, 8, 3])
"""

# chunk 目标是要切分成几个
x = torch.rand(4, 8, 6)
y = torch.chunk(x, 2, dim=0) # 按照要切分成2个去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.chunk(x, 4, dim=1) # 按照要切分成4个去分,每大块包含2个小块
for i in y:
print(i.size())

y = torch.chunk(x, 3, dim=2) # 按照要切分成3个去分,每大块包含2个小块
for i in y:
print(i.size())

"""
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 2, 6])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])
"""
文章来自个人专栏
云原生最佳实践
29 文章 | 1 订阅
0条评论
作者已关闭评论
作者已关闭评论
0
0