定义网络
class LinearRegressionModel(nn.Module):
def __init__(self, input_shape, output_shape):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_shape, output_shape)
def forward(self, x):
out = self.linear(x)
return out
全部代码
import torch
import torch.nn as nn
class LinearRegressionModel(nn.Module):
def __init__(self, input_shape, output_shape):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_shape, output_shape)
def forward(self, x):
out = self.linear(x)
return out
if __name__ == '__main__':
x_train = torch.randn(100, 4)
y_train = torch.randn(100, 1)
model = LinearRegressionModel(x_train.shape[1], 1)
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
for epoch in range(epochs):
epoch += 1
optimizer.zero_grad()
outputs = model(x_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print('epoch {}, loss {}'.format(epoch, loss.item()))
predicted = model(torch.randn(100, 4)).data.numpy()
print(predicted)