# nanovllm 详解

# layers

# SiluAndMul

siluandmul
1
2
3
4
5
6
7
8
9
10
class SiluAndMul(nn.Module):

def __init__(self):
super().__init__()

@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1) # 沿最后一维等分成两份
return F.silu(x) * y # 点积
# 用一个linear 代替两个,[hidden_size, intermediate_dim * 2]

silu=xσ(x)silu = x \cdot \sigma(x) , 其中 σ(x)=11+ex\sigma(x)=\frac{1}{1+e^{-x}}

Gate linear Unit: GLU(x,W1,W2)=σ(xW1)(xW2)\text{Gate linear Unit: } GLU(x, W_1, W_2)=\sigma(x W_1) \odot (x W_2)
\odot 表示逐元素乘法
SwiGlu(x,W1,W2,W3)=W2(Silu(xW1))(xW3)SwiGlu(x, W_1,W_2, W_3)=W_2(Silu(x W_1)) \odot (x W_3)
其中,W2W_2 是额外的线性变换矩阵

# LayerNorm

layernorm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch import nn
class RMSNorm(nn.Module):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))

@torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
# 带_的函数表示原地操作,节省内存
x = x.to(orig_dtype).mul_(self.weight)
return x

@torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual

def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)

root mean square normalization
RMSNorm(ai)=ai1di=1dai2+ϵgiRMSNorm(a_i) = \frac{a_i}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} a_i^2 + \epsilon}} \odot g_i
其中,aia_i 是输入向量的第 ii 个元素,dd 是输入向量的维度,ϵ\epsilon 是一个小常数,用于防止除以零,gig_i 是可学习的权重参数,共有 dd 个。

而 LayerNorm 的公式为:
LayerNorm(ai)=aiμσ2+ϵgi+biLayerNorm(a_i) = \frac{a_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot g_i + b_i
其中,μ\mu 是输入向量的均值,σ2\sigma^2 是输入向量的方差,gig_ibib_i 是可学习的权重和偏置参数。

rmsnorm 计算量少,可学习参数也少,同时避免均值归一化导致的梯度消失

# Linear

linear
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist


def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator


class LinearBase(nn.Module):

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim # 张量并行的维度,0表示输出维度,1表示输入维度
self.tp_rank = dist.get_rank() # 当前显卡编号
self.tp_size = dist.get_world_size() # 张量并行的总显卡数量
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size)) # 许多实现中没有bias,这个算可选项
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

# 全量复制线性层
class ReplicatedLinear(LinearBase):

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)

# 列切分线性层,输出维度切分
class ColumnParallelLinear(LinearBase):

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)


class MergedColumnParallelLinear(ColumnParallelLinear):

def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)


class QKVParallelLinear(ColumnParallelLinear):

def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)


class RowParallelLinear(LinearBase):

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
# 偏置只在tp_rank=0的显卡上计算一次,其他显卡上不计算偏置,最终结果通过通信汇总
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1: # 多卡时进行通信汇总
dist.all_reduce(y)
return y

此处是张量并行的核心实现,dist 是 PyTorch 的分布式通信包。

linear
1
2
3
4
5
6
def linear(input, weight, bias=None):
# 数学公式:y = x @ W.T + b
output = input.matmul(weight.t())
if bias is not None:
output += bias
return output

线性层的数学定义: y=xWT+by = x W^T + b
其中, x:[N,In_dim],W:[Out_dim,In_dim]x: [N, In\_dim], W: [Out\_dim, In\_dim]
其中输入张量是行存储的,权重矩阵 W 的存储是转置的
例如:

[x1x2x3]1×3    [w11w12w21w22w31w32]in dimout dim\underbrace{ \begin{bmatrix} x_1 & x_2 & x_3 \end{bmatrix} }_{1\times 3} \;\cdot\; \underbrace{ \begin{bmatrix} w_{11} & w_{12} \\ w_{21} & w_{22} \\ w_{31} & w_{32} \end{bmatrix}}_{\text{in dim}}^{\text{out dim}}

对于矩阵 W 而言,计算时需要逐列读取,对 cache 不友好,故转置存储

[x1x2x3]1×3[w11w21w31w12w22w32]in dim}  out dim\underbrace{\begin{bmatrix} x_1 & x_2 & x_3 \end{bmatrix}}_{1\times 3} \cdot \overbrace{ \begin{bmatrix} w_{11} & w_{21} & w_{31} \\ w_{12} & w_{22} & w_{32} \end{bmatrix}^\top }^{\text{in dim}} \bigg\}\;\text{out dim}


  • ColumnParallelLinear 是对 out 维度进行拆分

W:(out,in):[40961024]x:(B,S,in):[8,16,1024]tp_size=4,tp_dim=0W=[w0w3]}1024×4xWx[w0    w3]concat(y0,,y3)W: (out, in):[4096, 1024] \\ x: (B, S, in): [8, 16, 1024] \\ tp\_size= 4, tp\_dim = 0 \\ W = \left. \left[ \begin{array}{c} w_0 \\ \vdots \\ w_3 \end{array} \right]\right\}1024 \times 4 \\ x * W^\top \quad \rightarrow \quad x \cdot \left[ w_0^\top \; \cdots \; w_3^\top \right] \quad \rightarrow \quad \mathrm{concat}(y_0, \dots, y_3)

是先切分,再每块单独计算,每块持有自己的计算结果,等到 rowparallel 线性层进行聚合。输入 xx 不切分,会传递到每一个 rank 中


  • RowParallelLinear 是对 in 维度进行拆分,也需要拆分 xx

W:(out,in):[40961024]x:(B,S,in):[8,16,1024]tp_size=4,tp_dim=1W=[w0w1w2w3]2564W: (out, in):[4096, 1024] \\ x: (B, S, in): [8, 16, 1024] \\ tp\_size= 4, tp\_dim = 1 \\ W = \underbrace{ \left[ \begin{array}{c|c|c|c} w_0 & w_1 & w_2 & w_3 \end{array} \right]}_{256 * 4}

[x0x1x2x3]BS[4256]\underbrace{ \left[ \begin{array}{c|c|c|c} x_0 & x_1 & x_2 & x_3 \end{array} \right]}_{B * S * [4 * 256]}

[x0x1x2x3]BS4256[w0w1w2w3]4×[2564096]=x0w0+x1w1+x2w2+x3w3BS4096\underbrace{ \begin{bmatrix} x_0 & x_1 & x_2 & x_3 \end{bmatrix}}_{B * S * 4 * 256} \cdot \underbrace{ \begin{bmatrix} w_0^\top \\ w_1^\top \\ w_2^\top \\ w_3^\top \end{bmatrix}}_{4 \times [256 * 4096]} =\underbrace{ x_0 w_0^\top + x_1 w_1^\top + x_2 w_2^\top + x_3 w_3^\top}_{B * S * 4096}


  • MergedColumnParallelLinear 是 ColumnParallelLinear 的变体,支持一次加载多个权重块,减少通信开销
  • QKVParallelLinear 是针对自注意力机制中查询、键、值矩阵的特殊线性层,支持同时加载多个权重块,并且根据块的类型进行不同的切分和计算
  • 注意类之间的继承关系与方法绑定

注意,attention 层是拆分计算 qkv,然后在 o 层进行合并计算。而 ffn 是层拆分 gate 和 up,最后在 down 层进行合并计算
则仅有 down 和 o 层需要通信汇总,使用 rowparallel 线性层,需要 all_reduce

此处的 dist.all_reduce(y) 默认情况下是求和操作。对于 4 卡而言,每张卡持有自己的 yiy_i , 通过 all_reduce 后,每张卡上的 yy 都是 y0+y1+y2+y3y_0 + y_1 + y_2 + y_3,实现了结果的汇总。

不同通信原语的功能讲解放在后文

# embed_head

embed_head
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist

from nanovllm.utils.context import get_context


class VocabParallelEmbedding(nn.Module):

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
self.weight.weight_loader = self.weight_loader

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)

def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y # 传进来的内容被 flatten 过,是一维的,所以这里等价于-1
dist.all_reduce(y)
return y


class ParallelLMHead(VocabParallelEmbedding):

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)

def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0) # 将每张卡上的 logits 收集到 rank 0 的卡上
# 沿着最后一维拼接所有卡上的 logits,得到完整的 vocab 大小的 logits
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
embedding
1
2
3
4
5
6
7
8
def embedding(indices, weight):
# indices: 任意形状 [B, S]
# weight: [num_embeddings, embedding_dim]

output = []
for idx in indices.flatten():
output.append(weight[idx]) # 取第 idx 行
return torch.stack(output).reshape(*indices.shape, -1)
  • VocabParallelEmbedding

每张卡仅处理自己的那部分 token_id, 也只拿自己的那部分权重。输入是 token_id,输出是对应的 embedding 向量。

这里要注意的是 forward 的过程
当检查到 tp_size > 1 时,首先创建一个 mask,标记输入的 token_id 是否在当前卡负责的范围内。
对于不在范围内的 token_id,mask 是 False,对应的 embedding 输出也应该是 0。
对于在范围内的 token_id,输入 x 需要减去 vocab_start_idx ,以映射到当前卡的 embedding 权重索引范围。
而后执行 F.embedding(x, self.weight) 。注意此处输入 x 的尺寸
nanovllm 是把输入展平了,为了方便理解,先按不展平看来,输入 x 的尺寸是 [B, S] ,输出 y 的尺寸就会是 [B, S, embedding_dim]
虽然 self.weight 的尺寸是 [num_embeddings_per_partition, embedding_dim] ,但由于输入 x 已经被映射到当前卡负责的 token_id 范围内,所以 F.embedding 会正确地返回对应的 embedding 向量。

而后,不展平情况下, mask 的尺寸是 [B, S] ,需要通过 mask.unsqueeze(-1) 将其扩展为 [B, S, 1] ,然后经过广播机制与 y 的尺寸 [B, S, embedding_dim] 进行逐元素乘法,最终得到的 y 中只有当前卡负责的 token_id 的 embedding 向量是非零的。

但如果展平了,输入 x 的尺寸是 [B*S] ,输出 y 的尺寸是 [B*S, embedding_dim] ,此时 mask 的尺寸是 [B*S] ,一维向量,那么 mask.unsqueeze(1) 的尺寸是 [B*S, 1] ,同样可以与 y 的尺寸 [B*S, embedding_dim] 进行逐元素乘法,达到同样的效果。

计算结束后执行 dist.all_reduce(y) ,将所有卡上的 y 进行求和汇总,最终每张卡上的 y 都包含了所有 token_id 的 embedding 向量。这里依然是求和,因为非当前卡负责的 token_id 的 embedding 向量是 0,所以求和后不会改变结果。


  • ParallelLMHead
    和 VocabParallelEmbedding 类似,但它是输出层,输入是隐藏状态,输出是 vocab 大小的 logits。
    prefill 时:x.shape = [total_tokens, hidden_dim]
    decode 时:x.shape = [num_seqs, hidden_dim]
    W_full.shape = [Vocab, hidden_dim]
    cu_seqlens_q = [0, 3, 5],是每个 seq 的 length 累计长

计算时,每张卡拿到的权重为 W_part.shape = [vocab_per_partition, hidden_dim]
每张卡只负责计算自己那一段 vocab 的 logits,
logits.shape = [N, vocab_per_partition]

每张卡获得自己的部分 logits 后,使用 dist.gather 将所有卡上的 logits 收集到 rank 0 的卡上,最终在 rank 0 上将这些部分 logits 拼接成完整的 vocab 大小的 logits。

# rotary_embedding

rotary_embedding
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from functools import lru_cache
import torch
from torch import nn

def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)


class RotaryEmbedding(nn.Module):

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
) -> None:
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) # 外积,得到一个 [max_position_embeddings, rotary_dim//2] 的矩阵
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)

@torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key


@lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb

回忆旋转角公式:

设原向量为 v=(x,y)\boldsymbol{v}=(x,y),它的长度为 r=x2+y2r=\sqrt{x^2+y^2},与 xx 轴的夹角为 α\alpha,所以它的极坐标形式是:
x=rcosα,y=rsinαx=r\cos\alpha, \quad y=r\sin\alpha

当我们把它逆时针旋转 θ\theta 角后,新的夹角是 α+θ\alpha+\theta,新坐标 (x,y)(x',y') 满足:

{x=rcos(α+θ)y=rsin(α+θ)\begin{cases} x' = r\cos(\alpha+\theta) \\ y' = r\sin(\alpha+\theta) \end{cases}

用三角函数的和角公式展开:

x=r(cosαcosθsinαsinθ)=(rcosα)cosθ(rsinα)sinθ=xcosθysinθy=r(sinαcosθ+cosαsinθ)=(rsinα)cosθ+(rcosα)sinθ=xsinθ+ycosθ\begin{aligned} x' &= r(\cos\alpha\cos\theta - \sin\alpha\sin\theta) = (r\cos\alpha)\cos\theta - (r\sin\alpha)\sin\theta = x\cos\theta - y\sin\theta \\ y' &= r(\sin\alpha\cos\theta + \cos\alpha\sin\theta) = (r\sin\alpha)\cos\theta + (r\cos\alpha)\sin\theta = x\sin\theta + y\cos\theta \end{aligned}

把这组线性关系写成矩阵乘法,就是:

(xy)=(cosθsinθsinθcosθ)(xy)\begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x \\ y \end{pmatrix}

对于第 k 个二维子空间,在位置 i 上的旋转矩阵为:

Rki=[cosθi,ksinθi,ksinθi,kcosθi,k]R_k^i = \begin{bmatrix} \cos \theta_{i,k} & -\sin \theta_{i,k} \\ \sin \theta_{i,k} & \cos \theta_{i,k} \end{bmatrix}

其中旋转角定义为:

θi,k=iωkwithωk=1base2k/d\theta_{i,k} = i \cdot \omega_k \qquad\text{with}\qquad \omega_k = \frac{1}{\text{base}^{2k/d}}

等价地,

θi,k=ibase2k/d\theta_{i,k} = \frac{i}{\text{base}^{2k/d}}

  • base :RoPE 的底数
  • i :序列位置(seq index)
  • k :频率通道 / 二维块编号
  • d :rotary dimension

R 中的每个方块都是上述的一个二维方阵

Rx:[](x0x1xd1)=(x0x1x2x3xd1)(cosθ0cosθ0cosθ1cosθ1cosθd/21)+(x1x0x3x2xd2)(sinθ0sinθ0sinθ1sinθ1sinθd/21)R \cdot x : \begin{bmatrix} \square & & & \\ & \square & & \\ & & \ddots & \\ & & & \square \end{bmatrix} \begin{pmatrix} x_0 \\ x_1 \\ \vdots \\ x_{d-1} \end{pmatrix} =\begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \\ x_{d-1} \end{pmatrix} \odot \begin{pmatrix} \cos\theta_0 \\ \cos\theta_0 \\ \cos\theta_1 \\ \cos\theta_1 \\ \vdots \\ \cos\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ \vdots \\ x_{d-2} \end{pmatrix} \odot \begin{pmatrix} \sin\theta_0 \\ \sin\theta_0 \\ \sin\theta_1 \\ \sin\theta_1 \\ \vdots \\ \sin\theta_{d/2-1} \end{pmatrix}

但注意代码的实现,并不是原始的 rope,而是向量 x 的前半部分乘以 cos,后半部分乘以 sin,然后再进行线性组合。

只对 qk 做 rope

rope 的尺寸变化:
[max_pos, rotary_dim] -> [max_pos, 1, rotary_dim] -> [max_pos, head_num, head_size], 这里的 unsqueeze 是为了适配多头。但不适配 [batch, num_heads, seq_len, head_dim] 这种形状的多头,注意尺寸定义

# attention

attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@triton.jit
def store_kvcache_kernel(
key_ptr, # 指针
key_stride, # 每个 token 的 key 在内存中的跨度,单位是元素个数
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr, # 传的是个张量
D: tl.constexpr, # 常量
):
idx = tl.program_id(0) # 每个线程获取自己的 idx,idx 的范围是 [0, N)
slot = tl.load(slot_mapping_ptr + idx) # slot = slot_mapping[idx]
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)


def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1 # 每个 token 的 key 和 value 在内存中是连续存储的
assert key.stride(1) == head_dim and value.stride(1) == head_dim # 每个 head 之间的跨度是 head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D # k_cache 和 v_cache 中每个 slot 之间的跨度是 D
assert slot_mapping.numel() == N # slot_mapping 中每个 token 对应一个 slot
# 调用方式是 func[元组参数](普通参数),元组参数表示开启多少个线程。可以是多维的
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)


两个 kvcache 代码要一起看:
stride() 函数会获取对应维度在内存中的跨度,单位是元素个数。stride (-1) 获取最后一个维度的跨度,连续存储的话就是 1。stride (x) 一般就是 x+1 的 size

slot_mapping 是一个长度为 N 的张量,记录了每个 token 对应的 slot 编号。slot 的范围是 [0, num_slots),其中 num_slots 是 k_cache 和 v_cache 中可用的 slot 数量。对于需要缓存的 token,slot_mapping 中对应位置的值是一个非负整数,表示该 token 应该存储到 k_cache 和 v_cache 中的哪个 slot;对于不需要缓存的 token,slot_mapping 中对应位置的值是 -1。

在 kernel 函数中,每个线程根据自己的 idx 从 slot_mapping 中获取对应的 slot 编号。

key_offsets = idx * key_stride + tl.arange(0, D) 这里的 kv stride 就是每个 token 的内存跨度,num_heads * head_dim = D

同样的方法获取 cache 中的内存偏移。每个 token 的跨度依然是 D

kernel 函数把 kv 存放进 cache 里

[q1q2q3][k1,  k2,  k3][v1v2v3]=[q1k1q1k2q1k3q2k1q2k2q2k3q3k1q3k2q3k3][v1v2v3]\begin{bmatrix} q_1 \\ q_2 \\ q_3 \end{bmatrix} \begin{bmatrix} k_1^\top,\; k_2^\top,\; k_3^\top \end{bmatrix} \begin{bmatrix} v_1 \\ v_2 \\ v_3 \end{bmatrix} =\begin{bmatrix} q_1 k_1^\top & q_1 k_2^\top & q_1 k_3^\top \\ q_2 k_1^\top & q_2 k_2^\top & q_2 k_3^\top \\ q_3 k_1^\top & q_3 k_2^\top & q_3 k_3^\top \end{bmatrix} \begin{bmatrix} v_1 \\ v_2 \\ v_3 \end{bmatrix}

=[q1k1v1+q1k2v2+q1k3v3q2k1v1+q2k2v2+q2k3v3q3k1v1+q3k2v2+q3k3v3]=\begin{bmatrix} q_1 k_1^\top v_1 + q_1 k_2^\top v_2 + q_1 k_3^\top v_3 \\ q_2 k_1^\top v_1 + q_2 k_2^\top v_2 + q_2 k_3^\top v_3 \\ q_3 k_1^\top v_1 + q_3 k_2^\top v_2 + q_3 k_3^\top v_3 \end{bmatrix}

=[q1q2q3](k1v1+k2v2+k3v3)=\begin{bmatrix} q_1 \\ q_2 \\ q_3 \end{bmatrix} \left( k_1^\top v_1 + k_2^\top v_2 + k_3^\top v_3 \right)

每加入一个新的 token,则引入新的 qkv,新的 q 会和之前所有的 kv 进行计算,所以需要 kvcache

attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Attention(nn.Module):

def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
# flash attention 变长版本,可以同时传输不同 batch,根据 cu_seqlens 来区分每个 batch 的实际长度
# block_table 是为了支持 prefix cache 的,是映射表
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode q 变成[B, 1, H, D],因为新来的只有一个token
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o

初始使用空张量给 k_cache 和 v_cache 占位,后续在 model_runner 中进行替换绑定

两种 flash attention,prefill 时使用 flash_attn_varlen_func,支持变长输入,输入的 qkv 都是展平的,使用 cu_seqlens 来区分每个 batch 的实际长度
decode 时使用 flash_attn_with_kvcache,输入的 q 是 [B, 1, H, D] 的形状,因为每次只处理一个新 token,kv 则直接使用 k_cache 和 v_cache

# sampler

sampler
1
2
3
4
5
6
7
8
9
10
11
class Sampler(nn.Module):

def __init__(self):
super().__init__()

@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1) # argmax 采样,得到id而不是最大值本身
return sample_tokens

首先将 logits 除以 temperature,得到调整后的 logits。
然后对调整后的 logits 进行 softmax,得到每个 token 的概率分布。

sampler
1
2
3
noise = torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
scores = probs / noise
sample_tokens = scores.argmax(dim=-1)

采样流程等价于先构造一个与 probs 同形状的噪声张量,噪声服从指数分布(exponential distribution),然后将 probs 除以这个噪声,得到一个新的分数张量 scores,最后在 scores 上取 argmax,得到采样的 token id。

这种方法和传统的 multinomial 采样类似,都是循概率采样,但对 gpu 更友好,全是单次元素操作和 argmax,且支持 torch.compile 优化,而 multinomial 是控制流加顺序扫描,不太适合 gpu 和 torch.compile

EiExp(1)E_i \sim \text{Exp}(1)(独立同分布标准指数分布),则:

argmaxipiEi分类分布\arg\max_i \frac{p_i}{E_i} \sim \text{分类分布}

P(argmaxjpjEj=i)=piP\left( \arg\max_j \frac{p_j}{E_j} = i \right) = p_ipip_i 为采样概率,满足 ipi=1\sum_i p_i = 1)。

EExp(λ)E \sim \text{Exp}(\lambda),有:

P(E>t)=eλtP(E > t) = e^{-\lambda t}

标准指数分布(λ=1\lambda=1)简化为:

P(E>t)=etP(E > t) = e^{-t}

argmaxjpjEj=i\arg\max_j \frac{p_j}{E_j} = i 等价于:

piEi>pjEj(ji)    Ej>pjpiEi(ji)\frac{p_i}{E_i} > \frac{p_j}{E_j} \quad (\forall j \neq i) \implies E_j > \frac{p_j}{p_i} E_i \quad (\forall j \neq i)

EjE_j 相互独立,联合概率可拆分为乘积形式:

\begin{align*} P\left( \arg\max_j \frac{p_j}{E_j} = i \right) &= \int_0^\infty f_{E_i}(t) \prod_{j \neq i} P\left( E_j > \frac{p_j}{p_i} t \right) dt \end{align*}

其中:

  • fEi(t)=etf_{E_i}(t) = e^{-t}(标准指数分布概率密度)
  • P(Ej>pjpit)=epjpitP\left( E_j > \frac{p_j}{p_i} t \right) = e^{-\frac{p_j}{p_i} t}(代入指数分布性质)

代入并化简指数项(利用 jpj=1\sum_j p_j = 1):

\begin{align*} \text{原式} &= \int_0^\infty e^{-t} \cdot \prod_{j \neq i} e^{-\frac{p_j}{p_i} t} dt \\ &= \int_0^\infty e^{-t \left( 1 + \sum_{j \neq i} \frac{p_j}{p_i} \right)} dt \\ &= \int_0^\infty e^{-\frac{t}{p_i}} dt = p_i \end{align*}

# qwen3

qwen3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class Qwen3Attention(nn.Module):

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
self.qkv_bias = qkv_bias

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
if not self.qkv_bias:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias:
q = self.q_norm(q)
k = self.k_norm(k)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1)) # 把第一维到最后一维压平
return output


class Qwen3MLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x


class Qwen3DecoderLayer(nn.Module):

def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', True),
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual


class Qwen3Model(nn.Module):

def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(
self,
config: Qwen3Config
) -> None:
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
loader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)


def load_model(model: nn.Module, path: str):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))

# engine

pipline
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
LLM.generate()
-> LLMEngine.add_request()
-> Scheduler.add()
-> while not finished:
LLMEngine.step()
-> Scheduler.schedule()
-> ModelRunner.run()
-> prepare_prefill() / prepare_decode()
-> model forward
-> sampler
-> Scheduler.postprocess()
即:
用户 prompt
-> Sequence
-> Scheduler 决定本轮跑哪些请求
-> BlockManager 分配 KV cache block
-> ModelRunner 准备张量并调用模型
-> Sampler 采样新 token
-> Scheduler 更新状态
-> 完成后 decode 回文本

以下按照 sequence -> block manager -> scheduler -> model runner -> llm engine 的顺序介绍。

# sequence

sequence
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# 定义一个枚举类SequenceStatus,用于表示序列的状态
class SequenceStatus(Enum):
WAITING = auto() # 等待状态
RUNNING = auto() # 运行状态
FINISHED = auto() # 完成状态


# 定义Sequence类,用于表示一个文本生成序列(如对话或文本生成任务中的token序列)
class Sequence:
block_size = 256 # 定义块大小为256,用于将token序列分块处理
counter = count() # 类级别的计数器,用于生成唯一的序列ID

def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
# 初始化序列ID,使用计数器的下一个值
self.seq_id = next(Sequence.counter)
# 初始状态设为等待
self.status = SequenceStatus.WAITING
# 复制输入的token_ids列表,避免外部修改影响内部
self.token_ids = copy(token_ids)
# 记录最后一个token(初始为输入序列的最后一个token)
self.last_token = token_ids[-1]
# 记录token总数(初始为输入序列的长度)
self.num_tokens = len(self.token_ids)
# 记录提示词(prompt)的token数量(初始等于输入序列长度,因为输入通常是提示词)
self.num_prompt_tokens = len(token_ids)
# 记录已缓存的token数量(初始为0)
self.num_cached_tokens = 0
# 块表,用于记录序列对应的缓存块信息
self.block_table = []
# 从采样参数中获取温度参数(用于控制生成的随机性)
self.temperature = sampling_params.temperature
# 从采样参数中获取最大生成token数
self.max_tokens = sampling_params.max_tokens
# 从采样参数中获取是否忽略结束符(eos)的标志
self.ignore_eos = sampling_params.ignore_eos

# 定义__len__方法,使得可以用len(sequence)获取token总数
def __len__(self):
return self.num_tokens

# 定义__getitem__方法,使得可以用sequence[key]获取指定位置的token
def __getitem__(self, key):
return self.token_ids[key]

# 定义属性is_finished,判断序列是否已完成(状态为FINISHED)
@property
def is_finished(self):
return self.status == SequenceStatus.FINISHED

# 定义属性num_completion_tokens,计算生成的补全(completion)token数量(总token数减去提示词token数)
@property
def num_completion_tokens(self):
return self.num_tokens - self.num_prompt_tokens

# 定义属性prompt_token_ids,获取提示词部分的token列表
@property
def prompt_token_ids(self):
return self.token_ids[:self.num_prompt_tokens]

# 定义属性completion_token_ids,获取补全部分的token列表
@property
def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:]

# 定义属性num_cached_blocks,计算已缓存的块数量(已缓存token数除以块大小,整数除法)
@property
def num_cached_blocks(self):
return self.num_cached_tokens // self.block_size

# 定义属性num_blocks,计算总块数量(总token数除以块大小,向上取整)
@property
def num_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size

# 定义属性last_block_num_tokens,计算最后一个块中的token数量
@property
def last_block_num_tokens(self):
return self.num_tokens - (self.num_blocks - 1) * self.block_size

# 定义方法block,获取第i个块的token列表
def block(self, i):
# 断言i的范围有效(0 <= i < 总块数)
assert 0 <= i < self.num_blocks
# 返回第i个块的token(从i*block_size到(i+1)*block_size的切片)
return self.token_ids[i*self.block_size: (i+1)*self.block_size]

# 定义方法append_token,向序列添加一个新token
def append_token(self, token_id: int):
self.token_ids.append(token_id) # 添加到token列表
self.last_token = token_id # 更新最后一个token
self.num_tokens += 1 # 总token数加1

# 定义__getstate__方法,用于序列化对象时获取状态(配合pickle使用)
def __getstate__(self):
# 返回需要保存的状态:总token数、提示词token数、已缓存token数、块表,
# 以及token_ids(如果补全token数为0)或最后一个token(如果有补全token)
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
self.token_ids if self.num_completion_tokens == 0 else self.last_token)

# 定义__setstate__方法,用于反序列化时恢复对象状态(配合pickle使用)
def __setstate__(self, state):
# 从状态中恢复总token数、提示词token数、已缓存token数、块表
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
# 根据补全token数恢复token_ids或last_token
if self.num_completion_tokens == 0:
self.token_ids = state[-1]
else:
self.last_token = state[-1]

# block manager

block_manager
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# 定义Block类,用于表示一个块(存储token序列的基本单位)
class Block:

def __init__(self, block_id):
# 块的唯一标识符
self.block_id = block_id
# 引用计数(表示有多少个序列引用该块)
self.ref_count = 0
# 块的哈希值(用于缓存查找)
self.hash = -1
# 块中存储的token_id列表
self.token_ids = []

def update(self, hash: int, token_ids: list[int]):
# 更新块的哈希值和token_id列表
self.hash = hash
self.token_ids = token_ids

def reset(self):
# 重置块的状态:引用计数设为1,哈希值设为-1,清空token_id列表
self.ref_count = 1
self.hash = -1
self.token_ids = []


# 定义BlockManager类,用于管理块的分配、释放和缓存
class BlockManager:

def __init__(self, num_blocks: int, block_size: int):
# 每个块的大小(能存储的token数量)
self.block_size = block_size
# 初始化块列表:创建num_blocks个Block实例
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
# 哈希值到块ID的映射(用于快速查找缓存的块)
self.hash_to_block_id: dict[int, int] = dict()
# 空闲块ID的双端队列(用于高效分配和释放)
self.free_block_ids: deque[int] = deque(range(num_blocks))
# 已使用块ID的集合
self.used_block_ids: set[int] = set()

@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
# 类方法:计算token_ids的哈希值,可传入前缀哈希(用于链式计算)
h = xxhash.xxh64() # 创建64位哈希对象
if prefix != -1:
# 如果有前缀,先更新前缀的哈希(转为8字节小端字节序)
h.update(prefix.to_bytes(8, "little"))
# 将token_ids转为numpy数组的字节流并更新哈希
h.update(np.array(token_ids).tobytes())
return h.intdigest() # 返回哈希的整数形式

def _allocate_block(self, block_id: int) -> Block:
# 内部方法:分配指定ID的块(将其从空闲状态转为使用状态)
block = self.blocks[block_id]
# 断言:确保块当前是空闲的(引用计数为0)
assert block.ref_count == 0
block.reset() # 重置块状态
self.free_block_ids.remove(block_id) # 从空闲队列移除
self.used_block_ids.add(block_id) # 添加到已使用集合
return self.blocks[block_id] # 返回分配的块

def _deallocate_block(self, block_id: int) -> Block:
# 内部方法:释放指定ID的块(将其从使用状态转为空闲状态)
# 断言:确保块的引用计数已为0(没有序列引用)
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id) # 从已使用集合移除
self.free_block_ids.append(block_id) # 添加到空闲队列

def can_allocate(self, seq: Sequence) -> bool:
# 检查是否可以为序列分配所需的块:空闲块数量 >= 序列需要的块数
return len(self.free_block_ids) >= seq.num_blocks

def allocate(self, seq: Sequence):
# 为序列分配块(填充其block_table)
# 断言:确保序列当前没有分配块
assert not seq.block_table
h = -1 # 用于存储上一个块的哈希(作为下一个块的前缀)
cache_miss = False # 标记是否发生缓存未命中
# 遍历序列需要的每个块
for i in range(seq.num_blocks):
# 获取当前块对应的token_ids
token_ids = seq.block(i)
# 计算当前块的哈希(如果是满块,使用上一个块的哈希作为前缀)
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
# 尝试从哈希映射中查找块ID
block_id = self.hash_to_block_id.get(h, -1)
# 检查缓存是否命中:未找到ID或找到的块内容不匹配
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
# 如果缓存未命中,从空闲块中分配新块
if cache_miss:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
# 如果缓存命中
else:
# 增加序列的缓存token计数
seq.num_cached_tokens += self.block_size
# 如果块已在使用中,增加其引用计数
if block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_count += 1
# 如果块是空闲的,重新分配它
else:
block = self._allocate_block(block_id)
# 如果哈希有效(当前是满块),更新块的哈希和token_ids,并更新哈希映射
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
# 将块ID添加到序列的block_table中
seq.block_table.append(block_id)

def deallocate(self, seq: Sequence):
# 释放序列占用的块(减少引用计数,必要时回收)
# 反向遍历序列的block_table(从后往前释放)
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1 # 减少引用计数
# 如果引用计数为0,释放该块
if block.ref_count == 0:
self._deallocate_block(block_id)
# 重置序列的缓存token计数和block_table
seq.num_cached_tokens = 0
seq.block_table.clear()

def can_append(self, seq: Sequence) -> bool:
# 检查是否可以为序列追加一个token(需要时分配新块)
# 当序列长度模块大小为1时,说明需要新块(因为上一个块刚满)
# 空闲块数量 >= 是否需要新块
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)

def may_append(self, seq: Sequence):
# 处理序列追加token的逻辑(更新块或分配新块)
block_table = seq.block_table
# 获取最后一个块
last_block = self.blocks[block_table[-1]]
# 情况1:序列长度模块大小为1(需要新块,因为上一个块已满)
if len(seq) % self.block_size == 1:
# 断言:上一个块必须是满块(哈希有效)
assert last_block.hash != -1
# 从空闲块中分配新块并添加到block_table
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
# 情况2:序列长度模块大小为0(当前块刚填满,需要计算哈希)
elif len(seq) % self.block_size == 0:
# 断言:当前块之前未计算哈希(非满块状态)
assert last_block.hash == -1
# 获取当前块的token_ids
token_ids = seq.block(seq.num_blocks-1)
# 计算前缀哈希(上一个块的哈希,若存在)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
# 计算当前块的哈希
h = self.compute_hash(token_ids, prefix)
# 更新块的哈希和token_ids,并更新哈希映射
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
# 其他情况:块未填满,无需特殊处理(哈希保持无效)
else:
assert last_block.hash == -1

block manager 负责管理块的分配和释放,以及序列的

这里的 reset () 是初始分配 block 用的,所以引用计数要置 1

注意 def allocate(self, seq: Sequence) 函数调用的 compute_hash 方法,他把上一个块的哈希作为前缀传入,保证了只有在前缀一致的情况下才会命中缓存,这样可以避免不同块内容相同但位置不同导致的哈希冲突问题。

sequence 新来时,调用 block_manager.allocate(sequence) ,会尝试为 sequence 的每个块分配缓存块,如果块内容和之前的块相同且位置相同,则命中缓存,直接复用块;否则分配新块并更新缓存。

使用 set 记录已使用的哈希值到块的映射。

cache 管理是 block 粒度的,只有 token 数量达到 block_size 时才计算哈希并尝试缓存。

释放时倒着释放,越靠后越不可能被其他序列共享,

allocate(seq)
用于 prefill
给整段 prompt 分配 block
顺便做 prefix cache 命中

may_append(seq)
用于 decode
每轮生成时维护 block_table
必要时追加新 block
-> 当 seq_len % block_size == 1 时,说明上一个 block 刚填满,还多了一个 token,需要新 block
当一个 block 填满时,把它登记进 hash cache

deallocate(seq)
用于请求结束或被抢占
释放该请求引用的 block

# scheduler

scheduler
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# 定义调度器类,负责管理序列的调度(预处理、解码、抢占等逻辑)
class Scheduler:

# 初始化调度器
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs # 最大并发序列数
self.max_num_batched_tokens = config.max_num_batched_tokens # 批处理的最大token数
self.eos = config.eos # 结束符token ID
# 初始化块管理器,参数为KV缓存块数量和每个块的大小
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque() # 等待队列,存储待处理的序列
self.running: deque[Sequence] = deque() # 运行队列,存储正在处理的序列

# 判断所有序列是否都已处理完成(等待队列和运行队列都为空)
def is_finished(self):
return not self.waiting and not self.running

# 向等待队列添加一个序列
def add(self, seq: Sequence):
self.waiting.append(seq)

# 调度序列,返回本次调度的序列列表和是否为预处理阶段(prefill)的标志
def schedule(self) -> tuple[list[Sequence], bool]:
# 预处理阶段(prefill):处理新加入的序列,分配初始KV缓存
scheduled_seqs = [] # 存储本次调度的序列
num_seqs = 0 # 已调度的序列数
num_batched_tokens = 0 # 已调度的token总数

# 从等待队列取序列,直到达到最大序列数、超出最大token数或无法分配缓存块
while self.waiting and num_seqs < self.max_num_seqs:
seq = self.waiting[0] # 查看等待队列的第一个序列(不弹出)
# 检查是否超出最大token数,或块管理器无法为该序列分配初始缓存
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
break # 无法继续调度,退出循环
num_seqs += 1 # 增加已调度序列数
self.block_manager.allocate(seq) # 为序列分配KV缓存块
# 累加新增的token数(总长度减去已缓存的token数,即本次需要处理的新token)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING # 更新序列状态为运行中
self.waiting.popleft() # 从等待队列移除该序列
self.running.append(seq) # 添加到运行队列
scheduled_seqs.append(seq) # 添加到本次调度列表

# 如果预处理阶段有调度到序列,返回这些序列和True(表示是预处理)
if scheduled_seqs:
return scheduled_seqs, True

# 解码阶段(decode):处理已在运行队列中的序列,生成下一个token
# 从运行队列取序列,直到达到最大序列数
while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft() # 弹出运行队列的第一个序列

# 检查是否可以为该序列追加缓存块(用于存储新生成的KV信息)
while not self.block_manager.can_append(seq):
# 如果运行队列还有其他序列,抢占最后一个序列的缓存
if self.running:
self.preempt(self.running.pop())
else:
# 运行队列无其他序列,抢占当前序列自身的缓存
self.preempt(seq)
break
else:
# 成功获取追加缓存的权限,继续调度
num_seqs += 1
self.block_manager.may_append(seq) # 准备为序列追加缓存块
scheduled_seqs.append(seq) # 添加到本次调度列表

# 断言本次调度一定有序列(否则会报错)
assert scheduled_seqs
# 将已调度的序列重新添加回运行队列(保持原有顺序)
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False # 返回调度列表和False(表示是解码阶段)

# 抢占序列的缓存块:将序列状态改为等待,释放其缓存,并加入等待队列前端
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING # 更新状态为等待
self.block_manager.deallocate(seq) # 释放该序列的KV缓存块
self.waiting.appendleft(seq) # 添加到等待队列的最前面(优先处理)

# 后处理:更新序列的token,并检查是否完成(达到EOS或最大长度)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
# 遍历序列和对应的生成token
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id) # 向序列添加生成的token
# 检查是否终止:未忽略EOS且生成了EOS,或达到最大生成token数
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED # 更新状态为已完成
self.block_manager.deallocate(seq) # 释放缓存块
self.running.remove(seq) # 从运行队列移除该序列

Scheduler.schedule 设计

schedule() 的职责是:每次 engine step 之前,决定这一轮要送给模型执行哪些 Sequence,以及这一轮是 prefill 还是 decode。

它返回:

text
1
return scheduled_seqs, is_prefill

其中:

text
1
2
scheduled_seqs: 本轮要执行的序列列表
is_prefill: True 表示本轮是 prefill,False 表示本轮是 decode

整个 scheduler 维护两个队列:

python
1
2
self.waiting: deque[Sequence]
self.running: deque[Sequence]

含义是:

text
1
2
3
4
5
waiting:
新请求,或者被抢占后需要重新 prefill 的请求。

running:
已经完成 prefill,正在逐 token decode 的请求。

一、schedule 的总体策略

schedule() 分两段:

text
1
2
1. 优先调度 waiting 队列里的请求做 prefill
2. 如果没有任何 prefill 请求能被调度,再调度 running 队列里的请求做 decode

也就是说,这个实现里 prefill 优先级高于 decode

简化伪代码:

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def schedule():
scheduled_seqs = []

# 1. 先尝试 prefill
while waiting not empty and batch not full:
if resource enough:
allocate KV cache
move seq from waiting to running
scheduled_seqs.append(seq)
else:
break

if scheduled_seqs:
return scheduled_seqs, True

# 2. 如果没有 prefill,再 decode
while running not empty and batch not full:
if can append KV cache:
may_append KV cache
scheduled_seqs.append(seq)
else:
preempt some seq

return scheduled_seqs, False

二、Prefill 调度逻辑

代码:

python
1
2
3
4
5
6
7
8
9
10
11
12
13
while self.waiting and num_seqs < self.max_num_seqs:
seq = self.waiting[0]

if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
break

num_seqs += 1
self.block_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)

这段处理的是新请求的 prompt。

调度一个 waiting seq 需要满足两个条件:

text
1
2
1. batch token 数不能超过 max_num_batched_tokens
2. KV cache block 必须足够,即 can_allocate(seq)

max_num_batched_tokens 控制 prefill 的总 token 数,因为 prefill 是一次处理整段 prompt,开销和 prompt 长度强相关。

max_num_seqs 控制本轮最多处理多少条序列。

如果满足条件,就调用:

python
1
self.block_manager.allocate(seq)

allocate() 会为这个序列的 prompt 分配 KV cache block,并尝试做 prefix cache 命中。

之后:

python
1
2
3
4
seq.status = RUNNING
waiting.popleft()
running.append(seq)
scheduled_seqs.append(seq)

意思是:

text
1
2
3
这个请求已经从 waiting 进入 running。
它本轮会被送去模型做 prefill。
prefill 完成后,后续 decode 会继续从 running 队列调度。

这里有一个细节:

python
1
num_batched_tokens += len(seq) - seq.num_cached_tokens

如果 prefix cache 命中了前面一部分 token,那么这些 token 不需要重新 prefill,所以真正进入模型计算的是:

text
1
序列总长度 - 已缓存 token 数

三、为什么 prefill 有结果就立刻返回

代码:

python
1
2
if scheduled_seqs:
return scheduled_seqs, True

意思是:只要这一轮成功调度了任何 prefill 请求,就不再混入 decode 请求。

所以这个实现是:

text
1
2
一轮 step 要么全是 prefill
一轮 step 要么全是 decode

这样做实现简单,因为 prefill 和 decode 的输入组织方式完全不同:

text
1
2
3
4
5
6
7
prefill:
输入是每条请求的 prompt 未缓存部分,长度可变。
使用 cu_seqlens_q / cu_seqlens_k / slot_mapping。

decode:
输入是每条请求的 last_token,每条序列只处理 1 个 token。
使用 context_lens / block_tables / slot_mapping。

所以 ModelRunner 可以根据 is_prefill 分别走:

python
1
2
prepare_prefill(seqs)
prepare_decode(seqs)

四、Decode 调度逻辑

如果没有 waiting 请求被调度,才进入 decode:

python
1
2
3
4
5
6
7
8
9
10
11
12
13
while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft()

while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
num_seqs += 1
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)

decode 阶段处理的是已经 prefill 完成的序列。每条 running 序列每轮最多生成 1 个 token。

这里不是调用 allocate() ,而是调用:

python
1
2
can_append(seq)
may_append(seq)

因为 decode 时序列长度每轮只增长 1。

can_append(seq) 检查:如果这次 decode 需要新 block,当前是否有空闲 block。

BlockManager 里:

python
1
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)

这里的意思是:

text
1
2
3
如果 len(seq) % block_size == 1,说明当前 token 落在一个新 block 的第一个位置,
需要提前分配新 block。
否则当前 block 还没满,不需要新 block。

如果可以 append,就调用:

python
1
self.block_manager.may_append(seq)

may_append() 做两类事情:

text
1
2
1. 如果当前 token 需要新 block,就分配一个新 block 加到 block_table。
2. 如果某个 block 刚刚被填满,就计算 hash,把它加入 prefix cache。

然后把这个 seq 加入本轮 decode batch:

python
1
scheduled_seqs.append(seq)

五、抢占 preempt 逻辑

如果 decode 阶段发现 KV cache 不够:

python
1
2
3
4
5
6
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break

这里的策略是:

text
1
2
优先抢占 running 队列尾部的序列。
如果没有其他 running 序列可以抢占,就抢占当前 seq 自己。

preempt() 做三件事:

python
1
2
3
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)

含义:

text
1
2
3
1. 把 seq 状态改回 WAITING。
2. 释放它占用的 KV cache block。
3. 放到 waiting 队列头部,后面优先重新 prefill。

抢占的本质是:

text
1
2
3
4
5
为了给当前要 decode 的请求腾出 KV cache,
牺牲另一个 running 请求的 KV cache。
被抢占的请求不会丢 token_ids,
但它的 KV cache 被释放了,
以后需要重新 prefill,或者通过 prefix cache 复用部分 KV。

这是一个简化版的 preemption 机制。


六、decode 后为什么要把 scheduled_seqs 放回 running 队列

decode 调度时,每取一个 seq 会先:

python
1
seq = self.running.popleft()

如果它被成功调度,就先放进:

python
1
scheduled_seqs

decode 调度结束后:

python
1
self.running.extendleft(reversed(scheduled_seqs))

这句的作用是把本轮成功 decode 的 seq 放回 running 队列头部,并保持原顺序。

例如:

text
1
2
3
scheduled_seqs = [A, B, C]
reversed(scheduled_seqs) = [C, B, A]
running.extendleft([C, B, A])

extendleft 会依次从左边插入,所以最终队列头部仍然是:

text
1
A, B, C

为什么要放回 running?

因为 decode 后这些请求通常还没结束。模型本轮只会为每条序列生成一个 token。生成完后, postprocess() 会:

python
1
seq.append_token(token_id)

如果没遇到 EOS,也没达到 max_tokens ,它们还要继续参与下一轮 decode。

如果某个 seq 结束了, postprocess() 会做:

python
1
2
3
seq.status = FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)

所以 running 队列里只保留还没完成的请求。


七、schedule 和 postprocess 的配合

一次 engine step 是:

python
1
2
3
seqs, is_prefill = scheduler.schedule()
token_ids = model_runner.call("run", seqs, is_prefill)
scheduler.postprocess(seqs, token_ids)

schedule() 只负责:

text
1
2
3
决定谁跑
分配或准备 KV cache block
维护 waiting/running 队列

它不负责真的追加生成 token。

真正追加 token 在:

python
1
postprocess()

代码:

python
1
seq.append_token(token_id)

然后判断:

python
1
2
3
4
if token_id == eos or seq.num_completion_tokens == seq.max_tokens:
seq.status = FINISHED
block_manager.deallocate(seq)
running.remove(seq)

所以 decode 阶段的一个完整循环是:

text
1
2
3
4
5
6
7
8
9
10
11
12
1. schedule:
选择 running seq
确保 KV cache 有空间

2. model_runner:
用 seq.last_token 做输入
计算 logits
采样出新 token

3. postprocess:
把新 token append 到 seq
如果结束,释放 KV cache

八、这个 Scheduler 的特点

这个实现是一个简化版 vLLM scheduler,特点是:

text
1
2
3
4
5
6
7
8
1. waiting 和 running 分离。
2. prefill 优先于 decode。
3. 每轮不混合 prefill 和 decode。
4. prefill 受 max_num_batched_tokens 和 max_num_seqs 限制。
5. decode 主要受 max_num_seqs 和 KV cache block 是否充足限制。
6. KV cache 不够时,会 preempt running seq。
7. 被抢占的 seq 会回到 waiting,之后重新 prefill。
8. prefix cache 可以降低重新 prefill 的成本。

可以用一句话总结:

text
1
schedule() 是一个两阶段调度器:先尽可能把 waiting 请求做 prefill;如果没有 prefill 可做,就从 running 请求里组 decode batch。调度过程中它通过 BlockManager 管理 KV cache block,必要时抢占 running 请求,把它们放回 waiting,以保证当前 batch 能继续执行。

# model runner

model_runner
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
class ModelRunner:

def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config # huggingface模型配置
self.block_size = config.kvcache_block_size # kv缓存块大小
self.enforce_eager = config.enforce_eager # 即时执行模式
self.world_size = config.tensor_parallel_size
self.rank = rank # 传入参数,其实是进程编号
self.event = event # event 用于进程间同步
# 使用nccl后端初始化分布式环境
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
# 推理一次,初始化缓存和cuda内核
self.warmup_model()
self.allocate_kv_cache()
# 捕获cuda图,加速后续推理
if not self.enforce_eager:
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)

if self.world_size > 1:
if rank == 0:
# 创建1mb共享内存区域,用于进程间通信,rpc请求
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
else:
dist.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop() # 子进程进入循环,等待主进程指令

def exit(self):
if self.world_size > 1:
self.shm.close()
dist.barrier()
if self.rank == 0:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
dist.destroy_process_group()

def loop(self):
while True:
method_name, args = self.read_shm()
self.call(method_name, *args)
if method_name == "exit":
break

def read_shm(self):
assert self.world_size > 1 and self.rank > 0
# 等待主进程发来信号
self.event.wait()
# 前 4 个字节保存了后面数据的长度 n(小端)
n = int.from_bytes(self.shm.buf[0:4], "little")
# 从 4 开始读取 n 个字节的数据,并反序列化
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
# 清除事件,准备下一次等待
self.event.clear()
return method_name, args

def write_shm(self, method_name, *args):
assert self.world_size > 1 and self.rank == 0
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
# 主进程执行一份,通知子进程也执行
def call(self, method_name, *args):
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
return method(*args)

def warmup_model(self):
# 清一下缓存、统计,以便后面 allocate_kv_cache 正确估显存占用
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
# 估算一下需要开多少条最大序列长度
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True) # prefill
torch.cuda.empty_cache()

def allocate_kv_cache(self):
config = self.config
hf_config = config.hf_config
# 这张卡的空闲+总容量
free, total = torch.cuda.mem_get_info()
used = total - free
# warmup的时候,峰值占用
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
# 当前分配
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
# 计算每层kv缓存所需字节数,估算能分配多少块
# 每个rank 持有的kv 头数
num_kv_heads = hf_config.num_key_value_heads // self.world_size
# 每个头的维度
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
# 一个block能容纳block_size个token, 全模型有num_hidden_layers层,每层有2个(k和v)矩阵
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
# 希望用于kv缓存的的开销,减去已用的,减去峰值和当前的差,剩余就是一个安全的范围,用来放 kv,除以每块字节数,得到最大 block 数
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes

assert config.num_kvcache_blocks > 0
# 真正分配
# 2:表示两个维度,0 是k,1 是v
# num_layers: 每层一份 cache
# num_kvcache_blocks: 每层的块数
# block_size: 每块的token数
# num_kv_heads,head_dim:kvshape
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
# 模型里的 Attention 模块之前定义了 self.k_cache / self.v_cache(初始为空 tensor)
# 这里遍历所有模块,凡是有这两个属性的,就给它们绑定对应层的切片
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
# 每个seq.block_table 是一个块索引列表,表示该序列的每个块在kv缓存中的位置
# 是一个一位列表,表示这条序列用了那些block id,比如[3,5,6], 用了三个,最后一个可能不满
# 不同seq的block数量可能不一样, 所以要pad成一个二维张量
def prepare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs)
# 统一所有的blocktable长度,短的用-1填充,表示无效位置,方便后续计算
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
# 加速h2d 拷贝
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# 返回的是一个2 维张量,形状为 [批量大小, max_len](例如批量 10 个序列、max_len=3 → 形状 [10,3]),可直接传入注意力层,用于定位每个序列的 KV 缓存块。
return block_tables

# 构建可变长度的输入张量
def prepare_prefill(self, seqs: list[Sequence]):
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
seqlen = len(seq) # 整条序列当前token数
input_ids.extend(seq[seq.num_cached_tokens:]) # 只取未缓存的部分
positions.extend(list(range(seq.num_cached_tokens, seqlen))) # 位置从缓存后开始
seqlen_q = seqlen - seq.num_cached_tokens # query长度
seqlen_k = seqlen # key长度,因为key要包含前缀
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
# 更新最大长度
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
# 对齐
slot_mapping.extend(list(range(start, end)))
# k 比 q 大,说明有已缓存的 k,有 cache,复用
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs)
# 将输入和位置转换为张量,并将其移动到 GPU 上
# pin_memory=True 表示将数据放在固定内存中,cuda(non_blocking=True) 表示异步传输到 GPU
# 这样可以提高数据传输效率,尤其是在多 GPU 环境中。
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# 设置全局 context,让attention,lmhead,sampler 等模块可以访问ß
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions

def prepare_decode(self, seqs: list[Sequence]):
input_ids = [] # 上一个生成的token
# 位置是上一个token的索引
positions = []
slot_mapping = [] # 上一个token在kv缓存中的位置
context_lens = [] # 每条序列的总长度
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions

def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
# 推理模式,关闭自动求导引擎,比 nograd 更彻底
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
# 根据输入批量大小,选择使用 cuda 图还是即时执行
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else: # 小batch
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
# 每次step 会调用一次
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
# 生成
logits = self.run_model(input_ids, positions, is_prefill)
# 采样
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
# 在多gpu环境下捕获多个 batch 的 cuda graph
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 512) # bs:batch size
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
# 占位
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None

for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph() # 空图占位
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()

self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)

ModelRunner 可以理解成 engine 里真正 “跑模型” 的执行器。

前面几个类的分工是:

text
1
2
3
4
5
Sequence       管一条请求的 token 状态
BlockManager 管 KV cache block 分配
Scheduler 决定这一轮跑哪些 Sequence
ModelRunner 把 Sequence 变成 tensor,调用模型 forward,再采样 token
LLMEngine 把整个流程串起来

ModelRunner 的核心职责

它主要做 5 件事:

text
1
2
3
4
5
1. 初始化模型和 GPU 环境
2. 分配 KV cache
3. 把调度出来的 Sequence 转成模型输入
4. 调用模型 forward
5. 根据 logits 采样新 token

最核心入口是:

python
1
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:

这个函数每个 step 都会被调用。


和 LLMEngine 的交互

llm_engine.py 里,一轮生成是:

python
1
2
3
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)

所以调用关系是:

text
1
2
3
4
5
6
7
Scheduler 选出要跑的 seqs

ModelRunner.run(seqs, is_prefill)

返回本轮采样出的 token_ids

Scheduler.postprocess() 把 token append 回 seq

ModelRunner 本身不决定哪些请求能跑,它只执行 scheduler 给它的任务。


run () 的流程

核心代码:

python
1
2
3
4
5
6
7
def run(self, seqs, is_prefill):
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids

简单说:

text
1
2
3
4
5
1. 根据 prefill/decode 准备 input_ids 和 positions
2. 准备采样参数,比如 temperature
3. 跑模型得到 logits
4. sampler 从 logits 里采样下一个 token
5. 返回 token_ids

Prefill 时怎么准备输入

prefill 是处理 prompt。

函数:

python
1
prepare_prefill(seqs)

它会把多个请求的 prompt 拼成一个大 batch。

关键点:

python
1
input_ids.extend(seq[seq.num_cached_tokens:])

如果 prefix cache 命中了前面一部分,那么只需要跑没缓存的 token。

比如:

text
1
2
seq.token_ids = [A, B, C, D, E]
seq.num_cached_tokens = 3

那么 prefill 只跑:

text
1
[D, E]

同时它会准备:

text
1
2
3
4
5
positions       每个 token 的位置
cu_seqlens_q query 的累积长度
cu_seqlens_k key 的累积长度
slot_mapping 每个 token 的 KV 要写到哪个 cache slot
block_tables 每条 seq 的 block table

这些不是给普通 transformer forward 用的,而是给 attention 层用的。


Decode 时怎么准备输入

decode 是每条请求只处理 1 个 token。

函数:

python
1
prepare_decode(seqs)

核心逻辑:

python
1
2
3
4
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)

意思是:

text
1
2
decode 输入的是每条 seq 的 last_token。
模型根据 last_token 预测下一个 token。

易错点:

text
1
2
3
4
decode 时输入模型的 token,不是即将生成的新 token。
而是当前序列最后一个 token。
模型 forward 后 sampler 才得到新 token。
新 token 之后由 Scheduler.postprocess() append 到 seq。

和 BlockManager 的交互

ModelRunner 不直接分配 block。

block 分配发生在 Scheduler 里:

text
1
2
3
4
5
prefill:
Scheduler -> BlockManager.allocate(seq)

decode:
Scheduler -> BlockManager.may_append(seq)

ModelRunner 只使用已经准备好的:

text
1
2
3
seq.block_table
seq.num_cached_tokens
seq.last_block_num_tokens

也就是说:

text
1
2
BlockManager 负责“KV cache 放哪”
ModelRunner 负责“根据 block_table 把 tensor 写到正确位置”

和 Attention 的交互

ModelRunner 会调用:

python
1
set_context(...)

把当前这轮推理需要的元信息放到全局 context 里。

比如 prefill:

python
1
2
3
4
5
6
7
8
9
10
set_context(
True,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
slot_mapping,
None,
block_tables,
)

decode:

python
1
2
3
4
5
6
set_context(
False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)

然后 attention 层里会:

python
1
context = get_context()

拿到这些信息。

也就是说:

text
1
2
ModelRunner 准备 context
Attention 使用 context

比如 attention 需要知道:

text
1
2
3
4
这个 token 的 K/V 写到哪个 cache slot?
这条序列的 block_table 是什么?
当前是 prefill 还是 decode?
每条序列长度是多少?

这些都由 ModelRunner 准备。


KV cache 是怎么绑定到模型里的

初始化时, ModelRunner 会调用:

text
1
self.allocate_kv_cache()

它会分配一个大 tensor:

text
1
2
3
4
5
6
7
8
self.kv_cache = torch.empty(
2,
num_layers,
num_kvcache_blocks,
block_size,
num_kv_heads,
head_dim
)

第一维 2 表示:

text
1
2
0: key cache
1: value cache

然后遍历模型里的 attention 模块:

text
1
2
3
4
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]

所以 attention 层里的:

text
1
2
self.k_cache
self.v_cache

其实指向的是 ModelRunner 分配出来的全局 KV cache。


和 Sampler 的交互

模型 forward 后会得到 logits:

text
1
logits = self.run_model(...)

然后采样:

text
1
token_ids = self.sampler(logits, temperatures).tolist()

temperatures 来自每条 Sequence

text
1
temperatures.append(seq.temperature)

所以:

text
1
2
3
Sequence 保存采样参数
ModelRunner 取出来变成 tensor
Sampler 根据 logits + temperature 采样 token

一轮完整交互

把几个类串起来就是:

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
LLMEngine.step()

Scheduler.schedule()
选出 seqs
prefill 时 allocate block
decode 时 may_append block

ModelRunner.run(seqs, is_prefill)
prepare_prefill / prepare_decode
set_context
model forward
sampler
reset_context

Scheduler.postprocess(seqs, token_ids)
append token
判断 EOS / max_tokens
结束则 deallocate block

最重要的易错点

  1. ModelRunner 不负责调度,调度是 Scheduler 做的。

  2. ModelRunner 不负责分配 block,分配是 BlockManager 做的。

  3. ModelRunner 负责把 Sequence.block_table 转成 attention 能用的 tensor。

  4. prefill 输入的是 prompt 未缓存部分。

  5. decode 输入的是 seq.last_token ,不是新 token。

  6. 新 token 是 sampler 采样出来后,由 Scheduler.postprocess() 追加到 Sequence

一句话总结:

text
1
ModelRunner 是执行层:Scheduler 告诉它跑哪些 seq,BlockManager 已经准备好 KV cache 布局,它把 seq 转成 GPU tensor,设置 attention 需要的 context,跑模型,采样出新 token,再交回 Scheduler 更新状态。

# LLM engine

llm_engine
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class LLMEngine:

def __init__(self, model, **kwargs):
#
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs)
self.ps = [] # 进程列表
self.events = [] # 进程间时间列表,储存进程间同步信号
ctx = mp.get_context("spawn") # 使用spawn启动新进程,避免fork带来的问题,不共享父进程内存空间
# 启动tensor并行进程,从1开始,因为0号进程在主进程中运行
for i in range(1, config.tensor_parallel_size):
# event.wait,set 俩方法,分别用来阻塞和唤醒进程
event = ctx.Event()
# 子进程要执行的目标类,args 是传递给子进程的参数
process = ctx.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.scheduler = Scheduler(config)
# 进程结束自动调用exit方法,清理资源
atexit.register(self.exit)

def exit(self):
# join 阻塞主进程,等待子进程结束
self.model_runner.call("exit")
del self.model_runner
for p in self.ps:
p.join()

def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
seq = Sequence(prompt, sampling_params)
# 交给调度器处理
self.scheduler.add(seq)

def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
# 如果prefill,返回正数token数,否则返回负数表示decode的token数
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens

def is_finished(self):
return self.scheduler.is_finished()

# 输入可以是prompt,也可以是tokenid列表
# 采样配置可以是单个,也可以是列表,单个复制num_prompts份
# tqdm:生成进度条
def generate(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
# 总长度是请求数量,自适应端口宽度,进度条前缀文本是Generating
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
# 处理采样参数
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
# 添加请求
for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp)
outputs = {} # 存储最终生成结果,key=seq_id(请求唯一标识),value=生成的Token ID列表
prefill_throughput = decode_throughput = 0. # 初始化吞吐量(每秒处理Token数)
while not self.is_finished(): # 循环条件:调度器中还有未完成的请求
# 记录当前时间,用于计算单轮生成耗时
t = perf_counter()

# 核心调用:执行一轮生成(调度请求→模型计算→处理结果)
output, num_tokens = self.step()

# ---------------------- 进度条更新:显示生成速度 ----------------------
if use_tqdm:
# 根据 num_tokens 的正负判断阶段(正数=Prefill,负数=Decode)
if num_tokens > 0:
# Prefill吞吐量 = 处理的总Token数 ÷ 耗时
prefill_throughput = num_tokens / (perf_counter() - t)
else:
# Decode吞吐量 = 生成的Token数(负号抵消) ÷ 耗时
decode_throughput = -num_tokens / (perf_counter() - t)
# 进度条后缀显示两个阶段的速度(单位:tok/s,Token每秒)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
})

# ---------------------- 收集已完成的请求结果 ----------------------
# output 是 step() 返回的“已完成请求列表”,每个元素是 (seq_id, 生成的Token ID列表)
for seq_id, token_ids in output:
outputs[seq_id] = token_ids # 用 seq_id 作为key,确保结果不重复
if use_tqdm:
pbar.update(1) # 每完成一个请求,进度条前进1格
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
if use_tqdm:
pbar.close()
return outputs

engine 层可以按一条请求的生命周期理解,别按文件孤立看。

1. 用户请求进来

入口在 llm_engine.py

python
1
2
3
4
5
def add_request(self, prompt, sampling_params):
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
seq = Sequence(prompt, sampling_params)
self.scheduler.add(seq)

这里做三件事:

text
1
2
3
prompt 文本 -> token ids
token ids -> Sequence
Sequence -> 放进 Scheduler.waiting

所以 Sequence 就是一条请求的状态对象。


2. Sequence 记录请求状态

sequence.py 负责保存:

text
1
2
3
4
5
6
7
token_ids             prompt + 已生成 token
last_token 当前最后一个 token
num_prompt_tokens prompt 长度
num_cached_tokens 已经命中 prefix cache 的 token 数
block_table 这条请求使用哪些 KV cache block
status WAITING / RUNNING / FINISHED
sampling 参数 temperature / max_tokens / ignore_eos

你可以把它理解成:

text
1
一条请求在 engine 里的“档案袋”。

3. Scheduler 决定这一轮跑谁

核心在 scheduler.py

python
1
seqs, is_prefill = self.scheduler.schedule()

它维护两个队列:

text
1
2
3
4
5
waiting:
还没 prefill,或者被抢占后要重新 prefill 的请求

running:
已经 prefill 完,正在 decode 的请求

schedule() 的规则很简单:

text
1
2
先尝试从 waiting 里拿请求做 prefill。
如果没有 prefill 可做,再从 running 里拿请求做 decode。

prefill 时调用:

python
1
block_manager.allocate(seq)

decode 时调用:

python
1
block_manager.may_append(seq)

易错点:

text
1
schedule 只决定谁跑,不跑模型,也不追加新 token。

4. BlockManager 管 KV cache 位置

block_manager.py 负责 KV cache block。

核心数据:

python
1
2
3
4
self.blocks
self.free_block_ids
self.used_block_ids
self.hash_to_block_id

每条 Sequence 里有:

python
1
seq.block_table

比如:

python
1
seq.block_table = [2, 5, 9]

意思是:

text
1
2
3
这条序列的第 0 个逻辑 block 在物理 KV block 2
第 1 个逻辑 block 在物理 KV block 5
第 2 个逻辑 block 在物理 KV block 9

prefill 时:

python
1
allocate(seq)

负责给整段 prompt 分配 block,并通过 hash 尝试复用 prefix cache。

decode 时:

python
1
may_append(seq)

负责随着序列变长维护 block:

text
1
2
需要新 block 时分配。
block 填满时计算 hash,加入 prefix cache。

结束或抢占时:

python
1
deallocate(seq)

释放这条 seq 引用的 block。

注意:

text
1
2
3
释放不等于清空 KV cache。
它只是减少 ref_count,把 block 放回 free list。
hash 和 token_ids 可能还保留,用于 prefix cache。

5. ModelRunner 真的跑模型

model_runner.py 是执行层。

scheduler 选好 seq 后, llm_engine 调:

python
1
token_ids = self.model_runner.call("run", seqs, is_prefill)

run() 做:

text
1
2
3
4
5
1. prepare_prefill 或 prepare_decode
2. 准备 temperature
3. model forward
4. sampler 采样 token
5. 返回 token_ids

prefill 输入是:

text
1
prompt 里还没有缓存的部分

decode 输入是:

text
1
每条 seq 的 last_token

易错点:

text
1
2
3
decode 输入模型的是 last_token。
模型输出后 sampler 才得到 next_token。
next_token 不是在 ModelRunner 里 append 的。

6. Context 把 ModelRunner 和 Attention 连起来

ModelRunner 会调用:

python
1
set_context(...)

把这些信息放到全局 context:

text
1
2
3
4
5
is_prefill
slot_mapping
block_tables
cu_seqlens_q / cu_seqlens_k
context_lens

然后 attention 层里:

python
1
context = get_context()

拿到这些东西。

所以关系是:

text
1
2
ModelRunner 根据 Sequence.block_table 准备 context
Attention 根据 context 读写 KV cache

例如:

text
1
2
3
4
5
slot_mapping:
当前 token 的 K/V 应该写到 KV cache 的哪个 slot

block_tables:
每条序列对应哪些 KV block

7. postprocess 更新请求状态

模型跑完后回到 llm_engine.py

python
1
self.scheduler.postprocess(seqs, token_ids)

scheduler.py 里:

python
1
seq.append_token(token_id)

也就是:

text
1
真正把新 token 加进 Sequence 的地方,是 postprocess。

然后判断是否结束:

python
1
2
3
4
if token_id == eos or seq.num_completion_tokens == seq.max_tokens:
seq.status = FINISHED
block_manager.deallocate(seq)
running.remove(seq)

8. generate 是完整外层循环

llm_engine.pygenerate()

python
1
2
while not self.is_finished():
output, num_tokens = self.step()

step() 是核心:

python
1
2
3
4
5
6
def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
return outputs, num_tokens

所以整个 engine 可以压缩成一句:

text
1
2
generate 不断调用 step;
每个 step 先 schedule,后 model_runner.run,最后 postprocess。

最简心智模型

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
LLMEngine:
总控,负责 generate 循环。

Sequence:
单个请求的状态。

Scheduler:
决定这一轮哪些请求跑 prefill / decode。

BlockManager:
管 KV cache block 分配、释放、prefix cache。

ModelRunner:
把 seq 转成 tensor,跑模型,采样 token。

Attention:
根据 ModelRunner 设置的 context 读写 KV cache。

一条请求的完整过程:

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
prompt
-> tokenizer.encode
-> Sequence
-> waiting

prefill:
-> Scheduler.schedule
-> BlockManager.allocate
-> ModelRunner.prepare_prefill
-> model forward
-> sampler
-> postprocess append token

decode:
-> Scheduler.schedule
-> BlockManager.may_append
-> ModelRunner.prepare_decode
-> model forward
-> sampler
-> postprocess append token

finish:
-> BlockManager.deallocate
-> tokenizer.decode

抓住这条线,engine 层就很清楚了。

# 通信原语速查

约定:pp 张卡(rank 0,1,,p10, 1, \dots, p-1),每张卡持有一个张量块。用 A,B,C,DA, B, C, D 表示数据块,下标表示来自哪个 rank。⊕ 表示归约操作(通常是 sum,也可 max/min/mean 等)。

下面所有原语都以 p=4p=4 为例做可视化。


# 1. Broadcast

语义:把 rank src 的张量发给所有 rank。

[A]beforebroadcast (src=0)[AAAA]after\underbrace{\begin{bmatrix} A \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{before}} \xrightarrow{\text{broadcast (src=0)}} \underbrace{\begin{bmatrix} A \\ A \\ A \\ A \end{bmatrix}}_{\text{after}}

broadcast
1
dist.broadcast(tensor, src=0)

流量(p1)N(p-1)N(树形实现可到 logp\log p 步)。
用处:模型权重初始化同步、超参广播。


# 2. Scatter

语义:rank src 持有 pp 块,把第 ii 块发给 rank ii

[[A0,A1,A2,A3]]before (only rank 0 has data)scatter[A0A1A2A3]after\underbrace{\begin{bmatrix} [A_0, A_1, A_2, A_3] \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{before (only rank 0 has data)}} \xrightarrow{\text{scatter}} \underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{after}}

scatter
1
dist.scatter(output, scatter_list=[A0, A1, A2, A3], src=0)

流量N(p1)/pN(p-1)/p
用处:把一个完整 batch 切给各卡。


# 3. Gather

语义:Scatter 的逆。每 rank 各有一块,全部汇到 rank dst

[A0A1A2A3]beforegather (dst=0)[[A0,A1,A2,A3]]after (only rank 0)\underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}} \xrightarrow{\text{gather (dst=0)}} \underbrace{\begin{bmatrix} [A_0, A_1, A_2, A_3] \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{after (only rank 0)}}

gather
1
dist.gather(input, gather_list=[...] if rank == 0 else None, dst=0)

用处:TP 里把 logits 汇到 rank 0 给 Sampler。


# 4. Reduce

语义:各 rank 一块,按 ⊕ 归约到 rank dst

[A0A1A2A3]beforereduce (dst=0, op=sum)[i=03Ai]after (only rank 0)\underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}} \xrightarrow{\text{reduce (dst=0, op=sum)}} \underbrace{\begin{bmatrix} \sum_{i=0}^{3} A_i \\ \cdot \\ \cdot \\ \cdot \end{bmatrix}}_{\text{after (only rank 0)}}

reduce
1
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)

流量N(p1)/pN(p-1)/p
用处:集中统计(loss、norm 汇总到主卡)。


# 5. All-Gather

语义:每 rank 一块,拼起来给所有 rank。

[A0A1A2A3]beforeall-gather[[A0,A1,A2,A3][A0,A1,A2,A3][A0,A1,A2,A3][A0,A1,A2,A3]]after\underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}} \xrightarrow{\text{all-gather}} \underbrace{\begin{bmatrix} [A_0, A_1, A_2, A_3] \\ [A_0, A_1, A_2, A_3] \\ [A_0, A_1, A_2, A_3] \\ [A_0, A_1, A_2, A_3] \end{bmatrix}}_{\text{after}}

all_gather
1
dist.all_gather(tensor_list, input)

流量N(p1)N(p-1)
用处:Column-Parallel Linear 输出 → 拼成完整激活;SP / FSDP 里 unshard 权重。


# 6. Reduce-Scatter

语义:各 rank 持有完整的 pp 块;对每一块做跨卡归约,结果的第 ii 块留在 rank ii

[[A0(0),A1(0),A2(0),A3(0)][A0(1),A1(1),A2(1),A3(1)][A0(2),A1(2),A2(2),A3(2)][A0(3),A1(3),A2(3),A3(3)]]before; upper idx = rank, lower idx = chunkreduce-scatter (sum)[rA0(r)rA1(r)rA2(r)rA3(r)]after: rank i 只留 chunk i\underbrace{\begin{bmatrix} [A_0^{(0)}, A_1^{(0)}, A_2^{(0)}, A_3^{(0)}] \\ [A_0^{(1)}, A_1^{(1)}, A_2^{(1)}, A_3^{(1)}] \\ [A_0^{(2)}, A_1^{(2)}, A_2^{(2)}, A_3^{(2)}] \\ [A_0^{(3)}, A_1^{(3)}, A_2^{(3)}, A_3^{(3)}] \end{bmatrix}}_{\text{before; upper idx = rank, lower idx = chunk}} \xrightarrow{\text{reduce-scatter (sum)}} \underbrace{\begin{bmatrix} \sum_r A_0^{(r)} \\ \sum_r A_1^{(r)} \\ \sum_r A_2^{(r)} \\ \sum_r A_3^{(r)} \end{bmatrix}}_{\text{after: rank }i\text{ 只留 chunk }i}

reduce_scatter
1
dist.reduce_scatter(output, input_list=[...], op=dist.ReduceOp.SUM)

流量N(p1)/pN(p-1)/p
用处:FSDP 梯度聚合、Sequence Parallel 的反向。


# 7. All-Reduce

语义:各 rank 一块,归约后所有 rank 都持有相同的归约结果。

[A0A1A2A3]beforeall-reduce (sum)[i=03Aii=03Aii=03Aii=03Ai]after\underbrace{\begin{bmatrix} A_0 \\ A_1 \\ A_2 \\ A_3 \end{bmatrix}}_{\text{before}} \xrightarrow{\text{all-reduce (sum)}} \underbrace{\begin{bmatrix} \sum_{i=0}^{3} A_i \\ \sum_{i=0}^{3} A_i \\ \sum_{i=0}^{3} A_i \\ \sum_{i=0}^{3} A_i \end{bmatrix}}_{\text{after}}

all_reduce
1
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

等价分解

 all-reduce    reduce-scatter  +  all-gather \boxed{\ \text{all-reduce} \;\equiv\; \text{reduce-scatter} \;+\; \text{all-gather}\ }

流量(Ring-AllReduce):2N(p1)/p2N2N(p-1)/p \approx 2N(带宽最优)。
用处:DP 梯度同步;TP 的 RowParallel 出口、MLP down_proj 出口;Attention o_proj 出口。


# 8. All-to-All

语义:每个 rank 把自己的张量切成 pp 块,第 jj 块发给 rank jj;同时从 rank jj 收其第 ii 块。本质是 **“按块转置”**。

before:[[A00,A01,A02,A03][A10,A11,A12,A13][A20,A21,A22,A23][A30,A31,A32,A33]]all-to-allafter:[[A00,A10,A20,A30][A01,A11,A21,A31][A02,A12,A22,A32][A03,A13,A23,A33]]\text{before:}\quad \begin{bmatrix} [A_0^{0}, A_0^{1}, A_0^{2}, A_0^{3}] \\ [A_1^{0}, A_1^{1}, A_1^{2}, A_1^{3}] \\ [A_2^{0}, A_2^{1}, A_2^{2}, A_2^{3}] \\ [A_3^{0}, A_3^{1}, A_3^{2}, A_3^{3}] \end{bmatrix} \xrightarrow{\text{all-to-all}} \text{after:}\quad \begin{bmatrix} [A_0^{0}, A_1^{0}, A_2^{0}, A_3^{0}] \\ [A_0^{1}, A_1^{1}, A_2^{1}, A_3^{1}] \\ [A_0^{2}, A_1^{2}, A_2^{2}, A_3^{2}] \\ [A_0^{3}, A_1^{3}, A_2^{3}, A_3^{3}] \end{bmatrix}

行下标 = rank,列上标 = 目标 rank。可以看成分块矩阵的转置 TijTjiT_{ij} \to T_{ji}

all_to_all
1
2
dist.all_to_all(output_list, input_list)
# 或 all_to_all_single(output, input) # 等长情形

流量N(p1)/pN(p-1)/p(每 rank 发 / 收)。
用处:MoE 的 dispatch /combine(token → expert 两次 all-to-all);Sequence Parallel 切 head 与切 seq 互转。


# 9. Barrier

语义:无数据交换,纯同步点。

rank 0: barrier\text{rank 0: } \cdots \longrightarrow \|_{\text{barrier}} \longrightarrow \cdots

barrier
1
dist.barrier()

# 10. 点对点:Send / Recv

语义:有向的单播。

rank s:Xsendrank r:X\text{rank } s: X \xrightarrow{\text{send}} \text{rank } r:X

send/recv
1
2
3
dist.send(tensor, dst=r)
dist.recv(tensor, src=s)
# 或异步:dist.isend / dist.irecv

用处:Pipeline Parallel 相邻 rank 传激活 / 梯度。


# 11. 组合关系与流量对照

关键恒等式(Ring-AllReduce 正是这么实现的):

all-reduce  =  reduce-scatter  +  all-gather\text{all-reduce} \;=\; \text{reduce-scatter} \;+\; \text{all-gather}

all-gather  =  gather  +  broadcast\text{all-gather} \;=\; \text{gather} \;+\; \text{broadcast}

reduce  =  reduce-scatter  +  gather\text{reduce} \;=\; \text{reduce-scatter} \;+\; \text{gather}

原语 每卡流量 (Ring) 结果位置 归约?
broadcast N\approx N 所有 rank
scatter N(p1)/pN(p-1)/p (from src) 分散
gather N(p1)/pN(p-1)/p (to dst) 单 rank
reduce N(p1)/pN(p-1)/p 单 rank
all-gather N(p1)/pN(p-1)/p 所有 rank
reduce-scatter N(p1)/pN(p-1)/p 分散
all-reduce 2N(p1)/p2N(p-1)/p 所有 rank
all-to-all N(p1)/pN(p-1)/p 重排

# 12. 用在 Transformer 里的映射

场景 原语 数学形式
DP 梯度同步 all-reduce gˉ=1pigi\bar g = \frac{1}{p}\sum_i g_i
TP o_proj / down_proj 出口 all-reduce y=ixiWiy = \sum_i x_i W_i^\top
TP VocabParallelEmbedding all-reduce imaskiEi(x)\sum_i \text{mask}_i\odot E_i(x)
TP ParallelLMHead 收 logits gather logits 汇到 rank 0
Sequence Parallel Attn 入口 all-gather 把切 seq 的激活拼回
Sequence Parallel Attn 出口 reduce-scatter 拼回后的激活归约再切
FSDP 前向 unshard all-gather W=[W0,,Wp1]W = [W_0, \dots, W_{p-1}]
FSDP 反向 re-shard grad reduce-scatter g_i = \sum_j g_j^
MoE dispatch / combine all-to-all × 2 token ↔ expert 两次转置
Pipeline 前向传激活 send / recv hh_\ell 从 rank \ell+1\ell+1
更新于 阅读次数

请我喝[茶]~( ̄▽ ̄)~*

小春日和 微信支付

微信支付

小春日和 支付宝

支付宝

小春日和 wechat

wechat