import torch
torch.seed()
class PolRegression(torch.nn.Module):
def __init__(self, input_dim):
super(PolRegression, self).__init__()
self.linear = torch.nn.Linear(in_features=input_dim, out_features=1)
def forward(self, x):
return self.linear(x)
def generate_data(batch_size=32):
# 方式1
x = torch.randn(batch_size)
# 方式2
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)
# 避免广播 torch.randn(batch_size, 1)
y = 2.0 * x ** 3 + 3.0 * x ** 2 + 4.0 * x + 5.0 + torch.randn(batch_size, 1) / 100
x_data = torch.cat([x ** i for i in range(1, 4)], 1)
return x_data, y
checkpoint_path = 'poly-regression-model.pth'
def train(num_epoch=1000):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PolRegression(3).to(device)
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
start_epoch = read_checkpoint(model, optimizer, checkpoint_path)
for i in range(start_epoch, num_epoch):
x, y = generate_data(32)
x = x.to(device)
y = y.to(device)
predicted = model(x)
# just for debug
z = predicted - y
loss = loss_func(predicted, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {i + 1} , loss {loss.item()}")
if (i + 1) % 10 == 0:
write_checkpoint(i + 1, model, optimizer, checkpoint_path)
if (i + 1) % 10240 == 0:
print("just for simulator panic") # 模拟故障
exit(1)
print(f'weight: {model.linear.weight.data}')
print(f'bias: {model.linear.bias.data}')
write_checkpoint(num_epoch, model, optimizer, checkpoint_path)
def save_checkpoint(epoch, model, optimizer, path):
tmp_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}
torch.save(tmp_dict, path)
import os
import shutil
def rotate_save_file(file_path, max_versions=5):
if not os.path.exists(file_path):
return
# 获取文件名和扩展名
file_dir, file_name = os.path.split(file_path)
file_base, file_ext = os.path.splitext(file_name)
# 从最大版本号开始,逐个移动文件
for i in range(max_versions - 1, 0, -1):
old_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
new_file = os.path.join(file_dir, f"{file_base}.{i + 1}{file_ext}")
if os.path.exists(old_file):
shutil.move(old_file, new_file)
# 将当前文件移动到1号版本
shutil.move(file_path, os.path.join(file_dir, f"{file_base}.1{file_ext}"))
def write_checkpoint(epoch, model, optimizer, path):
rotate_save_file(checkpoint_path)
save_checkpoint(epoch, model, optimizer, path)
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
print(f"checkpoint loaded, starting at epoch {epoch}")
return epoch
def read_checkpoint(model, optimizer, checkpoint_path, max_versions=5):
file_dir, file_name = os.path.split(checkpoint_path)
file_base, file_ext = os.path.splitext(file_name)
for i in range(0, max_versions + 1):
rotated_file = os.path.join(file_dir, f"{file_base}.{i}{file_ext}")
if i == 0:
rotated_file = checkpoint_path
if os.path.exists(rotated_file):
try:
start_epoch = load_checkpoint(model, optimizer, rotated_file)
except Exception as e:
print(f"{rotated_file} have exception {e}")
else:
print(f"load {rotated_file} success")
return start_epoch
print("not found checkpoint, starting training from scratch.")
return 0
def test_eval():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PolRegression(3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
read_checkpoint(model, optimizer, checkpoint_path)
model.eval()
data = torch.tensor([1.0, 2.0, 6.0]).unsqueeze(1)
x_data = torch.cat([data ** i for i in range(1, 4)], 1)
with torch.no_grad():
y_predict = model(x_data)
print(y_predict)
if __name__ == "__main__":
train(10000)
test_eval()
'''
import torch
batch_size = 10
x = torch.randint(-5, 5, (batch_size,)) * 1.0 # 改成float
x = x.unsqueeze(1)
print(x.size())
y = 2 * x
print(y.size())
y = 2 * x + torch.randn(batch_size)
print(y.size())
z = torch.randn((batch_size, 1))
print(z.size())
'''
'''
torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 10])
torch.Size([10, 1])
'''