# nanovllm 详解
# layers
# SiluAndMul
siluandmul 1 2 3 4 5 6 7 8 9 10 class SiluAndMul (nn.Module): def __init__ (self ): super ().__init__() @torch.compile def forward (self, x: torch.Tensor ) -> torch.Tensor: x, y = x.chunk(2 , -1 ) return F.silu(x) * y
s i l u = x ⋅ σ ( x ) silu = x \cdot \sigma(x) s i l u = x ⋅ σ ( x ) , 其中 σ ( x ) = 1 1 + e − x \sigma(x)=\frac{1}{1+e^{-x}} σ ( x ) = 1 + e − x 1
Gate linear Unit: G L U ( x , W 1 , W 2 ) = σ ( x W 1 ) ⊙ ( x W 2 ) \text{Gate linear Unit: } GLU(x, W_1, W_2)=\sigma(x W_1) \odot (x W_2) Gate linear Unit: G L U ( x , W 1 , W 2 ) = σ ( x W 1 ) ⊙ ( x W 2 )
⊙ \odot ⊙ 表示逐元素乘法
S w i G l u ( x , W 1 , W 2 , W 3 ) = W 2 ( S i l u ( x W 1 ) ) ⊙ ( x W 3 ) SwiGlu(x, W_1,W_2, W_3)=W_2(Silu(x W_1)) \odot (x W_3) S w i G l u ( x , W 1 , W 2 , W 3 ) = W 2 ( S i l u ( x W 1 ) ) ⊙ ( x W 3 )
其中,W 2 W_2 W 2 是额外的线性变换矩阵
# LayerNorm
layernorm 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 import torchfrom torch import nnclass RMSNorm (nn.Module): def __init__ ( self, hidden_size: int , eps: float = 1e-6 , ) -> None : super ().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) @torch.compile def rms_forward ( self, x: torch.Tensor, ) -> torch.Tensor: orig_dtype = x.dtype x = x.float () var = x.pow (2 ).mean(dim=-1 , keepdim=True ) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x @torch.compile def add_rms_forward ( self, x: torch.Tensor, residual: torch.Tensor, ) -> tuple [torch.Tensor, torch.Tensor]: orig_dtype = x.dtype x = x.float ().add_(residual.float ()) residual = x.to(orig_dtype) var = x.pow (2 ).mean(dim=-1 , keepdim=True ) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x, residual def forward ( self, x: torch.Tensor, residual: torch.Tensor | None = None , ) -> torch.Tensor | tuple [torch.Tensor, torch.Tensor]: if residual is None : return self.rms_forward(x) else : return self.add_rms_forward(x, residual)
root mean square normalization
R M S N o r m ( a i ) = a i 1 d ∑ i = 1 d a i 2 + ϵ ⊙ g i RMSNorm(a_i) = \frac{a_i}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} a_i^2 + \epsilon}} \odot g_i R M S N o r m ( a i ) = d 1 ∑ i = 1 d a i 2 + ϵ a i ⊙ g i
其中,a i a_i a i 是输入向量的第 i i i 个元素,d d d 是输入向量的维度,ϵ \epsilon ϵ 是一个小常数,用于防止除以零,g i g_i g i 是可学习的权重参数,共有 d d d 个。
而 LayerNorm 的公式为:
L a y e r N o r m ( a i ) = a i − μ σ 2 + ϵ ⊙ g i + b i LayerNorm(a_i) = \frac{a_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot g_i + b_i L a y e r N o r m ( a i ) = σ 2 + ϵ a i − μ ⊙ g i + b i
其中,μ \mu μ 是输入向量的均值,σ 2 \sigma^2 σ 2 是输入向量的方差,g i g_i g i 和 b i b_i b i 是可学习的权重和偏置参数。
rmsnorm 计算量少,可学习参数也少,同时避免均值归一化导致的梯度消失
# Linear
linear 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 import torchfrom torch import nnimport torch.nn.functional as Fimport torch.distributed as distdef divide (numerator, denominator ): assert numerator % denominator == 0 return numerator // denominator class LinearBase (nn.Module): def __init__ ( self, input_size: int , output_size: int , bias: bool = False , tp_dim: int | None = None , ): super ().__init__() self.tp_dim = tp_dim self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() self.weight = nn.Parameter(torch.empty(output_size, input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(output_size)) self.bias.weight_loader = self.weight_loader else : self.register_parameter("bias" , None ) def forward (self, x: torch.Tensor ) -> torch.Tensor: raise NotImplementedError class ReplicatedLinear (LinearBase ): def __init__ ( self, input_size: int , output_size: int , bias: bool = False , ): super ().__init__(input_size, output_size, bias) def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor ): param.data.copy_(loaded_weight) def forward (self, x: torch.Tensor ) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class ColumnParallelLinear (LinearBase ): def __init__ ( self, input_size: int , output_size: int , bias: bool = False , ): tp_size = dist.get_world_size() super ().__init__(input_size, divide(output_size, tp_size), bias, 0 ) def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor ): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) def forward (self, x: torch.Tensor ) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class MergedColumnParallelLinear (ColumnParallelLinear ): def __init__ ( self, input_size: int , output_sizes: list [int ], bias: bool = False , ): self.output_sizes = output_sizes super ().__init__(input_size, sum (output_sizes), bias) def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int ): param_data = param.data shard_offset = sum (self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) class QKVParallelLinear (ColumnParallelLinear ): def __init__ ( self, hidden_size: int , head_size: int , total_num_heads: int , total_num_kv_heads: int | None = None , bias: bool = False , ): tp_size = dist.get_world_size() total_num_kv_heads = total_num_kv_heads or total_num_heads self.head_size = head_size self.num_heads = divide(total_num_heads, tp_size) self.num_kv_heads = divide(total_num_kv_heads, tp_size) output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size super ().__init__(hidden_size, output_size, bias) def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str ): param_data = param.data assert loaded_shard_id in ["q" , "k" , "v" ] if loaded_shard_id == "q" : shard_size = self.num_heads * self.head_size shard_offset = 0 elif loaded_shard_id == "k" : shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size else : shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) class RowParallelLinear (LinearBase ): def __init__ ( self, input_size: int , output_size: int , bias: bool = False , ): tp_size = dist.get_world_size() super ().__init__(divide(input_size, tp_size), output_size, bias, 1 ) def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor ): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) def forward (self, x: torch.Tensor ) -> torch.Tensor: y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None ) if self.tp_size > 1 : dist.all_reduce(y) return y
此处是张量并行的核心实现,dist 是 PyTorch 的分布式通信包。
linear 1 2 3 4 5 6 def linear (input , weight, bias=None ): output = input .matmul(weight.t()) if bias is not None : output += bias return output
线性层的数学定义: y = x W T + b y = x W^T + b y = x W T + b
其中, x : [ N , I n _ d i m ] , W : [ O u t _ d i m , I n _ d i m ] x: [N, In\_dim], W: [Out\_dim, In\_dim] x : [ N , I n _ d i m ] , W : [ O u t _ d i m , I n _ d i m ]
其中输入张量是行存储的,权重矩阵 W 的存储是转置的
例如:
[ x 1 x 2 x 3 ] ⏟ 1 × 3 ⋅ [ w 11 w 12 w 21 w 22 w 31 w 32 ] ⏟ in dim out dim \underbrace{
\begin{bmatrix}
x_1 & x_2 & x_3
\end{bmatrix}
}_{1\times 3}
\;\cdot\;
\underbrace{
\begin{bmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{bmatrix}}_{\text{in dim}}^{\text{out dim}}
1 × 3 [ x 1 x 2 x 3 ] ⋅ out dim ⎣ ⎢ ⎡ w 1 1 w 2 1 w 3 1 w 1 2 w 2 2 w 3 2 ⎦ ⎥ ⎤
对于矩阵 W 而言,计算时需要逐列读取,对 cache 不友好,故转置存储
[ x 1 x 2 x 3 ] ⏟ 1 × 3 ⋅ [ w 11 w 21 w 31 w 12 w 22 w 32 ] ⊤ ⏞ in dim } out dim \underbrace{\begin{bmatrix} x_1 & x_2 & x_3 \end{bmatrix}}_{1\times 3}
\cdot
\overbrace{
\begin{bmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32}
\end{bmatrix}^\top
}^{\text{in dim}}
\bigg\}\;\text{out dim}
1 × 3 [ x 1 x 2 x 3 ] ⋅ [ w 1 1 w 1 2 w 2 1 w 2 2 w 3 1 w 3 2 ] ⊤ in dim } out dim
ColumnParallelLinear 是对 out 维度进行拆分
W : ( o u t , i n ) : [ 4096 , 1024 ] x : ( B , S , i n ) : [ 8 , 16 , 1024 ] t p _ s i z e = 4 , t p _ d i m = 0 W = [ w 0 ⋮ w 3 ] } 1024 × 4 x ∗ W ⊤ → x ⋅ [ w 0 ⊤ ⋯ w 3 ⊤ ] → c o n c a t ( y 0 , … , y 3 ) W: (out, in):[4096, 1024] \\
x: (B, S, in): [8, 16, 1024] \\
tp\_size= 4, tp\_dim = 0 \\
W =
\left.
\left[
\begin{array}{c}
w_0 \\
\vdots \\
w_3
\end{array}
\right]\right\}1024 \times 4 \\
x * W^\top
\quad \rightarrow \quad
x \cdot \left[ w_0^\top \; \cdots \; w_3^\top \right]
\quad \rightarrow \quad
\mathrm{concat}(y_0, \dots, y_3)
W : ( o u t , i n ) : [ 4 0 9 6 , 1 0 2 4 ] x : ( B , S , i n ) : [ 8 , 1 6 , 1 0 2 4 ] t p _ s i z e = 4 , t p _ d i m = 0 W = ⎣ ⎢ ⎢ ⎡ w 0 ⋮ w 3 ⎦ ⎥ ⎥ ⎤ ⎭ ⎪ ⎪ ⎬ ⎪ ⎪ ⎫ 1 0 2 4 × 4 x ∗ W ⊤ → x ⋅ [ w 0 ⊤ ⋯ w 3 ⊤ ] → c o n c a t ( y 0 , … , y 3 )
是先切分,再每块单独计算,每块持有自己的计算结果,等到 rowparallel 线性层进行聚合。输入 x x x 不切分,会传递到每一个 rank 中
RowParallelLinear 是对 in 维度进行拆分,也需要拆分 x x x
W : ( o u t , i n ) : [ 4096 , 1024 ] x : ( B , S , i n ) : [ 8 , 16 , 1024 ] t p _ s i z e = 4 , t p _ d i m = 1 W = [ w 0 w 1 w 2 w 3 ] ⏟ 256 ∗ 4 W: (out, in):[4096, 1024] \\
x: (B, S, in): [8, 16, 1024] \\
tp\_size= 4, tp\_dim = 1 \\
W =
\underbrace{
\left[
\begin{array}{c|c|c|c}
w_0 & w_1 & w_2 & w_3
\end{array}
\right]}_{256 * 4}
W : ( o u t , i n ) : [ 4 0 9 6 , 1 0 2 4 ] x : ( B , S , i n ) : [ 8 , 1 6 , 1 0 2 4 ] t p _ s i z e = 4 , t p _ d i m = 1 W = 2 5 6 ∗ 4 [ w 0 w 1 w 2 w 3 ]
[ x 0 x 1 x 2 x 3 ] ⏟ B ∗ S ∗ [ 4 ∗ 256 ] \underbrace{
\left[
\begin{array}{c|c|c|c}
x_0 & x_1 & x_2 & x_3
\end{array}
\right]}_{B * S * [4 * 256]}
B ∗ S ∗ [ 4 ∗ 2 5 6 ] [ x 0 x 1 x 2 x 3 ]
[ x 0 x 1 x 2 x 3 ] ⏟ B ∗ S ∗ 4 ∗ 256 ⋅ [ w 0 ⊤ w 1 ⊤ w 2 ⊤ w 3 ⊤ ] ⏟ 4 × [ 256 ∗ 4096 ] = x 0 w 0 ⊤ + x 1 w 1 ⊤ + x 2 w 2 ⊤ + x 3 w 3 ⊤ ⏟ B ∗ S ∗ 4096 \underbrace{
\begin{bmatrix}
x_0 & x_1 & x_2 & x_3
\end{bmatrix}}_{B * S * 4 * 256}
\cdot
\underbrace{
\begin{bmatrix}
w_0^\top \\
w_1^\top \\
w_2^\top \\
w_3^\top
\end{bmatrix}}_{4 \times [256 * 4096]}
=\underbrace{
x_0 w_0^\top + x_1 w_1^\top + x_2 w_2^\top + x_3 w_3^\top}_{B * S * 4096}
B ∗ S ∗ 4 ∗ 2 5 6 [ x 0 x 1 x 2 x 3 ] ⋅ 4 × [ 2 5 6 ∗ 4 0 9 6 ] ⎣ ⎢ ⎢ ⎢ ⎡ w 0 ⊤ w 1 ⊤ w 2 ⊤ w 3 ⊤ ⎦ ⎥ ⎥ ⎥ ⎤ = B ∗ S ∗ 4 0 9 6 x 0 w 0 ⊤ + x 1 w 1 ⊤ + x 2 w 2 ⊤ + x 3 w 3 ⊤
MergedColumnParallelLinear 是 ColumnParallelLinear 的变体,支持一次加载多个权重块,减少通信开销
QKVParallelLinear 是针对自注意力机制中查询、键、值矩阵的特殊线性层,支持同时加载多个权重块,并且根据块的类型进行不同的切分和计算
注意类之间的继承关系与方法绑定
注意,attention 层是拆分计算 qkv,然后在 o 层进行合并计算。而 ffn 是层拆分 gate 和 up,最后在 down 层进行合并计算
则仅有 down 和 o 层需要通信汇总,使用 rowparallel 线性层,需要 all_reduce
此处的 dist.all_reduce(y) 默认情况下是求和操作。对于 4 卡而言,每张卡持有自己的 y i y_i y i , 通过 all_reduce 后,每张卡上的 y y y 都是 y 0 + y 1 + y 2 + y 3 y_0 + y_1 + y_2 + y_3 y 0 + y 1 + y 2 + y 3 ,实现了结果的汇总。
不同通信原语的功能讲解放在后文
# embed_head
embed_head 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 import torchfrom torch import nnimport torch.nn.functional as Fimport torch.distributed as distfrom nanovllm.utils.context import get_contextclass VocabParallelEmbedding (nn.Module): def __init__ ( self, num_embeddings: int , embedding_dim: int , ): super ().__init__() self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() assert num_embeddings % self.tp_size == 0 self.num_embeddings = num_embeddings self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) self.weight.weight_loader = self.weight_loader def weight_loader (self, param: nn.Parameter, loaded_weight: torch.Tensor ): param_data = param.data shard_size = param_data.size(0 ) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(0 , start_idx, shard_size) param_data.copy_(loaded_weight) def forward (self, x: torch.Tensor ): if self.tp_size > 1 : mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) x = mask * (x - self.vocab_start_idx) y = F.embedding(x, self.weight) if self.tp_size > 1 : y = mask.unsqueeze(1 ) * y dist.all_reduce(y) return y class ParallelLMHead (VocabParallelEmbedding ): def __init__ ( self, num_embeddings: int , embedding_dim: int , bias: bool = False , ): assert not bias super ().__init__(num_embeddings, embedding_dim) def forward (self, x: torch.Tensor ): context = get_context() if context.is_prefill: last_indices = context.cu_seqlens_q[1 :] - 1 x = x[last_indices].contiguous() logits = F.linear(x, self.weight) if self.tp_size > 1 : all_logits = [torch.empty_like(logits) for _ in range (self.tp_size)] if self.tp_rank == 0 else None dist.gather(logits, all_logits, 0 ) logits = torch.cat(all_logits, -1 ) if self.tp_rank == 0 else None return logits
embedding 1 2 3 4 5 6 7 8 def embedding (indices, weight ): output = [] for idx in indices.flatten(): output.append(weight[idx]) return torch.stack(output).reshape(*indices.shape, -1 )
每张卡仅处理自己的那部分 token_id, 也只拿自己的那部分权重。输入是 token_id,输出是对应的 embedding 向量。
这里要注意的是 forward 的过程
当检查到 tp_size > 1 时,首先创建一个 mask,标记输入的 token_id 是否在当前卡负责的范围内。
对于不在范围内的 token_id,mask 是 False,对应的 embedding 输出也应该是 0。
对于在范围内的 token_id,输入 x 需要减去 vocab_start_idx ,以映射到当前卡的 embedding 权重索引范围。
而后执行 F.embedding(x, self.weight) 。注意此处输入 x 的尺寸
nanovllm 是把输入展平了,为了方便理解,先按不展平看来,输入 x 的尺寸是 [B, S] ,输出 y 的尺寸就会是 [B, S, embedding_dim] 。
虽然 self.weight 的尺寸是 [num_embeddings_per_partition, embedding_dim] ,但由于输入 x 已经被映射到当前卡负责的 token_id 范围内,所以 F.embedding 会正确地返回对应的 embedding 向量。
而后,不展平情况下, mask 的尺寸是 [B, S] ,需要通过 mask.unsqueeze(-1) 将其扩展为 [B, S, 1] ,然后经过广播机制与 y 的尺寸 [B, S, embedding_dim] 进行逐元素乘法,最终得到的 y 中只有当前卡负责的 token_id 的 embedding 向量是非零的。
但如果展平了,输入 x 的尺寸是 [B*S] ,输出 y 的尺寸是 [B*S, embedding_dim] ,此时 mask 的尺寸是 [B*S] ,一维向量,那么 mask.unsqueeze(1) 的尺寸是 [B*S, 1] ,同样可以与 y 的尺寸 [B*S, embedding_dim] 进行逐元素乘法,达到同样的效果。
计算结束后执行 dist.all_reduce(y) ,将所有卡上的 y 进行求和汇总,最终每张卡上的 y 都包含了所有 token_id 的 embedding 向量。这里依然是求和,因为非当前卡负责的 token_id 的 embedding 向量是 0,所以求和后不会改变结果。
ParallelLMHead
和 VocabParallelEmbedding 类似,但它是输出层,输入是隐藏状态,输出是 vocab 大小的 logits。
prefill 时:x.shape = [total_tokens, hidden_dim]
decode 时:x.shape = [num_seqs, hidden_dim]
W_full.shape = [Vocab, hidden_dim]
cu_seqlens_q = [0, 3, 5],是每个 seq 的 length 累计长
计算时,每张卡拿到的权重为 W_part.shape = [vocab_per_partition, hidden_dim]
每张卡只负责计算自己那一段 vocab 的 logits,
logits.shape = [N, vocab_per_partition]
每张卡获得自己的部分 logits 后,使用 dist.gather 将所有卡上的 logits 收集到 rank 0 的卡上,最终在 rank 0 上将这些部分 logits 拼接成完整的 vocab 大小的 logits。
# rotary_embedding
rotary_embedding 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 from functools import lru_cacheimport torchfrom torch import nndef apply_rotary_emb ( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: x1, x2 = torch.chunk(x.float (), 2 , dim=-1 ) y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin return torch.cat((y1, y2), dim=-1 ).to(x.dtype) class RotaryEmbedding (nn.Module): def __init__ ( self, head_size: int , rotary_dim: int , max_position_embeddings: int , base: float , ) -> None : super ().__init__() self.head_size = head_size assert rotary_dim == head_size inv_freq = 1.0 / (base**(torch.arange(0 , rotary_dim, 2 , dtype=torch.float ) / rotary_dim)) t = torch.arange(max_position_embeddings, dtype=torch.float ) freqs = torch.einsum("i,j -> ij" , t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1 ).unsqueeze_(1 ) self.register_buffer("cos_sin_cache" , cache, persistent=False ) @torch.compile def forward ( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ) -> tuple [torch.Tensor, torch.Tensor]: cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2 , dim=-1 ) query = apply_rotary_emb(query, cos, sin) key = apply_rotary_emb(key, cos, sin) return query, key @lru_cache(1 ) def get_rope ( head_size: int , rotary_dim: int , max_position: int , base: float , rope_scaling: dict | None = None , ): assert rope_scaling is None rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) return rotary_emb
回忆旋转角公式:
设原向量为 v = ( x , y ) \boldsymbol{v}=(x,y) v = ( x , y ) ,它的长度为 r = x 2 + y 2 r=\sqrt{x^2+y^2} r = x 2 + y 2 ,与 x x x 轴的夹角为 α \alpha α ,所以它的极坐标形式是:
x = r cos α , y = r sin α x=r\cos\alpha, \quad y=r\sin\alpha x = r cos α , y = r sin α
当我们把它逆时针旋转 θ \theta θ 角后,新的夹角是 α + θ \alpha+\theta α + θ ,新坐标 ( x ′ , y ′ ) (x',y') ( x ′ , y ′ ) 满足:
{ x ′ = r cos ( α + θ ) y ′ = r sin ( α + θ ) \begin{cases}
x' = r\cos(\alpha+\theta) \\
y' = r\sin(\alpha+\theta)
\end{cases}
{ x ′ = r cos ( α + θ ) y ′ = r sin ( α + θ )
用三角函数的和角公式展开:
x ′ = r ( cos α cos θ − sin α sin θ ) = ( r cos α ) cos θ − ( r sin α ) sin θ = x cos θ − y sin θ y ′ = r ( sin α cos θ + cos α sin θ ) = ( r sin α ) cos θ + ( r cos α ) sin θ = x sin θ + y cos θ \begin{aligned}
x' &= r(\cos\alpha\cos\theta - \sin\alpha\sin\theta) = (r\cos\alpha)\cos\theta - (r\sin\alpha)\sin\theta = x\cos\theta - y\sin\theta \\
y' &= r(\sin\alpha\cos\theta + \cos\alpha\sin\theta) = (r\sin\alpha)\cos\theta + (r\cos\alpha)\sin\theta = x\sin\theta + y\cos\theta
\end{aligned}
x ′ y ′ = r ( cos α cos θ − sin α sin θ ) = ( r cos α ) cos θ − ( r sin α ) sin θ = x cos θ − y sin θ = r ( sin α cos θ + cos α sin θ ) = ( r sin α ) cos θ + ( r cos α ) sin θ = x sin θ + y cos θ
把这组线性关系写成矩阵乘法,就是:
( x ′ y ′ ) = ( cos θ − sin θ sin θ cos θ ) ( x y ) \begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x \\ y \end{pmatrix}
( x ′ y ′ ) = ( cos θ sin θ − sin θ cos θ ) ( x y )
对于第 k 个二维子空间,在位置 i 上的旋转矩阵为:
R k i = [ cos θ i , k − sin θ i , k sin θ i , k cos θ i , k ] R_k^i =
\begin{bmatrix}
\cos \theta_{i,k} & -\sin \theta_{i,k} \\
\sin \theta_{i,k} & \cos \theta_{i,k}
\end{bmatrix}
R k i = [ cos θ i , k sin θ i , k − sin θ i , k cos θ i , k ]
其中旋转角定义为:
θ i , k = i ⋅ ω k with ω k = 1 base 2 k / d \theta_{i,k} = i \cdot \omega_k
\qquad\text{with}\qquad
\omega_k = \frac{1}{\text{base}^{2k/d}}
θ i , k = i ⋅ ω k with ω k = base 2 k / d 1
等价地,
θ i , k = i base 2 k / d \theta_{i,k} = \frac{i}{\text{base}^{2k/d}}
θ i , k = base 2 k / d i
base :RoPE 的底数
i :序列位置(seq index)
k :频率通道 / 二维块编号
d :rotary dimension
R 中的每个方块都是上述的一个二维方阵
R ⋅ x : [ □ □ ⋱ □ ] ( x 0 x 1 ⋮ x d − 1 ) = ( x 0 x 1 x 2 x 3 ⋮ x d − 1 ) ⊙ ( cos θ 0 cos θ 0 cos θ 1 cos θ 1 ⋮ cos θ d / 2 − 1 ) + ( − x 1 x 0 − x 3 x 2 ⋮ x d − 2 ) ⊙ ( sin θ 0 sin θ 0 sin θ 1 sin θ 1 ⋮ sin θ d / 2 − 1 ) R \cdot x :
\begin{bmatrix}
\square & & & \\
& \square & & \\
& & \ddots & \\
& & & \square
\end{bmatrix}
\begin{pmatrix}
x_0 \\
x_1 \\
\vdots \\
x_{d-1}
\end{pmatrix}
=\begin{pmatrix}
x_0 \\
x_1 \\
x_2 \\
x_3 \\
\vdots \\
x_{d-1}
\end{pmatrix}
\odot
\begin{pmatrix}
\cos\theta_0 \\
\cos\theta_0 \\
\cos\theta_1 \\
\cos\theta_1 \\
\vdots \\
\cos\theta_{d/2-1}
\end{pmatrix}
+
\begin{pmatrix}
-x_1 \\
x_0 \\
-x_3 \\
x_2 \\
\vdots \\
x_{d-2}
\end{pmatrix}
\odot
\begin{pmatrix}
\sin\theta_0 \\
\sin\theta_0 \\
\sin\theta_1 \\
\sin\theta_1 \\
\vdots \\
\sin\theta_{d/2-1}
\end{pmatrix}
R ⋅ x : ⎣ ⎢ ⎢ ⎢ ⎡ □ □ ⋱ □ ⎦ ⎥ ⎥ ⎥ ⎤ ⎝ ⎜ ⎜ ⎜ ⎜ ⎛ x 0 x 1 ⋮ x d − 1 ⎠ ⎟ ⎟ ⎟ ⎟ ⎞ = ⎝ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎛ x 0 x 1 x 2 x 3 ⋮ x d − 1 ⎠ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎞ ⊙ ⎝ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎛ cos θ 0 cos θ 0 cos θ 1 cos θ 1 ⋮ cos θ d / 2 − 1 ⎠ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎞ + ⎝ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎛ − x 1 x 0 − x 3 x 2 ⋮ x d − 2 ⎠ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎞ ⊙ ⎝ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎜ ⎛ sin θ 0 sin θ 0 sin θ 1 sin θ 1 ⋮ sin θ d / 2 − 1 ⎠ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎞
但注意代码的实现,并不是原始的 rope,而是向量 x 的前半部分乘以 cos,后半部分乘以 sin,然后再进行线性组合。
只对 qk 做 rope
rope 的尺寸变化:
[max_pos, rotary_dim] -> [max_pos, 1, rotary_dim] -> [max_pos, head_num, head_size], 这里的 unsqueeze 是为了适配多头。但不适配 [batch, num_heads, seq_len, head_dim] 这种形状的多头,注意尺寸定义
# attention
attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 @triton.jit def store_kvcache_kernel ( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx = tl.program_id(0 ) slot = tl.load(slot_mapping_ptr + idx) if slot == -1 : return key_offsets = idx * key_stride + tl.arange(0 , D) value_offsets = idx * value_stride + tl.arange(0 , D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) cache_offsets = slot * D + tl.arange(0 , D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) def store_kvcache (key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor ): N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1 ) == 1 and value.stride(-1 ) == 1 assert key.stride(1 ) == head_dim and value.stride(1 ) == head_dim assert k_cache.stride(1 ) == D and v_cache.stride(1 ) == D assert slot_mapping.numel() == N store_kvcache_kernel[(N,)](key, key.stride(0 ), value, value.stride(0 ), k_cache, v_cache, slot_mapping, D)
两个 kvcache 代码要一起看:
stride() 函数会获取对应维度在内存中的跨度,单位是元素个数。stride (-1) 获取最后一个维度的跨度,连续存储的话就是 1。stride (x) 一般就是 x+1 的 size
slot_mapping 是一个长度为 N 的张量,记录了每个 token 对应的 slot 编号。slot 的范围是 [0, num_slots),其中 num_slots 是 k_cache 和 v_cache 中可用的 slot 数量。对于需要缓存的 token,slot_mapping 中对应位置的值是一个非负整数,表示该 token 应该存储到 k_cache 和 v_cache 中的哪个 slot;对于不需要缓存的 token,slot_mapping 中对应位置的值是 -1。
在 kernel 函数中,每个线程根据自己的 idx 从 slot_mapping 中获取对应的 slot 编号。
key_offsets = idx * key_stride + tl.arange(0, D) 这里的 kv stride 就是每个 token 的内存跨度,num_heads * head_dim = D
同样的方法获取 cache 中的内存偏移。每个 token 的跨度依然是 D
kernel 函数把 kv 存放进 cache 里
[ q 1 q 2 q 3 ] [ k 1 ⊤ , k 2 ⊤ , k 3 ⊤ ] [ v 1 v 2 v 3 ] = [ q 1 k 1 ⊤ q 1 k 2 ⊤ q 1 k 3 ⊤ q 2 k 1 ⊤ q 2 k 2 ⊤ q 2 k 3 ⊤ q 3 k 1 ⊤ q 3 k 2 ⊤ q 3 k 3 ⊤ ] [ v 1 v 2 v 3 ] \begin{bmatrix}
q_1 \\
q_2 \\
q_3
\end{bmatrix}
\begin{bmatrix}
k_1^\top,\; k_2^\top,\; k_3^\top
\end{bmatrix}
\begin{bmatrix}
v_1 \\
v_2 \\
v_3
\end{bmatrix}
=\begin{bmatrix}
q_1 k_1^\top & q_1 k_2^\top & q_1 k_3^\top \\
q_2 k_1^\top & q_2 k_2^\top & q_2 k_3^\top \\
q_3 k_1^\top & q_3 k_2^\top & q_3 k_3^\top
\end{bmatrix}
\begin{bmatrix}
v_1 \\
v_2 \\
v_3
\end{bmatrix}
⎣ ⎢ ⎡ q 1 q 2 q 3 ⎦ ⎥ ⎤ [ k 1 ⊤ , k 2 ⊤ , k 3 ⊤ ] ⎣ ⎢ ⎡ v 1 v 2 v 3 ⎦ ⎥ ⎤ = ⎣ ⎢ ⎡ q 1 k 1 ⊤ q 2 k 1 ⊤ q 3 k 1 ⊤ q 1 k 2 ⊤ q 2 k 2 ⊤ q 3 k 2 ⊤ q 1 k 3 ⊤ q 2 k 3 ⊤ q 3 k 3 ⊤ ⎦ ⎥ ⎤ ⎣ ⎢ ⎡ v 1 v 2 v 3 ⎦ ⎥ ⎤
= [ q 1 k 1 ⊤ v 1 + q 1 k 2 ⊤ v 2 + q 1 k 3 ⊤ v 3 q 2 k 1 ⊤ v 1 + q 2 k 2 ⊤ v 2 + q 2 k 3 ⊤ v 3 q 3 k 1 ⊤ v 1 + q 3 k 2 ⊤ v 2 + q 3 k 3 ⊤ v 3 ] =\begin{bmatrix}
q_1 k_1^\top v_1 + q_1 k_2^\top v_2 + q_1 k_3^\top v_3 \\
q_2 k_1^\top v_1 + q_2 k_2^\top v_2 + q_2 k_3^\top v_3 \\
q_3 k_1^\top v_1 + q_3 k_2^\top v_2 + q_3 k_3^\top v_3
\end{bmatrix}
= ⎣ ⎢ ⎡ q 1 k 1 ⊤ v 1 + q 1 k 2 ⊤ v 2 + q 1 k 3 ⊤ v 3 q 2 k 1 ⊤ v 1 + q 2 k 2 ⊤ v 2 + q 2 k 3 ⊤ v 3 q 3 k 1 ⊤ v 1 + q 3 k 2 ⊤ v 2 + q 3 k 3 ⊤ v 3 ⎦ ⎥ ⎤
= [ q 1 q 2 q 3 ] ( k 1 ⊤ v 1 + k 2 ⊤ v 2 + k 3 ⊤ v 3 ) =\begin{bmatrix}
q_1 \\
q_2 \\
q_3
\end{bmatrix}
\left(
k_1^\top v_1 + k_2^\top v_2 + k_3^\top v_3
\right)
= ⎣ ⎢ ⎡ q 1 q 2 q 3 ⎦ ⎥ ⎤ ( k 1 ⊤ v 1 + k 2 ⊤ v 2 + k 3 ⊤ v 3 )
每加入一个新的 token,则引入新的 qkv,新的 q 会和之前所有的 kv 进行计算,所以需要 kvcache
attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 class Attention (nn.Module): def __init__ ( self, num_heads, head_dim, scale, num_kv_heads, ): super ().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) def forward (self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.block_tables is not None : k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True , block_table=context.block_tables) else : o = flash_attn_with_kvcache(q.unsqueeze(1 ), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True ) return o
初始使用空张量给 k_cache 和 v_cache 占位,后续在 model_runner 中进行替换绑定
两种 flash attention,prefill 时使用 flash_attn_varlen_func,支持变长输入,输入的 qkv 都是展平的,使用 cu_seqlens 来区分每个 batch 的实际长度
decode 时使用 flash_attn_with_kvcache,输入的 q 是 [B, 1, H, D] 的形状,因为每次只处理一个新 token,kv 则直接使用 k_cache 和 v_cache
# sampler
sampler 1 2 3 4 5 6 7 8 9 10 11 class Sampler (nn.Module): def __init__ (self ): super ().__init__() @torch.compile def forward (self, logits: torch.Tensor, temperatures: torch.Tensor ): logits = logits.float ().div_(temperatures.unsqueeze(dim=1 )) probs = torch.softmax(logits, dim=-1 ) sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1 ).clamp_min_(1e-10 )).argmax(dim=-1 ) return sample_tokens
首先将 logits 除以 temperature,得到调整后的 logits。
然后对调整后的 logits 进行 softmax,得到每个 token 的概率分布。
sampler 1 2 3 noise = torch.empty_like(probs).exponential_(1 ).clamp_min_(1e-10 ) scores = probs / noise sample_tokens = scores.argmax(dim=-1 )
采样流程等价于先构造一个与 probs 同形状的噪声张量,噪声服从指数分布(exponential distribution),然后将 probs 除以这个噪声,得到一个新的分数张量 scores,最后在 scores 上取 argmax,得到采样的 token id。
这种方法和传统的 multinomial 采样类似,都是循概率采样,但对 gpu 更友好,全是单次元素操作和 argmax,且支持 torch.compile 优化,而 multinomial 是控制流加顺序扫描,不太适合 gpu 和 torch.compile
若 E i ∼ Exp ( 1 ) E_i \sim \text{Exp}(1) E i ∼ Exp ( 1 ) (独立同分布标准指数分布),则:
arg max i p i E i ∼ 分类分布 \arg\max_i \frac{p_i}{E_i} \sim \text{分类分布}
arg i max E i p i ∼ 分类分布
即 P ( arg max j p j E j = i ) = p i P\left( \arg\max_j \frac{p_j}{E_j} = i \right) = p_i P ( arg max j E j p j = i ) = p i (p i p_i p i 为采样概率,满足 ∑ i p i = 1 \sum_i p_i = 1 ∑ i p i = 1 )。
对 E ∼ Exp ( λ ) E \sim \text{Exp}(\lambda) E ∼ Exp ( λ ) ,有:
P ( E > t ) = e − λ t P(E > t) = e^{-\lambda t}
P ( E > t ) = e − λ t
标准指数分布(λ = 1 \lambda=1 λ = 1 )简化为:
P ( E > t ) = e − t P(E > t) = e^{-t}
P ( E > t ) = e − t
arg max j p j E j = i \arg\max_j \frac{p_j}{E_j} = i arg max j E j p j = i 等价于:
p i E i > p j E j ( ∀ j ≠ i ) ⟹ E j > p j p i E i ( ∀ j ≠ i ) \frac{p_i}{E_i} > \frac{p_j}{E_j} \quad (\forall j \neq i) \implies E_j > \frac{p_j}{p_i} E_i \quad (\forall j \neq i)
E i p i > E j p j ( ∀ j = i ) ⟹ E j > p i p j E i ( ∀ j = i )
因 E j E_j E j 相互独立,联合概率可拆分为乘积形式:
\begin{align*}
P\left( \arg\max_j \frac{p_j}{E_j} = i \right)
&= \int_0^\infty f_{E_i}(t) \prod_{j \neq i} P\left( E_j > \frac{p_j}{p_i} t \right) dt
\end{align*}
其中:
f E i ( t ) = e − t f_{E_i}(t) = e^{-t} f E i ( t ) = e − t (标准指数分布概率密度)
P ( E j > p j p i t ) = e − p j p i t P\left( E_j > \frac{p_j}{p_i} t \right) = e^{-\frac{p_j}{p_i} t} P ( E j > p i p j t ) = e − p i p j t (代入指数分布性质)
代入并化简指数项(利用 ∑ j p j = 1 \sum_j p_j = 1 ∑ j p j = 1 ):
\begin{align*}
\text{原式} &= \int_0^\infty e^{-t} \cdot \prod_{j \neq i} e^{-\frac{p_j}{p_i} t} dt \\
&= \int_0^\infty e^{-t \left( 1 + \sum_{j \neq i} \frac{p_j}{p_i} \right)} dt \\
&= \int_0^\infty e^{-\frac{t}{p_i}} dt = p_i
\end{align*}
# qwen3
qwen3 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 class Qwen3Attention (nn.Module): def __init__ ( self, hidden_size: int , num_heads: int , num_kv_heads: int , max_position: int = 4096 * 32 , head_dim: int | None = None , rms_norm_eps: float = 1e-06 , qkv_bias: bool = False , rope_theta: float = 10000 , rope_scaling: tuple | None = None , ) -> None : super ().__init__() tp_size = dist.get_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = self.total_num_kv_heads // tp_size self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim ** -0.5 self.qkv_bias = qkv_bias self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False , ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, ) if not self.qkv_bias: self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) def forward ( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1 ) q = q.view(-1 , self.num_heads, self.head_dim) k = k.view(-1 , self.num_kv_heads, self.head_dim) v = v.view(-1 , self.num_kv_heads, self.head_dim) if not self.qkv_bias: q = self.q_norm(q) k = self.k_norm(k) q, k = self.rotary_emb(positions, q, k) o = self.attn(q, k, v) output = self.o_proj(o.flatten(1 , -1 )) return output class Qwen3MLP (nn.Module): def __init__ ( self, hidden_size: int , intermediate_size: int , hidden_act: str , ) -> None : super ().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2 , bias=False , ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False , ) assert hidden_act == "silu" self.act_fn = SiluAndMul() def forward (self, x ): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.down_proj(x) return x class Qwen3DecoderLayer (nn.Module): def __init__ ( self, config: Qwen3Config, ) -> None : super ().__init__() self.self_attn = Qwen3Attention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position=config.max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr (config, 'attention_bias' , True ), head_dim=getattr (config, 'head_dim' , None ), rope_theta=getattr (config, "rope_theta" , 1000000 ), rope_scaling=getattr (config, "rope_scaling" , None ), ) self.mlp = Qwen3MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward ( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None , ) -> tuple [torch.Tensor, torch.Tensor]: if residual is None : hidden_states, residual = self.input_layernorm(hidden_states), hidden_states else : hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Qwen3Model (nn.Module): def __init__ ( self, config: Qwen3Config, ) -> None : super ().__init__() self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range (config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward ( self, input_ids: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen3ForCausalLM (nn.Module): packed_modules_mapping = { "q_proj" : ("qkv_proj" , "q" ), "k_proj" : ("qkv_proj" , "k" ), "v_proj" : ("qkv_proj" , "v" ), "gate_proj" : ("gate_up_proj" , 0 ), "up_proj" : ("gate_up_proj" , 1 ), } def __init__ ( self, config: Qwen3Config ) -> None : super ().__init__() self.model = Qwen3Model(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward ( self, input_ids: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: return self.model(input_ids, positions) def compute_logits ( self, hidden_states: torch.Tensor, ) -> torch.Tensor: return self.lm_head(hidden_states)
loader 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def default_weight_loader (param: nn.Parameter, loaded_weight: torch.Tensor ): param.data.copy_(loaded_weight) def load_model (model: nn.Module, path: str ): packed_modules_mapping = getattr (model, "packed_modules_mapping" , {}) for file in glob(os.path.join(path, "*.safetensors" )): with safe_open(file, "pt" , "cpu" ) as f: for weight_name in f.keys(): for k in packed_modules_mapping: if k in weight_name: v, shard_id = packed_modules_mapping[k] param_name = weight_name.replace(k, v) param = model.get_parameter(param_name) weight_loader = getattr (param, "weight_loader" ) weight_loader(param, f.get_tensor(weight_name), shard_id) break else : param = model.get_parameter(weight_name) weight_loader = getattr (param, "weight_loader" , default_weight_loader) weight_loader(param, f.get_tensor(weight_name))
# engine
pipline 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 LLM.generate() -> LLMEngine.add_request() -> Scheduler.add() -> while not finished: LLMEngine.step() -> Scheduler.schedule() -> ModelRunner.run() -> prepare_prefill() / prepare_decode() -> model forward -> sampler -> Scheduler.postprocess() 即: 用户 prompt -> Sequence -> Scheduler 决定本轮跑哪些请求 -> BlockManager 分配 KV cache block -> ModelRunner 准备张量并调用模型 -> Sampler 采样新 token -> Scheduler 更新状态 -> 完成后 decode 回文本
以下按照 sequence -> block manager -> scheduler -> model runner -> llm engine 的顺序介绍。
# sequence
sequence 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 class SequenceStatus (Enum ): WAITING = auto() RUNNING = auto() FINISHED = auto() class Sequence : block_size = 256 counter = count() def __init__ (self, token_ids: list [int ], sampling_params = SamplingParams( ) ): self.seq_id = next (Sequence .counter) self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) self.last_token = token_ids[-1 ] self.num_tokens = len (self.token_ids) self.num_prompt_tokens = len (token_ids) self.num_cached_tokens = 0 self.block_table = [] self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos def __len__ (self ): return self.num_tokens def __getitem__ (self, key ): return self.token_ids[key] @property def is_finished (self ): return self.status == SequenceStatus.FINISHED @property def num_completion_tokens (self ): return self.num_tokens - self.num_prompt_tokens @property def prompt_token_ids (self ): return self.token_ids[:self.num_prompt_tokens] @property def completion_token_ids (self ): return self.token_ids[self.num_prompt_tokens:] @property def num_cached_blocks (self ): return self.num_cached_tokens // self.block_size @property def num_blocks (self ): return (self.num_tokens + self.block_size - 1 ) // self.block_size @property def last_block_num_tokens (self ): return self.num_tokens - (self.num_blocks - 1 ) * self.block_size def block (self, i ): assert 0 <= i < self.num_blocks return self.token_ids[i*self.block_size: (i+1 )*self.block_size] def append_token (self, token_id: int ): self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 def __getstate__ (self ): return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.token_ids if self.num_completion_tokens == 0 else self.last_token) def __setstate__ (self, state ): self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1 ] if self.num_completion_tokens == 0 : self.token_ids = state[-1 ] else : self.last_token = state[-1 ]
# block manager
block_manager 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 class Block : def __init__ (self, block_id ): self.block_id = block_id self.ref_count = 0 self.hash = -1 self.token_ids = [] def update (self, hash : int , token_ids: list [int ] ): self.hash = hash self.token_ids = token_ids def reset (self ): self.ref_count = 1 self.hash = -1 self.token_ids = [] class BlockManager : def __init__ (self, num_blocks: int , block_size: int ): self.block_size = block_size self.blocks: list [Block] = [Block(i) for i in range (num_blocks)] self.hash_to_block_id: dict [int , int ] = dict () self.free_block_ids: deque[int ] = deque(range (num_blocks)) self.used_block_ids: set [int ] = set () @classmethod def compute_hash (cls, token_ids: list [int ], prefix: int = -1 ): h = xxhash.xxh64() if prefix != -1 : h.update(prefix.to_bytes(8 , "little" )) h.update(np.array(token_ids).tobytes()) return h.intdigest() def _allocate_block (self, block_id: int ) -> Block: block = self.blocks[block_id] assert block.ref_count == 0 block.reset() self.free_block_ids.remove(block_id) self.used_block_ids.add(block_id) return self.blocks[block_id] def _deallocate_block (self, block_id: int ) -> Block: assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) def can_allocate (self, seq: Sequence ) -> bool : return len (self.free_block_ids) >= seq.num_blocks def allocate (self, seq: Sequence ): assert not seq.block_table h = -1 cache_miss = False for i in range (seq.num_blocks): token_ids = seq.block(i) h = self.compute_hash(token_ids, h) if len (token_ids) == self.block_size else -1 block_id = self.hash_to_block_id.get(h, -1 ) if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True if cache_miss: block_id = self.free_block_ids[0 ] block = self._allocate_block(block_id) else : seq.num_cached_tokens += self.block_size if block_id in self.used_block_ids: block = self.blocks[block_id] block.ref_count += 1 else : block = self._allocate_block(block_id) if h != -1 : block.update(h, token_ids) self.hash_to_block_id[h] = block_id seq.block_table.append(block_id) def deallocate (self, seq: Sequence ): for block_id in reversed (seq.block_table): block = self.blocks[block_id] block.ref_count -= 1 if block.ref_count == 0 : self._deallocate_block(block_id) seq.num_cached_tokens = 0 seq.block_table.clear() def can_append (self, seq: Sequence ) -> bool : return len (self.free_block_ids) >= (len (seq) % self.block_size == 1 ) def may_append (self, seq: Sequence ): block_table = seq.block_table last_block = self.blocks[block_table[-1 ]] if len (seq) % self.block_size == 1 : assert last_block.hash != -1 block_id = self.free_block_ids[0 ] self._allocate_block(block_id) block_table.append(block_id) elif len (seq) % self.block_size == 0 : assert last_block.hash == -1 token_ids = seq.block(seq.num_blocks-1 ) prefix = self.blocks[block_table[-2 ]].hash if len (block_table) > 1 else -1 h = self.compute_hash(token_ids, prefix) last_block.update(h, token_ids) self.hash_to_block_id[h] = last_block.block_id else : assert last_block.hash == -1
block manager 负责管理块的分配和释放,以及序列的
这里的 reset () 是初始分配 block 用的,所以引用计数要置 1
注意 def allocate(self, seq: Sequence) 函数调用的 compute_hash 方法,他把上一个块的哈希作为前缀传入,保证了只有在前缀一致的情况下才会命中缓存,这样可以避免不同块内容相同但位置不同导致的哈希冲突问题。
sequence 新来时,调用 block_manager.allocate(sequence) ,会尝试为 sequence 的每个块分配缓存块,如果块内容和之前的块相同且位置相同,则命中缓存,直接复用块;否则分配新块并更新缓存。
使用 set 记录已使用的哈希值到块的映射。
cache 管理是 block 粒度的,只有 token 数量达到 block_size 时才计算哈希并尝试缓存。
释放时倒着释放,越靠后越不可能被其他序列共享,
allocate(seq)
用于 prefill
给整段 prompt 分配 block
顺便做 prefix cache 命中
may_append(seq)
用于 decode
每轮生成时维护 block_table
必要时追加新 block
-> 当 seq_len % block_size == 1 时,说明上一个 block 刚填满,还多了一个 token,需要新 block
当一个 block 填满时,把它登记进 hash cache
deallocate(seq)
用于请求结束或被抢占
释放该请求引用的 block
# scheduler
scheduler 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 class Scheduler : def __init__ (self, config: Config ): self.max_num_seqs = config.max_num_seqs self.max_num_batched_tokens = config.max_num_batched_tokens self.eos = config.eos self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) self.waiting: deque[Sequence ] = deque() self.running: deque[Sequence ] = deque() def is_finished (self ): return not self.waiting and not self.running def add (self, seq: Sequence ): self.waiting.append(seq) def schedule (self ) -> tuple [list [Sequence ], bool ]: scheduled_seqs = [] num_seqs = 0 num_batched_tokens = 0 while self.waiting and num_seqs < self.max_num_seqs: seq = self.waiting[0 ] if num_batched_tokens + len (seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): break num_seqs += 1 self.block_manager.allocate(seq) num_batched_tokens += len (seq) - seq.num_cached_tokens seq.status = SequenceStatus.RUNNING self.waiting.popleft() self.running.append(seq) scheduled_seqs.append(seq) if scheduled_seqs: return scheduled_seqs, True while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() while not self.block_manager.can_append(seq): if self.running: self.preempt(self.running.pop()) else : self.preempt(seq) break else : num_seqs += 1 self.block_manager.may_append(seq) scheduled_seqs.append(seq) assert scheduled_seqs self.running.extendleft(reversed (scheduled_seqs)) return scheduled_seqs, False def preempt (self, seq: Sequence ): seq.status = SequenceStatus.WAITING self.block_manager.deallocate(seq) self.waiting.appendleft(seq) def postprocess (self, seqs: list [Sequence ], token_ids: list [int ] ) -> list [bool ]: for seq, token_id in zip (seqs, token_ids): seq.append_token(token_id) if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: seq.status = SequenceStatus.FINISHED self.block_manager.deallocate(seq) self.running.remove(seq)
Scheduler.schedule 设计
schedule() 的职责是:每次 engine step 之前,决定这一轮要送给模型执行哪些 Sequence,以及这一轮是 prefill 还是 decode。
它返回:
text 1 return scheduled_seqs, is_prefill
其中:
text 1 2 scheduled_seqs: 本轮要执行的序列列表 is_prefill: True 表示本轮是 prefill,False 表示本轮是 decode
整个 scheduler 维护两个队列:
python 1 2 self.waiting: deque[Sequence ] self.running: deque[Sequence ]
含义是:
text 1 2 3 4 5 waiting: 新请求,或者被抢占后需要重新 prefill 的请求。 running: 已经完成 prefill,正在逐 token decode 的请求。
一、schedule 的总体策略
schedule() 分两段:
text 1 2 1. 优先调度 waiting 队列里的请求做 prefill 2. 如果没有任何 prefill 请求能被调度,再调度 running 队列里的请求做 decode
也就是说,这个实现里 prefill 优先级高于 decode 。
简化伪代码:
python 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def schedule (): scheduled_seqs = [] while waiting not empty and batch not full: if resource enough: allocate KV cache move seq from waiting to running scheduled_seqs.append(seq) else : break if scheduled_seqs: return scheduled_seqs, True while running not empty and batch not full: if can append KV cache: may_append KV cache scheduled_seqs.append(seq) else : preempt some seq return scheduled_seqs, False
二、Prefill 调度逻辑
代码:
python 1 2 3 4 5 6 7 8 9 10 11 12 13 while self.waiting and num_seqs < self.max_num_seqs: seq = self.waiting[0 ] if num_batched_tokens + len (seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): break num_seqs += 1 self.block_manager.allocate(seq) num_batched_tokens += len (seq) - seq.num_cached_tokens seq.status = SequenceStatus.RUNNING self.waiting.popleft() self.running.append(seq) scheduled_seqs.append(seq)
这段处理的是新请求的 prompt。
调度一个 waiting seq 需要满足两个条件:
text 1 2 1. batch token 数不能超过 max_num_batched_tokens 2. KV cache block 必须足够,即 can_allocate(seq)
max_num_batched_tokens 控制 prefill 的总 token 数,因为 prefill 是一次处理整段 prompt,开销和 prompt 长度强相关。
max_num_seqs 控制本轮最多处理多少条序列。
如果满足条件,就调用:
python 1 self.block_manager.allocate(seq)
allocate() 会为这个序列的 prompt 分配 KV cache block,并尝试做 prefix cache 命中。
之后:
python 1 2 3 4 seq.status = RUNNING waiting.popleft() running.append(seq) scheduled_seqs.append(seq)
意思是:
text 1 2 3 这个请求已经从 waiting 进入 running。 它本轮会被送去模型做 prefill。 prefill 完成后,后续 decode 会继续从 running 队列调度。
这里有一个细节:
python 1 num_batched_tokens += len (seq) - seq.num_cached_tokens
如果 prefix cache 命中了前面一部分 token,那么这些 token 不需要重新 prefill,所以真正进入模型计算的是:
text
三、为什么 prefill 有结果就立刻返回
代码:
python 1 2 if scheduled_seqs: return scheduled_seqs, True
意思是:只要这一轮成功调度了任何 prefill 请求,就不再混入 decode 请求。
所以这个实现是:
text 1 2 一轮 step 要么全是 prefill 一轮 step 要么全是 decode
这样做实现简单,因为 prefill 和 decode 的输入组织方式完全不同:
text 1 2 3 4 5 6 7 prefill: 输入是每条请求的 prompt 未缓存部分,长度可变。 使用 cu_seqlens_q / cu_seqlens_k / slot_mapping。 decode: 输入是每条请求的 last_token,每条序列只处理 1 个 token。 使用 context_lens / block_tables / slot_mapping。
所以 ModelRunner 可以根据 is_prefill 分别走:
python 1 2 prepare_prefill(seqs) prepare_decode(seqs)
四、Decode 调度逻辑
如果没有 waiting 请求被调度,才进入 decode:
python 1 2 3 4 5 6 7 8 9 10 11 12 13 while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() while not self.block_manager.can_append(seq): if self.running: self.preempt(self.running.pop()) else : self.preempt(seq) break else : num_seqs += 1 self.block_manager.may_append(seq) scheduled_seqs.append(seq)
decode 阶段处理的是已经 prefill 完成的序列。每条 running 序列每轮最多生成 1 个 token。
这里不是调用 allocate() ,而是调用:
python 1 2 can_append(seq) may_append(seq)
因为 decode 时序列长度每轮只增长 1。
can_append(seq) 检查:如果这次 decode 需要新 block,当前是否有空闲 block。
在 BlockManager 里:
python 1 return len (self.free_block_ids) >= (len (seq) % self.block_size == 1 )
这里的意思是:
text 1 2 3 如果 len(seq) % block_size == 1,说明当前 token 落在一个新 block 的第一个位置, 需要提前分配新 block。 否则当前 block 还没满,不需要新 block。
如果可以 append,就调用:
python 1 self.block_manager.may_append(seq)
may_append() 做两类事情:
text 1 2 1. 如果当前 token 需要新 block,就分配一个新 block 加到 block_table。 2. 如果某个 block 刚刚被填满,就计算 hash,把它加入 prefix cache。
然后把这个 seq 加入本轮 decode batch:
python 1 scheduled_seqs.append(seq)
五、抢占 preempt 逻辑
如果 decode 阶段发现 KV cache 不够:
python 1 2 3 4 5 6 while not self.block_manager.can_append(seq): if self.running: self.preempt(self.running.pop()) else : self.preempt(seq) break
这里的策略是:
text 1 2 优先抢占 running 队列尾部的序列。 如果没有其他 running 序列可以抢占,就抢占当前 seq 自己。
preempt() 做三件事:
python 1 2 3 seq.status = SequenceStatus.WAITING self.block_manager.deallocate(seq) self.waiting.appendleft(seq)
含义:
text 1 2 3 1. 把 seq 状态改回 WAITING。 2. 释放它占用的 KV cache block。 3. 放到 waiting 队列头部,后面优先重新 prefill。
抢占的本质是:
text 1 2 3 4 5 为了给当前要 decode 的请求腾出 KV cache, 牺牲另一个 running 请求的 KV cache。 被抢占的请求不会丢 token_ids, 但它的 KV cache 被释放了, 以后需要重新 prefill,或者通过 prefix cache 复用部分 KV。
这是一个简化版的 preemption 机制。
六、decode 后为什么要把 scheduled_seqs 放回 running 队列
decode 调度时,每取一个 seq 会先:
python 1 seq = self.running.popleft()
如果它被成功调度,就先放进:
python
decode 调度结束后:
python 1 self.running.extendleft(reversed (scheduled_seqs))
这句的作用是把本轮成功 decode 的 seq 放回 running 队列头部,并保持原顺序。
例如:
text 1 2 3 scheduled_seqs = [A, B, C] reversed(scheduled_seqs) = [C, B, A] running.extendleft([C, B, A])
extendleft 会依次从左边插入,所以最终队列头部仍然是:
text
为什么要放回 running?
因为 decode 后这些请求通常还没结束。模型本轮只会为每条序列生成一个 token。生成完后, postprocess() 会:
python 1 seq.append_token(token_id)
如果没遇到 EOS,也没达到 max_tokens ,它们还要继续参与下一轮 decode。
如果某个 seq 结束了, postprocess() 会做:
python 1 2 3 seq.status = FINISHED self.block_manager.deallocate(seq) self.running.remove(seq)
所以 running 队列里只保留还没完成的请求。
七、schedule 和 postprocess 的配合
一次 engine step 是:
python 1 2 3 seqs, is_prefill = scheduler.schedule() token_ids = model_runner.call("run" , seqs, is_prefill) scheduler.postprocess(seqs, token_ids)
schedule() 只负责:
text 1 2 3 决定谁跑 分配或准备 KV cache block 维护 waiting/running 队列
它不负责真的追加生成 token。
真正追加 token 在:
python
代码:
python 1 seq.append_token(token_id)
然后判断:
python 1 2 3 4 if token_id == eos or seq.num_completion_tokens == seq.max_tokens: seq.status = FINISHED block_manager.deallocate(seq) running.remove(seq)
所以 decode 阶段的一个完整循环是:
text 1 2 3 4 5 6 7 8 9 10 11 12 1. schedule: 选择 running seq 确保 KV cache 有空间 2. model_runner: 用 seq.last_token 做输入 计算 logits 采样出新 token 3. postprocess: 把新 token append 到 seq 如果结束,释放 KV cache
八、这个 Scheduler 的特点
这个实现是一个简化版 vLLM scheduler,特点是:
text 1 2 3 4 5 6 7 8 1. waiting 和 running 分离。 2. prefill 优先于 decode。 3. 每轮不混合 prefill 和 decode。 4. prefill 受 max_num_batched_tokens 和 max_num_seqs 限制。 5. decode 主要受 max_num_seqs 和 KV cache block 是否充足限制。 6. KV cache 不够时,会 preempt running seq。 7. 被抢占的 seq 会回到 waiting,之后重新 prefill。 8. prefix cache 可以降低重新 prefill 的成本。
可以用一句话总结:
text 1 schedule() 是一个两阶段调度器:先尽可能把 waiting 请求做 prefill;如果没有 prefill 可做,就从 running 请求里组 decode batch。调度过程中它通过 BlockManager 管理 KV cache block,必要时抢占 running 请求,把它们放回 waiting,以保证当前 batch 能继续执行。
# model runner
model_runner 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 class ModelRunner : def __init__ (self, config: Config, rank: int , event: Event | list [Event] ): self.config = config hf_config = config.hf_config self.block_size = config.kvcache_block_size self.enforce_eager = config.enforce_eager self.world_size = config.tensor_parallel_size self.rank = rank self.event = event dist.init_process_group("nccl" , "tcp://localhost:2333" , world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda" ) self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() self.warmup_model() self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu" ) torch.set_default_dtype(default_dtype) if self.world_size > 1 : if rank == 0 : self.shm = SharedMemory(name="nanovllm" , create=True , size=2 **20 ) dist.barrier() else : dist.barrier() self.shm = SharedMemory(name="nanovllm" ) self.loop() def exit (self ): if self.world_size > 1 : self.shm.close() dist.barrier() if self.rank == 0 : self.shm.unlink() if not self.enforce_eager: del self.graphs, self.graph_pool torch.cuda.synchronize() dist.destroy_process_group() def loop (self ): while True : method_name, args = self.read_shm() self.call(method_name, *args) if method_name == "exit" : break def read_shm (self ): assert self.world_size > 1 and self.rank > 0 self.event.wait() n = int .from_bytes(self.shm.buf[0 :4 ], "little" ) method_name, *args = pickle.loads(self.shm.buf[4 :n+4 ]) self.event.clear() return method_name, args def write_shm (self, method_name, *args ): assert self.world_size > 1 and self.rank == 0 data = pickle.dumps([method_name, *args]) n = len (data) self.shm.buf[0 :4 ] = n.to_bytes(4 , "little" ) self.shm.buf[4 :n+4 ] = data for event in self.event: event.set () def call (self, method_name, *args ): if self.world_size > 1 and self.rank == 0 : self.write_shm(method_name, *args) method = getattr (self, method_name, None ) return method(*args) def warmup_model (self ): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len num_seqs = min (max_num_batched_tokens // max_model_len, self.config.max_num_seqs) seqs = [Sequence ([0 ] * max_model_len) for _ in range (num_seqs)] self.run(seqs, True ) torch.cuda.empty_cache() def allocate_kv_cache (self ): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free peak = torch.cuda.memory_stats()["allocated_bytes.all.peak" ] current = torch.cuda.memory_stats()["allocated_bytes.all.current" ] num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr (hf_config, "head_dim" , hf_config.hidden_size // hf_config.num_attention_heads) block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks = int (total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert config.num_kvcache_blocks > 0 self.kv_cache = torch.empty(2 , hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim) layer_id = 0 for module in self.model.modules(): if hasattr (module, "k_cache" ) and hasattr (module, "v_cache" ): module.k_cache = self.kv_cache[0 , layer_id] module.v_cache = self.kv_cache[1 , layer_id] layer_id += 1 def prepare_block_tables (self, seqs: list [Sequence ] ): max_len = max (len (seq.block_table) for seq in seqs) block_tables = [seq.block_table + [-1 ] * (max_len - len (seq.block_table)) for seq in seqs] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) return block_tables def prepare_prefill (self, seqs: list [Sequence ] ): input_ids = [] positions = [] cu_seqlens_q = [0 ] cu_seqlens_k = [0 ] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] block_tables = None for seq in seqs: seqlen = len (seq) input_ids.extend(seq[seq.num_cached_tokens:]) positions.extend(list (range (seq.num_cached_tokens, seqlen))) seqlen_q = seqlen - seq.num_cached_tokens seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1 ] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1 ] + seqlen_k) max_seqlen_q = max (seqlen_q, max_seqlen_q) max_seqlen_k = max (seqlen_k, max_seqlen_k) if not seq.block_table: continue for i in range (seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1 : end = start + self.block_size else : end = start + seq.last_block_num_tokens slot_mapping.extend(list (range (start, end))) if cu_seqlens_k[-1 ] > cu_seqlens_q[-1 ]: block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True ).cuda(non_blocking=True ) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True ).cuda(non_blocking=True ) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) set_context(True , cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None , block_tables) return input_ids, positions def prepare_decode (self, seqs: list [Sequence ] ): input_ids = [] positions = [] slot_mapping = [] context_lens = [] for seq in seqs: input_ids.append(seq.last_token) positions.append(len (seq) - 1 ) context_lens.append(len (seq)) slot_mapping.append(seq.block_table[-1 ] * self.block_size + seq.last_block_num_tokens - 1 ) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True ).cuda(non_blocking=True ) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True ).cuda(non_blocking=True ) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True ).cuda(non_blocking=True ) block_tables = self.prepare_block_tables(seqs) set_context(False , slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) return input_ids, positions def prepare_sample (self, seqs: list [Sequence ] ): temperatures = [] for seq in seqs: temperatures.append(seq.temperature) temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True ).cuda(non_blocking=True ) return temperatures @torch.inference_mode() def run_model (self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool ): if is_prefill or self.enforce_eager or input_ids.size(0 ) > 512 : return self.model.compute_logits(self.model(input_ids, positions)) else : bs = input_ids.size(0 ) context = get_context() graph = self.graphs[next (x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars graph_vars["input_ids" ][:bs] = input_ids graph_vars["positions" ][:bs] = positions graph_vars["slot_mapping" ].fill_(-1 ) graph_vars["slot_mapping" ][:bs] = context.slot_mapping graph_vars["context_lens" ].zero_() graph_vars["context_lens" ][:bs] = context.context_lens graph_vars["block_tables" ][:bs, :context.block_tables.size(1 )] = context.block_tables graph.replay() return self.model.compute_logits(graph_vars["outputs" ][:bs]) def run (self, seqs: list [Sequence ], is_prefill: bool ) -> list [int ]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None reset_context() return token_ids @torch.inference_mode() def capture_cudagraph (self ): config = self.config hf_config = config.hf_config max_bs = min (self.config.max_num_seqs, 512 ) max_num_blocks = (config.max_model_len + self.block_size - 1 ) // self.block_size input_ids = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64) slot_mapping = torch.zeros(max_bs, dtype=torch.int32) context_lens = torch.zeros(max_bs, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) outputs = torch.zeros(max_bs, hf_config.hidden_size) self.graph_bs = [1 , 2 , 4 , 8 ] + list (range (16 , max_bs + 1 , 16 )) self.graphs = {} self.graph_pool = None for bs in reversed (self.graph_bs): graph = torch.cuda.CUDAGraph() set_context(False , slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) with torch.cuda.graph(graph, self.graph_pool): outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) if self.graph_pool is None : self.graph_pool = graph.pool() self.graphs[bs] = graph torch.cuda.synchronize() reset_context() self.graph_vars = dict ( input_ids=input_ids, positions=positions, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, outputs=outputs, )
ModelRunner 可以理解成 engine 里真正 “跑模型” 的执行器。
前面几个类的分工是:
text 1 2 3 4 5 Sequence 管一条请求的 token 状态 BlockManager 管 KV cache block 分配 Scheduler 决定这一轮跑哪些 Sequence ModelRunner 把 Sequence 变成 tensor,调用模型 forward,再采样 token LLMEngine 把整个流程串起来
ModelRunner 的核心职责
它主要做 5 件事:
text 1 2 3 4 5 1. 初始化模型和 GPU 环境 2. 分配 KV cache 3. 把调度出来的 Sequence 转成模型输入 4. 调用模型 forward 5. 根据 logits 采样新 token
最核心入口是:
python 1 def run (self, seqs: list [Sequence ], is_prefill: bool ) -> list [int ]:
这个函数每个 step 都会被调用。
和 LLMEngine 的交互
在 llm_engine.py 里,一轮生成是:
python 1 2 3 seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.call("run" , seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids)
所以调用关系是:
text 1 2 3 4 5 6 7 Scheduler 选出要跑的 seqs ↓ ModelRunner.run(seqs, is_prefill) ↓ 返回本轮采样出的 token_ids ↓ Scheduler.postprocess() 把 token append 回 seq
ModelRunner 本身不决定哪些请求能跑,它只执行 scheduler 给它的任务。
run () 的流程
核心代码:
python 1 2 3 4 5 6 7 def run (self, seqs, is_prefill ): input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None reset_context() return token_ids
简单说:
text 1 2 3 4 5 1. 根据 prefill/decode 准备 input_ids 和 positions 2. 准备采样参数,比如 temperature 3. 跑模型得到 logits 4. sampler 从 logits 里采样下一个 token 5. 返回 token_ids
Prefill 时怎么准备输入
prefill 是处理 prompt。
函数:
python
它会把多个请求的 prompt 拼成一个大 batch。
关键点:
python 1 input_ids.extend(seq[seq.num_cached_tokens:])
如果 prefix cache 命中了前面一部分,那么只需要跑没缓存的 token。
比如:
text 1 2 seq.token_ids = [A, B, C, D, E] seq.num_cached_tokens = 3
那么 prefill 只跑:
text
同时它会准备:
text 1 2 3 4 5 positions 每个 token 的位置 cu_seqlens_q query 的累积长度 cu_seqlens_k key 的累积长度 slot_mapping 每个 token 的 KV 要写到哪个 cache slot block_tables 每条 seq 的 block table
这些不是给普通 transformer forward 用的,而是给 attention 层用的。
Decode 时怎么准备输入
decode 是每条请求只处理 1 个 token。
函数:
python
核心逻辑:
python 1 2 3 4 input_ids.append(seq.last_token) positions.append(len (seq) - 1 ) context_lens.append(len (seq)) slot_mapping.append(seq.block_table[-1 ] * self.block_size + seq.last_block_num_tokens - 1 )
意思是:
text 1 2 decode 输入的是每条 seq 的 last_token。 模型根据 last_token 预测下一个 token。
易错点:
text 1 2 3 4 decode 时输入模型的 token,不是即将生成的新 token。 而是当前序列最后一个 token。 模型 forward 后 sampler 才得到新 token。 新 token 之后由 Scheduler.postprocess() append 到 seq。
和 BlockManager 的交互
ModelRunner 不直接分配 block。
block 分配发生在 Scheduler 里:
text 1 2 3 4 5 prefill: Scheduler -> BlockManager.allocate(seq) decode: Scheduler -> BlockManager.may_append(seq)
ModelRunner 只使用已经准备好的:
text 1 2 3 seq.block_table seq.num_cached_tokens seq.last_block_num_tokens
也就是说:
text 1 2 BlockManager 负责“KV cache 放哪” ModelRunner 负责“根据 block_table 把 tensor 写到正确位置”
和 Attention 的交互
ModelRunner 会调用:
python
把当前这轮推理需要的元信息放到全局 context 里。
比如 prefill:
python 1 2 3 4 5 6 7 8 9 10 set_context( True , cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None , block_tables, )
decode:
python 1 2 3 4 5 6 set_context( False , slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, )
然后 attention 层里会:
python
拿到这些信息。
也就是说:
text 1 2 ModelRunner 准备 context Attention 使用 context
比如 attention 需要知道:
text 1 2 3 4 这个 token 的 K/V 写到哪个 cache slot? 这条序列的 block_table 是什么? 当前是 prefill 还是 decode? 每条序列长度是多少?
这些都由 ModelRunner 准备。
KV cache 是怎么绑定到模型里的
初始化时, ModelRunner 会调用:
text 1 self.allocate_kv_cache()
它会分配一个大 tensor:
text 1 2 3 4 5 6 7 8 self.kv_cache = torch.empty( 2 , num_layers, num_kvcache_blocks, block_size, num_kv_heads, head_dim )
第一维 2 表示:
text 1 2 0: key cache 1: value cache
然后遍历模型里的 attention 模块:
text 1 2 3 4 for module in self.model.modules(): if hasattr (module, "k_cache" ) and hasattr (module, "v_cache" ): module.k_cache = self.kv_cache[0 , layer_id] module.v_cache = self.kv_cache[1 , layer_id]
所以 attention 层里的:
text 1 2 self.k_cache self.v_cache
其实指向的是 ModelRunner 分配出来的全局 KV cache。
和 Sampler 的交互
模型 forward 后会得到 logits:
text 1 logits = self.run_model(...)
然后采样:
text 1 token_ids = self.sampler(logits, temperatures).tolist()
temperatures 来自每条 Sequence :
text 1 temperatures.append(seq.temperature)
所以:
text 1 2 3 Sequence 保存采样参数 ModelRunner 取出来变成 tensor Sampler 根据 logits + temperature 采样 token
一轮完整交互
把几个类串起来就是:
text 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 LLMEngine.step() ↓ Scheduler.schedule() 选出 seqs prefill 时 allocate block decode 时 may_append block ↓ ModelRunner.run(seqs, is_prefill) prepare_prefill / prepare_decode set_context model forward sampler reset_context ↓ Scheduler.postprocess(seqs, token_ids) append token 判断 EOS / max_tokens 结束则 deallocate block
最重要的易错点
ModelRunner 不负责调度,调度是 Scheduler 做的。
ModelRunner 不负责分配 block,分配是 BlockManager 做的。
ModelRunner 负责把 Sequence.block_table 转成 attention 能用的 tensor。
prefill 输入的是 prompt 未缓存部分。
decode 输入的是 seq.last_token ,不是新 token。
新 token 是 sampler 采样出来后,由 Scheduler.postprocess() 追加到 Sequence 。
一句话总结:
text 1 ModelRunner 是执行层:Scheduler 告诉它跑哪些 seq,BlockManager 已经准备好 KV cache 布局,它把 seq 转成 GPU tensor,设置 attention 需要的 context,跑模型,采样出新 token,再交回 Scheduler 更新状态。
# LLM engine
llm_engine 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 class LLMEngine : def __init__ (self, model, **kwargs ): config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) self.ps = [] self.events = [] ctx = mp.get_context("spawn" ) for i in range (1 , config.tensor_parallel_size): event = ctx.Event() process = ctx.Process(target=ModelRunner, args=(config, i, event)) process.start() self.ps.append(process) self.events.append(event) self.model_runner = ModelRunner(config, 0 , self.events) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True ) config.eos = self.tokenizer.eos_token_id self.scheduler = Scheduler(config) atexit.register(self.exit) def exit (self ): self.model_runner.call("exit" ) del self.model_runner for p in self.ps: p.join() def add_request (self, prompt: str | list [int ], sampling_params: SamplingParams ): if isinstance (prompt, str ): prompt = self.tokenizer.encode(prompt) seq = Sequence (prompt, sampling_params) self.scheduler.add(seq) def step (self ): seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.call("run" , seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] num_tokens = sum (len (seq) for seq in seqs) if is_prefill else -len (seqs) return outputs, num_tokens def is_finished (self ): return self.scheduler.is_finished() def generate ( self, prompts: list [str ] | list [list [int ]], sampling_params: SamplingParams | list [SamplingParams], use_tqdm: bool = True , ) -> list [str ]: if use_tqdm: pbar = tqdm(total=len (prompts), desc="Generating" , dynamic_ncols=True ) if not isinstance (sampling_params, list ): sampling_params = [sampling_params] * len (prompts) for prompt, sp in zip (prompts, sampling_params): self.add_request(prompt, sp) outputs = {} prefill_throughput = decode_throughput = 0. while not self.is_finished(): t = perf_counter() output, num_tokens = self.step() if use_tqdm: if num_tokens > 0 : prefill_throughput = num_tokens / (perf_counter() - t) else : decode_throughput = -num_tokens / (perf_counter() - t) pbar.set_postfix({ "Prefill" : f"{int (prefill_throughput)} tok/s" , "Decode" : f"{int (decode_throughput)} tok/s" , }) for seq_id, token_ids in output: outputs[seq_id] = token_ids if use_tqdm: pbar.update(1 ) outputs = [outputs[seq_id] for seq_id in sorted (outputs.keys())] outputs = [{"text" : self.tokenizer.decode(token_ids), "token_ids" : token_ids} for token_ids in outputs] if use_tqdm: pbar.close() return outputs
engine 层可以按一条请求的生命周期理解,别按文件孤立看。
1. 用户请求进来
入口在 llm_engine.py 。
python 1 2 3 4 5 def add_request (self, prompt, sampling_params ): if isinstance (prompt, str ): prompt = self.tokenizer.encode(prompt) seq = Sequence (prompt, sampling_params) self.scheduler.add(seq)
这里做三件事:
text 1 2 3 prompt 文本 -> token ids token ids -> Sequence Sequence -> 放进 Scheduler.waiting
所以 Sequence 就是一条请求的状态对象。
2. Sequence 记录请求状态
sequence.py 负责保存:
text 1 2 3 4 5 6 7 token_ids prompt + 已生成 token last_token 当前最后一个 token num_prompt_tokens prompt 长度 num_cached_tokens 已经命中 prefix cache 的 token 数 block_table 这条请求使用哪些 KV cache block status WAITING / RUNNING / FINISHED sampling 参数 temperature / max_tokens / ignore_eos
你可以把它理解成:
text
3. Scheduler 决定这一轮跑谁
核心在 scheduler.py :
python 1 seqs, is_prefill = self.scheduler.schedule()
它维护两个队列:
text 1 2 3 4 5 waiting: 还没 prefill,或者被抢占后要重新 prefill 的请求 running: 已经 prefill 完,正在 decode 的请求
schedule() 的规则很简单:
text 1 2 先尝试从 waiting 里拿请求做 prefill。 如果没有 prefill 可做,再从 running 里拿请求做 decode。
prefill 时调用:
python 1 block_manager.allocate(seq)
decode 时调用:
python 1 block_manager.may_append(seq)
易错点:
text 1 schedule 只决定谁跑,不跑模型,也不追加新 token。
4. BlockManager 管 KV cache 位置
block_manager.py 负责 KV cache block。
核心数据:
python 1 2 3 4 self.blocks self.free_block_ids self.used_block_ids self.hash_to_block_id
每条 Sequence 里有:
python
比如:
python 1 seq.block_table = [2 , 5 , 9 ]
意思是:
text 1 2 3 这条序列的第 0 个逻辑 block 在物理 KV block 2 第 1 个逻辑 block 在物理 KV block 5 第 2 个逻辑 block 在物理 KV block 9
prefill 时:
python
负责给整段 prompt 分配 block,并通过 hash 尝试复用 prefix cache。
decode 时:
python
负责随着序列变长维护 block:
text 1 2 需要新 block 时分配。 block 填满时计算 hash,加入 prefix cache。
结束或抢占时:
python
释放这条 seq 引用的 block。
注意:
text 1 2 3 释放不等于清空 KV cache。 它只是减少 ref_count,把 block 放回 free list。 hash 和 token_ids 可能还保留,用于 prefix cache。
5. ModelRunner 真的跑模型
model_runner.py 是执行层。
scheduler 选好 seq 后, llm_engine 调:
python 1 token_ids = self.model_runner.call("run" , seqs, is_prefill)
run() 做:
text 1 2 3 4 5 1. prepare_prefill 或 prepare_decode 2. 准备 temperature 3. model forward 4. sampler 采样 token 5. 返回 token_ids
prefill 输入是:
text
decode 输入是:
text
易错点:
text 1 2 3 decode 输入模型的是 last_token。 模型输出后 sampler 才得到 next_token。 next_token 不是在 ModelRunner 里 append 的。
6. Context 把 ModelRunner 和 Attention 连起来
ModelRunner 会调用:
python
把这些信息放到全局 context:
text 1 2 3 4 5 is_prefill slot_mapping block_tables cu_seqlens_q / cu_seqlens_k context_lens
然后 attention 层里:
python
拿到这些东西。
所以关系是:
text 1 2 ModelRunner 根据 Sequence.block_table 准备 context Attention 根据 context 读写 KV cache
例如:
text 1 2 3 4 5 slot_mapping: 当前 token 的 K/V 应该写到 KV cache 的哪个 slot block_tables: 每条序列对应哪些 KV block
7. postprocess 更新请求状态
模型跑完后回到 llm_engine.py :
python 1 self.scheduler.postprocess(seqs, token_ids)
在 scheduler.py 里:
python 1 seq.append_token(token_id)
也就是:
text 1 真正把新 token 加进 Sequence 的地方,是 postprocess。
然后判断是否结束:
python 1 2 3 4 if token_id == eos or seq.num_completion_tokens == seq.max_tokens: seq.status = FINISHED block_manager.deallocate(seq) running.remove(seq)
8. generate 是完整外层循环
llm_engine.py 的 generate() :
python 1 2 while not self.is_finished(): output, num_tokens = self.step()
step() 是核心:
python 1 2 3 4 5 6 def step (self ): seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.call("run" , seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] return outputs, num_tokens
所以整个 engine 可以压缩成一句:
text 1 2 generate 不断调用 step; 每个 step 先 schedule,后 model_runner.run,最后 postprocess。
最简心智模型
text 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 LLMEngine: 总控,负责 generate 循环。 Sequence: 单个请求的状态。 Scheduler: 决定这一轮哪些请求跑 prefill / decode。 BlockManager: 管 KV cache block 分配、释放、prefix cache。 ModelRunner: 把 seq 转成 tensor,跑模型,采样 token。 Attention: 根据 ModelRunner 设置的 context 读写 KV cache。
一条请求的完整过程:
text 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 prompt -> tokenizer.encode -> Sequence -> waiting prefill: -> Scheduler.schedule -> BlockManager.allocate -> ModelRunner.prepare_prefill -> model forward -> sampler -> postprocess append token decode: -> Scheduler.schedule -> BlockManager.may_append -> ModelRunner.prepare_decode -> model forward -> sampler -> postprocess append token finish: -> BlockManager.deallocate -> tokenizer.decode
抓住这条线,engine 层就很清楚了。
# 通信原语速查
约定:p p p 张卡(rank 0 , 1 , … , p − 1 0, 1, \dots, p-1 0 , 1 , … , p − 1 ),每张卡持有一个张量块。用 A , B , C , D A, B, C, D A , B , C , D 表示数据块,下标表示来自哪个 rank。⊕ 表示归约操作(通常是 sum,也可 max/min/mean 等)。
下面所有原语都以 p = 4 p=4 p = 4 为例做可视化。
# 1. Broadcast
语义 :把 rank src 的张量发给所有 rank。
[ A ⋅ ⋅ ⋅ ] ⏟ before → broadcast (src=0) [ A A A A ] ⏟ after \underbrace{\begin{bmatrix} A \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{before}}
\xrightarrow{\text{broadcast (src=0)}}
\underbrace{\begin{bmatrix} A \\ A \\ A \\ A \end{bmatrix}}_{\text{after}}
before ⎣ ⎢ ⎢ ⎢ ⎡ A ⋅ ⋅ ⋅ ⎦ ⎥ ⎥ ⎥ ⎤ broadcast (src=0) after ⎣ ⎢ ⎢ ⎢ ⎡ A A A A ⎦ ⎥ ⎥ ⎥ ⎤
broadcast 1 dist.broadcast(tensor, src=0 )
流量 :( p − 1 ) N (p-1)N ( p − 1 ) N (树形实现可到 log p \log p log p 步)。
用处 :模型权重初始化同步、超参广播。
# 2. Scatter
语义 :rank src 持有 p p p 块,把第 i i i 块发给 rank i i i 。
[ [ A 0 , A 1 , A 2 , A 3 ] ⋅ ⋅ ⋅ ] ⏟ before (only rank 0 has data) → scatter [ A 0 A 1 A 2 A 3 ] ⏟ after \underbrace{\begin{bmatrix} [A_0, A_1, A_2, A_3] \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{before (only rank 0 has data)}}
\xrightarrow{\text{scatter}}
\underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{after}}
before (only rank 0 has data) ⎣ ⎢ ⎢ ⎢ ⎡ [ A 0 , A 1 , A 2 , A 3 ] ⋅ ⋅ ⋅ ⎦ ⎥ ⎥ ⎥ ⎤ scatter after ⎣ ⎢ ⎢ ⎢ ⎡ A 0 A 1 A 2 A 3 ⎦ ⎥ ⎥ ⎥ ⎤
scatter 1 dist.scatter(output, scatter_list=[A0, A1, A2, A3], src=0 )
流量 :N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p 。
用处 :把一个完整 batch 切给各卡。
# 3. Gather
语义 :Scatter 的逆。每 rank 各有一块,全部汇到 rank dst 。
[ A 0 A 1 A 2 A 3 ] ⏟ before → gather (dst=0) [ [ A 0 , A 1 , A 2 , A 3 ] ⋅ ⋅ ⋅ ] ⏟ after (only rank 0) \underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}}
\xrightarrow{\text{gather (dst=0)}}
\underbrace{\begin{bmatrix} [A_0, A_1, A_2, A_3] \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{after (only rank 0)}}
before ⎣ ⎢ ⎢ ⎢ ⎡ A 0 A 1 A 2 A 3 ⎦ ⎥ ⎥ ⎥ ⎤ gather (dst=0) after (only rank 0) ⎣ ⎢ ⎢ ⎢ ⎡ [ A 0 , A 1 , A 2 , A 3 ] ⋅ ⋅ ⋅ ⎦ ⎥ ⎥ ⎥ ⎤
gather 1 dist.gather(input , gather_list=[...] if rank == 0 else None , dst=0 )
用处 :TP 里把 logits 汇到 rank 0 给 Sampler。
# 4. Reduce
语义 :各 rank 一块,按 ⊕ 归约到 rank dst 。
[ A 0 A 1 A 2 A 3 ] ⏟ before → reduce (dst=0, op=sum) [ ∑ i = 0 3 A i ⋅ ⋅ ⋅ ] ⏟ after (only rank 0) \underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}}
\xrightarrow{\text{reduce (dst=0, op=sum)}}
\underbrace{\begin{bmatrix} \sum_{i=0}^{3} A_i \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{after (only rank 0)}}
before ⎣ ⎢ ⎢ ⎢ ⎡ A 0 A 1 A 2 A 3 ⎦ ⎥ ⎥ ⎥ ⎤ reduce (dst=0, op=sum) after (only rank 0) ⎣ ⎢ ⎢ ⎢ ⎡ ∑ i = 0 3 A i ⋅ ⋅ ⋅ ⎦ ⎥ ⎥ ⎥ ⎤
reduce 1 dist.reduce(tensor, dst=0 , op=dist.ReduceOp.SUM)
流量 :N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p 。
用处 :集中统计(loss、norm 汇总到主卡)。
# 5. All-Gather
语义 :每 rank 一块,拼起来给所有 rank。
[ A 0 A 1 A 2 A 3 ] ⏟ before → all-gather [ [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] ] ⏟ after \underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}}
\xrightarrow{\text{all-gather}}
\underbrace{\begin{bmatrix}
[A_0, A_1, A_2, A_3] \\
[A_0, A_1, A_2, A_3] \\
[A_0, A_1, A_2, A_3] \\
[A_0, A_1, A_2, A_3]
\end{bmatrix}}_{\text{after}}
before ⎣ ⎢ ⎢ ⎢ ⎡ A 0 A 1 A 2 A 3 ⎦ ⎥ ⎥ ⎥ ⎤ all-gather after ⎣ ⎢ ⎢ ⎢ ⎡ [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] [ A 0 , A 1 , A 2 , A 3 ] ⎦ ⎥ ⎥ ⎥ ⎤
all_gather 1 dist.all_gather(tensor_list, input )
流量 :N ( p − 1 ) N(p-1) N ( p − 1 ) 。
用处 :Column-Parallel Linear 输出 → 拼成完整激活;SP / FSDP 里 unshard 权重。
# 6. Reduce-Scatter
语义 :各 rank 持有完整的 p p p 块;对每一块做跨卡归约,结果的第 i i i 块留在 rank i i i 。
[ [ A 0 ( 0 ) , A 1 ( 0 ) , A 2 ( 0 ) , A 3 ( 0 ) ] [ A 0 ( 1 ) , A 1 ( 1 ) , A 2 ( 1 ) , A 3 ( 1 ) ] [ A 0 ( 2 ) , A 1 ( 2 ) , A 2 ( 2 ) , A 3 ( 2 ) ] [ A 0 ( 3 ) , A 1 ( 3 ) , A 2 ( 3 ) , A 3 ( 3 ) ] ] ⏟ before; upper idx = rank, lower idx = chunk → reduce-scatter (sum) [ ∑ r A 0 ( r ) ∑ r A 1 ( r ) ∑ r A 2 ( r ) ∑ r A 3 ( r ) ] ⏟ after: rank i 只留 chunk i \underbrace{\begin{bmatrix}
[A_0^{(0)}, A_1^{(0)}, A_2^{(0)}, A_3^{(0)}] \\
[A_0^{(1)}, A_1^{(1)}, A_2^{(1)}, A_3^{(1)}] \\
[A_0^{(2)}, A_1^{(2)}, A_2^{(2)}, A_3^{(2)}] \\
[A_0^{(3)}, A_1^{(3)}, A_2^{(3)}, A_3^{(3)}]
\end{bmatrix}}_{\text{before; upper idx = rank, lower idx = chunk}}
\xrightarrow{\text{reduce-scatter (sum)}}
\underbrace{\begin{bmatrix}
\sum_r A_0^{(r)} \\
\sum_r A_1^{(r)} \\
\sum_r A_2^{(r)} \\
\sum_r A_3^{(r)}
\end{bmatrix}}_{\text{after: rank }i\text{ 只留 chunk }i}
before; upper idx = rank, lower idx = chunk ⎣ ⎢ ⎢ ⎢ ⎢ ⎡ [ A 0 ( 0 ) , A 1 ( 0 ) , A 2 ( 0 ) , A 3 ( 0 ) ] [ A 0 ( 1 ) , A 1 ( 1 ) , A 2 ( 1 ) , A 3 ( 1 ) ] [ A 0 ( 2 ) , A 1 ( 2 ) , A 2 ( 2 ) , A 3 ( 2 ) ] [ A 0 ( 3 ) , A 1 ( 3 ) , A 2 ( 3 ) , A 3 ( 3 ) ] ⎦ ⎥ ⎥ ⎥ ⎥ ⎤ reduce-scatter (sum) after: rank i 只留 chunk i ⎣ ⎢ ⎢ ⎢ ⎢ ⎡ ∑ r A 0 ( r ) ∑ r A 1 ( r ) ∑ r A 2 ( r ) ∑ r A 3 ( r ) ⎦ ⎥ ⎥ ⎥ ⎥ ⎤
reduce_scatter 1 dist.reduce_scatter(output, input_list=[...], op=dist.ReduceOp.SUM)
流量 :N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p 。
用处 :FSDP 梯度聚合、Sequence Parallel 的反向。
# 7. All-Reduce
语义 :各 rank 一块,归约后所有 rank 都持有相同的归约结果。
[ A 0 A 1 A 2 A 3 ] ⏟ before → all-reduce (sum) [ ∑ i = 0 3 A i ∑ i = 0 3 A i ∑ i = 0 3 A i ∑ i = 0 3 A i ] ⏟ after \underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}}
\xrightarrow{\text{all-reduce (sum)}}
\underbrace{\begin{bmatrix}
\sum_{i=0}^{3} A_i \\
\sum_{i=0}^{3} A_i \\
\sum_{i=0}^{3} A_i \\
\sum_{i=0}^{3} A_i
\end{bmatrix}}_{\text{after}}
before ⎣ ⎢ ⎢ ⎢ ⎡ A 0 A 1 A 2 A 3 ⎦ ⎥ ⎥ ⎥ ⎤ all-reduce (sum) after ⎣ ⎢ ⎢ ⎢ ⎡ ∑ i = 0 3 A i ∑ i = 0 3 A i ∑ i = 0 3 A i ∑ i = 0 3 A i ⎦ ⎥ ⎥ ⎥ ⎤
all_reduce 1 dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
等价分解 :
all-reduce ≡ reduce-scatter + all-gather \boxed{\ \text{all-reduce} \;\equiv\; \text{reduce-scatter} \;+\; \text{all-gather}\ }
all-reduce ≡ reduce-scatter + all-gather
流量 (Ring-AllReduce):2 N ( p − 1 ) / p ≈ 2 N 2N(p-1)/p \approx 2N 2 N ( p − 1 ) / p ≈ 2 N (带宽最优)。
用处 :DP 梯度同步;TP 的 RowParallel 出口、MLP down_proj 出口;Attention o_proj 出口。
# 8. All-to-All
语义 :每个 rank 把自己的张量切成 p p p 块,第 j j j 块发给 rank j j j ;同时从 rank j j j 收其第 i i i 块。本质是 **“按块转置”**。
before: [ [ A 0 0 , A 0 1 , A 0 2 , A 0 3 ] [ A 1 0 , A 1 1 , A 1 2 , A 1 3 ] [ A 2 0 , A 2 1 , A 2 2 , A 2 3 ] [ A 3 0 , A 3 1 , A 3 2 , A 3 3 ] ] → all-to-all after: [ [ A 0 0 , A 1 0 , A 2 0 , A 3 0 ] [ A 0 1 , A 1 1 , A 2 1 , A 3 1 ] [ A 0 2 , A 1 2 , A 2 2 , A 3 2 ] [ A 0 3 , A 1 3 , A 2 3 , A 3 3 ] ] \text{before:}\quad
\begin{bmatrix}
[A_0^{0}, A_0^{1}, A_0^{2}, A_0^{3}] \\
[A_1^{0}, A_1^{1}, A_1^{2}, A_1^{3}] \\
[A_2^{0}, A_2^{1}, A_2^{2}, A_2^{3}] \\
[A_3^{0}, A_3^{1}, A_3^{2}, A_3^{3}]
\end{bmatrix}
\xrightarrow{\text{all-to-all}}
\text{after:}\quad
\begin{bmatrix}
[A_0^{0}, A_1^{0}, A_2^{0}, A_3^{0}] \\
[A_0^{1}, A_1^{1}, A_2^{1}, A_3^{1}] \\
[A_0^{2}, A_1^{2}, A_2^{2}, A_3^{2}] \\
[A_0^{3}, A_1^{3}, A_2^{3}, A_3^{3}]
\end{bmatrix}
before: ⎣ ⎢ ⎢ ⎢ ⎡ [ A 0 0 , A 0 1 , A 0 2 , A 0 3 ] [ A 1 0 , A 1 1 , A 1 2 , A 1 3 ] [ A 2 0 , A 2 1 , A 2 2 , A 2 3 ] [ A 3 0 , A 3 1 , A 3 2 , A 3 3 ] ⎦ ⎥ ⎥ ⎥ ⎤ all-to-all after: ⎣ ⎢ ⎢ ⎢ ⎡ [ A 0 0 , A 1 0 , A 2 0 , A 3 0 ] [ A 0 1 , A 1 1 , A 2 1 , A 3 1 ] [ A 0 2 , A 1 2 , A 2 2 , A 3 2 ] [ A 0 3 , A 1 3 , A 2 3 , A 3 3 ] ⎦ ⎥ ⎥ ⎥ ⎤
行下标 = rank,列上标 = 目标 rank。可以看成分块矩阵的转置 T i j → T j i T_{ij} \to T_{ji} T i j → T j i 。
all_to_all 1 2 dist.all_to_all(output_list, input_list)
流量 :N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p (每 rank 发 / 收)。
用处 :MoE 的 dispatch /combine(token → expert 两次 all-to-all);Sequence Parallel 切 head 与切 seq 互转。
# 9. Barrier
语义 :无数据交换,纯同步点。
rank 0: ⋯ ⟶ ∥ barrier ⟶ ⋯ \text{rank 0: } \cdots \longrightarrow \|_{\text{barrier}} \longrightarrow \cdots
rank 0: ⋯ ⟶ ∥ barrier ⟶ ⋯
barrier
# 10. 点对点:Send / Recv
语义 :有向的单播。
rank s : X → send rank r : X \text{rank } s: X \xrightarrow{\text{send}} \text{rank } r:X
rank s : X send rank r : X
send/recv 1 2 3 dist.send(tensor, dst=r) dist.recv(tensor, src=s)
用处 :Pipeline Parallel 相邻 rank 传激活 / 梯度。
# 11. 组合关系与流量对照
关键恒等式 (Ring-AllReduce 正是这么实现的):
all-reduce = reduce-scatter + all-gather \text{all-reduce} \;=\; \text{reduce-scatter} \;+\; \text{all-gather}
all-reduce = reduce-scatter + all-gather
all-gather = gather + broadcast \text{all-gather} \;=\; \text{gather} \;+\; \text{broadcast}
all-gather = gather + broadcast
reduce = reduce-scatter + gather \text{reduce} \;=\; \text{reduce-scatter} \;+\; \text{gather}
reduce = reduce-scatter + gather
原语
每卡流量 (Ring)
结果位置
归约?
broadcast
≈ N \approx N ≈ N
所有 rank
✗
scatter
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p (from src)
分散
✗
gather
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p (to dst)
单 rank
✗
reduce
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p
单 rank
✓
all-gather
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p
所有 rank
✗
reduce-scatter
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p
分散
✓
all-reduce
2 N ( p − 1 ) / p 2N(p-1)/p 2 N ( p − 1 ) / p
所有 rank
✓
all-to-all
N ( p − 1 ) / p N(p-1)/p N ( p − 1 ) / p
重排
✗
场景
原语
数学形式
DP 梯度同步
all-reduce
g ˉ = 1 p ∑ i g i \bar g = \frac{1}{p}\sum_i g_i g ˉ = p 1 ∑ i g i
TP o_proj / down_proj 出口
all-reduce
y = ∑ i x i W i ⊤ y = \sum_i x_i W_i^\top y = ∑ i x i W i ⊤
TP VocabParallelEmbedding
all-reduce
∑ i mask i ⊙ E i ( x ) \sum_i \text{mask}_i\odot E_i(x) ∑ i mask i ⊙ E i ( x )
TP ParallelLMHead 收 logits
gather
logits 汇到 rank 0
Sequence Parallel Attn 入口
all-gather
把切 seq 的激活拼回
Sequence Parallel Attn 出口
reduce-scatter
拼回后的激活归约再切
FSDP 前向 unshard
all-gather
W = [ W 0 , … , W p − 1 ] W = [W_0, \dots, W_{p-1}] W = [ W 0 , … , W p − 1 ]
FSDP 反向 re-shard grad
reduce-scatter
g_i = \sum_j g_j^
MoE dispatch / combine
all-to-all × 2
token ↔ expert 两次转置
Pipeline 前向传激活
send / recv
h ℓ h_\ell h ℓ 从 rank ℓ \ell ℓ 到 ℓ + 1 \ell+1 ℓ + 1