import torch
x = torch.tensor([2.], requires_grad=True)
w = torch.tensor([1.], requires_grad=True)
# 方式1
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward(retain_graph=True)
print(w.grad) # tensor([5.])
w.grad.data.zero_()
y.backward(retain_graph=True)
print(w.grad) # tensor([5.])
y.backward(retain_graph=True)
print(w.grad) # tensor([10.])
y.backward(retain_graph=True)
print(w.grad) # tensor([105.])
# 方式2
for _ in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
w.grad.zero_()
y.backward()
print(w.grad)
"""
y.backward()
y.backward() 2次执行 backward()会报错
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
"""
0条评论