Cross-Attention 公式详解

1. 定义

Cross-Attention(交叉注意力)常用于 Transformer 解码器、Encoder-Decoder 结构中,用于将一个序列(如目标语言)对另一个序列(如源语言)的表示进行对齐和关注。

2. 公式结构

我们有:
• 查询向量来自 Decoder 的当前输入(Q)
• 键和值向量来自 Encoder 的输出(K 和 V)

3. 公式

Cross-Attention 核心计算过程

  1. 输入维度:

    • 查询矩阵:Q∈RTq×dQ \in \mathbb{R}^{T_q \times d}QRTq×d
    • 键矩阵:K∈RTk×dK \in \mathbb{R}^{T_k \times d}KRTk×d
    • 值矩阵:V∈RTk×dV \in \mathbb{R}^{T_k \times d}VRTk×d
  2. 计算注意力权重:

Attention(Q,K,V)=softmax(QK⊤dk)V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QK)V

  • QQQ 是 Decoder 的投影输出
  • K,VK, VK,V 是 Encoder 的输出进行线性变换后的结果
  • dkd_kdk 是键向量的维度(通常是 d/hd / hd/hhhh 是头数)
  1. 多头注意力(Multi-Head Attention):

MultiHead(Q,K,V)=Concat(head1,...,headh)WO \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中每个 head 的计算:

headi=Attention(QWiQ,KWiK,VWiV) \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)

  • WiQ,WiK,WiV∈Rd×dhW^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d \times d_h}WiQ,WiK,WiVRd×dh 是每个头独立的投影矩阵
  • WO∈Rhdh×dW^O \in \mathbb{R}^{hd_h \times d}WORhdh×d 是输出投影矩阵

4. 使用场景说明

•	Encoder-Decoder:Decoder 的每一层利用 Cross-Attention 获取对 Encoder 输出的关注
•	多模态:文本 Query + 图像 Key/Value,实现跨模态融合
•	检索增强:Query 结合外部知识库的 Key/Value 进行对齐匹配

5.代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # Linear projection for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        query: [batch_size, tgt_len, embed_dim] - from decoder
        key:   [batch_size, src_len, embed_dim] - from encoder
        value: [batch_size, src_len, embed_dim] - from encoder
        mask:  [batch_size, tgt_len, src_len] (optional)
        """

        B, T_q, _ = query.size()
        T_k = key.size(1)

        # Project Q, K, V
        Q = self.q_proj(query)  # [B, T_q, embed_dim]
        K = self.k_proj(key)    # [B, T_k, embed_dim]
        V = self.v_proj(value)  # [B, T_k, embed_dim]

        # Split into heads
        Q = Q.view(B, T_q, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_q, head_dim]
        K = K.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_k, head_dim]
        V = V.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_k, head_dim]

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [B, heads, T_q, T_k]

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)  # [B, heads, T_q, T_k]
        attn = self.dropout(attn)

        context = torch.matmul(attn, V)  # [B, heads, T_q, head_dim]
        context = context.transpose(1, 2).contiguous().view(B, T_q, self.embed_dim)  # [B, T_q, embed_dim]

        output = self.out_proj(context)  # [B, T_q, embed_dim]
        return output

6.使用实例

decoder_query = torch.randn(8, 10, 512)  # 来自 decoder
encoder_key_value = torch.randn(8, 20, 512)  # 来自 encoder

cross_attn = CrossAttention(embed_dim=512, num_heads=8)
out = cross_attn(decoder_query, encoder_key_value, encoder_key_value)
print(out.shape)  # [8, 10, 512]
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐