深度学习之Attention:基于TensorFlow的Attention+ANN简单示例及说明

  本文以MNIST数据为例,介绍了在TensorFlow中实现Attention+ANN的注意力网络示例。

一、基本流程

  Attention注意力机制是大名鼎鼎的Transformer模型的核心,其主要用于seq2seq任务。与循环神经网络如LSTM/GRU相比,在处理序列到序列的问题时,Attention能更好的考虑整个输入序列的信息。
  1.数据准备
  此处依然使用MNIST手写数字集,加载方法如下:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

  同样需要注意的是,我们需要将数据样本整理成 序列长度 × 特征维度 序列长度 × 特征维度 序列长度×特征维度 的形式。

  2.模型设计
  在此示例中,我们设计的网络结构,包含一个输入层,然后接Attention层,之后再接一个包含100个神经节点的隐藏层,最后接输出层。网络结构的简易视图如下,其中为了简便起见,将注意力层和隐藏层、隐藏层和输出层的全连接都只用了一个箭头来表示。
在这里插入图片描述

# 设计网络结构
inputs  = tf.keras.Input(shape=(28*28,1))
atten   = tf.keras.layers.Attention(use_scale=True)([inputs,inputs],training=False)
flatten = tf.keras.layers.Flatten()(atten)	# 将Attention层的输出展平
hidden  = tf.keras.layers.Dense(100,activation="relu")(flatten)
outputs = tf.keras.layers.Dense(10,activation="sigmoid")(hidden)

# 构建模型
model = tf.keras.models.Model(inputs=inputs,outputs=outputs)

同样需要注意的是,Attention层出来后,需要将其输出展平(Flatten),再接到隐藏层,不然将导致网络输出格式不匹配。

  3.模型编译

# 设定模型训练配置
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 定义代价函数
model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])

# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                               patience=3,
                                               start_from_epoch=0)

  在这部分,定义了代价函数、优化方法等配置,还通过tf.keras.callbacks.EarlyStopping设定了提前停止条件。关于早停机制的详细说明可参看《深度学习之Early Stopping:TensorFlow中的早停机制示例及说明》

  4.模型训练

# 训练模型
history = model.fit(x_train,y_train,
                    epochs=5,
                    batch_size=20,
                    validation_data=(x_test,y_test),
                    callbacks=es_callback)

  如下是训练过程中的模型输出。可以看到每一轮训练完时,训练数据上的损失、准确性和验证数据上的损失和准确性。其中可见每轮训练的耗时非常长,超过了450秒。所有在本次示例中将 epoch 的值设置为了 5 5 5,以减少训练需要的时间。
在这里插入图片描述

二、完整代码

import tensorflow as tf

def main():
    # 加载MNIST数据库
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # 整理输入数据[样本数,序列长度,特征维度]
    x_train,x_test = tf.reshape(x_train,[len(x_train),28*28,1]),tf.reshape(x_test,[len(x_test),28*28,1])

    # 设计网络结构
    inputs  = tf.keras.Input(shape=(28*28,1))
    atten   = tf.keras.layers.Attention(use_scale=True)([inputs,inputs],training=False)
    flatten = tf.keras.layers.Flatten()(atten)
    hidden  = tf.keras.layers.Dense(100,activation="relu")(flatten)
    outputs = tf.keras.layers.Dense(10,activation="sigmoid")(hidden)

    # 构建模型
    model = tf.keras.models.Model(inputs=inputs,outputs=outputs)

    # 设定模型训练配置
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 定义代价函数
    model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])

    # 设定提前停止条件
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                                   patience=3,
                                                   start_from_epoch=0)
    # 训练模型
    history = model.fit(x_train,y_train,
                        epochs=5,
                        batch_size=20,
                        validation_data=(x_test,y_test),
                        callbacks=es_callback)
    
    # 保存模型
    # tf.keras.models.save_model(model,filepath='./model/attention/')
    
if __name__=='__main__':
  main()

三、参考资料

[1] TensorFlow API 文档:注意力层
[2] << LEEDL TUTORIAL 深度学习详解:第六章 自注意力机制>>

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐