之前一直关注BERT等语言模型,最近学习了一下fasttext的原理,在新浪微博的二分类数据集中进行试验,发现效果挺好,简单记录相关内容。

import fasttext

def train_model(input_path,output_path,dim=100,epoch=5,lr=0.1,loss='softmax'):
	model = fasttext.train_supervised(input_path,label='__label__',dim=dim,epoch=epoch,lr=lr,wordNgrams=2,loss=loss)
	model.save_model(output_path)#模型保存路径
	return model

#训练模型
model = train_model(input_path='./train.txt',output_path='./fasttext.model')

#验证集上进行效果验证
valid_result = model.test('./valid.txt')
print(valid_result)

#对测试数据进行预测,测试数据也要按照训练数据的方式进行分词处理再以空格连接,只是测试数据中不加标签
with open('./test.txt','r') as rf:
	lines = rf.readlines()
predict_lines = [line.replace('\n','') for line in lines]#一定要把去除每一行数据的换行符!

#普通预测只输出类别
lables = model.predict(predict_lines)
#输出类别和概率
labels_pro = model.predict_proba(predict_lines)
#输出前K个类别
labels_k = model.predict(predict_lines, k = 3)


fasttext对输入数据格式具有特殊的要求,以“今天是元旦”为例,需要将其整理成如下格式:
“今天 是 元旦 __label__0”
文本进行分词后以空格拼接,类别标签为‘0’,并且文本内容和标签之间是’\t’,
数据处理简要代码如下:

import jieba
for line in lines:
	con = line.split('\t')[0]#line = con + '\t' + label
	label = line.split('\t')[1]
	new_line = ' '.join(jieba.lcut(con)) + '\t' + '__label__' + label

在新浪微博的二分类数据集中准确率达到96%,后续尝试在多分类、长文本数据集中的效果

参考:
https://zhuanlan.zhihu.com/p/66739066
https://zhuanlan.zhihu.com/p/89802447(比较详细)
https://www.souhu.com/a/216556978_609569

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐