Block Manager¶
约 1257 个字 226 行代码 预计阅读时间 7 分钟
block_manager.py 是 NanoVLLM 的 KV cache 块管理器。它不保存真正的 K/V 张量,只在 CPU 侧维护物理块元数据:哪些 block 空闲、哪些 block 正在被序列引用、哪些完整 token block 可以作为 prefix cache 复用。
核心目标有三个:
- 为 prefill 阶段分配
seq.block_table,把一条逻辑 sequence 映射到一组物理 KV cache block。 - 对已经完整写入的 token block 计算链式 hash,支持后续请求命中 prefix cache。
- 在 sequence 结束或被 preempt 时回收 block,同时维护引用计数和空闲队列。
Block¶
Block 是一个物理 KV cache block 的元数据对象。真正的 K/V 数据在 model_runner.kv_cache 这个 GPU 张量里,Block 只记录这块“是谁、被几个人引用、内容 hash 是什么”。
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids = []
四个字段:
block_id:物理 block 编号,对应 KV cache 大张量中的 block 下标。ref_count:引用计数。prefix cache 命中时,多条 sequence 的block_table可以指向同一个物理 block。hash:完整 token block 的链式 hash,用来索引 prefix cache。token_ids:这个 block 覆盖的 token id 列表,用于 hash 冲突后的二次校验。
两个方法:
def update(self, hash: int, token_ids: list[int]):
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
update():某个 block 在 prefill 中被完整写满后,记录它的 hash 和 token 内容。reset():从空闲池重新分配 block 时调用,默认把ref_count设为 1,表示新分配后立刻被当前 sequence 持有。
BlockManager¶
BlockManager 管理全部物理 block 的生命周期。
class BlockManager:
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
几个核心结构:
blocks:所有物理 block 的元数据数组,block_id就是数组下标。hash_to_block_id:prefix cache 的索引,记录“某个链式 hash 对应哪个物理 block”。free_block_ids:当前可分配的空闲 block 队列。used_block_ids:当前正在被至少一条 sequence 引用的 block 集合。
链式 hash¶
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
这里不是只 hash 当前 block 的 token,而是把上一个 block 的 hash 也混进去:
这样 h2 不只代表第 2 个 block 的内容,还隐含了完整前缀 block0 + block1 + block2。这对 prefix cache 很关键:两个请求只有从开头到当前 block 都一致,才应该共享这块 KV cache。
can_allocate¶
can_allocate() 在真正分配前做两件事:
- 从头扫描 sequence 的完整 token block,尽可能命中 prefix cache。
- 计算剩余需要新分配的 block 数,判断空闲 block 是否足够。
def can_allocate(self, seq: Sequence) -> int:
h = -1
num_cached_blocks = 0
num_new_blocks = seq.num_blocks
for i in range(seq.num_blocks - 1):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
break
num_cached_blocks += 1
if block_id in self.used_block_ids:
num_new_blocks -= 1
if len(self.free_block_ids) < num_new_blocks:
return -1
return num_cached_blocks
几个细节:
- 只遍历
range(seq.num_blocks - 1),也就是不拿最后一个 block 做 prefix cache 命中。最后一个 block 很可能是不满的尾块,NanoVLLM 只缓存完整 block。 - 命中 hash 后还要比较
token_ids,这是为了避免 hash 冲突导致错误复用。 - 如果命中的 block 当前已经在
used_block_ids里,说明它正在被别的 sequence 使用;复用它只需要增加ref_count,不消耗新的空闲 block,所以num_new_blocks -= 1。 - 返回
-1表示空闲 block 不够,本轮不能调度这条 sequence;返回非负数表示可复用的 prefix block 数。
allocate¶
allocate() 真正填充 seq.block_table。
def allocate(self, seq: Sequence, num_cached_blocks: int):
assert not seq.block_table
h = -1
for i in range(num_cached_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id[h]
block = self.blocks[block_id]
if block_id in self.used_block_ids:
block.ref_count += 1
else:
block.ref_count = 1
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
seq.block_table.append(block_id)
for i in range(num_cached_blocks, seq.num_blocks):
seq.block_table.append(self._allocate_block())
seq.num_cached_tokens = num_cached_blocks * self.block_size
分两段处理:
- 前
num_cached_blocks个 block:从hash_to_block_id找到已有物理 block,写进seq.block_table。 - 后面的 block:从
free_block_ids里分配新的物理 block。
seq.num_cached_tokens = num_cached_blocks * self.block_size 表示这些 token 的 K/V 已经在 cache 中,不需要重新 prefill。
_allocate_block 与 _deallocate_block¶
def _allocate_block(self) -> int:
block_id = self.free_block_ids.popleft()
block = self.blocks[block_id]
assert block.ref_count == 0
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset()
self.used_block_ids.add(block_id)
return block_id
def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
_allocate_block() 有一个容易忽略的点:如果这个空闲 block 之前带着 hash,并且 hash_to_block_id 还指向它,那么重新分配前要删掉旧 hash 映射。否则 prefix cache 可能命中一块已经被改写的物理 block。
_deallocate_block() 只在 ref_count == 0 时调用,把 block 从 used_block_ids 移回 free_block_ids。
deallocate¶
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
回收时反向遍历 seq.block_table,对每个物理 block 做引用计数减一。只有 ref_count 归零时,这个 block 才会真正回到空闲池。
这里不会立刻清空 block.hash 和 block.token_ids。这意味着一个不再被使用的 block 仍然可能作为 prefix cache 的候选存在;只要它还没被重新分配覆盖,后续请求仍可以通过 hash_to_block_id 找到并复用它。
decode 追加 block¶
decode 阶段每轮只追加一个新 token。只有当序列长度跨过 block 边界时,才需要分配新的物理 block。
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
if len(seq) % self.block_size == 1:
seq.block_table.append(self._allocate_block())
为什么是 len(seq) % self.block_size == 1?
may_append() 在模型前向前调用,此时 seq 还没有 append 本轮新 token。若当前长度模 block size 等于 1,说明当前最后一个 token 已经是新 block 的第一个位置,block_table 需要先补上这个新 block,后续 KV 才有地方写。
hash_blocks¶
def hash_blocks(self, seq: Sequence):
start = seq.num_cached_tokens // self.block_size
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
if start == end: return
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
for i in range(start, end):
block = self.blocks[seq.block_table[i]]
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block.update(h, token_ids)
self.hash_to_block_id[h] = block.block_id
hash_blocks() 在 postprocess() 中调用,用来给本轮新写满的完整 block 打 hash。
start:本轮前已经缓存到第几个 block。end:本轮后已经缓存到第几个 block。start == end:说明本轮没有新写满任何完整 block,直接返回。h = self.blocks[seq.block_table[start - 1]].hash:如果不是从第一个 block 开始,就接上前一个 block 的链式 hash。
只有完整 block 会被注册进 hash_to_block_id,尾部不满的 block 不会进入 prefix cache。
Source Code¶
from collections import deque
import xxhash
import numpy as np
from nanovllm.engine.sequence import Sequence
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids = []
def update(self, hash: int, token_ids: list[int]):
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
class BlockManager:
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
def _allocate_block(self) -> int:
block_id = self.free_block_ids.popleft()
block = self.blocks[block_id]
assert block.ref_count == 0
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset()
self.used_block_ids.add(block_id)
return block_id
def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence) -> int:
h = -1
num_cached_blocks = 0
num_new_blocks = seq.num_blocks
for i in range(seq.num_blocks - 1):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
break
num_cached_blocks += 1
if block_id in self.used_block_ids:
num_new_blocks -= 1
if len(self.free_block_ids) < num_new_blocks:
return -1
return num_cached_blocks
def allocate(self, seq: Sequence, num_cached_blocks: int):
assert not seq.block_table
for i in range(num_cached_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id[h]
block = self.blocks[block_id]
if block_id in self.used_block_ids:
block.ref_count += 1
else:
block.ref_count = 1
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
seq.block_table.append(block_id)
for i in range(num_cached_blocks, seq.num_blocks):
seq.block_table.append(self._allocate_block())
seq.num_cached_tokens = num_cached_blocks * self.block_size
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
if len(seq) % self.block_size == 1:
seq.block_table.append(self._allocate_block())
def hash_blocks(self, seq: Sequence):
start = seq.num_cached_tokens // self.block_size
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
if start == end: return
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
for i in range(start, end):
block = self.blocks[seq.block_table[i]]
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block.update(h, token_ids)
self.hash_to_block_id[h] = block.block_id
Last update: May 12, 2026
Discussion