【Megatron-LM源码分析(六)】-流水线并行-1F1B

理论基础

流水线并行相关代码

训练数据获取

pretrain_gpt中的get_batch函数中,可以看到,只有当前worker变为pp的最后一个stage或者是变为pp的第一个stage,才会去尝试获取数据,否则其数据都是来自pp前后的worker。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def get_batch(data_iterator):
"""Generate a batch."""

# TODO: this is pretty hacky, find a better way
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None

# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return batch.values()

模型构造

获取模型是在train训练中的get_model中获取,主要代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type

# Build model.
def build_model():
if (
mpu.get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
this_model = model_provider_func(
pre_process=pre_process, post_process=post_process, vp_stage=i)
this_model.model_type = model_type
this_model.vp_stage = i
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(pre_process=pre_process, post_process=post_process)
model.model_type = model_type
return model

if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else:
model = build_model()
  • 可以看到这里主要是对vp进行了专门的处理,借助用户提供的model_provider函数获取各个vp划分下的模型,从而做到在一个pp rank内存储多个vp进一步切分的模型

用户提供的model_provider函数如下所示,其主要是构建出被pp切分后的GPTModel模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def model_provider(
pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.

If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.

Returns:
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()

if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)

use_te = args.transformer_impl == "transformer_engine"

if args.record_memory_history:
torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True,
)

def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump

dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)

torch._C._cuda_attach_out_of_memory_observer(oom_observer)

print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)

if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else:
# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
mtp_block_spec = None
if args.mtp_num_layers is not None:
if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0:
# Get the decoder layer spec explicitly if no decoder layer in the last stage,
# Only happens with block spec (TransformerBlockSubmodules) when using MoE.
transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config)
else:
transformer_layer_spec_for_mtp = transformer_layer_spec
mtp_block_spec = get_gpt_mtp_block_spec(
config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage
)

model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)

return model

model_provider中进行模型构建时会先尝试构建模型GPTModel,在其中其会构建TransformerBlockTransformerBlock的初始化代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
model_comm_pgs: ModelCommProcessGroups = None,
vp_stage: Optional[int] = None,
):
super().__init__(config=config)

self.submodules = _get_block_submodules(config, spec, vp_stage)
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.vp_stage = vp_stage

# required for pipeline parallel schedules
self.input_tensor = None

self.checkpoint_core_attention = (
self.config.recompute_granularity == 'selective'
and "core_attn" in self.config.recompute_modules
)

if get_cpu_offload_context is not None:
(self.offload_context, self.group_prefetch_offload_commit_async) = (
get_cpu_offload_context(
self.config.cpu_offloading,
self.config.cpu_offloading_num_layers,
self.config.num_layers,
self.config.cpu_offloading_activations,
self.config.cpu_offloading_weights,
self.config.cpu_offloading_double_buffering,
)
)
self.config._cpu_offloading_context = (
self.offload_context if self.config.cpu_offloading else None
)
else:
assert (
self.config.cpu_offloading is False
), "CPU Offloading is enabled when TE is not present"

self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None
self.config._cpu_offloading_context = None

if model_comm_pgs is None:
model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups()
self.model_comm_pgs = model_comm_pgs

self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)

  • 可以看到其首先通过_get_block_submodules(config, spec, vp_stage)得到了self.submodules

    • _get_block_submodules的代码如下所示
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    def _get_block_submodules(
    config: TransformerConfig,
    spec: Union[TransformerBlockSubmodules, ModuleSpec],
    vp_stage: Optional[int] = None,
    ) -> TransformerBlockSubmodules:
    """
    Retrieve or construct TransformerBlockSubmodules based on the provided specification.

    Args:
    config (TransformerConfig): Configuration object for the transformer model.
    spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the
    transformer block submodules. Can be either a TransformerBlockSubmodules
    instance or a ModuleSpec.
    vp_stage (Optional[int]): Virtual pipeline stage number.

    Returns:
    TransformerBlockSubmodules: The submodules for the transformer block.
    """

    # Transformer block submodules.
    if isinstance(spec, TransformerBlockSubmodules):
    return spec

    # ModuleSpec here is generally assumed to be for a transformer layer that
    # is implemented in `transformer_layer.py` or if it subclasses
    # `BaseTransformerLayer` from the `transformer_layer.py` file.
    elif isinstance(spec, ModuleSpec):
    if issubclass(spec.module, TransformerBlock):
    return spec.submodules
    elif issubclass(spec.module, BaseTransformerLayer):
    num_layers = get_num_layers_to_build(config, vp_stage)
    return TransformerBlockSubmodules(
    layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
    )
    else:
    raise Exception(f"specialize for {spec.module.__name__}.")
    else:
    raise Exception(f"specialize for {type(spec).__name__}.")

    • 正常情况下会先通过num_layers = get_num_layers_to_build(config, vp_stage)得到在本pp rank下需要构建的layer数量。

      • get_num_layers_to_build的代码如下所示
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      def get_num_layers_to_build(config: TransformerConfig, vp_stage: Optional[int] = None) -> int:
      """
      Determine the number of transformer layers to build for the current pipeline stage.
      Args:
      config (TransformerConfig): Configuration object containing transformer model parameters.
      vp_stage (Optional[int]): Virtual pipeline stage number.

      Returns:
      int: The number of layers to be built for the current pipeline stage.
      """
      # If we have a custom PP layout, straightforwardly
      # return the number of decoders in the layout array.
      if config.pipeline_model_parallel_layout is not None:
      return config.pipeline_model_parallel_layout.get_num_layers_to_build(
      layer_type=LayerType.decoder, vp_stage=vp_stage
      )

      if (
      config.num_layers_in_first_pipeline_stage is not None
      or config.num_layers_in_last_pipeline_stage is not None
      ):

      assert not (
      config.account_for_embedding_in_pipeline_split
      or config.account_for_loss_in_pipeline_split
      ), " \
      Does not support standalone embedding stage and standalone loss stage with uneven pp"
      # Number of layers to distribute over rest of pipeline stages
      layers_to_distribute = config.num_layers
      # Number of pipeline stages left for distributing transformer layers
      pipeline_stages_left = parallel_state.get_pipeline_model_parallel_world_size()

      # If the uneven first (last) pipeline stage is enabled, remove the specified number
      # of layers to calculate the number of layers on each middle pipeline stage.
      if config.num_layers_in_first_pipeline_stage is not None:
      layers_to_distribute -= config.num_layers_in_first_pipeline_stage
      pipeline_stages_left -= 1

      if config.num_layers_in_last_pipeline_stage is not None:
      layers_to_distribute -= config.num_layers_in_last_pipeline_stage
      pipeline_stages_left -= 1

      # If pp_size <= 2, we do not have any intermediate pipeline stages, and we do not
      # need to check if the left over layers are divisible by the left over stages.
      if pipeline_stages_left > 0:
      assert (
      layers_to_distribute % pipeline_stages_left == 0
      ), "With uneven pipelineing the left over layers must be divisible by left over stages"
      num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
      else:
      num_layers_per_pipeline_rank = 0

      # If the uneven first (last) pipeline stage is enabled, return the specified number
      # of layers for all virtual pipeline parallel stages within the first (last) pipeline
      # parallel stage.
      if (
      parallel_state.is_pipeline_first_stage(ignore_virtual=True)
      and config.num_layers_in_first_pipeline_stage is not None
      ):
      num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage

      if (
      parallel_state.is_pipeline_last_stage(ignore_virtual=True)
      and config.num_layers_in_last_pipeline_stage is not None
      ):
      num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
      else:
      # Include the embedding layer and loss layer into pipeline parallelism partition
      num_layers = config.num_layers
      if config.account_for_embedding_in_pipeline_split:
      num_layers += 1

      if config.account_for_loss_in_pipeline_split:
      num_layers += 1

      assert (
      num_layers % config.pipeline_model_parallel_size == 0
      ), "num_layers should be divisible by pipeline_model_parallel_size"
      num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size

      if (
      parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None
      and config.pipeline_model_parallel_size > 1
      ):
      # Interleaved pipeline parallelism:
      # Number of layers in each model chunk is the number of layers in the stage,
      # divided by the number of model chunks in a stage.
      # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
      # layers to stages like (each list is a model chunk):
      # Stage 0: [0] [2] [4] [6]
      # Stage 1: [1] [3] [5] [7]
      # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
      # layers to stages like (each list is a model chunk):
      # Stage 0: [0, 1] [4, 5]
      # Stage 1: [2, 3] [6, 7]
      vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()

      assert (
      num_layers_per_pipeline_rank % vp_size == 0
      ), f"num_layers_per_pipeline_rank {num_layers_per_pipeline_rank} \
      should be divisible by vp_size {vp_size}"
      num_layers_per_virtual_stage = num_layers_per_pipeline_rank // vp_size

      num_layers_to_build = num_layers_per_virtual_stage

      else:
      # Non-interleaved pipeline parallelism:
      # Each stage gets a contiguous set of layers.
      num_layers_to_build = num_layers_per_pipeline_rank

      # The embedding (or loss) layer cannot function as a standalone transformer layer
      # Reduce the number of layers to construct by 1 on the first (or last) stage if the
      # embedding (or loss) layer is included in the pipeline parallelism partition and placement.
      if (
      parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage)
      and config.account_for_embedding_in_pipeline_split
      ):
      num_layers_to_build -= 1
      assert num_layers_to_build >= 0, "Not enough layers in the first virtual pipeline stage"

      if (
      parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)
      and config.account_for_loss_in_pipeline_split
      ):
      num_layers_to_build -= 1
      assert num_layers_to_build >= 0, "Not enough layers in the last virtual pipeline stage"

      return num_layers_to_build

      • 其支持自定义的各pp的layer数量,也支持自动计算,即将layer数量按照pp维度进行均分,在自动计算的时候有account_for_embedding_in_pipeline_splitaccount_for_loss_in_pipeline_split值得注意,其表示在平分layers的是否将embedding或loss层记为一个layer,并且还支持自定义第一层和最后一层的layer数量

      • 如果是带vp的情况,其返回的结果就是进一步被vp维度整除过的数量:num_layers_per_pipeline_rank // vp_size

    • 在得到num_layers后会构建出TransformerBlockSubmodules(layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl)并返回

  • 然后其会调用_build_layers

    • _build_layers代码如下
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    def _build_layers(self):
    # Transformer layers.
    # @jcasper can we improve how we deal with layer_number?
    # currently it's only used in CoreAttention?
    # if self.apply_query_key_layer_scaling:
    # coeff = self.layer_number
    # self.norm_factor *= coeff
    def build_layer(layer_spec, layer_number):
    global_layer_number = layer_number + get_transformer_layer_offset(
    self.config, self.vp_stage
    ) # 1-based index
    if self.config.heterogeneous_block_specs:
    layer_config = self.config.get_config_for_layer(global_layer_number)
    else:
    layer_config = self.config

    fp8_init_context = get_fp8_context(layer_config, global_layer_number - 1, is_init=True)
    with fp8_init_context:
    module = build_module(
    layer_spec,
    config=layer_config,
    layer_number=layer_number,
    model_comm_pgs=self.model_comm_pgs,
    vp_stage=self.vp_stage,
    )
    return module

    # offset is implicit in TransformerLayer
    self.layers = torch.nn.ModuleList(
    [
    build_layer(layer_spec, i + 1)
    for i, layer_spec in enumerate(self.submodules.layer_specs)
    ]
    )

    # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293)
    # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
    # self.post_process and self.post_layer_norm guide this behavior
    if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
    self.final_layernorm = build_module(
    self.submodules.layer_norm,
    config=self.config,
    hidden_size=self.config.hidden_size,
    eps=self.config.layernorm_epsilon,
    )
    else:
    self.final_layernorm = None # Either this or nn.Identity

    • 其主要内容是实际构建各layer,此外将各layer赋值上global_layer_number

P2P通信P2PCommunicator整理

代码流程整理

P2PCommunicator主要负责流水线并行中的P2P通信,其代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
class P2PCommunicator:
"""P2P (Point-to-Point) Communicator for pipeline parallelism.

This class handles communication between pipeline stages by managing
tensor exchanges between consecutive stages in the pipeline.
"""

def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig):
# Basic attrs
self.pp_group = pp_group
self.config = config

world_size = self.pp_group.size()
curr_rank_in_pg = self.pp_group.rank()

next_rank_pg = (curr_rank_in_pg + 1) % world_size
prev_rank_pg = (curr_rank_in_pg - 1) % world_size

self.next_rank: int | None = dist.get_global_rank(self.pp_group, next_rank_pg)
self.prev_rank: int | None = dist.get_global_rank(self.pp_group, prev_rank_pg)
self.virtual_pipeline_model_parallel_size = (
config.virtual_pipeline_model_parallel_size
if config.virtual_pipeline_model_parallel_size is not None
else None
)

def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.

Args:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
config = self.config
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty(
(3,), device=torch.cuda.current_device(), dtype=torch.int64
)
if recv_next:
recv_next_shape_tensor = torch.empty(
(3,), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(
tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_next is not None:
send_next_shape_tensor = torch.tensor(
tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
)

if config.use_ring_exchange_p2p:
torch.distributed.ring_exchange(
tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=self.pp_group,
)
else:
ops = []
if send_prev_shape_tensor is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, send_prev_shape_tensor, self.prev_rank
)
ops.append(send_prev_op)
if recv_prev_shape_tensor is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_prev_shape_tensor, self.prev_rank
)
ops.append(recv_prev_op)
if send_next_shape_tensor is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, send_next_shape_tensor, self.next_rank
)
ops.append(send_next_op)
if recv_next_shape_tensor is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_next_shape_tensor, self.next_rank
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()

# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch.cuda.synchronize()

recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor.tolist()

recv_next_shape = [0, 0, 0]
if recv_next_shape_tensor is not None:
recv_next_shape = recv_next_shape_tensor.tolist()

return recv_prev_shape, recv_next_shape

def _communicate(
self,
*,
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
wait_on_reqs: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.

Args:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)

tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)

recv_prev (boolean, required):
whether tensor should be received from previous rank.

recv_next (boolean, required):
whether tensor should be received from next rank.

tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).

wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.

Returns:
tuple containing

- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.

"""

config = self.config
tensor_recv_prev_func = None
tensor_recv_next_func = None

if not config.variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = self._communicate_shapes(
tensor_send_next, tensor_send_prev, recv_prev, recv_next
)

def create_tensor_recv_prev():
return torch.empty(
recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)

def create_tensor_recv_next():
return torch.empty(
recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)

if recv_prev:
if config.pipeline_dtype is None:
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev_func = create_tensor_recv_prev

if recv_next:
if config.pipeline_dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next_func = create_tensor_recv_next

# Send tensors in both the forward and backward directions as appropriate.
if config.use_ring_exchange_p2p:

def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []

p2p_func = _ring_exchange_wrapper
elif config.batch_p2p_comm:
assert wait_on_reqs
p2p_func = _batched_p2p_ops
else:
p2p_func = _p2p_ops

pp_group = self.pp_group
next_rank = self.next_rank
prev_rank = self.prev_rank

if config.use_ring_exchange_p2p or config.batch_p2p_comm:
reqs = []
else:
reqs = {}

tensor_recv_prev = None
tensor_recv_next = None
if tensor_recv_prev_func is not None:
tensor_recv_prev = tensor_recv_prev_func()

if tensor_recv_next_func is not None:
tensor_recv_next = tensor_recv_next_func()

p2p_reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=pp_group,
prev_pipeline_rank=prev_rank,
next_pipeline_rank=next_rank,
)
if isinstance(p2p_reqs, list):
reqs.extend(p2p_reqs)
else:
reqs.update(p2p_reqs)

if wait_on_reqs and len(reqs) > 0:
for req in reqs if isinstance(reqs, list) else reqs.values():
req.wait()
reqs = None

if config.batch_p2p_comm and config.batch_p2p_sync:
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize()

return tensor_recv_prev, tensor_recv_next, reqs

@nvtx_decorator()
def recv_forward(
self, tensor_shapes, is_first_stage: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
unwrap_tensor_shapes = False
if is_single_shape(tensor_shapes):
unwrap_tensor_shapes = True
tensor_shapes = [tensor_shapes]
input_tensors = []
config = self.config
for tensor_shape in tensor_shapes:
if is_first_stage:
input_tensor = None
else:
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
input_tensor, _, _ = self._communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
)
if config.timers is not None:
config.timers('forward-recv').stop()
input_tensors.append(input_tensor)
if unwrap_tensor_shapes:
return input_tensors[0]
return input_tensors

@nvtx_decorator()
def recv_backward(
self, tensor_shapes, is_last_stage: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""Receive tensor from next rank in pipeline (backward receive)."""
unwrap_tensor_shapes = False
if is_single_shape(tensor_shapes):
unwrap_tensor_shapes = True
tensor_shapes = [tensor_shapes]
config = self.config
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if is_last_stage:
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
_, output_tensor_grad, _ = self._communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
)
if config.timers is not None:
config.timers('backward-recv').stop()
output_tensor_grads.append(output_tensor_grad)
if unwrap_tensor_shapes:
return output_tensor_grads[0]
return output_tensor_grads

@nvtx_decorator()
def send_forward(self, output_tensors, is_last_stage: bool) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
config = self.config
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]

for output_tensor in output_tensors:
if not is_last_stage:
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
self._communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=None,
)
if config.timers is not None:
config.timers('forward-send').stop()

@nvtx_decorator()
def send_backward(self, input_tensor_grads, is_first_stage: bool) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
config = self.config
for input_tensor_grad in input_tensor_grads:
if not is_first_stage:
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
self._communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=None,
)
if config.timers is not None:
config.timers('backward-send').stop()

@nvtx_decorator()
def send_forward_recv_backward(
self, output_tensors, tensor_shapes, is_last_stage: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""Batched send and recv with next rank in pipeline."""
config = self.config
unwrap_output_tensors = False
if not isinstance(output_tensors, list):
unwrap_output_tensors = True
output_tensors = [output_tensors]
if not isinstance(tensor_shapes, list):
tensor_shapes = [tensor_shapes]
output_tensor_grads = []
for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
if is_last_stage:
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, _ = self._communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
)
if config.timers is not None:
config.timers('forward-send-backward-recv').stop()
output_tensor_grads.append(output_tensor_grad)
if unwrap_output_tensors:
return output_tensor_grads[0]
return output_tensor_grads

@nvtx_decorator()
def send_backward_recv_forward(
self, input_tensor_grads, tensor_shapes, is_first_stage: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""Batched send and recv with previous rank in pipeline."""
config = self.config
unwrap_input_tensor_grads = False
if not isinstance(input_tensor_grads, list):
unwrap_input_tensor_grads = True
input_tensor_grads = [input_tensor_grads]
if not isinstance(tensor_shapes, list):
tensor_shapes = [tensor_shapes]
input_tensors = []
for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
if is_first_stage:
input_tensor = None
else:
if config.timers is not None:
config.timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _, _ = self._communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
)
if config.timers is not None:
config.timers('backward-send-forward-recv').stop()
input_tensors.append(input_tensor)
if unwrap_input_tensor_grads:
return input_tensors[0]
return input_tensors

@nvtx_decorator()
def send_forward_recv_forward(
self,
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline."""
config = self.config
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _, wait_handles = self._communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if overlap_p2p_comm:
return input_tensor, wait_handles
return input_tensor

@nvtx_decorator()
def send_backward_recv_backward(
self,
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
config = self.config
if config.timers is not None:
config.timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, wait_handles = self._communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
)
if config.timers is not None:
config.timers('backward-send-backward-recv').stop()
if overlap_p2p_comm:
return output_tensor_grad, wait_handles
return output_tensor_grad

@nvtx_decorator()
def send_forward_backward_recv_forward_backward(
self,
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
) -> torch.Tensor:
"""Batched send and recv with previous and next ranks in pipeline."""
config = self.config
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
input_tensor, output_tensor_grad, _ = self._communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
)
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad

  • PP并行中每个worker在初始化时P2PCommunicatorP2PCommunicator就会依据当前worker的rank来自动计算前后的worker:

    • next_rank_pg = (curr_rank_in_pg + 1) % world_sizeself.next_rank: int | None = dist.get_global_rank(self.pp_group, next_rank_pg)

    • prev_rank_pg = (curr_rank_in_pg - 1) % world_sizeself.prev_rank: int | None = dist.get_global_rank(self.pp_group, prev_rank_pg)

  • P2PCommunicator的基础函数是_communicate,其他函数都是以此为基础构建出来的

    1. 首先计算recv_prev_shaperecv_next_shape

      • 如果sequence长度不是变长的,那么其形状直接等于传入的tensor_shape

      • 如果sequence长度是变长的,就调用_communicate_shapes通过交流获得接受的形状

        • 因为这里的Tensor形状都是假设为[S, B, H],其中S是可变的,所以只需要传递一个具有3个int的Tensor代表shape。因为除了首尾两个worker外,其他的在前向传播时是从pre接收,向next发送,而在反向传播时是从next接收,向pre发送,所以有4个shape需要获取。

        • 发/收 shape 的两种底层实现

          • 如果 config.use_ring_exchange_p2p

            • torch.distributed.ring_exchange 一次性完成 prev/next 双向交换。
          • 否则:

            • 构造 P2POp(isend/irecv, …) 列表

            • 调 batch_isend_irecv(ops) 发起

            • 对每个 req wait()

            • 最后 torch.cuda.synchronize() 做额外保护(应对历史 race bug)

    2. 依据是否要接收pre或next,创建对应的空Tensor

    3. 选择p2p方式获取p2p_func,总共有3种:

      1. use_ring_exchange_p2p

        • 用 ring_exchange 一次做 prev/next 的 send/recv

        • 函数写成 wrapper 并返回空 req 列表(因为 ring_exchange 是同步式接口)

        • 适用:你希望代码极简或某些后端下 ring_exchange 更稳定。

      2. batch_p2p_comm

        • 用 _batched_p2p_ops:

          • 构造 P2POp(isend/irecv, …)

          • batch_isend_irecv(ops) 一次发起

          • 返回 reqs: List[Work]

        • 并且这里 assert wait_on_reqs,也就是说 batched 模式下默认要求同步等待(除非外面再封一层 overlap 逻辑,但这里看到它强制 wait,这样实现最稳)。

      3. 默认 _p2p_ops(非 batched):

        • 用 torch.distributed.isend/irecv 逐个发起,返回字典 reqs = {“send_next”: Work, …}
    4. 调用p2p_func进行实际数据传输,获得结果p2p_reqs

    5. 如果有wait_on_reqs,那么就进行req.wait()来等待,默认是True,即进行等待,只有在send_forward_recv_forwardsend_backward_recv_backward时会依据overlap_p2p_comm来判断,如果overlap_p2p_comm为false,即不overlap,那么就直接还是等待,否则就不等待了

    6. 如果配置了 config.batch_p2p_comm and config.batch_p2p_sync,就进行torch.cuda.synchronize()

    7. 最终返回tensor_recv_prev, tensor_recv_next, reqs

  • _communicate为基准,其构造了多个功能函数:

    • 单次发送、接收函数:

      • recv_forward:如果不是first_stage,那么就通过

      input_tensor, _, _ = self._communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, tensor_shape=tensor_shape)获取input_tensor并返回

      • recv_backward:类似,不过参数变为了recv_next=True

      • send_forward:如果不是last_stage,那么就通过

      self._communicate(tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=False, tensor_shape=None)发送output_tensor

      • send_backward:类似,不过参数变为了tensor_send_prev=input_tensor_grad
    • 发送与接收重叠函数:

      • send_forward_recv_backward:如果不是last_stage,就通过

      _, output_tensor_grad, _ = self._communicate(tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, tensor_shape=tensor_shape)发送output_tensor,然后接受output_tensor_grad并返回

      • send_backward_recv_forward:类似

      • send_forward_recv_forward:类似,但是注意参数wait_on_reqs=(not overlap_p2p_comm)

      • send_backward_recv_backward:类似,但是注意参数wait_on_reqs=(not overlap_p2p_comm)

    • 全量通信函数:

      • send_forward_backward_recv_forward_backward:其通过如下代码来发送、接收Forward、Backward结果
      1
      2
      3
      4
      5
      6
      7
      input_tensor, output_tensor_grad, _ = self.    _communicate    (
      tensor_send_next=output_tensor,
      tensor_send_prev=input_tensor_grad,
      recv_prev=recv_prev,
      recv_next=recv_next,
      tensor_shape=tensor_shape,
      )

3种P2P通信方式介绍

ring_exchange通信

  • 直接通过torch.distributed.ring_exchange(**kwargs)进行通信,代码如下所示:
1
2
3
def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []
  • ring_exchange 语义上就是“环形邻居交换”,可以同时指定:

    • tensor_send_prev / tensor_recv_prev

    • tensor_send_next / tensor_recv_next

    • group=pp_group

  • 其在一次调用里把四个方向的动作都描述完;底层通常会更高效地安排通信(比手工发四个 isend/irecv 更像一个“原子操作”)。

  • 其更不容易因为发送/接收顺序写错而导致潜在死锁/乱序问题(尤其某些后端实现对顺序敏感)。

  • 对“同时收发”的 pipeline 场景很贴合(典型:send_forward_recv_backward、send_backward_recv_forward)

  • 不过其依赖 PyTorch/后端对 ring_exchange 的实现质量;在某些组合上性能不如 batched isend/irecv,并且也更不容易做overlap

_p2p_ops通信

  • 其手动按需构造了多个isend、irecv来完成P2P通信,其代码如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup,
prev_pipeline_rank: int,
next_pipeline_rank: int,
):
reqs = {}
even_send_odd_recv_group = group
if group.size() == 2 and torch.distributed.get_backend(group) != 'ucc':
# Use the global process group for one of the two p2p communications
# to allow the overlap of the independent communications.
# Using the global process group is compatible because the pipeline-parallel
# communications set the source and destination by global rank.
# The only exception occurs when using the ‘ucc’ backend.
# Because the global communicator always uses the ‘nccl’ backend,
# we must ensure the else path is followed for the ‘ucc’ backend.
even_recv_odd_send_group = torch.distributed.group.WORLD
else:
even_recv_odd_send_group = group

if group.rank() % 2 == 0:
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=next_pipeline_rank, group=even_send_odd_recv_group
)
reqs["send_next"] = send_next_req

if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_recv_odd_send_group
)
reqs["recv_prev"] = recv_prev_req

if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_send_odd_recv_group
)
reqs["send_prev"] = send_prev_req

if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=next_pipeline_rank, group=even_recv_odd_send_group
)
reqs["recv_next"] = recv_next_req

else:
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_send_odd_recv_group
)
reqs["recv_prev"] = recv_prev_req

if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=next_pipeline_rank, group=even_recv_odd_send_group
)
reqs["send_next"] = send_next_req

if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=next_pipeline_rank, group=even_send_odd_recv_group
)
reqs["recv_next"] = recv_next_req

if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_recv_odd_send_group
)
reqs["send_prev"] = send_prev_req
return reqs

  • 其对PP维度为2时进行了一些专门处理,使得其even_recv_odd_send_group为专门的torch.distributed.group.WORLD

  • 为了避免全都向一个方向等待或全向同一个方向发生导致的死锁,以及避免因为大家“同时 send 同一方向”而出现拥塞或某些后端下的等待链,其通过识别当前rank的奇偶来控制是先处理send还是先处理recv,如下所示:

1
2
3
4
if group.rank() % 2 == 0:
send_next -> recv_prev -> send_prev -> recv_next
else:
recv_prev -> send_next -> recv_next -> send_prev
  • 各个操作的句柄最后都返回了回来,从而更适合做overlap,但是因为可能需要多个通信操作,从而容易使得python开销更大

_batched_p2p_ops通信

_batched_p2p_ops相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def _batched_p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup,
prev_pipeline_rank: int,
next_pipeline_rank: int,
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
else:
reqs = []
return reqs

  • 其先构造多个 P2POp,再一次性发起,最后也直接反正其中每个操作的句柄列表

  • 注意在调用_batched_p2p_ops的时候还有一个限制就是必须配置wait_on_reqs

  • 此外还一般会启用 torch.cuda.synchronize(),这是因为某些 PyTorch 版本的 batch_isend_irecv 曾出现竞态问题,所以这里提供 batch_p2p_sync 强制同步;这会影响性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
elif config.batch_p2p_comm:
assert wait_on_reqs
p2p_func = _batched_p2p_ops

#...

if wait_on_reqs and len(reqs) > 0:
for req in reqs if isinstance(reqs, list) else reqs.values():
req.wait()
reqs = None

if config.batch_p2p_comm and config.batch_p2p_sync:
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize()

PP调度

在实际的train函数中,会通过get_forward_backward_func获得执行一个batch的前向反向传播的forward_backward_func调度函数,获取时主要考虑是否有pp并行,pp并行中是否还包含了vp切分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.

Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.

Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.

The function returned takes the following arguments:

forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.

A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.

For example:

def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {'lm loss': averaged_loss[0]}

def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)

forward_backward_func(forward_step_func=forward_step, ...)

data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.

model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.

num_microbatches (int, required):
The number of microbatches to go through

seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.

micro_batch_size (int, required): The number of sequences in a microbatch.

decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.

forward_only (optional, default = False): Perform only the forward step

collect_non_loss_data (optional, bool, default=False): TODO

first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation
step.

adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and send
tensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now.
Takes in a list of receive shapes and a list of send shapes and returns the adjusted
respective list of shapes. Thus it is not used in the other forward-backward functions
which have different shape handling.

"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func

获取到的forward_backward_func函数会在进一步的train_step中被调用,代码如下所示,forward_backward_func会负责在一个batch内的如何调度多个micro batch进行前向、反向传播。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func):
"""Single training step."""
args = get_args()
timers = get_timers()

rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()

if has_nvidia_modelopt:
# [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors
adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(
model, args.seq_length, args.micro_batch_size, args.decoder_seq_length
)
else:
adjust_tensor_shapes_fn = None

# For the mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap,
# we need to call the _copy_main_params_to_param_buffer() after the grad buffer
# is zeroed by zero_grad_buffer() because param and grad buffer are shared.
if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather:
for optim_instance in optimizer.chained_optimizers:
if isinstance(optim_instance, DistributedOptimizer):
optim_instance._copy_main_params_to_param_buffer()

# Forward pass.
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False,
adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,
)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None

# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()

# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

# Update parameters.

timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()

# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)

# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)

# Update learning rate.
if update_successful:
increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1

# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()

if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}

for key in losses_reduced[0].keys():
val = [x[key].view(-1) for x in losses_reduced]
if val[0].numel() == 2:
if args.sft:
# in mcore the normalization happens on micro batch instead of global
val = torch.vstack(val)
val = val[:, 0] / val[:, 1]
val = val.mean()
torch.distributed.all_reduce(
val,
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
val /= torch.distributed.get_world_size(
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
loss_reduced[key] = val
else:
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
val = torch.vstack(val).sum(dim=0)
torch.distributed.all_reduce(
val,
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
loss_reduced[key] = val[0] / val[1]
elif val[0].numel() == 1:
# legacy behavior, we average over the number of microbatches
val = torch.cat(val).mean()
loss_reduced[key] = val
else:
raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")
return (
loss_reduced,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
)
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad

其中的参数forward_step_func往往就是用户提供的一次micro_batch前向传播的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
"""Forward training step.

Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
"""
args = get_args()
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers('batch-generator').stop()

with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
if return_schedule_plan:
assert args.overlap_moe_expert_parallel_comm, \
"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return schedule_plan, partial(loss_func, loss_mask, model=model)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)

# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)

forward_backward_pipelining_without_interleaving

我们这里先看不含vp切分即不进行交错1F1B调度的forward_backward_pipelining_without_interleaving

1F1B理论分析

前面理论基础提到的博客也有对1F1B的介绍,这里简单回顾一下。

  • 1F1B的流程如下图所示,下图pp并行度是4,一个batch中的micro_batch的含量为8。

注意其与Megatron-LM中的图有所不同,看代码应该是下面这份图才是准确的

  • 在1F1B流程中包含几个阶段,分别为:流水线预热(warmup)→ 稳态交替 forward/backward → 冷却(cooldown)把剩余 backward 做完

  • 然后1F1B还需要负责一个batch结束后进行数据并行下的梯度同步

代码分析

forward_backward_pipelining_without_interleaving其代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: Optional[int] = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None,
p2p_communicator: Optional[P2PCommunicator] = None,
grad_finalize_pgs: Optional[GradFinalizeProcessGroups] = None,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages. Returns dictionary with losses if the last stage, empty dict otherwise."""

if isinstance(model, list):
assert (
len(model) == 1
), "non-interleaved pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-interleaved pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]

config = get_model_config(model)
if config.overlap_p2p_comm:
raise ValueError(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)

if p2p_communicator is None and grad_finalize_pgs is None:
p2p_communicator = P2PCommunicator(
pp_group=parallel_state.get_pipeline_model_parallel_group(), config=config
)
tp_group = parallel_state.get_tensor_model_parallel_group()
cp_group = parallel_state.get_context_parallel_group()
embd_group = parallel_state.get_embedding_group(check_initialized=False)
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
pp_group = parallel_state.get_pipeline_model_parallel_group()

grad_finalize_pgs = GradFinalizeProcessGroups()
grad_finalize_pgs.tp = tp_group
grad_finalize_pgs.pp = pp_group
grad_finalize_pgs.embd = embd_group
grad_finalize_pgs.pos_embd = pos_emb_group
grad_finalize_pgs.cp = cp_group
grad_finalize_pgs.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
elif p2p_communicator is not None and grad_finalize_pgs is not None:
model_type = get_model_type(model)
assert model_type != ModelType.encoder_and_decoder, (
"encoder PP stages not yet supported when passing custom process groups. "
"support coming soon!"
)
assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config"
assert hasattr(grad_finalize_pgs, 'tp'), "grad_finalize_pgs must have tp_group"
assert hasattr(grad_finalize_pgs, 'cp'), "grad_finalize_pgs must have cp_group"
assert hasattr(grad_finalize_pgs, 'embd'), (
"grad_finalize_pgs must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group. "
" If you are using the default process group, please use "
" `parallel_state.get_embedding_group()` "
"If you don't need embd_group, you need to explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pos_embd'), (
"grad_finalize_pgs must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group. "
" If you are using the default process group, please use "
" `parallel_state.get_position_embedding_group()` "
"If you don't need pos_embd_group, you need to explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pp'), "grad_finalize_pgs must have pp_group"
assert hasattr(grad_finalize_pgs, 'dp_cp'), "grad_finalize_pgs must have dp_cp_group"
tp_group = grad_finalize_pgs.tp
cp_group = grad_finalize_pgs.cp
else:
raise ValueError(
"Invalid combination of p2p_communicator, grad_finalize_pgs "
"provide none or provide all the process groups"
)

# Needed only when gradients are finalized in M-Core
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(
config, model, is_pp_last_stage(p2p_communicator.pp_group)
)

if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None

def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()

def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None

disable_grad_sync()

# Compute number of warmup microbatches.
num_warmup_microbatches = (
p2p_communicator.pp_group.size() - p2p_communicator.pp_group.rank() - 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches

# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1

model_type = get_model_type(model)

rank = p2p_communicator.pp_group.rank()
recv_tensor_shapes = get_tensor_shapes(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
tp_group=tp_group,
cp_group=cp_group,
)
send_tensor_shapes = get_tensor_shapes(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
tp_group=tp_group,
cp_group=cp_group,
)
if adjust_tensor_shapes_fn is not None:
recv_tensor_shapes, send_tensor_shapes = adjust_tensor_shapes_fn(
recv_tensor_shapes, send_tensor_shapes
)

# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")

if not forward_only:
input_tensors = []
output_tensors = []
forward_data_store = []

# Run warmup forward passes.
for i in range(num_warmup_microbatches):
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
i % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None

input_tensor = p2p_communicator.recv_forward(
recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group)
)
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
cp_group_size=grad_finalize_pgs.cp.size(),
collect_non_loss_data=collect_non_loss_data,
checkpoint_activations_microbatch=checkpoint_activations_microbatch,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
is_last_stage=is_pp_last_stage(p2p_communicator.pp_group),
)
p2p_communicator.send_forward(output_tensor, is_pp_last_stage(p2p_communicator.pp_group))
total_num_tokens += num_tokens

if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)

# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communicator.recv_forward(
recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group)
)

# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)

# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
(i + num_warmup_microbatches) % max_outstanding_backprops
) >= config.num_microbatches_with_partial_activation_checkpoints
else:
checkpoint_activations_microbatch = None

output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
cp_group_size=grad_finalize_pgs.cp.size(),
collect_non_loss_data=collect_non_loss_data,
checkpoint_activations_microbatch=checkpoint_activations_microbatch,
is_first_microbatch=check_first_val_step(
first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
),
current_microbatch=i + num_warmup_microbatches,
is_last_stage=is_pp_last_stage(p2p_communicator.pp_group),
)
total_num_tokens += num_tokens

if forward_only:
p2p_communicator.send_forward(
output_tensor, is_pp_last_stage(p2p_communicator.pp_group)
)
if not last_iteration:
input_tensor = p2p_communicator.recv_forward(
recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group)
)
else:
output_tensor_grad = p2p_communicator.send_forward_recv_backward(
output_tensor, send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group)
)

# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)

# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if num_warmup_microbatches == 0 and last_iteration:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()

input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
config,
p2p_communicator.pp_group.size(),
)

if last_iteration:
input_tensor = None
p2p_communicator.send_backward(
input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group)
)
else:
input_tensor = p2p_communicator.send_backward_recv_forward(
input_tensor_grad,
recv_tensor_shapes,
is_pp_first_stage(p2p_communicator.pp_group),
)

# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):

# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if i == num_warmup_microbatches - 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()

input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

output_tensor_grad = p2p_communicator.recv_backward(
send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group)
)

input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
config,
pipeline_model_parallel_size=p2p_communicator.pp_group.size(),
)

p2p_communicator.send_backward(
input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group)
)

# Launch any remaining grad reductions.
if no_sync_context is not None:
enable_grad_sync()
if config.grad_sync_func is not None:
config.grad_sync_func(model.parameters())

if config.finalize_model_grads_func is not None and not forward_only:

# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute(
config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group
)

# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
[model],
total_num_tokens if config.calculate_per_token_loss else None,
grad_finalize_pgs=grad_finalize_pgs,
)

if config.timers is not None:
config.timers('forward-backward').stop()

if (
hasattr(config, 'enable_cuda_graph')
and config.enable_cuda_graph
and config.cuda_graph_scope != "full_iteration"
):
create_cudagraphs()

return forward_data_store

其整体流程如下:

  1. 参数检查,因为不采用vp交错,所以model数量与data_iterator数量都应该为1

  2. 如果没有传入p2p_communicator就构建对应的P2PCommunicator;如果没有传入grad_finalize_pgs就也构建对应的GradFinalizeProcessGroups。如果传入了就检查其是否完整

  3. 如果启用了no_sync,即一次反向传播后不立刻进行梯度同步,就构建对应的no_sync_context。因为在 pipeline 里,大部分 backward 进行时如果立刻触发 DP all-reduce,会产生额外同步/降低 overlap;在前面的pre_train流程分析中我们也看到Megatron-LM只在最后一次micro_batch同步中再进行梯度同步。

  4. 计算 warmup / steady 的 microbatch 数:

    • num_warmup_microbatches = min(pp_size - pp_rank - 1, num_microbatches),以上图为例,pp_rank=0的首个worker的num_warmup_microbatches=3pp_rank=3的最后一个worker的num_warmup_microbatches=0

    • num_microbatches_remaining = num_microbatches - num_warmup_microbatches,以上图为例,pp_rank=0的首个worker的num_microbatches_remaining=5pp_rank=3的最后一个worker的num_warmup_microbatches=8

  5. 与前述博客(https://slipegg.github.io/2025/12/08/Reducing-Activation-Paper-Note/)中提到的选择性激活方法有关,如果配置了选择性激活以减少峰值激活显存占用量,则只对一部分layers激活做保存。如果有相关配置就设定`max_outstanding_backprops = num_warmup_microbatches + 1`

  6. 推导PP传输的张量的大小

    1. 相关代码如下:

    2. 基本的,如果采用了上下文并行,那么就effective_seq_length // cp_group.size(),如果还采用了序列并行,那么进一步的effective_seq_length = effective_seq_length // tp_group.size()

    3. 最终推导出的形状是[effective_seq_length,batch_size,hidden_size]

    在这里也可以看到上下文并行是对序列上下文彻底的切分,而序列并行是在上下文并行的基础上的进一步切分

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    def get_tensor_shapes(
    *,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int,
    config,
    tp_group: torch.distributed.ProcessGroup,
    cp_group: torch.distributed.ProcessGroup,
    ):
    """
    Determine right tensor sizes (based on position of rank with respect to split rank) and
    model size.
    """

    tensor_shapes = []
    # Use decoder_seq_length if provided, otherwise use seq_length
    effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length
    effective_seq_length = effective_seq_length // cp_group.size()

    if config.sequence_parallel:
    effective_seq_length = effective_seq_length // tp_group.size()

    tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size))
    return tensor_shapes
  7. 如果不是forward_only模式,那么需要初始化input_tensoroutput_tensor,它们 是“跨 stage 传输的 activation”(该 stage 的输入输出),在 backward 时需要重新拿出来求梯度,是属于 pipeline schedule 的核心缓存。

  8. 进入warm up阶段,每个worker需要处理num_warmup_microbatches次,每次的流程为:

    • 通过p2p_communicator.recv_forward来获取上一层PP传递的input_tensor(第一层会自动跳过)

    • 依据接收到的输入或者从data_loader中获取数据进行自己本地模型的前向传播得到结果output_tensor

      1. forward_step,其中关键的set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")set_input_tensor(input_tensor)是将输入的input_tensor存入到了模型参数中的input_tensor
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      132
      133
      134
      def forward_step(
      forward_step_func,
      data_iterator,
      model,
      num_microbatches,
      input_tensor,
      forward_data_store,
      config,
      cp_group_size,
      collect_non_loss_data=False,
      checkpoint_activations_microbatch=None,
      is_first_microbatch=False,
      current_microbatch=None,
      vp_stage=None,
      is_last_stage=True,
      ):
      """Forward step for passed-in model.

      If it is the first stage, the input tensor is obtained from the data_iterator.
      Otherwise, the passed-in input_tensor is used.

      Args:
      forward_step_func (callable):
      The forward step function for the model that takes the
      data iterator as the first argument, and model as the second.
      This user's forward step is expected to output a tuple of two elements:

      1. The output object from the forward step. This output object needs to be a
      tensor or some kind of collection of tensors. The only hard requirement
      for this object is that it needs to be acceptible as input into the second
      function.
      2. A function to reduce (optionally) the output from the forward step. This
      could be a reduction over the loss from the model, it could be a function that
      grabs the output from the model and reformats, it could be a function that just
      passes through the model output. This function must have one of the following
      patterns, and depending on the pattern different things happen internally:

      a. A tuple of reduced loss and some other data. Note that in this case
      the first argument is divided by the number of global microbatches,
      assuming it is a loss, so that the loss is stable as a function of
      the number of devices the step is split across.
      b. A triple of reduced loss, number of tokens, and some other data. This
      is similar to case (a), but the loss is further averaged across the
      number of tokens in the batch. If the user is not already averaging
      across the number of tokens, this pattern is useful to use.
      c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
      of tensors, etc in the case of inference). To trigger case 3 you need
      to specify `collect_non_loss_data=True` and you may also want to
      specify `forward_only=True` in the call to the parent forward_backward
      function.
      data_iterator (iterator):
      The data iterator.
      model (nn.Module):
      The model to perform the forward step on.
      num_microbatches (int):
      The number of microbatches.
      input_tensor (Tensor or list[Tensor]):
      The input tensor(s) for the forward step.
      forward_data_store (list):
      The list to store the forward data. If you go down path 2.a or
      2.b for the return of your forward reduction function then this will store only the
      final dimension of the output, for example the metadata output by the loss function.
      If you go down the path of 2.c then this will store the entire output of the forward
      reduction function applied to the model output.
      config (object):
      The configuration object.
      collect_non_loss_data (bool, optional):
      Whether to collect non-loss data. Defaults to False.
      This is the path to use if you want to collect arbitrary output from the model forward,
      such as with inference use cases. Defaults to False.
      checkpoint_activations_microbatch (int, optional):
      The microbatch to checkpoint activations.
      Defaults to None.
      is_first_microbatch (bool, optional):
      Whether it is the first microbatch. Defaults to False.
      current_microbatch (int, optional):
      The current microbatch. Defaults to None.
      vp_stage (int, optional):
      The virtual pipeline stage. Defaults to None.
      is_last_stage (bool, optional):
      Whether it is the last stage. Defaults to True.
      Also considering virtual stages.
      In case of PP/VPP, is_last_stage/is_vp_last_stage.

      Returns:
      Tensor or list[Tensor]: The output object(s) from the forward step.
      Tensor: The number of tokens.
      """
      from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler

      if config.timers is not None:
      config.timers('forward-compute', log_level=2).start()

      if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
      model.set_is_first_microbatch()
      if current_microbatch is not None:
      set_current_microbatch(model, current_microbatch)

      unwrap_output_tensor = False
      if not isinstance(input_tensor, list):
      input_tensor = [input_tensor]
      unwrap_output_tensor = True

      set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
      set_input_tensor(input_tensor)

      if config.enable_autocast:
      context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
      else:
      context_manager = contextlib.nullcontext()
      with context_manager:
      if checkpoint_activations_microbatch is None:
      output_tensor, loss_func = forward_step_func(data_iterator, model)
      else:
      output_tensor, loss_func = forward_step_func(
      data_iterator, model, checkpoint_activations_microbatch
      )
      output_tensor, num_tokens = forward_step_calc_loss(
      model,
      output_tensor,
      loss_func,
      config,
      vp_stage,
      collect_non_loss_data,
      num_microbatches,
      forward_data_store,
      cp_group_size,
      is_last_stage,
      )

      if unwrap_output_tensor:
      return output_tensor, num_tokens
      return [output_tensor], num_tokens

      • 如下在TransformerBlock中的Forward的代码中也可以看出,如果没有pre_process,即不是第一层的话,就直接读取自身的input_tensor,如果是第一层的话就读取外部从data_loader中获取的数据
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      class TransformerBlock(MegatronModule):
      """Transformer class."""
      def forward(
      self,
      hidden_states: Union[Tensor, WrappedTensor],
      attention_mask: Optional[Tensor],
      context: Optional[Tensor] = None,
      context_mask: Optional[Tensor] = None,
      rotary_pos_emb: Optional[Tensor] = None,
      rotary_pos_cos: Optional[Tensor] = None,
      rotary_pos_sin: Optional[Tensor] = None,
      attention_bias: Optional[Tensor] = None,
      inference_context: Optional[BaseInferenceContext] = None,
      packed_seq_params: Optional[PackedSeqParams] = None,
      sequence_len_offset: Optional[Tensor] = None,
      *,
      inference_params: Optional[BaseInferenceContext] = None,
      ):
      """
      Perform the forward pass through the transformer block.

      This method handles the core computation of the transformer, including
      self-attention, optional cross-attention, and feed-forward operations.

      Args:
      hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h]
      where s is the sequence length, b is the batch size, and h is the hidden size.
      Can be passed as a WrappedTensor during inference to avoid an obsolete
      reference in the calling function.
      attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
      self-attention.
      context (Tensor, optional): Context tensor for cross-attention.
      context_mask (Tensor, optional): Mask for cross-attention context
      rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
      attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
      to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
      Used as an alternative to apply attention mask for TE cuDNN attention.
      inference_context (BaseInferenceContext, optional): Parameters for inference-time
      optimizations.
      packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
      processing.

      Returns:
      Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
      [s, b, h], and optionally the updated context tensor if cross-attention is used.
      """

      inference_context = deprecate_inference_params(inference_context, inference_params)

      # Delete the obsolete reference to the initial input tensor if necessary
      if isinstance(hidden_states, WrappedTensor):
      hidden_states = hidden_states.unwrap()

      if not self.pre_process:
      # See set_input_tensor()
      hidden_states = self.input_tensor
    • 通过p2p_communicator.send_forward将结果output_tensor传输到下一层(最后一层会自动跳过)

    • 如果不是forward_only模式,就将input_tensoroutput_tensor append进队列存储,并且还会按需清除掉output_tensor

  9. 如果num_warmup_microbatches大于0,那么就需要先通过p2p_communicator.recv_forward来获取上一层PP传递的input_tensor(第一层会自动跳过)。存在micro_batch次数过少而导致没有稳定运行的阶段

  10. 进入稳定运行steady 阶段:

    • 依据接收到的输入或者从data_loader中获取数据进行自己本地模型的前向传播得到结果output_tensor

    • 如果是forward_only模式:

      1. 通过p2p_communicator.send_forward来发送output_tensor到下一层pp

      2. 如果不是最后一次迭代,还需要继续通过p2p_communicator.recv_forward来获取上一层PP传递的input_tensor(第一层会自动跳过)。整体与warm up阶段相同

    • 如果不是forward_only模式:

      1. 就通过p2p_communicator.send_forward_recv_backward来发送output_tensor并接受backward结果output_tensor_grad

      2. input_tensoroutput_tensor append进队列存储,并且还会按需清除掉output_tensor

      3. 如果当前的rank是最后一个rank,即num_warmup_microbatches == 0,并且是最后一次稳态迭代时,才会启动grad sync,因为这时已经到了batch的最后阶段,也确实需要进行梯度同步了

      4. 再从队列头 pop 出最早的input_tensoroutput_tensor,再结合获取的output_tensor_grad用于做 backward,得到input_tensor_grad

      5. 如果不是最后一次的稳态迭代,就通过p2p_communicator.send_backward_recv_forward来发送input_tensor_grad并且获取新输入;如果是最后一次稳态迭代,那么就不需要获取新输入,直接通过p2p_communicator.send_backward发送input_tensor_grad即可

  11. 然后进入cooldown阶段,cooldown阶段需要迭代的数量与warmup相同,就是num_warmup_microbatches

    • 如果是cooldown的最后一个阶段,就启动grad sync

    • 从队列头 pop 出最早的input_tensoroutput_tensor,并通过p2p_communicator.recv_backward获取output_tensor_grad

    • 使用input_tensoroutput_tensoroutput_tensor_grad进行反向传播,获得input_tensor_grad

    • 通过p2p_communicator.send_backwardinput_tensor_grad发送给pp上一层

  12. 如果定义了config.finalize_model_grads_func就启用,其主要负责对所有的梯度进行最终整理与同步,通常包括:

    • DP 的 full grad all-reduce / reduce-scatter

    • sequence parallel 下 layernorm 的 all-reduce

    • pipeline parallel 情况下 embedding all-reduce(embedding group)

  13. 最后如果启用 cuda graph 且 scope 不是 full_iteration,会调用 create_cudagraphs()。

1F1B流水线并行实验

实验依据采用的是GPT3 857m的模型,运行脚本如下所示,值得注意的是,这里设置了pp维度为4,设置batch_size为16,micro_batch_size为2,所以一个batch中存在8个micro_batch,整体与前述的图片中的情况类似。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/bin/bash

# Runs the "857m" parameter model

export CUDA_DEVICE_MAX_CONNECTIONS=1

GPUS_PER_NODE=4
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))

CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
VOCAB_FILE=$3 #<Specify path to file>/gpt2-vocab.json
MERGE_FILE=$4 #<Specify path to file>/gpt2-merges.txt
DATA_PATH=$5 #<Specify path and file prefix>_text_document
USE_NSYS=0
if [[ ${6:-} == "--nsys" ]]; then
USE_NSYS=1
fi

DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)

GPT_MODEL_ARGS=(
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--attention-backend auto # Can use (flash/fused/unfused/local)
)

TRAINING_ARGS=(
--micro-batch-size 2
--global-batch-size 16
# --rampup-batch-size 16 16 5859375
--train-iters 20000
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--fp16
--lr 6.0e-5
--lr-decay-style cosine
--min-lr 6.0e-6
--lr-warmup-fraction .001
--lr-decay-iters 20000
)

MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 4
)

DATA_ARGS=(
--data-path $DATA_PATH
--vocab-file $VOCAB_FILE
--merge-file $MERGE_FILE
--split 949,50,1
)

EVAL_AND_LOGGING_ARGS=(
--log-interval 200
--save-interval 10000
--eval-interval 1000
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)

PROFILER_ARGS=(
--profile
--use-pytorch-profiler
--profile-step-start 110
--profile-step-end 112
--profile-ranks 0
)

# Build command as an array (no string concatenation)
CMD=(
torchrun
"${DISTRIBUTED_ARGS[@]}"
pretrain_gpt.py
"${GPT_MODEL_ARGS[@]}"
"${TRAINING_ARGS[@]}"
"${MODEL_PARALLEL_ARGS[@]}"
"${DATA_ARGS[@]}"
"${EVAL_AND_LOGGING_ARGS[@]}"
"${PROFILER_ARGS[@]}"
)

if [[ "$USE_NSYS" -eq 1 ]]; then
NSIGHT_PREFIX="./nsight_profile/gpt3_857m"
echo "Running with Nsight profiling, output prefix: ${NSIGHT_PREFIX}"
exec nsys profile \
-s none -t nvtx,cuda \
--cudabacktrace=all \
--cuda-graph-trace=node \
--python-backtrace=cuda \
--wait all \
-o "${NSIGHT_PREFIX}" \
--force-overwrite true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
"${CMD[@]}"
else
exec "${CMD[@]}"
fi

运行命令为:

1
bash examples/gpt3/train_gpt3_857m_distributed.sh     /workspace/megatron-lm/model_ckpt/gpt3_857m_pp4     /workspace/megatron-lm/tb_logs/gpt3_857m_profiler_pp4     /workspace/megatron-lm/data/tokenizer/gpt2-vocab.json     /workspace/megatron-lm/data/tokenizer/gpt2-merges.txt     /workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document     > gpt3_857m_pp4.log 2>&1 &

得到的部分运行日志如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
W0108 08:56:58.737000 2180992 torch/distributed/run.py:766] 
W0108 08:56:58.737000 2180992 torch/distributed/run.py:766] *****************************************
W0108 08:56:58.737000 2180992 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0108 08:56:58.737000 2180992 torch/distributed/run.py:766] *****************************************
using world size: 4, data-parallel size: 1, context-parallel size: 1, hierarchical context-parallel sizes: None, tensor-model-parallel size: 1, pipeline-model-parallel size: 4
WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False since non-interleaved schedule does not support overlapping p2p communication and aligned param AG
Number of virtual stages per pipeline stage: None
WARNING: Setting args.check_for_nan_in_loss_and_grad to False since dynamic loss scaling is being used
using torch.float16 for parameters ...
------------------------ arguments ------------------------
account_for_embedding_in_pipeline_split ......... False
account_for_loss_in_pipeline_split .............. False
accumulate_allreduce_grads_in_fp32 .............. False
adam_beta1 ...................................... 0.9
adam_beta2 ...................................... 0.95
adam_eps ........................................ 1e-08
add_bias_linear ................................. True
add_position_embedding .......................... True
add_qkv_bias .................................... True
adlr_autoresume ................................. False
adlr_autoresume_interval ........................ 1000
align_grad_reduce ............................... True
align_param_gather .............................. False
app_tag_run_name ................................ None
app_tag_run_version ............................. 0.0.0
apply_layernorm_1p .............................. False
apply_query_key_layer_scaling ................... False
apply_residual_connection_post_layernorm ........ False
apply_rope_fusion ............................... False
async_save ...................................... None
async_tensor_model_parallel_allreduce ........... True
attention_backend ............................... AttnBackend.auto
attention_dropout ............................... 0.1
attention_softmax_in_fp32 ....................... False
auto_detect_ckpt_format ......................... False
barrier_with_L1_time ............................ True
bert_binary_head ................................ True
bert_embedder_type .............................. megatron
bert_load ....................................... None
bf16 ............................................ False
bias_dropout_fusion ............................. True
bias_gelu_fusion ................................ True
bias_swiglu_fusion .............................. True
biencoder_projection_dim ........................ 0
biencoder_shared_query_context_model ............ False
block_data_path ................................. None
cache_mla_latents ............................... False
calc_ft_timeouts ................................ False
calculate_per_token_loss ........................ False
check_for_large_grads ........................... False
check_for_nan_in_loss_and_grad .................. False
check_for_spiky_loss ............................ False
check_weight_hash_across_dp_replicas_interval ... None
ckpt_assume_constant_structure .................. False
ckpt_convert_format ............................. None
ckpt_convert_save ............................... None
ckpt_convert_update_legacy_dist_opt_format ...... False
ckpt_format ..................................... torch_dist
ckpt_fully_parallel_load ........................ False
ckpt_fully_parallel_save ........................ True
ckpt_fully_parallel_save_deprecated ............. False
ckpt_step ....................................... None
classes_fraction ................................ 1.0
clip_grad ....................................... 1.0
clone_scatter_output_in_embedding ............... True
config_logger_dir ...............................
consumed_train_samples .......................... 0
consumed_valid_samples .......................... 0
context_parallel_size ........................... 1
cp_comm_type .................................... ['p2p']
create_attention_mask_in_dataloader ............. True
cross_entropy_fusion_impl ....................... native
cross_entropy_loss_fusion ....................... False
cuda_graph_scope ................................ full
cuda_graph_warmup_steps ......................... 3
data_args_path .................................. None
data_cache_path ................................. None
data_parallel_random_init ....................... False
data_parallel_sharding_strategy ................. no_shard
data_parallel_size .............................. 1
data_path ....................................... ['/workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document']
data_per_class_fraction ......................... 1.0
data_sharding ................................... True
dataloader_type ................................. single
ddp_average_in_collective ....................... False
ddp_bucket_size ................................. None
ddp_num_buckets ................................. None
ddp_pad_buckets_for_high_nccl_busbw ............. False
decoder_first_pipeline_num_layers ............... None
decoder_last_pipeline_num_layers ................ None
decoder_num_layers .............................. None
decoder_seq_length .............................. None
decoupled_lr .................................... None
decoupled_min_lr ................................ None
decrease_batch_size_if_needed ................... False
defer_embedding_wgrad_compute ................... False
delay_wgrad_compute ............................. False
deprecated_use_mcore_models ..................... False
deterministic_mode .............................. False
dino_bottleneck_size ............................ 256
dino_freeze_last_layer .......................... 1
dino_head_hidden_size ........................... 2048
dino_local_crops_number ......................... 10
dino_local_img_size ............................. 96
dino_norm_last_layer ............................ False
dino_teacher_temp ............................... 0.07
dino_warmup_teacher_temp ........................ 0.04
dino_warmup_teacher_temp_epochs ................. 30
disable_bf16_reduced_precision_matmul ........... False
disable_mamba_mem_eff_path ...................... False
disable_straggler_on_startup .................... False
dist_ckpt_format_deprecated ..................... None
dist_ckpt_strictness ............................ assume_ok_unexpected
distribute_saved_activations .................... False
distributed_backend ............................. nccl
distributed_timeout_minutes ..................... 10
embedding_init_method_std ....................... None
embedding_path .................................. None
empty_unused_memory_level ....................... 0
enable_cuda_graph ............................... False
enable_experimental ............................. False
enable_ft_package ............................... False
enable_full_sharding_in_hsdp .................... False
enable_gloo_process_groups ...................... True
enable_msc ...................................... True
enable_one_logger ............................... True
encoder_num_layers .............................. 24
encoder_seq_length .............................. 2048
end_weight_decay ................................ 0.1
eod_mask_loss ................................... False
error_injection_rate ............................ 0
error_injection_type ............................ transient_error
eval_interval ................................... 1000
eval_iters ...................................... 10
evidence_data_path .............................. None
exit_duration_in_mins ........................... None
exit_interval ................................... None
exit_on_missing_checkpoint ...................... False
exit_signal_handler ............................. False
exp_avg_dtype ................................... torch.float32
exp_avg_sq_dtype ................................ torch.float32
expert_model_parallel_size ...................... 1
expert_tensor_parallel_size ..................... 1
export_force_local_attention .................... False
export_kd_cfg ................................... None
export_kd_teacher_ckpt_format ................... None
export_kd_teacher_load .......................... None
export_kv_cache_quant ........................... False
export_legacy_megatron .......................... False
export_model_type ............................... GPTModel
export_moe_apply_probs_on_input ................. False
export_qk_l2_norm ............................... False
export_quant_cfg ................................ None
export_real_quant_cfg ........................... None
export_te_mcore_model ........................... False
external_cuda_graph ............................. False
ffn_hidden_size ................................. 4096
finetune ........................................ False
finetune_data_split ............................. train
finetune_hf_dataset ............................. None
first_last_layers_bf16 .......................... False
flash_decode .................................... False
fp16 ............................................ True
fp16_lm_cross_entropy ........................... False
fp32_residual_connection ........................ False
fp8 ............................................. None
fp8_amax_compute_algo ........................... most_recent
fp8_amax_history_len ............................ 1
fp8_interval .................................... 1
fp8_margin ...................................... 0
fp8_param_gather ................................ False
fp8_recipe ...................................... delayed
fp8_wgrad ....................................... True
fsdp_double_buffer .............................. False
full_validation ................................. False
global_batch_size ............................... 16
grad_reduce_in_bf16 ............................. False
gradient_accumulation_fusion .................... True
gradient_reduce_div_fusion ...................... True
group_query_attention ........................... False
head_lr_mult .................................... 1.0
heterogeneous_layers_config_encoded_json ........ None
heterogeneous_layers_config_path ................ None
hidden_dropout .................................. 0.1
hidden_size ..................................... 1024
hierarchical_context_parallel_sizes ............. None
high_priority_stream_groups ..................... []
hybrid_attention_ratio .......................... 0.0
hybrid_mlp_ratio ................................ 0.0
hybrid_override_pattern ......................... None
hysteresis ...................................... 2
ict_head_size ................................... None
ict_load ........................................ None
img_h ........................................... 224
img_w ........................................... 224
indexer_batch_size .............................. 128
indexer_log_interval ............................ 1000
inference_batch_times_seqlen_threshold .......... -1
inference_dynamic_batching ...................... False
inference_dynamic_batching_buffer_guaranteed_fraction 0.2
inference_dynamic_batching_buffer_overflow_factor None
inference_dynamic_batching_buffer_size_gb ....... 40.0
inference_dynamic_batching_chunk_size ........... 256
inference_dynamic_batching_max_requests_override None
inference_dynamic_batching_max_tokens_override .. None
inference_dynamic_batching_num_cuda_graphs ...... 16
inference_max_batch_size ........................ 8
inference_max_seq_length ........................ 2560
inference_rng_tracker ........................... False
init_method_std ................................. 0.006
init_method_xavier_uniform ...................... False
init_model_with_meta_device ..................... False
initial_loss_scale .............................. 4294967296
inprocess_active_world_size ..................... 4
inprocess_barrier_timeout ....................... 120
inprocess_completion_timeout .................... 120
inprocess_empty_cuda_cache ...................... False
inprocess_granularity ........................... node
inprocess_hard_timeout .......................... 90
inprocess_heartbeat_interval .................... 30
inprocess_heartbeat_timeout ..................... 60
inprocess_last_call_wait ........................ 1
inprocess_max_iterations ........................ None
inprocess_monitor_process_interval .............. 1.0
inprocess_monitor_thread_interval ............... 1.0
inprocess_progress_watchdog_interval ............ 1.0
inprocess_restart ............................... False
inprocess_soft_timeout .......................... 60
inprocess_termination_grace_time ................ 1
is_hybrid_model ................................. False
iter_per_epoch .................................. 1250
iterations_to_skip .............................. []
keep_fp8_transpose_cache ........................ False
kitchen_config_file ............................. None
kitchen_recipe_number ........................... None
kv_channels ..................................... 64
kv_lora_rank .................................... 32
lazy_mpu_init ................................... None
load ............................................ /workspace/megatron-lm/model_ckpt/gpt3_857m_pp4
load_main_params_from_ckpt ...................... None
load_model_opt_format ........................... False
local_rank ...................................... 0
log_energy ...................................... False
log_interval .................................... 200
log_loss_scale_to_tensorboard ................... True
log_memory_to_tensorboard ....................... False
log_num_zeros_in_grad ........................... False
log_params_norm ................................. False
log_progress .................................... False
log_straggler ................................... False
log_throughput .................................. False
log_timers_to_tensorboard ....................... False
log_validation_ppl_to_tensorboard ............... False
log_world_size_to_tensorboard ................... False
logging_level ................................... None
loss_scale ...................................... None
loss_scale_window ............................... 1000
lr .............................................. 6e-05
lr_decay_iters .................................. 20000
lr_decay_samples ................................ None
lr_decay_style .................................. cosine
lr_warmup_fraction .............................. 0.001
lr_warmup_init .................................. 0.0
lr_warmup_iters ................................. 0
lr_warmup_samples ............................... 0
lr_wsd_decay_iters .............................. None
lr_wsd_decay_samples ............................ None
lr_wsd_decay_style .............................. exponential
main_grads_dtype ................................ torch.float32
main_params_dtype ............................... torch.float32
make_vocab_size_divisible_by .................... 128
mamba_head_dim .................................. 64
mamba_num_groups ................................ 8
mamba_num_heads ................................. None
mamba_state_dim ................................. 128
manual_gc ....................................... False
manual_gc_eval .................................. True
manual_gc_interval .............................. 0
mask_factor ..................................... 1.0
mask_prob ....................................... 0.15
mask_type ....................................... random
masked_softmax_fusion ........................... True
max_position_embeddings ......................... 2048
max_tokens_to_oom ............................... 12000
memory_snapshot_path ............................ snapshot.pickle
merge_file ...................................... /workspace/megatron-lm/data/tokenizer/gpt2-merges.txt
micro_batch_size ................................ 2
microbatch_group_size_per_vp_stage .............. None
mid_level_dataset_surplus ....................... 0.005
min_loss_scale .................................. 1.0
min_lr .......................................... 6e-06
mlp_chunks_for_prefill .......................... 1
mmap_bin_files .................................. True
mock_data ....................................... False
moe_apply_probs_on_input ........................ False
moe_aux_loss_coeff .............................. 0.0
moe_deepep_num_sms .............................. 20
moe_enable_deepep ............................... False
moe_expert_capacity_factor ...................... None
moe_extended_tp ................................. False
moe_ffn_hidden_size ............................. None
moe_grouped_gemm ................................ False
moe_input_jitter_eps ............................ None
moe_layer_freq .................................. 1
moe_layer_recompute ............................. False
moe_pad_expert_input_to_capacity ................ False
moe_per_layer_logging ........................... False
moe_permute_fusion .............................. False
moe_router_bias_update_rate ..................... 0.001
moe_router_dtype ................................ None
moe_router_enable_expert_bias ................... False
moe_router_force_load_balancing ................. False
moe_router_fusion ............................... False
moe_router_group_topk ........................... None
moe_router_load_balancing_type .................. aux_loss
moe_router_num_groups ........................... None
moe_router_padding_for_fp8 ...................... False
moe_router_pre_softmax .......................... False
moe_router_score_function ....................... softmax
moe_router_topk ................................. 2
moe_router_topk_scaling_factor .................. None
moe_shared_expert_intermediate_size ............. None
moe_shared_expert_overlap ....................... False
moe_token_dispatcher_type ....................... allgather
moe_token_drop_policy ........................... probs
moe_upcycling_granularity ....................... 1
moe_use_legacy_grouped_gemm ..................... False
moe_use_upcycling ............................... False
moe_z_loss_coeff ................................ None
mrope_section ................................... None
mscale .......................................... 1.0
mscale_all_dim .................................. 0.0
mtp_loss_scaling_factor ......................... 0.1
mtp_num_layers .................................. None
multi_latent_attention .......................... False
multiple_validation_sets ........................ False
nccl_all_reduce_for_prefill ..................... False
nccl_communicator_config_path ................... None
nccl_ub ......................................... False
no_load_optim ................................... None
no_load_rng ..................................... None
no_persist_layer_norm ........................... False
no_rope_freq .................................... None
no_save_optim ................................... None
no_save_rng ..................................... None
non_persistent_ckpt_type ........................ None
non_persistent_global_ckpt_dir .................. None
non_persistent_local_ckpt_algo .................. fully_parallel
non_persistent_local_ckpt_dir ................... None
non_persistent_save_interval .................... None
norm_epsilon .................................... 1e-05
normalization ................................... LayerNorm
num_attention_heads ............................. 16
num_channels .................................... 3
num_classes ..................................... 1000
num_dataset_builder_threads ..................... 1
num_distributed_optimizer_instances ............. 1
num_experts ..................................... None
num_layers ...................................... 24
num_layers_at_end_in_bf16 ....................... 1
num_layers_at_start_in_bf16 ..................... 1
num_layers_per_virtual_pipeline_stage ........... None
num_query_groups ................................ 1
num_virtual_stages_per_pipeline_rank ............ None
num_workers ..................................... 2
object_storage_cache_path ....................... None
one_logger_async ................................ False
one_logger_project .............................. megatron-lm
one_logger_run_name ............................. None
onnx_safe ....................................... None
openai_gelu ..................................... False
optimizer ....................................... adam
optimizer_cpu_offload ........................... False
optimizer_offload_fraction ...................... 1.0
output_bert_embeddings .......................... False
overlap_cpu_optimizer_d2h_h2d ................... False
overlap_grad_reduce ............................. False
overlap_moe_expert_parallel_comm ................ False
overlap_p2p_comm ................................ False
overlap_p2p_comm_warmup_flush ................... False
overlap_param_gather ............................ False
overlap_param_gather_with_optimizer_step ........ False
override_opt_param_scheduler .................... False
padded_vocab_size ............................... None
params_dtype .................................... torch.float16
patch_dim ....................................... 16
per_split_data_args_path ........................ None
perform_initialization .......................... True
pin_cpu_grads ................................... True
pin_cpu_params .................................. True
pipeline_model_parallel_comm_backend ............ None
pipeline_model_parallel_layout .................. None
pipeline_model_parallel_size .................... 4
position_embedding_type ......................... learned_absolute
pretrained_checkpoint ........................... None
profile ......................................... True
profile_ranks ................................... [0]
profile_step_end ................................ 112
profile_step_start .............................. 110
q_lora_rank ..................................... None
qk_head_dim ..................................... 128
qk_l2_norm ...................................... False
qk_layernorm .................................... False
qk_pos_emb_head_dim ............................. 64
query_in_block_prob ............................. 0.1
rampup_batch_size ............................... None
rank ............................................ 0
recompute_granularity ........................... None
recompute_method ................................ None
recompute_modules ............................... None
recompute_num_layers ............................ None
record_memory_history ........................... False
relative_attention_max_distance ................. 128
relative_attention_num_buckets .................. 32
replication ..................................... False
replication_factor .............................. 2
replication_jump ................................ None
rerun_mode ...................................... validate_results
reset_attention_mask ............................ False
reset_position_ids .............................. False
result_rejected_tracker_filename ................ None
retriever_report_topk_accuracies ................ []
retriever_score_scaling ......................... False
retriever_seq_length ............................ 256
retro_add_retriever ............................. False
retro_attention_gate ............................ 1
retro_cyclic_train_iters ........................ None
retro_encoder_attention_dropout ................. 0.1
retro_encoder_hidden_dropout .................... 0.1
retro_encoder_layers ............................ 2
retro_num_neighbors ............................. 2
retro_num_retrieved_chunks ...................... 2
retro_project_dir ............................... None
retro_verify_neighbor_count ..................... True
reuse_grad_buf_for_mxfp8_param_ag ............... False
rope_scaling_factor ............................. 8.0
rope_type ....................................... None
rotary_base ..................................... 10000
rotary_interleaved .............................. False
rotary_percent .................................. 1.0
rotary_scaling_factor ........................... 1.0
rotary_seq_len_interpolation_factor ............. None
run_workload_inspector_server ................... False
sample_rate ..................................... 1.0
save ............................................ /workspace/megatron-lm/model_ckpt/gpt3_857m_pp4
save_interval ................................... 10000
save_retain_interval ............................ None
scatter_gather_tensors_in_pipeline .............. True
seed ............................................ 1234
seq_length ...................................... 2048
sequence_parallel ............................... False
sft ............................................. False
sft_tokenizer_prompt_format ..................... nemotron-h-aligned
sgd_momentum .................................... 0.9
sharp_enabled_group ............................. None
short_seq_prob .................................. 0.1
skip_train ...................................... False
skipped_train_samples ........................... 0
spec ............................................ None
split ........................................... 949,50,1
squared_relu .................................... False
start_weight_decay .............................. 0.1
straggler_ctrlr_port ............................ 65535
straggler_minmax_count .......................... 1
strict_fsdp_dtensor_load ........................ True
suggested_communication_unit_size ............... None
swiglu .......................................... False
swin_backbone_type .............................. tiny
symmetric_ar_type ............................... None
te_rng_tracker .................................. False
tensor_model_parallel_size ...................... 1
tensorboard_dir ................................. /workspace/megatron-lm/tb_logs/gpt3_857m_profiler_pp4
tensorboard_log_interval ........................ 1
tensorboard_queue_size .......................... 1000
test_data_path .................................. None
test_mode ....................................... False
tiktoken_num_special_tokens ..................... 1000
tiktoken_pattern ................................ None
tiktoken_special_tokens ......................... None
timing_log_level ................................ 0
timing_log_option ............................... minmax
titles_data_path ................................ None
tokenizer_model ................................. None
tokenizer_type .................................. GPT2BPETokenizer
torch_fsdp2_reshard_after_forward ............... True
tp_comm_bootstrap_backend ....................... nccl
tp_comm_bulk_dgrad .............................. True
tp_comm_bulk_wgrad .............................. True
tp_comm_overlap ................................. False
tp_comm_overlap_ag .............................. True
tp_comm_overlap_cfg ............................. None
tp_comm_overlap_rs .............................. True
tp_comm_overlap_rs_dgrad ........................ False
tp_comm_split_ag ................................ True
tp_comm_split_rs ................................ True
train_data_path ................................. None
train_iters ..................................... 20000
train_samples ................................... None
train_sync_interval ............................. None
transformer_impl ................................ transformer_engine
transformer_pipeline_model_parallel_size ........ 4
untie_embeddings_and_output_weights ............. False
use_checkpoint_args ............................. False
use_checkpoint_opt_param_scheduler .............. False
use_cpu_initialization .......................... None
use_dist_ckpt ................................... True
use_dist_ckpt_deprecated ........................ False
use_distributed_optimizer ....................... False
use_flash_attn .................................. False
use_fused_weighted_squared_relu ................. False
use_legacy_models ............................... False
use_megatron_fsdp ............................... False
use_mp_args_from_checkpoint_args ................ False
use_one_sent_docs ............................... False
use_persistent_ckpt_worker ...................... False
use_precision_aware_optimizer ................... False
use_pytorch_profiler ............................ True
use_ring_exchange_p2p ........................... False
use_rope_scaling ................................ False
use_rotary_position_embeddings .................. False
use_sharp ....................................... False
use_tokenizer_model_from_checkpoint_args ........ True
use_torch_fsdp2 ................................. False
use_torch_optimizer_for_cpu_offload ............. False
use_tp_pp_dp_mapping ............................ False
v_head_dim ...................................... 128
valid_data_path ................................. None
variable_seq_lengths ............................ False
virtual_pipeline_model_parallel_size ............ None
vision_backbone_type ............................ vit
vision_pretraining .............................. False
vision_pretraining_type ......................... classify
vocab_extra_ids ................................. 0
vocab_file ...................................... /workspace/megatron-lm/data/tokenizer/gpt2-vocab.json
vocab_size ...................................... None
wandb_exp_name ..................................
wandb_project ...................................
wandb_save_dir ..................................
weight_decay .................................... 0.1
weight_decay_incr_style ......................... constant
wgrad_deferral_limit ............................ 0
world_size ...................................... 4
yaml_cfg ........................................ None
-------------------- end of arguments ---------------------
INFO:megatron.core.num_microbatches_calculator:setting number of microbatches to constant 8
> building GPT2BPETokenizer tokenizer ...
> padded vocab (size: 50257) with 47 dummy tokens (new size: 50304)
WARNING:megatron.core.rerun_state_machine:RerunStateMachine initialized in mode RerunMode.VALIDATE_RESULTS
> initializing torch distributed ...
> initialized tensor model parallel with size 1
> initialized pipeline model parallel with size 4
> setting random seeds to 1234 ...
> compiling dataset index builder ...
make: Entering directory '/workspace/megatron-lm/megatron/core/datasets'
> setting tensorboard ...
WARNING: one_logger package is required to enable e2e metrics tracking. please go to https://confluence.nvidia.com/display/MLWFO/Package+Repositories for details to install it
[rank1]:[W108 08:57:05.977603534 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank2]:[W108 08:57:05.007203225 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
make: Nothing to be done for 'default'.
make: Leaving directory '/workspace/megatron-lm/megatron/core/datasets'
>>> done with dataset index builder. Compilation time: 0.255 seconds
> compiling and loading fused kernels ...
[rank3]:[W108 08:57:05.064460267 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank0]:[W108 08:57:05.131019652 ProcessGroupNCCL.cpp:4751] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
>>> done with compiling and loading fused kernels. Compilation time: 0.321 seconds
time to initialize megatron (seconds): 2.282
[after megatron is initialized] datetime: 2026-01-08 08:57:07
building GPT model ...
> number of parameters on (tensor, pipeline) model parallel rank (0, 1): 75577344
> number of parameters on (tensor, pipeline) model parallel rank (0, 2): 75577344
INFO:megatron.core.distributed.param_and_grad_buffer:Number of buckets for gradient all-reduce / reduce-scatter: 1
Params for bucket 1 (75577344 elements, 75577344 padded size):
module.decoder.layers.5.mlp.linear_fc1.bias
module.decoder.layers.2.self_attention.linear_qkv.bias
module.decoder.layers.1.self_attention.linear_qkv.bias
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.0.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.5.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.5.self_attention.linear_proj.bias
module.decoder.layers.5.self_attention.linear_proj.weight
module.decoder.layers.2.mlp.linear_fc1.bias
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.0.self_attention.linear_proj.bias
module.decoder.layers.5.mlp.linear_fc1.weight
module.decoder.layers.5.self_attention.linear_qkv.bias
module.decoder.layers.4.mlp.linear_fc1.bias
module.decoder.layers.4.self_attention.linear_proj.weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.3.self_attention.linear_qkv.bias
module.decoder.layers.3.self_attention.linear_qkv.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.2.self_attention.linear_proj.bias
module.decoder.layers.1.mlp.linear_fc2.bias
module.decoder.layers.4.mlp.linear_fc2.weight
module.decoder.layers.4.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.mlp.linear_fc2.bias
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.self_attention.linear_proj.weight
module.decoder.layers.1.mlp.linear_fc2.weight
module.decoder.layers.1.mlp.linear_fc1.bias
module.decoder.layers.1.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.0.mlp.linear_fc1.weight
module.decoder.layers.0.self_attention.linear_proj.weight
module.decoder.layers.5.self_attention.linear_qkv.weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.4.mlp.linear_fc2.bias
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.self_attention.linear_proj.bias
module.decoder.layers.1.self_attention.linear_proj.bias
module.decoder.layers.0.mlp.linear_fc2.bias
module.decoder.layers.0.self_attention.linear_qkv.weight
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.5.mlp.linear_fc2.bias
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.4.self_attention.linear_proj.bias
module.decoder.layers.3.mlp.linear_fc2.weight
module.decoder.layers.3.mlp.linear_fc1.bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.1.mlp.linear_fc1.weight
module.decoder.layers.1.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.0.mlp.linear_fc2.weight
module.decoder.layers.0.self_attention.linear_qkv.bias
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.5.mlp.linear_fc2.weight
module.decoder.layers.5.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.4.self_attention.linear_qkv.bias
module.decoder.layers.4.self_attention.linear_qkv.weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.3.self_attention.linear_proj.weight
module.decoder.layers.2.self_attention.linear_qkv.weight
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.self_attention.linear_qkv.weight
module.decoder.layers.1.self_attention.linear_proj.weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.4.mlp.linear_fc1.weight
module.decoder.layers.4.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.3.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc2.bias
module.decoder.layers.2.mlp.linear_fc2.weight
module.decoder.layers.2.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.0.mlp.linear_fc1.bias
module.decoder.layers.0.mlp.linear_fc1.layer_norm_weight
INFO:megatron.core.distributed.param_and_grad_buffer:Number of buckets for gradient all-reduce / reduce-scatter: 1
Params for bucket 1 (75577344 elements, 75577344 padded size):
module.decoder.layers.5.mlp.linear_fc1.bias
module.decoder.layers.5.self_attention.linear_qkv.weight
module.decoder.layers.4.mlp.linear_fc1.weight
module.decoder.layers.3.mlp.linear_fc1.bias
module.decoder.layers.3.mlp.linear_fc1.weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.2.self_attention.linear_proj.weight
module.decoder.layers.1.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.0.mlp.linear_fc2.weight
module.decoder.layers.4.self_attention.linear_qkv.weight
module.decoder.layers.4.self_attention.linear_proj.bias
module.decoder.layers.3.mlp.linear_fc2.bias
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.2.self_attention.linear_qkv.weight
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.self_attention.linear_qkv.bias
module.decoder.layers.1.self_attention.linear_qkv.weight
module.decoder.layers.0.self_attention.linear_qkv.bias
module.decoder.layers.5.mlp.linear_fc2.weight
module.decoder.layers.5.self_attention.linear_qkv.bias
module.decoder.layers.5.self_attention.linear_proj.weight
module.decoder.layers.4.mlp.linear_fc2.weight
module.decoder.layers.3.self_attention.linear_qkv.bias
module.decoder.layers.2.mlp.linear_fc2.bias
module.decoder.layers.2.mlp.linear_fc1.bias
module.decoder.layers.1.self_attention.linear_proj.weight
module.decoder.layers.0.mlp.linear_fc2.bias
module.decoder.layers.0.mlp.linear_fc1.bias
module.decoder.layers.0.self_attention.linear_qkv.weight
module.decoder.layers.5.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.5.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.3.self_attention.linear_proj.bias
module.decoder.layers.2.self_attention.linear_qkv.bias
module.decoder.layers.1.mlp.linear_fc2.bias
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.0.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.5.self_attention.linear_proj.bias
module.decoder.layers.4.mlp.linear_fc1.bias
module.decoder.layers.4.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.4.self_attention.linear_proj.weight
module.decoder.layers.3.self_attention.linear_qkv.weight
module.decoder.layers.3.self_attention.linear_proj.weight
module.decoder.layers.2.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.mlp.linear_fc1.weight
module.decoder.layers.5.mlp.linear_fc2.bias
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.mlp.linear_fc2.weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.self_attention.linear_proj.bias
module.decoder.layers.0.self_attention.linear_proj.bias
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.5.mlp.linear_fc1.weight
module.decoder.layers.4.mlp.linear_fc2.bias
module.decoder.layers.4.self_attention.linear_qkv.bias
module.decoder.layers.4.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.2.mlp.linear_fc2.weight
module.decoder.layers.1.mlp.linear_fc1.bias
module.decoder.layers.1.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.0.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.2.self_attention.linear_proj.bias
module.decoder.layers.1.mlp.linear_fc2.weight
module.decoder.layers.0.mlp.linear_fc1.weight
module.decoder.layers.0.self_attention.linear_proj.weight
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_bias
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 129185792 > number of parameters on (tensor, pipeline) model parallel rank (0, 3): 127090688

INFO:megatron.core.distributed.param_and_grad_buffer:Number of buckets for gradient all-reduce / reduce-scatter: 1
Params for bucket 1 (127090688 elements, 127090688 padded size):
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.5.self_attention.linear_proj.weight
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.4.self_attention.linear_proj.bias
module.decoder.layers.3.self_attention.linear_qkv.weight
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.self_attention.linear_proj.bias
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.mlp.linear_fc1.weight
module.decoder.layers.5.self_attention.linear_qkv.weight
module.decoder.layers.4.mlp.linear_fc2.bias
module.decoder.layers.3.mlp.linear_fc2.bias
module.decoder.layers.3.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.mlp.linear_fc1.weight
module.decoder.layers.1.self_attention.linear_qkv.weight
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.0.self_attention.linear_proj.weight
module.decoder.layers.5.mlp.linear_fc2.weight
module.decoder.final_layernorm.bias
module.decoder.layers.4.mlp.linear_fc1.bias
module.decoder.layers.4.mlp.linear_fc1.weight
module.decoder.layers.3.self_attention.linear_proj.weight
module.decoder.layers.2.mlp.linear_fc2.bias
module.decoder.layers.2.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.2.self_attention.linear_qkv.weight
module.decoder.layers.2.self_attention.linear_proj.weight
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.mlp.linear_fc2.bias
module.output_layer.weight
module.decoder.layers.5.mlp.linear_fc1.weight
module.decoder.layers.3.self_attention.linear_qkv.bias
module.decoder.layers.2.mlp.linear_fc2.weight
module.decoder.layers.2.mlp.linear_fc1.bias
module.decoder.layers.1.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.1.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.0.mlp.linear_fc1.weight
module.decoder.layers.4.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.3.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.1.mlp.linear_fc1.bias
module.decoder.layers.1.self_attention.linear_proj.weight
module.decoder.layers.0.mlp.linear_fc1.bias
module.decoder.layers.0.self_attention.linear_qkv.bias
module.decoder.final_layernorm.weight
module.decoder.layers.5.mlp.linear_fc2.bias
module.decoder.layers.5.mlp.linear_fc1.bias
module.decoder.layers.5.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.5.self_attention.linear_proj.bias
module.decoder.layers.4.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.4.self_attention.linear_qkv.bias
module.decoder.layers.4.self_attention.linear_proj.weight
module.decoder.layers.3.mlp.linear_fc2.weight
module.decoder.layers.3.mlp.linear_fc1.bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.self_attention.linear_proj.bias
module.decoder.layers.5.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.4.self_attention.linear_qkv.weight
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.2.self_attention.linear_proj.bias
module.decoder.layers.1.self_attention.linear_qkv.bias
module.decoder.layers.0.mlp.linear_fc2.bias
module.decoder.layers.0.mlp.linear_fc2.weight
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.0.self_attention.linear_proj.bias
module.decoder.layers.0.self_attention.linear_qkv.weight
module.decoder.layers.5.self_attention.linear_qkv.bias
module.decoder.layers.4.mlp.linear_fc2.weight
module.decoder.layers.2.self_attention.linear_qkv.bias
module.decoder.layers.1.mlp.linear_fc2.weight
module.decoder.layers.0.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.0.mlp.linear_fc1.layer_norm_weight
INFO:megatron.core.distributed.distributed_data_parallel:Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=False, check_for_large_grads=False, bucket_size=None, pad_buckets_for_high_nccl_busbw=False, average_in_collective=False, fp8_param_gather=False, reuse_grad_buf_for_mxfp8_param_ag=False, use_megatron_fsdp=False, use_custom_fsdp=False, data_parallel_sharding_strategy='no_shard', gradient_reduce_div_fusion=True, suggested_communication_unit_size=None, preserve_fp32_weights=True, keep_fp8_transpose_cache=False, nccl_ub=False, fsdp_double_buffer=False, outer_dp_sharding_strategy='no_shard', disable_symmetric_registration=False, delay_wgrad_compute=False)
INFO:megatron.core.distributed.param_and_grad_buffer:Number of buckets for gradient all-reduce / reduce-scatter: 1
Params for bucket 1 (129185792 elements, 129185792 padded size):
module.decoder.layers.5.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.5.self_attention.linear_proj.bias
module.decoder.layers.4.self_attention.linear_qkv.bias
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.4.self_attention.linear_proj.weight
module.decoder.layers.3.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc1.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.1.self_attention.linear_qkv.weight
module.decoder.layers.1.self_attention.linear_proj.weight
module.decoder.layers.5.mlp.linear_fc1.weight
module.decoder.layers.5.self_attention.linear_qkv.bias
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.self_attention.linear_qkv.weight
module.decoder.layers.1.mlp.linear_fc2.bias
module.decoder.layers.0.mlp.linear_fc1.bias
module.decoder.layers.0.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.0.self_attention.linear_qkv.bias
module.decoder.layers.0.self_attention.linear_proj.weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.5.self_attention.linear_proj.weight
module.decoder.layers.4.mlp.linear_fc1.bias
module.decoder.layers.4.self_attention.linear_proj.bias
module.decoder.layers.3.mlp.linear_fc2.bias
module.decoder.layers.3.self_attention.linear_proj.weight
module.decoder.layers.2.mlp.linear_fc2.weight
module.decoder.layers.1.self_attention.linear_qkv.bias
module.decoder.layers.0.self_attention.linear_qkv.weight
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.0.self_attention.linear_proj.bias
module.embedding.position_embeddings.weight
module.decoder.layers.5.mlp.linear_fc2.bias
module.decoder.layers.5.mlp.linear_fc1.bias
module.decoder.layers.5.mlp.linear_fc2.weight
module.decoder.layers.4.mlp.linear_fc1.weight
module.decoder.layers.4.self_attention.linear_qkv.weight
module.decoder.layers.3.mlp.linear_fc2.weight
module.decoder.layers.2.mlp.linear_fc1.bias
module.decoder.layers.2.self_attention.linear_proj.bias
module.decoder.layers.0.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.4.mlp.linear_fc2.bias
module.decoder.layers.4.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.4.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.2.mlp.linear_fc2.bias
module.decoder.layers.2.self_attention.linear_proj.weight
module.decoder.layers.1.mlp.linear_fc1.bias
module.decoder.layers.1.mlp.linear_fc1.weight
module.decoder.layers.0.mlp.linear_fc2.bias
module.embedding.word_embeddings.weight
module.decoder.layers.5.self_attention.linear_qkv.weight
module.decoder.layers.5.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.3.mlp.linear_fc1.bias
module.decoder.layers.3.self_attention.linear_qkv.bias
module.decoder.layers.2.self_attention.linear_qkv.bias
module.decoder.layers.1.mlp.linear_fc2.weight
module.decoder.layers.5.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.4.mlp.linear_fc2.weight
module.decoder.layers.3.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.3.self_attention.linear_qkv.weight
module.decoder.layers.2.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.mlp.linear_fc1.layer_norm_weight
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.1.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.self_attention.linear_proj.bias
module.decoder.layers.4.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.3.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.3.self_attention.linear_proj.bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_bias
module.decoder.layers.2.self_attention.linear_qkv.layer_norm_weight
module.decoder.layers.1.mlp.linear_fc1.layer_norm_bias
module.decoder.layers.0.mlp.linear_fc2.weight
module.decoder.layers.0.mlp.linear_fc1.weight
module.decoder.layers.0.mlp.linear_fc1.layer_norm_bias
INFO:megatron.core.optimizer:Setting up optimizer with config OptimizerConfig(optimizer='adam', lr=6e-05, min_lr=6e-06, decoupled_lr=None, decoupled_min_lr=None, weight_decay=0.1, fp8_recipe='delayed', fp16=True, bf16=False, reuse_grad_buf_for_mxfp8_param_ag=False, params_dtype=torch.float16, use_precision_aware_optimizer=False, store_param_remainders=True, main_grads_dtype=torch.float32, main_params_dtype=torch.float32, exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, loss_scale=None, initial_loss_scale=4294967296, min_loss_scale=1.0, loss_scale_window=1000, hysteresis=2, adam_beta1=0.9, adam_beta2=0.95, adam_eps=1e-08, sgd_momentum=0.9, use_distributed_optimizer=False, overlap_param_gather=False, overlap_param_gather_with_optimizer_step=False, use_megatron_fsdp=False, optimizer_cpu_offload=False, optimizer_offload_fraction=1.0, use_torch_optimizer_for_cpu_offload=False, overlap_cpu_optimizer_d2h_h2d=False, pin_cpu_grads=True, pin_cpu_params=True, clip_grad=1.0, log_num_zeros_in_grad=False, barrier_with_L1_time=True, timers=<megatron.core.timers.Timers object at 0x7ff3d6145b50>, config_logger_dir='')
INFO:megatron.core.optimizer_param_scheduler:> learning rate decay style: cosine
WARNING: could not find the metadata file /workspace/megatron-lm/model_ckpt/gpt3_857m_pp4/latest_checkpointed_iteration.txt
will not load any checkpoints and will start from random
(min, max) time across ranks (ms):
load-checkpoint ................................: (0.41, 0.43)
[after model, optimizer, and learning rate scheduler are built] datetime: 2026-01-08 08:57:07
> building train, validation, and test datasets ...
> datasets target sizes (minimum size):
train: 320000
validation: 3360
test: 160
INFO:megatron.core.datasets.blended_megatron_dataset_config:Let split_matrix = [(0, 0.949), (0.949, 0.999), (0.999, 1.0)]
> building train, validation, and test datasets for GPT ...
INFO:megatron.core.datasets.blended_megatron_dataset_builder:Building GPTDataset splits with sizes=(320000, 3360, 160) and config=GPTDatasetConfig(random_seed=1234, sequence_length=2048, blend=(['/workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document'], None), blend_per_split=None, multiple_validation_sets=False, full_validation=False, split='949,50,1', split_matrix=[(0, 0.949), (0.949, 0.999), (0.999, 1.0)], num_dataset_builder_threads=1, path_to_cache=None, mmap_bin_files=True, mock=False, tokenizer=<megatron.training.tokenizer.tokenizer._GPT2BPETokenizer object at 0x7ff3d6425760>, mid_level_dataset_surplus=0.005, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False, create_attention_mask=True, drop_last_partial_validation_sequence=True, add_extra_token_to_sequence=True, object_storage_cache_path=None)
INFO:megatron.core.datasets.indexed_dataset:Load the _IndexReader from /workspace/megatron-lm/data/TinyStoriesV2-GPT4-train_text_document.idx
INFO:megatron.core.datasets.indexed_dataset: Extract the sequence lengths
INFO:megatron.core.datasets.indexed_dataset: Extract the sequence pointers
INFO:megatron.core.datasets.indexed_dataset: Extract the document indices
INFO:megatron.core.datasets.indexed_dataset:> total number of sequences: 14548094
INFO:megatron.core.datasets.indexed_dataset:> total number of documents: 14548094
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset train indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from c376e20e5de541283d4ccc974c960cb8-GPTDataset-train-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 521301
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset valid indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from f975a8258f34477c465b869135b1a202-GPTDataset-valid-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 13728
INFO:megatron.core.datasets.gpt_dataset:Load the GPTDataset test indices
INFO:megatron.core.datasets.gpt_dataset: Load the document index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-document_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the sample index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-sample_index.npy
INFO:megatron.core.datasets.gpt_dataset: Load the shuffle index from b12f62104fc19a6d3c6c6402fedd7e04-GPTDataset-test-shuffle_index.npy
INFO:megatron.core.datasets.gpt_dataset:> total number of samples: 273
> finished creating GPT datasets ...
[after dataloaders are built] datetime: 2026-01-08 08:57:08
done with setup ...
(min, max) time across ranks (ms):
model-and-optimizer-setup ......................: (166.42, 186.26)
train/valid/test-data-iterators-setup ..........: (67.88, 194.17)
training ...
Overwriting rerun_state_machine.current_iteration from -1 to 0...
[before the start of training step] datetime: 2026-01-08 08:57:08
Number of parameters in transformer block in billions: 0.30
Number of parameters in embedding layers in billions: 0.05
Total number of parameters in billions: 0.35
Number of parameters in most loaded shard in billions: 0.1270
Number of parameters in other shards in billions: 0.0755
Theoretical memory footprints: weight and optimizer=2180.68 MB
[2026-01-08 08:59:01] iteration 200/ 20000 | consumed samples: 3200 | elapsed time per iteration (ms): 565.5 | learning rate: 5.999146E-05 | global batch size: 16 | lm loss: 5.657457E+00 | loss scale: 8192.0 | grad norm: 0.912 | number of skipped iterations: 20 | number of nan iterations: 0 |
[Rank 2] (after 200 iterations) memory (MB) | allocated: 1461.5283203125 | max allocated: 3028.9892578125 | reserved: 3156.0 | max reserved: 3156.0[Rank 1] (after 200 iterations) memory (MB) | allocated: 1461.5283203125 | max allocated: 3894.86767578125 | reserved: 4022.0 | max reserved: 4022.0

[Rank 0] (after 200 iterations) memory (MB) | allocated: 2468.0283203125 | max allocated: 5586.99609375 | reserved: 5734.0 | max reserved: 5734.0
[Rank 3] (after 200 iterations) memory (MB) | allocated: 2460.318359375 | max allocated: 4044.673828125 | reserved: 4670.0 | max reserved: 4670.0
[2026-01-08 09:00:43] iteration 400/ 20000 | consumed samples: 6400 | elapsed time per iteration (ms): 510.8 | learning rate: 5.995675E-05 | global batch size: 16 | lm loss: 3.729681E+00 | loss scale: 8192.0 | grad norm: 0.920 | number of skipped iterations: 0 | number of nan iterations: 0 |

Profiler文件

注意这里采用了新UIhttps://ui.perfetto.dev/来查看Profiler文件。

可以看到这里与之前的代码分析基本一致,注意这里是rank 0的分析,所以前面warmup阶段有3个Forward,然后进入steady阶段,有Forward+Backward连续触发了5次,然后最后用3次Backward进行cold down。

细看前面warmup阶段,在两个Forward之间send_forward的P2P通信是batched_p2p_ops,并且紧接着是一个cuda sync。

而在steady阶段采用的send_forward_recv_backward也是_batch_p2p_ops

在最后的cold down阶段采用的recv_backward也是_batch_p2p_ops


【Megatron-LM源码分析(六)】-流水线并行-1F1B
http://example.com/2026/01/09/megatron-lm-pp/
作者
滑滑蛋
发布于
2026年1月9日
许可协议