import torch
import time
m1 = torch.randn(50, 784)
m2 = torch.randn(784, 10)
def cost_time(func):
def fun(*args, **kwargs):
t = time.perf_counter()
result = func(*args, **kwargs)
print(f'func {func.__name__} cost time:{time.perf_counter() - t:.8f} s')
return result
return fun
@cost_time
def matmulV1(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
for j in range(c2):
for k in range(c1):
rst[i][j] += a[i][k] * b[k][j]
return rst
matmulV1(m1, m2)
@cost_time
def matmulV2(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
for j in range(c2):
rst[i][j] = (a[i, :] * b[:, j]).sum() # 改了这里
return rst
matmulV2(m1, m2)
@cost_time
def matmulV3(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
return rst
matmulV3(m1, m2)
@cost_time
def matmulV4(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)
matmulV4(m1, m2)
@cost_time
def matmulV5(a, b):
return torch.einsum("ik,kj->ij", a, b)
matmulV5(m1, m2)
@cost_time
def matmulV6(a, b):
return a @ b
matmulV6(m1, m2)
0条评论