TextCNN算法流程
- 整体流程是将词拼接在一起,一句话构成一个特征图
- 根据卷积核得到多个特征向量
- 每个特征向量全局池化,选最大的特征作为这个特征向量的值
- 拼接特征值,得到句子的特征向量
- 全连接后得到目标维度的结果
完整代码
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as F
class TextCnnModel(nn.Module):
def __init__(self, embedding_size, output_size, channels=256, filter_sizes=(2, 3, 4), dropout=0.5):
"""
TextCnn 模型做文本分类
:param embedding_size: 每个词向量的embedding长度
:param output_size: 最后的输出个数、待分类的个数
:param channels: 卷积核的数量
:param filter_sizes: 卷积核尺寸
:param dropout: 随机失活概率
"""
super(TextCnnModel, self).__init__()
self.convs = nn.ModuleList(
[nn.Conv2d(1, channels, (k, embedding_size)) for k in filter_sizes])
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(channels * len(filter_sizes), output_size)
def conv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3)
x = F.max_pool1d(x, x.size(2)).squeeze(2)
return x
def forward(self, x):
out = x.unsqueeze(1)
out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
out = self.dropout(out)
out = self.fc(out)
return out
def get_total_train_data(word_embedding_size, class_count):
"""得到全部的训练数据,这里需要替换成自己的数据"""
import numpy as np
x_train = torch.Tensor(np.random.random((1000, 20, word_embedding_size)))
y_train = torch.Tensor(
np.random.randint(0, class_count, size=(1000, 1))).long()
return x_train, y_train
if __name__ == '__main__':
epochs = 1000
batch_size = 30
embedding_size = 350
output_class = 14
x_train, y_train = get_total_train_data(embedding_size, output_class)
train_loader = Data.DataLoader(
dataset=Data.TensorDataset(x_train, y_train),
batch_size=batch_size,
shuffle=True,
num_workers=6,
drop_last=True,
)
model = TextCnnModel(embedding_size=embedding_size, output_size=output_class)
cross_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
for i in range(epochs):
for seq, labels in train_loader:
optimizer.zero_grad()
y_pred = model(seq)
single_loss = cross_loss(y_pred, labels.squeeze())
single_loss.backward()
optimizer.step()
print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))