# 写在前面
本文涉及到的张量计算、reshape 等操作均使用 einops 库来实现,如果你还不会使用,强烈推荐花费 5 分钟学习一下,能大幅提升代码可读性。
einops_install 1 2 3 4 5 6 7 8 pip install einops from einops import rearrange, repeat, einsum
# 大模型相关手撕题
按重要度排序:(但一般就考个 mha)
# 自注意力
Attn ( Q , K , V ) = softmax ( Q K ⊤ d ) V \text{Attn}(Q,K,V)=\text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V
Attn ( Q , K , V ) = softmax ( d Q K ⊤ ) V
还有一种交叉注意力,Q 来自 query 序列,K/V 来自另一个序列,只需要把 forward 函数的 单独传输 k,v 即可再算即可。
self_attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 import torchimport torch.nn as nnfrom math import sqrtfrom einops import einsum, rearrange, repeatclass SelfAttention (nn.Module): def __init__ (self, d_model: int ): super ().__init__() self.d_model = d_model self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) def forward (self, x, mask = None ): q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) score = einsum(q, k, "b s d, b t d -> b s t" ) / sqrt(self.d_model) if mask is not None : score = score.masked_fill(mask == 0 , -1e9 ) attn = torch.softmax(score, dim = -1 ) out = einsum(attn, v, "b s t, b t d -> b s d" ) return self.w_o(out) def main (): torch.manual_seed(0 ) b, s, d = 2 , 5 , 32 x = torch.randn(b, s, d) mask = torch.tril(torch.ones(s, s)).view(1 , s, s) m = SelfAttention(d) y = m(x, mask) assert y.shape == (b, s, d) print ("OK: self-attention shape" ) if __name__ == "__main__" : main()
# 多头注意力机制
将 d 维的 Q、K、V 分成 h 组,每组维度为 d/h,分别计算注意力,实际就是一个 reshape
multihead_attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 from math import sqrtimport torch from einops import einsum, rearrangeimport torch.nn as nnclass MultiHeadAttention (nn.Module): def __init__ (self, d_model :int , num_heads : int ): super ().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.dim_heads = d_model // num_heads self.num_heads = num_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) def forward (self, q, k, v, mask = None ): batch_size = q.size(0 ) q = rearrange(self.w_q(q), 'b s (h d) -> b h s d' , h=self.num_heads, d=self.dim_heads) k = rearrange(self.w_k(k), 'b s (h d) -> b h s d' , h=self.num_heads, d=self.dim_heads) v = rearrange(self.w_v(v), 'b s (h d) -> b h s d' , h=self.num_heads, d=self.dim_heads) scores = einsum(q, k, "b h s d, b h t d -> b h s t" ) / sqrt(self.dim_heads) if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) attn_weight = torch.softmax(scores, dim = -1 ) output = einsum(attn_weight, v, "b h s t, b h t d -> b h s d" ) output = rearrange(output, "b h s d -> b s (h d)" , h = self.num_heads, d = self.dim_heads,) return self.w_o(output) def main (): d_model = 512 num_heads = 8 seq_len = 10 batch_size = 2 mha = MultiHeadAttention(d_model = d_model, num_heads= num_heads) q = torch.randn(batch_size, seq_len, d_model) k = torch.randn(batch_size, seq_len, d_model) v = torch.randn(batch_size, seq_len, d_model) mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0 ).unsqueeze(0 ) output = mha(q, k, v, mask) print (f"输入维度:{q.shape} " ) print (f"输出维度:{output.shape} " ) if __name__ == "__main__" : main()
# 分组查询注意力机制
多头注意力的变种,其中 query 按照 num_heads 分组,而 key 和 value 则按照 num_kv_heads 分组。num_heads 要大于 num_kv_heads,计算时拓展一下 shape 即可。
group_query_attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 import torchimport torch.nn as nnfrom einops import einsum, rearrange, repeatclass GroupQueryAttention (nn.Module): def __init__ (self, d_model : int , num_heads : int , num_kv_heads : int ): super ().__init__() assert d_model % num_heads == 0 assert num_heads % num_kv_heads == 0 self.d_model = d_model self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = d_model // num_heads self.kv_group = num_heads // num_kv_heads self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, self.num_kv_heads * self.head_dim) self.v_proj = nn.Linear(d_model, self.num_kv_heads * self.head_dim) self.o_proj = nn.Linear(d_model, d_model) def forward (self, x, mask ): B, S, D = x.shape q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q = rearrange(q, "b s (h d) -> b h s d" , h = self.num_heads) k = rearrange(k, "b s (h d) -> b h s d" , h = self.num_kv_heads) v = rearrange(v, "b s (h d) -> b h s d" , h = self.num_kv_heads) k_extend = repeat(k, "b h s d -> b (h g) s d" , g = self.kv_group) v_extend = repeat(v, "b h s d -> b (h g) s d" , g = self.kv_group) scores = einsum(q, k_extend, "b h s d, b h t d -> b h s t" ) / self.head_dim ** 0.5 if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) scores = torch.softmax(scores, dim = -1 ) out = einsum(scores, v_extend, "b h i j, b h j d -> b h i d" ) out = rearrange(out, "b h s d -> b s (h d)" ) out = self.o_proj(out) return out def main (): B, S, D = 2 , 8 , 64 head_q = 8 head_kv = 2 x = torch.randn(B, S, D) gqa = GroupQueryAttention(d_model=D, num_heads=head_q, num_kv_heads= head_kv) mask = torch.tril(torch.ones(S, S)).unsqueeze(0 ).unsqueeze(0 ) out = gqa(x, mask) print (f"input shape:{x.shape} " ) print (f"output shape:{out.shape} " ) if __name__ == "__main__" : main()
# safe softmax
减去最大值,避免数值过大导致溢出。
softmax ( x i ) = e x i ∑ j e x j = e x i − max ( x ) ∑ j e x j − max ( x ) \text{softmax}(x_i)=\frac{e^{x_i}}{\sum_j e^{x_j}}\\
=\frac{e^{x_i-\max(x)}}{\sum_j e^{x_j-\max(x)}}
softmax ( x i ) = ∑ j e x j e x i = ∑ j e x j − m a x ( x ) e x i − m a x ( x )
stable_softmax 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 import torchfrom einops import einsum, rearrangedef safe_softmax (scores : torch.tensor, dim: int = -1 ) -> torch.tensor: x_max = scores.max (dim = dim, keepdim = True ).values x_shifted = scores - x_max exp_x = torch.exp(x_shifted) softmax_x = exp_x / exp_x.sum (dim = dim, keepdim= True ) return softmax_x def main (): torch.manual_seed(0 ) x = torch.randn(4 , 7 ) * 20 y1 = safe_softmax(x, dim=-1 ) y2 = torch.softmax(x, dim=-1 ) assert torch.allclose(y1, y2, atol=1e-6 , rtol=1e-6 ) print ("OK: stable_softmax matches torch.softmax" ) if __name__ == "__main__" : main()
# 熵 Entropy 与交叉熵 Cross Entropy
熵是信息论中的一个重要概念,用于衡量一个随机变量的不确定性。对于离散随机变量 X,其熵定义为:
H ( X ) = − ∑ i P ( x i ) log P ( x i ) H(X) = -\sum_{i} P(x_i) \log P(x_i)
H ( X ) = − i ∑ P ( x i ) log P ( x i )
其中,P (x_i) 是随机变量 X 取值为 x_i 的概率。
本文实现的熵函数,输入是 logits,logits 是模型最后一层未经过归一化的原始输出张量,是计算概率之前的原始得分.
用到的 log_softmax 函数直接对 softmax 和 log 进行合并计算,数值更稳定。
交叉熵 (one-hot):
C E ( y , p ) = − ∑ i y i log p i \mathrm{CE}(y,p)=-\sum_i y_i \log p_i
C E ( y , p ) = − i ∑ y i log p i
entropy 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 import torchimport torch.nn.functional as Fdef entropy (logits : torch.tensor, dim : int = -1 ) -> torch.tensor: logp = F.log_softmax(logits, dim = dim) p = torch.exp(logp) return -(p * logp).sum (dim = dim) def cross_entropy (logits : torch.tensor, targets : torch.tensor, dim : int = -1 ) -> torch.tensor: """ 参数说明: - logits: 模型输出的原始预测值(未经过softmax),shape一般为[batch_size, num_classes] - targets: 目标标签,支持两种形式: 1. 类别索引(整数),shape为[batch_size] 2. 独热编码(one-hot),shape为[batch_size, num_classes] - dim: 计算softmax的维度,默认最后一维 返回: - 每个样本的交叉熵值,shape为[batch_size] 在官方实现中,返回的是批次的均值,要么调用官方函数时指定 reduction='none' ,要么修改这里的实现。 """ logp = F.log_softmax(logits, dim = dim) if targets.dim() == logits.dim() - 1 : targets = F.one_hot(targets, num_classes = logits.size(dim)).float () return -(targets * logp).sum (dim = dim) def main (): torch.manual_seed(0 ) logits = torch.randn(10 , 6 ) H = entropy(logits) assert H.shape == (10 ,) assert torch.isfinite(H).all () print ("OK: entropy_from_logits shape & finite" ) target = torch.randint(0 , 6 , (10 ,)) l1 = cross_entropy(logits, target) l2 = F.cross_entropy(logits, target, reduction='none' ) assert torch.allclose(l1, l2, atol=1e-6 , rtol=1e-6 ) print ("OK: cross_entropy_from_logits matches F.cross_entropy" ) if __name__ == "__main__" : main()
# KL 散度
KL 散度使用的符号有点混乱。
以下实现的 kl_divergence 定义为:
对于离散概率分布 P 和 Q,KL 散度定义为:
D K L ( P ∣ ∣ Q ) = ∑ i P ( x i ) log P ( x i ) Q ( x i ) D_{KL}(P || Q) = \sum_{i} P(x_i) \log \frac{P(x_i)}{Q(x_i)}
D K L ( P ∣ ∣ Q ) = i ∑ P ( x i ) log Q ( x i ) P ( x i )
而官方提供的接口 torch.nn.functional.kl_div 中,第一个参数是 log Q,第二个参数是 P, 详见 main 函数中的输入示例。
以下为 KL 散度的实现代码:
kl_divergence 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import torchdef kl_divergence (p, q ): eps=1e-8 log_ratio = torch.log((p + eps) / (q + eps)) kl = torch.sum (p * log_ratio, dim=-1 ) return kl def k1_estimate (q, p ): logr = torch.log(p) - torch.log(q) return -logr def k2_estimate (q, p ): logr = torch.log(p) - torch.log(q) return logr ** 2 / 2 def k3_estimate (q, p ): logr = torch.log(p) - torch.log(q) return (logr.exp() - 1 ) - logr def main (): torch.manual_seed(0 ) p_logits = torch.randn(6 , 10 ) q_logits = torch.randn(6 , 10 ) p = torch.softmax(p_logits, dim=-1 ) q = torch.softmax(q_logits, dim=-1 ) logp = torch.log_softmax(p_logits, dim=-1 ) logq = torch.log_softmax(q_logits, dim=-1 ) kld = kl_divergence(p, q) kl_div = torch.kl_div(logq, p, reduction=0 ).sum (dim=-1 ) assert torch.allclose(kld, kl_div, atol=1e-6 , rtol=1e-6 ) print ("OK: kl_divergence matches torch.kl_div" ) if __name__ == "__main__" : main()
其中,k1, k2, k3 参考 Schulman 的定义,p 和 q 和常见的顺序相反 Approximating KL Divergence
定义如下:
KL [ q , p ] = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] k 1 = log q ( x ) p ( x ) = − log r \text{KL}[q,p]=\sum_xq(x)\log\frac{q(x)}{p(x)}=\mathbb{E}_{x\sim q}\left[\log\frac{q(x)}{p(x)}\right] \\
k1=\log\frac{q(x)}{p(x)}=-\log r \\
KL [ q , p ] = x ∑ q ( x ) log p ( x ) q ( x ) = E x ∼ q [ log p ( x ) q ( x ) ] k 1 = log p ( x ) q ( x ) = − log r
无偏,方差 20,较大
k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 = 1 2 ( log r ) 2 k2=\frac{1}{2}\left(\log\frac{p(x)}{q(x)}\right)^2=\frac{1}{2}(\log r)^2 \\
k 2 = 2 1 ( log q ( x ) p ( x ) ) 2 = 2 1 ( log r ) 2
有偏,方差较小
k 3 = ( r − 1 ) − log r k3=(r-1)-\log r \\
k 3 = ( r − 1 ) − log r
无偏且方差小
此处为该博客给出的样例代码:
joschu_kl_estimators 1 2 3 4 5 6 7 8 9 10 11 12 import torch.distributions as disp = dis.Normal(loc=0 , scale=1 ) q = dis.Normal(loc=0.1 , scale=1 ) x = q.sample(sample_shape=(10_000_000 ,)) truekl = dis.kl_divergence(p, q) print ("true" , truekl)logr = p.log_prob(x) - q.log_prob(x) k1 = -logr k2 = logr ** 2 / 2 k3 = (logr.exp() - 1 ) - logr for k in (k1, k2, k3): print ((k.mean() - truekl) / truekl, k.std() / truekl)
# MLP with swiglu
swiglu 是如今 Qwen3 等主流模型中常用的激活函数,其公式为:
SwiGLU ( x ) = Swish ( x W 1 + b 1 ) ⊙ ( x W 2 + b 2 ) \text{SwiGLU}(x) = \text{Swish}(xW_1 + b_1) \odot (xW_2 + b_2)
SwiGLU ( x ) = Swish ( x W 1 + b 1 ) ⊙ ( x W 2 + b 2 )
swiglu 后再乘一个权重矩阵,即为一个 MLP 层了。
主流模型已经不再使用 bias。
工程实现中,通常会把两个线性层合并成一个线性层输出两倍的维度,然后再拆分。同样的情况在 RoPE 中也有出现。
mlp_swiglu 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 import torchimport torch.nn as nnfrom einops import einsum, rearrangeimport torch.nn.functional as Fclass MLP (nn.Module): def __init__ ( self, hidden_size: int , intermediate_size: int ): super ().__init__() self.gate_up_proj = nn.Linear(hidden_size, intermediate_size * 2 ) self.down_proj = nn.Linear(intermediate_size, hidden_size) def silu (self, x ): sig = 1 / (1 + torch.exp(-x)) return x * sig def forward (self, x ): gate_up = self.gate_up_proj(x) x1, x2 = gate_up.chunk(2 , dim=-1 ) activated = self.silu(x1) * x2 out = self.down_proj(activated) return out def main (): torch.manual_seed(0 ) batch_size = 2 seq_len = 4 hidden_size = 8 intermediate_size = 32 mlp = MLP(hidden_size=hidden_size, intermediate_size=intermediate_size) x = torch.randn(batch_size, seq_len, hidden_size) out = mlp(x) print (f"输入维度:{x.shape} " ) print (f"输出维度:{out.shape} " ) if __name__ == "__main__" : main()
# RoPE
1 2 import torch from einops import einsum, rearrange, repeat
# RMSNorm
r m s ( x ) = 1 d ∑ j x j 2 + ϵ , RMSNorm ( x ) = γ x r m s ( x ) \mathrm{rms}(x)=\sqrt{\frac{1}{d}\sum_j x_j^2+\epsilon}
\quad,\quad
\text{RMSNorm}(x)=\gamma\frac{x}{\mathrm{rms}(x)}
r m s ( x ) = d 1 j ∑ x j 2 + ϵ , RMSNorm ( x ) = γ r m s ( x ) x
实现中本应多用原地操作的,手撕题倒也不用那么复杂。
此外由于现在都用 pre_norm,所以 residual 直接加在输入上了。
输入的 x 在加上残差之后,会作为新的残差输入给下一层。
rmsnorm 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import torchimport torch.nn as nnfrom einops import einsum, rearrange, repeatclass RMSNorm (nn.Module): def __init__ (self, dim: int , eps: float = 1e-8 ): super ().__init__() self.dim = dim self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward (self, x, residual = None ): if residual is not None : x = x.add_(residual) rms = torch.sqrt((x * x).mean(dim=-1 , keepdim=True ) + self.eps) return (x / rms) * self.scale, x
# PPO
r ( θ ) = π θ ( a ∣ s ) π θ old ( a ∣ s ) = exp ( log π θ − log π old ) L PPO = − E [ min ( r A , clip ( r , 1 − ϵ , 1 + ϵ ) A ) ] r(\theta)=\frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)}
= \exp(\log\pi_\theta-\log\pi_{\text{old}})\\
\mathcal{L}_{\text{PPO}}=-\mathbb{E}\left[\min(rA,\ \text{clip}(r,1-\epsilon,1+\epsilon)A)\right]
r ( θ ) = π θ old ( a ∣ s ) π θ ( a ∣ s ) = exp ( log π θ − log π old ) L PPO = − E [ min ( r A , clip ( r , 1 − ϵ , 1 + ϵ ) A ) ]
ppo 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import torchfrom einops import einsum, rearrange, repeatdef ppo_clip_loss (logp_new, logp_old, advantage, clip_eps=0.2 ): ratio = torch.exp(logp_new - logp_old) unclipped = ratio * advantage clipped = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantage return -torch.minimum(unclipped, clipped).mean() def main (): logp_old = torch.zeros(4 ) logp_new = torch.log(torch.tensor([1.0 , 1.5 , 0.5 , 10.0 ])) adv = torch.ones(4 ) loss = ppo_clip_loss(logp_new, logp_old, adv, clip_eps=0.2 ) expected_obj = torch.tensor([1.0 , 1.2 , 0.8 , 1.2 ]).mean() assert torch.allclose(-loss, expected_obj, atol=1e-6 , rtol=1e-6 ) print ("OK: PPO clipped objective matches expected" ) if __name__ == "__main__" : main()
# DPO
Δ π = log π ( y + ∣ x ) − log π ( y − ∣ x ) , Δ ref = log π ref ( y + ∣ x ) − log π ref ( y − ∣ x ) L DPO = − log σ ( β ( Δ π − Δ ref ) ) \Delta_\pi = \log\pi(y^+|x)-\log\pi(y^-|x)
\quad,\quad
\Delta_{\text{ref}} = \log\pi_{\text{ref}}(y^+|x)-\log\pi_{\text{ref}}(y^-|x)\\
\mathcal{L}_{\text{DPO}}= -\log\sigma\left(\beta(\Delta_\pi-\Delta_{\text{ref}})\right)
Δ π = log π ( y + ∣ x ) − log π ( y − ∣ x ) , Δ ref = log π ref ( y + ∣ x ) − log π ref ( y − ∣ x ) L DPO = − log σ ( β ( Δ π − Δ ref ) )
dpo_loss 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import torchimport torch.nn.functional as Ffrom einops import einsum, rearrange, repeatdef dpo_loss (logp_pi_c, logp_pi_r, logp_ref_c, logp_ref_r, beta=0.1 ): diff_pi = logp_pi_c - logp_pi_r diff_ref = logp_ref_c - logp_ref_r logits = beta * (diff_pi - diff_ref) return -F.logsigmoid(logits).mean() def main (): pi_c = torch.tensor([-1.0 , -0.2 , -0.1 ]) pi_r = torch.tensor([-2.0 , -1.0 , -0.5 ]) ref_c = torch.tensor([-1.2 , -0.4 , -0.2 ]) ref_r = torch.tensor([-1.8 , -0.9 , -0.45 ]) loss = dpo_loss(pi_c, pi_r, ref_c, ref_r, beta=0.5 ) assert loss.item() < 0.8 print ("OK: DPO loss sanity" , loss.item()) if __name__ == "__main__" : main()