《第二十二篇》Cross-Attention:注意力机制详解-代码实现和公式详解《二》
本文详细介绍了Cross-Attention的原理、公式和实现。Cross-Attention用于Transformer解码器和Encoder-Decoder结构,通过查询向量(Q)与键、值向量(K、V)的交互实现序列对齐。核心公式为softmax(QK^T/√d_k)V,其中Q来自Decoder,K、V来自Encoder。文章还给出了多头注意力机制公式和PyTorch实现代码,包括投影计算、注意
Cross-Attention 公式详解
1. 定义
Cross-Attention(交叉注意力)常用于 Transformer 解码器、Encoder-Decoder 结构中,用于将一个序列(如目标语言)对另一个序列(如源语言)的表示进行对齐和关注。
2. 公式结构
我们有:
• 查询向量来自 Decoder 的当前输入(Q)
• 键和值向量来自 Encoder 的输出(K 和 V)
3. 公式
Cross-Attention 核心计算过程
-
输入维度:
- 查询矩阵:Q∈RTq×dQ \in \mathbb{R}^{T_q \times d}Q∈RTq×d
- 键矩阵:K∈RTk×dK \in \mathbb{R}^{T_k \times d}K∈RTk×d
- 值矩阵:V∈RTk×dV \in \mathbb{R}^{T_k \times d}V∈RTk×d
-
计算注意力权重:
Attention(Q,K,V)=softmax(QK⊤dk)V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dkQK⊤)V
- QQQ 是 Decoder 的投影输出
- K,VK, VK,V 是 Encoder 的输出进行线性变换后的结果
- dkd_kdk 是键向量的维度(通常是 d/hd / hd/h,hhh 是头数)
- 多头注意力(Multi-Head Attention):
MultiHead(Q,K,V)=Concat(head1,...,headh)WO \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个 head 的计算:
headi=Attention(QWiQ,KWiK,VWiV) \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)
- WiQ,WiK,WiV∈Rd×dhW^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d \times d_h}WiQ,WiK,WiV∈Rd×dh 是每个头独立的投影矩阵
- WO∈Rhdh×dW^O \in \mathbb{R}^{hd_h \times d}WO∈Rhdh×d 是输出投影矩阵
4. 使用场景说明
• Encoder-Decoder:Decoder 的每一层利用 Cross-Attention 获取对 Encoder 输出的关注
• 多模态:文本 Query + 图像 Key/Value,实现跨模态融合
• 检索增强:Query 结合外部知识库的 Key/Value 进行对齐匹配
5.代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(CrossAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
# Linear projection for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Output projection
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
query: [batch_size, tgt_len, embed_dim] - from decoder
key: [batch_size, src_len, embed_dim] - from encoder
value: [batch_size, src_len, embed_dim] - from encoder
mask: [batch_size, tgt_len, src_len] (optional)
"""
B, T_q, _ = query.size()
T_k = key.size(1)
# Project Q, K, V
Q = self.q_proj(query) # [B, T_q, embed_dim]
K = self.k_proj(key) # [B, T_k, embed_dim]
V = self.v_proj(value) # [B, T_k, embed_dim]
# Split into heads
Q = Q.view(B, T_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_q, head_dim]
K = K.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_k, head_dim]
V = V.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_k, head_dim]
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, heads, T_q, T_k]
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1) # [B, heads, T_q, T_k]
attn = self.dropout(attn)
context = torch.matmul(attn, V) # [B, heads, T_q, head_dim]
context = context.transpose(1, 2).contiguous().view(B, T_q, self.embed_dim) # [B, T_q, embed_dim]
output = self.out_proj(context) # [B, T_q, embed_dim]
return output
6.使用实例
decoder_query = torch.randn(8, 10, 512) # 来自 decoder
encoder_key_value = torch.randn(8, 20, 512) # 来自 encoder
cross_attn = CrossAttention(embed_dim=512, num_heads=8)
out = cross_attn(decoder_query, encoder_key_value, encoder_key_value)
print(out.shape) # [8, 10, 512]
更多推荐
所有评论(0)