searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

pytorch中冻结训练(上)

2024-11-15 09:17:46
0
0
import torch.nn as nn
import torch.optim as optim
import torch


# 定义一个简单的网络

class MyNet(nn.Module):
def __init__(self, num_class=5):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc1.weight = nn.Parameter(torch.ones((4,8), dtype=torch.float))
self.fc2 = nn.Linear(4, num_class)
self.fc2.weight = nn.Parameter(torch.ones((num_class,4), dtype=torch.float))

def forward(self, x):
return self.fc2(self.fc1(x))


model = MyNet()

loss_fn = nn.CrossEntropyLoss()

choice = 3
if choice == 1:
# 情况一:不冻结参数时
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 传入的是所有的参数

if choice == 2:
# 情况二:采用方式一冻结fc1层时
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 优化器传入的是所有的参数

if choice == 3:
# 情况三:采用方式二冻结fc1层时, 优化器只传入fc2的参数
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)

if choice == 4:
# 情况4: 最优做法是将不更新的参数的requires_grad设置为False,同时不将该参数传入optimizer
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False

# 定义一个 filter ,只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)


# 训练前的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0, 5, [3]).long()
output = model(x)

loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 训练后的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

"""
结论
1. 最优写法能够节省显存和提升速度:
2. 节省显存:不将不更新的参数传入optimizer
3. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间

# element_size返回单个元素的字节大小,nelement返回元素个数
import torch
a = torch.zeros([128, 128])
print(a.element_size() * a.nelement())

# pytorch查看模型的参数总量、占用显存量以及flops
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
使用DNN_printer
"""
0条评论
作者已关闭评论
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
Top123
29文章数
3粉丝数
Top123
29 文章 | 3 粉丝
原创

pytorch中冻结训练(上)

2024-11-15 09:17:46
0
0
import torch.nn as nn
import torch.optim as optim
import torch


# 定义一个简单的网络

class MyNet(nn.Module):
def __init__(self, num_class=5):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc1.weight = nn.Parameter(torch.ones((4,8), dtype=torch.float))
self.fc2 = nn.Linear(4, num_class)
self.fc2.weight = nn.Parameter(torch.ones((num_class,4), dtype=torch.float))

def forward(self, x):
return self.fc2(self.fc1(x))


model = MyNet()

loss_fn = nn.CrossEntropyLoss()

choice = 3
if choice == 1:
# 情况一:不冻结参数时
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 传入的是所有的参数

if choice == 2:
# 情况二:采用方式一冻结fc1层时
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 优化器传入的是所有的参数

if choice == 3:
# 情况三:采用方式二冻结fc1层时, 优化器只传入fc2的参数
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)

if choice == 4:
# 情况4: 最优做法是将不更新的参数的requires_grad设置为False,同时不将该参数传入optimizer
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False

# 定义一个 filter ,只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)


# 训练前的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0, 5, [3]).long()
output = model(x)

loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 训练后的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

"""
结论
1. 最优写法能够节省显存和提升速度:
2. 节省显存:不将不更新的参数传入optimizer
3. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间

# element_size返回单个元素的字节大小,nelement返回元素个数
import torch
a = torch.zeros([128, 128])
print(a.element_size() * a.nelement())

# pytorch查看模型的参数总量、占用显存量以及flops
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
使用DNN_printer
"""
文章来自个人专栏
云原生最佳实践
29 文章 | 1 订阅
0条评论
作者已关闭评论
作者已关闭评论
0
0