09. 显存优化有哪些常见手段?

整理推理部署中的常见显存优化手段。

简单回答

显存占用主要来自三部分:模型参数、KV Cache、以及计算过程中的中间激活值。对应的优化手段包括:量化(减小参数和 KV Cache 的精度)、PagedAttention(减少 KV Cache 的碎片和浪费)、GQA/MQA(从架构层面减少 KV Cache 大小)、FlashAttention(减少中间激活值的显存占用)、Offloading(把部分数据卸载到 CPU 内存或磁盘)、梯度检查点(训练场景,用计算换显存)、以及模型并行(把模型分到多张 GPU 上)。

详细解释

显存占用分析

优化显存要先搞清楚显存被谁占了。推理场景下主要有两大块。

模型参数是固定开销。7B 模型 FP16 约 14GB,70B 约 140GB。INT4 量化后分别降到约 3.5GB 和 35GB。参数在整个服务运行期间常驻显存。

KV Cache 是动态开销,随着并发请求数和序列长度增长。前面分析过,单个请求在 LLaMA-7B 上 4096 长度的 KV Cache 约 2GB。并发 16 个请求就是 32GB。KV Cache 通常是推理场景显存的最大消耗者——它决定了系统的最大并发数。

训练场景还有额外的大头:优化器状态(Adam 每个参数额外存两个状态,FP32 下是参数的 8 倍)、梯度(和参数同样大小)、中间激活值(正向传播的中间结果,反向传播需要用到)。这也是为什么训练比推理要消耗多得多的显存。

各优化手段详解

量化前面详细讲过了。参数量化直接减小模型参数的显存占用。KV Cache 也可以量化——从 FP16 量化到 INT8 或 FP8,KV Cache 大小减半,能支持的并发翻倍。

PagedAttention减少的是 KV Cache 的碎片和浪费,不减少每个 token 的 KV 实际大小,但把显存利用率从二三十提升到九十以上,等效于"多出"两三倍的可用显存。

GQA/MQA是从模型架构层面减少 KV Cache。GQA 让多个 Q 头共享 KV 头,KV 头数从 32 降到 8,KV Cache 缩小到 1/4。这个优化需要在模型训练时就确定,推理时无法改变。

FlashAttention减少的是 Attention 计算过程中的中间激活值占用。标准 Attention 需要 materialize 整个 的注意力矩阵( 是序列长度),FlashAttention 通过分块计算避免了这个 的显存开销。对推理影响最大的是 Prefill 阶段(Decode 阶段序列长度为 1 没有这个问题)。

Offloading是把部分数据从 GPU 显存卸载到 CPU 内存甚至磁盘。比如不活跃的 KV Cache 块 offload 到 CPU,需要时再搬回来。MoE 模型可以把不活跃的专家 offload 到 CPU。Offloading 的代价是增加了数据搬运延迟(PCIe 带宽远低于 HBM),所以通常用在对延迟容忍度较高的场景,或者用流水线手段把搬运和计算重叠。

梯度检查点(Gradient Checkpointing / Activation Recomputation) 是训练场景的优化。正向传播时不保存所有中间激活值,反向传播需要时重新计算。用约 33% 的额外计算换约 60~70% 的显存节省。是训练大模型的标配技巧。

模型并行(TP/PP)把模型分到多张 GPU 上,每张 GPU 只负责一部分参数和对应的 KV Cache。这是"用更多 GPU"来解决单 GPU 显存不足的问题。

实际部署的优先级

推理部署中,优化显存的优先级通常是:先量化(效果最直接,INT8 几乎无损)→ 再用 PagedAttention(升级推理框架到 vLLM)→ 确认模型是否用了 GQA(如果没有且可以换模型就选 GQA 模型)→ KV Cache 量化 → Offloading(最后手段)。

如果这些都做了显存还是不够,要么加 GPU 卡数做 TP,要么换更小的模型。

面试时可以这样答

推理场景显存主要被两块占——模型参数和 KV Cache。KV Cache 通常是更大的瓶颈,因为它随并发数和序列长度线性增长。

优化手段按优先级来说:量化是首选——模型参数 INT4/INT8,KV Cache INT8/FP8,直接砍掉一半到四分之三。PagedAttention 把 KV Cache 的显存利用率从二三十提到九十以上,等效翻倍。GQA 从架构层面缩小 KV Cache,选模型时就应该优先选 GQA 模型。FlashAttention 减少 Prefill 阶段的中间激活值占用。实在不够就 Offloading 到 CPU,或者加卡做张量并行。

训练场景还有梯度检查点、混合精度训练、ZeRO 优化器等,用计算或通信换显存。

常见追问

  1. KV Cache 量化对生成质量的影响有多大?
  2. FlashAttention 具体是怎么减少显存占用的?
  3. 你实际部署中显存是怎么分配和规划的?