构造迭代读取的Dataloader
,首先需要可迭代的DataSet
,这一部分详细请参考:pytorch构造可迭代的Dataset——IterableDataset(pytorch Data学习二),下面直接开始封装到DataLoader
中
文章目录
- 封装IterableDataset到DataLoader
-
- 1. 一般文本封装方法
- 2. pandas read_xxx封装方法
封装IterableDataset到DataLoader
比如文件test_csv.csv
:
1,2,3,4,1
1,2,3,4,2
1,2,3,4,3
1,2,3,4,4
1,2,3,4,5
1. 一般文本封装方法
由于DataLoader得到的迭代数据都是Tensor
格式的数据,因此需要将文本转换为tensor格式,修改dataset的__iter__
方法为:
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
class MyIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
with open(self.file_path, 'r') as file_obj:
for line in file_obj:
line_data = line.strip('\n').split(',')
yield torch.from_numpy(np.array(line_data, dtype='int')) # 这里按照自己的代码看格式哈
然后封装即可:
if __name__ == '__main__':
dataset = MyIterableDataset('test_csv.csv')
dataloader = DataLoader(dataset, batch_size=3) # 这里batch_size=3,意味着每次读取dataloader都会循环三次dataset
for data in dataloader:
print(data)
完整代码:
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
class MyIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
with open(self.file_path, 'r') as file_obj:
for line in file_obj:
line_data = line.strip('\n').split(',')
yield torch.from_numpy(np.array(line_data, dtype='int'))
if __name__ == '__main__':
dataset = MyIterableDataset('test_csv.csv')
dataloader = DataLoader(dataset, batch_size=3)
for data in dataloader:
print(data)
2. pandas read_xxx封装方法
思路同上,代码如下:
class PandasIterableDataset(IterableDataset):
def __init__(self, file_path):
import pandas as pd
self.data_iter = pd.read_csv(file_path, iterator=True, header=None, chunksize=1)
def __iter__(self):
for data in self.data_iter:
yield torch.from_numpy(np.array(data).flatten())
if __name__ == '__main__':
dataset = PandasIterableDataset('test_csv.csv')
dataloader = DataLoader(dataset, batch_size=3)
for data in dataloader:
print(data)