【Picotron-Tutorial】数据并行

原生数据并行

理论分析

在原生的数据并行中,每个数据并行的组都会自己处理自己的数据,这带来的一个问题在于我们需要及时同步训练过程中的梯度以及优化器的状态。

最原生的方法就是我们在前向传播后,在对每一个层进行反向传播后进行一次同步,如下图所示。由于梯度得到了及时的同步,所有优化器的状态自然也就会变得相同。

代码分析

  1. 修改dataloader为分布式,从而使得每个dp进程每次获取到的数据batch是不相同的,其主要修改是加入DistributedSampler
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
seed=seed,
shuffle=False
)

super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False,
)
  • 对于原本的model需要包裹一个DataParallelNaive,即:model = DataParallelNaive(model),这一层包裹会给每一个需要计算梯度的参数注册一个勾子函数,该函数的作用是如果model的require_backward_grad_sync=true,那么就会进行一次all_reduce获取到其他进程上的参数,然后进行平均,得到平均参数,如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
### begin Data Parallel (naive)
class DataParallelNaive(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
# whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
self.require_backward_grad_sync = True
self.register_backward_hook(self._allreduce_grads)

def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)

def register_backward_hook(self, hook):
"""Registers a backward hook for all parameters of the model that require gradients."""
for p in self.module.parameters():
if p.requires_grad is True:
p.register_hook(hook)

def _allreduce_grads(self, grad):
"""Performs an all-reduce operation to synchronize gradients across multiple processes."""
# No synchronization needed during gradient accumulation, except at the final accumulation step.
if self.require_backward_grad_sync:
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
grad /= pgm.process_group_manager.dp_world_size
return grad
### end Data Parallel (naive)
  • 修改原本的训练进程,添加一段对于model.require_backward_grad_sync赋值的控制,使得在最后一个dataloader.grad_acc_steps时会进行梯度平均。
1
2
if requires_grad_sync:
model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1)

带bucket的数据并行

理论分析

对于原生的数据并行,其最大的问题在于每层进行一次反向传播的时候都需要一个网络传输,这导致整体的速度被拖慢了。所以有提出带bucket带数据并行的方案,其特点在于将多层作为一个bucket,然后在反向传播时,只有当bucket中的所有反向传播都结束的时候才进行一次同步,同时设置该同步为异步的同步,这样就不会阻塞整体的反向传播的进程了。

代码分析

主要是需要构建3个类:

  • Bucket

  • BucketManager

  • DataParallelBucket

Bucket

Bucket 类代表一个梯度桶,它管理一组模型参数及其对应的梯度,并负责这些梯度的同步。

  • 其包含了一个grad_data来存储梯度信息,这个grad_data由多个参数的grad拼接而成。

  • 然后有一个params_with_grad_ready来记录哪些对应的参数已经完成了梯度计算,并通过一个函数来支持标记params_with_grad_ready。如果对应的参数都完成梯度计算后,它支持通过异步的all-reduce操作来同步梯度,并支持通过wait函数来等待梯度同步完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# ... existing code ...
class Bucket:
def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
# params: 这个桶包含的参数集合。
self.params = set(params)
# params_with_grad_ready: 记录桶内哪些参数的梯度已经计算完毕并准备好同步。
self.params_with_grad_ready = set()
# grad_data: 一个预分配的张量,用于存储这个桶内所有参数的梯度。参数的梯度会被拷贝到这个张量中进行 all-reduce。
self.grad_data = grad_data
# process_group: 用于梯度同步的分布式进程组 (通常是数据并行组)。
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
# handle: 异步 all-reduce 操作的句柄,用于后续等待操作完成。
self.handle = None

self.reset() # 初始化状态

def sync_gradient(self) -> None:
"""发起一个异步的 all-reduce 操作来同步梯度。"""
assert self.handle is None # 确保没有正在进行的同步操作
# 在 all-reduce 求和之前,先将梯度除以进程组大小,这样 all-reduce 之后就直接是平均梯度。
self.grad_data /= self.process_group_size
# 发起异步 all-reduce,对 grad_data 进行求和操作。
self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)

def reset(self) -> None:
"""重置桶的状态,通常在梯度同步完成后调用。"""
self.handle = None # 清除句柄
self.params_with_grad_ready.clear() # 清空已准备好的参数集合
self.grad_data.zero_() # 将梯度存储张量清零,为下一次迭代做准备

def wait(self) -> None:
"""等待 all-reduce 操作完成。"""
assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish"
self.handle.wait() # 阻塞等待异步操作完成

def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""标记一个参数的梯度已准备好进行同步。当桶内所有参数都准备好时,启动梯度同步。"""
assert param in self.params and param not in self.params_with_grad_ready
self.params_with_grad_ready.add(param)
# 如果桶内所有参数的梯度都已准备好
if len(self.params_with_grad_ready) == len(self.params):
self.sync_gradient() # 则开始同步这个桶的梯度
# ... existing code ...

BucketManager

BucketManager 负责将模型的所有参数划分到多个 Bucket 中,并管理这些桶。

  • 用户需要指定每个桶的最大容量,以元素数量记。

  • 会遍历模型中的各个参数,尝试将其放入桶中,如果放入不了就再新建一个桶放入

  • 然后为每个桶创建一个连续内存来存储梯度,并将其与参数的grad进行映射,保障两个修改是同步的

  • 支持标记param梯度计算完毕并将消息传递给对应的桶

  • 支持等待所有的桶都同步完成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# ... existing code ...
class BucketManager:
def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
self.params = list(params) # 模型的所有参数
self.buckets = [] # 存储所有 Bucket 对象的列表
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
# params_to_bucket_location: 一个字典,映射每个参数到它所在的桶的索引以及在桶内梯度张量中的位置 (start, end, bucket_idx)。
self.params_to_bucket_location = {}
self.bucket_size = bucket_size # 用户指定的每个桶的最大容量 (以元素数量计)。
self.bucket_sizes = None # 实际每个桶的大小
self.grad_data_list = [] # 存储每个桶的梯度数据张量 (grad_data) 的列表。
self.grad_type = grad_type # 梯度的数据类型,通常是 float32 以保证精度。

self._initialize_buckets() # 初始化分桶逻辑

def _initialize_buckets(self) -> None:
"""根据 bucket_size 将模型参数划分到不同的桶中。"""
cur_bucket_size = 0
cur_bucket_idx = 0

# 遍历所有需要梯度的参数,将它们分配到桶中
for param in self.params:
if not param.requires_grad:
continue

num_elements = param.numel() # 参数的元素数量
# 如果当前桶是空的,或者当前参数加入后会超过桶的容量,则创建一个新桶
if cur_bucket_size == 0: # 新桶的第一个参数
self.params_to_bucket_location[param] = (0, num_elements, cur_bucket_idx)
cur_bucket_size = num_elements
elif cur_bucket_size + num_elements > self.bucket_size: # 当前桶放不下,开新桶
cur_bucket_idx += 1
self.params_to_bucket_location[param] = (0, num_elements, cur_bucket_idx)
cur_bucket_size = num_elements
else: # 可以放入当前桶
self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + num_elements, cur_bucket_idx)
cur_bucket_size += num_elements

# 收集每个桶的实际大小和包含的参数
num_total_buckets = cur_bucket_idx + 1
actual_bucket_sizes = [0] * num_total_buckets
buckets_to_params_list = [[] for _ in range(num_total_buckets)]
for param, (start, end, idx) in self.params_to_bucket_location.items():
actual_bucket_sizes[idx] = max(actual_bucket_sizes[idx], end) # 桶的实际大小是最后一个参数的结束位置
buckets_to_params_list[idx].append(param)

self.bucket_sizes = actual_bucket_sizes

# 为每个桶创建梯度存储张量 (grad_data) 和 Bucket 对象
for i in range(len(self.bucket_sizes)):
# 为每个桶预分配一块连续的内存来存储梯度
grad_tensor = torch.zeros(self.bucket_sizes[i], dtype=self.grad_type, device='cuda')
self.grad_data_list.append(grad_tensor)
self.buckets.append(Bucket(buckets_to_params_list[i], grad_tensor, self.process_group))

# 为每个参数创建一个指向其对应桶中梯度存储区视图的 'main_grad' 属性。
# 参数的梯度会先累加到这个 'main_grad' 中。
# 注意这里是倒序遍历参数,这与 PyTorch 反向传播计算梯度的顺序有关,
# 使得参数的梯度视图 (param.main_grad) 在其梯度实际计算出来之前就被创建。
for param in self.params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
# param.main_grad 是一个视图 (view),它指向 self.grad_data_list[bucket_id] 中的特定区域。
# 对 param.main_grad 的修改会直接反映在 grad_data_list[bucket_id] 上。
param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)

def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor:
"""从一个大张量中获取一个特定形状的视图。"""
return tensor[start:end].view(shape)

def reset(self) -> None:
"""重置所有桶的状态。"""
for bucket in self.buckets:
bucket.reset()

def wait(self) -> None:
"""等待所有桶的梯度同步完成。"""
for bucket in self.buckets:
if bucket.handle is not None: # 只等待已经启动了 all_reduce 的桶
bucket.wait()

def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""标记一个参数的梯度已准备好,并通知其所在的桶。"""
bucket_idx = self.params_to_bucket_location[param][2]
self.buckets[bucket_idx].mark_param_as_ready(param)
# ... existing code ...

DataParallelBucket

这是梯度分桶数据并行策略的顶层封装,它继承自 nn.Module ,可以像普通的 PyTorch模块一样使用。它主要用来包装原始的model。

  • 它负责初始化bucketManager

  • 给各个参数注册一个勾子函数,其负责

    • 累加梯度到main_grad

    • 如果是acc_grad中的最后一个计算,就:

      • 注册一个post_backward函数,该函数会在在整个反向传播结束后被调用,其负责等待桶中所有的梯度同步完成,然后将同步后的梯度 (存储在 param.main_grad) 复制回 param.grad,以便优化器使用。

      • 还会告诉bucketManager参数已经准备好,从而让bucketManager判断是否需要开始收集各个参数的grad

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ... existing code ...
class DataParallelBucket(nn.Module):
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
super().__init__()
self.module = module # 被包装的原始模型
self.require_backward_grad_sync = True # 控制是否进行梯度同步,用于梯度累积

# 计算每个桶的大小 (以元素数量计)
# grad_size: 假设梯度是 bfloat16 (2字节) 或 float32 (4字节)。这里代码写的是2,对应bfloat16
# bucket_cap_mb: 用户指定的桶容量上限 (MB)
grad_element_size = torch.tensor([], dtype=grad_type).element_size() # 获取梯度类型对应的字节数
bucket_size_in_elements = bucket_cap_mb * 1024 * 1024 // grad_element_size

self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size_in_elements, grad_type)
self.register_backward_hook() # 注册反向传播钩子
self._post_backward_callback_set = False # 标记是否已经注册了 post_backward 回调

def forward(self, *inputs, **kwargs):
# 前向传播直接调用原始模块
return self.module(*inputs, **kwargs)

# backward 和 get_flops 方法是可选的,取决于原始模块是否需要它们
# def backward(self, input_tensor, output_tensor, output_tensor_grad):
# return self.module.backward(input_tensor, output_tensor, output_tensor_grad)

# def get_flops(self, *args, **kwargs):
# return self.module.get_flops(*args, **kwargs)

def register_backward_hook(self):
"""
为每个需要梯度的参数注册一个钩子 (hook)。
这个钩子会在该参数的梯度计算完成后被调用。
"""
self.grad_accs = [] # 存储梯度累加器函数,防止被垃圾回收
for param in self.module.parameters():
if param.requires_grad:
# param_tmp.grad_fn.next_functions[0][0] 是获取参数对应的梯度累加器节点 (AccumulateGrad object)
param_tmp = param.expand_as(param) # 确保 param_tmp 有 grad_fn
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
# 为梯度累加器节点注册钩子
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)

def _make_param_hook(self, param: torch.nn.Parameter, bucket_manager: BucketManager):
"""创建一个参数特定的钩子函数。"""
def param_hook(*unused):
"""
当参数 param 的梯度计算完成后,这个钩子会被调用。
1. 将计算得到的 param.grad 累加到 param.main_grad (即桶的梯度存储区)。
2. 将 param.grad 清空 (因为梯度已经拷贝到 main_grad)。
3. 如果需要同步,则标记该参数已准备好,并可能触发桶的同步。
4. 注册一个 _post_backward 回调,确保在整个反向传播完成后执行某些操作。
"""
if param.requires_grad:
assert param.grad is not None, f"Gradient for {param.name} is None." # 确保梯度存在
# 1. 累加梯度到 main_grad (桶的存储区)
param.main_grad.add_(param.grad.data)
# 2. 清空原始梯度,因为已经复制到 main_grad
param.grad = None

if self.require_backward_grad_sync: # 如果不是梯度累积的中间步骤
# 3. 注册 _post_backward 回调 (如果还没注册的话)
# 这个回调会在整个 backward() 调用完成后执行。
if not self._post_backward_callback_set:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True

# 4. 标记参数已准备好,通知 BucketManager
bucket_manager.mark_param_as_ready(param)
return param_hook

def _post_backward(self):
"""
在整个反向传播过程结束后执行的回调。
1. 等待所有桶的梯度同步完成。
2. 将同步后的梯度 (存储在 param.main_grad) 复制回 param.grad,以便优化器使用。
"""
# 1. 等待所有桶的 all-reduce 操作完成
self.bucket_manager.wait()
self._post_backward_callback_set = False # 重置标记,为下一次 backward 做准备

# 2. 将同步并平均后的梯度从 param.main_grad 复制回 param.grad
# 优化器 (如 AdamW) 会读取 param.grad 来更新参数。
for p in self.module.parameters():
if p.requires_grad:
# 需要确保数据类型匹配,优化器通常期望 param.grad 和 param.data 类型一致
p.grad = p.main_grad.to(p.dtype)

def reset(self):
"""重置 BucketManager 的状态,主要是清零所有桶的梯度。"""
self.bucket_manager.reset()
# ... existing code ...

【Picotron-Tutorial】数据并行
http://example.com/2025/06/14/Picotron-Tutorial 数据并行/
作者
滑滑蛋
发布于
2025年6月14日
许可协议