原理可以参考:使用textCNN进行文本分类的原理
Keras的另一个实现可以参考:Keras实现textCNN文本分类

模型构建与训练

定义网络结构

定义一个textCNN类, 代码为tensorflow2.x版本。

from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout


class TextCNN(object):
    def __init__(self, maxlen, max_features, embedding_dims,
                 class_num=5,
                 last_activation='softmax'):
        self.maxlen = maxlen               # 句子最大长度
        self.max_features = max_features   # 词表大小
        self.embedding_dims = embedding_dims
        self.class_num = class_num
        self.last_activation = last_activation  # 激活函数

    def get_model(self):
        input = Input((self.maxlen,))   # batch_size留空
        embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)(input)
        convs = []
        for kernel_size in [3, 4, 5]:
            c = Conv1D(128, kernel_size, activation='relu')(embedding)
            c = GlobalMaxPooling1D()(c)
            convs.append(c)
        x = Concatenate()(convs)  # 拼接

        output = Dense(self.class_num, activation=self.last_activation)(x)
        model = Model(inputs=input, outputs=output)
        return model

定义通用的工具函数

注意:此处在utils.py中定义了一些工具函数
utils.py的代码如下所示:

# coding: utf-8

import sys
from collections import Counter
import numpy as np
import tensorflow.keras as kr
import os

if sys.version_info[0] > 2:
    is_py3 = True
else:
    reload(sys)
    sys.setdefaultencoding("utf-8")
    is_py3 = False

def open_file(filename, mode='r'):
    """
    常用文件操作,可在python2和python3间切换.
    mode: 'r' or 'w' for read or write
    """
    if is_py3:
        return open(filename, mode, encoding='utf-8', errors='ignore')
    else:
        return open(filename, mode)

def read_file(filename):
    """读取单个文件,文件中包含多个类别"""
    contents = []
    labels = []
    with open_file(filename) as f:
        for line in f:
            try:
                raw = line.strip().split("\t")
                content = raw[1].split(' ')
                if content:
                    contents.append(content)
                    labels.append(raw[0])
            except:
                pass
    return contents, labels    

def read_single_file(filename):
    """读取单个文件,文件为单一类别"""
    contents = []    
    label = filename.split('/')[-1].split('.')[0]

    with open_file(filename) as f:
        for line in f:
            try:
                content = line.strip().split(' ')
                if content:
                    contents.append(content)
            except:
                pass
    return contents, label

def read_files(dirname):
    """读取文件夹"""
    contents = []
    labels = []
    files = [f for f in os.listdir(dirname) if f.endswith(".txt")]
    for filename in files:
        content, label = read_single_file(os.path.join(dirname, filename))
        contents.extend(content)
        labels.extend([label]*len(content))
    return contents, labels

def build_vocab(train_dir, vocab_file, vocab_size=5000):
    """根据训练集构建词汇表,存储"""
    data_train, _ = read_files(train_dir)

    all_data = []
    for content in data_train:
        all_data.extend(content)

    counter = Counter(all_data)
    count_pairs = counter.most_common(vocab_size - 1)
    words, _ = list(zip(*count_pairs))
    # 添加一个 <PAD> 来将所有文本pad为同一长度
    words = ['<PAD>'] + list(words)
    open_file(vocab_file, mode='w').write('\n'.join(words) + '\n')


def read_vocab(vocab_file):
    """读取词汇表"""
    # words = open_file(vocab_dir).read().strip().split('\n')
    with open_file(vocab_file) as fp:
        # 如果是py2 则每个值都转化为unicode
        words = [_.strip() for _ in fp.readlines()]
    word_to_id = dict(zip(words, range(len(words))))
    return words, word_to_id


def read_category():
    """读取分类,编码"""
    categories = ['car', 'entertainment', 'military', 'sports', 'technology']
    cat_to_id = dict(zip(categories, range(len(categories))))
    return categories, cat_to_id

def encode_cate(content, words):
    """将id表示的内容转换为文字"""
    return [(words[x] if x in words else 40000) for x in content]

def encode_sentences(contents, words):
    """将id表示的内容转换为文字"""
    return [encode_cate(x,words) for x in contents]

def process_file(filename, word_to_id, cat_to_id, max_length=600):
    """将文件转换为id表示"""
    contents, labels = read_file(filename)

    data_id, label_id = [], []
    for i in range(len(contents)):
        data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
        label_id.append(cat_to_id[labels[i]])

    # 使用keras提供的pad_sequences来将文本pad为固定长度
    x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
    y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示

    return x_pad, y_pad


def batch_iter(x, y, batch_size=64):
    """生成批次数据"""
    data_len = len(x)
    num_batch = int((data_len - 1) / batch_size) + 1

    indices = np.random.permutation(np.arange(data_len))
    x_shuffle = x[indices]
    y_shuffle = y[indices]

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

数据处理与训练

from tensorflow.keras.preprocessing import sequence
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
import sys
sys.path.append('../data/lesson2_data')
from utils import *

# 路径等配置
data_dir = "../data/lesson2_data/data"
vocab_file = "../data/lesson2_data/vocab/vocab.txt"
vocab_size = 40000

# 神经网络配置
max_features = 40001
maxlen = 100
batch_size = 256
embedding_dims = 50
epochs = 8

print('数据预处理与加载数据...')
# 如果不存在词汇表,重建
if not os.path.exists(vocab_file):  
    build_vocab(data_dir, vocab_file, vocab_size)
# 获得 词汇/类别 与id映射字典
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_file)

# 全部数据
x, y = read_files(data_dir)
data = list(zip(x,y))
del x,y
# 乱序
random.shuffle(data)
# 切分训练集和测试集
train_data, test_data = train_test_split(data)
# 对文本的词id和类别id进行编码
x_train = encode_sentences([content[0] for content in train_data], word_to_id)
y_train = to_categorical(encode_cate([content[1] for content in train_data], cat_to_id))
x_test = encode_sentences([content[0] for content in test_data], word_to_id)
y_test = to_categorical(encode_cate([content[1] for content in test_data], cat_to_id))

print('对序列做padding,保证是 samples*timestep 的维度')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

print('构建模型...')
model = TextCNN(maxlen, max_features, embedding_dims).get_model()
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])

print('训练...')
# 设定callbacks回调函数
my_callbacks = [
    ModelCheckpoint('../../tmp/cnn_model.h5', verbose=1),
    EarlyStopping(monitor='val_accuracy', patience=2, mode='max')
]

# fit拟合数据
history = model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          callbacks=my_callbacks,
          validation_data=(x_test, y_test))

#print('对测试集预测...')
#result = model.predict(x_test)
训练中间信息输出(画图)
import matplotlib.pyplot as plt
plt.switch_backend('agg')
%matplotlib inline

fig1 = plt.figure()
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves :CNN',fontsize=16)
fig1.savefig('../../tmp/loss_cnn.png')
plt.show()
fig2=plt.figure()
plt.plot(history.history['accuracy'],'r',linewidth=3.0)
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves : CNN',fontsize=16)
fig2.savefig('../../tmp/accuracy_cnn.png')
plt.show()

注意:Windows下pycharm中运行需要把%matplotlib inlineplt.switch_backend('agg')都注释掉才能出图。

模型结构打印
from tensorflow.keras.utils import plot_model
# model.summary()
plot_model(model, show_shapes=True, show_layer_names=True)
遇到问题
  1. TensorFlow版本为1.13。
    TensorFlow1.x版本中没有集成keras,需要把tensorflow.keras.xx改为keras.xx

  2. 运行报错:ValueError: Error when checking target: expected dense_1 to have shape (5,) but got array with shape (40001,)
    出现这个问题很可能是多分类的label设置不当导致的。通过检查代码发现, 经过to_categorical之后的y有40001维,且前40000维均为0,最后一维为1。定位到utils.py中的encode_cate函数,发现是label中带有文件夹名,无法映射到cat_to_id,再次定位到read_single_file,发现是Windows和linux分隔符不同导致的。于是将label = filename.split('/')[-1].split('.')[0]修改为:

    import platform
    if platform.system()=='Linux':   # Windows will be : Windows, Linux will be : Linux
        label = filename.split('/')[-1].split('.')[0]
    else:
        label = filename.split('\\')[-1].split('.')[0]
    
  3. 预测时plt.plot(history.history['accuracy'],'r',linewidth=3.0)处报错KeyError: 'accuracy'
    原因是keras库版本不同。因为keras库老版本中的参数不是accuracy,而是acc,将参数accuracy替换为acc即可。同理,将val_accuracy替换为val_acc。
    打印print(history.history.keys())可得,history的四个参数为dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])
    修改代码为:

    fig2=plt.figure()
    plt.plot(history.history['acc'],'r',linewidth=3.0)
    plt.plot(history.history['val_acc'],'b',linewidth=3.0)
    
  4. 画结构图时plot_model(model, show_shapes=True, show_layer_names=True)处报错ImportError: Failed to import ``pydot``. Please install ``pydot``. For example with ``pip install pydot``.然后又报错'pydotfailed to call GraphViz.' OSError: ``pydot`` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.
    解决:使用plot_model得先安装好另外两个库,graphviz和pydot。

    pip3 install graphviz
    pip3 install pydot
    

    电脑还得安装graphviz并添加环境变量。下载地址http://www.graphviz.org/download/,选择自己操作系统对应的版本的.msi文件,安装完成后将安装的Graphviz2.38/bin添加到环境变量,添加完后最好重启电脑。如果在pycharm中不显示图片,就用to_file写入到文件里。
    也可以在代码中添加环境变量,如下:

    import os
    os.environ["PATH"] += os.pathsep + 'E:\Program Files (x86)\Graphviz2.38\\bin'
    plot_model(model, show_shapes=True, show_layer_names=True,to_file='test.png')
    
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐