Scheduler¶
约 1162 个字 134 行代码 预计阅读时间 6 分钟
Scheduler 是 NanoVLLM 的推理调度器,负责在 waiting 和 running 两个队列之间推进请求,并协调 BlockManager 完成 KV cache 块的分配、复用和回收。
schedule():决定本轮跑哪些 sequence,以及每条 sequence 跑多少 token。postprocess():处理模型输出,更新 KV cache 进度、追加新 token,并在结束时释放块。
class Scheduler:
def __init__(self, config: Config):
...
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque() # 等待 prefill 的序列
self.running: deque[Sequence] = deque() # 正在 decode 的序列
Schedule¶
为什么采用 prefill-first?¶
在传统在线推理调度中,系统往往优先处理 prefill,因为只有先完成 prefill 才能生成第一个 token,有助于优化 TTFT。但系统中通常同时存在 prefill 和 decode 请求,如果某个请求的 prefill 很长,它会持续占用 token budget,拖慢其他请求的 decode。
NanoVLLM 的策略是:只要 waiting 里还有请求,就先做 prefill;只有 waiting 清空后,才做 running 里的 decode。
NanoVLLM 采用 prefill-first + FIFO + 最简 chunked prefill,主要优点是实现简单,并且能优先降低 TTFT。
- 对新请求来说,只有 prefill 完成后才能生成第一个 token。prefill-first 会尽快把 prompt 的 KV cache 写好,让请求更快进入 decode,所以首 token 延迟更低。
- FIFO 让行为很直观:谁先来谁先处理,没有复杂优先级、预算分配、抢占排序。
- 最简 chunked prefill 可以避免超长 prompt 永久卡住队头。如果一个 prompt 长度超过
max_num_batched_tokens,不切的话它永远进不了 batch;切成多轮后,至少系统能继续推进。
代价是 decode 稳定性差:只要 waiting 队列持续有请求,调度器就会一直做 prefill,running 里的 decode 请求会被阻塞,导致 ITL 变差、decode 吞吐抖动。
Trade-off:NanoVLLM 选择的是“简单、低 TTFT、可推进”的调度,而不是生产级的 decode 公平性和吞吐稳定性。
block_table¶
if not seq.block_table:
num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
break
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
else:
num_tokens = seq.num_tokens - seq.num_cached_tokens
block_table 为空时:新序列第一次进调度,或者被 preempt 过(deallocate 把 block_table 清掉了)。
- 需要先问 BlockManager “空块够吗、能命中多少 prefix cache”,返回
-1就是现有空块不够,本轮做不了,直接停。 - 够就算出还要处理多少 token:总长减掉 cache 命中的部分。
block_table 不为空只有一种情况:chunked prefill 的续跑。
- 序列上轮跑了一部分没跑完,没弹出
waiting,block_table还是上轮分的那些块。 - 不需要再问 BlockManager,直接拿
num_cached_tokens算剩余。 - 区别在于不再检查空块够不够:块已经分过了,只是没填满。
chunked prefill¶
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
...
seq.num_scheduled_tokens = min(num_tokens, remaining)
remaining < num_tokens:本轮 token budget 装不下这条序列需要的全部 token。scheduled_seqs 不为空意味着本轮已经有其他序列在跑。两个条件同时为真就停——后面的序列等下一轮。
只有第一个序列(scheduled_seqs 空)时,这个 break 不会触发,即使 token 不够也继续往下走,用 min(num_tokens, remaining) 切一块。这就是只对队头做 chunk。
prefill 结束条件¶
假设 max_num_batched_tokens=20,A(25 token)、B(8)、C(6) 依次入队:
- 第一轮:A 队头,需 25 但
remaining=20,scheduled_seqs空 → chunk 切 20,budget 用光,A 剩 5 token 留在waiting。 - 第二轮:A 队头,需 5 ≤
remaining=20→ 全跑,弹出进running。剩 15 budget,B(8)、C(6) 都塞下,waiting清空。 - 第三轮起:
waiting空,走 decode,每序列每轮各 1 token。
prefill 的 while 有四个退出条件:
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
# ...
break # ← 1, 2, 3
# ...
# loop end # ← 4
remaining == 0:budget 用光,比如 A 切 20 token 后剩余 0,下一轮继续。can_allocate返回-1:空闲 block 不够。- chunk 闸门触发:
remaining < num_tokens and scheduled_seqs,非队头长序列。 waiting空了或scheduled_seqs达上限:while条件自然结束。
while 结束后看 scheduled_seqs:
scheduled_seqs 是本轮要跑的序列集合,决定当前 step 进入 prefill 还是 decode。
decode¶
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
当 block 不够的时候,优先挤掉 running 队尾其他 sequence 的 KV;如果 running 只剩自己,就挤掉自己并退回去重做 prefill。
挤别人时,当前 sequence 可以继续跑,被挤的人承担代价;挤自己时,当前 sequence 自己承担代价。谁被 preempt,谁重来。
- block 采用懒分配:只在真正需要新 block 时才分配。每个 decode step 追加一个 token,每
block_size个 token 才会跨一次 block 边界,所以需要分配新 block 的概率约为 \(\frac{1}{\text{block\_size}}\)。
被挤掉的序列通过 preempt 回到 waiting 队头,KV cache 被回收,下一轮重新 prefill:
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
seq.is_prefill = True
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
Trade-off:用低概率的 preempt 换更高的显存利用率。
postprocess¶
模型前向跑完后,LLMEngine.step() 把结果交给 postprocess()。
hash_blocks() 每轮都调,但内部 start == end 时会直接 return,相当于空操作。只有整块填满的少数轮次才真正跑 compute_hash 并注册进 hash_to_block_id,大部分轮次不会产生额外工作。
Source Code¶
from collections import deque
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_size = config.kvcache_block_size
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque() # 等待 prefill 的序列
self.running: deque[Sequence] = deque() # 正在 decode 的序列
def is_finished(self):
return not self.waiting and not self.running
def add(self, seq: Sequence):
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
scheduled_seqs = []
num_batched_tokens = 0
# prefill 阶段:从 waiting 取序列,受 token 总量和 seq 数约束
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
seq = self.waiting[0]
remaining = self.max_num_batched_tokens - num_batched_tokens
if remaining == 0:
break
if not seq.block_table:
num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
break
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
else:
num_tokens = seq.num_tokens - seq.num_cached_tokens
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
if not seq.block_table:
self.block_manager.allocate(seq, num_cached_blocks)
seq.num_scheduled_tokens = min(num_tokens, remaining)
num_batched_tokens += seq.num_scheduled_tokens
if seq.num_cached_tokens + seq.num_scheduled_tokens == seq.num_tokens:
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
if scheduled_seqs:
return scheduled_seqs, True # 本轮做 prefill
# decode 阶段:waiting 空了,从 running 取序列,每序列 1 个 token
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs)) # 放回队头,保持 FIFO
return scheduled_seqs, False # 本轮做 decode
def preempt(self, seq: Sequence):
# 回收块的 KV cache,重新排队等 prefill
seq.status = SequenceStatus.WAITING
seq.is_prefill = True
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
for seq, token_id in zip(seqs, token_ids):
self.block_manager.hash_blocks(seq) # 新填满的块打 hash 注册
seq.num_cached_tokens += seq.num_scheduled_tokens
seq.num_scheduled_tokens = 0
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
continue # chunked prefill 未完成,不追加 token
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)
Last update: May 5, 2026
Discussion