【Picotron-Tutorial】上下文并行

理论分析

上下文并行核心思想是将序列并行的方法(也就是沿序列长度进行拆分)的思路应用到已经采用张量并行的模块上

对于上下文并行,就像序列并行一样,我们将沿序列维度拆分输入,但这次我们对整个模型进行拆分,而不仅仅是对之前Tensor+Sequence并行中涉及的部分模型。

拆分序列是横向的切割,所以不会影响大多数模块,如MLP和LayerNorm,因为它们对每个token的处理是独立的。

在计算梯度后,会启动一次all-reduce操作以在上下文并行组内同步梯度。

不过注意力模块需要特别注意,在注意力模块中,每个token需要访问来自所有其他序列token的键/值对,或者在因果注意力的情况下,至少需要关注每个前面的token。

由于上下文并行是沿序列维度将输入分布到各个GPU上,注意力模块将需要各个GPU之间进行充分通信,以交换必要的键/值数据。为了降低通信的影响衍生了各种同步方式。

环形注意力(Ring Attention)

假设我们有4个GPU和4个token的输入。最初,输入序列沿序列维度均匀拆分,因此每个GPU仅拥有一个token及其对应的Q/K/V值。

假设Q1、K1和V1分别表示第一个token的查询、键和值,并且它们位于第1个GPU上。

注意力计算需要4个时间步来完成。在每个时间步中,每个GPU依次执行以下三个操作:

  1. 以非阻塞的方式将“当前的键和值”发送给下一台机器(在非阻塞模式下的最后一个时间步除外),以便在此步骤尚未完成时即可开始下一步骤

  2. 在本地对已拥有的“当前键和值”计算注意力得分,这通常涉及执行 $$\frac{Softmax(QK^T)}{\sqrt[]{d}}∗V$$

  3. 等待接收来自上一台GPU的键和值,然后返回到步骤1,此时“当前的键和值”即为刚刚从上一台GPU接收到的键/值对。

不过有一个大问题,那就是环形注意力的简单实现会导致因果注意力矩阵形状造成的GPU间工作不平衡。让我们通过考虑带有因果注意力掩码的注意力得分矩阵来观察Softmax计算:

Softmax是按行计算的,这意味着每当一个GPU收到一整行的所有token时,就可以进行计算。我们看到GPU1可以立即计算,因为它一开始就拥有token 1-4,而GPU1实际上不需要从其他GPU接收任何信息。然而,GPU2需要等待第二轮,才能收到token 1-4,从而获得token 1-8的所有值。同时,GPU1的工作量明显比其他GPU要少得多。

之字形环形注意力(Zig-Zag Ring Attention) – 平衡计算的实现

为了更好地平衡计算负载,我们需要一种更好的方式来分配输入序列。

这可以通过不将token纯粹顺序地分配给各个GPU,而是稍微混合一下顺序,从而使每个GPU上都有较早和较晚的token。这种方法被称为之字形注意力,在这种新排列中,注意力掩码将显示出计算分布较为均匀。

与此同时,我们也会看到,为了完成所有行的计算,每个GPU都需要来自其他所有GPU的信息。

我们有两种常见方式来重叠计算和通信:一种是通过执行一次通用的all-gather操作,同时在每个GPU上重新组合所有KV(类似于Zero-3的方式);另一种是根据需要从每个GPU逐个收集KV对:

这两种实现方式的关键区别在于它们的通信模式和内存使用:

  1. AllGather实现:
  • 所有GPU同时收集来自其他所有GPU的完整键/值对

  • 需要更多的临时内存,因为每个GPU需要一次性存储完整的KV对

  • 通信在一步内完成,但伴随较大的内存开销

  • 全对全(环形)实现:
  • GPU以环形模式交换KV对,每次传输一个数据块

  • 更节省内存,因为每个GPU只需临时存储一个数据块

  • 通信被分散并与计算重叠,尽管由于多次通信步骤会带来一些额外的基础延迟

全对全方法通常在内存效率上更优,但其通信模式稍显复杂;而AllGather方法则更简单,但在注意力计算过程中需要更多的临时内存。

代码分析

ContextCommunicate

  • 初始化时有点对点的send_rank和recv_rank。

  • 对于send_recv函数,会创建一个异步的send_operation和recv_operation。然后会将这两个添加到_pending_operations中。

  • 对于commit函数会批量提交所有_pending_operations中待处理的操作。

  • 对于wait函数,它会等待所有已提交的任务都完成。

RingAttentionFunc

这里主要实现的是环形注意力机制。

  • 如果有Context parrallel那么就会通过apply_context_parallel设置环境变量”CONTEXT_PARALLEL”,然后根据环境变量在attention计算的时候如果是Context parrallel就会执行ring_attention(q, k, v, sm_scale, is_causal)函数。

  • 对于forward,它会遍历所有的world_size来依次进行处理:

    • 如果不是最后一次,那么就对k、v执行send_recv,得到next_k、next_v,然后commit。

    • 如果step <= comm.rank,那就说明当前的数据是在attention计算中是需要的。那么就调用ring_attention_forward来执行部分序列数据的标准的attention的计算

    • 然后调用update_out_and_lse来执行对这种部分序列的attention的累计更新

    • 如果不是最后一次,那么就将k、v替换为刚通过网络传输得到的next_k和next_v

  • 对于backward,首先会创建两个通信组ContextCommunicate,主要用于 K, V 环形通信的通信对象和 K, V 梯度的环形通信的通信对象,同样也是遍历world_size次来依次处理

    • 如果不是最后一次,那么就对k、v执行send_recv,得到next_k、next_v,然后commit。

    • 如果当前step<=kv_comm.rank,那么就需要通过ring_attention_backward计算梯度,这里实际上是在手动计算反向传播出来的梯度。

    • 得到梯度dq,还需要等待d_kv_comm,得到dk、dv

    • 更新k、v为next_k、next_v

    • 将dq、dv通过send_recv发送给下一个并进行接收

update_rope_for_context_parallel

由于现在每个GPU只有一部分的序列,所以在计算位置编码的时候不能依赖原始的位置,而是要加上当前rank之前的的。


【Picotron-Tutorial】上下文并行
http://example.com/2025/06/28/Picotron-Tutorial Context parallel/
作者
滑滑蛋
发布于
2025年6月28日
许可协议