Attention 机制简单来说就是给定 Q(Query), K(Key), V(Value),通过 Query 和 Key 的匹配程度来决定从 Value 中提取多少信息(就是一个加权求和的过程)。这个可以参考数据库中的查询,根据查询键 Q 去匹配数据库中的键 K,找到对应的记录并取出该记录的值 V。Attention 机制与此类似,先通过 $ Q \cdot K^T $ 来计算相似度,由于$Q \cdot K^T$的结果是一个实数值向量,它的取值范围可能会很大,所以跟 V 相乘之前还需要先进行 softmax,softmax 会把它们归一化成一个和为 1 的概率分布,因此可以写成 $softmax(Q \cdot K^T) \cdot V$。不过 Transformer 的原论文《Attention Is All You Need》中还对$Q \cdot K^T$后的结果除了一个缩放因子$\sqrt{d_k}$,因为随着 $d_k$ 增大,点积的方差会变大,softmax 更容易饱和,导致梯度变小,而除以 $\sqrt{d_k}$ 可以避免这种饱和。由此可以得出最终的公式为 $Attention(Q,K,V)=softmax(\frac {Q \cdot K^T}{\sqrt{d_k}}) \cdot V$。由于加了“缩放因子”,所以这种Attention机制也叫Scaled Dot-Product Attention机制。

Self-Attention vs. Cross-Attention

Self-Attention 之所以叫 “Self”,就是因为它的 Query、Key、Value 都是由同一个 x 分别通过不同的线性层(k_proj、q_proj、v_proj)来得到的。而 Cross-Attention 的 “Cross” 则是因为在 Transformer 中,它的 Key、Value 来自 Encoder 中的输出,而 Query 则是 Decoder 中的输入。

Multi-Head Attention(多头注意力机制, MHA)

Transformer 中用的其实是 MHA,他就是把 Q、K、V 都分成多个头,在每个头中都分别进行一次 Attention 操作,最后再把所有头的结果拼起来。这样的好处是可以不同的头关注不同的子空间,从而提取出不同的特征信息。具体的做法如下:

import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads               # 头数
        self.head_dim = hidden_size // num_heads # 每个头的维度

        # Q、K、V 投影矩阵
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        
        # 输出线性层
        self.o_proj = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, attention_mask=None):
        batch_size = x.size()[0]

        # 输入的 x 形状为(batch_size, seq_len, hidden_size)

        # 计算 Q、K、V
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        # 经过投影层后,Q、K、V 的形状为(batch_size, seq_len, hidden_size)

        # 把 Q、K、V 都分成多个头
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        # 经过分头操作后,Q、K、V 的形状为(batch_size, num_heads, seq_len, head_dim)

        # 计算 Attention
        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        if attention_mask is not None: # attention_mask 矩阵中为 0 的位置不关注
            attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf"))
        # attn_scores 的形状为(batch_size, num_heads, seq_len, seq_len)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        # attn_probs 的形状为(batch_size, num_heads, seq_len, seq_len)
        output = attn_probs @ v
        # output 的形状为(batch_size, num_heads, seq_len, head_dim)

        # 合并所有头的结果
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size)
        # output 的形状变回(batch_size, seq_len, hidden_size)
        output = self.o_proj(output)
        return output

MQA(Multi-Query Attention)

MQA 是 MHA 的一个变体,它的多个 Query 共享同一套 Key、Value,这样只需要保留一组 Key、Value 即可,可以减少参数量、显著减少推理时的 KV cache 和内存带宽开销,但表达能力可能会降低。具体的实现如下(大部分跟 MHA 是一样的,除了 k_proj/v_proj 的输出维度和分头操作):

import math
import torch
import torch.nn as nn

class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        # Q、K、V 投影矩阵
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, self.head_dim)
        self.v_proj = nn.Linear(hidden_size, self.head_dim)

        # 输出线性层
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, attention_mask=None):
        batch_size = x.size()[0]
        # 输入的 x 形状为(batch_size, seq_len, hidden_size)

        # 计算 Q、K、V
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        # 经过投影后,Q 的形状为(batch_size, seq_len, hidden_size)
        # K、V 的形状为(batch_size, seq_len, head_dim)

        # 把 Q 分成多个头, K、V reshape 成一个共享头(之后广播给所有的 query head 使用)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, 1, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, 1, self.head_dim).transpose(1, 2)
        # 经过分头操作后,Q 的形状为(batch_size, num_heads, seq_len, head_dim)
        # K、V 的形状为(batch_size, 1, seq_len, head_dim)

        # 计算 Attention
        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        # 计算 q @ k^T 时,k^T 会在 num_heads 维度进行广播
        # 最后得到的 attn_scores 的形状为(batch_size, num_heads, seq_len, seq_len)
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf"))
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = attn_probs @ v

        # 合并所有头的结果
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size)
        output = self.o_proj(output)
        return output

GQA(Group-Query Attention)

GQA 实际是 MHA 和 MQA 之间的一种 trade-off,它把 Query 分成多个 group,每个 group 里的 Query 共享同一套 Key、Value。 具体的实现如下所示

import math
import torch
import torch.nn as nn

class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, num_groups):
        super().__init__()
        self.hidden_size = hidden_size           # 隐藏层维度
        self.num_heads = num_heads               # 头数
        self.num_groups = num_groups             # 组数
        self.head_dim = hidden_size // num_heads # 每个头的维度
        # 每个组有 self.num_heads // self.num_groups 个头

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, self.head_dim * num_groups)
        self.v_proj = nn.Linear(hidden_size, self.head_dim * num_groups)

        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, attention_mask=None):
        batch_size = x.size()[0]
        
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = (k.view(batch_size, -1, self.num_groups, 1, self.head_dim)
            .expand(batch_size, -1, self.num_groups, self.num_heads // self.num_groups, self.head_dim)
            .reshape(batch_size, -1, self.num_heads, self.head_dim)
            .transpose(1, 2))
        v = (v.view(batch_size, -1, self.num_groups, 1, self.head_dim)
            .expand(batch_size, -1, self.num_groups, self.num_heads // self.num_groups, self.head_dim)
            .reshape(batch_size, -1, self.num_heads, self.head_dim)
            .transpose(1, 2))
        """
        # 直接写成下面这种方法更简洁
        k = self.k_proj(x).view(batch_size, -1, self.num_groups, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, -1, self.num_groups, self.head_dim).transpose(1, 2)
        
        k = k.repeat_interleave(self.num_heads // self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_heads // self.num_groups, dim=1)
        """

        # 计算 Attention
        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf"))
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = attn_probs @ v

        # 合并所有头的结果
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size)
        output = self.o_proj(output)
        return output