30天打牢数模基础-循环神经网络讲解
显示每轮训练的损失(loss)和准确率(accuracy),验证集的损失(val_loss)和准确率(val_accuracy);:如果训练集准确率很高(如100%)但测试集准确率很低(如50%),说明模型过拟合,可以减少训练轮次或增加数据。:如果有主题标签(如“思乡”“山水”),可以将其作为额外特征加入模型(需修改模型结构为函数式API)。:显示测试集输入唐诗前半部分的预测结果(如输入“慈母手中
案例代码实现
一、代码说明
本代码实现了唐诗绝句补全任务,使用LSTM模型处理序列数据,解决“顺序依赖”和“长期记忆”问题。流程如下:
模拟数据:生成10条唐诗数据(输入序列+目标字);
数据预处理:构建词汇表、将文字转换为索引、统一序列长度(Padding);
模型构建:使用Embedding+LSTM+Dense结构,处理序列数据;
模型训练:用训练集训练模型,优化参数;
测试生成:输入唐诗前半部分,模型自动补全后半部分的字。
二、完整代码
# 导入必要的库
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import plot_model
# ------------------------------
# 1. 模拟唐诗数据(输入序列+目标字)
# ------------------------------
data = [
("床前明月光,疑是地上", "霜"), # 李白《静夜思》
("空山新雨后,天气晚来", "秋"), # 王维《山居秋暝》
("远看山有色,近听水无", "声"), # 王维《画》
("独在异乡为异客,每逢佳节倍思", "亲"), # 王维《九月九日忆山东兄弟》
("白日依山尽,黄河入海", "流"), # 王之涣《登鹳雀楼》
("锄禾日当午,汗滴禾下", "土"), # 李绅《悯农》
("春眠不觉晓,处处闻啼", "鸟"), # 孟浩然《春晓》
("红豆生南国,春来发几", "枝"), # 王维《相思》
("慈母手中线,游子身上", "衣"), # 孟郊《游子吟》(测试集)
("离离原上草,一岁一枯", "荣") # 白居易《赋得古原草送别》(测试集)
]
# ------------------------------
# 2. 数据预处理(构建词汇表+转换索引+Padding)
# ------------------------------
# 提取所有字符,构建词汇表(去重+排序)
chars = set()
for input_seq, target in data:
chars.update(input_seq) # 加入输入序列的所有字
chars.add(target) # 加入目标字
chars = sorted(chars) # 排序保证索引稳定
vocab_size = len(chars)
print(f"【词汇表信息】大小:{vocab_size},包含字符:{chars}")
# 构建字符-索引映射(用于文字转索引)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)} # 索引转字符(用于结果解码)
# 计算输入序列的最大长度(统一序列长度,短序列补0)
max_len = max(len(input_seq) for input_seq, _ in data)
print(f"【序列信息】输入序列最大长度:{max_len}")
# 将输入序列转换为索引序列,并Padding到最大长度
X = [] # 输入序列(索引)
y = [] # 目标字(索引)
for input_seq, target in data:
# 输入序列转索引(忽略未收录字符,此处假设所有字符都在词汇表中)
input_idx = [char_to_idx[c] for c in input_seq]
# Padding:短序列后面补0,长度统一为max_len
input_idx_padded = pad_sequences([input_idx], maxlen=max_len, padding='post')[0]
X.append(input_idx_padded)
# 目标字转索引
y.append(char_to_idx[target])
# 转换为numpy数组(Keras模型要求输入为数组)
X = np.array(X)
y = np.array(y)
# 划分训练集(80%)和测试集(20%)
split_idx = int(0.8 * len(X))
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]
print(f"【数据集划分】训练集:{len(X_train)}条,测试集:{len(X_test)}条")
# ------------------------------
# 3. 构建LSTM模型(Embedding+LSTM+Dense)
# ------------------------------
model = Sequential([
# 嵌入层:将字符索引转换为低维向量(输入维度=词汇表大小,输出维度=128,输入长度=最大序列长度)
Embedding(input_dim=vocab_size, output_dim=128, input_length=max_len, name="embedding"),
# LSTM层:128个隐藏单元(处理序列依赖,保留长期记忆)
LSTM(128, name="lstm"),
# 输出层:词汇表大小的神经元(预测每个字符的概率),激活函数用softmax(多分类)
Dense(vocab_size, activation='softmax', name="output")
])
# 编译模型:优化器用Adam(常用优化器),损失函数用稀疏交叉熵(目标是整数索引)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'] # 监控准确率
)
# 打印模型结构(可选,需安装pydot和graphviz)
# plot_model(model, to_file='lstm_model.png', show_shapes=True)
print("【模型结构】")
model.summary()
# ------------------------------
# 4. 训练模型
# ------------------------------
print("\n【模型训练】开始训练...")
history = model.fit(
X_train, y_train, # 训练集输入和目标
epochs=30, # 训练轮次(可调整,越大越可能过拟合)
batch_size=2, # 批次大小(小批量更稳定)
validation_data=(X_test, y_test), # 验证集(监控泛化能力)
verbose=1 # 显示训练过程(1=显示进度条,0=不显示)
)
print("【模型训练】结束训练!")
# ------------------------------
# 5. 定义补全函数(输入文本,返回预测的字)
# ------------------------------
def complete_poem(input_text):
"""
唐诗补全函数:输入唐诗前半部分,返回模型预测的后半部分字。
参数:input_text(str):唐诗前半部分(如“床前明月光,疑是地上”)
返回:str:预测的字(如“霜”),失败返回None
"""
# 1. 检查输入长度(超过最大长度则截断)
if len(input_text) > max_len:
input_text = input_text[:max_len]
print(f"【提示】输入文本过长,截断为:{input_text}")
# 2. 转换输入文本为索引序列(忽略未收录字符)
input_idx = []
for c in input_text:
if c in char_to_idx:
input_idx.append(char_to_idx[c])
else:
print(f"【错误】输入包含未收录字符:{c}")
return None
# 3. Padding到最大长度(后面补0)
input_idx_padded = pad_sequences([input_idx], maxlen=max_len, padding='post')[0]
# 4. 模型预测(输入为[1, max_len]的数组)
predictions = model.predict(np.array([input_idx_padded]), verbose=0)[0]
# 5. 取概率最大的字符索引(argmax返回最大值的位置)
predicted_idx = np.argmax(predictions)
# 6. 转换索引为字符(返回结果)
return idx_to_char[predicted_idx]
# ------------------------------
# 6. 测试补全函数(用测试集数据验证)
# ------------------------------
print("\n【模型测试】补全结果(测试集):")
# 测试集为数据集中的后2条(索引8、9)
test_cases = [
("慈母手中线,游子身上", "衣"), # 预期:衣(测试集第1条)
("离离原上草,一岁一枯", "荣") # 预期:荣(测试集第2条)
]
for input_text, expected in test_cases:
predicted = complete_poem(input_text)
if predicted:
print(f"输入:{input_text} → 预测:{predicted}(预期:{expected})")
else:
print(f"输入:{input_text} → 补全失败")
三、代码使用说明
1. 环境准备
安装Python(3.7+);
安装必要库:pip install tensorflow numpy。
2. 运行代码
将代码保存为tangshi_completion.py;
在终端运行:python tangshi_completion.py。
3. 结果解释
词汇表信息:显示模型能识别的所有字符(如“床”“前”“明”等);
模型结构:显示LSTM模型的层结构(嵌入层→LSTM层→输出层);
训练过程:显示每轮训练的损失(loss)和准确率(accuracy),验证集的损失(val_loss)和准确率(val_accuracy);
补全结果:显示测试集输入唐诗前半部分的预测结果(如输入“慈母手中线,游子身上”,预测“衣”)。
4. 调整建议
增加数据:模拟更多唐诗数据(如20-30条),提高模型泛化能力;
调整参数:
epochs:训练轮次(如改为50,增加训练时间,可能提高准确率);
batch_size:批次大小(如改为4,加快训练速度);
LSTM层的units:隐藏单元数量(如改为256,增加模型复杂度);
加入主题标签:如果有主题标签(如“思乡”“山水”),可以将其作为额外特征加入模型(需修改模型结构为函数式API)。
四、注意事项
序列顺序:输入序列的顺序不能打乱(如“床前明月光”不能改为“光月明前床”),否则模型无法正确捕捉顺序依赖;
未收录字符:输入文本中的字符必须在词汇表中(否则会提示错误),可以通过扩展词汇表解决;
过拟合:如果训练集准确率很高(如100%)但测试集准确率很低(如50%),说明模型过拟合,可以减少训练轮次或增加数据。
通过本代码,你可以掌握LSTM处理序列数据的核心流程,解决“唐诗补全”这类需要长期记忆的任务。
更多推荐
所有评论(0)