案例代码实现

 

一、代码说明

 

本代码实现了唐诗绝句补全任务,使用LSTM模型处理序列数据,解决“顺序依赖”和“长期记忆”问题。流程如下:

模拟数据:生成10条唐诗数据(输入序列+目标字);

数据预处理:构建词汇表、将文字转换为索引、统一序列长度(Padding);

模型构建:使用Embedding+LSTM+Dense结构,处理序列数据;

模型训练:用训练集训练模型,优化参数;

测试生成:输入唐诗前半部分,模型自动补全后半部分的字。

 

二、完整代码

# 导入必要的库
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import plot_model

# ------------------------------
# 1. 模拟唐诗数据(输入序列+目标字)
# ------------------------------
data = [
    ("床前明月光,疑是地上", "霜"),    # 李白《静夜思》
    ("空山新雨后,天气晚来", "秋"),    # 王维《山居秋暝》
    ("远看山有色,近听水无", "声"),    # 王维《画》
    ("独在异乡为异客,每逢佳节倍思", "亲"),  # 王维《九月九日忆山东兄弟》
    ("白日依山尽,黄河入海", "流"),    # 王之涣《登鹳雀楼》
    ("锄禾日当午,汗滴禾下", "土"),    # 李绅《悯农》
    ("春眠不觉晓,处处闻啼", "鸟"),    # 孟浩然《春晓》
    ("红豆生南国,春来发几", "枝"),    # 王维《相思》
    ("慈母手中线,游子身上", "衣"),    # 孟郊《游子吟》(测试集)
    ("离离原上草,一岁一枯", "荣")     # 白居易《赋得古原草送别》(测试集)
]

# ------------------------------
# 2. 数据预处理(构建词汇表+转换索引+Padding)
# ------------------------------
# 提取所有字符,构建词汇表(去重+排序)
chars = set()
for input_seq, target in data:
    chars.update(input_seq)  # 加入输入序列的所有字
    chars.add(target)        # 加入目标字
chars = sorted(chars)  # 排序保证索引稳定
vocab_size = len(chars)
print(f"【词汇表信息】大小:{vocab_size},包含字符:{chars}")

# 构建字符-索引映射(用于文字转索引)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}  # 索引转字符(用于结果解码)

# 计算输入序列的最大长度(统一序列长度,短序列补0)
max_len = max(len(input_seq) for input_seq, _ in data)
print(f"【序列信息】输入序列最大长度:{max_len}")

# 将输入序列转换为索引序列,并Padding到最大长度
X = []  # 输入序列(索引)
y = []  # 目标字(索引)
for input_seq, target in data:
    # 输入序列转索引(忽略未收录字符,此处假设所有字符都在词汇表中)
    input_idx = [char_to_idx[c] for c in input_seq]
    # Padding:短序列后面补0,长度统一为max_len
    input_idx_padded = pad_sequences([input_idx], maxlen=max_len, padding='post')[0]
    X.append(input_idx_padded)
    # 目标字转索引
    y.append(char_to_idx[target])

# 转换为numpy数组(Keras模型要求输入为数组)
X = np.array(X)
y = np.array(y)

# 划分训练集(80%)和测试集(20%)
split_idx = int(0.8 * len(X))
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]
print(f"【数据集划分】训练集:{len(X_train)}条,测试集:{len(X_test)}条")

# ------------------------------
# 3. 构建LSTM模型(Embedding+LSTM+Dense)
# ------------------------------
model = Sequential([
    # 嵌入层:将字符索引转换为低维向量(输入维度=词汇表大小,输出维度=128,输入长度=最大序列长度)
    Embedding(input_dim=vocab_size, output_dim=128, input_length=max_len, name="embedding"),
    # LSTM层:128个隐藏单元(处理序列依赖,保留长期记忆)
    LSTM(128, name="lstm"),
    # 输出层:词汇表大小的神经元(预测每个字符的概率),激活函数用softmax(多分类)
    Dense(vocab_size, activation='softmax', name="output")
])

# 编译模型:优化器用Adam(常用优化器),损失函数用稀疏交叉熵(目标是整数索引)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']  # 监控准确率
)

# 打印模型结构(可选,需安装pydot和graphviz)
# plot_model(model, to_file='lstm_model.png', show_shapes=True)
print("【模型结构】")
model.summary()

# ------------------------------
# 4. 训练模型
# ------------------------------
print("\n【模型训练】开始训练...")
history = model.fit(
    X_train, y_train,          # 训练集输入和目标
    epochs=30,                 # 训练轮次(可调整,越大越可能过拟合)
    batch_size=2,              # 批次大小(小批量更稳定)
    validation_data=(X_test, y_test),  # 验证集(监控泛化能力)
    verbose=1                  # 显示训练过程(1=显示进度条,0=不显示)
)
print("【模型训练】结束训练!")

# ------------------------------
# 5. 定义补全函数(输入文本,返回预测的字)
# ------------------------------
def complete_poem(input_text):
    """
    唐诗补全函数:输入唐诗前半部分,返回模型预测的后半部分字。
    参数:input_text(str):唐诗前半部分(如“床前明月光,疑是地上”)
    返回:str:预测的字(如“霜”),失败返回None
    """
    # 1. 检查输入长度(超过最大长度则截断)
    if len(input_text) > max_len:
        input_text = input_text[:max_len]
        print(f"【提示】输入文本过长,截断为:{input_text}")
    
    # 2. 转换输入文本为索引序列(忽略未收录字符)
    input_idx = []
    for c in input_text:
        if c in char_to_idx:
            input_idx.append(char_to_idx[c])
        else:
            print(f"【错误】输入包含未收录字符:{c}")
            return None
    
    # 3. Padding到最大长度(后面补0)
    input_idx_padded = pad_sequences([input_idx], maxlen=max_len, padding='post')[0]
    
    # 4. 模型预测(输入为[1, max_len]的数组)
    predictions = model.predict(np.array([input_idx_padded]), verbose=0)[0]
    
    # 5. 取概率最大的字符索引(argmax返回最大值的位置)
    predicted_idx = np.argmax(predictions)
    
    # 6. 转换索引为字符(返回结果)
    return idx_to_char[predicted_idx]

# ------------------------------
# 6. 测试补全函数(用测试集数据验证)
# ------------------------------
print("\n【模型测试】补全结果(测试集):")
# 测试集为数据集中的后2条(索引8、9)
test_cases = [
    ("慈母手中线,游子身上", "衣"),    # 预期:衣(测试集第1条)
    ("离离原上草,一岁一枯", "荣")     # 预期:荣(测试集第2条)
]

for input_text, expected in test_cases:
    predicted = complete_poem(input_text)
    if predicted:
        print(f"输入:{input_text} → 预测:{predicted}(预期:{expected})")
    else:
        print(f"输入:{input_text} → 补全失败")

三、代码使用说明

 

1. 环境准备

 

安装Python(3.7+);

安装必要库:pip install tensorflow numpy。

 

2. 运行代码

 

将代码保存为tangshi_completion.py;

在终端运行:python tangshi_completion.py。

 

3. 结果解释

 

词汇表信息:显示模型能识别的所有字符(如“床”“前”“明”等);

模型结构:显示LSTM模型的层结构(嵌入层→LSTM层→输出层);

训练过程:显示每轮训练的损失(loss)和准确率(accuracy),验证集的损失(val_loss)和准确率(val_accuracy);

补全结果:显示测试集输入唐诗前半部分的预测结果(如输入“慈母手中线,游子身上”,预测“衣”)。

 

4. 调整建议

 

增加数据:模拟更多唐诗数据(如20-30条),提高模型泛化能力;

调整参数

epochs:训练轮次(如改为50,增加训练时间,可能提高准确率);

batch_size:批次大小(如改为4,加快训练速度);

LSTM层的units:隐藏单元数量(如改为256,增加模型复杂度);

加入主题标签:如果有主题标签(如“思乡”“山水”),可以将其作为额外特征加入模型(需修改模型结构为函数式API)。

 

四、注意事项

 

序列顺序:输入序列的顺序不能打乱(如“床前明月光”不能改为“光月明前床”),否则模型无法正确捕捉顺序依赖;

未收录字符:输入文本中的字符必须在词汇表中(否则会提示错误),可以通过扩展词汇表解决;

过拟合:如果训练集准确率很高(如100%)但测试集准确率很低(如50%),说明模型过拟合,可以减少训练轮次或增加数据。

通过本代码,你可以掌握LSTM处理序列数据的核心流程,解决“唐诗补全”这类需要长期记忆的任务。

 

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐