基本思路
- 所有的数据都叫
Dataset
- 加载数据的叫
DataLoader
- 每次加载多少条数据叫
batch_size
所以构造训练数据的步骤是:首先把数据加载为Dataset
,然后用DataLoader
依次把数据传递到模型中即可
全部代码
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
txt_data = np.loadtxt('./my_data.txt', delimiter=',')
self._x = torch.from_numpy(txt_data[:, :2])
self._y = torch.from_numpy(txt_data[:, 2])
self._len = len(txt_data)
def __getitem__(self, item): # 每次循环的时候返回的值
return self._x[item], self._y[item]
def __len__(self):
return self._len
data = MyDataset()
dataloader = DataLoader(data, batch_size=3, shuffle=False, drop_last=True, num_workers=0)
n = 0
for x_data, y_label in dataloader:
print('x:', x_data)
print("y:", y_label)
n += 1
print('迭代次数:', n)