基于 Tensorflow 的 TextRNN 在搜狗新闻数据的文本分类实践

先前一篇文章基于 Tensorflow 的 TextCNN 在搜狗新闻数据的文本分类实践是CNN在文本分类中的一次尝试,在接触了LSTM之后了解到它自然语言处理中有非常广泛的应用,比如情感分析、信息提取(命名体识别)、机器翻译等领域大放异彩,本文就来看看它在文本分类上究竟会有什么样的表现。

因为原始数据的处理在上一次尝试中已经完成了,所以在本次实验中只需要完成RNN模型的参数配置和模型运行的脚本。

RNN配置参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class TRNNConfig(object):
"""RNN配置参数"""

embedding_dim = 64
seq_length = 1000
num_classes = 11
vocab_size = 5000

num_layers = 2 # 隐藏层层数
hidden_dim = 128
rnn = 'gru' # 记忆单元是lstm或gru

dropout_keep_prob = 0.8
learning_rate = 1e-3

batch_size = 128
num_epochs = 10

print_per_batch = 100
save_per_batch = 10

RNN训练与验证

RNN的训练验证部分代码与CNN的差别很小,修改模型存储路径等几个小地方即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# coding: utf-8

from __future__ import print_function

import os
import sys
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from sklearn import metrics

from rnn_model import TRNNConfig, TextRNN
from data.sougounews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab

base_dir = 'data/'
train_dir = os.path.join(base_dir, 'news_train.txt')
test_dir = os.path.join(base_dir, 'news_test.txt')
val_dir = os.path.join(base_dir, 'news_val.txt')
vocab_dir = os.path.join(base_dir, 'vocab.txt')

save_dir = 'checkpoints/text_rnn'
save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果的保存路径


def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))


def feed_data(x_batch, y_batch, keep_prob):
feed_dict = {
model.input_x: x_batch,
model.input_y: y_batch,
model.keep_prob: keep_prob
}
return feed_dict


def evaluate(sess, x_, y_):
data_len = len(x_)
batch_eval = batch_iter(x_, y_, 128)
total_loss = 0.0
total_acc = 0.0
for x_batch, y_batch in batch_eval:
batch_len = len(x_batch)
feed_dict = feed_data(x_batch, y_batch, 1.0)
loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len

return total_loss / data_len, total_acc / data_len


def train():
print('Configuring TensorBoard and Saver...')
tensorboard_dir = 'tensorboard/text_rnn'
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)

tf.summary.scalar('loss', model.loss)
tf.summary.scalar('accuracy', model.acc)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)

# Saver 配置
saver = tf.train.Saver()
if not os.path.exists(save_dir):
os.makedirs(save_dir)

print('Loading training and validation data...')
# 载入训练集和验证集数据
start_time = time.time()
x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)

# 创建 session
session = tf.Session()
session.run(tf.global_variables_initializer())
writer.add_graph(session.graph)

print('Training and evaluating...')
start_time = time.time()
total_batch = 0
best_acc_val = 0.0
last_improved = 0
require_improvement = 1000

flag = False
for epoch in range(config.num_epochs):
print('Epoch:', epoch + 1)
batch_train = batch_iter(x_train, y_train, config.batch_size)
for x_batch, y_batch in batch_train:
feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

if total_batch % config.save_per_batch == 0:
# 每多少批将训练结果写入 tensorboard
s = session.run(merged_summary, feed_dict=feed_dict)
writer.add_summary(s, total_batch)

if total_batch % config.print_per_batch == 0:
# 每多少批次输出训练集和验证集结果
feed_dict[model.keep_prob] = 1.0
loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
loss_val, acc_val = evaluate(session, x_val, y_val)

if acc_val > best_acc_val:
best_acc_val = acc_val
last_improved = total_batch
saver.save(sess=session, save_path=save_path)
improved_str = '*'
else:
improved_str = ''

time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
+ ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

session.run(model.optim, feed_dict=feed_dict)
total_batch += 1

if total_batch - last_improved > require_improvement:
# 验证集准确率不提升,提前结束训练
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break


def test():
print('Loading test data...')
start_time = time.time()
x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path)

print('Testing...')
loss_test, acc_test = evaluate(session, x_test, y_test)
msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
print(msg.format(loss_test, acc_test))

batch_size = 128
data_len = len(x_test)
num_batch = int((data_len - 1) / batch_size) + 1

y_test_cls = np.argmax(y_test, 1)
y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)
for i in range(num_batch):
start_id = i * batch_size
end_id = min((i + 1) * batch_size, data_len)
feed_dict = {
model.input_x: x_test[start_id:end_id],
model.keep_prob: 1.0
}
y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

# 评估
print('Precision, Recall and F1-score...')
print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

# 混淆矩阵
print('Confusion Matrix...')
cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
print(cm)

time_dif = get_time_dif(start_time)
print('Time usage:', time_dif)


if __name__ == '__main__':
if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
raise ValueError("""please use: python run_rnn.py [train / test]""")

print('Configuring RNN model...')
config = TRNNConfig()
if not os.path.exists(vocab_dir):
build_vocab(train_dir, vocab_dir, config.vocab_size)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
config.vocab_size = len(words)
model = TextRNN(config)

if sys.argv[1] == 'train':
train()
else:
test()

模型训练和测试结果

从运行的结果来看,RNN的训练时间明显要长很多,大约是CNN的两倍,并且内存占用也明显要高。

image

GRU

经过10轮的训练,模型在训练集上的准确率达到96.09%,验证集达到86.59%,相比CNN来说表现还是要差一点点,原因嘛,暂时也不知道(==)。因为训练数据集比较小,所以有的epoch训练结果没有打印在屏幕上。

image

再看看测试集上的表现,总体准确率为83.96%,除了文化(cul)、商业和社会(news),其他类别的准确率都在80%以上。社会类别的新闻表现最差,才65%,但是我觉得跟搜狗的新闻数据质量也有关系,因为就文章内容来说,文不对题和内容属性模糊的频率很高,特别是社会类别,似乎分在文化类也可以。

image

LSTM

把RNN的cell换成LSTM,看看效果如何。训练的时间依然很长,花了半个小时,10个epoch中,在第8轮验证集数据达到最佳效果,最佳的训练集准确率为95.31%,验证集准确率为85.66%。

测试集数据,准确率为81.79%,这次轮到社会类和汽车的类别预测效果最差,需要进一步优化参数。

image

image

思考

RNN模型在文本分类任务上表现不佳的原因可能有:

  • RNN训练本身需要大量的数据,而在本次实践中拿到的数据每个类别仅仅只有550条
  • 本次新闻分类尚未引进预训练好的词向量,所以出现了过拟合
  • 基于字符级的文本分类比基于词的表现要差?
------本文结束,欢迎收藏本站、分享、评论或联系作者!------
点击查看
赠人玫瑰 手有余香
0%