searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

pytorch-训练全流程

2024-11-14 09:41:58
3
0
import torch

torch.seed()


class PolRegression(torch.nn.Module):
def __init__(self, input_dim):
super(PolRegression, self).__init__()
self.linear = torch.nn.Linear(in_features=input_dim, out_features=1)

def forward(self, x):
return self.linear(x)


def generate_data(batch_size=32):
# 方式1
x = torch.randn(batch_size)
# 方式2
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)

# 避免广播 torch.randn(batch_size, 1)
y = 2.0 * x ** 3 + 3.0 * x ** 2 + 4.0 * x + 5.0 + torch.randn(batch_size, 1) / 100

x_data = torch.cat([x ** i for i in range(1, 4)], 1)

return x_data, y


checkpoint_path = 'poly-regression-model.pth'


def train(num_epoch=1000):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PolRegression(3).to(device)

loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

start_epoch = read_checkpoint(model, optimizer, checkpoint_path)
for i in range(start_epoch, num_epoch):
x, y = generate_data(32)
x = x.to(device)
y = y.to(device)

predicted = model(x)
# just for debug
z = predicted - y
loss = loss_func(predicted, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {i + 1} , loss {loss.item()}")

if (i + 1) % 10 == 0:
write_checkpoint(i + 1, model, optimizer, checkpoint_path)

if (i + 1) % 10240 == 0:
print("just for simulator panic") # 模拟故障
exit(1)

print(f'weight: {model.linear.weight.data}')
print(f'bias: {model.linear.bias.data}')
write_checkpoint(num_epoch, model, optimizer, checkpoint_path)


def save_checkpoint(epoch, model, optimizer, path):
tmp_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}
torch.save(tmp_dict, path)


import os
import shutil


def rotate_save_file(file_path, max_versions=5):
if not os.path.exists(file_path):
return

# 获取文件名和扩展名
file_dir, file_name = os.path.split(file_path)
file_base, file_ext = os.path.splitext(file_name)

# 从最大版本号开始,逐个移动文件
for i in range(max_versions - 1, 0, -1):
old_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
new_file = os.path.join(file_dir, f"{file_base}.{i + 1}{file_ext}")
if os.path.exists(old_file):
shutil.move(old_file, new_file)

# 将当前文件移动到1号版本
shutil.move(file_path, os.path.join(file_dir, f"{file_base}.1{file_ext}"))


def write_checkpoint(epoch, model, optimizer, path):
rotate_save_file(checkpoint_path)
save_checkpoint(epoch, model, optimizer, path)


def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
print(f"checkpoint loaded, starting at epoch {epoch}")
return epoch


def read_checkpoint(model, optimizer, checkpoint_path, max_versions=5):
file_dir, file_name = os.path.split(checkpoint_path)
file_base, file_ext = os.path.splitext(file_name)

for i in range(0, max_versions + 1):
rotated_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
if i == 0:
rotated_file = checkpoint_path

if os.path.exists(rotated_file):
try:
start_epoch = load_checkpoint(model, optimizer, rotated_file)
except Exception as e:
print(f"{rotated_file} have exception {e}")
else:
print(f"load {rotated_file} success")
return start_epoch

print("not found checkpoint, starting training from scratch.")
return 0


def test_eval():
device = "cuda" if torch.cuda.is_available() else "cpu"

model = PolRegression(3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
read_checkpoint(model, optimizer, checkpoint_path)
model.eval()

data = torch.tensor([1.0, 2.0, 6.0]).unsqueeze(1)
x_data = torch.cat([data ** i for i in range(1, 4)], 1)
with torch.no_grad():
y_predict = model(x_data)
print(y_predict)


if __name__ == "__main__":
train(10000)
test_eval()


'''
import torch

batch_size = 10
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)
print(x.size())

y = 2 * x
print(y.size())

y = 2 * x + torch.randn(batch_size)
print(y.size())

z = torch.randn((batch_size, 1))
print(z.size())
'''

'''
torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 10])
torch.Size([10, 1])

'''
0条评论
作者已关闭评论
Top123
33文章数
3粉丝数
Top123
33 文章 | 3 粉丝
Top123
33文章数
3粉丝数
Top123
33 文章 | 3 粉丝
原创

pytorch-训练全流程

2024-11-14 09:41:58
3
0
import torch

torch.seed()


class PolRegression(torch.nn.Module):
def __init__(self, input_dim):
super(PolRegression, self).__init__()
self.linear = torch.nn.Linear(in_features=input_dim, out_features=1)

def forward(self, x):
return self.linear(x)


def generate_data(batch_size=32):
# 方式1
x = torch.randn(batch_size)
# 方式2
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)

# 避免广播 torch.randn(batch_size, 1)
y = 2.0 * x ** 3 + 3.0 * x ** 2 + 4.0 * x + 5.0 + torch.randn(batch_size, 1) / 100

x_data = torch.cat([x ** i for i in range(1, 4)], 1)

return x_data, y


checkpoint_path = 'poly-regression-model.pth'


def train(num_epoch=1000):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PolRegression(3).to(device)

loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

start_epoch = read_checkpoint(model, optimizer, checkpoint_path)
for i in range(start_epoch, num_epoch):
x, y = generate_data(32)
x = x.to(device)
y = y.to(device)

predicted = model(x)
# just for debug
z = predicted - y
loss = loss_func(predicted, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {i + 1} , loss {loss.item()}")

if (i + 1) % 10 == 0:
write_checkpoint(i + 1, model, optimizer, checkpoint_path)

if (i + 1) % 10240 == 0:
print("just for simulator panic") # 模拟故障
exit(1)

print(f'weight: {model.linear.weight.data}')
print(f'bias: {model.linear.bias.data}')
write_checkpoint(num_epoch, model, optimizer, checkpoint_path)


def save_checkpoint(epoch, model, optimizer, path):
tmp_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}
torch.save(tmp_dict, path)


import os
import shutil


def rotate_save_file(file_path, max_versions=5):
if not os.path.exists(file_path):
return

# 获取文件名和扩展名
file_dir, file_name = os.path.split(file_path)
file_base, file_ext = os.path.splitext(file_name)

# 从最大版本号开始,逐个移动文件
for i in range(max_versions - 1, 0, -1):
old_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
new_file = os.path.join(file_dir, f"{file_base}.{i + 1}{file_ext}")
if os.path.exists(old_file):
shutil.move(old_file, new_file)

# 将当前文件移动到1号版本
shutil.move(file_path, os.path.join(file_dir, f"{file_base}.1{file_ext}"))


def write_checkpoint(epoch, model, optimizer, path):
rotate_save_file(checkpoint_path)
save_checkpoint(epoch, model, optimizer, path)


def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
print(f"checkpoint loaded, starting at epoch {epoch}")
return epoch


def read_checkpoint(model, optimizer, checkpoint_path, max_versions=5):
file_dir, file_name = os.path.split(checkpoint_path)
file_base, file_ext = os.path.splitext(file_name)

for i in range(0, max_versions + 1):
rotated_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
if i == 0:
rotated_file = checkpoint_path

if os.path.exists(rotated_file):
try:
start_epoch = load_checkpoint(model, optimizer, rotated_file)
except Exception as e:
print(f"{rotated_file} have exception {e}")
else:
print(f"load {rotated_file} success")
return start_epoch

print("not found checkpoint, starting training from scratch.")
return 0


def test_eval():
device = "cuda" if torch.cuda.is_available() else "cpu"

model = PolRegression(3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
read_checkpoint(model, optimizer, checkpoint_path)
model.eval()

data = torch.tensor([1.0, 2.0, 6.0]).unsqueeze(1)
x_data = torch.cat([data ** i for i in range(1, 4)], 1)
with torch.no_grad():
y_predict = model(x_data)
print(y_predict)


if __name__ == "__main__":
train(10000)
test_eval()


'''
import torch

batch_size = 10
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)
print(x.size())

y = 2 * x
print(y.size())

y = 2 * x + torch.randn(batch_size)
print(y.size())

z = torch.randn((batch_size, 1))
print(z.size())
'''

'''
torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 10])
torch.Size([10, 1])

'''
文章来自个人专栏
云原生最佳实践
33 文章 | 1 订阅
0条评论
作者已关闭评论
作者已关闭评论
0
0