Skip to content
Self-Knowing

Model Runner

约 1469 个字 40 行代码 预计阅读时间 5 分钟

ModelRunner 是 GPU 侧的完整执行器——加载模型、管 KV cache、拼装 attention 元数据、跑 forward、采样。TP > 1 时每个 GPU 一个进程,rank 0 做主控,其余 worker 循环等待指令。


prepare Prefill

Prepare Prefill 就是把 CPU 侧一些零散的 Sequence 状态,拼装成适合 GPU 的张量形式。

Block_table : 记录 sequence 的 KV 对应的物理块号。

因为 block 分配发生在 schedule 阶段,在模型跑 prefill 之前

调度和执行的顺序:

step() 里:
  ① scheduler.schedule()        ← 此时分配 block_table
     ├─ can_allocate(seq)       检查 prefix cache,看要几个块
     ├─ allocate(seq, ...)      把物理块号写进 seq.block_table  ← 这里写的!
     └─ num_scheduled_tokens=44 决定本轮算多少 token

  ② model_runner.run()          ← block_table 已经有值了
     └─ prepare_prefill(seqs)   用 block_table 算 slot_mapping

allocate 是 CPU 侧的簿记——BlockManager 说"这几个物理块归你了",写在 block_table 里。这时候 GPU 上还没跑任何计算,KV cache 里可能还是旧数据或零。

真正往 KV cache 里写数据,是在第 ② 步模型 forward 跑 flash attention 的时候。flash attention 根据 slot_mapping 把新 token 的 K/V 写道 kv_cache 的对应位置上。

所以 block_table 先有值(CPU 侧分好地盘),prefill 才拿着这些值去算 K/V 该写哪(GPU 侧填数据)。

prefix cache

cu_seqlens_q = num_scheduled_tokens 
cu_seqlens_k = num_scheduled_tokens + num_cached_tokens

if cu_seqlens_k[-1] > cu_seqlens_q[-1]:    # prefix cache
            block_tables = self.prepare_block_tables(seqs)

cu_seqlens_q:当前 batch 中,每条样本的 Q 序列长度是多少

cu_seqlens_k:当前 batch 中,每条样本的 K 序列长度是多少

prefix cache 的条件中,两种写法是等价的

  • sum(seq.num_cached_tokens for seq in seqs) > 0
  • cu_seqlens_k[-1] > cu_seqlens_q[-1]

prefill and decode

Prefill 的 n:prompt token 数量。

prompt = "今天天气怎么样"
n = 6 个 token
→ 6 个 token 的 K/V 全部算出,写入 KV cache

prefill 后:cache 里有 n 份 K/V,序列总 token 数 = n。


Decode 的 n:这一轮开始时,序列里已有的 token 总数。

decode:n = 6(prefill 刚结束)
  cache 里已有 n-1 = 5 份 K/V(前 5 个 prompt token)
  第 6 个 token(last_token)进来,当场算它的 K/V,也写进 cache
  Q 去 attend 到全部 6 个 K/V → 产第 7 个 token

所以 decode 里 "n-1 个来自 cache,1 个现算",合计 n 份 K/V,产出第 n+1 个 token。n 每轮涨 1。


生成文本时,每个新 token 都依赖前一个 token 的采样结果。你不知道 token N+1 是什么,必须先算出 token N 并采样,才能把 token N 喂进去。所以生成阶段天生是逐 token 的,每轮只算 1 个。

Prefill: 我知道 prompt 的所有 token → 一次性并行全算 → compute-bound
Decode:  我不知道下一个 token 是什么 → 必须逐轮算 → memory-bound

两者瓶颈完全不同,所以用不同的 kernel 分开优化:

  • Prefill 用 varlen:并行处理不等的 seq,矩阵乘法密集
  • Decode 用 with_kvcache:针对 "读 cache > 读 Q > 写 cache" 的模式,专门优化 cache 访问路径

run model

def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
        input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
        temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
        logits = self.run_model(input_ids, positions, is_prefill)
        token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
        reset_context()
        return token_ids

run() 是 ModelRunner 对外暴露的唯一入口,做完一轮完整推理。

它被 scheduler 调用,拿到一批 seq 和一个 is_prefill 标志。

  • 第一步是拼装:根据 prefill 还是 decode 选不同的函数,把分散在各 seq 对象里的 last_tokennum_cached_tokensblock_table 等信息翻译成 input_idspositions 两个 GPU 张量,同时把 slot_mappingcontext_lensblock_tables 等 attention 元数据贴到公告板上。

  • 第二步取出每条 seq 的采样温度:拼成 [batch_size] 的 GPU 张量。rank 1 以上不做采样,跳过。

  • 第三步跑模型:prefill 因为序列长度不固定,直接调 model(input_ids, positions);decode 形状固定(每 seq 1 token),走 CUDA graph replay 省 kernel launch overhead。所有 rank 都参与,TP 模式下每层 forward 自动做 all-reduce。模型内部,hidden_states 从 embedding 出发,逐层经 QKV 投影、RoPE、flash attention(从公告板取元数据读写 KV cache)、output projection、MLP,最后 lm_headhidden_dim 映射到 vocab_size,产出一组原始 logits 分数。

  • 第四步采样:只有 rank 0 执行:logits 除以 temperature 调节随机程度,softmax 转概率,Gumbel-max 随机选出一个整数 token id。这是真正的"生成"动作——从一个 vocab_size 的分布中选一个词。

  • 最后清空公告板,把 token id 列表返回给 scheduler。scheduler 的 postprocess 决定这些 token 是里程碑(chunked prefill 未完成时不追 token)还是终点(EOS 则终止序列)。

allocated kv cache

block_bytes = 2 * num_layers * block_size * num_kv_heads * head_dim * itemsize

self.kv_cache = torch.empty(2, num_layers, num_blocks, 256, num_kv_heads, head_dim)

各轴含义

大小 含义
0 2 K=0, V=1
1 num_layers 第几层(每层独立)
2 num_blocks 第几个物理 block
3 256 block 内的第几个 token
4 num_kv_heads TP 后每卡分到的 KV head 数
5 head_dim 每个 head 的维度

block_bytes 的计算

以 Qwen3-0.6B + TP=1 为例:

2 × 28 × 256 × 16 × 64 × 2 = 29,360,128 bytes ≈ 28 MB / block
↑   ↑    ↑     ↑    ↑   ↑
K+V 层  256   16   64  float16
    数  token head dim  2字节

一个物理 block 在每层 hidden layer 都有,各占 block_size 个 token 。

这个公式就是 KV Cache 的显存占用。


如何定位一个 token 的 K/V

token 位置 299,block_table = [5, 10]

逻辑块 = 299 // 256 = 1
物理块 = block_table[1] = 10
块内偏移 = 299 % 256 = 43
slot = 10 * 256 + 43 = 2603

绑定到层时:

module.k_cache = self.kv_cache[0, layer_id]   # 形状 [num_blocks, 256, heads, dim]

多进程通信

四个方法组成一个简易 RPC 系统。

Rank 0                       Rank 1+ (in loop)
──────                       ─────────────────
call("run", seqs, True)
  ├─ write_shm(...)              event.wait()         ← 阻塞等信号
  │    pickle → shm              read shm → unpickle
  │    event.set() ──────────→    event.clear()
  │                              call("run", seqs)    ← 执行
  └─ self.run(seqs, True)        self.run(seqs)       ← 所有 rank 并行跑

Rank 0 通知所有 worker 要执行什么方法、什么参数,然后自己也执行。所有 rank 跑同样的代码——TP 的 all-reduce 保证了结果一致。


loop — worker 的生命

def loop(self):
    while True:
        method_name, args = self.read_shm()   # 阻塞等命令
        self.call(method_name, *args)          # 执行
        if method_name == "exit":
            break

无限循环等命令,直到收到 "exit"__init__ 里 rank > 0 最后一行就是 self.loop(),创建后就陷进去不出来了。


shared memory 布局

buf[0:4]:   数据长度 n(4 字节,小端序)
buf[4:n+4]: pickle.dumps([method_name, *args])

总共 1MB,decode 时 Sequence 压缩后只传 last_token,够用。

read_shm / write_shm

# 写(rank 0)
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little")     # 长度
self.shm.buf[4:n+4] = data                        # 内容
for event in self.event:
    event.set()                                    # 唤醒所有 worker

# 读(rank 1+)
self.event.wait()                                  # 阻塞等信号
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()                                 # 复位,准备下一轮

一对 wait/set 保证不会读到半截数据。


call — 分发器

def call(self, method_name, *args):
    if self.world_size > 1 and self.rank == 0:
        self.write_shm(method_name, *args)   # 通知 worker
    method = getattr(self, method_name, None)
    return method(*args)                      # 自己执行

Rank 0 先通知后执行,worker 收到通知后执行。getattr(self, method_name) 动态分发——传 "run" 就调 self.run,传 "exit" 就调 self.exit。所有 rank 执行完全相同的代码。


Created: May 12, 2026
Last update: May 12, 2026

Discussion