【Megatron-LM源码分析(二)】-GPT模型pretrain流程

本次查看Megatron-LM的版本是core_r0.14.0,查看的GPT训练文件是pretrain_gpt.py

入口函数

main入口函数代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
if __name__ == "__main__":

# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True

# Optionally enable inprocess restart on pretrain
pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)

pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
store=store,
)

其功能主要为:

  1. 临时函数,告诉数据集构建器这是一个分布式训练环境,需要在多个进程间协调数据集构建

  2. 可选地启用进程内重启功能,为训练函数添加故障恢复能力,允许在 GPU 故障时自动重启而不中断整个作业

  3. 调用核心pretrain函数并传入自定义相关函数作为参数进行训练

进程重启功能

其调用的是maybe_wrap_for_inprocess_restart函数,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def maybe_wrap_for_inprocess_restart(pretrain):

args = arguments.parse_args(ignore_unknown_args=True)

if args.inprocess_restart:
pretrain = inprocess_restart(pretrain, args)

torch.distributed.TCPStore(
host_name=os.environ['MASTER_ADDR'], # 主节点 IP 地址
port=int(os.environ['MASTER_PORT'])+1, # 端口 (避免与主通信冲突)
world_size=int(os.getenv('WORLD_SIZE', '1')), # 总进程数
is_master=(int(os.getenv('RANK', '0')) == 0), # 是否为主进程
timeout=timedelta(seconds=300), # 连接超时 (5分钟)
wait_for_workers=True, # 等待所有 worker 连接
use_libuv=True, # 使用 libuv 提高性能
)
else:
store = None

return pretrain, store

其主要功能是查看是否带有inprocess_restart启动参数,如果没有就不操作,如果有就继续操作,包括:

  • 调用inprocess_restart对pretrain关键函数进行包装

  • 创建TCPStore,TCPStore类似于是一个分布式KV存储系统,充当控制面。它作用有:

    • 底层采用TCP协议,所以如果NCCL或训练的通信组出错也不会受影响。注意使用的是int(os.environ['MASTER_PORT']) + 1端口,以避免端口冲突

    • wait_for_workers=True参数确保等待所有worker都正常运行

    • 用以控制保证所有的worker都进入了新一轮的训练

    • 其容错的运行流程类似如下

1
2
3
4
5
6
1. 训练开始 → 所有进程连接到 TCPStore
2. 进程 A 检测到 GPU 故障 → 向 TCPStore 报告
3. TCPStore 广播故障状态 → 所有进程暂停
4. 协调重启决策 → 隔离故障进程
5. 启动新进程 → 通过 TCPStore 同步状态
6. 所有进程确认 → 恢复训练

不过注意的是如果一个节点确实损坏了,它无法找到新的节点来替代,只能不断地重启了,除非有足够的热备节点

调用inprocess_restartpretrain关键函数进行包装的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def inprocess_restart(train, args):
if inprocess is None:
warnings.warn('In-process restart is not available')
return train

if 'TORCH_CPP_LOG_LEVEL' not in os.environ or os.environ['TORCH_CPP_LOG_LEVEL'] not in (
'error',
'fatal',
):
warnings.warn(
'Set TORCH_CPP_LOG_LEVEL=error to suppress c10d waitForInput timeout warning messages'
)

# Layers represents a configuration for a layer of branches at a certain
# depth in a topology tree constructed by inprocess.rank_assignment.Tree.
# First layer contains all ranks and it's the root of the topology tree,
# the second optional layer groups ranks by nodes.
layers = [
inprocess.rank_assignment.Layer(
min_ranks=args.inprocess_active_world_size,
max_ranks=args.inprocess_active_world_size,
flag=inprocess.rank_assignment.LayerFlag.RESERVE,
)
]
if args.inprocess_granularity == 'node':
device_count = torch.cuda.device_count()

layers.append(
inprocess.rank_assignment.Layer(
min_ranks=device_count,
max_ranks=device_count,
key_or_fn=lambda _: socket.gethostname(),
flag=inprocess.rank_assignment.LayerFlag.RESERVE,
)
)

finalize = [
inprocess.finalize.ThreadedFinalize(timeout=timedelta(seconds=10), fn=destroy_state)
]

if args.inprocess_empty_cuda_cache:
finalize.append(
inprocess.finalize.ThreadedFinalize(
timeout=timedelta(seconds=10), fn=torch.cuda.empty_cache
)
)

initialize = inprocess.Compose(
inprocess.initialize.RetryController(min_world_size=args.inprocess_active_world_size),
inprocess.nested_restarter.NestedRestarterHandlingCompleted(),
)
abort = inprocess.Compose(
inprocess.abort.AbortTransformerEngine(),
inprocess.abort.AbortTorchDistributed(),
inprocess.nested_restarter.NestedRestarterHandlingStarting(),
)
completion = inprocess.nested_restarter.NestedRestarterFinalized()
terminate = inprocess.nested_restarter.NestedRestarterAborted()

train = inprocess.Wrapper(
store_kwargs={
'timeout': timedelta(seconds=300),
'port': int(os.environ['MASTER_PORT']) + 2,
},
initialize=initialize,
abort=abort,
completion=completion,
terminate=terminate,
health_check=inprocess.health_check.CudaHealthCheck(timeout=timedelta(seconds=10)),
rank_assignment=inprocess.rank_assignment.Tree(layers=layers),
finalize=inprocess.Compose(*finalize),
heartbeat_interval=timedelta(seconds=args.inprocess_heartbeat_interval),
heartbeat_timeout=timedelta(seconds=args.inprocess_heartbeat_timeout),
barrier_timeout=timedelta(seconds=args.inprocess_barrier_timeout),
completion_timeout=timedelta(seconds=args.inprocess_completion_timeout),
monitor_process_interval=timedelta(seconds=args.inprocess_monitor_process_interval),
monitor_thread_interval=timedelta(seconds=args.inprocess_monitor_thread_interval),
last_call_wait=timedelta(seconds=args.inprocess_last_call_wait),
soft_timeout=timedelta(seconds=args.inprocess_soft_timeout),
hard_timeout=timedelta(seconds=args.inprocess_hard_timeout),
termination_grace_time=timedelta(seconds=args.inprocess_termination_grace_time),
enabled=True,
)(train)

return train

其主要功能为:

  1. 查看是否成功从import nvidia_resiliency_ext.inprocess as inprocess引入inprocess,如果没有就直接返回

  2. 提醒设置日志级别

  3. 构建Layers:(这里的Layers有啥作用没咋看懂)

    1. 设置最小 / 最大存活 rank以及是否采用RESERVE模式

    2. 如果是node粒度还需要再构建node层级的layers,以做到node级别的移除

  4. 构建 abort 之后 / restart 之前 执行的清理逻辑finalize,包含的处理逻辑有

    1. destroy_state

      • destroy process group

      • 释放 NCCL communicator

      • 清理 Megatron 内部全局状态

    2. empty_cache(可选):

      • 清除从cache

      • 在OOM场景下很有用

  5. 再就是构建状态机中Initialize / Abort / Completion / Terminate这四个状态:

    1. initialize:等待至少 min_world_size 个 rank 可用

    2. abort(失败时触发):负责停 Transformer Engine,abort torch.distributed,通知 nested restarter开始重启

    3. completion(正常结束):标记这一轮执行完成,不触发 restart

    4. terminate(彻底失败): 直接终止,不再尝试恢复

  6. 包装训练函数:

    1. 设置了上述的状态机

    2. 设置了很多timeout

    3. 将端口设置为int(os.environ['MASTER_PORT']) + 2以避免端口冲突

pretrain参数

pretrain是训练的核心入口,它更加类似于一个训练流程的驱动,用户负责通过参数提供数据、模型、loss计算方法等,它负责对其进行组装然后将分布式训练策略、checkpoint、log等方法进行执行。

其函数定义如下,下面对其进行分组介绍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
get_embedding_ranks=None,
get_position_embedding_ranks=None,
non_loss_data_func=None,
store=None,
inprocess_call_wrapper: Optional[CallWrapper] = None,
):

数据相关参数

  • train_valid_test_dataset_provider负责告诉Megatron如何划分出train_ds、valid_ds、test_ds这3类数据集。

    • 本示例中传入的函数如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build the train test and validation datasets.

    Args:
    train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
    args = get_args()

    config = core_gpt_dataset_config_from_args(args)

    if args.sft:
    dataset_type = SFTDataset
    else:
    if args.mock_data:
    dataset_type = MockGPTDataset
    else:
    dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
    dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
    ).build()

    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds
    • 其输入参数为train_val_test_num_samples,如注释所言,其代表的train、val、test对应的sample的数量

    • 在函数里其根据参数得到dataset_type

    • 然后还有函数is_dataset_built_on_rank作为参数,该函数用于决定是否在当前进程构建数据集,其函数如下,实际效果是只在PP并行组的第一个和最后一个中的TP组的第一个rank构建数据集。

    1
    2
    3
    4
    5
    def is_dataset_built_on_rank():
    return (
    parallel_state.is_pipeline_first_stage(ignore_virtual=True)
    or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
    ) and parallel_state.get_tensor_model_parallel_rank() == 0
    • 然后其构建了BlendedMegatronDatasetBuilder这个数据集处理类,并调用其build()函数得到了train_ds, valid_ds, test_ds。BlendedMegatronDatasetBuilder的功能主要如下,后面会再找机会详细介绍

      • 负责从多个数据源构建混合数据集

      • 支持分布式训练环境下的数据集协调

      • 提供高效的数据集缓存和并行构建机制

模型相关参数

  • model_provider参数负责提供一个原始模型,即提供一个在CPU上的没有进行fp16转换没有ddp切割的原始模型。

    • 本示例中传入的函数如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    def model_provider(
    pre_process=True, post_process=True, vp_stage: Optional[int] = None
    ) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
    pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
    post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.

    Returns:
    Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()

    if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
    return model_provider_modelopt(pre_process, post_process)

    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
    torch.cuda.memory._record_memory_history(
    True,
    # keep 100,000 alloc/free events from before the snapshot
    trace_alloc_max_entries=100000,
    # record stack information for the trace events
    trace_alloc_record_context=True,
    )

    def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    print('saving allocated state during OOM')
    snapshot = torch.cuda.memory._snapshot()
    from pickle import dump

    dump(
    snapshot,
    open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
    )

    torch._C._cuda_attach_out_of_memory_observer(oom_observer)

    print_rank_0('building GPT model ...')
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
    config = core_transformer_config_from_yaml(args, "language_model")
    else:
    config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
    model = megatron.legacy.model.GPTModel(
    config,
    num_tokentypes=0,
    parallel_output=True,
    pre_process=pre_process,
    post_process=post_process,
    )
    else: # using core models
    if args.spec is not None:
    transformer_layer_spec = import_module(args.spec)
    else:
    if args.num_experts:
    # Define the decoder block spec
    transformer_layer_spec = get_gpt_decoder_block_spec(
    config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
    )
    elif args.heterogeneous_layers_config_path is not None:
    transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
    else:
    # Define the decoder layer spec
    transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
    mtp_block_spec = None
    if args.mtp_num_layers is not None:
    if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0:
    # Get the decoder layer spec explicitly if no decoder layer in the last stage,
    # Only happens with block spec (TransformerBlockSubmodules) when using MoE.
    transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config)
    else:
    transformer_layer_spec_for_mtp = transformer_layer_spec
    mtp_block_spec = get_gpt_mtp_block_spec(
    config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage
    )

    model = GPTModel(
    config=config,
    transformer_layer_spec=transformer_layer_spec,
    vocab_size=args.padded_vocab_size,
    max_sequence_length=args.max_position_embeddings,
    pre_process=pre_process,
    post_process=post_process,
    fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
    parallel_output=True,
    share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
    position_embedding_type=args.position_embedding_type,
    rotary_percent=args.rotary_percent,
    rotary_base=args.rotary_base,
    rope_scaling=args.use_rope_scaling,
    mtp_block_spec=mtp_block_spec,
    vp_stage=vp_stage,
    )

    return model
    • 该函数的参数为:

      • pre_process: 是否计算嵌入层(输入处理)

      • post_process: 是否计算输出 logits/损失

      • vp_stage: 虚拟 pipeline stage(用于梯度累积优化)

    • 获取训练参数,如果启用 NVIDIA ModelOpt,委托给专门的 model provider

    • 依据参数决定是否使用 Transformer Engine(NVIDIA 的高性能 transformer 实现)

    • 依据参数决定是否启用内存历史记录用于调试 OOM 问题,以及是否自动保存内存快照到文件

    • 依据yaml文件或者是输入参数args获取config

    • 然后进入核心模型分支,依据参数获取transformer_layer_spec和mtp_block_spec

    • 最后根据各参数构建出GPTModel

  • model_type:用于告诉 Megatron 模型的“拓扑语义”,其有这三类:

    • encoder_or_decoder = 1:传统的编码器-解码器模型,或仅包含解码器的自回归模型

    • retro_encoder = 2:Retrieval-Enhanced Transformer (RETRO) 模型中的编码器组件,RETRO 是一种特殊的 transformer 架构,使用外部知识库进行检索增强

    • retro_decoder = 3:RETRO 模型中的解码器组件,负责最终的文本生成

  • get_embedding_ranks:指定哪些 rank 持有 word embedding

  • get_position_embedding_ranks:指定哪些 rank 持有 position embedding

Forward执行参数

  • forward_step_func:最核心的训练函数,其定义了一次 iteration 的“前向 + loss 计算”逻辑,

    • 其主要负责:

      1. data_iterator 里取 batch

      2. 调用 model(...)

      3. 计算 loss

      4. 返回:

        • loss(标量 tensor)

        • dict(用于 logging 的指标)

    • 示例中的代码如下所示,基本与上述功能一样,但是返回不一样了,合理怀疑是现在代码改了但是注释没改,返回的是计算结果以及计算loss的组合,此外还多了一些指标采集和专门的分支处理:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
      """Forward training step.

      Args:
      data_iterator : Input data iterator
      model (GPTModel): The GPT Model
      return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
      """
      args = get_args()
      timers = get_timers()

      # Get the batch.
      timers('batch-generator', log_level=2).start()
      global stimer
      with stimer(bdata=True):
      tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
      timers('batch-generator').stop()

      with stimer:
      if args.use_legacy_models:
      output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
      else:
      if return_schedule_plan:
      # MoE 专家并行重叠通信模式
      assert args.overlap_moe_expert_parallel_comm, \
      "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
      schedule_plan = model.build_schedule_plan(
      tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
      )
      return schedule_plan, partial(loss_func, loss_mask, model=model)
      else:
      # 标准前向传播
      output_tensor = model(
      tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
      )

      # [ModelOpt]: model is needed to access ModelOpt distillation losses
      return output_tensor, partial(loss_func, loss_mask, model=model)

      • 其中的get_bach也是自定义的,如下,依据rank所属的并行组来获取数据。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      def get_batch(data_iterator):
      """Generate a batch."""

      # TODO: this is pretty hacky, find a better way
      if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
      not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
      ):
      return None, None, None, None, None

      # get batches based on the TP rank you are on
      batch = get_batch_on_this_tp_rank(data_iterator)

      # slice batch along sequence dimension for context parallelism
      batch = get_batch_on_this_cp_rank(batch)

      return batch.values()
  • process_non_loss_data_func:可选参数,用来处理不参与反向传播的数据,例如专门把一些数据dump到TensorBoard

  • non_loss_data_func:在 evaluation 阶段执行自定义逻辑

参数与配置扩展

  • extra_args_provider:允许“业务代码”向 Megatron 的 argparse 注入自定义参数

  • args_defaults:覆盖 / 预设 Megatron 参数默认值

分布式 / 容错相关参数

  • store:提供一个外部的控制接口,如前述进程重启功能所示,在实例中就将控制面TCPStore传递给了该参数

  • inprocess_call_wrapper: in-process restart 自动注入的“调用包装器”,负责捕获:Python exception、CUDA error,然后上报给 inprocess controller,决定是retry还是abort或是terminate等,普通用户不用传,开启 inprocess_restart 时自动生效

pretrain流程概览

  • pretrain的代码如下所示,上面已经对pretrain的传入参数进行了解析,下面先对其整体流程做进一步梳理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
get_embedding_ranks=None,
get_position_embedding_ranks=None,
non_loss_data_func=None,
store=None,
inprocess_call_wrapper: Optional[CallWrapper] = None,
):
"""Main training program.

This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the model using the forward_step_func.

Args:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
get_embedding_ranks (TODO):
get_position_embedding_ranks (TODO):
non_loss_data_func (callable): A custom function to call during evaluation.
It can run e.g. benchmarks.
store: an optional instance of torch.distributed.Store, to be used by
torch.distributed.init_process_group
inprocess_call_wrapper: an optional instance of inprocess.CallWrapper,
it is automatically injected when in-process restart is in use
"""

if inprocess_call_wrapper is not None:
iteration = inprocess_call_wrapper.iteration
store = torch.distributed.PrefixStore(str(iteration), store)

# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(
extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
store=store,
)

args = get_args()
timers = get_timers()

if args.log_progress:
append_to_progress_log("Starting job")

# Initialize fault tolerance
# NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
if args.enable_ft_package:
ft_integration.setup(args)
ft_integration.maybe_setup_simulated_fault()

# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()

# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device='cuda')
torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()

app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)

print_rank_0(
'time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)
)
print_datetime('after megatron is initialized')
app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()

# Track E2E metrics on pretrain start
one_logger_utils.on_pretrain_start()

# Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local':
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
LocalCheckpointManager,
)
from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import (
parse_group_sequence,
GroupWrapper,
)
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
CliqueReplicationStrategy,
)
except ModuleNotFoundError:
raise RuntimeError(
"The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed."
)

if args.replication:
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump, args.replication_factor
)
else:
repl_strategy = None

checkpointing_context = {
'local_checkpoint_manager': LocalCheckpointManager(
args.non_persistent_local_ckpt_dir, repl_strategy=repl_strategy
)
}
else:
checkpointing_context = {}

# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type, checkpointing_context=checkpointing_context
)

timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' 'scheduler are built')
config = get_model_config(model[0])

# Data stuff.
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator = (
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()

# Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
one_logger_utils.track_config_flags(
args.train_iters,
args.skip_train,
args.do_train,
args.do_valid,
args.do_test,
args.dataloader_type,
args.retro_project_dir,
args.retro_cyclic_train_iters,
)

# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'], barrier=True)

one_logger = get_one_logger()
one_logger and one_logger.log_metrics(app_metrics)

if not args.skip_train:
print_rank_0('training ...')

if args.dataloader_type == 'cyclic' and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)

iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
)

print_datetime('after training is done')

if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
preprocess_common_state_dict_fn=preprocess_common_state_dict,
)

one_logger and one_logger.log_metrics(
{'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()}
)

else:
print_rank_0('skipping training (--skip-train is on) ...')

iteration = args.iteration

if args.do_valid:
prefix = f'iteration {iteration} on validation set'
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)

if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(
prefix,
forward_step_func,
test_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)

wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()

ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)

one_logger and one_logger.log_metrics(
{'app_finish_time': one_logger_utils.get_timestamp_in_ms()}
)

ft_integration.shutdown()
one_logger_utils.finish()

初始化

  1. 如果参数inprocess_call_wrapper不为空,说明需要容错,那么再次进入pretrain的时候,为了避免还接入到原本的控制面,需要调用inprocess_call_wrapper.iteration进行命名空间更新,来接入新的store。

  2. 初始化megatron-lm的通信组、并行设置、关键参数等等,下面会具体介绍

  3. 获取全局参数,megatron-lm是单例设计,获取参数都是通过get_*来通过获取全局变量获得

  4. FT(Fault Tolerance)初始化,FT更偏向于是利用checkpoint进行容错,inprocess是对进程运行时的容错

  5. 设置PyTorch JIT fusion进行算子融合,如果有必要还会对其进行预热

  6. 通过min操作的all reduce来获取最小的训练开始时间,已记录相关日志

  7. 如果参数控制需要不落盘的内存级的checkpoint,就引入相关的包并设置对应的上下文

  8. 根据并行化策略等得到model、optimizer、opt_param_scheduler,下面会具体介绍

  9. 构建数据迭代器,如果采用了Virtual Pipeline并行,那么每个 pipeline stage都会有自己专门的 data iterator

训练

  • 如果参数指示跳过train,那么就不执行train,继续执行后面的,如果没有,那么就执行iteration, num_floating_point_operations_so_far = train(...)进行训练,下面会具体介绍。并且如果迭代次数不是保存checkpoint的倍数那么就会专门对最后的模型进行保存

收尾

  • 如果参数指示需要valid,那么就调用evaluate_and_print_results使用valid_data_iterator获取数据执行一次valid,下面会具体介绍

  • 如果参数指示需要test,那么就调用evaluate_and_print_results使用test_data_iterator获取数据执行一次test

  • 得到wandb的句柄并关闭

  • 确保async checkpoint 完成,并将所有 IO 收尾

  • 关闭FT和Logger

pretrain核心流程解析

initialize_megatron

initialize_megatron代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
skip_mpu_initialization=False,
get_embedding_ranks=None,
get_position_embedding_ranks=None,
parsed_args=None,
store=None,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), "Megatron requires CUDA."

# Parse arguments
if parsed_args is None:
args = parse_args(extra_args_provider, ignore_unknown_args)
else:
args = parsed_args

# Prep for checkpoint conversion.
if args.ckpt_convert_format is not None:
assert args.ckpt_convert_save is not None
assert args.load is not None
args.exit_on_missing_checkpoint = True

if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoint-args requires --load argument"
assert args.non_persistent_ckpt_type != "local", (
"--use-checkpoint-args is not supported with --non_persistent_ckpt_type=local. "
"Two-stage checkpoint loading is not implemented, and all arguments must be defined "
"before initializing LocalCheckpointManager."
)
load_args_from_checkpoint(args)

if args.async_save and args.use_persistent_ckpt_worker:
init_persistent_async_worker()

if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)

# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)

# set logging level
setup_logging()

# init rerun state
def state_save_func():
return {'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}

def state_restore_func(state_dict):
if state_dict['rng_tracker_states']:
tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states'])

args = get_args()
initialize_rerun_state_machine(
state_save_func=state_save_func,
state_restore_func=state_restore_func,
mode=RerunMode(args.rerun_mode),
error_injector=RerunErrorInjector(
error_injection_rate=args.error_injection_rate,
error_injection_type=RerunDiagnostic(args.error_injection_type),
),
result_rejected_tracker_filename=args.result_rejected_tracker_filename,
)

# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, store)

# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(
args.seed,
args.data_parallel_random_init,
args.te_rng_tracker,
args.inference_rng_tracker,
use_cudagraphable_rng=args.enable_cuda_graph or args.external_cuda_graph,
)

# Setup MoE aux loss scale value.
if args.num_experts is not None:
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler

MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device()))

if skip_mpu_initialization:
return None

args = get_args()
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()

# Autoresume.
_init_autoresume()

# Compile dependencies.
_compile_dependencies()

if args.tp_comm_overlap:
# TODO: Should this be activated with just decoder-tp-comm-overlap too?
_initialize_tp_communicators()

# No continuation function
return None

initialize_megatron流程如下:

  1. 检查是否包含cuda

  2. 解析参数,注意这里还使用了pretrain函数传递进来的extra_args_provider

  3. 对checkpoint做格式转换并考虑从checkpoint中获取训练参数,如果使用异步checkpoint还负责启动保存checkpoint的IO worker

  4. 校验参数,设置全局参数

  5. 初始化日志

  6. 初始化容错的rerun状态机

  7. 如果使用lazy_mpu_init,就先设置一些模型并行参数,返回finish_mpu_init,等待外部调用其初始化

  8. 如果不使用lazy_mpu_init,就先调用finish_mpu_init初始化,再自动从 checkpoint 恢复,再提前编译依赖,再做 TP 通信重叠初始化。

finish_mpu_init

finish_mpu_init是初始化的核心模块,其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, store)

# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(
args.seed,
args.data_parallel_random_init,
args.te_rng_tracker,
args.inference_rng_tracker,
use_cudagraphable_rng=args.enable_cuda_graph or args.external_cuda_graph,
)

# Setup MoE aux loss scale value.
if args.num_experts is not None:
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler

MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device()))

finish_mpu_init流程如下:

  1. 调用_initialize_distributed初始化通信组

  2. 设置随机随机种子,Megatron 的 RNG 体系是:DP 可以不同,TP / PP 必须一致

  3. 如果是专家并行,还需要设置MoE 辅助损失缩放

_initialize_distributed

对于关键的_initialize_distributed,其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, store):
"""Initialize torch.distributed and core model parallel."""
args = get_args()

device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():

if args.rank == 0:
print(
"torch distributed is already initialized, " "skipping initialization ...",
flush=True,
)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()

else:

if args.rank == 0:
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
if device_count > 0:
torch.cuda.set_device(args.local_rank)
device_id = torch.device(f'cuda:{args.local_rank}')
else:
device_id = None

# Set to non-default stream for cudagraph capturing.
if args.external_cuda_graph:
torch.cuda.set_stream(torch.cuda.Stream())

# Call the init process
init_process_group_kwargs = {
'backend': args.distributed_backend,
'store': store,
'world_size': args.world_size,
'rank': args.rank,
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
}

torch.distributed.init_process_group(**init_process_group_kwargs)
inprocess_restart.maybe_force_nccl_backend_init(device_id)

# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print("model parallel is already initialized")
else:
mpu.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
use_sharp=args.use_sharp,
context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size,
num_distributed_optimizer_instances=args.num_distributed_optimizer_instances,
expert_tensor_parallel_size=args.expert_tensor_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp',
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups,
high_priority_stream_groups=args.high_priority_stream_groups,
sharp_enabled_group=args.sharp_enabled_group,
)
if args.rank == 0:
print(
f"> initialized tensor model parallel with size "
f"{mpu.get_tensor_model_parallel_world_size()}"
)
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)

流程如下:

  1. 通过torch.distributed.is_initialized()检查是否初始化torch.distributed,如果没有就调用torch.distributed.init_process_group(**init_process_group_kwargs)初始化。注意这里使用了pretrain传入的TCPStore。然后为了防止 NCCL communicator 因进程重启而失效,还强制触发一次 NCCL 初始化。

  2. 检查设备数是否大于0,如果是就检查是否已经进行模型并行初始化,如果没有就调用mpu.initialize_model_parallel进行初始化。

mpu.initialize_model_parallel代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
# pylint: disable=C0301
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_comm_backend: Optional[str] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
expert_model_parallel_size: int = 1,
num_distributed_optimizer_instances: int = 1,
expert_tensor_parallel_size: Optional[int] = None,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-ep-dp-pp",
get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
create_gloo_process_groups: bool = True,
high_priority_stream_groups: Optional[List[str]] = None,
sharp_enabled_group: Optional[str] = None,
) -> None:
"""Initialize model data parallel groups.

Args:
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.

pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.

virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):

GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]

pipeline_model_parallel_comm_backend (str, optional):
The backend to use for pipeline parallel communication.
If None, the default backend will be used.

use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.

context_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.

For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].

Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.

expert_model_parallel_size (int, default = 1):
The number of Mixture of Experts parallel GPUs in each expert
parallel group.

num_distributed_optimizer_instances (int, default = 1):
The number of distributed optimizer replicas across the data-
parallel domain.

expert_tensor_parallel_size (int, default = tp_size):
The number of GPUs to split individual tensors of expert.

nccl_communicator_config_path (str, default = None):
Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
for each communicator.

distributed_timeout_minutes (int, default = 30): Timeout, in
minutes,for operations executed against distributed
process groups. See PyTorch documentation at
https://pytorch.org/docs/stable/distributed.html for
caveats.

order (str, default=tp-dp-pp):
The rank initialization order of parallelism. Now we support
tp-dp-pp and tp-pp-dp orders.

get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
A function that takes in a list of ranks for a pipeline group and returns
those ranks that should have embeddings.

get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
A function that takes in a list of ranks for a pipeline group, and returns
those ranks that should have position embeddings.

create_gloo_process_groups (bool, default = True):
Create Gloo process groups if set to True. If set to False, Gloo process groups are
not created and calls to get Gloo process groups will result in assertion errors.

high_priority_stream_groups (List[str], default = None):
Specify which communicator groups should use high priority streams during creation.
Assigning high priority to communication streams ensures that communication kernels
are scheduled with higher priority, minimizing the exposed communication when it is
overlapped with other computation kernels.
Example: initialize_parallel_groups(..., high_priority_stream_groups=['dp_cp','ep_dp'])

sharp_enabled_group (str, default = None):
Specify which communicator group should use SHARP communication.
This option is only valid when use_sharp is True.
By default (None), it is enabled from dp group.
Available options (choose one): [dp, dp_replica]

Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# NCCL restricts IB SHARP usage to a single communicator group—the first one created
# with NCCL_COLLNET_ENABLE=1. After this group is created, NCCL_COLLNET_ENABLE must be
# set to 0 for subsequent groups.
if "NCCL_COLLNET_ENABLE" in os.environ:
del os.environ["NCCL_COLLNET_ENABLE"]

if use_sharp:
if sharp_enabled_group is None:
# By default, SHARP is enabled from dp group.
sharp_enabled_group = "dp"
else:
# Currently, only dp and dp_replica groups are supported for SHARP.
assert sharp_enabled_group in ["dp", "dp_replica"], "Invalid sharp_enabled_group"
if sharp_enabled_group == "dp_replica":
assert (
num_distributed_optimizer_instances > 1
), "dp_replica group requires num_distributed_optimizer_instances > 1"
else:
assert (
sharp_enabled_group is None
), "sharp_enabled_group is only valid when use_sharp is True"

if get_embedding_ranks is None:
get_embedding_ranks = default_embedding_ranks

if get_position_embedding_ranks is None:
get_position_embedding_ranks = default_position_embedding_ranks

# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size

if world_size % model_size != 0:
raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}")

data_parallel_size: int = world_size // model_size

if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 1:
raise RuntimeError(
"pipeline-model-parallel size should be greater than 1 with interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size

rank = torch.distributed.get_rank()

nccl_comm_cfgs = {}
if nccl_communicator_config_path is not None:
try:
import yaml
except ImportError:
raise RuntimeError(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
)

with open(nccl_communicator_config_path, "r") as stream:
nccl_comm_cfgs = yaml.safe_load(stream)

# Set is_high_priority_stream flag to the nccl_comm_cfgs if it is in high_priority_stream_groups
high_priority_stream_groups = high_priority_stream_groups or []
for pg_name in high_priority_stream_groups:
overwrite_nccl_comm_cfgs(nccl_comm_cfgs, pg_name, ("is_high_priority_stream", True))

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
rank_offset=0,
)

# Build expert rank generator
if expert_tensor_parallel_size is None:
expert_tensor_parallel_size = tensor_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size
)
expert_data_parallel_size = world_size // expert_tensor_model_pipeline_parallel_size
if world_size % expert_tensor_model_pipeline_parallel_size != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
)

# TODO: support expert specific ordering
expert_decoder_rank_generator = RankGenerator(
tp=expert_tensor_parallel_size,
ep=expert_model_parallel_size,
dp=expert_data_parallel_size,
pp=pipeline_model_parallel_size,
cp=1,
order=order,
rank_offset=0,
)

assert (
order.endswith("pp")
or pipeline_model_parallel_size == 1
or expert_data_parallel_size == data_parallel_size
), "When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same"

assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks(
"pp"
), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}"

timeout = timedelta(minutes=distributed_timeout_minutes)

# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"

assert (
data_parallel_size * context_parallel_size
) % num_distributed_optimizer_instances == 0, (
"Data parallel size should be divisible by partial DistOpt shard factor"
)
intra_partial_data_parallel_size = (
data_parallel_size * context_parallel_size
) // num_distributed_optimizer_instances

# Set NCCL_COLLNET_ENABLE to 1 to enable SHARP for the dp group.
if sharp_enabled_group == "dp":
os.environ["NCCL_COLLNET_ENABLE"] = "1"

# In case of using SHARP, the dp-cp group requires to use NCCL COLLNET feature.
# Due to the hardware limitation, only the initially created communication group
# is eligible for using the NCCL COLLNET feature.
# Therefore, dp-cp group, which potentially requires SHARP-enablement,
# need to be created before all the other groups
for ranks_with_cp in decoder_rank_generator.get_ranks('dp-cp'):
group_with_cp = create_group(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs),
group_desc="DATA_PARALLEL_GROUP_WITH_CP",
)
if create_gloo_process_groups:
group_with_cp_gloo = create_group(
ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc="DATA_PARALLEL_GROUP_WITH_CP_GLOO",
)
else:
group_with_cp_gloo = None
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp

if num_distributed_optimizer_instances > 1:
# Create groups for intra-partial DP domain
for i in range(num_distributed_optimizer_instances):
intra_partial_dp_ranks_with_cp = ranks_with_cp[
(i * intra_partial_data_parallel_size) : (
(i + 1) * intra_partial_data_parallel_size
)
]
intra_partial_dp_group_with_cp = create_group(
intra_partial_dp_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options("intra_dp_cp", nccl_comm_cfgs),
group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP",
)
if create_gloo_process_groups:
intra_partial_dp_group_with_cp_gloo = create_group(
intra_partial_dp_ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO",
)
else:
intra_partial_dp_group_with_cp_gloo = None
if rank in intra_partial_dp_ranks_with_cp:
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = intra_partial_dp_group_with_cp
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = (
intra_partial_dp_group_with_cp_gloo
)
else:
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = _DATA_PARALLEL_GROUP_WITH_CP
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = _DATA_PARALLEL_GROUP_WITH_CP_GLOO

# Apply SHARP to the dp group.
if sharp_enabled_group == "dp":
if rank == 0:
print(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
# PyTorch is performing lazy initialization of the communicator group.
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
torch.distributed.barrier(
group=get_data_parallel_group(with_context_parallel=True),
device_ids=[torch.cuda.current_device()],
)
torch.cuda.synchronize()
# Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to the dp group.
if "NCCL_COLLNET_ENABLE" in os.environ:
del os.environ["NCCL_COLLNET_ENABLE"]

for ranks in decoder_rank_generator.get_ranks('dp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("dp", nccl_comm_cfgs),
group_desc="DATA_PARALLEL_GROUP",
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks, timeout=timeout, backend="gloo", group_desc="DATA_PARALLEL_GROUP_GLOO"
)
else:
group_gloo = None
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks

# Build the context-parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GLOBAL_RANKS
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
for ranks in decoder_rank_generator.get_ranks('cp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("cp", nccl_comm_cfgs),
group_desc="CONTEXT_PARALLEL_GROUP",
)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
if hierarchical_context_parallel_sizes:
assert np.prod(hierarchical_context_parallel_sizes) == context_parallel_size
global _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS
hierarchical_groups, _ = create_hierarchical_groups(
rank,
ranks,
hierarchical_context_parallel_sizes,
create_gloo_process_groups=False,
pg_options=get_nccl_options("hcp", nccl_comm_cfgs),
timeout=timeout,
group_desc="CONTEXT_PARALLEL_GROUP",
)
if rank in ranks:
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups

# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GLOBAL_RANKS
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for ranks in decoder_rank_generator.get_ranks('tp-pp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("mp", nccl_comm_cfgs),
group_desc="MODEL_PARALLEL_GROUP",
)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
_MODEL_PARALLEL_GLOBAL_RANKS = ranks

# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
assert (
_TENSOR_MODEL_PARALLEL_GROUP is None
), 'tensor model parallel group is already initialized'
for ranks in decoder_rank_generator.get_ranks('tp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp", nccl_comm_cfgs),
group_desc="TENSOR_MODEL_PARALLEL_GROUP",
)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is None
), "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, "position embedding group is already initialized"
if pipeline_model_parallel_comm_backend == "ucc":
# The UCC backend provides two key benefits:
# 1) Achieves better bandwidth utilization than NCCL when using InfiniBand links.
# 2) Does not use GPU SM resources (Zero-SM), mitigating performance interference
# with overlapping compute kernels.

# The UCC backend is recommended in the following cases:
# 1) When the exposed pipeline-parallel (PP) communications are significant.
# - E.g., Pipeline parallelism with very less gradient accumulation steps.
# - It may provide better performance due to improved bandwidth utilization.
# 2) When the critical-path pipeline stage has substantial PP-communication overlap.
# - E.g., Uneven pipeline parallelism.
# - It may provide better performance due to zero SM resource usage.
if "CUDA_DEVICE_MAX_CONNECTIONS" in os.environ:
# UCC backend requires CUDA_DEVICE_MAX_CONNECTIONS variable to be larger than 1,
# to gurantee the overlapped UCC communications. If this environment variable is set to 1,
# all the UCC communication will be serialized.
assert (
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] != "1"
), "UCC-backend requires CUDA_DEVICE_MAX_CONNECTIONS > 1"

# Setting up required environment variables for ucc backend
#
# "TORCH_UCC_BLOCKING_WAIT=none" allows non-blocking waits of the communiction handle
# "UCC_EC_CUDA_STREAM_TASK_MODE" controls how CUDA execution engines (EC)
# schedule tasks on CUDA streams.
# "UCX_TLS" controls transport layer selection
# "NSYS_UCP_COMM_PARAMS=1" enables capturing ucx tracing in nsys profiling
# "UCX_RNDV_THRESH" controls threshold threshold for switching between
# eager and rendezvous (RNDV) communication protocols.
# "UCX_NET_DEVICES" select which network interfaces UCX should use.
# "UCC_CL_BASIC_TLS" controls which Transport Layers are used by
# the Basic Collective libraray

os.environ["TORCH_UCC_BLOCKING_WAIT"] = (
os.environ["TORCH_UCC_BLOCKING_WAIT"]
if "TORCH_UCC_BLOCKING_WAIT" in os.environ
else "none"
)
os.environ["UCC_EC_CUDA_STREAM_TASK_MODE"] = (
os.environ["UCC_EC_CUDA_STREAM_TASK_MODE"]
if "UCC_EC_CUDA_STREAM_TASK_MODE" in os.environ
else "driver"
)
os.environ["UCX_TLS"] = (
os.environ["UCX_TLS"] if "UCX_TLS" in os.environ else "ib,cuda_copy"
) # cuda_ipc (i.e., NVLink-enablement) will be later supported
os.environ["NSYS_UCP_COMM_PARAMS"] = "1"
os.environ["UCX_RNDV_THRESH"] = "0"
os.environ["UCX_NET_DEVICES"] = "all"
os.environ["UCC_CL_BASIC_TLS"] = "^sharp,nccl"

for ranks in decoder_rank_generator.get_ranks('pp'):
group = create_group(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == "ucc"
else get_nccl_options("pp", nccl_comm_cfgs)
),
group_desc="PIPELINE_MODEL_PARALLEL_GROUP",
)
assert (
pipeline_model_parallel_comm_backend == None
or pipeline_model_parallel_comm_backend == "nccl"
or pipeline_model_parallel_comm_backend == "ucc"
), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported'

if rank in ranks:
if _PIPELINE_MODEL_PARALLEL_GROUP is None:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
elif isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
_PIPELINE_MODEL_PARALLEL_GROUP.append(group)
_PIPELINE_GLOBAL_RANKS.append(ranks)
else:
_PIPELINE_MODEL_PARALLEL_GROUP = [_PIPELINE_MODEL_PARALLEL_GROUP, group]
_PIPELINE_GLOBAL_RANKS = [_PIPELINE_GLOBAL_RANKS, ranks]

embedding_ranks = get_embedding_ranks(ranks)
group = create_group(
embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options("embd", nccl_comm_cfgs),
group_desc="EMBEDDING_GROUP",
)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
_EMBEDDING_GLOBAL_RANKS = embedding_ranks

position_embedding_ranks = get_position_embedding_ranks(ranks)
group = create_group(
position_embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options("pos_embd", nccl_comm_cfgs),
group_desc="POSITION_EMBEDDING_GROUP",
)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks

# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is None
), 'Tensor + data parallel group is already initialized'
for ranks in decoder_rank_generator.get_ranks('tp-dp-cp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp_dp_cp", nccl_comm_cfgs),
group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP",
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
for ranks in decoder_rank_generator.get_ranks('tp-dp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp_dp", nccl_comm_cfgs),
group_desc="TENSOR_AND_DATA_PARALLEL_GROUP",
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group

global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
assert (
_TENSOR_AND_CONTEXT_PARALLEL_GROUP is None
), 'Tensor + context parallel group is already initialized'
for ranks in decoder_rank_generator.get_ranks('tp-cp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp_cp", nccl_comm_cfgs),
group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP",
)
if rank in ranks:
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = group

### Expert-related parallel groups initialization
# Build the expert model parallel group
global _EXPERT_MODEL_PARALLEL_GROUP
assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
for ranks in expert_decoder_rank_generator.get_ranks('ep'):
group = create_group(
ranks,
pg_options=get_nccl_options("ep", nccl_comm_cfgs),
group_desc="EXPERT_MODEL_PARALLEL_GROUP",
)
if rank in ranks:
_EXPERT_MODEL_PARALLEL_GROUP = group

# Build the expert tensor parallel group
global _EXPERT_TENSOR_PARALLEL_GROUP
assert (
_EXPERT_TENSOR_PARALLEL_GROUP is None
), 'Expert tensor model parallel group is already initialized'
for ranks in expert_decoder_rank_generator.get_ranks('tp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("ep_tp", nccl_comm_cfgs),
group_desc="EXPERT_TENSOR_PARALLEL_GROUP",
)
if rank in ranks:
_EXPERT_TENSOR_PARALLEL_GROUP = group

# Build the tensor + expert parallel groups
global _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP
assert (
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is None
), 'Expert tensor + model parallel group is already initialized'
for ranks in expert_decoder_rank_generator.get_ranks('tp-ep'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp_ep_mp", nccl_comm_cfgs),
group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP",
)
if rank in ranks:
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = group

# Build the expert+tensor+pipeline parallel groups
global _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
assert (
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is None
), 'The expert_tensor_model_pipeline parallel group is already initialized'
for ranks in expert_decoder_rank_generator.get_ranks('tp-ep-pp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("tp_ep_pp", nccl_comm_cfgs),
group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP",
)
if rank in ranks:
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = group

# Build the expert data parallel group
global _EXPERT_DATA_PARALLEL_GROUP
assert _EXPERT_DATA_PARALLEL_GROUP is None, "Expert data group is already initialized"
global _EXPERT_DATA_PARALLEL_GROUP_GLOO
assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, "Expert data group-gloo is already initialized"
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
assert (
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None
), "Intra partial expert data group is already initialized"
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO
assert (
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO is None
), "Intra partial expert data group-gloo is already initialized"
global _INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
assert (
_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None
), "Inter partial expert data group is already initialized"

assert (
expert_data_parallel_size % num_distributed_optimizer_instances == 0
), "Expert data parallel size should be divisible by partial DistOpt shard factor"
intra_partial_expert_data_parallel_size = (
expert_data_parallel_size // num_distributed_optimizer_instances
)

for ranks in expert_decoder_rank_generator.get_ranks('dp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs),
group_desc="EXPERT_DATA_PARALLEL_GROUP",
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks, backend="gloo", group_desc="EXPERT_DATA_PARALLEL_GROUP_GLOO"
)
else:
group_gloo = None
if rank in ranks:
_EXPERT_DATA_PARALLEL_GROUP = group
_EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo

if num_distributed_optimizer_instances > 1:
# Create groups for Partial DistOpt, one for intra-partial DP domain
# Another for inter-partial DP domain

# Set NCCL_COLLNET_ENABLE to 1 to enable SHARP for the dp_replica group.
if sharp_enabled_group == "dp_replica":
os.environ["NCCL_COLLNET_ENABLE"] = "1"
hierarchical_groups, hierarchical_groups_gloo = create_hierarchical_groups(
rank,
ranks,
[intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances],
create_gloo_process_groups=create_gloo_process_groups,
pg_options=[
get_nccl_options("intra_ep_dp", nccl_comm_cfgs),
get_nccl_options("inter_ep_dp", nccl_comm_cfgs),
],
timeout=timeout,
group_desc="EXPERT_DATA_PARALLEL_GROUP",
)
if rank in ranks:
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = hierarchical_groups[0]
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = hierarchical_groups_gloo[0]
_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = hierarchical_groups[1]

if sharp_enabled_group == "dp_replica":
# PyTorch is performing lazy initialization of the communicator group.
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
if _INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is not None:
torch.distributed.barrier(
group=_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP,
device_ids=[torch.cuda.current_device()],
)
torch.cuda.synchronize()
# Set NCCL_COLLNET_ENABLE to 0 to restrict SHARP application to the dp_replica group.
if "NCCL_COLLNET_ENABLE" in os.environ:
del os.environ["NCCL_COLLNET_ENABLE"]
else:
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = _EXPERT_DATA_PARALLEL_GROUP
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = _EXPERT_DATA_PARALLEL_GROUP_GLOO
### End of expert related parallel groups initialization

# build the intra distributed optimizer instance group
global _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP
assert (
_INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP is None
), "Intra distributed optimizer instance group is already initialized"

model_parallel_group_id = 0
intra_dist_opt_ranks = []
for ranks in expert_decoder_rank_generator.get_ranks('tp-ep-pp'):
model_parallel_group_id += 1
intra_dist_opt_ranks.extend(ranks)
if model_parallel_group_id % intra_partial_expert_data_parallel_size == 0:
intra_dist_opt_instance_group = create_group(
intra_dist_opt_ranks,
timeout=timeout,
pg_options=get_nccl_options("intra_dist_opt_instance", nccl_comm_cfgs),
group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP",
)
if rank in intra_dist_opt_ranks:
_INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP = intra_dist_opt_instance_group
intra_dist_opt_ranks = []

# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()

  • mpu.initialize_model_parallel的核心目的是依据并行策略设置创建一堆并行通信组,实现各worker的rank与并行组的映射。这包括了TP、PP、DP、Context Parallel(CP)、Expert Parallel(EP)。

  • 其首先构建了一个RankGenerator,这是rank与并行组匹配的核心

    • RankGenerator的相关代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    class RankGenerator(object):
    """A class for generating rank groups for different modes of parallelism."""

    def __init__(
    self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0
    ) -> None:
    assert (
    ep == 1 or cp == 1
    ), "Both EP and CP > 1 in not allow in one rank generator. \
    CP is only included in default RankGenerator, and EP only in expert RankGenerator."

    self.tp = tp
    self.ep = ep
    self.dp = dp
    self.pp = pp
    self.cp = cp
    self.rank_offset = rank_offset
    self.world_size = tp * dp * pp * cp * ep

    self.name_to_size = {
    "tp": self.tp,
    "pp": self.pp,
    "dp": self.dp,
    "ep": self.ep,
    "cp": self.cp,
    }
    self.order = order
    order = order.lower()

    for name in self.name_to_size.keys():
    if name not in order and self.name_to_size[name] != 1:
    raise RuntimeError(
    f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't"
    f"specified the order ({self.order})."
    )
    elif name not in order:
    order = order + "-" + name

    self.order = order
    self.ordered_size = []

    for token in order.split("-"):
    self.ordered_size.append(self.name_to_size[token])

    def get_mask(self, order: str, token: str):
    """Create a mask for the specified tokens based on the given order.

    Args:
    order (str): The order of parallelism types (e.g., 'tp-dp-pp').
    token (str): The specific parallelism types to include in the mask,
    separated by hyphens (e.g., 'tp-dp').
    """
    ordered_token = order.split("-")
    token_list = token.split("-")
    mask = [False] * len(ordered_token)
    for t in token_list:
    mask[ordered_token.index(t)] = True
    return mask

    def get_ranks(self, token):
    """Get rank group by input token.

    Args:
    token (str):
    Specify the ranks type that want to get. If we want
    to obtain multiple parallel types, we can use a hyphen
    '-' to separate them. For example, if we want to obtain
    the TP_DP group, the token should be 'tp-dp'.
    """
    mask = self.get_mask(self.order, token)
    ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask)
    if self.rank_offset > 0:
    for rank_group in ranks:
    for i in range(len(rank_group)):
    rank_group[i] += self.rank_offset
    return ranks

    def generate_masked_orthogonal_rank_groups(
    world_size: int, parallel_size: List[int], mask: List[bool]
    ) -> List[List[int]]:
    r"""Generate orthogonal parallel groups based on the parallel size and mask.

    Arguments:
    world_size (int): world size

    parallel_size (List[int]):
    The parallel size of each orthogonal parallel type. For example, if
    tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
    and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].

    mask (List[bool]):
    The mask controls which parallel methods the generated groups represent. If mask[i] is
    True, it means the generated group contains the i-th parallelism method. For example,
    if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
    the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
    generated group is the `pp` group.

    Algorithm:
    For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
    local_rank satisfy the following equation:
    global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
    tp_rank \in [0, tp_size)
    dp_rank \in [0, dp_size)
    pp_rank \in [0, pp_size)

    If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
    For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
    dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
    The tp_rank and pp_rank will be combined to form the `dp_group_index`.
    dp_group_index = tp_rank + pp_rank * tp_size (2)

    So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
    range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
    equation (1).

    This function solve this math problem.

    For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
    and the mask = [False, True, False]. Then,
    dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
    dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
    ...
    dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2

    dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
    dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
    ...
    dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
    """

    def prefix_product(a: List[int], init=1) -> List[int]:
    r = [init]
    for v in a:
    init = init * v
    r.append(init)
    return r

    def inner_product(a: List[int], b: List[int]) -> int:
    return sum([x * y for x, y in zip(a, b)])

    def decompose(index, shape, stride=None):
    """
    This function solve the math problem below:
    There is an equation:
    index = sum(idx[i] * stride[i])
    And given the value of index, stride.
    Return the idx.
    This function will be used to get the pp/dp/pp_rank
    from group_index and rank_in_group.
    """
    if stride is None:
    stride = prefix_product(shape)
    idx = [(index // d) % s for s, d in zip(shape, stride)]
    # stride is a prefix_product result. And the value of stride[-1]
    # is not used.
    assert (
    sum([x * y for x, y in zip(idx, stride[:-1])]) == index
    ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
    return idx

    masked_shape = [s for s, m in zip(parallel_size, mask) if m]
    unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]

    global_stride = prefix_product(parallel_size)
    masked_stride = [d for d, m in zip(global_stride, mask) if m]
    unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]

    group_size = prefix_product(masked_shape)[-1]
    num_of_group = world_size // group_size

    ranks = []
    for group_index in range(num_of_group):
    # get indices from unmaksed for group_index.
    decomposed_group_idx = decompose(group_index, unmasked_shape)
    rank = []
    for rank_in_group in range(group_size):
    # get indices from masked for rank_in_group.
    decomposed_rank_idx = decompose(rank_in_group, masked_shape)
    rank.append(
    inner_product(decomposed_rank_idx, masked_stride)
    + inner_product(decomposed_group_idx, unmasked_stride)
    )
    ranks.append(rank)
    return ranks

    • RankGenerator需要先获取到各个并行方法的并行度,此外还需要获得一个rank计数的顺序,这是一个字符串,如"tp-dp-pp",说明先计数tp再是dp再是pp,RankGenerator在初始化时还会进行一定程度的补全与解析。

    • RankGenerator有两个函数

      • 一个是get_mask,负责根据order和token返回mask。例如order是’tp-dp-pp’,token是’tp-dp’,那么就会返回[true, false, true]

      • 一个是get_ranks,负责依据token返回对应的rank group。例如现在的order是’tp-pp-dp’,tp_size=2,pp_size=2,dp_size=2,现在global rank的计算公式为tp_rank+pp_rank*tp_size+dp_rank*tp_size*pp_size,现在token是’dp’,说明想要知道tp所属rank、pp所属rank相同,但是所属dp不同的rank的集合,即tp_rank+pp_rank*tp_size+rang(dp_rank)*tp_size*pp_size,rang(dp_rank)={0,1},也就是需要知道哪些rank需要进行dp间通信以共享对应相同模型参数的梯度计算结果等,在这个例子中我们得到的就是[[0,4],[1,5],[2,6],[3,7]],如下图所示,同颜色的就是同一个dp_group内的rank。

      • 同理对于get_ranks,如果现在的order是’tp-dp-pp’,tp_size=2,dp_size=3,pp_size=4,现在global rank的计算公式为tp_rank+dp_rank*tp_size+pp_rank*tp_size*dp_size,如果token依旧是’dp’,那么dp group的计算公式为tp_rank+rang(dp_rank)*tp_size+pp_rank*tp_size*dp_size,即[[0,2,4],[1,3,5]…[19,21,23]]

  • 然后借助RankGenerator,我们就可以创建各个通信组,其创建流程基本与如下代码类似,即得到不同并行策略的groups,然后遍历这些group,对每个group创建torch.distributed.new_group,然后查看如果本进程的rank在这个group里,那么就设置其相关全局变量为这个group。注意这里每次都创建了new_group,但是本进程接下来可能并不会保存它,这么做是为了在分布式执行中让所有worker都执行同样的new_group,以保证分布式通信的正确,防止死锁等问题。

    1
    2
    3
    4
    5
    for ranks in rank_generator.get_ranks('xxx'):
    group = create_group(ranks, backend=..., pg_options=...)
    if rank in ranks:
    GLOBAL_GROUP = group
    GLOBAL_RANKS = ranks
    • 非 Expert(Decoder / Attention)并行组对应关系

    • Expert(MoE)相关并行组对应关系:

    • Distributed Optimizer 相关组对应关系

setup_model_and_optimizer

setup_model_and_optimizer的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def setup_model_and_optimizer(
model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0,
checkpointing_context=None,
):
"""Setup model and optimizer."""
args = get_args()
timers = get_timers()
one_logger = get_one_logger()

model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model)

one_logger and one_logger.log_metrics({"app_build_optimzer_start_time": one_logger_utils.get_timestamp_in_ms()})
kwargs = {}
for f in dataclasses.fields(OptimizerConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
config = OptimizerConfig(**kwargs)
config.timers = timers
optimizer = get_megatron_optimizer(
config,
model,
no_wd_decay_cond,
scale_lr_cond,
lr_mult,
use_gloo_process_groups=args.enable_gloo_process_groups,
# If the user is asking for a non-zero embedding init std, skip weight decay for embeddings
# to avoid embeddings from shrinking to zero as recommended in https://arxiv.org/abs/2312.16903
default_skip_embedding_weight_decay=args.embedding_init_method_std is not None,
)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
one_logger and one_logger.log_metrics({"app_build_optimzer_finish_time": one_logger_utils.get_timestamp_in_ms()})

if args.moe_use_upcycling:
torch.distributed.barrier()
assert not checkpoint_exists(args.save), (
"The upcycling destination directory already exists. "
"Please check if --moe-use-upcycling is mistakenly enabled. "
"Upcycling should only be set for the first run when converting the dense model. "
"All subsequent runs should remove this flag. "
)
# before changing moe related global args, save them in local variables
num_experts = args.num_experts
expert_model_parallel_size = args.expert_model_parallel_size
moe_ffn_hidden_size = args.ffn_hidden_size

# set dense model related args in to global args before getting dense model
args.num_experts = None
args.expert_model_parallel_size = 1
args.ffn_hidden_size = moe_ffn_hidden_size * args.moe_upcycling_granularity

# get dense model
dense_model_for_upcycling = get_model(model_provider_func, model_type)

# recover moe upcycling related args in global args before executing upcycling
args.num_experts = num_experts
args.expert_model_parallel_size = expert_model_parallel_size
args.ffn_hidden_size = moe_ffn_hidden_size

# execute upcycling
_, args.num_floating_point_operations_so_far = upcycling_utils.load_and_upcycle_model(
load_checkpoint,
unwrapped_model,
dense_model_for_upcycling,
load_kwargs={
'model': dense_model_for_upcycling,
'optimizer': None,
'opt_param_scheduler': None,
},
)
args.iteration = 1
save_checkpoint(
args.iteration, model, None, None, args.num_floating_point_operations_so_far
)
torch.distributed.barrier()
del dense_model_for_upcycling
if (args.fp16 or args.bf16) and optimizer is not None:
optimizer.reload_model_params()
print_rank_0(f'Upcycled checkpoint saved to {args.save}')

if (
args.load is not None or args.pretrained_checkpoint is not None
) and not args.moe_use_upcycling:
one_logger and one_logger.log_metrics(
{'load_checkpoint_start_time': one_logger_utils.get_timestamp_in_ms()}
)
timers('load-checkpoint', log_level=0).start(barrier=True)

args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
model,
optimizer,
opt_param_scheduler,
checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2
and getattr(args, "use_torch_fsdp2", False)
and args.ckpt_format == "torch_dist",
)
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
one_logger and one_logger.log_metrics(
{
'load_checkpoint_finish_time': one_logger_utils.get_timestamp_in_ms(),
'load_checkpoint_time': timers('load-checkpoint').active_time(),
}
)
else:
args.iteration = 0
args.num_floating_point_operations_so_far = 0

# get model without FP16 and/or DDP wrappers
if (
args.iteration == 0
and len(unwrapped_model) == 1
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert')
):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()

# Convert checkpoint format.
if args.ckpt_convert_format is not None:
load_ckpt_format = args.ckpt_format
args.ckpt_format = args.ckpt_convert_format
args.save = os.path.join(args.ckpt_convert_save, args.ckpt_convert_format)
update_use_dist_ckpt(args)

save_checkpoint(
args.iteration,
model,
optimizer,
opt_param_scheduler,
args.num_floating_point_operations_so_far,
preprocess_common_state_dict_fn=preprocess_common_state_dict,
)

print_rank_0("> converted checkpoint: %s -> %s." % (load_ckpt_format, args.ckpt_format))
torch.distributed.barrier()
exit()

return model, optimizer, opt_param_scheduler

其整体流程如下:

  1. 通过get_model获取本worker上的模型

  2. 通过unwrap_model来获取DDP包装下的原始模型

  3. 通过get_megatron_optimizer获取optimizer

  4. 通过get_optimizer_param_scheduler获取optimizer学习率参数调度器

  5. 如有配置args.moe_use_upcycling,执行MoE upcycling,把 Dense FFN 模型转成 MoE 模型,然后从检查点中获取模型并保存,还会调整优化器

  6. 如果没有配置args.moe_use_upcycling并且配置了检查点,那么就从检查点中加载模型、优化器等

  7. 如果iteration = 0 时的特殊初始化(BERT),执行optimizer.reload_model_params()

  8. 如果配置了args.ckpt_convert_format,就加载旧格式的模型检查点,然后保存为新格式的模型检查点

  9. 最终return model, optimizer, opt_param_scheduler

get_model

get_model代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type

# Build model.
def build_model():
if (
mpu.get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
this_model = model_provider_func(
pre_process=pre_process, post_process=post_process, vp_stage=i)
this_model.model_type = model_type
this_model.vp_stage = i
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(pre_process=pre_process, post_process=post_process)
model.model_type = model_type
return model

if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else:
model = build_model()

if not isinstance(model, list):
model = [model]

# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

# Print number of parameters.
num_parameters = sum(
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
)
if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
print(
' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
num_parameters,
),
flush=True,
)

# GPU allocation.
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
# in the fully_shard function of FSDP2 instead.
if (
not (args.use_torch_fsdp2 and args.use_cpu_initialization)
and not args.init_model_with_meta_device
):
for model_module in model:
model_module.cuda(torch.cuda.current_device())

# Fp16 conversion.
if args.fp16 or args.bf16:
config = get_model_config(model[0])
model = [Float16Module(config, model_module) for model_module in model]

# Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
# copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
# from the current fp8 param) to its amax_history. The below function will correct
# the amax_history back.
# After TE2.x: Below function is an empty function and does nothing.
correct_amax_history_if_needed(model)

if wrap_with_ddp:
if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
elif args.use_megatron_fsdp:
DP = megatron_FSDP
else:
DP = DDP

config = get_model_config(model[0])

if getattr(args, "use_torch_fsdp2", False):
reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True)
ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward)
else:
kwargs = {}
for f in dataclasses.fields(DistributedDataParallelConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['check_for_large_grads'] = args.check_for_large_grads
if args.ddp_num_buckets is not None:
assert args.ddp_bucket_size is None, \
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
assert args.ddp_num_buckets > 0, \
"--ddp-num-buckets must be greater than 0"
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
else:
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
kwargs['average_in_collective'] = args.ddp_average_in_collective
if args.use_megatron_fsdp and args.use_precision_aware_optimizer:
kwargs["preserve_fp32_weights"] = False
ddp_config = DistributedDataParallelConfig(**kwargs)

# In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None

with torch.cuda.stream(torch.cuda.Stream()):
model = [
DP(
config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0)
or args.overlap_param_gather_with_optimizer_step,
)
for (model_chunk_idx, model_chunk) in enumerate(model)
]

# Broadcast params from data parallel src rank to other data parallel ranks.
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()

return model

其整体流程如下

  1. 构建了一个build_model()函数,并使用其获得各个worker上应有的模型。其主要是负责根据当前是否是第一阶段、是否是最后一阶段以及当前的pp维度、vp维度构建在当前worker上的模型。例如当前layer数量是8,pp_size是4,vp_size是2,则4个worker获得的模型分别是[embedding+layer_0, layer_4], [layer_1, layer_5], [layer_2,layer_6], [layer_3,layer_7+lm head+loss]。

  2. 给各个model的parameters设置tp的默认参数:

1
2
3
4
5
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
  • 如果没有使用fsdp2并且不使用cpu初始化并且没有init_model_with_meta_device,那么就将模型搬运到GPU显存上

  • 如果配置了fp16、bf16,那么就对model进行低精度转换

  • 然后如果启用了DDP就进行DDP包装:

    1. 首先定义DP,这里支持3类DDP包装,分别是torch_FSDP、megatron_FSDP和DDP

    2. 然后构造ddp_config:

      1. 如果使用的是fsdp2,就直接TorchFullyShardedDataParallelConfig作为ddp_config

      2. 不然就自行填写参数构造DistributedDataParallelConfig作为ddp_config

    3. 然后对于当前rank的各个model,使用DP对其进行包装

    4. 如果使用了data_parallel_random_init,还需要在ddp内进行broadcast_params以统一参数。

get_megatron_optimizer

get_megatron_optimizer代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
use_gloo_process_groups: bool = True,
default_skip_embedding_weight_decay: bool = False,
grad_comm_pgs: Optional[GradCommProcessGroups] = None,
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.

We use separate optimizers for expert parameters and non-expert parameters.

Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
use_gloo_process_groups (bool): if false, disable use of Gloo process groups
in underlying Megatron optimizers.
default_skip_embedding_weight_decay (bool): whether to skip weight decay for
embedding parameters by default, if no_weight_decay_cond is not provided.
This is useful if you do not want embeddings to shrink to zero in training
as recommended in https://arxiv.org/abs/2312.16903
grad_comm_pgs (Optional[GradCommProcessGroups]): gradient communication process groups.
If None, uses default parallel_state groups.
model_comm_pgs (Optional[ModelCommProcessGroups]): model communication process groups.
If None, uses default parallel_state groups.

Returns:
Instance of MegatronOptimizer.
"""

log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')

# Separate out first model chunk if overlapping param AG with optimizer step.
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
overlap_param_gather_with_optimizer_step_flags = [True, False]
else:
all_dense_model_chunks = [model_chunks]
overlap_param_gather_with_optimizer_step_flags = [False]

if grad_comm_pgs is None and model_comm_pgs is None:
# Gradient communication groups
dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
intra_dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
)

intra_expt_dp_group = parallel_state.get_expert_data_parallel_group(
partial_expert_data_parallel=True
)

# Gloo groups
if use_gloo_process_groups:
intra_dp_cp_group_gloo = parallel_state.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
)
intra_expt_dp_group_gloo = parallel_state.get_expert_data_parallel_group_gloo(
partial_expert_data_parallel=True
)
else:
intra_dp_cp_group_gloo = None
intra_expt_dp_group_gloo = None

# Model communication groups
mp_group = parallel_state.get_model_parallel_group()
expt_tp_pp_group = parallel_state.get_expert_tensor_model_pipeline_parallel_group()
elif grad_comm_pgs is not None and model_comm_pgs is not None:
# 1. dp group - this is always required
if not hasattr(grad_comm_pgs, 'dp'):
raise ValueError("dp process group is required but not provided in grad_comm_pgs")
dp_group = grad_comm_pgs.dp

# 2. dp_cp group:
# - If provided in grad_comm_pgs, use it
# - Otherwise check context_parallel_size
# - If cp_size is 1, use same as dp
# - If cp_size > 1, raise error as dp_cp is needed
if hasattr(grad_comm_pgs, 'dp_cp'):
dp_cp_group = grad_comm_pgs.dp_cp
else:
model_config = get_model_config(model_chunks[0])
cp_size = getattr(model_config, 'context_parallel_size', 1)
if cp_size == 1:
# If no context parallelism, dp_cp is same as dp
dp_cp_group = dp_group
else:
raise ValueError(
"dp_cp process group is required when context_parallel_size > 1 "
"but not provided in grad_comm_pgs"
)

# 3. Handle expert data parallel group
assert hasattr(grad_comm_pgs, 'expt_dp'), (
"expt_dp process group is required but not provided in grad_comm_pgs",
"please explicitly set it to None if you don't need it",
)
expt_dp_group = grad_comm_pgs.expt_dp

# 4. Handle intra_dp_cp, intra_expt_dp, and inter_dist_opt
# based on optimizer instances:
# Get ddp_config from model chunks to determine optimizer instances
ddp_config = model_chunks[0].ddp_config
if ddp_config.num_distributed_optimizer_instances == 1:
# With a single optimizer instance:
# - intra_dp_cp is same as dp_cp
# - intra_expt_dp is same as expt_dp
# - inter_dist_opt is not needed (set to None)
intra_dp_cp_group = dp_cp_group
intra_expt_dp_group = expt_dp_group
else:
# With multiple optimizer instances, both groups must be provided
if not (
hasattr(grad_comm_pgs, 'intra_dp_cp')
and hasattr(grad_comm_pgs, 'intra_expt_dp')
and hasattr(grad_comm_pgs, 'inter_dist_opt')
):
raise ValueError(
"intra_dp_cp, intra_expt_dp, and inter_dist_opt "
"process groups are required when using multiple optimizer "
"instances (>1) but not provided in grad_comm_pgs"
)
intra_dp_cp_group = grad_comm_pgs.intra_dp_cp
intra_expt_dp_group = grad_comm_pgs.intra_expt_dp

# 5. Model communication groups
assert hasattr(model_comm_pgs, 'mp'), (
"mp process group is required but not provided in model_comm_pgs",
"please explicitly set it to None if you don't need it",
)
mp_group = model_comm_pgs.mp

# Expert tensor-model-pipeline group for MoE
assert hasattr(model_comm_pgs, 'tp_ep_pp'), (
"tp_ep_pp process group is required but not provided in model_comm_pgs",
"please explicitly set it to None if you don't need it",
)
expt_tp_pp_group = model_comm_pgs.tp_ep_pp

# Set up gloo groups - these might not be provided in process groups config
# so we need to create them or set to None
assert not use_gloo_process_groups, (
"Gloo process groups are not supported when grad_comm_pgs and model_comm_pgs are "
"provided. Please set use_gloo_process_groups to False."
)
intra_dp_cp_group_gloo = None
intra_expt_dp_group_gloo = None

else:
raise ValueError("Grad and model comm process groups must be provided or both must be None")

model_parallel_rank = get_pg_rank(mp_group)

if get_pg_size(dp_cp_group) > get_pg_size(intra_dp_cp_group):
if grad_comm_pgs is not None:
inter_dist_opt_group = grad_comm_pgs.inter_dist_opt
else:
inter_dist_opt_group = parallel_state.get_inter_distributed_optimizer_instance_group()
distributed_optimizer_instance_id = get_pg_rank(inter_dist_opt_group)
else:
distributed_optimizer_instance_id = 0

optimizers = []
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config
if ddp_config.use_megatron_fsdp:
for model_chunk, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
model_chunk,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: True,
buffer_name='buffers',
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
)

optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
)
model_chunk_offset += 1

if len(optimizers) == 1:
return optimizers[0]

return ChainedOptimizer(optimizers)

for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
dense_model_chunks,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: not g['is_expert_parallel'],
buffer_name='buffers',
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
)
for model_chunk in dense_model_chunks:
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)

# Pass Gloo process groups into optimizer only if needed.
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=dense_model_chunks,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=intra_dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
)
model_chunk_offset += 1

moe_param_groups, moe_buffers = _get_param_groups_and_buffers(
model_chunks,
model_chunk_offset=0,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: g['is_expert_parallel'],
buffer_name='expert_parallel_buffers',
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
)
if len(moe_param_groups) > 0:
expt_model_parallel_rank = get_pg_rank(expt_tp_pp_group)
# Pass Gloo process groups into optimizer only if needed.
if use_gloo_process_groups:
expt_data_parallel_group_gloo = intra_expt_dp_group_gloo
else:
expt_data_parallel_group_gloo = None
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunks,
param_groups=moe_param_groups,
per_model_buffers=moe_buffers,
model_parallel_group=expt_tp_pp_group,
data_parallel_group=intra_expt_dp_group,
data_parallel_group_gloo=expt_data_parallel_group_gloo,
data_parallel_group_idx=expt_model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
)

return ChainedOptimizer(optimizers)

get_optimizer_param_scheduler

get_optimizer_param_scheduler负责给optimizer计算每一步使用什么学习率,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()

# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
wsd_decay_steps = None
if args.lr_wsd_decay_iters is not None:
wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
wsd_decay_steps = args.lr_wsd_decay_samples
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_samples
else:
raise Exception('either train-iters or train-samples should be provided.')

opt_param_scheduler = OptimizerParamScheduler(
optimizer,
init_lr=args.lr_warmup_init,
max_lr=args.lr,
min_lr=args.min_lr,
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler,
wsd_decay_steps=wsd_decay_steps,
lr_wsd_decay_style=args.lr_wsd_decay_style,
)

return opt_param_scheduler

其有两类训练描述方式:

  • Iteration-based training,即设置了类似--train-iters 500000

    • 因为Megatron 的 OptimizerParamScheduler 内部是以已处理的 sample 数作为横轴,所以1 iteration = global_batch_size 个 samples

    • 然后其支持设置warm up步数、Warm Start Decay步数

  • Sample-based training,即设置了类似--train-samples 300B

    • 首先其通过update_train_iters(args)来反推train_iters,因为可能其他地方还需要使用iters

    • 然后可以直接得到lr_decay_steps = args.lr_decay_samples,wd_incr_steps = args.train_samples

  • 最后借助这些参数构造OptimizerParamScheduler

train

train的代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
energy_monitor = get_energy_monitor()
one_logger = get_one_logger()

if args.run_workload_inspector_server:
try:
from workload_inspector.utils.webserver import run_server
import threading

threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
except ModuleNotFoundError:
print_rank_0("workload inspector module not found.")

# Write args to tensorboard
write_args_to_tensorboard()

# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()

# Tracking loss.
total_loss_dict = {}

# Iterations.
iteration = args.iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine = get_rerun_state_machine()
if rerun_state_machine.current_iteration != iteration:
print_rank_0(f"Overwriting rerun_state_machine.current_iteration from "
f"{rerun_state_machine.current_iteration} to {iteration}...")
rerun_state_machine.current_iteration = iteration

# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(
iteration=iteration,
consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)

num_floating_point_operations_so_far = args.num_floating_point_operations_so_far

# Setup some training config params.
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], (megatron_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, (
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
if args.align_grad_reduce:
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
if len(model) == 1:
config.grad_sync_func = config.grad_sync_func[0]
if args.overlap_param_gather and args.align_param_gather:
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
if len(model) == 1:
config.param_sync_func = config.param_sync_func[0]
config.finalize_model_grads_func = finalize_model_grads

if args.log_energy:
energy_monitor.setup()
energy_monitor.resume()

timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
pre_hook_enabled = False
should_exit = False
exit_code = 0

if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert (
args.manual_gc_interval >= 0
), 'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()

# Singleton initialization of straggler detector.
if args.log_straggler:
global stimer
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(
world,
rank,
mmcnt=mmcnt,
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
num_floating_point_operations_since_last_log_event = 0.0

num_microbatches = get_num_microbatches()
eval_duration = 0.0
eval_iterations = 0
# Wrap forward_backward_func for Full iteration CUDA graph
forward_backward_func = get_forward_backward_func()
if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":
forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)

def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start = (
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration,
'eval_iterations': eval_iterations,
'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length,
}

# Cache into one-logger for callback.
if one_logger:
with one_logger.get_context_manager():
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)

prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
with_stack=True,
)
prof.start()

start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(
model, cross_check=True
), "Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")

# Capture CUDA Graphs.
if args.external_cuda_graph:
cuda_graph_helper = TECudaGraphHelper(
model=model,
config=config,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
optimizers=[optimizer],
)
cuda_graph_helper.create_cudagraphs()

# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.step()
elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()

ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=False)
ft_integration.on_checkpointing_end(is_async_finalization=True)

# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
# from the previous iteration, save a checkpoint. Then run consistency check
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, (
f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}"
)
if args.save is not None:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)

# Completely skip iteration if needed.
if iteration in args.iterations_to_skip:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue

# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
(
loss_dict,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func
)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
if should_exit:
break

# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
# Set the manual hooks when CUDA Graphs are used.
if args.external_cuda_graph:
cuda_graph_helper.cuda_graph_set_manual_hooks()

iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (
get_current_global_batch_size() - get_current_running_global_batch_size()
)
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
assert num_skipped_samples_in_batch == 0
args.skipped_train_samples += num_skipped_samples_in_batch
num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size)
num_floating_point_operations_so_far += num_floating_point_operations_in_batch
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch

# Logging.
if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None

if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
decoupled_learning_rate = None
for param_group in optimizer.param_groups:
if len(param_group['params']) == 0:
continue
if param_group['is_decoupled_lr']:
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)

# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
if args.log_energy:
energy_monitor.pause()
timers('interval-time').stop()
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
eval_duration += timers('eval-time').elapsed()
eval_iterations += sum(args.eval_iters) if isinstance(args.eval_iters, list) else args.eval_iters
timers('eval-time').stop()
one_logger_utils.track_e2e_metrics()

if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
if args.log_energy:
energy_monitor.resume()

# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)

# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
if should_exit:
break

one_logger_utils.track_e2e_metrics()

# Flush TensorBoard, WandB writers and one-logger.
writer = get_tensorboard_writer()
if writer:
writer.flush()

# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if pre_hook_enabled:
disable_forward_pre_hook(model)

ft_integration.on_checkpointing_start()
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()

if args.log_energy:
energy_monitor.lap()
total_energy = energy_monitor.get_total()
print_rank_0(f"Total training energy (GPU): {total_energy / 1e6} MJ")
energy_monitor.shutdown()

# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)

return iteration, num_floating_point_operations_so_far

其主要流程为:

  1. 获取全局配置

  2. 切换到 train() 模式

  3. 将迭代次数与rerun_state_machine对齐

  4. config 注入“分布式重叠通信/同步”钩子,config 会被 forward_backward_func / train_step 使用,用于控制 no_sync(梯度累积)、梯度同步重叠、参数 all-gather 重叠 等。

  5. 还有一些可选的配置,如控制gc、开启profiler等

  6. 进入主循环:while iteration < args.train_iters

    1. 更新microbatch,因为存在动态batch size的场景

    2. 使用train_step进行一轮训练

    3. 按需保存checkpoint

    4. 第一轮成功后再启用 forward pre-hook

    5. 更新 iteration/样本数/FLOPs

    6. 按需进行评估

  7. 训练收尾,flush writer(TensorBoard/WandB/one-logger),若 pre-hook 还开着,关闭它(避免退出时还有挂钩或后台同步),最终 finalize 异步检查点等

train_step

参数forward_backward_func

train_step的关键参数还包含了forward_backward_func,其通过如下获得:

1
2
3
4
# Wrap forward_backward_func for Full iteration CUDA graph
forward_backward_func = get_forward_backward_func()
if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":
forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)

get_forward_backward_func函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.

Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.

Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.

The function returned takes the following arguments:

forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.

A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.

For example:

def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {'lm loss': averaged_loss[0]}

def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)

forward_backward_func(forward_step_func=forward_step, ...)

data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.

model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.

num_microbatches (int, required):
The number of microbatches to go through

seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.

micro_batch_size (int, required): The number of sequences in a microbatch.

decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.

forward_only (optional, default = False): Perform only the forward step

collect_non_loss_data (optional, bool, default=False): TODO

first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation
step.

adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and send
tensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now.
Takes in a list of receive shapes and a list of send shapes and returns the adjusted
respective list of shapes. Thus it is not used in the other forward-backward functions
which have different shape handling.

"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func

其依据pipeline划分为了多个类别

  • forward_backward_pipelining_with_interleaving:开了pp和vp

  • forward_backward_pipelining_without_interleaving:开了pp但是没有vp

  • forward_backward_no_pipelining:没有开pp

先只看最简单的forward_backward_no_pipelining,代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def forward_backward_no_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: Optional[int] = None, # unused
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
grad_finalize_pgs: Optional[GradFinalizeProcessGroups] = None,
):
"""Run forward and backward passes with no pipeline parallelism"""

if grad_finalize_pgs is None:
tp_group = parallel_state.get_tensor_model_parallel_group()
cp_group = parallel_state.get_context_parallel_group()
embd_group = parallel_state.get_embedding_group(check_initialized=False)
pp_group = parallel_state.get_pipeline_model_parallel_group()
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
grad_finalize_pgs = GradFinalizeProcessGroups()
grad_finalize_pgs.tp = tp_group
grad_finalize_pgs.cp = cp_group
grad_finalize_pgs.embd = embd_group
grad_finalize_pgs.pos_embd = pos_emb_group
grad_finalize_pgs.pp = pp_group
grad_finalize_pgs.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)

elif grad_finalize_pgs is not None:
assert hasattr(grad_finalize_pgs, 'tp')
assert hasattr(grad_finalize_pgs, 'cp')
assert hasattr(grad_finalize_pgs, 'embd'), (
"grad_finalize_pgs must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group. If you are "
"using the default process group, please use `parallel_state.get_embedding_group()` "
"to get the process group. If you don't need explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pos_embd'), (
"grad_finalize_pgs must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group. "
"If you are using the default process group, "
"please use `parallel_state.get_position_embedding_group()` "
"to get the process group. If you don't need explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pp')
assert hasattr(grad_finalize_pgs, 'dp_cp')

if isinstance(model, list):
assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"

config = get_model_config(model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext

model_type = get_model_type(model)

forward_data_store = []
input_tensor, output_tensor_grad = None, None
total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")

if config.overlap_moe_expert_parallel_comm and not forward_only:
forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
output_tensor_grad,
forward_data_store,
config,
collect_non_loss_data,
first_val_step,
forward_only,
no_sync_func,
total_num_tokens,
partial(check_first_val_step, first_val_step, forward_only),
)
else:
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
grad_finalize_pgs.cp.size(),
collect_non_loss_data,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
)
total_num_tokens += num_tokens
if not forward_only:
backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
grad_finalize_pgs.cp.size(),
collect_non_loss_data,
is_first_microbatch=check_first_val_step(
first_val_step, forward_only, num_microbatches == 1
),
current_microbatch=num_microbatches - 1,
)

total_num_tokens += num_tokens

if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config.finalize_model_grads_func(
[model],
total_num_tokens if config.calculate_per_token_loss else None,
grad_finalize_pgs=grad_finalize_pgs,
)

if config.timers is not None:
config.timers('forward-backward').stop()

if (
hasattr(config, 'enable_cuda_graph')
and config.enable_cuda_graph
and config.cuda_graph_scope != "full_iteration"
):
create_cudagraphs()

return forward_data_store

其流程如下:

  1. 如果没有设置grad_finalize_pgs,就构造默认值,从而在做梯度规约/归并时,需要明确“在哪些通信域里做什么规约”,例如DP/DP×CP:参数梯度的 all-reduce 或 reduce-scatter

  2. 做一些参数检查

  3. 然后分为MoE overlap 与普通路径:

    1. MoE会调用combined_1f1b_schedule_for_no_pipelining

    2. 普通路径就是执行num_microbatchesforward_step,如果不是forward_only就执行backward_step,注意后一个 micro-batch 在 no_sync 外,从而让“梯度同步/规约”在这一轮的尾部发生(典型的梯度累积实现方式)

  4. 最后执行finalize_model_grads_func,这一步通常包含:

    • DP 的全梯度 all-reduce / reduce-scatter(取决于 optimizer/ZeRO 类策略)

    • sequence parallel 的 LayerNorm 等规约

    • embedding/pos embedding 的特定规约

其中一个micro batch的forward_step代码如下所示,注意到这里就会调用到用户传入的forward_step_func函数了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
cp_group_size,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
vp_stage=None,
is_last_stage=True,
):
"""Forward step for passed-in model.

If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.

Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:

1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:

a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
vp_stage (int, optional):
The virtual pipeline stage. Defaults to None.
is_last_stage (bool, optional):
Whether it is the last stage. Defaults to True.
Also considering virtual stages.
In case of PP/VPP, is_last_stage/is_vp_last_stage.

Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler

if config.timers is not None:
config.timers('forward-compute', log_level=2).start()

if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)

unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True

set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)

if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
output_tensor, num_tokens = forward_step_calc_loss(
model,
output_tensor,
loss_func,
config,
vp_stage,
collect_non_loss_data,
num_microbatches,
forward_data_store,
cp_group_size,
is_last_stage,
)

if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens

train_step代码&流程

train_step的代码如下,它是训练循环里真正执行一次参数更新(或跳过更新) 的函数。它把 “梯度清零 → 前后向(可能重跑)→ optimizer.step → LR scheduler.step → 统计/归约 loss” 串起来,并与 容错重跑机制、分布式优化器/通信重叠、视觉预训练特殊逻辑交织在一起。

1
train_step

其流程如下:

  1. 其计算Forward的代码外包了一层while rerun_state_machine.should_run_forward_backward(data_iterator),这是为了当检测到某些 transient 错误或需要重新取 batch 时可以进行容错与重跑。具体Forward如下

    1. 清空model和optimizer的grad

    2. 一些特殊逻辑,包含形状调整函数(仅 ModelOpt 蒸馏 + PP)和mxfp8 参数复用 grad buffer 时的参数 buffer 拷贝

    3. 调用forward_backward_func执行一次forward并拿到loss

  2. 统一判断是否要 checkpoint或exit

  3. 清空cuda cache,判断是否要走视觉预训练的特殊逻辑

  4. 调用optimizer.step()更新模型参数

  5. 进一步调整opt_param_scheduler的学习率

  6. 如果当前是 pipeline last stage 则汇总 microbatches 的 loss,主要是用于日志数据记录等

evaluate_and_print_results

evaluate_and_print_results的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def evaluate_and_print_results(
prefix,
forward_step_func,
data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=None,
):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
if write_to_tensorboard:
writer = get_tensorboard_writer()
else:
writer = None

wandb_writer = get_wandb_writer()

data_iterators = data_iterator if args.multiple_validation_sets else [data_iterator]

if not args.multiple_validation_sets:
eval_iters = [args.eval_iters]
else:
eval_iters = args.eval_iters

if args.full_validation:
assert len(eval_iters) == len(data_iterators)

# with full validation we need to distribute eval_iters to all ranks
if mpu.get_tensor_model_parallel_rank() == 0:
eval_iters = torch.tensor(args.eval_iters, dtype=torch.long, device='cuda')
else:
eval_iters = torch.tensor([0] * len(eval_iters), dtype=torch.long, device='cuda')
torch.distributed.broadcast(eval_iters, 0)
eval_iters = eval_iters.tolist()
args.eval_iters = eval_iters[0] if not args.multiple_validation_sets else eval_iters

for index, (iterator, iterations) in enumerate(zip(data_iterators, eval_iters)):
suffix = ""
if args.multiple_validation_sets:
suffix = f"-{index}"
total_loss_dict, collected_non_loss_data, timelimit = evaluate(
forward_step_func,
iterator,
model,
process_non_loss_data_func,
config,
verbose,
non_loss_data_func,
eval_iters=iterations,
)
# Timelimit hit during evaluation
if timelimit:
return
string = f' validation{suffix} loss at {prefix} | '
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer:
writer.add_scalar('{} validation{}'.format(key, suffix), total_loss_dict[key].item(), iteration)
writer.add_scalar(
'{} validation{} vs samples'.format(key, suffix),
total_loss_dict[key].item(),
args.consumed_train_samples,
)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar('{} validation{} ppl'.format(key, suffix), ppl, iteration)
writer.add_scalar(
'{} validation{} ppl vs samples'.format(key, suffix), ppl, args.consumed_train_samples
)
if wandb_writer and is_last_rank():
wandb_writer.log(
{'{} validation{}'.format(key, suffix): total_loss_dict[key].item()}, iteration
)

if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)

length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)

其流程如下:

  1. 进行TensorBoard 与 WandB的初始化

  2. 如果是多验证集,就获取对应的eval_iters列表

  3. 如果设置了full_validation,就把验证集发给所有rank

  4. 对每个验证集调用 evaluate(…)

  5. 打印对应的日志,写入到WandB

  6. 调用process_non_loss_data_func进行自定义的非 loss 数据的处理

evaluate

evaluate的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def evaluate(
forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
config,
verbose=False,
non_loss_data_func=None,
eval_iters=None,
):
"""Evaluation."""
args = get_args()
timers = get_timers()

timers('evaluate', log_level=0).start(barrier=True)

if args.vision_pretraining and args.vision_pretraining_type == "dino":
from megatron.legacy.model.vision.knn_monitor import compute_feature_bank

compute_feature_bank(model)

# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()

# Disable result validation during evaluation
rerun_state_machine = get_rerun_state_machine()
rerun_mode = rerun_state_machine.get_mode()
rerun_state_machine.set_mode(RerunMode.DISABLED)

total_loss_dict = {}

# make validation batch size independent from training batch size
eval_batch_size = args.global_batch_size
eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size)
forward_backward_func = get_forward_backward_func()
if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":
forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)

if eval_iters is None:
eval_iters = args.eval_iters

with torch.no_grad():
iteration = 0
if verbose:
print_rank_0(f'Evaluating on {eval_iters * eval_batch_size} samples')
while iteration < eval_iters:
iteration += 1
if verbose:
print_rank_0(f'Evaluating iter {iteration}/{eval_iters}')

# Don't care about timing during evaluation
config.timers = None
ft_integration.on_eval_step_start()
loss_dicts = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=eval_num_microbatches,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True,
)
ft_integration.on_eval_step_end()
config.timers = get_timers()

# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()

if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for key in loss_dicts[0].keys():
if key not in total_loss_dict:
total_loss_dict[key] = torch.tensor(
[0.0, 0.0], dtype=torch.float
).cuda()
val = [x[key].view(-1) for x in loss_dicts]

if val[0].numel() == 2:
if args.sft:
# normalize over micro batch instead of global
val = torch.vstack(val)
val = val[:, 0] / val[:, 1]
val = val.mean()
torch.distributed.all_reduce(
val,
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
val /= torch.distributed.get_world_size(
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
total_loss_dict[key][0] += val
total_loss_dict[key][1] += 1
else :
val = torch.vstack(val).sum(dim=0)
torch.distributed.all_reduce(
val,
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
total_loss_dict[key] += val
elif val[0].numel() == 1:
val = torch.cat(val).sum()
total_loss_dict[key][0] += val
total_loss_dict[key][1] += len(loss_dicts)
else:
raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")

args.consumed_valid_samples += eval_batch_size

if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.tensor(
[train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda'
)
torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
rerun_state_machine.set_mode(rerun_mode)
print_rank_0('Exiting during evaluation, timelimit reached')
return None, None, True

collected_non_loss_data = None
if non_loss_data_func is not None:
collected_non_loss_data = non_loss_data_func(model)
elif process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True,
collect_non_loss_data=True,
)

# Move model back to the train mode.
for model_module in model:
model_module.train()

for key in total_loss_dict:
numerator, denominator = total_loss_dict[key]
total_loss_dict[key] = numerator / denominator

timers('evaluate').stop()
timers.log(['evaluate'])

rerun_state_machine.set_mode(rerun_mode)

rerun_state_machine.set_mode(rerun_mode)

return total_loss_dict, collected_non_loss_data, False

其整体流程如下所示:

  1. 切换模型到 eval(),并临时禁用rerun_state_machine

  2. 计算evaluate用的 microbatch 大小

  3. 与train类似,得到forward_backward_func

  4. 在no_grad下依次执行各个eval_iter:

    1. 调用forward_backward_func,注意调用时专门设置了forward_only=True

    2. 在 pipeline last stage 收集与规约 loss

    3. 通过评估样本计数与 time limit 计算是否需要提前退出

    4. 执行non_loss_data_funcforward_backward_func

  5. 将model恢复训练模式,并把累计的 numerator/denominator 转成标量

  6. 恢复rerun_state_machine、返回结果


【Megatron-LM源码分析(二)】-GPT模型pretrain流程
http://example.com/2025/12/22/megatron-lm-pre-train-process/
作者
滑滑蛋
发布于
2025年12月22日
许可协议