import torch.nn as nn
import torch.optim as optim
import torch
# 定义一个简单的网络
class MyNet(nn.Module):
def __init__(self, num_class=5):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc1.weight = nn.Parameter(torch.ones((4,8), dtype=torch.float))
self.fc2 = nn.Linear(4, num_class)
self.fc2.weight = nn.Parameter(torch.ones((num_class,4), dtype=torch.float))
def forward(self, x):
return self.fc2(self.fc1(x))
model = MyNet()
loss_fn = nn.CrossEntropyLoss()
choice = 3
if choice == 1:
# 情况一:不冻结参数时
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 传入的是所有的参数
if choice == 2:
# 情况二:采用方式一冻结fc1层时
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 优化器传入的是所有的参数
if choice == 3:
# 情况三:采用方式二冻结fc1层时, 优化器只传入fc2的参数
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)
if choice == 4:
# 情况4: 最优做法是将不更新的参数的requires_grad设置为False,同时不将该参数传入optimizer
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
# 定义一个 filter ,只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)
# 训练前的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0, 5, [3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 训练后的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
"""
结论
1. 最优写法能够节省显存和提升速度:
2. 节省显存:不将不更新的参数传入optimizer
3. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间
# element_size返回单个元素的字节大小,nelement返回元素个数
import torch
a = torch.zeros([128, 128])
print(a.element_size() * a.nelement())
# pytorch查看模型的参数总量、占用显存量以及flops
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
使用DNN_printer
"""
0条评论