大模型 kvcache 浅析

需要清楚知道它的限制,即限制在:

  • 推理阶段

  • decoder-only架构,单向注意力

推理回顾

假设模型最终生成了“遥遥领先”4个字。

当模型生成第一个“遥”字时,input=”<s>”, “<s>”是起始字符。Attention的计算如下:

为了看上去方便,我们暂时忽略scale项根号d, 但是要注意这个scale面试时经常考。

如上图所示,最终Attention的计算公式如下,(softmaxed 表示已经按行进行了softmax):

以此类推生成第四个字的时候如下:

公式为:

需要注意的一点在于QK^T这个矩阵的长宽都等于历史字符的数量,随着对话轮数的增加,这部分空间暂用是十分恐怖的

KV Cache

这里有一个特点,如果是对于一般的自回归系统,我们在生成下一个字的时候往往需要参考过去所有的字,但是实际上因为mask的存在,我们的Attention_4只需要拿当前新增字符的Q与过去生成的字符(也包括自己)的K、V相乘就可得到,所以我们可以将过去的K、V进行缓存以进行更高速的计算

下图展示了使用KV Cache和不使用的对比

而在实际实现时的做法就较为简单了,直接和缓存起来的过去key和value与新生成的字符的key和value拼接就可以得到新的key和value。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

显存代价

KV cache实际上可以认为是一种以空间换时间的操作,其内存消耗为:

$$2 \times batch \times context_length \times n_layers \times n_head \times d_heads \times Pa$$

  • 2代表是k与v两类缓存

  • **batch**:批量大小(batch size),表示一次训练或推理过程中输入的样本数量。

  • **context_length**:上下文长度,也就是输入序列的长度。在语言模型中,通常对应于文本的词数或标记的数量。

  • **n_layers**:模型的层数,即 Transformer 模型中的层数,各个层有自己独立的K、V。例如,BERT 或 GPT 中的层数。

  • **n_heads**:注意力头的数量。在多头注意力机制中,模型会将注意力计算分为多个头,每个头独立地进行计算,最后将结果合并。

  • **d_heads**:每个注意力头的维度。在多头注意力中,每个头计算的 keyvalue 的维度,通常是整个模型的嵌入维度(如 hidden_size)除以 n_heads

  • **Pa**:一个常数,用于表示每个 keyvalue 向量的每个元素占用的内存量(如浮点数的字节数)。通常,如果是 32 位浮点数,则为 4 字节;如果是 16 位浮点数,则为 2 字节。

以一个batch_size=32, context_length=2048, n_layer=32, n_head=32, d_head=128, float32类型,则需要占用的显存为: 2 * 32 * 2048 * 32 * 32 * 4096 * 4 / 1024/1024/1024 = 64G。

优化速度

有过一些实验,对于hugging face等推理库:

  • 使用kvcache耗时11s

  • 不使用kvcache耗时56s

参考资料


大模型 kvcache 浅析
http://example.com/2025/04/25/kvcacheStudy/
作者
滑滑蛋
发布于
2025年4月25日
许可协议