【Megatron-LM源码分析(五)】-Tensor并行

理论基础

训练数据获取

pretrain_gpt.py文件中的get_batch函数可以看到有专门的tp数据处理,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_batch(data_iterator):
"""Generate a batch."""

# TODO: this is pretty hacky, find a better way
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None

# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return batch.values()

进一步的,查看get_batch_on_this_tp_rank函数如下所示,tp rank为0的worker会从data loader中获取一份micro_batch的数据,然后组成batch格式,将其broadcast到tp组的其他worker中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def get_batch_on_this_tp_rank(data_iterator):

args = get_args()

def _broadcast(item):
if item is not None:
torch.distributed.broadcast(
item,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group(),
)

if mpu.get_tensor_model_parallel_rank() == 0:

if data_iterator is not None:
data = next(data_iterator)
else:
data = None

batch = {
'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': (
None
if "attention_mask" not in data
else data["attention_mask"].cuda(non_blocking=True)
),
'position_ids': data["position_ids"].cuda(non_blocking=True),
}

if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])

elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])

elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens'])
_broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])

else:

tokens = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
labels = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
loss_mask = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
if args.create_attention_mask_in_dataloader:
attention_mask = torch.empty(
(args.micro_batch_size, 1, args.seq_length, args.seq_length),
dtype=torch.bool,
device=torch.cuda.current_device(),
)
else:
attention_mask = None
position_ids = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)

if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)

elif mpu.is_pipeline_first_stage():
labels = None
loss_mask = None

_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)

elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens)
_broadcast(position_ids)
else:
tokens = None
position_ids = None

_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids,
}

return batch

上述可以看到其主要broadcast了tokens、labels、loss_mask、attention_mask、position_ids这五分数据,如下图的torch profiler所示,也确实发生了5次的broadcast。

Tensor并行相关代码

模型构建

model构建的入口函数在pretrain_gpt.pymodel_provider函数中,其默认执行路线如下所示(去除了一些不必要的分支)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def model_provider(
pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.

If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.

Returns:
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()

use_te = args.transformer_impl == "transformer_engine"

print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)

# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)

model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)

return model

对于_get_transformer_layer_spec函数,其实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.

Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration

Returns:
transformer_layer_spec: The transformer layer specification
"""
args = get_args()
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm,
use_kitchen=config.use_kitchen,
)
else:
return get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
use_kitchen=config.use_kitchen,
)

默认参数中use_teTrue,即使用了具有算子融合等优化的transformer_engine,故走到了get_gpt_layer_with_transformer_engine_spec分支,而不是Megatron-LM本地的get_gpt_layer_local_spec分支,get_gpt_layer_with_transformer_engine_spec如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-argument
moe_use_legacy_grouped_gemm: Optional[bool] = False,
qk_l2_norm: Optional[bool] = False,
use_te_op_fuser: Optional[bool] = False,
use_kitchen: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).

Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
use_te_op_fuser (bool, optional): Use Transformer Engine's operation-based API, which may
enable certain operation fusions. Defaults to False.

Returns:
ModuleSpec: Module specification with TE modules

"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
" and will be removed soon. Please update your code accordingly."
)

if use_kitchen:
assert HAVE_KITCHEN
backend: BackendSpecProvider = KitchenSpecProvider(fallback=TESpecProvider())
if use_te_op_fuser:
raise AssertionError("use_te_op_fuser not compatible with using kitchen in mlp.")
else:
backend = TESpecProvider()

mlp = get_mlp_module_spec_for_backend(
backend=backend,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
use_te_op_fuser=use_te_op_fuser,
)

if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
linear_q_up_proj = (
backend.column_parallel_layer_norm_linear()
if qk_layernorm
else backend.column_parallel_linear()
)
linear_kv_up_proj = (
backend.column_parallel_layer_norm_linear()
if qk_layernorm
else backend.column_parallel_linear()
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=backend.layer_norm(),
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=backend.column_parallel_linear(),
linear_q_down_proj=backend.linear(),
linear_q_up_proj=linear_q_up_proj,
linear_kv_down_proj=backend.linear(),
linear_kv_up_proj=linear_kv_up_proj,
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
qk_norm = backend.layer_norm(for_qk=True)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=backend.column_parallel_layer_norm_linear(),
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
"mlp.1.basic_ops.0.weight": "mlp.linear_fc1.weight",
"mlp.1.basic_ops.1.bias": "mlp.linear_fc1.bias",
"mlp.3.basic_ops.0.weight": "mlp.linear_fc2.weight",
"mlp.3.basic_ops.1.bias": "mlp.linear_fc2.bias",
},
),
)

注意上述的use_kitchen很重要,而默认情况下其为False,故backend = TESpecProvider(),即使用的是transformer_engine来生成TransformerLayer,而不是用 NVIDIA Kitchen作为后端来提供(部分)Transformer 子模块的实现/spec。而Megatron-LM还进一步对transformer_engine的相关模块进行了简单封装以使其可以支持Tensor并行等功能。

Megatron-LM本地实现gpt_layer

由于transformer_engine是专有封装过于复杂,所以我们转而去查看Megatron-LM的本地实现,get_gpt_layer_with_transformer_engine_spec如下所示,我们查看的backend为LocalSpecProvider

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-argument
moe_use_legacy_grouped_gemm: Optional[bool] = False,
normalization: Optional[str] = None,
qk_l2_norm: Optional[bool] = False,
use_kitchen: bool = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.

Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.

Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""

if use_kitchen:
assert HAVE_KITCHEN
backend = KitchenSpecProvider(fallback=LocalSpecProvider())
else:
backend = LocalSpecProvider()
# Adjust for RMS norm.
if normalization == "RMSNorm":
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False)
qk_norm = backend.layer_norm(rms_norm=True, for_qk=True)
else:
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False)
qk_norm = backend.layer_norm(rms_norm=False, for_qk=True)

if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
" and will be removed soon. Please update your code accordingly."
)

mlp = get_mlp_module_spec_for_backend(
backend=backend,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)

if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=layer_norm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=backend.column_parallel_linear(),
linear_q_down_proj=backend.column_parallel_linear(),
linear_q_up_proj=backend.column_parallel_linear(),
linear_kv_down_proj=backend.column_parallel_linear(),
linear_kv_up_proj=backend.column_parallel_linear(),
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
kv_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=layer_norm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=backend.column_parallel_linear(),
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
"input_layernorm.": "self_attention.linear_qkv.layer_norm_",
"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_",
},
),
)

  • 其使用的是TransformerLayer来组装,初始化代码如下所示,初始化的模块依次为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    def __init__(
    self,
    config: TransformerConfig,
    submodules: TransformerLayerSubmodules,
    layer_number: int = 1,
    hidden_dropout: Optional[float] = None,
    model_comm_pgs: Optional[ModelCommProcessGroups] = None,
    vp_stage: Optional[int] = None,
    ):
    super().__init__(config=config)

    # Enable cuda graphs.
    if (
    config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration"
    ) or config.external_cuda_graph:
    assert not (
    config.enable_cuda_graph and config.external_cuda_graph
    ), "Cudagraphs and external cudagraphs cannot be enabled at the same time"
    if config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration":
    if not self.training:
    # Cudagraphs for inference are only enabled with the flash decoding kernel
    assert (
    self.config.flash_decode
    ), "--flash-decode is required to use CUDA graphs during inference"
    self.cudagraph_manager = CudaGraphManager(config, vp_stage=vp_stage)
    else:
    # List to store CUDA graphs. A list of `N` CUDA graphs for this layer where N is
    # the number of microbatches. Multiple CUDA graphs per layer is required to support
    # pipelining which requires running FWD graph of multiple microbatches before BWD
    # graph. To enable CUDA graph, this list should be populated in the model training
    # script with the graphs returned by make_graphed_callables API before the first
    # training step.
    self.cuda_graphs = []
    # List to store forward pre-hooks. Forward pre-hooks are not captured into CUDA
    # graphs. Those hooks and args are collected in this list and should be manually
    # triggered before CUDA Graph running. This is required to ensure the correct param
    # all-gather overlap with forward compute.
    self.cuda_graph_manual_hooks = []
    self.current_microbatch = -1

    if model_comm_pgs is None:
    model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups()

    self.submodules_config = submodules
    self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage)
    self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout

    # [Module 1: Input Layernorm] Optional Layernorm on the input data
    # TODO: add pytorch only layernorm
    self.input_layernorm = build_module(
    submodules.input_layernorm,
    config=self.config,
    hidden_size=self.config.hidden_size,
    eps=self.config.layernorm_epsilon,
    )

    attention_optional_kwargs = {}
    if config.context_parallel_size > 1 and config.cp_comm_type is not None:
    if isinstance(config.cp_comm_type, list):
    attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[self.layer_number]
    else:
    attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type

    attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs

    # [Module 2: SelfAttention]
    self.self_attention = build_module(
    submodules.self_attention,
    config=self.config,
    layer_number=self.layer_number,
    **attention_optional_kwargs,
    )

    # [Module 3: BiasDropoutFusion]
    self.self_attn_bda = build_module(submodules.self_attn_bda)

    # [Module 4: Post SelfAttention] Optional Layernorm after self-attn
    self.pre_cross_attn_layernorm = build_module(
    submodules.pre_cross_attn_layernorm,
    config=self.config,
    hidden_size=self.config.hidden_size,
    eps=self.config.layernorm_epsilon,
    )

    # [Module 5: CrossAttention]
    self.cross_attention = build_module(
    submodules.cross_attention,
    config=self.config,
    layer_number=self.layer_number,
    **attention_optional_kwargs,
    )

    # [Module 6: BiasDropoutFusion]
    self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)

    # [Module 7: Pre MLP] Optional Layernorm before MLP
    self.pre_mlp_layernorm = build_module(
    submodules.pre_mlp_layernorm,
    config=self.config,
    hidden_size=self.config.hidden_size,
    eps=self.config.layernorm_epsilon,
    )
    # [Module 8: MLP block]
    additional_mlp_kwargs = {}
    # import here to avoid circular import
    from megatron.core.extensions.transformer_engine import TEFusedMLP
    from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
    from megatron.core.transformer.moe.moe_layer import MoELayer

    # MLP expects tp_group but MoELayer expects model_comm_pgs to be passed in.
    # We can change MLP to accept model_comm_pgs but it makes the logic implicit
    # The conditional below is to make the logic explicit
    # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs
    if isinstance(submodules.mlp, ModuleSpec):
    if submodules.mlp.module in (MoELayer, GroupedMLP, TEGroupedMLP, SequentialMLP):
    additional_mlp_kwargs["model_comm_pgs"] = model_comm_pgs
    elif submodules.mlp.module == MLP:
    assert hasattr(
    model_comm_pgs, 'tp'
    ), 'TP process group is required for MLP in TransformerLayer'
    additional_mlp_kwargs["tp_group"] = model_comm_pgs.tp
    elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP:
    assert hasattr(
    model_comm_pgs, 'tp'
    ), 'TP process group is required for TEFusedMLP in TransformerLayer'
    additional_mlp_kwargs["tp_group"] = model_comm_pgs.tp
    else:
    log_single_rank(
    logger,
    logging.WARNING,
    f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.",
    )
    self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
    if hasattr(self.mlp, 'set_layer_number'):
    self.mlp.set_layer_number(self.layer_number)

    # [Module 9: BiasDropoutFusion]
    self.mlp_bda = build_module(submodules.mlp_bda)

    self.recompute_input_layernorm = False
    self.recompute_pre_mlp_layernorm = False
    self.recompute_mlp = False
    if self.config.recompute_granularity == 'selective':
    if "layernorm" in self.config.recompute_modules:
    if (
    not isinstance(self.input_layernorm, IdentityOp)
    and not self.config.external_cuda_graph
    ):
    self.recompute_input_layernorm = True
    if self.config.fp8:
    self.self_attention.set_for_recompute_input_layernorm()
    if not isinstance(self.pre_mlp_layernorm, IdentityOp):
    self.recompute_pre_mlp_layernorm = True
    if self.config.fp8:
    if isinstance(self.mlp, MoELayer):
    self.mlp.set_for_recompute_pre_mlp_layernorm()
    else:
    from megatron.core.extensions.transformer_engine import (
    set_save_original_input,
    )

    set_save_original_input(self.mlp.linear_fc1)
    if "mlp" in self.config.recompute_modules:
    if not isinstance(self.mlp, MoELayer):
    self.recompute_mlp = True

    # @jcasper how should we handle nvfuser?
    # Set bias+dropout+add fusion grad_enable execution handler.
    # TORCH_MAJOR = int(torch.__version__.split('.')[0])
    # TORCH_MINOR = int(torch.__version__.split('.')[1])
    # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
    # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
    self.bias_dropout_add_exec_handler = torch.enable_grad

    1. Input Layernorm

    2. SelfAttention

    3. BiasDropoutFusion

    4. Post SelfAttention

    5. CrossAttention

    6. BiasDropoutFusion

    7. Pre MLP

    8. MLP block

    9. BiasDropoutFusion

  • 其前向传播也是一些比较标准的实现,代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def forward(self, *args, **kwargs):
"""
Perform a forward pass through the transformer layer.

This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
hidden_states, context = self._forward_attention(*args, **kwargs)
output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
return output, context

def _forward_attention(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Perform a forward pass through the attention layer and the layernorms before and after
the attention operations.

Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor, optional): Bias tensor for Q * K.T.
inference_context (object, optional): Parameters for inference-time optimizations.
packed_seq_params (object, optional): Parameters for packed sequence processing.
sequence_len_offset (Tensor, optional): Offset along sequence dimension
during inference.

Returns:
Tuple[Tensor, Tensor]: A tuple containing:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.
context (Tensor): Updated context tensor if cross-attention is used,
otherwise None.
"""

inference_context = deprecate_inference_params(inference_context, inference_params)

# Residual connection.
residual = hidden_states

# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)

# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")

if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")

# Residual connection.
residual = hidden_states

# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)

# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_context=inference_context,
)

if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)

return hidden_states, context

def _forward_mlp(self, hidden_states, inference_context=None):
"""
Perform a forward pass through the feed-forward layer.

Args:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.

Returns:
output (Tensor): Transformed hidden states of shape [s, b, h].
"""

# Residual connection.
residual = hidden_states

# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)

nvtx_range_push(suffix="mlp")
# Potentially chunk the MLP computation during prefill to minimize the peak activation size
should_chunk_mlp_for_prefill = (
self.config.mlp_chunks_for_prefill > 1
and inference_context is not None
and not inference_context.is_decode_only()
and not isinstance(self.mlp, IdentityOp)
)

if self.recompute_mlp:
if self.config.fp8:
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import te_checkpoint

mlp_output_with_bias = te_checkpoint(
self.mlp,
False,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
pre_mlp_layernorm_output,
)
else:
mlp_output_with_bias = tensor_parallel.checkpoint(
self.mlp, False, pre_mlp_layernorm_output
)
elif should_chunk_mlp_for_prefill:
# Chunk input along sequence dimension
num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0])
chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0)

# Compute outputs for each chunk
outputs = [self.mlp(chunk) for chunk in chunks]

# Aggregate chunk outputs
mlp_output = torch.cat([out for out, _ in outputs], dim=0)
bias_chunks = [bias for _, bias in outputs if bias is not None]
bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None
mlp_output_with_bias = (mlp_output, bias_output)

else:
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)

if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
nvtx_range_pop(suffix="mlp")

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")

# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)

return output

MLP模块

MLP模块中往往是先进行一次全连接计算,在使用类似gelu的激活函数,再使用一次全连接计算,在TP并行中往往采用的是对前一次采用列并行对后一次采用行并行的方式。

本地模块获取MLP的相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_mlp_module_spec_for_backend(
backend: BackendSpecProvider,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
use_te_op_fuser: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""

linear_fc2 = backend.row_parallel_linear()

if num_experts is None:
# Dense MLP w/ or w/o TE modules.
if use_te_op_fuser:
return ModuleSpec(module=TEFusedMLP)
elif backend.fuse_layernorm_and_linear():
linear_fc1 = backend.column_parallel_layer_norm_linear()
assert linear_fc1 is not None
else:
linear_fc1 = backend.column_parallel_linear()
return ModuleSpec(
module=MLP, submodules=MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec_for_backend(
backend=backend,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)

一般情况下两个linear层分别为column_parallel_linearrow_parallel_linear,然后以此为基础构建了MLP模块,MLP模块的相关代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.

Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.

We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""

def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
is_expert: bool = False,
input_size: Optional[int] = None,
ffn_hidden_size: int = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super().__init__(config=config)

self.config: TransformerConfig = config

self.input_size = input_size if input_size != None else self.config.hidden_size

tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
if ffn_hidden_size is None:
if is_expert:
raise ValueError("MoE MLP requires `ffn_hidden_size`, but it was not provided.")
warnings.warn(
"MLP requires ffn_hidden_size, but it was not provided. Using \
config.ffn_hidden_size by default.",
DeprecationWarning,
stacklevel=2,
)
ffn_hidden_size = self.config.ffn_hidden_size

# If this is a gated linear unit we double the output width
# see https://arxiv.org/pdf/2002.05202.pdf
if self.config.gated_linear_unit:
ffn_hidden_size *= 2

self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name="fc1",
tp_group=tp_group,
)

self.activation_func = self.config.activation_func

self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name="fc2",
tp_group=tp_group,
)

def forward(self, hidden_states, per_token_scale=None):
"""Perform the forward pass through the MLP block."""
# [s, b, 4 * h/p]
nvtx_range_push(suffix="linear_fc1")
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
nvtx_range_pop(suffix="linear_fc1")

nvtx_range_push(suffix="activation")
if self.config.bias_activation_fusion:
if per_token_scale is not None:
if self.activation_func == F.silu and self.config.gated_linear_unit:
# dtype is handled inside the fused kernel
intermediate_parallel = weighted_bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
per_token_scale.unsqueeze(-1),
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of swiglu with per_token_scale in MLP.")
else:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(
intermediate_parallel, bias_parallel
)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
self.config.cpu_offloading
and self.config.cpu_offloading_activations
and HAVE_TE,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:

def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]

intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)

if per_token_scale is not None:
original_dtype = intermediate_parallel.dtype
intermediate_parallel = intermediate_parallel * per_token_scale.unsqueeze(-1)
intermediate_parallel = intermediate_parallel.to(original_dtype)
nvtx_range_pop(suffix="activation")

# [s, b, h]
nvtx_range_push(suffix="linear_fc2")
output, output_bias = self.linear_fc2(intermediate_parallel)
nvtx_range_pop(suffix="linear_fc2")

if per_token_scale is not None:
assert output_bias is None, "Bias is not supported with per_token_scale"

return output, output_bias

# pylint: disable=missing-function-docstring
def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = {}
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f"{prefix}{name}.", sharded_offsets, metadata)
if self.config.gated_linear_unit and name == "linear_fc1":
for k, v in sub_sd.items():
if k in (f"{prefix}{name}.weight", f"{prefix}{name}.bias"):
sub_sd[k] = apply_swiglu_sharded_factory(
v, sharded_offsets, singleton_local_shards
)
sharded_state_dict.update(sub_sd)
return sharded_state_dict

def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()

  • 在初始化时:

    • 其读取配置得到了ffn_hidden_size以及tp_group等参数

    • 然后构建了column_parallel_linear类型的fc1以及row_parallel_linear类型的fc2,还要按配置所需的activation_func

  • 在Forward时:

    • 其整个流程为了方便Nsys分析使用nvtx_range_push进行了准确的划分

    • 先调用linear_fc1,再调用activation计算,再调用linear_fc2计算

column_parallel_linear

Megatron-LM本地写的ColumnParallelLinear如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].

Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
"""

def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(ColumnParallelLinear, self).__init__()

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.embedding_activation_buffer = embedding_activation_buffer
self.grad_output_buffer = grad_output_buffer
self.config = config
self.disable_grad_reduce = disable_grad_reduce
self.tp_group = tp_group

self.tp_group = get_tensor_model_parallel_group_if_none(
self.tp_group, is_expert=self.is_expert
)
world_size = get_pg_size(self.tp_group)
rank = get_pg_rank(self.tp_group)
self.explicit_expert_comm = self.is_expert and (world_size > 1 or self.expert_parallel)
self.output_size_per_partition = divide(output_size, world_size)

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
rank=rank,
world_size=world_size,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=0,
stride=stride,
is_expert=self.is_expert,
)

setattr(self.weight, "allreduce", not (self.is_expert and self.expert_parallel))
else:
self.weight = None

if bias:
if config.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=config.params_dtype)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, "allreduce", not (self.is_expert and self.expert_parallel))
else:
self.register_parameter("bias", None)

self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
"`sequence_parallel` is set to `True`, but tensor model parallel size "
f"is {world_size}. Disabling sequence parallel."
)
self.sequence_parallel = False

self.allreduce_dgrad = (
world_size > 1 and not self.sequence_parallel and not self.disable_grad_reduce
)

if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
'pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." '
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion

if self.allreduce_dgrad and self.sequence_parallel:
raise RuntimeError(
"`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
)

# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f"{prefix}_extra_state"
)
)

def _forward_impl(self, input, weight, *args, **kwargs):
if not weight.requires_grad:
return linear_with_frozen_weight(input, weight, *args, **kwargs)
else:
return linear_with_grad_accumulation_and_async_allreduce(input, weight, *args, **kwargs)

def forward(
self,
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
):
"""Forward of ColumnParallelLinear

Args:
input_:
3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional):
weight tensor to use, compulsory when skip_weight_param_allocation is True.
runtime_gather_output (bool): Gather output at runtime. Default None means
`gather_output` arg in the constructor will be used.

Returns:
- output
- bias

"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
else:
# Check the weight passed in is the correct shape
expected_shape = (self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(
f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected"
)

bias = self.bias if not self.skip_bias_add else None

if (
self.allreduce_dgrad
or self.sequence_parallel
or self.explicit_expert_comm
or self.disable_grad_reduce
):
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_, group=self.tp_group)

if self.config.defer_embedding_wgrad_compute:
if (
self.config.wgrad_deferral_limit == 0
or len(self.embedding_activation_buffer) < self.config.wgrad_deferral_limit
):
self.embedding_activation_buffer.append(input_parallel)

# Matrix multiply.
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad

if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context is True:
if not HAVE_TE:
assert (
self.config.cpu_offloading is False
), "CPU Offloading cannot be enabled while TE is not present"
else:
input_parallel.activation_offloading = self.config.cpu_offloading_activations

output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=(
self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None
),
wgrad_deferral_limit=(
self.config.wgrad_deferral_limit
if self.config.defer_embedding_wgrad_compute
else None
),
tp_group=self.tp_group,
)

gather_output = self.gather_output
# Use the runtime gather output if it's set explicitly.
if runtime_gather_output is not None:
gather_output = runtime_gather_output

if gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel, group=self.tp_group)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets
)

def set_extra_state(self, state: Any):
"""Extra state is ignored"""

def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)

  • 在初始化时:

    • 其首先计算出在TP列并行下self.output_size_per_partition = divide(output_size, world_size),并以此为基础初始化权重self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, ...))

    • 此外还标记了计算梯度时是否需要allreduce_dgrad,需要的条件是world_size > 1 and not self.sequence_parallel and not self.disable_grad_reduce,因为sequence_parallel 与梯度并行有冲突。

  • 在Forward时,流程如下:

    1. 首先如果没有weight参数就使用自身初始化的weight,然后检查形状。

    2. 对于列并行而言,典型的实现是输入在所有 TP ranks 上一致(复制一份),每个 rank 用自己的 W_i 计算 Y_i = X @ W_i^Tcopy_to_tensor_model_parallel_region在 TP>1 时会涉及通信/广播式的“让 input 在 TP ranks 上一致”,但如果启用了某些模式(sequence_parallel / allreduce_dgrad / expert 显式通信 / disable_grad_reduce),这里会选择不走 copy 路径(因为这些模式下输入已经按其它语义准备好了,或者通信由别处负责),直接使用传入的input_

    3. 然后其调用了_forward_impl计算结果,这里进行了多层包装,主要是为了应对sequence_parallel的情况,因为如果sequence_parallel为True,那么其会使用All gather获取input完整序列再做Gemm。

      • 注意其这里也定义了在 backward 中:

        • 如果 ctx.allreduce_dgrad=True:会 torch.distributed.all_reduce(grad_input, async_op=True)
          这是 TP 下典型的 dgrad 通信重叠。

        • 如果 ctx.sequence_parallel=True:会 reduce_scatter 把 grad_input 分发回 sequence-parallel 格式。

    4. 然后其还需要根据runtime_gather_output参数来判断是否需要执行All Gather来复原所有结果。注意在上述的MLP Forward计算时并没有配置runtime_gather_output,所以没有执行All Gather,这也符合TP并行的需要

    5. 最后返回output, output_bias

注意这里并没有直接定义backward的行为,但是正如我们前面所分析的,列并行在反向传播时求 $$\frac{\partial L}{\partial X}$$时需要All Reduce(Sum)操作,这部分backward的行为是Pytorch自动生成的

row_parallel_linear

row_parallel_linear代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]

Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object

"""

def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(RowParallelLinear, self).__init__()

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
self.config = config
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
self.tp_group = tp_group

if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")

# Divide the weight matrix along the last dimension.
self.tp_group = get_tensor_model_parallel_group_if_none(
self.tp_group, is_expert=self.is_expert
)

world_size = get_pg_size(self.tp_group)
rank = get_pg_rank(self.tp_group)
self.explicit_expert_comm = self.is_expert and (world_size > 1 or self.expert_parallel)

self.input_size_per_partition = divide(input_size, world_size)

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=1,
stride=stride,
is_expert=self.is_expert,
)
setattr(self.weight, "allreduce", not (self.is_expert and self.expert_parallel))

if bias:
if config.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
else:
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)

if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, "allreduce", not (self.is_expert and self.expert_parallel))
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
else:
self.register_parameter("bias", None)

# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f"{prefix}_extra_state"
)
)

def _forward_impl(self, input, weight, *args, **kwargs):
if not weight.requires_grad:
return linear_with_frozen_weight(input, weight, *args, **kwargs)
else:
return linear_with_grad_accumulation_and_async_allreduce(input, weight, *args, **kwargs)

def forward(self, input_):
"""Forward of RowParallelLinear

Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]

Returns:
- output
- bias
"""

# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_, group=self.tp_group)
# Matrix multiply.
allreduce_dgrad = False

if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context is True:
if not HAVE_TE:
assert (
self.config.cpu_offloading is False
), "CPU Offloading cannot be enabled while TE is not present"
else:
input_parallel.activation_offloading = self.config.cpu_offloading_activations

output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=allreduce_dgrad,
sequence_parallel=False,
tp_group=None,
grad_output_buffer=None,
)

# All-reduce across all the partitions.
if self.explicit_expert_comm:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(
output_parallel, group=self.tp_group
)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel, group=self.tp_group)
if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {"weight": 1}, sharded_offsets
)

def set_extra_state(self, state: Any):
"""Extra state is ignored"""

def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)

  • 在初始化时:

    • 参数设置整体与row_parallel_linear类似,不同点在于其包含参数input_is_parallel记录输出是否已经被并行切分,并且存在约束如果设置了self.sequence_parallel,那么self.input_is_parallel必须为True。

    • 其切分权重时也是对输入维度进行切分(input_size_per_partition = input_size / tp_world_size)

  • 在Forward时,流程如下:

    1. 其计查看参数input_is_parallel,如果没有切分就调用scatter在TP组内进行划分

    2. 然后其调用_forward_impl来实现具体计算,与ColumnParallelLinear计算类似,如果使用了sequence_parallel会先All Gather获取对应输入数据

    3. 然后对局部输出做对应通信得到output

      • 普通情况(非 expert、非 sequence_parallel)
        调用 reduce_from_tensor_model_parallel_region
        => 本质是 **TP all-reduce(sum)**,把各 rank 的 Y_i 求和得到完整 Y(每个 rank 都得到同样的 Y)。

      • sequence_parallel=True
        调用 reduce_scatter_to_sequence_parallel_region
        => 把 sum 的结果直接按 sequence parallel 需要的布局做 reduce-scatter,避免先 all-reduce 再切分的额外开销。

      • expert 显式通信(MoE): 不在这里做 reduce,直接返回本地 output_parallel,因为 MoE 的 token dispatcher 负责跨 rank 的聚合/路由。

    4. 最后返回output, output_bias

Transformer模块

在具体实现Transformer模块时,其会依赖multi_latent_attention参数来判断GPT 的每一层 self-attention 子模块用标准SelfAttention还是用MLA(Multi‑Latent Attention)变体。

我们这里直接看最标准的实现,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
class TransformerLayer(MegatronModule, BaseTransformerLayer):
"""A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""

def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
vp_stage: Optional[int] = None,
):
super().__init__(config=config)

# Enable cuda graphs.
if (
config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration"
) or config.external_cuda_graph:
assert not (
config.enable_cuda_graph and config.external_cuda_graph
), "Cudagraphs and external cudagraphs cannot be enabled at the same time"
if config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration":
if not self.training:
# Cudagraphs for inference are only enabled with the flash decoding kernel
assert (
self.config.flash_decode
), "--flash-decode is required to use CUDA graphs during inference"
self.cudagraph_manager = CudaGraphManager(config, vp_stage=vp_stage)
else:
# List to store CUDA graphs. A list of `N` CUDA graphs for this layer where N is
# the number of microbatches. Multiple CUDA graphs per layer is required to support
# pipelining which requires running FWD graph of multiple microbatches before BWD
# graph. To enable CUDA graph, this list should be populated in the model training
# script with the graphs returned by make_graphed_callables API before the first
# training step.
self.cuda_graphs = []
# List to store forward pre-hooks. Forward pre-hooks are not captured into CUDA
# graphs. Those hooks and args are collected in this list and should be manually
# triggered before CUDA Graph running. This is required to ensure the correct param
# all-gather overlap with forward compute.
self.cuda_graph_manual_hooks = []
self.current_microbatch = -1

if model_comm_pgs is None:
model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups()

self.submodules_config = submodules
self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage)
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout

# [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

attention_optional_kwargs = {}
if config.context_parallel_size > 1 and config.cp_comm_type is not None:
if isinstance(config.cp_comm_type, list):
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[self.layer_number]
else:
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type

attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs

# [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)

# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)

# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

# [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)

# [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)

# [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 8: MLP block]
additional_mlp_kwargs = {}
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import TEFusedMLP
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer

# MLP expects tp_group but MoELayer expects model_comm_pgs to be passed in.
# We can change MLP to accept model_comm_pgs but it makes the logic implicit
# The conditional below is to make the logic explicit
# if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs
if isinstance(submodules.mlp, ModuleSpec):
if submodules.mlp.module in (MoELayer, GroupedMLP, TEGroupedMLP, SequentialMLP):
additional_mlp_kwargs["model_comm_pgs"] = model_comm_pgs
elif submodules.mlp.module == MLP:
assert hasattr(
model_comm_pgs, 'tp'
), 'TP process group is required for MLP in TransformerLayer'
additional_mlp_kwargs["tp_group"] = model_comm_pgs.tp
elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP:
assert hasattr(
model_comm_pgs, 'tp'
), 'TP process group is required for TEFusedMLP in TransformerLayer'
additional_mlp_kwargs["tp_group"] = model_comm_pgs.tp
else:
log_single_rank(
logger,
logging.WARNING,
f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.",
)
self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)

# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)

self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
if self.config.recompute_granularity == 'selective':
if "layernorm" in self.config.recompute_modules:
if (
not isinstance(self.input_layernorm, IdentityOp)
and not self.config.external_cuda_graph
):
self.recompute_input_layernorm = True
if self.config.fp8:
self.self_attention.set_for_recompute_input_layernorm()
if not isinstance(self.pre_mlp_layernorm, IdentityOp):
self.recompute_pre_mlp_layernorm = True
if self.config.fp8:
if isinstance(self.mlp, MoELayer):
self.mlp.set_for_recompute_pre_mlp_layernorm()
else:
from megatron.core.extensions.transformer_engine import (
set_save_original_input,
)

set_save_original_input(self.mlp.linear_fc1)
if "mlp" in self.config.recompute_modules:
if not isinstance(self.mlp, MoELayer):
self.recompute_mlp = True

# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad

@staticmethod
def _get_layer_offset(config: TransformerConfig):
"""
Get the layer offset for the current pipeline stage.

Deprecated: please use `get_transformer_layer_offset` instead.
"""

warnings.warn(
"TransformerLayer._get_layer_offset is deprecated."
"Please use get_transformer_layer_offset instead."
)
return get_transformer_layer_offset(config)

def forward(self, *args, **kwargs):
"""
Perform a forward pass through the transformer layer.

This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
hidden_states, context = self._forward_attention(*args, **kwargs)
output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
return output, context

def _forward_attention(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Perform a forward pass through the attention layer and the layernorms before and after
the attention operations.

Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor, optional): Bias tensor for Q * K.T.
inference_context (object, optional): Parameters for inference-time optimizations.
packed_seq_params (object, optional): Parameters for packed sequence processing.
sequence_len_offset (Tensor, optional): Offset along sequence dimension
during inference.

Returns:
Tuple[Tensor, Tensor]: A tuple containing:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.
context (Tensor): Updated context tensor if cross-attention is used,
otherwise None.
"""

inference_context = deprecate_inference_params(inference_context, inference_params)

# Residual connection.
residual = hidden_states

# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)

# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")

if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")

# Residual connection.
residual = hidden_states

# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)

# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_context=inference_context,
)

if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)

return hidden_states, context

def _forward_mlp(self, hidden_states, inference_context=None):
"""
Perform a forward pass through the feed-forward layer.

Args:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.

Returns:
output (Tensor): Transformed hidden states of shape [s, b, h].
"""

# Residual connection.
residual = hidden_states

# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)

nvtx_range_push(suffix="mlp")
# Potentially chunk the MLP computation during prefill to minimize the peak activation size
should_chunk_mlp_for_prefill = (
self.config.mlp_chunks_for_prefill > 1
and inference_context is not None
and not inference_context.is_decode_only()
and not isinstance(self.mlp, IdentityOp)
)

if self.recompute_mlp:
if self.config.fp8:
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import te_checkpoint

mlp_output_with_bias = te_checkpoint(
self.mlp,
False,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
pre_mlp_layernorm_output,
)
else:
mlp_output_with_bias = tensor_parallel.checkpoint(
self.mlp, False, pre_mlp_layernorm_output
)
elif should_chunk_mlp_for_prefill:
# Chunk input along sequence dimension
num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0])
chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0)

# Compute outputs for each chunk
outputs = [self.mlp(chunk) for chunk in chunks]

# Aggregate chunk outputs
mlp_output = torch.cat([out for out, _ in outputs], dim=0)
bias_chunks = [bias for _, bias in outputs if bias is not None]
bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None
mlp_output_with_bias = (mlp_output, bias_output)

else:
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)

if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
nvtx_range_pop(suffix="mlp")

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")

# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)

return output

def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the transformer layer.

Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.

Returns:
ShardedStateDict: A dictionary containing the sharded state of the transformer layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict

def get_layer_static_inputs(self, seq_length, micro_batch_size):
"""
Get the static inputs for the transformer layer.

Returns:
Dict[str, torch.Tensor]: A dictionary containing the static inputs for the layer.
"""
# Calculate data shape related values.
context_parallel_size = self.config.context_parallel_size
slen_per_cp = seq_length // context_parallel_size
sequence_parallel = self.config.sequence_parallel
tensor_model_parallel_size = self.config.tensor_model_parallel_size
slen_per_cptp = (
slen_per_cp // tensor_model_parallel_size if sequence_parallel else slen_per_cp
)

static_inputs = {}
static_inputs["hidden_states"] = torch.ones(
(slen_per_cptp, micro_batch_size, self.config.hidden_size),
dtype=torch.bfloat16,
requires_grad=True,
device=torch.cuda.current_device(),
)
static_inputs["attention_mask"] = (
~(torch.tril(torch.ones((slen_per_cp, seq_length))).bool())
.to(torch.cuda.current_device())
.reshape(1, 1, slen_per_cp, seq_length)
.tile(micro_batch_size, 1, 1, 1)
)
return static_inputs

def setup_manual_hooks(self, make_hook_func):
"""
Set CUDA Graph manual hooks for the modules that contain direct parameters and are
covered by cudagraphs.
"""
self.cuda_graph_manual_hooks = []

# Select the modules who contain direct parameters and are covered by cudagraphs.
# Add these modules to the `cuda_graph_manual_hooks` because their hooks will not
# be automatically triggered when they go through the CUDA Graph path.
if self.config.cuda_graph_scope == 'full':
high_level_modules = [self]
else:
assert (
self.config.cuda_graph_scope == 'attn'
), "Invalid cuda_graph_scope ${self.config.cuda_graph_scope}"
high_level_modules = [
self.input_layernorm,
self.self_attention,
self.pre_cross_attn_layernorm,
self.cross_attention,
]

param_modules = []
for module in high_level_modules:
for submodule in module.modules():
if next(submodule.parameters(recurse=False), None) is not None:
# Module contains direct parameters.
param_modules.append(submodule)
continue
if len(param_modules) > 0:
for module in param_modules:
self.cuda_graph_manual_hooks.append((make_hook_func(), (module,)))

def _cuda_graph_capture(self, *args, **kwargs):
"""
CUDA Graph capture for this layer. There are some differences from the normal pass:
1. In some conditions CUDA graph cannot cover the entire layer. The `cuda_graph_scope`
attribute can be set to control the scope of the CUDA graph.
2. If context is None, it cannot be returned as output.
"""
hidden_states, context = self._forward_attention(*args, **kwargs)

if self.config.cuda_graph_scope == "full":
hidden_states = self._forward_mlp(hidden_states)
cuda_graph_outputs = [hidden_states]

if context is not None:
cuda_graph_outputs.append(context)
return tuple(cuda_graph_outputs)

def _cuda_graph_replay(self, *args, **kwargs):
"""
CUDA graph replay for this layer and microbatch
`self.current_microbatch`. TransformerEngine versions>=1.10
allow keyword arguments with CUDA graph. However, CUDA graph
acccepts only Tensor inputs and Tensor outputs. Hence,
`inference_context` and `packed_seq_params` are excluded from
input list while output is limited to `hidden_states`.
"""

def _check_cuda_graph_replay_args(*args, **kwargs):
"""Helper function to get optional tensor arguments for CUDA graph."""

assert len(args) <= 1, "At most one positional argument `hidden_states` is expected."
if len(args) == 1:
hidden_states = args[0]
else:
hidden_states = kwargs.pop("hidden_states")
cudagraph_args = [hidden_states]

optional_inputs = kwargs.copy()
optional_inputs['is_first_microbatch'] = self.current_microbatch == 0
try:
import transformer_engine.pytorch as te # pylint: disable=unused-import

def get_zero_attention_mask(slen_per_tpcp, micro_batch_size):
sequence_parallel = self.config.sequence_parallel
tensor_model_parallel_size = self.config.tensor_model_parallel_size
slen_per_cp = (
slen_per_tpcp * tensor_model_parallel_size
if sequence_parallel
else slen_per_tpcp
)
slen = slen_per_cp * self.config.context_parallel_size
return torch.zeros(
(micro_batch_size, 1, slen_per_cp, slen),
dtype=torch.bool,
device=torch.cuda.current_device(),
)

if not is_te_min_version("1.10.0"):
# TE version < 1.10.0 does not support keyword arguments with CUDA graph.
for k, v in kwargs.items():
if k == "attention_mask":
if v is not None:
cudagraph_args.append(v)
optional_inputs[k] = None
else:
cudagraph_args.append(
get_zero_attention_mask(
hidden_states.size(0), hidden_states.size(1)
)
)
else:
assert v is None, "Keyword Arguments not supported with CUDA graph."
elif optional_inputs['attention_mask'] is None:
# The attention_mask can be None when there is no padding to the input sequence.
# However, an attention_mask Tensor must be passed into cudagraph for replay, so
# we create an equivalent zero Tensor as the attention_mask.
optional_inputs["attention_mask"] = get_zero_attention_mask(
hidden_states.size(0), hidden_states.size(1)
)
except ImportError:
raise RuntimeError("CUDAGraph requires TransformerEngine, but not installed")
return tuple(cudagraph_args), optional_inputs

cg_index = self.current_microbatch % len(self.cuda_graphs)
assert ('inference_context' not in kwargs or kwargs['inference_context'] is None) and (
'packed_seq_params' not in kwargs or kwargs['packed_seq_params'] is None
), "CUDA graph accepts only Tensor inputs."
cudagraph_args, cudagraph_kwargs = _check_cuda_graph_replay_args(*args, **kwargs)

for hook, hook_args in self.cuda_graph_manual_hooks:
hook(*hook_args)
cuda_graph_output = self.cuda_graphs[cg_index](*cudagraph_args, **cudagraph_kwargs)

if cudagraph_kwargs.get('context') is not None:
context = cuda_graph_output[-1]
cuda_graph_output = cuda_graph_output[:-1]
else:
context = None
if self.config.cuda_graph_scope == "attn":
# CUDA Graph only covers the attention layer. Feed-forward
# layer still goes through the normal pass.
output = self._forward_mlp(*cuda_graph_output)
else:
output = cuda_graph_output[0]
return output, context

def __call__(self, *args, **kwargs):
# Training and validation mode CUDA graphs
if hasattr(self, 'cudagraph_manager') and kwargs.get('inference_context') is None:
return self.cudagraph_manager(self, args, kwargs)
# Inference mode. CUDA graphs are used in the decode phase only, when attn mask is None
elif not self.training and (
hasattr(self, 'cudagraph_manager')
and kwargs['attention_mask'] is None
and (
(
kwargs.get('inference_context') is not None
and kwargs['inference_context'].is_decode_only()
)
or (
kwargs.get('inference_params') is not None
and kwargs['inference_params'].is_decode_only()
)
)
):
assert (
kwargs.get('attention_mask') is None
), f"Attention mask must not be set when using CUDA graphs for decode"
return self.cudagraph_manager(self, args, kwargs)
elif (
self.config.external_cuda_graph
and self.training
and (is_graph_capturing() or self.cuda_graphs)
):
if not self.cuda_graphs:
# Do CUDA Graphs capture.
cuda_graph_func = self._cuda_graph_capture
else:
# Do CUDA Graphs replay.
cuda_graph_func = self._cuda_graph_replay
return cuda_graph_func(*args, **kwargs)
return super(MegatronModule, self).__call__(*args, **kwargs)

  • 其初始化的时候初始化了下面几个模块:

    1. Input Layernorm:对输入数据进行可选的层归一化

    2. SelfAttention

    3. BiasDropoutFusion

    4. Post SelfAttention:自注意力后的可选层归一化

    5. CrossAttention

    6. BiasDropoutFusion

    7. Pre MLP:MLP 前的可选层归一化

    8. MLP block

    9. BiasDropoutFusion

我们下面再看一下SelfAttention模块是如何设计的,尤其关注其与TP并行相关的内容

SelfAttention

SelfAttention的相关代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class SelfAttention(Attention):
"""Self-attention layer class

Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""

def __init__(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
model_comm_pgs: ModelCommProcessGroups = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="self",
cp_comm_type=cp_comm_type,
model_comm_pgs=model_comm_pgs,
)

self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
tp_group=self.model_comm_pgs.tp,
)

if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None

if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None

def run_realtime_tests(self):
"""Performs a consistency check.

This function makes sure that tensors across devices are the same during an experiment.
This is often not guaranteed to be so because of silent hardware failures (eg, memory
corruption loading a checkpoint, network traffic corruption encountered during
data transmission).

(TODO) In the future, more tensors should be checked across the training run and
checked every X iterations. This is left for future work. Equality of tensors is probably
not required; transmitting hashes is sufficient."""

if not self.config.qk_layernorm:
return

# check that all tensor parallel and data parallel ranks have the same
# Q & K layernorm parameters.
rank = get_data_parallel_rank()
inputs = torch.stack(
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
]
)
dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())]
dp_list[rank] = inputs
torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group())

def _compare(srcs, tgts, names, parallelism):
assert len(srcs) == len(tgts) == len(names)
for src, tgt, name in zip(srcs, tgts, names):
assert torch.all(src == tgt), (
f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. "
f"Diff: {torch.norm(src - tgt)}"
)

for i, dp in enumerate(dp_list):
q_w, q_b, k_w, k_b = torch.unbind(dp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"DP",
)

rank = get_tensor_model_parallel_rank()
tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())]
tp_list[rank] = inputs
torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group())

for i, tp in enumerate(tp_list):
q_w, q_b, k_w, k_b = torch.unbind(tp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"TP",
)

def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_qkv, _ = self.linear_qkv(hidden_states)

# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)

split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]

if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn]
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
else:

# [sq, b, ng, (np/ng + 2) * hn]
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

if self.q_layernorm is not None:
query = self.q_layernorm(query)

if self.k_layernorm is not None:
key = self.k_layernorm(key)

if self.config.test_mode:
self.run_realtime_tests()

return query, key, value

def backward_dw(self) -> NoReturn:
"""Execute weight update operations"""
self._backward_qkv_proj()
self._backward_output_proj()

def _backward_qkv_proj(self):
"""Update weights for QKV projection layer"""
self.linear_qkv.backward_dw()

def _backward_output_proj(self):
"""Update weights for output projection layer"""
self.linear_proj.backward_dw()

def set_for_recompute_input_layernorm(self):
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
from megatron.core.extensions.transformer_engine import set_save_original_input

set_save_original_input(self.linear_qkv)

  • 在初始化时:

    • 尤其它是拓展了Attention类,所以其首先对Attention进行了初始化,Attention的初始化代码如下所示:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      class Attention(MegatronModule, ABC):
      """Attention layer abstract class.

      This layer only contains common modules required for the "self attn" and
      "cross attn" specializations.
      """

      def __init__(
      self,
      config: TransformerConfig,
      submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
      layer_number: int,
      attn_mask_type: AttnMaskType,
      attention_type: str,
      cp_comm_type: str = None,
      model_comm_pgs: ModelCommProcessGroups = None,
      ):
      super().__init__(config=config)

      self.config = config
      self.layer_number = layer_number
      self.attn_mask_type = attn_mask_type
      self.attention_type = attention_type

      # For normal attention without groups, num_query_groups == num_attention_heads,
      # so these two will be the same
      self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
      self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups

      if model_comm_pgs is None:
      model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(
      required_pgs=['tp', 'cp']
      )
      else:
      assert hasattr(
      model_comm_pgs, 'tp'
      ), "Attention model_comm_pgs must have tp process group"
      assert hasattr(
      model_comm_pgs, 'cp'
      ), "Attention model_comm_pgs must have cp process group"
      self.model_comm_pgs = model_comm_pgs

      # Per attention head and per partition values
      world_size = get_pg_size(self.model_comm_pgs.tp)
      self.hidden_size_per_attention_head = divide(
      self.query_projection_size, self.config.num_attention_heads
      )
      self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
      self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)

      # To support both CUDA Graphs and key value with different hidden size
      self.key_hidden_size = self.hidden_size_per_attention_head
      self.val_hidden_size = self.hidden_size_per_attention_head

      self.core_attention = build_module(
      submodules.core_attention,
      config=self.config,
      layer_number=self.layer_number,
      attn_mask_type=self.attn_mask_type,
      attention_type=self.attention_type,
      cp_comm_type=cp_comm_type,
      softmax_scale=self.config.softmax_scale,
      model_comm_pgs=self.model_comm_pgs,
      )

      self.checkpoint_core_attention = (
      self.config.recompute_granularity == 'selective'
      and "core_attn" in self.config.recompute_modules
      )

      # Output.
      self.linear_proj = build_module(
      submodules.linear_proj,
      self.query_projection_size,
      self.config.hidden_size,
      config=self.config,
      init_method=self.config.output_layer_init_method,
      bias=self.config.add_bias_linear,
      input_is_parallel=True,
      skip_bias_add=True,
      is_expert=False,
      tp_comm_buffer_name='proj',
      tp_group=self.model_comm_pgs.tp,
      )

      if (
      HAVE_TE
      and self.config.fp8
      and self.config.fp8_recipe != 'delayed'
      and is_te_min_version("2.6.0dev0")
      and isinstance(self.linear_proj, TELinear)
      ):
      # For fp8 training, the output of the fused core_attn is saved by itself, and
      # linear_proj also saves the quantized tensor of this output. Here we set the
      # linear_proj to save the original input tensors to avoid the extra memory usage of
      # the quantized tensor.
      set_save_original_input(self.linear_proj)

      • 在计算q、k、v的输出维度时,其单独计算了q的维度(self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads),再计算了k与v的维度(self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups),因为在类似在GQA/MQA中self.config.num_attention_headsself.config.num_query_groups可能不同

      • 然后基于TP并行度切分了q所对应的self.config.num_attention_heads个数,还切分了kv所对应的self.config.num_query_groups,注意这里如果不能整除的话会直接报错,所以运行起来的必然是每个TP rank都有均匀切分的q、k、v

      • 然后其构建了core attention,在本地模式中使用的是DotProductAttention,代码如下所示,其主要是在Forward时负责依据传入的q、k、v、attention_mask等计算attention结果,

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      132
      133
      134
      135
      136
      137
      138
      139
      140
      141
      142
      143
      144
      145
      146
      147
      148
      149
      150
      151
      152
      153
      154
      155
      156
      157
      158
      159
      160
      161
      162
      163
      164
      165
      166
      167
      168
      169
      170
      171
      172
      173
      174
      175
      176
      177
      178
      179
      180
      181
      182
      183
      184
      185
      186
      187
      188
      189
      190
      191
      192
      193
      194
      195
      196
      197
      198
      class DotProductAttention(MegatronModule):
      """
      Region where selective activation recomputation is applied.
      This region is memory intensive but less compute intensive which
      makes activation checkpointing more efficient for LLMs (20B+).
      See Reducing Activation Recomputation in Large Transformer Models:
      https://arxiv.org/abs/2205.05198 for more details.

      We use the following notation:
      h: hidden size
      n: number of attention heads
      p: number of tensor model parallel partitions
      b: batch size
      s: sequence length
      """

      def __init__(
      self,
      config: TransformerConfig,
      layer_number: int,
      attn_mask_type: AttnMaskType,
      attention_type: str,
      attention_dropout: float = None,
      softmax_scale: float = None,
      cp_comm_type: str = None,
      model_comm_pgs: ModelCommProcessGroups = None,
      ):
      super().__init__(config=config)

      self.config: TransformerConfig = config

      assert (
      self.config.context_parallel_size == 1
      ), "Context parallelism is only supported by TEDotProductAttention!"

      assert (
      self.config.window_size is None
      ), "Sliding Window Attention is only supported by TEDotProductAttention!"

      self.layer_number = max(1, layer_number)
      self.attn_mask_type = attn_mask_type
      self.attention_type = attention_type # unused for now

      projection_size = self.config.kv_channels * self.config.num_attention_heads

      # Per attention head and per partition values.
      if model_comm_pgs is None:
      # For backward compatibility, remove in v0.14 and raise error
      # raise ValueError("DotProductAttention was called without ModelCommProcessGroups")
      model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups(required_pgs=['tp'])
      else:
      assert hasattr(
      model_comm_pgs, 'tp'
      ), "DotProductAttention model_comm_pgs must have tp process group"

      world_size = model_comm_pgs.tp.size()
      self.hidden_size_per_partition = divide(projection_size, world_size)
      self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
      self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
      self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)

      coeff = None
      if softmax_scale is None:
      self.softmax_scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head)
      else:
      self.softmax_scale = softmax_scale

      if self.config.apply_query_key_layer_scaling:
      coeff = self.layer_number
      self.softmax_scale /= coeff

      self.scale_mask_softmax = FusedScaleMaskSoftmax(
      input_in_fp16=self.config.fp16,
      input_in_bf16=self.config.bf16,
      attn_mask_type=self.attn_mask_type,
      scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
      mask_func=attention_mask_func,
      softmax_in_fp32=self.config.attention_softmax_in_fp32,
      scale=coeff,
      )

      # Dropout. Note that for a single iteration, this layer will generate
      # different outputs on different number of parallel partitions but
      # on average it should not be partition dependent.
      self.attention_dropout = torch.nn.Dropout(
      self.config.attention_dropout if attention_dropout is None else attention_dropout
      )

      def forward(
      self,
      query: Tensor,
      key: Tensor,
      value: Tensor,
      attention_mask: Tensor,
      attn_mask_type: AttnMaskType = None,
      attention_bias: Tensor = None,
      packed_seq_params: Optional[PackedSeqParams] = None,
      ):
      """Forward."""
      assert packed_seq_params is None, (
      "Packed sequence is not supported by DotProductAttention."
      "Please use TEDotProductAttention instead."
      )
      assert attention_bias is None, "Attention bias is not supported for DotProductAttention."

      # ===================================
      # Raw attention scores. [b, n/p, s, s]
      # ===================================

      # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
      # This is a noop for normal attention where ng == np. When using group query attention this
      # creates a view that has the keys and values virtually repeated along their dimension to
      # match the number of queries.

      # attn_mask_type is not used.
      if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
      key = key.repeat_interleave(
      self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
      )
      value = value.repeat_interleave(
      self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
      )

      # [b, np, sq, sk]
      output_size = (query.size(1), query.size(2), query.size(0), key.size(0))

      # [sq, b, np, hn] -> [sq, b * np, hn]
      # This will be a simple view when doing normal attention, but in group query attention
      # the key and value tensors are repeated to match the queries so you can't use
      # simple strides to extract the queries.
      query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
      # [sk, b, np, hn] -> [sk, b * np, hn]
      key = key.view(output_size[3], output_size[0] * output_size[1], -1)

      # preallocting input tensor: [b * np, sq, sk]
      matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
      (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu"
      )

      # Raw attention scores. [b * np, sq, sk]
      matmul_result = torch.baddbmm(
      matmul_input_buffer,
      query.transpose(0, 1), # [b * np, sq, hn]
      key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
      beta=0.0,
      alpha=self.softmax_scale,
      )

      # change view to [b, np, sq, sk]
      attention_scores = matmul_result.view(*output_size)

      # ===========================
      # Attention probs and dropout
      # ===========================

      # attention scores and attention mask [b, np, sq, sk]
      attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)

      # This is actually dropping out entire tokens to attend to, which might
      # seem a bit unusual, but is taken from the original Transformer paper.

      if not self.config.sequence_parallel:
      with tensor_parallel.get_cuda_rng_tracker().fork():
      attention_probs = self.attention_dropout(attention_probs)
      else:
      attention_probs = self.attention_dropout(attention_probs)

      # =========================
      # Context layer. [sq, b, hp]
      # =========================

      # value -> context layer.
      # [sk, b, np, hn] --> [b, np, sq, hn]

      # context layer shape: [b, np, sq, hn]
      output_size = (value.size(1), value.size(2), query.size(0), value.size(3))

      # change view [sk, b * np, hn]
      value = value.view(value.size(0), output_size[0] * output_size[1], -1)

      # change view [b * np, sq, sk]
      attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)

      # matmul: [b * np, sq, hn]
      context = torch.bmm(attention_probs, value.transpose(0, 1))

      # change view [b, np, sq, hn]
      context = context.view(*output_size)

      # [b, np, sq, hn] --> [sq, b, np, hn]
      context = context.permute(2, 0, 1, 3).contiguous()

      # [sq, b, np, hn] --> [sq, b, hp]
      new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
      context = context.view(*new_context_shape)

      return context

      • 然后其构建了linear_proj,注意其使用的是row_parallel_linear,并且它也明确在参数中指出了其输入是并行的,符合一贯的先列并行再行并行计算的结果
    • 其创建了linear_qkv:

      • linear_qkvcolumn_parallel_linear

      • linear_qkv输入维度是标准的self.config.hidden_size,其输出维度是self.query_projection_size + 2 * self.kv_projection_size,因为linear_qkv需要投影生成q、k、v这3个基础张量

      • 此外值得注意的是它还专门设计了gather_output为False,因为其本身就希望使用列并行来多注意力头计算

    • 然后还构建了submodules.q_layernormsubmodules.k_layernorm

  • 在Forward中完全走的是Attention的代码如下所示,依据nvtx_range_push其相关流程可以划分为:

    1. 计算出当前Sequence的q、k、v:

      • 其代码如下所示,
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
      """
      Derives `query`, `key` and `value` tensors from `hidden_states`.
      """
      # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
      mixed_qkv, _ = self.linear_qkv(hidden_states)

      # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
      new_tensor_shape = mixed_qkv.size()[:-1] + (
      self.num_query_groups_per_partition,
      (
      (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
      * self.hidden_size_per_attention_head
      ),
      )
      mixed_qkv = mixed_qkv.view(*new_tensor_shape)

      split_arg_list = [
      (
      self.num_attention_heads_per_partition
      // self.num_query_groups_per_partition
      * self.hidden_size_per_attention_head
      ),
      self.hidden_size_per_attention_head,
      self.hidden_size_per_attention_head,
      ]

      if SplitAlongDim is not None:

      # [sq, b, ng, (np/ng + 2) * hn]
      # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
      (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
      else:

      # [sq, b, ng, (np/ng + 2) * hn]
      # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
      (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)

      # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
      query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

      if self.q_layernorm is not None:
      query = self.q_layernorm(query)

      if self.k_layernorm is not None:
      key = self.k_layernorm(key)

      if self.config.test_mode:
      self.run_realtime_tests()

      return query, key, value

      • 首先通过mixed_qkv, _ = self.linear_qkv(hidden_states)得到mixed_qkv,因为self.linear_qkv是列并行并且初始化时设置了gather_output=False,所以得到的mixed_qkv是被TP并行划分后的部分结果,由于前面的检查,所以它必然是q、k、v维度的整数倍。故是从[sq,b,h]转化为了[sq,b,per_tp_num_query_groups *(per_tp_num_heads / per_tp_num_query_groups + 2) * head_dim],结果的最后一维是q、k、v的维度和

      • 然后会把形状进行调整,最后得到q的维度为[sq,b,per_tp_num_heads,head_dim],最后得到k与v的维度都是[sq,b,per_tp_num_query_groups, head_dim]

      为什么要引入 num_query_groups 这一维?因为它在支持 GQA/MQA 时很关键:

      • 普通 attention:num_query_groups== num_heads,每个 group 只有 1 个 query head,对应关系很直接。

      • GQA:num_query_groups < num_heads,多个 query heads 共享同一组 K/V(在同一个 group 下)。

    2. 调整key值

    3. 调用rotary_pos_emb

    4. 调用core_attention进行计算

    5. 调用linear_proj得到最终结果,因为其是一个row_parallel_linear,所以最后会通过all reduce得到完整的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
) -> Tuple[Tensor, Tensor]:
"""
Perform a forward pass through the attention module.

Args:
hidden_states (Tensor): Hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Optional[Tensor]): Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]): Inference context that manages
KV cache.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary
embedding tensor(s).
rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
attention_bias (Optional[Tensor]): Attention bias.
packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.
sequence_len_offset (Optional[int]): Sequence length offset used for
inference CUDA graphs.

Return:
(Tuple[Tensor, Tensor]) Attention output and bias.

"""
# Check if we need to skip RoPE
# no_rope is 0-indexed array and self.layer_number is 1-indexed
no_rope = (
self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False
)
if no_rope:
rotary_pos_emb = None

inference_context = deprecate_inference_params(inference_context, inference_params)

if inference_context and inference_context.is_dynamic_batching():
assert HAVE_FA3 or is_fa_min_version(
"2.7.3"
), "flash attn verion v2.7.3 and above is required for dynamic batching."

# hidden_states: [sq, b, h]
if self.config.flash_decode and not self.training and inference_context is not None:
rotary_pos_emb = None
else:
assert rotary_pos_cos is None and rotary_pos_sin is None

# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2

# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
nvtx_range_pop(suffix="qkv")

# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================

in_decode_mode = (
inference_context is not None
and inference_context.is_decode_only()
and not self.training
)

# This branch only runs in the decode phase of flash decoding and returns after the linear
# projection. This conditional is not used in the prefill phase or non-flash-decoding cases.
nvtx_range_push(suffix="adjust_key_value")
if in_decode_mode and self.config.flash_decode:
assert self.layer_number in inference_context.key_value_memory_dict
assert inference_context.sequence_len_offset is not None
inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[
self.layer_number
]
output = self.flash_decode(
sequence_len_offset=sequence_len_offset,
query_layer=query,
key_layer=key,
value_layer=value,
inference_key_memory=inference_key_memory,
inference_value_memory=inference_value_memory,
rotary_cos=rotary_pos_cos,
rotary_sin=rotary_pos_sin,
rotary_interleaved=self.config.rotary_interleaved,
)
out = output.transpose(0, 1).contiguous()
context_layer = out.view(out.size(0), out.size(1), -1)
output, bias = self.linear_proj(context_layer)
return output, bias

if (
in_decode_mode
and self.config.enable_cuda_graph
and self.config.cuda_graph_scope != "full_iteration"
and inference_context.is_static_batching()
):
raise ValueError(f"CUDA graphs must use flash decode with static batching!")

query, key, value, rotary_pos_emb, attn_mask_type, block_table = (
self._adjust_key_value_for_inference(
inference_context,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)
)

if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)
nvtx_range_pop(suffix="adjust_key_value")

# ================================================
# relative positional embedding (rotary embedding)
# ================================================
nvtx_range_push(suffix="rotary_pos_emb")
if rotary_pos_emb is not None and not self.config.flash_decode:
q_pos_emb, k_pos_emb = rotary_pos_emb

if packed_seq_params is not None:
if packed_seq_params.cu_seqlens_q_padded is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
else:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
if packed_seq_params.cu_seqlens_kv_padded is not None:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
else:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None

if q_pos_emb is not None:
# TODO VIJAY: simplify
if inference_context is None or inference_context.is_static_batching():
query = apply_rotary_pos_emb(
query,
q_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
cp_group=self.model_comm_pgs.cp,
)
else:
query = inference_context.apply_rotary_emb_query(
query, q_pos_emb, self.config, cu_seqlens_q, self.model_comm_pgs.cp
)
if k_pos_emb is not None:
key = apply_rotary_pos_emb(
key,
k_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
cp_group=self.model_comm_pgs.cp,
)

# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
nvtx_range_pop(suffix="rotary_pos_emb")

# ==================================
# core attention computation
# ==================================

nvtx_range_push(suffix="core_attention")
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
if inference_context is None or inference_context.is_static_batching():
# Static batching attention kernel.
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)

else:
# Dynamic batching attention kernel.
q, k, v = (query, key, value)
cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths()
cu_kv_lengths, kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths()

core_attn_out = self.flash_decode_and_prefill(
q,
k,
v,
max_seqlen_q,
max_seqlen_k,
cu_query_lengths,
cu_kv_lengths,
kv_lengths,
block_table,
)
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')

if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
nvtx_range_pop(suffix="core_attention")

# =================
# Output. [sq, b, h]
# =================

nvtx_range_push(suffix="linear_proj")
output, bias = self.linear_proj(core_attn_out)
nvtx_range_pop(suffix="linear_proj")

return output, bias

Embedding

在构建GPTModel时,在初始化时对embedding层使用的是LanguageModelEmbedding进行初始化,如下所示:

1
2
3
4
5
6
7
8
9
if self.pre_process or self.mtp_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
tp_group=self.model_comm_pgs.tp,
)

LanguageModelEmbedding也会涉及到TP并行切分,因为词表可能会难以放入一个GPU内,所以就可以进行TP切分,每个GPU只保留一部分词表 embedding,然后在Forward时每个GPU只去获取在自己范围内的token的内容,最后all reduce得到完整的embedding。

LanguageModelEmbedding

LanguageModelEmbedding的代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class LanguageModelEmbedding(MegatronModule):
"""Language model embeddings.

Args:
config (TransformerConfig): config object with all necessary configs for TransformerBlock
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This
is used for positional embedding
add_position_embedding (bool): Add a position embedding.
embedding_dropout_prob (float): dropout probability for embeddings
num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head. Defaults to 0.
scatter_to_sequence_parallel (bool): Set to False to disable scatter of embedding
across sequence parallel region. Defaults to True.
"""

def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super().__init__(config=config)

self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.tp_group = get_tensor_model_parallel_group_if_none(tp_group)
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)

# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.embedding_init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
tp_group=self.tp_group,
)

# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)

# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.embedding_init_method(self.position_embeddings.weight)

if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.embedding_init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None

# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)

def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True

@nvtx_decorator()
def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
"""Forward pass of the embedding module.

Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None

Returns:
Tensor: The output embeddings
"""
word_embeddings = self.word_embeddings(input_ids)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings

if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()

if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None

# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()

# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
embeddings, group=self.tp_group
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)

return embeddings

  • 在初始化时,其使用tensor_parallel.VocabParallelEmbedding进行初始化,

    • VocabParallelEmbedding的代码如下所示
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.

    Args:
    num_embeddings: vocabulary size.
    embedding_dim: size of hidden state.
    reduce_scatter_embeddings: Decides whether to perform ReduceScatter after embedding lookup

    Keyword Args:
    config: A megatron.core.ModelParallelConfig object
    """

    def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    *,
    init_method: Callable,
    reduce_scatter_embeddings: bool = False,
    config: ModelParallelConfig,
    tp_group: Optional[torch.distributed.ProcessGroup] = None,
    ):
    super(VocabParallelEmbedding, self).__init__()
    # Keep the input dimensions.
    self.num_embeddings = num_embeddings
    self.embedding_dim = embedding_dim
    self.reduce_scatter_embeddings = reduce_scatter_embeddings
    self.tp_group = tp_group

    self.tp_group = get_tensor_model_parallel_group_if_none(self.tp_group)

    (self.vocab_start_index, self.vocab_end_index) = (
    VocabUtility.vocab_range_from_global_vocab_size(
    self.num_embeddings, get_pg_rank(self.tp_group), get_pg_size(self.tp_group)
    )
    )
    self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
    self.deterministic_mode = config.deterministic_mode

    # Allocate weights and initialize.
    if config.use_cpu_initialization:
    self.weight = Parameter(
    torch.empty(
    self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
    )
    )
    if config.perform_initialization:
    _initialize_affine_weight_cpu(
    self.weight,
    self.num_embeddings,
    self.embedding_dim,
    self.num_embeddings_per_partition,
    0,
    init_method,
    params_dtype=config.params_dtype,
    rank=get_pg_rank(self.tp_group),
    world_size=get_pg_size(self.tp_group),
    )
    else:
    self.weight = Parameter(
    torch.empty(
    self.num_embeddings_per_partition,
    self.embedding_dim,
    device=torch.cuda.current_device(),
    dtype=config.params_dtype,
    )
    )
    if config.perform_initialization:
    _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)

    def forward(self, input_):
    """Forward.

    Args:
    input_ (torch.Tensor): Input tensor.
    """
    if self.tp_group.size() > 1:
    # Build the mask.
    input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
    # Mask the input.
    masked_input = input_.clone() - self.vocab_start_index
    masked_input[input_mask] = 0
    else:
    masked_input = input_
    # Get the embeddings.
    if self.deterministic_mode:
    output_parallel = self.weight[masked_input]
    else:
    # F.embedding currently has a non-deterministic backward function
    output_parallel = F.embedding(masked_input, self.weight)
    # Mask the output embedding.
    if self.tp_group.size() > 1:
    output_parallel[input_mask, :] = 0.0

    if self.reduce_scatter_embeddings:
    # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
    output_parallel = output_parallel.transpose(0, 1).contiguous()
    output = reduce_scatter_to_sequence_parallel_region(
    output_parallel, group=self.tp_group
    )
    else:
    # Reduce across all the model parallel GPUs.
    output = reduce_from_tensor_model_parallel_region(output_parallel, group=self.tp_group)
    return output

    def sharded_state_dict(
    self,
    prefix: str = "",
    sharded_offsets: Tuple[Tuple[int, int, int]] = (),
    metadata: Optional[dict] = None,
    ) -> ShardedStateDict:
    """Non-default implementation for embeddings due to `allow_shape_mismatch` param"""
    state_dict = self.state_dict(prefix="", keep_vars=True)

    weight_prefix = f"{prefix}weight"
    return {
    weight_prefix: make_tp_sharded_tensor_for_checkpoint(
    tensor=state_dict["weight"],
    key=weight_prefix,
    allow_shape_mismatch=True,
    prepend_offsets=sharded_offsets,
    )
    }

    • VocabParallelEmbedding在初始化时首先根据TP对embedding进行分组,得到起始位置self.vocab_start_index,与结束self.vocab_end_index

    • VocabParallelEmbedding在Forward时:

      1. 首先得到input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index),然后再得到masked_input = input_.clone() - self.vocab_start_index,再将不在这个范围内的置零masked_input[input_mask] = 0

      2. masked_input记录了token更新后的id,然后再在weight中依据masked_inputid去取对应的内容,得到output_parallel,并将不属于本rank的清零

      3. 然后这里依据reduce_scatter_embeddings有两种输出策略进行选择,注意

        reduce_scatter_embeddings = ((not self.add_position_embedding) and self.num_tokentypes <= 0 and self.config.sequence_parallel and self.scatter_to_sequence_parallel)

        • reduce_scatter_embeddings=True(配合 sequence parallel)

          • 先把布局从 [b, s, h] 转成 [s, b, h],因为 Megatron 的 sequence-parallel 通常以 [seq, batch, hidden] 为主(这样更容易沿 seq 维切分/拼接)。

          • 然后调用 reduce_scatter_to_sequence_parallel_region

            • 语义上等价于:先对 output_parallel 在 TP 组上做 sum-reduce,再按 sequence 维把结果 scatter 给各 rank。

            • 好处:直接产出 sequence-parallel 需要的分片输出,避免 “all-reduce 得到全量,再手动切分” 的额外开销和内存峰值。

        • reduce_scatter_embeddings=False(默认更直观)

          • reduce_from_tensor_model_parallel_region

            • 语义就是对 output_parallel 在 TP 组上 all-reduce(sum);

            • 每个 TP rank 都拿到完整的 embedding 输出(与未切分词表时一致)。

Tensor并行实验

实验依据采用的是GPT3 857m的模型,运行脚本如下所示,值得注意的是在GPT_MODEL_ARGS参数中设置为了local,即不使用transformer_engine而是使用Megatron-LM本地实现的gpt_layer,与上述介绍对应,此外也设置TP切分维度为4

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/bin/bash

# Runs the "857m" parameter model

export CUDA_DEVICE_MAX_CONNECTIONS=1

GPUS_PER_NODE=4
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))

CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
VOCAB_FILE=$3 #<Specify path to file>/gpt2-vocab.json
MERGE_FILE=$4 #<Specify path to file>/gpt2-merges.txt
DATA_PATH=$5 #<Specify path and file prefix>_text_document
USE_NSYS=0
if [[ ${6:-} == "--nsys" ]]; then
USE_NSYS=1
fi

DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)

GPT_MODEL_ARGS=(
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--transformer-impl local
)

TRAINING_ARGS=(
--micro-batch-size 4
--global-batch-size 16
# --rampup-batch-size 16 16 5859375
--train-iters 20000
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--fp16
--lr 6.0e-5
--lr-decay-style cosine
--min-lr 6.0e-6
--lr-warmup-fraction .001
--lr-decay-iters 20000
)

MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 4
--pipeline-model-parallel-size 1
)

DATA_ARGS=(
--data-path $DATA_PATH
--vocab-file $VOCAB_FILE
--merge-file $MERGE_FILE
--split 949,50,1
)

EVAL_AND_LOGGING_ARGS=(
--log-interval 200
--save-interval 10000
--eval-interval 1000
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)

PROFILER_ARGS=(
--profile
--use-pytorch-profiler
--profile-step-start 110
--profile-step-end 112
--profile-ranks 0
)

# Build command as an array (no string concatenation)
CMD=(
torchrun
"${DISTRIBUTED_ARGS[@]}"
pretrain_gpt.py
"${GPT_MODEL_ARGS[@]}"
"${TRAINING_ARGS[@]}"
"${MODEL_PARALLEL_ARGS[@]}"
"${DATA_ARGS[@]}"
"${EVAL_AND_LOGGING_ARGS[@]}"
"${PROFILER_ARGS[@]}"
)

if [[ "$USE_NSYS" -eq 1 ]]; then
NSIGHT_PREFIX="./nsight_profile/gpt3_857m"
echo "Running with Nsight profiling, output prefix: ${NSIGHT_PREFIX}"
exec nsys profile \
-s none -t nvtx,cuda \
--cudabacktrace=all \
--cuda-graph-trace=node \
--python-backtrace=cuda \
--wait all \
-o "${NSIGHT_PREFIX}" \
--force-overwrite true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
"${CMD[@]}"
else
exec "${CMD[@]}"
fi

运行的命令为:

1
bash examples/gpt3/train_gpt3_857m_distributed.sh     /workspace/megatron-lm/model_ckpt/gpt3_857m_tp4     /workspace/megatron-lm/tb_logs/gpt3_857m_profiler_tp4     /workspace/megatron-lm/data/tokenizer/gpt2-vocab.json     /workspace/megatron-lm/data/tokenizer/gpt2-merges.txt     /workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document      > gpt3_857m_tp4.log 2>&1 &

运行日志如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
W0103 04:22:27.318000 1772056 torch/distributed/run.py:766] 
W0103 04:22:27.318000 1772056 torch/distributed/run.py:766] *****************************************
W0103 04:22:27.318000 1772056 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0103 04:22:27.318000 1772056 torch/distributed/run.py:766] *****************************************
using world size: 4, data-parallel size: 1, context-parallel size: 1, hierarchical context-parallel sizes: None, tensor-model-parallel size: 4, pipeline-model-parallel size: 1
Number of virtual stages per pipeline stage: None
WARNING: Setting args.check_for_nan_in_loss_and_grad to False since dynamic loss scaling is being used
using torch.float16 for parameters ...
------------------------ arguments ------------------------
account_for_embedding_in_pipeline_split ......... False
account_for_loss_in_pipeline_split .............. False
accumulate_allreduce_grads_in_fp32 .............. False
adam_beta1 ...................................... 0.9
adam_beta2 ...................................... 0.95
adam_eps ........................................ 1e-08
add_bias_linear ................................. True
add_position_embedding .......................... True
add_qkv_bias .................................... True
adlr_autoresume ................................. False
adlr_autoresume_interval ........................ 1000
align_grad_reduce ............................... True
align_param_gather .............................. False
app_tag_run_name ................................ None
app_tag_run_version ............................. 0.0.0
apply_layernorm_1p .............................. False
apply_query_key_layer_scaling ................... False
apply_residual_connection_post_layernorm ........ False
apply_rope_fusion ............................... False
async_save ...................................... None
async_tensor_model_parallel_allreduce ........... True
attention_backend ............................... AttnBackend.auto
attention_dropout ............................... 0.1
attention_softmax_in_fp32 ....................... False
auto_detect_ckpt_format ......................... False
barrier_with_L1_time ............................ True
bert_binary_head ................................ True
bert_embedder_type .............................. megatron
bert_load ....................................... None
bf16 ............................................ False
bias_dropout_fusion ............................. True
bias_gelu_fusion ................................ True
bias_swiglu_fusion .............................. True
biencoder_projection_dim ........................ 0
biencoder_shared_query_context_model ............ False
block_data_path ................................. None
cache_mla_latents ............................... False
calc_ft_timeouts ................................ False
calculate_per_token_loss ........................ False
check_for_large_grads ........................... False
check_for_nan_in_loss_and_grad .................. False
check_for_spiky_loss ............................ False
check_weight_hash_across_dp_replicas_interval ... None
ckpt_assume_constant_structure .................. False
ckpt_convert_format ............................. None
ckpt_convert_save ............................... None
ckpt_convert_update_legacy_dist_opt_format ...... False
ckpt_format ..................................... torch_dist
ckpt_fully_parallel_load ........................ False
ckpt_fully_parallel_save ........................ True
ckpt_fully_parallel_save_deprecated ............. False
ckpt_step ....................................... None
classes_fraction ................................ 1.0
clip_grad ....................................... 1.0
clone_scatter_output_in_embedding ............... True
config_logger_dir ...............................
consumed_train_samples .......................... 0
consumed_valid_samples .......................... 0
context_parallel_size ........................... 1
cp_comm_type .................................... ['p2p']
create_attention_mask_in_dataloader ............. True
cross_entropy_fusion_impl ....................... native
cross_entropy_loss_fusion ....................... False
cuda_graph_scope ................................ full
cuda_graph_warmup_steps ......................... 3
data_args_path .................................. None
data_cache_path ................................. None
data_parallel_random_init ....................... False
data_parallel_sharding_strategy ................. no_shard
data_parallel_size .............................. 1
data_path ....................................... ['/workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document']
data_per_class_fraction ......................... 1.0
data_sharding ................................... True
dataloader_type ................................. single
ddp_average_in_collective ....................... False
ddp_bucket_size ................................. None
ddp_num_buckets ................................. None
ddp_pad_buckets_for_high_nccl_busbw ............. False
decoder_first_pipeline_num_layers ............... None
decoder_last_pipeline_num_layers ................ None
decoder_num_layers .............................. None
decoder_seq_length .............................. None
decoupled_lr .................................... None
decoupled_min_lr ................................ None
decrease_batch_size_if_needed ................... False
defer_embedding_wgrad_compute ................... False
delay_wgrad_compute ............................. False
deprecated_use_mcore_models ..................... False
deterministic_mode .............................. False
dino_bottleneck_size ............................ 256
dino_freeze_last_layer .......................... 1
dino_head_hidden_size ........................... 2048
dino_local_crops_number ......................... 10
dino_local_img_size ............................. 96
dino_norm_last_layer ............................ False
dino_teacher_temp ............................... 0.07
dino_warmup_teacher_temp ........................ 0.04
dino_warmup_teacher_temp_epochs ................. 30
disable_bf16_reduced_precision_matmul ........... False
disable_mamba_mem_eff_path ...................... False
disable_straggler_on_startup .................... False
dist_ckpt_format_deprecated ..................... None
dist_ckpt_strictness ............................ assume_ok_unexpected
distribute_saved_activations .................... False
distributed_backend ............................. nccl
distributed_timeout_minutes ..................... 10
embedding_init_method_std ....................... None
embedding_path .................................. None
empty_unused_memory_level ....................... 0
enable_cuda_graph ............................... False
enable_experimental ............................. False
enable_ft_package ............................... False
enable_full_sharding_in_hsdp .................... False
enable_gloo_process_groups ...................... True
enable_msc ...................................... True
enable_one_logger ............................... True
encoder_num_layers .............................. 24
encoder_seq_length .............................. 2048
end_weight_decay ................................ 0.1
eod_mask_loss ................................... False
error_injection_rate ............................ 0
error_injection_type ............................ transient_error
eval_interval ................................... 1000
eval_iters ...................................... 10
evidence_data_path .............................. None
exit_duration_in_mins ........................... None
exit_interval ................................... None
exit_on_missing_checkpoint ...................... False
exit_signal_handler ............................. False
exp_avg_dtype ................................... torch.float32
exp_avg_sq_dtype ................................ torch.float32
expert_model_parallel_size ...................... 1
expert_tensor_parallel_size ..................... 4
export_force_local_attention .................... False
export_kd_cfg ................................... None
export_kd_teacher_ckpt_format ................... None
export_kd_teacher_load .......................... None
export_kv_cache_quant ........................... False
export_legacy_megatron .......................... False
export_model_type ............................... GPTModel
export_moe_apply_probs_on_input ................. False
export_qk_l2_norm ............................... False
export_quant_cfg ................................ None
export_real_quant_cfg ........................... None
export_te_mcore_model ........................... False
external_cuda_graph ............................. False
ffn_hidden_size ................................. 4096
finetune ........................................ False
finetune_data_split ............................. train
finetune_hf_dataset ............................. None
first_last_layers_bf16 .......................... False
flash_decode .................................... False
fp16 ............................................ True
fp16_lm_cross_entropy ........................... False
fp32_residual_connection ........................ False
fp8 ............................................. None
fp8_amax_compute_algo ........................... most_recent
fp8_amax_history_len ............................ 1
fp8_interval .................................... 1
fp8_margin ...................................... 0
fp8_param_gather ................................ False
fp8_recipe ...................................... delayed
fp8_wgrad ....................................... True
fsdp_double_buffer .............................. False
full_validation ................................. False
global_batch_size ............................... 16
grad_reduce_in_bf16 ............................. False
gradient_accumulation_fusion .................... True
gradient_reduce_div_fusion ...................... True
group_query_attention ........................... False
head_lr_mult .................................... 1.0
heterogeneous_layers_config_encoded_json ........ None
heterogeneous_layers_config_path ................ None
hidden_dropout .................................. 0.1
hidden_size ..................................... 1024
hierarchical_context_parallel_sizes ............. None
high_priority_stream_groups ..................... []
hybrid_attention_ratio .......................... 0.0
hybrid_mlp_ratio ................................ 0.0
hybrid_override_pattern ......................... None
hysteresis ...................................... 2
ict_head_size ................................... None
ict_load ........................................ None
img_h ........................................... 224
img_w ........................................... 224
indexer_batch_size .............................. 128
indexer_log_interval ............................ 1000
inference_batch_times_seqlen_threshold .......... -1
inference_dynamic_batching ...................... False
inference_dynamic_batching_buffer_guaranteed_fraction 0.2
inference_dynamic_batching_buffer_overflow_factor None
inference_dynamic_batching_buffer_size_gb ....... 40.0
inference_dynamic_batching_chunk_size ........... 256
inference_dynamic_batching_max_requests_override None
inference_dynamic_batching_max_tokens_override .. None
inference_dynamic_batching_num_cuda_graphs ...... 16
inference_max_batch_size ........................ 8
inference_max_seq_length ........................ 2560
inference_rng_tracker ........................... False
init_method_std ................................. 0.006
init_method_xavier_uniform ...................... False
init_model_with_meta_device ..................... False
initial_loss_scale .............................. 4294967296
inprocess_active_world_size ..................... 4
inprocess_barrier_timeout ....................... 120
inprocess_completion_timeout .................... 120
inprocess_empty_cuda_cache ...................... False
inprocess_granularity ........................... node
inprocess_hard_timeout .......................... 90
inprocess_heartbeat_interval .................... 30
inprocess_heartbeat_timeout ..................... 60
inprocess_last_call_wait ........................ 1
inprocess_max_iterations ........................ None
inprocess_monitor_process_interval .............. 1.0
inprocess_monitor_thread_interval ............... 1.0
inprocess_progress_watchdog_interval ............ 1.0
inprocess_restart ............................... False
inprocess_soft_timeout .......................... 60
inprocess_termination_grace_time ................ 1
is_hybrid_model ................................. False
iter_per_epoch .................................. 1250
iterations_to_skip .............................. []
keep_fp8_transpose_cache ........................ False
kitchen_config_file ............................. None
kitchen_recipe_number ........................... None
kv_channels ..................................... 64
kv_lora_rank .................................... 32
lazy_mpu_init ................................... None
load ............................................ /workspace/megatron-lm/model_ckpt/gpt3_857m_tp4
load_main_params_from_ckpt ...................... None
load_model_opt_format ........................... False
local_rank ...................................... 0
log_energy ...................................... False
log_interval .................................... 200
log_loss_scale_to_tensorboard ................... True
log_memory_to_tensorboard ....................... False
log_num_zeros_in_grad ........................... False
log_params_norm ................................. False
log_progress .................................... False
log_straggler ................................... False
log_throughput .................................. False
log_timers_to_tensorboard ....................... False
log_validation_ppl_to_tensorboard ............... False
log_world_size_to_tensorboard ................... False
logging_level ................................... None
loss_scale ...................................... None
loss_scale_window ............................... 1000
lr .............................................. 6e-05
lr_decay_iters .................................. 20000
lr_decay_samples ................................ None
lr_decay_style .................................. cosine
lr_warmup_fraction .............................. 0.001
lr_warmup_init .................................. 0.0
lr_warmup_iters ................................. 0
lr_warmup_samples ............................... 0
lr_wsd_decay_iters .............................. None
lr_wsd_decay_samples ............................ None
lr_wsd_decay_style .............................. exponential
main_grads_dtype ................................ torch.float32
main_params_dtype ............................... torch.float32
make_vocab_size_divisible_by .................... 128
mamba_head_dim .................................. 64
mamba_num_groups ................................ 8
mamba_num_heads ................................. None
mamba_state_dim ................................. 128
manual_gc ....................................... False
manual_gc_eval .................................. True
manual_gc_interval .............................. 0
mask_factor ..................................... 1.0
mask_prob ....................................... 0.15
mask_type ....................................... random
masked_softmax_fusion ........................... True
max_position_embeddings ......................... 2048
max_tokens_to_oom ............................... 12000
memory_snapshot_path ............................ snapshot.pickle
merge_file ...................................... /workspace/megatron-lm/data/tokenizer/gpt2-merges.txt
micro_batch_size ................................ 4
microbatch_group_size_per_vp_stage .............. None
mid_level_dataset_surplus ....................... 0.005
min_loss_scale .................................. 1.0
min_lr .......................................... 6e-06
mlp_chunks_for_prefill .......................... 1
mmap_bin_files .................................. True
mock_data ....................................... False
moe_apply_probs_on_input ........................ False
moe_aux_loss_coeff .............................. 0.0
moe_deepep_num_sms .............................. 20
moe_enable_deepep ............................... False
moe_expert_capacity_factor ...................... None
moe_extended_tp ................................. False
moe_ffn_hidden_size ............................. None
moe_grouped_gemm ................................ False
moe_input_jitter_eps ............................ None
moe_layer_freq .................................. 1
moe_layer_recompute ............................. False
moe_pad_expert_input_to_capacity ................ False
moe_per_layer_logging ........................... False
moe_permute_fusion .............................. False
moe_router_bias_update_rate ..................... 0.001
moe_router_dtype ................................ None
moe_router_enable_expert_bias ................... False
moe_router_force_load_balancing ................. False
moe_router_fusion ............................... False
moe_router_group_topk ........................... None
moe_router_load_balancing_type .................. aux_loss
moe_router_num_groups ........................... None
moe_router_padding_for_fp8 ...................... False
moe_router_pre_softmax .......................... False
moe_router_score_function ....................... softmax
moe_router_topk ................................. 2
moe_router_topk_scaling_factor .................. None
moe_shared_expert_intermediate_size ............. None
moe_shared_expert_overlap ....................... False
moe_token_dispatcher_type ....................... allgather
moe_token_drop_policy ........................... probs
moe_upcycling_granularity ....................... 1
moe_use_legacy_grouped_gemm ..................... False
moe_use_upcycling ............................... False
moe_z_loss_coeff ................................ None
mrope_section ................................... None
mscale .......................................... 1.0
mscale_all_dim .................................. 0.0
mtp_loss_scaling_factor ......................... 0.1
mtp_num_layers .................................. None
multi_latent_attention .......................... False
multiple_validation_sets ........................ False
nccl_all_reduce_for_prefill ..................... False
nccl_communicator_config_path ................... None
nccl_ub ......................................... False
no_load_optim ................................... None
no_load_rng ..................................... None
no_persist_layer_norm ........................... False
no_rope_freq .................................... None
no_save_optim ................................... None
no_save_rng ..................................... None
non_persistent_ckpt_type ........................ None
non_persistent_global_ckpt_dir .................. None
non_persistent_local_ckpt_algo .................. fully_parallel
non_persistent_local_ckpt_dir ................... None
non_persistent_save_interval .................... None
norm_epsilon .................................... 1e-05
normalization ................................... LayerNorm
num_attention_heads ............................. 16
num_channels .................................... 3
num_classes ..................................... 1000
num_dataset_builder_threads ..................... 1
num_distributed_optimizer_instances ............. 1
num_experts ..................................... None
num_layers ...................................... 24
num_layers_at_end_in_bf16 ....................... 1
num_layers_at_start_in_bf16 ..................... 1
num_layers_per_virtual_pipeline_stage ........... None
num_query_groups ................................ 1
num_virtual_stages_per_pipeline_rank ............ None
num_workers ..................................... 2
object_storage_cache_path ....................... None
one_logger_async ................................ False
one_logger_project .............................. megatron-lm
one_logger_run_name ............................. None
onnx_safe ....................................... None
openai_gelu ..................................... False
optimizer ....................................... adam
optimizer_cpu_offload ........................... False
optimizer_offload_fraction ...................... 1.0
output_bert_embeddings .......................... False
overlap_cpu_optimizer_d2h_h2d ................... False
overlap_grad_reduce ............................. False
overlap_moe_expert_parallel_comm ................ False
overlap_p2p_comm ................................ False
overlap_p2p_comm_warmup_flush ................... False
overlap_param_gather ............................ False
overlap_param_gather_with_optimizer_step ........ False
override_opt_param_scheduler .................... False
padded_vocab_size ............................... None
params_dtype .................................... torch.float16
patch_dim ....................................... 16
per_split_data_args_path ........................ None
perform_initialization .......................... True
pin_cpu_grads ................................... True
pin_cpu_params .................................. True
pipeline_model_parallel_comm_backend ............ None
pipeline_model_parallel_layout .................. None
pipeline_model_parallel_size .................... 1
position_embedding_type ......................... learned_absolute
pretrained_checkpoint ........................... None
profile ......................................... True
profile_ranks ................................... [0]
profile_step_end ................................ 112
profile_step_start .............................. 110
q_lora_rank ..................................... None
qk_head_dim ..................................... 128
qk_l2_norm ...................................... False
qk_layernorm .................................... False
qk_pos_emb_head_dim ............................. 64
query_in_block_prob ............................. 0.1
rampup_batch_size ............................... None
rank ............................................ 0
recompute_granularity ........................... None
recompute_method ................................ None
recompute_modules ............................... None
recompute_num_layers ............................ None
record_memory_history ........................... False
relative_attention_max_distance ................. 128
relative_attention_num_buckets .................. 32
replication ..................................... False
replication_factor .............................. 2
replication_jump ................................ None
rerun_mode ...................................... validate_results
reset_attention_mask ............................ False
reset_position_ids .............................. False
result_rejected_tracker_filename ................ None
retriever_report_topk_accuracies ................ []
retriever_score_scaling ......................... False
retriever_seq_length ............................ 256
retro_add_retriever ............................. False
retro_attention_gate ............................ 1
retro_cyclic_train_iters ........................ None
retro_encoder_attention_dropout ................. 0.1
retro_encoder_hidden_dropout .................... 0.1
retro_encoder_layers ............................ 2
retro_num_neighbors ............................. 2
retro_num_retrieved_chunks ...................... 2
retro_project_dir ............................... None
retro_verify_neighbor_count ..................... True
reuse_grad_buf_for_mxfp8_param_ag ............... False
rope_scaling_factor ............................. 8.0
rope_type ....................................... None
rotary_base ..................................... 10000
rotary_interleaved .............................. False
rotary_percent .................................. 1.0
rotary_scaling_factor ........................... 1.0
rotary_seq_len_interpolation_factor ............. None
run_workload_inspector_server ................... False
sample_rate ..................................... 1.0
save ............................................ /workspace/megatron-lm/model_ckpt/gpt3_857m_tp4
save_interval ................................... 10000
save_retain_interval ............................ None
scatter_gather_tensors_in_pipeline .............. True
seed ............................................ 1234
seq_length ...................................... 2048
sequence_parallel ............................... False
sft ............................................. False
sft_tokenizer_prompt_format ..................... nemotron-h-aligned
sgd_momentum .................................... 0.9
sharp_enabled_group ............................. None
short_seq_prob .................................. 0.1
skip_train ...................................... False
skipped_train_samples ........................... 0
spec ............................................ None
split ........................................... 949,50,1
squared_relu .................................... False
start_weight_decay .............................. 0.1
straggler_ctrlr_port ............................ 65535
straggler_minmax_count .......................... 1
strict_fsdp_dtensor_load ........................ True
suggested_communication_unit_size ............... None
swiglu .......................................... False
swin_backbone_type .............................. tiny
symmetric_ar_type ............................... None
te_rng_tracker .................................. False
tensor_model_parallel_size ...................... 4
tensorboard_dir ................................. /workspace/megatron-lm/tb_logs/gpt3_857m_profiler_tp4
tensorboard_log_interval ........................ 1
tensorboard_queue_size .......................... 1000
test_data_path .................................. None
test_mode ....................................... False
tiktoken_num_special_tokens ..................... 1000
tiktoken_pattern ................................ None
tiktoken_special_tokens ......................... None
timing_log_level ................................ 0
timing_log_option ............................... minmax
titles_data_path ................................ None
tokenizer_model ................................. None
tokenizer_type .................................. GPT2BPETokenizer
torch_fsdp2_reshard_after_forward ............... True
tp_comm_bootstrap_backend ....................... nccl
tp_comm_bulk_dgrad .............................. True
tp_comm_bulk_wgrad .............................. True
tp_comm_overlap ................................. False
tp_comm_overlap_ag .............................. True
tp_comm_overlap_cfg ............................. None
tp_comm_overlap_rs .............................. True
tp_comm_overlap_rs_dgrad ........................ False
tp_comm_split_ag ................................ True
tp_comm_split_rs ................................ True
train_data_path ................................. None
train_iters ..................................... 20000
train_samples ................................... None
train_sync_interval ............................. None
transformer_impl ................................ local
transformer_pipeline_model_parallel_size ........ 1
untie_embeddings_and_output_weights ............. False
use_checkpoint_args ............................. False
use_checkpoint_opt_param_scheduler .............. False
use_cpu_initialization .......................... None
use_dist_ckpt ................................... True
use_dist_ckpt_deprecated ........................ False
use_distributed_optimizer ....................... False
use_flash_attn .................................. False
use_fused_weighted_squared_relu ................. False
use_legacy_models ............................... False
use_megatron_fsdp ............................... False
use_mp_args_from_checkpoint_args ................ False
use_one_sent_docs ............................... False
use_persistent_ckpt_worker ...................... False
use_precision_aware_optimizer ................... False
use_pytorch_profiler ............................ True
use_ring_exchange_p2p ........................... False
use_rope_scaling ................................ False
use_rotary_position_embeddings .................. False
use_sharp ....................................... False
use_tokenizer_model_from_checkpoint_args ........ True
use_torch_fsdp2 ................................. False
use_torch_optimizer_for_cpu_offload ............. False
use_tp_pp_dp_mapping ............................ False
v_head_dim ...................................... 128
valid_data_path ................................. None
variable_seq_lengths ............................ False
virtual_pipeline_model_parallel_size ............ None
vision_backbone_type ............................ vit
vision_pretraining .............................. False
vision_pretraining_type ......................... classify
vocab_extra_ids ................................. 0
vocab_file ...................................... /workspace/megatron-lm/data/tokenizer/gpt2-vocab.json
vocab_size ...................................... None
wandb_exp_name ..................................
wandb_project ...................................
wandb_save_dir ..................................
weight_decay .................................... 0.1
weight_decay_incr_style ......................... constant
wgrad_deferral_limit ............................ 0
world_size ...................................... 4
yaml_cfg ........................................ None
-------------------- end of arguments ---------------------
INFO:megatron.core.num_microbatches_calculator:setting number of microbatches to constant 4
> building GPT2BPETokenizer tokenizer ...
> padded vocab (size: 50257) with 431 dummy tokens (new size: 50688)
WARNING:megatron.core.rerun_state_machine:RerunStateMachine initialized in mode RerunMode.VALIDATE_RESULTS
> initializing torch distributed ...
> initialized tensor model parallel with size 4
> initialized pipeline model parallel with size 1
> setting random seeds to 1234 ...
> compiling dataset index builder ...
make: Entering directory '/workspace/megatron-lm/megatron/core/datasets'
[rank2]:[W103 04:22:34.420203100 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank1]:[W103 04:22:34.605196852 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
> setting tensorboard ...
WARNING: one_logger package is required to enable e2e metrics tracking. please go to https://confluence.nvidia.com/display/MLWFO/Package+Repositories for details to install it
make: Nothing to be done for 'default'.
make: Leaving directory '/workspace/megatron-lm/megatron/core/datasets'
>>> done with dataset index builder. Compilation time: 0.207 seconds
> compiling and loading fused kernels ...
[rank3]:[W103 04:22:34.713259577 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank0]:[W103 04:22:34.732739062 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
>>> done with compiling and loading fused kernels. Compilation time: 0.324 seconds
time to initialize megatron (seconds): 2.393
[after megatron is initialized] datetime: 2026-01-03 04:22:36
building GPT model ...
> number of parameters on (tensor, pipeline) model parallel rank (3, 0): 90763264
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 90763264
> number of parameters on (tensor, pipeline) model parallel rank (2, 0): 90763264
> number of parameters on (tensor, pipeline) model parallel rank (1, 0): 90763264
INFO:megatron.core.distributed.distributed_data_parallel:Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=False, check_for_large_grads=False, bucket_size=None, pad_buckets_for_high_nccl_busbw=False, average_in_collective=False, fp8_param_gather=False, reuse_grad_buf_for_mxfp8_param_ag=False, use_megatron_fsdp=False, use_custom_fsdp=False, data_parallel_sharding_strategy='no_shard', gradient_reduce_div_fusion=True, suggested_communication_unit_size=None, preserve_fp32_weights=True, keep_fp8_transpose_cache=False, nccl_ub=False, fsdp_double_buffer=False, outer_dp_sharding_strategy='no_shard', disable_symmetric_registration=False, delay_wgrad_compute=False)
INFO:megatron.core.distributed.param_and_grad_buffer:Number of buckets for gradient all-reduce / reduce-scatter: 1
Params for bucket 1 (90763264 elements, 90763264 padded size):
module.decoder.layers.22.self_attention.linear_qkv.weight
module.decoder.layers.20.mlp.linear_fc2.weight
module.decoder.layers.18.self_attention.linear_proj.weight
module.decoder.layers.16.input_layernorm.bias
module.decoder.layers.8.mlp.linear_fc1.weight
module.decoder.layers.7.pre_mlp_layernorm.bias
module.decoder.layers.7.self_attention.linear_proj.weight
module.decoder.layers.3.input_layernorm.bias
module.decoder.layers.1.mlp.linear_fc1.weight
module.decoder.layers.21.self_attention.linear_qkv.bias
module.decoder.layers.14.self_attention.linear_proj.weight
module.decoder.layers.12.mlp.linear_fc1.bias
module.decoder.layers.11.pre_mlp_layernorm.weight
module.decoder.layers.8.self_attention.linear_qkv.bias
module.decoder.layers.7.mlp.linear_fc2.bias
module.decoder.layers.7.self_attention.linear_qkv.bias
module.decoder.layers.6.mlp.linear_fc1.bias
module.decoder.layers.5.self_attention.linear_qkv.bias
module.decoder.layers.2.mlp.linear_fc1.bias
module.decoder.layers.21.mlp.linear_fc1.weight
module.decoder.layers.20.self_attention.linear_proj.weight
module.decoder.layers.18.input_layernorm.bias
module.decoder.layers.15.mlp.linear_fc2.bias
module.decoder.layers.11.self_attention.linear_qkv.weight
module.decoder.layers.6.self_attention.linear_qkv.weight
module.decoder.layers.23.self_attention.linear_qkv.bias
module.decoder.layers.22.self_attention.linear_proj.bias
module.decoder.layers.14.input_layernorm.bias
module.decoder.layers.13.pre_mlp_layernorm.weight
module.decoder.layers.10.self_attention.linear_qkv.bias
module.decoder.layers.9.input_layernorm.weight
module.decoder.layers.2.mlp.linear_fc2.bias
module.decoder.layers.1.input_layernorm.weight
module.decoder.layers.23.mlp.linear_fc1.weight
module.decoder.layers.22.pre_mlp_layernorm.bias
module.decoder.layers.20.input_layernorm.bias
module.decoder.layers.17.mlp.linear_fc2.bias
module.decoder.layers.16.mlp.linear_fc1.bias
module.decoder.layers.15.pre_mlp_layernorm.weight
module.decoder.layers.13.self_attention.linear_qkv.weight
module.decoder.layers.10.mlp.linear_fc1.weight
module.decoder.layers.5.pre_mlp_layernorm.bias
module.decoder.layers.4.input_layernorm.weight
module.decoder.layers.22.input_layernorm.weight
module.decoder.layers.15.self_attention.linear_qkv.weight
module.decoder.layers.12.self_attention.linear_qkv.bias
module.decoder.layers.11.self_attention.linear_proj.bias
module.decoder.layers.9.self_attention.linear_qkv.bias
module.decoder.layers.6.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc2.weight
module.decoder.layers.2.self_attention.linear_proj.bias
module.decoder.layers.0.mlp.linear_fc2.bias
module.decoder.layers.0.self_attention.linear_qkv.bias
module.embedding.word_embeddings.weight
module.decoder.layers.22.mlp.linear_fc2.weight
module.decoder.layers.19.mlp.linear_fc2.bias
module.decoder.layers.18.mlp.linear_fc1.bias
module.decoder.layers.17.pre_mlp_layernorm.weight
module.decoder.layers.12.mlp.linear_fc1.weight
module.decoder.layers.11.pre_mlp_layernorm.bias
module.decoder.layers.9.mlp.linear_fc1.bias
module.decoder.layers.5.self_attention.linear_qkv.weight
module.decoder.layers.4.self_attention.linear_proj.weight
module.decoder.layers.17.self_attention.linear_qkv.weight
module.decoder.layers.14.mlp.linear_fc1.bias
module.decoder.layers.13.self_attention.linear_proj.bias
module.decoder.layers.11.input_layernorm.weight
module.decoder.layers.8.pre_mlp_layernorm.bias
module.decoder.layers.7.mlp.linear_fc1.weight
module.decoder.layers.7.self_attention.linear_qkv.weight
module.decoder.final_layernorm.weight
module.decoder.layers.22.self_attention.linear_proj.weight
module.decoder.layers.20.mlp.linear_fc1.bias
module.decoder.layers.19.pre_mlp_layernorm.weight
module.decoder.layers.16.self_attention.linear_qkv.bias
module.decoder.layers.15.self_attention.linear_proj.bias
module.decoder.layers.13.pre_mlp_layernorm.bias
module.decoder.layers.11.mlp.linear_fc2.weight
module.decoder.layers.9.input_layernorm.bias
module.decoder.layers.7.mlp.linear_fc2.weight
module.decoder.layers.0.mlp.linear_fc2.weight
module.decoder.layers.19.self_attention.linear_qkv.weight
module.decoder.layers.16.mlp.linear_fc1.weight
module.decoder.layers.15.pre_mlp_layernorm.bias
module.decoder.layers.13.input_layernorm.weight
module.decoder.layers.6.pre_mlp_layernorm.bias
module.decoder.layers.6.pre_mlp_layernorm.weight
module.decoder.layers.5.input_layernorm.weight
module.decoder.layers.4.mlp.linear_fc1.weight
module.decoder.layers.3.mlp.linear_fc1.weight
module.decoder.layers.1.pre_mlp_layernorm.bias
module.decoder.layers.0.pre_mlp_layernorm.weight
module.decoder.layers.22.input_layernorm.bias
module.decoder.layers.18.self_attention.linear_qkv.bias
module.decoder.layers.17.self_attention.linear_proj.bias
module.decoder.layers.15.input_layernorm.weight
module.decoder.layers.13.mlp.linear_fc2.weight
module.decoder.layers.13.mlp.linear_fc1.bias
module.decoder.layers.11.self_attention.linear_proj.weight
module.decoder.layers.8.mlp.linear_fc2.bias
module.decoder.layers.0.input_layernorm.bias
module.decoder.layers.18.mlp.linear_fc1.weight
module.decoder.layers.17.pre_mlp_layernorm.bias
module.decoder.layers.15.mlp.linear_fc2.weight
module.decoder.layers.14.self_attention.linear_qkv.bias
module.decoder.layers.8.input_layernorm.bias
module.decoder.layers.4.pre_mlp_layernorm.weight
module.decoder.layers.4.input_layernorm.bias
module.decoder.layers.2.input_layernorm.weight
module.decoder.layers.1.self_attention.linear_proj.weight
module.embedding.position_embeddings.weight
module.decoder.layers.1.mlp.linear_fc1.bias
module.decoder.final_layernorm.bias
module.decoder.layers.21.mlp.linear_fc2.bias
module.decoder.layers.20.self_attention.linear_qkv.bias
module.decoder.layers.19.self_attention.linear_proj.bias
module.decoder.layers.17.input_layernorm.weight
module.decoder.layers.14.mlp.linear_fc1.weight
module.decoder.layers.13.self_attention.linear_proj.weight
module.decoder.layers.11.input_layernorm.bias
module.decoder.layers.6.self_attention.linear_proj.bias
module.decoder.layers.3.self_attention.linear_proj.bias
module.decoder.layers.17.mlp.linear_fc2.weight
module.decoder.layers.20.mlp.linear_fc1.weight
module.decoder.layers.19.pre_mlp_layernorm.bias
module.decoder.layers.15.self_attention.linear_proj.weight
module.decoder.layers.8.pre_mlp_layernorm.weight
module.decoder.layers.23.mlp.linear_fc2.bias
module.decoder.layers.22.mlp.linear_fc1.bias
module.decoder.layers.21.pre_mlp_layernorm.weight
module.decoder.layers.19.input_layernorm.weight
module.decoder.layers.13.input_layernorm.bias
module.decoder.layers.10.mlp.linear_fc2.bias
module.decoder.layers.9.pre_mlp_layernorm.bias
module.decoder.layers.5.mlp.linear_fc1.bias
module.decoder.layers.3.mlp.linear_fc1.bias
module.decoder.layers.1.self_attention.linear_qkv.weight
module.decoder.layers.21.self_attention.linear_qkv.weight
module.decoder.layers.19.mlp.linear_fc2.weight
module.decoder.layers.17.self_attention.linear_proj.weight
module.decoder.layers.15.input_layernorm.bias
module.decoder.layers.13.mlp.linear_fc2.bias
module.decoder.layers.4.mlp.linear_fc1.bias
module.decoder.layers.2.input_layernorm.bias
module.decoder.layers.0.self_attention.linear_proj.bias
module.decoder.layers.0.pre_mlp_layernorm.bias
module.decoder.layers.23.pre_mlp_layernorm.weight
module.decoder.layers.12.mlp.linear_fc2.bias
module.decoder.layers.11.mlp.linear_fc1.bias
module.decoder.layers.10.pre_mlp_layernorm.weight
module.decoder.layers.9.mlp.linear_fc2.bias
module.decoder.layers.8.self_attention.linear_proj.bias
module.decoder.layers.5.self_attention.linear_proj.bias
module.decoder.layers.5.input_layernorm.bias
module.decoder.layers.2.pre_mlp_layernorm.bias
module.decoder.layers.0.input_layernorm.weight
module.decoder.layers.23.self_attention.linear_qkv.weight
module.decoder.layers.19.self_attention.linear_proj.weight
module.decoder.layers.17.input_layernorm.bias
module.decoder.layers.14.mlp.linear_fc2.bias
module.decoder.layers.10.self_attention.linear_qkv.weight
module.decoder.layers.6.mlp.linear_fc2.weight
module.decoder.layers.4.self_attention.linear_proj.bias
module.decoder.layers.3.self_attention.linear_qkv.weight
module.decoder.layers.0.self_attention.linear_proj.weight
module.decoder.layers.22.self_attention.linear_qkv.bias
module.decoder.layers.21.self_attention.linear_proj.bias
module.decoder.layers.12.pre_mlp_layernorm.weight
module.decoder.layers.8.mlp.linear_fc2.weight
module.decoder.layers.6.input_layernorm.weight
module.decoder.layers.4.pre_mlp_layernorm.bias
module.decoder.layers.3.mlp.linear_fc2.weight
module.decoder.layers.2.self_attention.linear_qkv.bias
module.decoder.layers.22.mlp.linear_fc1.weight
module.decoder.layers.21.pre_mlp_layernorm.bias
module.decoder.layers.19.input_layernorm.bias
module.decoder.layers.16.mlp.linear_fc2.bias
module.decoder.layers.15.mlp.linear_fc1.bias
module.decoder.layers.12.self_attention.linear_qkv.weight
module.decoder.layers.9.self_attention.linear_proj.bias
module.decoder.layers.7.input_layernorm.weight
module.decoder.layers.5.pre_mlp_layernorm.weight
module.decoder.layers.4.self_attention.linear_qkv.weight
module.decoder.layers.23.self_attention.linear_proj.bias
module.decoder.layers.21.input_layernorm.weight
module.decoder.layers.11.self_attention.linear_qkv.bias
module.decoder.layers.10.self_attention.linear_proj.bias
module.decoder.layers.9.self_attention.linear_proj.weight
module.decoder.layers.5.mlp.linear_fc2.weight
module.decoder.layers.5.mlp.linear_fc1.weight
module.decoder.layers.2.self_attention.linear_proj.weight
module.decoder.layers.17.mlp.linear_fc1.bias
module.decoder.layers.23.pre_mlp_layernorm.bias
module.decoder.layers.21.mlp.linear_fc2.weight
module.decoder.layers.18.mlp.linear_fc2.bias
module.decoder.layers.16.pre_mlp_layernorm.weight
module.decoder.layers.11.mlp.linear_fc1.weight
module.decoder.layers.10.pre_mlp_layernorm.bias
module.decoder.layers.9.mlp.linear_fc1.weight
module.decoder.layers.5.self_attention.linear_proj.weight
module.decoder.layers.2.self_attention.linear_qkv.weight
module.decoder.layers.23.input_layernorm.weight
module.decoder.layers.16.self_attention.linear_qkv.weight
module.decoder.layers.14.mlp.linear_fc2.weight
module.decoder.layers.13.self_attention.linear_qkv.bias
module.decoder.layers.12.self_attention.linear_proj.bias
module.decoder.layers.10.input_layernorm.weight
module.decoder.layers.9.mlp.linear_fc2.weight
module.decoder.layers.8.mlp.linear_fc1.bias
module.decoder.layers.7.pre_mlp_layernorm.weight
module.decoder.layers.6.input_layernorm.bias
module.decoder.layers.23.mlp.linear_fc2.weight
module.decoder.layers.21.self_attention.linear_proj.weight
module.decoder.layers.19.mlp.linear_fc1.bias
module.decoder.layers.18.pre_mlp_layernorm.weight
module.decoder.layers.15.self_attention.linear_qkv.bias
module.decoder.layers.13.mlp.linear_fc1.weight
module.decoder.layers.12.pre_mlp_layernorm.bias
module.decoder.layers.10.mlp.linear_fc2.weight
module.decoder.layers.8.self_attention.linear_qkv.weight
module.decoder.layers.1.mlp.linear_fc2.bias
module.decoder.layers.18.self_attention.linear_qkv.weight
module.decoder.layers.15.mlp.linear_fc1.weight
module.decoder.layers.14.pre_mlp_layernorm.weight
module.decoder.layers.12.input_layernorm.weight
module.decoder.layers.9.pre_mlp_layernorm.weight
module.decoder.layers.7.mlp.linear_fc1.bias
module.decoder.layers.5.mlp.linear_fc2.bias
module.decoder.layers.4.mlp.linear_fc2.bias
module.decoder.layers.3.pre_mlp_layernorm.bias
module.decoder.layers.1.self_attention.linear_qkv.bias
module.decoder.layers.23.self_attention.linear_proj.weight
module.decoder.layers.21.input_layernorm.bias
module.decoder.layers.20.pre_mlp_layernorm.weight
module.decoder.layers.17.self_attention.linear_qkv.bias
module.decoder.layers.16.self_attention.linear_proj.bias
module.decoder.layers.14.self_attention.linear_qkv.weight
module.decoder.layers.12.mlp.linear_fc2.weight
module.decoder.layers.10.self_attention.linear_proj.weight
module.decoder.layers.3.input_layernorm.weight
module.decoder.layers.1.mlp.linear_fc2.weight
module.decoder.layers.20.self_attention.linear_qkv.weight
module.decoder.layers.17.mlp.linear_fc1.weight
module.decoder.layers.16.pre_mlp_layernorm.bias
module.decoder.layers.4.mlp.linear_fc2.weight
module.decoder.layers.2.pre_mlp_layernorm.weight
module.decoder.layers.1.self_attention.linear_proj.bias
module.decoder.layers.23.input_layernorm.bias
module.decoder.layers.19.self_attention.linear_qkv.bias
module.decoder.layers.18.self_attention.linear_proj.bias
module.decoder.layers.16.input_layernorm.weight
module.decoder.layers.12.self_attention.linear_proj.weight
module.decoder.layers.10.input_layernorm.bias
module.decoder.layers.0.mlp.linear_fc1.bias
module.decoder.layers.1.pre_mlp_layernorm.weight
module.decoder.layers.18.pre_mlp_layernorm.bias
module.decoder.layers.20.mlp.linear_fc2.bias
module.decoder.layers.19.mlp.linear_fc1.weight
module.decoder.layers.16.mlp.linear_fc2.weight
module.decoder.layers.14.self_attention.linear_proj.bias
module.decoder.layers.7.self_attention.linear_proj.bias
module.decoder.layers.6.self_attention.linear_proj.weight
module.decoder.layers.3.self_attention.linear_proj.weight
module.decoder.layers.2.mlp.linear_fc1.weight
module.decoder.layers.22.mlp.linear_fc2.bias
module.decoder.layers.21.mlp.linear_fc1.bias
module.decoder.layers.20.self_attention.linear_proj.bias
module.decoder.layers.18.input_layernorm.weight
module.decoder.layers.14.pre_mlp_layernorm.bias
module.decoder.layers.12.input_layernorm.bias
module.decoder.layers.9.self_attention.linear_qkv.weight
module.decoder.layers.6.self_attention.linear_qkv.bias
module.decoder.layers.3.pre_mlp_layernorm.weight
module.decoder.layers.0.mlp.linear_fc1.weight
module.decoder.layers.20.pre_mlp_layernorm.bias
module.decoder.layers.18.mlp.linear_fc2.weight
module.decoder.layers.16.self_attention.linear_proj.weight
module.decoder.layers.14.input_layernorm.weight
module.decoder.layers.8.input_layernorm.weight
module.decoder.layers.6.mlp.linear_fc2.bias
module.decoder.layers.4.self_attention.linear_qkv.bias
module.decoder.layers.0.self_attention.linear_qkv.weight
module.decoder.layers.23.mlp.linear_fc1.bias
module.decoder.layers.22.pre_mlp_layernorm.weight
module.decoder.layers.20.input_layernorm.weight
module.decoder.layers.11.mlp.linear_fc2.bias
module.decoder.layers.10.mlp.linear_fc1.bias
module.decoder.layers.8.self_attention.linear_proj.weight
module.decoder.layers.7.input_layernorm.bias
module.decoder.layers.3.mlp.linear_fc2.bias
module.decoder.layers.3.self_attention.linear_qkv.bias
module.decoder.layers.1.input_layernorm.bias
INFO:megatron.core.optimizer:Setting up optimizer with config OptimizerConfig(optimizer='adam', lr=6e-05, min_lr=6e-06, decoupled_lr=None, decoupled_min_lr=None, weight_decay=0.1, fp8_recipe='delayed', fp16=True, bf16=False, reuse_grad_buf_for_mxfp8_param_ag=False, params_dtype=torch.float16, use_precision_aware_optimizer=False, store_param_remainders=True, main_grads_dtype=torch.float32, main_params_dtype=torch.float32, exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, loss_scale=None, initial_loss_scale=4294967296, min_loss_scale=1.0, loss_scale_window=1000, hysteresis=2, adam_beta1=0.9, adam_beta2=0.95, adam_eps=1e-08, sgd_momentum=0.9, use_distributed_optimizer=False, overlap_param_gather=False, overlap_param_gather_with_optimizer_step=False, use_megatron_fsdp=False, optimizer_cpu_offload=False, optimizer_offload_fraction=1.0, use_torch_optimizer_for_cpu_offload=False, overlap_cpu_optimizer_d2h_h2d=False, pin_cpu_grads=True, pin_cpu_params=True, clip_grad=1.0, log_num_zeros_in_grad=False, barrier_with_L1_time=True, timers=<megatron.core.timers.Timers object at 0x7f59b73c1430>, config_logger_dir='')
INFO:megatron.core.optimizer_param_scheduler:> learning rate decay style: cosine
WARNING: could not find the metadata file /workspace/megatron-lm/model_ckpt/gpt3_857m_tp4/latest_checkpointed_iteration.txt
will not load any checkpoints and will start from random
(min, max) time across ranks (ms):
load-checkpoint ................................: (0.29, 0.30)
[after model, optimizer, and learning rate scheduler are built] datetime: 2026-01-03 04:22:36
> building train, validation, and test datasets ...
> datasets target sizes (minimum size):
train: 320000
validation: 3360
test: 160
INFO:megatron.core.datasets.blended_megatron_dataset_config:Let split_matrix = [(0, 0.949), (0.949, 0.999), (0.999, 1.0)]
> building train, validation, and test datasets for GPT ...
INFO:megatron.core.datasets.blended_megatron_dataset_builder:Building GPTDataset splits with sizes=(320000, 3360, 160) and config=GPTDatasetConfig(random_seed=1234, sequence_length=2048, blend=(['/workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document'], None), blend_per_split=None, multiple_validation_sets=False, full_validation=False, split='949,50,1', split_matrix=[(0, 0.949), (0.949, 0.999), (0.999, 1.0)], num_dataset_builder_threads=1, path_to_cache=None, mmap_bin_files=True, mock=False, tokenizer=<megatron.training.tokenizer.tokenizer._GPT2BPETokenizer object at 0x7f59b702e240>, mid_level_dataset_surplus=0.005, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False, create_attention_mask=True, drop_last_partial_validation_sequence=True, add_extra_token_to_sequence=True, object_storage_cache_path=None)
INFO:megatron.core.datasets.indexed_dataset:Load the _IndexReader from /workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document.idx
INFO:megatron.core.datasets.indexed_dataset: Extract the sequence lengths
INFO:megatron.core.datasets.indexed_dataset: Extract the sequence pointers
INFO:megatron.core.datasets.indexed_dataset: Extract the document indices
INFO:megatron.core.datasets.indexed_dataset:> total number of sequences: 14548094
INFO:megatron.core.datasets.indexed_dataset:> total number of documents: 14548094
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset train indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 521301
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset valid indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 13728
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset test indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 273
> finished creating GPT datasets ...
[after dataloaders are built] datetime: 2026-01-03 04:22:36
done with setup ...
(min, max) time across ranks (ms):
model-and-optimizer-setup ......................: (162.43, 174.52)
train/valid/test-data-iterators-setup ..........: (26.90, 152.19)
training ...
Overwriting rerun_state_machine.current_iteration from -1 to 0...
[before the start of training step] datetime: 2026-01-03 04:22:36
Number of parameters in transformer block in billions: 0.30
Number of parameters in embedding layers in billions: 0.05
Total number of parameters in billions: 0.35
Number of parameters in most loaded shard in billions: 0.0885
Theoretical memory footprints: weight and optimizer=1519.18 MB
[2026-01-03 04:29:07] iteration 200/ 20000 | consumed samples: 3200 | elapsed time per iteration (ms): 1952.1 | learning rate: 5.999146E-05 | global batch size: 16 | lm loss: 5.665651E+00 | loss scale: 8192.0 | grad norm: 0.833 | number of skipped iterations: 20 | number of nan iterations: 0 |
[Rank 1] (after 200 iterations) memory (MB) | allocated: 1935.56005859375 | max allocated: 12972.46923828125 | reserved: 14548.0 | max reserved: 14548.0
[Rank 2] (after 200 iterations) memory (MB) | allocated: 1911.18505859375 | max allocated: 12948.21923828125 | reserved: 14516.0 | max reserved: 14516.0
[Rank 0] (after 200 iterations) memory (MB) | allocated: 1923.56005859375 | max allocated: 12960.71923828125 | reserved: 14812.0 | max reserved: 14812.0
[Rank 3] (after 200 iterations) memory (MB) | allocated: 1911.18505859375 | max allocated: 12947.96923828125 | reserved: 14532.0 | max reserved: 14532.0

profiler文件

下图就是初始的LanguageModelEmbedding因为TP维度是4,并且没有Sequence并行,所以后续采用reduce_from_tensor_model_parallel_region来进行all reduce获得token转化结果

下图是MHA计算时最后一步通过与linear_proj的计算将维度转换回去的计算,这里linear_proj是行并行最后会调用all reduce得到结果

下图是MLP模块中最后行并行后调用all reduce的地方