Model Runner¶
约 1469 个字 40 行代码 预计阅读时间 5 分钟
ModelRunner 是 GPU 侧的完整执行器——加载模型、管 KV cache、拼装 attention 元数据、跑 forward、采样。TP > 1 时每个 GPU 一个进程,rank 0 做主控,其余 worker 循环等待指令。
prepare Prefill¶
Prepare Prefill 就是把 CPU 侧一些零散的 Sequence 状态,拼装成适合 GPU 的张量形式。
Block_table : 记录 sequence 的 KV 对应的物理块号。
因为 block 分配发生在 schedule 阶段,在模型跑 prefill 之前。
调度和执行的顺序:
step() 里:
① scheduler.schedule() ← 此时分配 block_table
│
├─ can_allocate(seq) 检查 prefix cache,看要几个块
├─ allocate(seq, ...) 把物理块号写进 seq.block_table ← 这里写的!
└─ num_scheduled_tokens=44 决定本轮算多少 token
② model_runner.run() ← block_table 已经有值了
│
└─ prepare_prefill(seqs) 用 block_table 算 slot_mapping
allocate 是 CPU 侧的簿记——BlockManager 说"这几个物理块归你了",写在 block_table 里。这时候 GPU 上还没跑任何计算,KV cache 里可能还是旧数据或零。
真正往 KV cache 里写数据,是在第 ② 步模型 forward 跑 flash attention 的时候。flash attention 根据 slot_mapping 把新 token 的 K/V 写道 kv_cache 的对应位置上。
所以 block_table 先有值(CPU 侧分好地盘),prefill 才拿着这些值去算 K/V 该写哪(GPU 侧填数据)。
prefix cache¶
cu_seqlens_q = num_scheduled_tokens
cu_seqlens_k = num_scheduled_tokens + num_cached_tokens
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs)
cu_seqlens_q:当前 batch 中,每条样本的 Q 序列长度是多少
cu_seqlens_k:当前 batch 中,每条样本的 K 序列长度是多少
prefix cache 的条件中,两种写法是等价的
sum(seq.num_cached_tokens for seq in seqs) > 0cu_seqlens_k[-1] > cu_seqlens_q[-1]
prefill and decode¶
Prefill 的 n:prompt token 数量。
prefill 后:cache 里有 n 份 K/V,序列总 token 数 = n。
Decode 的 n:这一轮开始时,序列里已有的 token 总数。
decode:n = 6(prefill 刚结束)
cache 里已有 n-1 = 5 份 K/V(前 5 个 prompt token)
第 6 个 token(last_token)进来,当场算它的 K/V,也写进 cache
Q 去 attend 到全部 6 个 K/V → 产第 7 个 token
所以 decode 里 "n-1 个来自 cache,1 个现算",合计 n 份 K/V,产出第 n+1 个 token。n 每轮涨 1。
生成文本时,每个新 token 都依赖前一个 token 的采样结果。你不知道 token N+1 是什么,必须先算出 token N 并采样,才能把 token N 喂进去。所以生成阶段天生是逐 token 的,每轮只算 1 个。
Prefill: 我知道 prompt 的所有 token → 一次性并行全算 → compute-bound
Decode: 我不知道下一个 token 是什么 → 必须逐轮算 → memory-bound
两者瓶颈完全不同,所以用不同的 kernel 分开优化:
- Prefill 用 varlen:并行处理不等的 seq,矩阵乘法密集
- Decode 用 with_kvcache:针对 "读 cache > 读 Q > 写 cache" 的模式,专门优化 cache 访问路径
run model¶
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
run() 是 ModelRunner 对外暴露的唯一入口,做完一轮完整推理。
它被 scheduler 调用,拿到一批 seq 和一个 is_prefill 标志。
-
第一步是拼装:根据 prefill 还是 decode 选不同的函数,把分散在各 seq 对象里的
last_token、num_cached_tokens、block_table等信息翻译成input_ids、positions两个 GPU 张量,同时把slot_mapping、context_lens、block_tables等 attention 元数据贴到公告板上。 -
第二步取出每条 seq 的采样温度:拼成
[batch_size]的 GPU 张量。rank 1 以上不做采样,跳过。 -
第三步跑模型:prefill 因为序列长度不固定,直接调
model(input_ids, positions);decode 形状固定(每 seq 1 token),走 CUDA graph replay 省 kernel launch overhead。所有 rank 都参与,TP 模式下每层 forward 自动做 all-reduce。模型内部,hidden_states 从 embedding 出发,逐层经 QKV 投影、RoPE、flash attention(从公告板取元数据读写 KV cache)、output projection、MLP,最后lm_head把hidden_dim映射到vocab_size,产出一组原始 logits 分数。 -
第四步采样:只有 rank 0 执行:logits 除以 temperature 调节随机程度,softmax 转概率,Gumbel-max 随机选出一个整数 token id。这是真正的"生成"动作——从一个
vocab_size的分布中选一个词。 -
最后清空公告板,把 token id 列表返回给 scheduler。scheduler 的
postprocess决定这些 token 是里程碑(chunked prefill 未完成时不追 token)还是终点(EOS 则终止序列)。
allocated kv cache¶
block_bytes = 2 * num_layers * block_size * num_kv_heads * head_dim * itemsize
self.kv_cache = torch.empty(2, num_layers, num_blocks, 256, num_kv_heads, head_dim)
各轴含义:
| 轴 | 大小 | 含义 |
|---|---|---|
| 0 | 2 | K=0, V=1 |
| 1 | num_layers | 第几层(每层独立) |
| 2 | num_blocks | 第几个物理 block |
| 3 | 256 | block 内的第几个 token |
| 4 | num_kv_heads | TP 后每卡分到的 KV head 数 |
| 5 | head_dim | 每个 head 的维度 |
block_bytes 的计算:
以 Qwen3-0.6B + TP=1 为例:
2 × 28 × 256 × 16 × 64 × 2 = 29,360,128 bytes ≈ 28 MB / block
↑ ↑ ↑ ↑ ↑ ↑
K+V 层 256 16 64 float16
数 token head dim 2字节
一个物理 block 在每层 hidden layer 都有,各占 block_size 个 token 。
这个公式就是 KV Cache 的显存占用。
如何定位一个 token 的 K/V:
token 位置 299,block_table = [5, 10]
逻辑块 = 299 // 256 = 1
物理块 = block_table[1] = 10
块内偏移 = 299 % 256 = 43
slot = 10 * 256 + 43 = 2603
绑定到层时:
多进程通信¶
四个方法组成一个简易 RPC 系统。
Rank 0 Rank 1+ (in loop)
────── ─────────────────
call("run", seqs, True)
├─ write_shm(...) event.wait() ← 阻塞等信号
│ pickle → shm read shm → unpickle
│ event.set() ──────────→ event.clear()
│ call("run", seqs) ← 执行
└─ self.run(seqs, True) self.run(seqs) ← 所有 rank 并行跑
Rank 0 通知所有 worker 要执行什么方法、什么参数,然后自己也执行。所有 rank 跑同样的代码——TP 的 all-reduce 保证了结果一致。
loop — worker 的生命¶
def loop(self):
while True:
method_name, args = self.read_shm() # 阻塞等命令
self.call(method_name, *args) # 执行
if method_name == "exit":
break
无限循环等命令,直到收到 "exit"。__init__ 里 rank > 0 最后一行就是 self.loop(),创建后就陷进去不出来了。
shared memory 布局¶
总共 1MB,decode 时 Sequence 压缩后只传 last_token,够用。
read_shm / write_shm
# 写(rank 0)
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little") # 长度
self.shm.buf[4:n+4] = data # 内容
for event in self.event:
event.set() # 唤醒所有 worker
# 读(rank 1+)
self.event.wait() # 阻塞等信号
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear() # 复位,准备下一轮
一对 wait/set 保证不会读到半截数据。
call — 分发器¶
def call(self, method_name, *args):
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args) # 通知 worker
method = getattr(self, method_name, None)
return method(*args) # 自己执行
Rank 0 先通知后执行,worker 收到通知后执行。getattr(self, method_name) 动态分发——传 "run" 就调 self.run,传 "exit" 就调 self.exit。所有 rank 执行完全相同的代码。
Last update: May 12, 2026
Discussion