searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

矩阵相乘6种优化方法

2024-09-25 09:31:37
6
0
import torch
import time

m1 = torch.randn(50, 784)
m2 = torch.randn(784, 10)


def cost_time(func):
def fun(*args, **kwargs):
t = time.perf_counter()
result = func(*args, **kwargs)
print(f'func {func.__name__} cost time:{time.perf_counter() - t:.8f} s')
return result

return fun


@cost_time
def matmulV1(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
for j in range(c2):
for k in range(c1):
rst[i][j] += a[i][k] * b[k][j]
return rst


matmulV1(m1, m2)


@cost_time
def matmulV2(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
for j in range(c2):
rst[i][j] = (a[i, :] * b[:, j]).sum() # 改了这里
return rst


matmulV2(m1, m2)


@cost_time
def matmulV3(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
return rst


matmulV3(m1, m2)


@cost_time
def matmulV4(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)


matmulV4(m1, m2)


@cost_time
def matmulV5(a, b):
return torch.einsum("ik,kj->ij", a, b)


matmulV5(m1, m2)


@cost_time
def matmulV6(a, b):
return a @ b


matmulV6(m1, m2)



0条评论
作者已关闭评论
Top123
32文章数
3粉丝数
Top123
32 文章 | 3 粉丝
Top123
32文章数
3粉丝数
Top123
32 文章 | 3 粉丝
原创

矩阵相乘6种优化方法

2024-09-25 09:31:37
6
0
import torch
import time

m1 = torch.randn(50, 784)
m2 = torch.randn(784, 10)


def cost_time(func):
def fun(*args, **kwargs):
t = time.perf_counter()
result = func(*args, **kwargs)
print(f'func {func.__name__} cost time:{time.perf_counter() - t:.8f} s')
return result

return fun


@cost_time
def matmulV1(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
for j in range(c2):
for k in range(c1):
rst[i][j] += a[i][k] * b[k][j]
return rst


matmulV1(m1, m2)


@cost_time
def matmulV2(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
for j in range(c2):
rst[i][j] = (a[i, :] * b[:, j]).sum() # 改了这里
return rst


matmulV2(m1, m2)


@cost_time
def matmulV3(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

rst = torch.zeros(r1, c2)

for i in range(r1):
rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
return rst


matmulV3(m1, m2)


@cost_time
def matmulV4(a, b):
r1, c1 = a.shape
r2, c2 = b.shape

assert c1 == r2

return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)


matmulV4(m1, m2)


@cost_time
def matmulV5(a, b):
return torch.einsum("ik,kj->ij", a, b)


matmulV5(m1, m2)


@cost_time
def matmulV6(a, b):
return a @ b


matmulV6(m1, m2)



文章来自个人专栏
云原生最佳实践
32 文章 | 1 订阅
0条评论
作者已关闭评论
作者已关闭评论
0
0