# 写在前面

本文涉及到的张量计算、reshape 等操作均使用 einops 库来实现,如果你还不会使用,强烈推荐花费 5 分钟学习一下,能大幅提升代码可读性。

einops_install
1
2
3
4
5
6
7
8
pip install einops
# 本文使用的符号含义:
# b: batch size
# s: sequence length
# d: dimension
# h: number of heads
# t: 符号被占用的替代品
from einops import rearrange, repeat, einsum

# 大模型相关手撕题

按重要度排序:(但一般就考个 mha)

# 自注意力

Attn(Q,K,V)=softmax(QKd)V\text{Attn}(Q,K,V)=\text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V

还有一种交叉注意力,Q 来自 query 序列,K/V 来自另一个序列,只需要把 forward 函数的 单独传输 k,v 即可再算即可。

self_attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
from math import sqrt
from einops import einsum, rearrange, repeat

class SelfAttention(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)

def forward(self, x, mask = None):
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)

score = einsum(q, k, "b s d, b t d -> b s t") / sqrt(self.d_model)
if mask is not None:
score = score.masked_fill(mask == 0, -1e9)
attn = torch.softmax(score, dim = -1)
out = einsum(attn, v, "b s t, b t d -> b s d")
return self.w_o(out)
def main():
torch.manual_seed(0)
b, s, d = 2, 5, 32
x = torch.randn(b, s, d)
mask = torch.tril(torch.ones(s, s)).view(1, s, s)
m = SelfAttention(d)
y = m(x, mask)
assert y.shape == (b, s, d)
print("OK: self-attention shape")

if __name__ == "__main__":
main()

# 多头注意力机制

将 d 维的 Q、K、V 分成 h 组,每组维度为 d/h,分别计算注意力,实际就是一个 reshape

multihead_attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from math import sqrt
import torch
from einops import einsum, rearrange
import torch.nn as nn

class MultiHeadAttention(nn.Module):
def __init__(self, d_model :int, num_heads : int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.dim_heads = d_model // num_heads
self.num_heads = num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask = None):
batch_size = q.size(0)

q = rearrange(self.w_q(q), 'b s (h d) -> b h s d', h=self.num_heads, d=self.dim_heads)
k = rearrange(self.w_k(k), 'b s (h d) -> b h s d', h=self.num_heads, d=self.dim_heads)
v = rearrange(self.w_v(v), 'b s (h d) -> b h s d', h=self.num_heads, d=self.dim_heads)

scores = einsum(q, k, "b h s d, b h t d -> b h s t") / sqrt(self.dim_heads)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weight = torch.softmax(scores, dim = -1)
output = einsum(attn_weight, v, "b h s t, b h t d -> b h s d")
output = rearrange(output, "b h s d -> b s (h d)", h = self.num_heads, d = self.dim_heads,)
return self.w_o(output)


def main():
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model = d_model, num_heads= num_heads)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
output = mha(q, k, v, mask)
print(f"输入维度:{q.shape}")
print(f"输出维度:{output.shape}")

if __name__ == "__main__":
main()

# 分组查询注意力机制

多头注意力的变种,其中 query 按照 num_heads 分组,而 key 和 value 则按照 num_kv_heads 分组。num_heads 要大于 num_kv_heads,计算时拓展一下 shape 即可。

group_query_attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
from einops import einsum, rearrange, repeat

class GroupQueryAttention(nn.Module):
def __init__(self, d_model : int, num_heads : int, num_kv_heads : int):
super().__init__()
assert d_model % num_heads == 0
assert num_heads % num_kv_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_heads
self.kv_group = num_heads // num_kv_heads

self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, self.num_kv_heads * self.head_dim)
self.v_proj = nn.Linear(d_model, self.num_kv_heads * self.head_dim)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask):
B, S, D = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = rearrange(q, "b s (h d) -> b h s d", h = self.num_heads)
k = rearrange(k, "b s (h d) -> b h s d", h = self.num_kv_heads)
v = rearrange(v, "b s (h d) -> b h s d", h = self.num_kv_heads)

k_extend = repeat(k, "b h s d -> b (h g) s d", g = self.kv_group)
v_extend = repeat(v, "b h s d -> b (h g) s d", g = self.kv_group)

scores = einsum(q, k_extend, "b h s d, b h t d -> b h s t") / self.head_dim ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
scores = torch.softmax(scores, dim = -1)

out = einsum(scores, v_extend, "b h i j, b h j d -> b h i d")
out = rearrange(out, "b h s d -> b s (h d)")
out = self.o_proj(out)
return out

def main():
B, S, D = 2, 8, 64
head_q = 8
head_kv = 2
x = torch.randn(B, S, D)
gqa = GroupQueryAttention(d_model=D, num_heads=head_q, num_kv_heads= head_kv)
mask = torch.tril(torch.ones(S, S)).unsqueeze(0).unsqueeze(0)
out = gqa(x, mask)
print(f"input shape:{x.shape}")
print(f"output shape:{out.shape}")

if __name__ == "__main__":
main()

# safe softmax

减去最大值,避免数值过大导致溢出。

softmax(xi)=exijexj=eximax(x)jexjmax(x)\text{softmax}(x_i)=\frac{e^{x_i}}{\sum_j e^{x_j}}\\ =\frac{e^{x_i-\max(x)}}{\sum_j e^{x_j-\max(x)}}

stable_softmax
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from einops import einsum, rearrange
def safe_softmax(scores : torch.tensor, dim: int = -1) -> torch.tensor:
x_max = scores.max(dim = dim, keepdim = True).values
x_shifted = scores - x_max
exp_x = torch.exp(x_shifted)
softmax_x = exp_x / exp_x.sum(dim = dim, keepdim= True)
return softmax_x

def main():
torch.manual_seed(0)
x = torch.randn(4, 7) * 20
y1 = safe_softmax(x, dim=-1)
y2 = torch.softmax(x, dim=-1)
assert torch.allclose(y1, y2, atol=1e-6, rtol=1e-6)
print("OK: stable_softmax matches torch.softmax")
if __name__ == "__main__":
main()

# 熵 Entropy 与交叉熵 Cross Entropy

熵是信息论中的一个重要概念,用于衡量一个随机变量的不确定性。对于离散随机变量 X,其熵定义为:

H(X)=iP(xi)logP(xi)H(X) = -\sum_{i} P(x_i) \log P(x_i)

其中,P (x_i) 是随机变量 X 取值为 x_i 的概率。
本文实现的熵函数,输入是 logits,logits 是模型最后一层未经过归一化的原始输出张量,是计算概率之前的原始得分.
用到的 log_softmax 函数直接对 softmax 和 log 进行合并计算,数值更稳定。
交叉熵 (one-hot):

CE(y,p)=iyilogpi\mathrm{CE}(y,p)=-\sum_i y_i \log p_i

entropy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F

def entropy(logits : torch.tensor, dim : int = -1) -> torch.tensor:
logp = F.log_softmax(logits, dim = dim)
p = torch.exp(logp)
return -(p * logp).sum(dim = dim)
def cross_entropy(logits : torch.tensor, targets : torch.tensor, dim : int = -1) -> torch.tensor:
"""
参数说明:
- logits: 模型输出的原始预测值(未经过softmax),shape一般为[batch_size, num_classes]
- targets: 目标标签,支持两种形式:
1. 类别索引(整数),shape为[batch_size]
2. 独热编码(one-hot),shape为[batch_size, num_classes]
- dim: 计算softmax的维度,默认最后一维
返回:
- 每个样本的交叉熵值,shape为[batch_size]
在官方实现中,返回的是批次的均值,要么调用官方函数时指定 reduction='none' ,要么修改这里的实现。
"""
logp = F.log_softmax(logits, dim = dim)
if targets.dim() == logits.dim() - 1:
targets = F.one_hot(targets, num_classes = logits.size(dim)).float()
return -(targets * logp).sum(dim = dim)
# 与官方代码一致的实现:
#return -(targets * logp).sum(dim = dim).mean()

def main():
torch.manual_seed(0)
logits = torch.randn(10, 6)
H = entropy(logits)
assert H.shape == (10,)
assert torch.isfinite(H).all()
print("OK: entropy_from_logits shape & finite")

target = torch.randint(0, 6, (10,))
l1 = cross_entropy(logits, target)
l2 = F.cross_entropy(logits, target, reduction='none')
assert torch.allclose(l1, l2, atol=1e-6, rtol=1e-6)
print("OK: cross_entropy_from_logits matches F.cross_entropy")

if __name__ == "__main__":
main()

# KL 散度

KL 散度使用的符号有点混乱。
以下实现的 kl_divergence 定义为:
对于离散概率分布 P 和 Q,KL 散度定义为:

DKL(PQ)=iP(xi)logP(xi)Q(xi)D_{KL}(P || Q) = \sum_{i} P(x_i) \log \frac{P(x_i)}{Q(x_i)}

而官方提供的接口 torch.nn.functional.kl_div 中,第一个参数是 log Q,第二个参数是 P, 详见 main 函数中的输入示例。
以下为 KL 散度的实现代码:

kl_divergence
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch

def kl_divergence(p, q):
# 计算 log(P/Q)
eps=1e-8
log_ratio = torch.log((p + eps) / (q + eps))
kl = torch.sum(p * log_ratio, dim=-1)
return kl

def k1_estimate(q, p):
logr = torch.log(p) - torch.log(q)
return -logr
def k2_estimate(q, p):
logr = torch.log(p) - torch.log(q)
return logr ** 2 / 2
def k3_estimate(q, p):
logr = torch.log(p) - torch.log(q)
return (logr.exp() - 1) - logr

def main():
torch.manual_seed(0)
p_logits = torch.randn(6, 10)
q_logits = torch.randn(6, 10)

p = torch.softmax(p_logits, dim=-1)
q = torch.softmax(q_logits, dim=-1)
logp = torch.log_softmax(p_logits, dim=-1)
logq = torch.log_softmax(q_logits, dim=-1)

# 我的实现
kld = kl_divergence(p, q)
# 官方实现,注意输入
# 我的pytorch版本较旧,没有reduction='none',所以用reduction=0再sum
kl_div = torch.kl_div(logq, p, reduction=0).sum(dim=-1)
assert torch.allclose(kld, kl_div, atol=1e-6, rtol=1e-6)
print("OK: kl_divergence matches torch.kl_div")

if __name__ == "__main__":
main()

其中,k1, k2, k3 参考 Schulman 的定义,p 和 q 和常见的顺序相反 Approximating KL Divergence

定义如下:

  • k1 估计

KL[q,p]=xq(x)logq(x)p(x)=Exq[logq(x)p(x)]k1=logq(x)p(x)=logr\text{KL}[q,p]=\sum_xq(x)\log\frac{q(x)}{p(x)}=\mathbb{E}_{x\sim q}\left[\log\frac{q(x)}{p(x)}\right] \\ k1=\log\frac{q(x)}{p(x)}=-\log r \\

无偏,方差 20,较大

  • k2 估计

k2=12(logp(x)q(x))2=12(logr)2k2=\frac{1}{2}\left(\log\frac{p(x)}{q(x)}\right)^2=\frac{1}{2}(\log r)^2 \\

有偏,方差较小

  • k3 估计

k3=(r1)logrk3=(r-1)-\log r \\

无偏且方差小

此处为该博客给出的样例代码:

joschu_kl_estimators
1
2
3
4
5
6
7
8
9
10
11
12
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
print((k.mean() - truekl) / truekl, k.std() / truekl)

# MLP with swiglu

swiglu 是如今 Qwen3 等主流模型中常用的激活函数,其公式为:

SwiGLU(x)=Swish(xW1+b1)(xW2+b2)\text{SwiGLU}(x) = \text{Swish}(xW_1 + b_1) \odot (xW_2 + b_2)

swiglu 后再乘一个权重矩阵,即为一个 MLP 层了。
主流模型已经不再使用 bias。
工程实现中,通常会把两个线性层合并成一个线性层输出两倍的维度,然后再拆分。同样的情况在 RoPE 中也有出现。

mlp_swiglu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
from einops import einsum, rearrange
import torch.nn.functional as F

class MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int
):
super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, intermediate_size * 2)
self.down_proj = nn.Linear(intermediate_size, hidden_size)

def silu(self, x):
sig = 1 / (1 + torch.exp(-x))
return x * sig

def forward(self, x):
gate_up = self.gate_up_proj(x)
x1, x2 = gate_up.chunk(2, dim=-1)
activated = self.silu(x1) * x2
out = self.down_proj(activated)
return out

def main():
torch.manual_seed(0)
batch_size = 2
seq_len = 4
hidden_size = 8
intermediate_size = 32

mlp = MLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
x = torch.randn(batch_size, seq_len, hidden_size)
out = mlp(x)
print(f"输入维度:{x.shape}")
print(f"输出维度:{out.shape}")

if __name__ == "__main__":
main()

# RoPE

1
2
import torch
from einops import einsum, rearrange, repeat

# RMSNorm

rms(x)=1djxj2+ϵ,RMSNorm(x)=γxrms(x)\mathrm{rms}(x)=\sqrt{\frac{1}{d}\sum_j x_j^2+\epsilon} \quad,\quad \text{RMSNorm}(x)=\gamma\frac{x}{\mathrm{rms}(x)}

实现中本应多用原地操作的,手撕题倒也不用那么复杂。
此外由于现在都用 pre_norm,所以 residual 直接加在输入上了。
输入的 x 在加上残差之后,会作为新的残差输入给下一层。

rmsnorm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
from einops import einsum, rearrange, repeat

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-8):
super().__init__()
self.dim = dim
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))

def forward(self, x, residual = None):
if residual is not None:
x = x.add_(residual)
rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + self.eps)
return (x / rms) * self.scale, x

# PPO

r(θ)=πθ(as)πθold(as)=exp(logπθlogπold)LPPO=E[min(rA, clip(r,1ϵ,1+ϵ)A)]r(\theta)=\frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)} = \exp(\log\pi_\theta-\log\pi_{\text{old}})\\ \mathcal{L}_{\text{PPO}}=-\mathbb{E}\left[\min(rA,\ \text{clip}(r,1-\epsilon,1+\epsilon)A)\right]

ppo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from einops import einsum, rearrange, repeat

def ppo_clip_loss(logp_new, logp_old, advantage, clip_eps=0.2):
ratio = torch.exp(logp_new - logp_old)
unclipped = ratio * advantage
clipped = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantage
return -torch.minimum(unclipped, clipped).mean()

def main():
logp_old = torch.zeros(4)
logp_new = torch.log(torch.tensor([1.0, 1.5, 0.5, 10.0])) # ratios
adv = torch.ones(4)
loss = ppo_clip_loss(logp_new, logp_old, adv, clip_eps=0.2)
expected_obj = torch.tensor([1.0, 1.2, 0.8, 1.2]).mean()
assert torch.allclose(-loss, expected_obj, atol=1e-6, rtol=1e-6)
print("OK: PPO clipped objective matches expected")

if __name__ == "__main__":
main()

# DPO

Δπ=logπ(y+x)logπ(yx),Δref=logπref(y+x)logπref(yx)LDPO=logσ(β(ΔπΔref))\Delta_\pi = \log\pi(y^+|x)-\log\pi(y^-|x) \quad,\quad \Delta_{\text{ref}} = \log\pi_{\text{ref}}(y^+|x)-\log\pi_{\text{ref}}(y^-|x)\\ \mathcal{L}_{\text{DPO}}= -\log\sigma\left(\beta(\Delta_\pi-\Delta_{\text{ref}})\right)

dpo_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn.functional as F
from einops import einsum, rearrange, repeat

def dpo_loss(logp_pi_c, logp_pi_r, logp_ref_c, logp_ref_r, beta=0.1):
# all: (batch,)
diff_pi = logp_pi_c - logp_pi_r
diff_ref = logp_ref_c - logp_ref_r
logits = beta * (diff_pi - diff_ref)
return -F.logsigmoid(logits).mean()

def main():
# policy更偏好chosen => loss应小于~0.693
pi_c = torch.tensor([-1.0, -0.2, -0.1])
pi_r = torch.tensor([-2.0, -1.0, -0.5])
ref_c = torch.tensor([-1.2, -0.4, -0.2])
ref_r = torch.tensor([-1.8, -0.9, -0.45])
loss = dpo_loss(pi_c, pi_r, ref_c, ref_r, beta=0.5)
assert loss.item() < 0.8
print("OK: DPO loss sanity", loss.item())

if __name__ == "__main__":
main()
更新于 阅读次数

请我喝[茶]~( ̄▽ ̄)~*

小春日和 微信支付

微信支付

小春日和 支付宝

支付宝

小春日和 wechat

wechat