用Python 实现 transformer attention

最后发布时间 : 2026-05-09 16:26:52 浏览量 :

下面我会从 “为什么需要 Attention” → “数学公式” → “Python 手写实现” → “训练” → “完整 Transformer Block” 一步一步讲。

目标不是“调用库”,而是你真正理解:

  • Attention 到底在算什么
  • Query / Key / Value 是什么
  • 为什么要除以 √d
  • softmax 在干什么
  • loss 怎么定义
  • 训练时梯度如何更新
  • Transformer 为什么能替代 RNN

我会先从最小版本开始,只用:

  • numpy
  • 后面再用 torch

这样你能真正理解原理。


一、Attention 到底解决什么问题

先看一句话:

"Tom likes cats because he has three of them"

模型在处理 "he" 时:

  • "he" 指的是谁?
  • 是 Tom
  • 不是 cats

传统 RNN:

  • 信息要一步一步传递
  • 距离远容易遗忘

Attention:

  • 直接“看”所有 token
  • 自动决定重点关注谁

所以:

Attention 本质:

“当前词应该关注输入中的哪些词”


二、Attention 的核心思想

假设输入:

I love transformer

token:

["I", "love", "transformer"]

每个 token 会变成向量:

I            -> x1
love         -> x2
transformer  -> x3

Attention 会让:

  • 每个词
  • 与所有词
  • 计算相关性

三、Q / K / V 到底是什么

这是 Transformer 最核心部分。

输入向量:

X

通过三个线性层:

Q = XWq
K = XWk
V = XWv

得到:

  • Query(我要找什么)
  • Key(我是什么)
  • Value(我的信息)

可以理解:

名称类比
Query搜索词
Key标签
Value内容

类似搜索引擎:

Query: "cat"

Key:
doc1: "dog"
doc2: "cat"

匹配度:
doc2 更高

于是更多关注 doc2。


四、Attention 数学公式

核心公式:

Attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V

这是 Transformer 灵魂。

下面一步一步拆。


五、第一步:计算相似度

1. 点积

QK^T

表示:

  • Query 与 Key 做相似度计算

例如:

Q = [1, 2]
K = [2, 1]

点积:

1×2 + 2×1 = 4

越大:

  • 越相关

六、为什么除以 √d

公式:

\frac{QK^T}{\sqrt{d_k}}

原因:

维度大时:

  • 点积会越来越大
  • softmax 会进入极端区域
  • 梯度消失

举例:

[10, 11, 9]

softmax:

[0.24, 0.66, 0.09]

如果:

[1000, 1100, 900]

softmax:

[0, 1, 0]

梯度几乎没了。

所以:

\sqrt{d_k}

用于稳定训练。


七、softmax 在干什么

softmax:

softmax(x_i)=\frac{e^{x_i}}{\sum_j e^{x_j}}

作用:

把分数变成概率。

例如:

[2, 1, 0]

变成:

[0.67, 0.24, 0.09]

表示:

  • 67% 注意第一个词
  • 24% 注意第二个词

八、最后乘 V

公式最后:

softmax(...)V

表示:

  • 用 attention 权重
  • 加权求和 Value

例如:

weights = [0.8, 0.2]

V1 = [1, 1]
V2 = [10, 10]

结果:

0.8*V1 + 0.2*V2

九、手写最小 Attention(numpy)

现在开始真正实现。


十、第一版 Attention

import numpy as np

# 输入序列
X = np.array([
    [1, 0, 1, 0],
    [0, 2, 0, 2],
    [1, 1, 1, 1]
], dtype=np.float32)

# 随机初始化权重
Wq = np.random.randn(4, 4)
Wk = np.random.randn(4, 4)
Wv = np.random.randn(4, 4)

# 生成 Q K V
Q = X @ Wq
K = X @ Wk
V = X @ Wv

# attention score
scores = Q @ K.T

# 缩放
dk = K.shape[-1]
scores = scores / np.sqrt(dk)

# softmax
exp_scores = np.exp(scores)
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

# 输出
output = attention_weights @ V

print(output)

十一、这里发生了什么

1. X

输入 token embedding:

[seq_len, hidden_dim]

这里:

[3, 4]

表示:

  • 3 个 token
  • 每个 4 维

2. Q K V

线性变换:

Q = XWq

本质:

神经网络全连接层。


3. score

Q @ K.T

结果:

[3, 3]

表示:

每个 token 对其他 token 的关注程度。


十二、Mask 是什么

语言模型训练时:

不能偷看未来。

例如:

I love ?

预测:

transformer

不能提前看到 transformer。

所以需要:

Causal Mask

矩阵:

\begin{bmatrix} 1 & 0 & 0 \ 1 & 1 & 0 \ 1 & 1 & 1 \end{bmatrix}

未来位置设为:

-1e9

softmax 后接近 0。


十三、加入 Mask

mask = np.triu(np.ones((3, 3)), k=1)

scores = scores - mask * 1e9

十四、Multi-Head Attention

为什么多头?

因为:

不同 head 学不同关系:

  • 一个学语法
  • 一个学位置
  • 一个学语义

公式:

head_i = Attention(Q_i,K_i,V_i)

最后:

Concat(head_1,...,head_h)W^O


十五、PyTorch 实现 Attention

现在开始真正深度学习版本。


十六、单头 Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):

    def __init__(self, dim):
        super().__init__()

        self.Wq = nn.Linear(dim, dim)
        self.Wk = nn.Linear(dim, dim)
        self.Wv = nn.Linear(dim, dim)

    def forward(self, x):

        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        scores = Q @ K.transpose(-2, -1)

        scores = scores / (K.size(-1) ** 0.5)

        attn = F.softmax(scores, dim=-1)

        out = attn @ V

        return out

十七、测试

x = torch.randn(2, 5, 32)

model = SelfAttention(32)

y = model(x)

print(y.shape)

输出:

[2, 5, 32]

含义:

[batch, seq_len, hidden]

十八、Transformer Block

Transformer 不只是 attention。

完整结构:

Attention
   ↓
Add & Norm
   ↓
Feed Forward
   ↓
Add & Norm

十九、FFN 是什么

FFN:

Linear
ReLU
Linear

作用:

  • 非线性变换
  • 提升表达能力

公式:

FFN(x)=W_2(ReLU(W_1x))


二十、完整 Block

class TransformerBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()

        self.attn = SelfAttention(dim)

        self.norm1 = nn.LayerNorm(dim)

        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )

        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):

        x = x + self.attn(x)

        x = self.norm1(x)

        x = x + self.ffn(x)

        x = self.norm2(x)

        return x

二十一、为什么 Add & Norm

残差连接

x + attention(x)

作用:

避免深层网络梯度消失。


LayerNorm

标准化:

让训练稳定。


二十二、Loss 怎么定义

Transformer 语言模型:

本质:

下一个 token 预测

例如:

输入:

I love

目标:

transformer

二十三、输出 logits

模型输出:

[batch, seq_len, vocab_size]

例如:

[2, 5, 10000]

表示:

每个位置预测 10000 个词概率。


二十四、Cross Entropy Loss

核心公式:

Loss = -\log(p_{true})

例如:

真实词:

cat

模型预测:

[dog=0.1, cat=0.7, fish=0.2]

loss:
-\log(0.7)

如果预测正确:

  • loss 小

预测错误:

  • loss 大

二十五、训练代码

model = TransformerBlock(32)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3
)

x = torch.randn(8, 10, 32)

target = torch.randn(8, 10, 32)

for step in range(100):

    pred = model(x)

    loss = F.mse_loss(pred, target)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    print(loss.item())

二十六、真正 NLP 训练

真正 GPT:

输入:

I love transformer

target:

love transformer <eos>

即:

输入:  I
预测:  love

输入:  I love
预测:  transformer

二十七、Token Embedding

文本不能直接输入网络。

需要:

nn.Embedding

例如:

self.embedding = nn.Embedding(vocab_size, dim)

作用:

把 token id:

15

变成:

[0.1, -0.5, ...]

二十八、位置编码 Positional Encoding

Attention 本身不知道顺序。

所以:

需要加入位置。

经典公式:

PE(pos,2i)=sin(pos/10000^{2i/d})

PE(pos,2i+1)=cos(pos/10000^{2i/d})

作用:

给模型位置感。


二十九、为什么 Transformer 强

相比 RNN:

Transformer:

  • 并行训练
  • 长距离依赖强
  • GPU 利用率高

所以:

几乎统一 NLP。


三十、你接下来应该学什么

建议顺序:

第一阶段

真正吃透:

  • softmax
  • matrix multiply
  • cross entropy
  • backprop

第二阶段

手写:

  • multi-head attention
  • positional encoding
  • causal mask

第三阶段

实现:

  • mini GPT

第四阶段

理解:

  • KV Cache
  • FlashAttention
  • RoPE
  • RMSNorm
  • MoE

三十一、最重要的一句话

Transformer 本质:

用 Query 与 Key 计算相关性,再对 Value 做加权聚合。

公式:

softmax(QK^T/\sqrt{d})V

这一个公式,
基本统治了现代 AI。