searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

多次执行backward的方法

2024-11-06 10:00:17
4
0
import torch

x = torch.tensor([2.], requires_grad=True)
w = torch.tensor([1.], requires_grad=True)

# 方式1
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward(retain_graph=True)
print(w.grad) # tensor([5.])

w.grad.data.zero_()
y.backward(retain_graph=True)
print(w.grad) # tensor([5.])

y.backward(retain_graph=True)
print(w.grad) # tensor([10.])

y.backward(retain_graph=True)
print(w.grad) # tensor([105.])

# 方式2
for _ in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

w.grad.zero_()
y.backward()
print(w.grad)

"""
y.backward()
y.backward() 2次执行 backward()会报错
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

"""
0条评论
作者已关闭评论
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
原创

多次执行backward的方法

2024-11-06 10:00:17
4
0
import torch

x = torch.tensor([2.], requires_grad=True)
w = torch.tensor([1.], requires_grad=True)

# 方式1
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward(retain_graph=True)
print(w.grad) # tensor([5.])

w.grad.data.zero_()
y.backward(retain_graph=True)
print(w.grad) # tensor([5.])

y.backward(retain_graph=True)
print(w.grad) # tensor([10.])

y.backward(retain_graph=True)
print(w.grad) # tensor([105.])

# 方式2
for _ in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

w.grad.zero_()
y.backward()
print(w.grad)

"""
y.backward()
y.backward() 2次执行 backward()会报错
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

"""
文章来自个人专栏
云原生最佳实践
29 文章 | 1 订阅
0条评论
作者已关闭评论
作者已关闭评论
0
0