【Picotron-Tutorial】Tensor并行

理论分析

分析的对象 $$Y=X@W$$

列并行

需要给每个GPU都复制一份X(往往都是早就有了),然后对于W进行列维度的切分。最后每个GPU会有不同列的结果,最后会对其进行all_gather拼接得到结果。

行并行

对于行并行,由于W的行数减小了,所以X的列数也要跟着变,所以首先需要将X进行列维度的拆分,划分到各个GPU卡上,然后与W进行相乘,得到的结果再进行all_reduce。

MLP模块的Tensor并行策略

以大模型中的MLP模块为例,其结构往往为

  1. 矩阵乘

  2. Gelu

  3. 矩阵乘

所以如何设置tensor并行的策略就非常重要。

首先由于我们希望将gelu操作与一开始的矩阵乘操作放在一起运算,而行并行中最后会通过all_reduce进行一次相加,由于 Gelu(Y_0)+Gelu(Y_1) != Gelu(Y_0+Y_1),所以行并行并不能满足要求。而列并行中最后只是简单的拼接,所以还是可以做到的。所以一开始我们需要选择列并行。

然后需要讨论队后一个矩阵乘,我们需要选择什么矩阵并行的方法:

  1. 如果采用列并行,那么我们就需要先进行一次all_gather操作得到结果,然后再broadcast给各个卡,最后再将结果进行all_gather汇聚在一起,注意这里相当于产生了3个通信操作。

  • 如果采用行并行,那么就不需要中间的进行结果汇聚的操作了,直接进行行并行的计算然后再进行all_reduce即可。注意这样做的话我们就只需要一次通信即可。

综上,最后采取列并行+行并行的矩阵并行运算的方法才是最合适的方法。

Attention模块的Tensor并行策略

attention模块内主要的计算步骤如下:

  1. 与W_q, W_k, W_v进行矩阵乘得到Q、K、V

  • 得到各个注意力头的attention输出

  • 拼接各个attention,然后与W矩阵相乘得到最终的attention

其实整体与MLP模块的分析类似,我们会先采取列并行的方式来划分W_q, W_k, W_v,然后采用行并行的方式来划分W_o,这样最后计算的时候就不需要汇总了,而是直接计算即可。

Embedding的Tensor并行策略

Embedding层的主要作用是通过各个token的id去embedding矩阵中获取对应的行作为输入。

所以在进行tensor并行的时候,只能对embedding矩阵采取行并行的切分方法,但是注意我们不会对输入进行切分,具体在使用的时候还会有一些其他的注意事项。

  1. 由于每块GPU只有不同id范围的embedding,所以我们首先需要将各个token对id减去embedding矩阵的起始位置,得到新的坐标

  2. 然后得到所有不在当前范围内的token的坐标,并将这些坐标mask成0

  3. 然后依据一般的embedding获取的规则去获取所有token对应的embeddings

  4. 然后再将所有超出范围的token的坐标对应的embeddings层化为0

  5. 最后将各个GPU上的embeddings层进行all_reduce即可得到最后的结果

代码分析

概览

首先调用apply_tensor_parallel函数来替换model中的部分层为矩阵并行的层。这里是直接写死各个层需要用什么并行方式。其整体来说就是先进行列并行然后再进行行并行,从而节省了中间的通信操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def apply_tensor_parallel(model):

def _replace_module(_module, _linear_proj_name, _style, args={}):
assert _style in ["column", "row", 'vocab']
linear_layer = getattr(_module, _linear_proj_name)

if _style == "column":
new_linear_layer = ColumnParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
gather_output=args.get("gather_output", False)
)
elif _style == "row":
new_linear_layer = RowParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
)
else:
new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim,
)
setattr(_module, _linear_proj_name, new_linear_layer)

module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
("attention", "v_proj", "column"),
("attention", "out_proj", "row"),
("mlp", "up_proj", "column"),
("mlp", "gate_proj", "column"),
("mlp", "down_proj", "row"),
]

for layer in model.decoder_layers:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
_replace_module(getattr(layer, module_name), linear_proj_name, style)

_replace_module(model, "embedding", "vocab")
_replace_module(model, "final_proj", "column", args={"gather_output": True})

return model

列并行实现

其在初始化参数的时候会先按照原先的形状进行初始化,然后再将其按照并行维度进行划分,然后取自己rank对应的数据。需要注意对于矩阵乘,pytorch实现的时候是用X@W^T,所以对于列并行,实际上是会对W进行行并行。

收集结果的时候是用all_gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ColumnParallelLinear(nn.Module):

def __init__(self, in_features: int, out_features: int, bias: bool, gather_output: bool = False):

super(ColumnParallelLinear, self).__init__()

self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank

self.in_features = in_features
self.out_features = out_features
assert out_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // self.tp_world_size
self.gather_output = gather_output

# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
self.bias = nn.Parameter(torch.Tensor(self.output_size_per_partition))
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)

self.reset_parameters()

def reset_parameters(self):
# Initialize weight tensor with the default initialization method used for nn.Linear in PyTorch
if self.tp_world_size == 1:
# U(-sqrt(k), sqrt(k))
k = 1 / self.weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(self.weight, -bound, bound)
return

# When TP > 1, Initialize master weight
master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
# Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)

# Split the model into size of self.output_size_per_partitio and take the corresponding partition
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()

def forward(self, input):
input_parallel = Copy.apply(input)
# XW_i^T + b, output is Y_i
output = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
output = Gather.apply(output)
return output

行并行实现

与列并行基本一致,就是在实现的时候是对W的列进行划分,收集结果的时候是用All_reduce。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class RowParallelLinear(nn.Module):

def __init__(self, in_features: int, out_features: int, bias: bool):
super(RowParallelLinear, self).__init__()

self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank

self.in_features = in_features
self.out_features = out_features
assert in_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.input_size_per_partition = in_features // self.tp_world_size

self.weight = nn.Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)

self.reset_parameters()

def reset_parameters(self):
# Initialize weight tensor with the default initialization method used for nn.Linear in PyTorch
if self.tp_world_size == 1:
# U(-sqrt(k), sqrt(k))
k = 1 / self.weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(self.weight, -bound, bound)
return

# When TP > 1, Initialize master weight
master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
# Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)

# Split the model into size of self.input_size_per_partition and take the corresponding partition
weight_list = torch.split(master_weight, self.input_size_per_partition, dim=1)
self.weight.data = weight_list[self.tp_rank].contiguous()

def forward(self, input):
# X_i * W_i^T + b
output_parallel = F.linear(input, self.weight)
# All-reduce across all the partitions.
output = Reduce.apply(output_parallel)
return output if self.bias is None else output + self.bias

Embedding并行

与之前谈论的类似,先得到input_mask,然后再将input id减去start id得到masked_input,然后将input_mask对应位置的mask_input标记为0,得到embedding的结果后,再将mask_input对应位置标记为0,最后进行reduce得到结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False
):

super(VocabParallelEmbedding, self).__init__()

self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank

self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = self._vocab_range_from_global_vocab_size(
self.num_embeddings, pgm.process_group_manager.tp_rank, pgm.process_group_manager.tp_world_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index

self.weight = nn.Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))

self.reset_parameters()

def _vocab_range_from_global_vocab_size(self, global_vocab_size: int, rank: int, world_size: int):
assert global_vocab_size % world_size == 0, f"{global_vocab_size} is not divisible by {world_size}"
per_partition_vocab_size = global_vocab_size // world_size
# vocab_range_from_per_partition_vocab_size
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l

def reset_parameters(self):
if self.tp_world_size == 1:
# Initialize Vocab embedding with N(0, 1)
torch.nn.init.normal_(self.weight, mean=0.0, std=1.0)
return

# When TP > 1, Initialize master weight
master_weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=self.weight.dtype, requires_grad=False)
torch.nn.init.normal_(master_weight, mean=0.0, std=1.0)

# Split the model into size of self.num_embeddings_per_partition and take the corresponding partition
weight_list = torch.split(master_weight, self.num_embeddings_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()

def forward(self, input):
"""
Performs an embedding lookup for input tokens in the parallelized embedding layer
1. Masks tokens that fall outside the specified vocabulary range and adjusts the input
2. Performs embedding lookups for valid tokens, setting embeddings of out-of-vocabulary tokens to zero
3. Reduces the embeddings across model parallel GPUs using all-reduce for synchronization
"""
# Build the mask for out-of-vocabulary tokens.
input_mask = (input < self.vocab_start_index) | (input >= self.vocab_end_index)
# Mask the input.
masked_input = input.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings for the valid tokens.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Embedding of out-of-vocabulary tokens is set to 0.
output_parallel[input_mask, :] = 0.0
output = Reduce.apply(output_parallel)
return output

【Picotron-Tutorial】Tensor并行
http://example.com/2025/06/07/Picotron-Tutorial Tensor Parallel/
作者
滑滑蛋
发布于
2025年6月7日
许可协议