深度学习之Attention:基于TensorFlow的Attention+ANN简单示例及说明
本文以MNIST数据为例,介绍了在TensorFlow中实现Attention+ANN的注意力网络示例。
深度学习之Attention:基于TensorFlow的Attention+ANN简单示例及说明
本文以MNIST数据为例,介绍了在TensorFlow中实现Attention+ANN的注意力网络示例。
一、基本流程
Attention注意力机制是大名鼎鼎的Transformer模型的核心,其主要用于seq2seq任务。与循环神经网络如LSTM/GRU相比,在处理序列到序列的问题时,Attention能更好的考虑整个输入序列的信息。
1.数据准备
此处依然使用MNIST手写数字集,加载方法如下:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
同样需要注意的是,我们需要将数据样本整理成 序列长度 × 特征维度 序列长度 × 特征维度 序列长度×特征维度 的形式。
2.模型设计
在此示例中,我们设计的网络结构,包含一个输入层,然后接Attention层,之后再接一个包含100个神经节点的隐藏层,最后接输出层。网络结构的简易视图如下,其中为了简便起见,将注意力层和隐藏层、隐藏层和输出层的全连接都只用了一个箭头来表示。
# 设计网络结构
inputs = tf.keras.Input(shape=(28*28,1))
atten = tf.keras.layers.Attention(use_scale=True)([inputs,inputs],training=False)
flatten = tf.keras.layers.Flatten()(atten) # 将Attention层的输出展平
hidden = tf.keras.layers.Dense(100,activation="relu")(flatten)
outputs = tf.keras.layers.Dense(10,activation="sigmoid")(hidden)
# 构建模型
model = tf.keras.models.Model(inputs=inputs,outputs=outputs)
同样需要注意的是,Attention层出来后,需要将其输出展平(Flatten),再接到隐藏层,不然将导致网络输出格式不匹配。
3.模型编译
# 设定模型训练配置
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 定义代价函数
model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])
# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
patience=3,
start_from_epoch=0)
在这部分,定义了代价函数、优化方法等配置,还通过tf.keras.callbacks.EarlyStopping
设定了提前停止条件。关于早停机制的详细说明可参看《深度学习之Early Stopping:TensorFlow中的早停机制示例及说明》。
4.模型训练
# 训练模型
history = model.fit(x_train,y_train,
epochs=5,
batch_size=20,
validation_data=(x_test,y_test),
callbacks=es_callback)
如下是训练过程中的模型输出。可以看到每一轮训练完时,训练数据上的损失、准确性和验证数据上的损失和准确性。其中可见每轮训练的耗时非常长,超过了450秒。所有在本次示例中将 epoch
的值设置为了 5 5 5,以减少训练需要的时间。
二、完整代码
import tensorflow as tf
def main():
# 加载MNIST数据库
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 整理输入数据[样本数,序列长度,特征维度]
x_train,x_test = tf.reshape(x_train,[len(x_train),28*28,1]),tf.reshape(x_test,[len(x_test),28*28,1])
# 设计网络结构
inputs = tf.keras.Input(shape=(28*28,1))
atten = tf.keras.layers.Attention(use_scale=True)([inputs,inputs],training=False)
flatten = tf.keras.layers.Flatten()(atten)
hidden = tf.keras.layers.Dense(100,activation="relu")(flatten)
outputs = tf.keras.layers.Dense(10,activation="sigmoid")(hidden)
# 构建模型
model = tf.keras.models.Model(inputs=inputs,outputs=outputs)
# 设定模型训练配置
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 定义代价函数
model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])
# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
patience=3,
start_from_epoch=0)
# 训练模型
history = model.fit(x_train,y_train,
epochs=5,
batch_size=20,
validation_data=(x_test,y_test),
callbacks=es_callback)
# 保存模型
# tf.keras.models.save_model(model,filepath='./model/attention/')
if __name__=='__main__':
main()
三、参考资料
[1] TensorFlow API 文档:注意力层
[2] << LEEDL TUTORIAL 深度学习详解:第六章 自注意力机制>>
更多推荐
所有评论(0)