02. KV Cache 的作用是什么?为什么它能加速推理?
整理 KV Cache 的原理、显存开销与工程管理。
简单回答
KV Cache 是在自回归生成过程中缓存已计算的 Key 和 Value 矩阵,避免每生成一个新 token 时重新计算之前所有 token 的 KV。没有 KV Cache,生成第 N 个 token 需要重算前 N-1 个 token 的 KV,计算量是 ;有了 KV Cache,只需要计算新 token 的 KV,计算量降为 。代价是显存占用随序列长度线性增长,长上下文场景下 KV Cache 可能占用几 GB 甚至几十 GB 的显存。
详细解释
KV Cache 的原理
在 Transformer 的 Self-Attention 中,对于输入序列中的每个 token,需要计算它的 Query(Q)、Key(K)、Value(V),然后做注意力计算:
在自回归生成中,每次只新增一个 token。假设我们正在生成第 个 token,它的 Attention 计算需要:新 token 的 Q(),以及位置 1 到 所有 token 的 K 和 V。
问题在于:位置 1 到 的 K 和 V 在生成前一个 token 时已经算过了。如果不缓存,每次都要重新算,相当于每生成一个 token 就把之前所有 token 从头算一遍——计算量会按序列长度的平方增长。
KV Cache 的做法很直接——把每一步计算过的 K 和 V 保存下来。生成第 个 token 时,只需要计算新 token 的 ,把 追加到缓存的 K 和 V 矩阵末尾,然后用 和完整的 K 做 attention 计算。
显存占用分析
KV Cache 的大小取决于模型结构和序列长度。对于一个标准的 Transformer 模型:
其中 是层数, 是注意力头数, 是每个头的维度, 是序列长度,2 代表 K 和 V 两个矩阵。以 LLaMA-2-7B(32 层,32 头,每头维度 128)为例,FP16 下每个 token 的 KV Cache 占用约 。序列长度 4096 时就是约 2GB——一个请求就要占 2GB 显存用于 KV Cache。
在 batch serving 场景下,如果同时服务 16 个请求,KV Cache 就是 32GB,可能超过模型参数本身的显存占用。这就是为什么 KV Cache 管理是推理系统的核心挑战之一。
GQA / MQA 对 KV Cache 的影响
标准的 MHA(Multi-Head Attention)中每个头都有独立的 K 和 V。GQA(Grouped Query Attention)让多个 Q 头共享一组 K、V 头,MQA(Multi-Query Attention)更极端——所有 Q 头共享同一组 K、V。
共享 KV 头直接减少 KV Cache 的大小。LLaMA-2-70B 用了 GQA,8 个 KV 头对应 64 个 Q 头,KV Cache 大小只有 MHA 的 1/8。这是现代大模型广泛采用 GQA 的重要原因之一——不是为了减少计算量(KV 的计算量本来就不大),而是为了减少 KV Cache 占用的显存,从而支持更大的 batch 或更长的序列。
KV Cache 的工程管理
KV Cache 的管理是推理框架的核心功能之一。vLLM 的 PagedAttention 就是解决 KV Cache 管理问题的——把 KV Cache 分成固定大小的"页",按需分配和释放,避免预分配最大长度导致的显存浪费。后面 vLLM 那道题会详细讲。
还有一些 KV Cache 的压缩技术:量化 KV Cache(把 KV Cache 从 FP16 压缩到 INT8 甚至 INT4,减少一半到四分之三的显存占用);KV Cache 驱逐策略(对于超长序列,按注意力分数淘汰不太重要的 KV 条目);以及把部分 KV Cache offload 到 CPU 内存。
面试时可以这样答
KV Cache 的作用是避免自回归生成中的重复计算。每生成一个新 token,需要和之前所有 token 做 Attention,如果不缓存 KV,每步都要重算前面所有 token 的 KV,计算量按序列长度平方增长。KV Cache 把已经算过的 KV 存下来,每步只需要算新 token 的 KV 并追加到缓存里,计算量降到线性。
代价是显存占用。以 LLaMA-7B 为例,每个 token 的 KV Cache 约 512KB,4096 长度就是 2GB。Batch serving 时几十个请求同时在跑,KV Cache 占用可能超过模型参数本身。
这也是为什么 GQA 被广泛采用——不是为了减少计算量,而是为了减少 KV Cache 的显存占用。LLaMA-2-70B 用 GQA 把 KV Cache 缩小到 MHA 的八分之一。vLLM 的 PagedAttention 也是解决 KV Cache 管理问题的,把 KV Cache 按页分配避免显存碎片和浪费。
常见追问
- KV Cache 量化到 INT8 对生成质量有多大影响?
- 如果序列长度是 128K,KV Cache 的显存占用怎么算?
- 为什么 Prefill 阶段不需要 KV Cache?