【Megatron-LM源码分析(四)】-DDP数据并行

数据并行存在多种,最简单的就是DDP,每个DP都拥有完整的模型,然后在每个batch结束后在DP间同步梯度,最后统一进行优化器更新。再复杂一些的数据宾县会上ZERO技术,将模型、梯度、优化器状态等进行切分。

这里主要关注最简单的DDP,主要关注DP并行组如何划分,如果做个各个DP读取不同的数据,以及训练过程中如何做到梯度同步。

DP并行组

查看megatron/core/parallel_state.py中的initialize_model_parallel函数,可以看到其model_size计算的公式为model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size,然后其data_parallel_size并不是直接定义出来的,而是通过data_parallel_size: int = world_size // model_size计算出来。

其构建DP并行组的代码如下,这里是通过decoder_rank_generator将所有除所属dp不同外都相同的rank组成一个dp group通信组:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for ranks in decoder_rank_generator.get_ranks('dp'):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options("dp", nccl_comm_cfgs),
group_desc="DATA_PARALLEL_GROUP",
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks, timeout=timeout, backend="gloo", group_desc="DATA_PARALLEL_GROUP_GLOO"
)
else:
group_gloo = None
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks

数据切分

在DP并行下,每个dp都应该获得不同的数据。

数据集读取器构造流程

数据集读取器的整体构造流程为:

  1. 用户提供数据集原始文件,以及token化所需要的merge和vocab文件,然后提供split划分train、valid、test的比例,并设置micro-batch-sizeglobal-batch-size

  2. 用户通过定义BlendedMegatronDatasetBuilder得到将文件转化为数据集的train_ds, valid_ds, test_ds

  3. 然后将这些数据集转化为支持迭代获取一批一批数据的rain_dataloader, valid_dataloaders, test_dataloader

Megatron-LM构建数据集读取器的具体构造顺序如下:

  1. 用户需要自定义一个train_valid_test_datasets_provider,在该函数中构建BlendedMegatronDatasetBuilder获得train_ds, valid_ds, test_ds并返回,这类ds对原始数据集文件进行了包裹,支持从中读取数据。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def is_dataset_built_on_rank():
return (
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0

def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
split=args.split,
multiple_validation_sets=args.multiple_validation_sets,
full_validation=args.full_validation,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
object_storage_cache_path=args.object_storage_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
)

def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.

Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()

config = core_gpt_dataset_config_from_args(args)

if args.sft:
dataset_type = SFTDataset
else:
if args.mock_data:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset

print_rank_0("> building train, validation, and test datasets for GPT ...")

train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
).build()

print_rank_0("> finished creating GPT datasets ...")

return train_ds, valid_ds, test_ds

  • 将自定义的train_valid_test_datasets_provider作为参数传递给pretrain核心训练函数,然后在pretrain中借助build_train_valid_test_data_iterators函数构造出数据集迭代器:train_data_iterator,valid_data_iterator, test_data_iterator。注意如果是pp并行中启用了vp,那么需要多个iterators。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    if __name__ == "__main__":

    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    # Optionally enable inprocess restart on pretrain
    pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)

    pretrain(
    train_valid_test_datasets_provider,
    model_provider,
    ModelType.encoder_or_decoder,
    forward_step,
    args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
    extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
    store=store,
    )

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    def pretrain(...):
    ...
    if args.virtual_pipeline_model_parallel_size is not None:
    train_data_iterator = []
    valid_data_iterator = []
    test_data_iterator = []
    for i in range(len(model)):
    iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
    train_data_iterator.append(iterators[0])
    valid_data_iterator.append(iterators[1])
    test_data_iterator.append(iterators[2])
    else:
    train_data_iterator, valid_data_iterator, test_data_iterator = (
    build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
    )
    ...

    1. build_train_valid_test_data_iterators中首先会通过build_train_valid_test_data_loaders获取 train_dataloader, valid_dataloaders, test_dataloader

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider):
      """Build pretraining data iterators."""

      args = get_args()

      # Build loaders.
      train_dataloader, valid_dataloaders, test_dataloader = build_train_valid_test_data_loaders(
      build_train_valid_test_datasets_provider
      )

      # Build iterators.
      dl_type = args.dataloader_type
      assert dl_type in ['single', 'cyclic', 'external']

      def _get_iterator(dataloader_type, dataloader):
      """Return dataset iterator."""
      if dataloader_type == "single":
      return RerunDataIterator(iter(dataloader))
      elif dataloader_type == "cyclic":
      return RerunDataIterator(iter(cyclic_iter(dataloader)))
      elif dataloader_type == "external":
      # External dataloader is passed through. User is expected to define how to iterate.
      if isinstance(dataloader, list):
      return [RerunDataIterator(d) for d in dataloader]
      else:
      return RerunDataIterator(dataloader)
      else:
      raise RuntimeError("unexpected dataloader type")

      if train_dataloader is not None:
      train_data_iterator = _get_iterator(dl_type, train_dataloader)
      else:
      train_data_iterator = None

      # when using full validation, we need to override eval iters with the correct
      # number of iterations on tp rank 0 so that it can be distributed to the other
      # ranks later
      if args.full_validation:
      if args.multiple_validation_sets:
      if valid_dataloaders[0] is None:
      args.eval_iters = [None]*len(valid_dataloaders)
      else:
      args.eval_iters = [len(dl) for dl in valid_dataloaders]
      else:
      args.eval_iters = len(valid_dataloaders[0])

      if args.multiple_validation_sets:
      if valid_dataloaders[0] is None:
      valid_data_iterators = [None] * len(valid_dataloaders)
      else:
      valid_dl_type = "cyclic" if args.full_validation else dl_type
      print(
      f"[VALID DATA LOADER LENGTHS] "
      ", ".join(f"{idx}: {len(dl)}" for idx, dl in enumerate(valid_dataloaders))
      )
      valid_data_iterators = [
      _get_iterator(valid_dl_type, dl) for dl in valid_dataloaders
      ]
      elif valid_dataloaders[0] is not None:
      valid_data_iterators = _get_iterator(dl_type, valid_dataloaders[0])
      else:
      valid_data_iterators = None

      if test_dataloader is not None:
      test_data_iterator = _get_iterator(dl_type, test_dataloader)
      else:
      test_data_iterator = None

      return train_data_iterator, valid_data_iterators, test_data_iterator

      1. build_train_valid_test_data_loaders函数如下所示:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider):
      """Build pretraining data loaders."""

      args = get_args()

      (train_dataloader, valid_dataloaders, test_dataloader) = (None, None, None)

      print_rank_0('> building train, validation, and test datasets ...')

      # Backward compatibility, assume fixed batch size.
      if args.iteration > 0 and args.consumed_train_samples == 0:
      assert (
      args.train_samples is None
      ), 'Only backward compatiblity support for iteration-based training'
      args.consumed_train_samples = args.iteration * args.global_batch_size
      if args.iteration > 0 and args.consumed_valid_samples == 0:
      if args.train_samples is None:
      args.consumed_valid_samples = (
      (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size
      )

      # Rely on distributed-aware core datasets, temporary
      is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False)

      # Construct the data pipeline
      if is_distributed or mpu.get_tensor_model_parallel_rank() == 0:

      # Build datasets.
      train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
      build_train_valid_test_datasets_provider
      )
      valid_ds = [valid_ds] if not isinstance(valid_ds, list) else valid_ds

      # Build dataloders.
      train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples)

      valid_dataloaders = []
      for valid_d in valid_ds:
      if args.skip_train or args.full_validation:
      valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0))
      else:
      if args.multiple_validation_sets:
      # TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not
      # correct and needs to be calculated/set per validation set
      raise NotImplementedError("--multiple-validation-sets currently requires --full-validation")
      valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples))
      if not args.multiple_validation_sets:
      assert len(valid_dataloaders) == 1
      test_dataloader = build_pretraining_data_loader(test_ds, 0)

      # Flags to know if we need to do training/validation/testing.
      do_train = train_dataloader is not None and args.train_iters > 0
      do_valid = valid_dataloaders is not None and (args.full_validation or args.eval_iters > 0)
      do_test = test_dataloader is not None and (args.full_validation or args.eval_iters > 0)
      flags = torch.tensor(
      [int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device='cuda'
      )
      else:
      flags = torch.tensor([0, 0, 0], dtype=torch.long, device='cuda')

      torch.distributed.broadcast(flags, 0)

      args.do_train = getattr(args, "do_train", False) or flags[0].item()
      args.do_valid = getattr(args, "do_valid", False) or flags[1].item()
      args.do_test = getattr(args, "do_test", False) or flags[2].item()

      return train_dataloader, valid_dataloaders, test_dataloader

      • 其首先会补全出当前已消耗的train、valid样本数量,以避免断点重训后还使用同样的数据。

      • 然后获取build_train_valid_test_datasets_provideris_distributed参数,如果为true才认为需要构建数据集。

      • 然后其调用build_train_valid_test_datasets,通过简单计算train、valid、test需要的样本数量借助用户自定义的train_valid_test_datasets_provider获得train_ds, valid_ds, test_ds

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      def build_train_valid_test_datasets(build_train_valid_test_datasets_provider):
      """Build pretraining datasets."""
      train_valid_test_num_samples = get_train_valid_test_num_samples()
      print_rank_0(' > datasets target sizes (minimum size):')
      print_rank_0(' train: {}'.format(train_valid_test_num_samples[0]))
      print_rank_0(' validation: {}'.format(train_valid_test_num_samples[1]))
      print_rank_0(' test: {}'.format(train_valid_test_num_samples[2]))
      return build_train_valid_test_datasets_provider(train_valid_test_num_samples)

      def get_train_valid_test_num_samples():
      """Train/valid/test num samples."""

      args = get_args()

      # Number of train/valid/test samples.
      if args.train_samples:
      train_samples = args.train_samples
      else:
      train_samples = args.train_iters * args.global_batch_size
      if args.full_validation:
      eval_samples = None
      else:
      eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
      eval_samples = eval_iters * args.global_batch_size
      test_iters = args.eval_iters

      return (train_samples, eval_samples, test_iters * args.global_batch_size)
      • 然后再将train_ds, valid_ds, test_ds与当前已消耗的train、valid样本数量结合利用build_pretraining_data_loader构造出train_dataloader、valid_dataloaders、test_dataloader ,如果对应的dataloader不为空就设置对应args.do_train、args.do_valid、args.do_test。

        1. build_pretraining_data_loader代码如下所示

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          def build_pretraining_data_loader(dataset, consumed_samples):
          """Build dataloader given an input dataset."""

          if dataset is None:
          return None
          args = get_args()

          if hasattr(dataset,'split'):
          split = dataset.split
          elif hasattr(dataset,'index_split'):
          split = dataset.index_split
          else:
          split = None

          if split == Split.valid and args.full_validation:
          batch_sampler = MegatronPretrainingSampler(
          total_samples=len(dataset),
          consumed_samples=0,
          micro_batch_size=args.micro_batch_size,
          data_parallel_rank=mpu.get_data_parallel_rank(),
          data_parallel_size=mpu.get_data_parallel_world_size())
          elif args.dataloader_type == 'single':
          # Megatron sampler
          batch_sampler = MegatronPretrainingSampler(
          total_samples=len(dataset),
          consumed_samples=consumed_samples,
          micro_batch_size=args.micro_batch_size,
          data_parallel_rank=mpu.get_data_parallel_rank(),
          data_parallel_size=mpu.get_data_parallel_world_size())
          elif args.dataloader_type == 'cyclic':
          batch_sampler = MegatronPretrainingRandomSampler(
          dataset,
          total_samples=len(dataset),
          consumed_samples=consumed_samples,
          micro_batch_size=args.micro_batch_size,
          data_parallel_rank=mpu.get_data_parallel_rank(),
          data_parallel_size=mpu.get_data_parallel_world_size(),
          data_sharding=args.data_sharding)
          elif args.dataloader_type == "external":
          # External dataloaders are passed through. User is expected to provide a
          # torch-compatible dataloader and define samplers, if needed.
          return dataset
          else:
          raise Exception('{} dataloader type is not supported.'.format(
          args.dataloader_type))

          # Torch dataloader.
          return torch.utils.data.DataLoader(dataset,
          batch_sampler=batch_sampler,
          num_workers=args.num_workers,
          pin_memory=True,
          persistent_workers=True if args.num_workers > 0 else False,
          )
        2. 这里有Smapler,其中MegatronPretrainingSampler是按序读取、可恢复的Sampler,MegatronPretrainingRandomSampler是随机采样(基于当前epoch做随机数)、可无限循环的Smapler。

        3. 最后基于Samplerydataset返回了标准的torch.utils.data.DataLoader

        4. 注意这里对于valid数据集且需要full_validation,即数据全跑一遍的情况,构建了consumed_samples=0的MegatronPretrainingSampler。此外也支持通过args.dataloader_type == "external"自定义Dataloader

    2. 如果dataloader非空就通过_get_iterator依据dataloader_type对齐进行包装,包装成RerunDataIterator以支持rerun容错重跑

    3. 注意对于valid_dataloaders,如果参数配置了full_validation需要更新eval_iters为全部的iters

  • train_data_iterator与valid_data_iterator会被传入到train函数中进行训练,test_data_iterator会在训练完后如果配置了args.do_test就最最终的测试。

数据集构造关键类介绍

BlendedMegatronDatasetBuilder

BlendedMegatronDatasetBuilder主要是支持数据集混合功能,例如将常识数据集与代码数据集混合,其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
class BlendedMegatronDatasetBuilder(object):
"""Builder class for the BlendedDataset and MegatronDataset classes

Args:
cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset

sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split

is_built_on_rank (Callable): A callable which returns True if the dataset should be built on
the current rank and False otherwise. It should be Megatron Core parallelism aware i.e.
global rank, local group rank, and virtual rank may inform its return value.

config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
"""

def __init__(
self,
cls: Type[MidLevelDataset],
sizes: List[int],
is_built_on_rank: Callable,
config: BlendedMegatronDatasetConfig,
):
self.cls = cls
self.sizes = sizes
self.is_built_on_rank = is_built_on_rank
self.config = config

log_single_rank(
logger,
logging.INFO,
f"Building {cls.__name__} splits with sizes={self.sizes} and config={self.config}",
)

if not self.config.mock:
for split in Split:
size_is_none = self.sizes[split.value] is None
if self.config.blend_per_split is None:
weights_are_none = self.config.blend[1] is None
else:
if self.config.blend_per_split[split.value] is None:
continue
weights_are_none = self.config.blend_per_split[split.value][1] is None
if size_is_none:
assert weights_are_none, f"""size_is_none => weights_are_none fails
for {split.name} split
This can occur with multiple validation sets if datasets have weights"""

if torch.distributed.is_initialized():
gb_rank = torch.distributed.get_rank()
if gb_rank == 0:
assert (
self.is_built_on_rank()
), "is_built_on_rank must return True when global rank = 0"

def build(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)

This method is distributed-aware and must be called on all ranks.

The dataset splits returned can vary according to the config. Supply config.blend and
config.split to build BlendedDataset and/or MegatronDataset splits from the same
distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset
splits from separate distributions. In either case, for each split, handle the following
cases:

(1) The split is None
- do nothing

(2) The split has one contributing dataset, and...

(a) 'size' is not None
- Build a mid-level dataset with low-level dataset sampling in proportion to the
size

(b) 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling

(3) The split has multiple contributing datasets, and...

(a) 'weights' is not None and 'size' is not None
- Build mid-level datasets with low-level dataset sampling in proportion to their
weights and the size
- Build a top-level dataset of length marginally greater than 'size' with mid-level
dataset sampling in proportion to their weights and the size

(b) 'weights' is not None and 'size' is None
- Error

(c) 'weights' is None and 'size' is not None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset of length 'size' (capped at the sum of the mid-level
dataset lengths) with mid-level dataset sampling in proportion to their lengths
and the size

(d) 'weights' is None and 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset with no excess mid-level dataset sampling

Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
datasets = self._build_blended_dataset_splits()

for dataset in datasets:
if dataset is not None and len(dataset) > 0:
if isinstance(dataset, BlendedDataset):
assert dataset.size is None or dataset.size == len(dataset)
elif isinstance(dataset, MegatronDataset):
assert dataset.num_samples is None or dataset.num_samples <= len(dataset)

return datasets

def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)

See the BlendedMegatronDatasetBuilder.build alias for more information.

Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
##
# Return fake "mock" datasets
##
if self.config.mock:
split = self.config.split_matrix
try:
return self._build_megatron_dataset_splits(None, split, self.sizes)
except Exception as error:
raise Exception(
f"{self.cls.__name__} failed to build as a mock data generator"
) from error

##
# All splits come from the same distribution
##
elif self.config.blend:
prefixes, weights = self.config.blend
if weights is not None:
weights = normalize(weights)

split = self.config.split_matrix

# Blend consists of a single prefix
if len(prefixes) == 1 and weights is None:
return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes)

# Build the mid-level datasets
if weights is None:
# Build only one "epoch"
sizes_per_dataset_buffer = [[None for split in Split] for prefix in prefixes]
else:
# The number of samples we plan to use per dataset
sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes)
# The number of samples we plan to build per dataset
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, self.sizes, surplus=self.config.mid_level_dataset_surplus
)

# Build each dataset in parallel
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split, sizes_per_dataset_buffer
)

# Build the top-level datasets
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
if split[i] is not None:
weights_i = weights
if weights_i is not None and self.sizes[i] is not None:
# Blend according to client-specified weights and client-specified size
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size_i = sum(size_per_dataset)
elif weights_i is None:
# Blend according to dataset sizes as-is and (maybe) client-specified size
try:
weights_i = [
len(megatron_dataset) for megatron_dataset in megatron_datasets[i]
]
except TypeError:
weights_i = [0 for _ in prefixes]
if self.sizes[i] is not None:
size_i = min(self.sizes[i], sum(weights_i))
else:
# Build exhaustive indices
size_i = None
else:
raise ValueError(
"Using client-specified weights requires client-specified size"
)
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True, # synchronize_ranks, default behavior to build on rank-0 first
megatron_datasets[i],
weights_i,
size_i,
self.config,
)

return blended_datasets

##
# Each split comes from a separate distribution
##
else:
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
split_spoof = [None] * len(Split)
split_spoof[i] = (0.0, 1.0)
sizes_spoof = [0] * len(Split)
sizes_spoof[i] = self.sizes[i]

# Blend is provided for the split
blend = self.config.blend_per_split[i]
if blend is not None:
prefixes, weights = blend
if weights is not None:
weights = normalize(weights)

# Blend consists of a sigle prefix
if len(prefixes) == 1:
blended_datasets[i] = self._build_megatron_dataset_splits(
prefixes[0], split_spoof, sizes_spoof
)[i]
continue
elif self.config.multiple_validation_sets and i == Split.valid.value:
# handle multiple validation sets
validation_datasets = []
if self.config.full_validation:
# verify that size is None, which causes a single epoch dataset
# to be built
assert sizes_spoof[i] is None
for prefix in prefixes:
ds = self._build_megatron_dataset_splits(
prefix, split_spoof, sizes_spoof
)[i]
validation_datasets.append(ds)
blended_datasets[i] = validation_datasets
continue

# Build mid-level datasets
if weights is None:
sizes_per_dataset_buffer = [
[None for split in Split] for prefix in prefixes
]
else:
# The number of samples we plan to use per dataset
sizes_per_dataset_target = _get_size_per_split_per_dataset(
weights, sizes_spoof
)
# The number of samples we plan to build per dataset
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, sizes_spoof, surplus=self.config.mid_level_dataset_surplus
)

# Build each dataset in parallel
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split_spoof, sizes_per_dataset_buffer
)[i]

# Build top-level dataset
if weights is not None and self.sizes[i] is not None:
# Blend according to client-specified weights and client-specified size
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size = sum(size_per_dataset)
elif weights is None:
# Blend according to dataset sizes as-is and (maybe) client-specified size
try:
weights = [
len(megatron_dataset) for megatron_dataset in megatron_datasets
]
except TypeError:
weights = [0 for _ in prefixes]
if self.sizes[i] is not None:
size = min(self.sizes[i], sum(weights))
else:
# Build exhaustive indices
size = None
else:
raise RuntimeError
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True, # synchronize_ranks, default behavior to build on rank-0 first
megatron_datasets,
weights,
size,
self.config,
)

return blended_datasets

def _build_megatron_datasets_parallel(
self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]]
) -> List[List[Optional[MegatronDataset]]]:
"""Build the megatron datasets for a list of prefixes in parallel

Args:
prefixes (List[str]): The list of prefix strings

split (List[float]): The dataset split ratios (must sum to 1.00)

sizes_per_dataset (List[List[int]]): The number of samples to request
per MegatronDataset per spilt

Returns:
List[List[Optional[MegatronDataset]]]: For each split, have a list of
MegatronDataset per prefix
"""

# Helper function to wrap the threading logic
def _threading_helper(
megatron_datasets: List[List[Optional[MegatronDataset]]],
num_workers: int,
prefixes: List[str],
split: List[float],
sizes_per_dataset: List[List[int]],
) -> None:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
all_futures = []
for i in range(len(prefixes)):
all_futures.append(
executor.submit(
self._build_megatron_dataset_splits,
prefixes[i],
split,
sizes_per_dataset[i],
False, # synchronize_ranks, barrier is called in this function
)
)
for future in all_futures:
try:
megatron_datasets_split = future.result()
for j in range(len(megatron_datasets_split)):
megatron_datasets[j].append(megatron_datasets_split[j])
except Exception as err:
raise err

megatron_datasets = [[] for _ in range(len(Split))]
num_dataset_builder_threads = self.config.num_dataset_builder_threads

if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
# First, build on rank 0
if rank == 0:
num_workers = num_dataset_builder_threads
if num_workers > 1:
# since only rank 0 is running, scale up the thread count
# but not too much to avoid overloading storage on miss path.
# if user set num_dataset_builder_threads to 1,
# i.e. meant for serial build, do not scale up.
num_workers *= min(2, max(1, torch.cuda.device_count()))
_threading_helper(
megatron_datasets, num_workers, prefixes, split, sizes_per_dataset
)

torch.distributed.barrier()

# Then, build on other ranks; guaranteed to be data_cache hit
if rank != 0:
_threading_helper(
megatron_datasets,
num_dataset_builder_threads,
prefixes,
split,
sizes_per_dataset,
)
else:
_threading_helper(
megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset
)

return megatron_datasets

def _build_megatron_dataset_splits(
self,
dataset_path: Optional[str],
split: List[float],
sizes: List[int],
synchronize_ranks: bool = True,
) -> List[Optional[MidLevelDataset]]:
"""Build each MidLevelDataset split from a single LowLevelDataset

Args:
dataset_path (Optional[str]): The path on disk which defines the underlying
LowLevelDataset, or None for mock dataset classes

split (List[Tuple[float, float]]): The dataset split matrix

sizes (List[int]): The number of total samples to draw from each split

synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.

Returns:
List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split
"""
# short-cut if we are not building on this rank
if torch.distributed.is_initialized() and not self.is_built_on_rank():
for i in range(len(Split)):
if split[i] is not None and synchronize_ranks:
torch.distributed.barrier()
return [None] * len(Split)

# Build the low level dataset
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)

# Build the split indices for the low level dataset
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
split_indices = []
for i, _ in enumerate(Split):
if split[i] is not None:
beg = int(round(split[i][0] * float(num_elements)))
end = int(round(split[i][1] * float(num_elements)))
split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32))
else:
split_indices.append(None)

# Build the mid level dataset
mid_level_datasets = []
for i, _split in enumerate(Split):
if split[i] is None:
mid_level_datasets.append(None)
else:
mid_level_datasets.append(
self.build_generic_dataset(
self.cls,
self.is_built_on_rank,
synchronize_ranks,
low_level_dataset,
dataset_path,
split_indices[i],
sizes[i],
_split,
self.config,
)
)

return mid_level_datasets

@staticmethod
def build_generic_dataset(
cls: Union[Type[DistributedDataset], Callable],
is_built_on_rank: Callable,
synchronize_ranks: bool,
*args: Any,
) -> Optional[Union[DistributedDataset, Iterable]]:
"""Build the DistributedDataset

Return None if and only if the underlying dataset class is not built on the current rank
and torch.distributed is initialized.

Args:
cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be
built. In special cases, e.g. when we are building the low level dataset for a
RawMegatronDataset instance, we can accept a Callable which returns an Iterable.

synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.

args (Tuple[Any]): The positional arguments used to build the provided
DistributedDataset class

Raises:
Exception: When the dataset constructor raises an OSError

Returns:
Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the
Iterable instantiation, or None
"""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()

dataset = None

# First, build on rank 0
if rank == 0 and is_built_on_rank():
try:
dataset = cls(*args)
except OSError as err:
log = (
f"Failed to write dataset materials to the data cache directory. Please "
f"supply a directory to which you have write access via the path_to_cache "
f"attribute in BlendedMegatronDatasetConfig and retry. Refer to the "
f"preserved traceback above for more information."
)
raise Exception(log) from err

if synchronize_ranks:
torch.distributed.barrier()

# After, build on other ranks
if rank != 0 and is_built_on_rank():
dataset = cls(*args)

return dataset

return cls(*args)

不过这里我们暂时不考虑数据集混合的情况,而是先看单一数据集下如何处理的。

单一数据集下会走进_build_blended_dataset_splits的如下代码:

1
2
3
4
# Blend consists of a single prefix
if len(prefixes) == 1 and weights is None:
return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes)

_build_megatron_dataset_splits中的处理流程如下:

  1. 如果当前是不需要创建数据集的rank(is_dataset_built_on_rank)就进行同步等待,即只有最前和最后的pp并行的rank以及tp的第一位才需要构建,从而避免资源浪费。
1
2
3
4
5
6
def is_dataset_built_on_rank():
return (
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0

  • 然后通过self.cls.build_low_level_dataset(dataset_path, self.config)构建 low-level dataset,这里我们查看的是GPTDataset,如下所示

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25

    @staticmethod
    def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset:
    """Abstract method implementation

    Args:
    dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files

    config (GPTDatasetConfig): The config

    Returns:
    IndexedDataset: The underlying IndexedDataset
    """
    if is_object_storage_path(dataset_path):
    assert config.object_storage_cache_path is not None
    return IndexedDataset(
    dataset_path,
    multimodal=False,
    mmap=config.mmap_bin_files,
    object_storage_config=ObjectStorageConfig(
    path_to_idx_cache=config.object_storage_cache_path
    ),
    )
    return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files)

    1. 从代码可以看到其本质是构建了一个IndexedDataset,代码如下所示

    2. 它使用的数据是一对文件:

      • path_prefix.idx:索引文件,记录每条样本(sequence)的长度、在 .bin 里的字节偏移(pointer)、以及(可选)multimodal 的 mode;还记录文档边界(document_indices)。

      • path_prefix.bin:真实 token 数据(连续存储的定长 dtype 数组)。

    3. 其核心能力是:

      • 提供O(1) 级别的随机访问能力:通过 .idx 找到第 i 条样本在 .bin 中的 offset 和 length,然后从 .bin 读取对应 token 序列。

      • 支持切片读取dataset[start:stop] 允许一次读连续多条样本(step 必须为 1),避免逐条调用导致频繁 IO。

      • 部分读取get(idx, offset, length) 可以只取某条样本的一段 token(用于截断、窗口等场景)。

      • 高效读取策略可选:通过 mmap=True/False 选择用内存映射(_MMapBinReader)或文件读(_FileBinReader);如果数据在对象存储(S3/MSC),则用对应 reader 分块拉取,并把 .idx 缓存到本地。

      • **exists(path_prefix)**:检查 .idx/.bin 是否存在(本地或对象存储)。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    class IndexedDataset(torch.utils.data.Dataset):
    """The low-level interface dataset class

    Args:
    path_prefix (str): The index (.idx) and data (.bin) prefix

    multimodal (bool): Whether the dataset is multimodal. Defaults to False.

    mmap (bool): Whether to mmap the .bin files. Defaults to True.

    object_storage_config (Optional[ObjectStorageConfig]): Supplied only for data stored on S3
    or MSC. IndexedDataset downloads the index (.idx) file to
    `object_storage_config.path_to_idx_cache` and streams data from the data (.bin) file
    in `object_storage_config.bin_chunk_nbytes` blocks. Note that `mmap` must be disabled
    for S3 data loading. Defaults to None.
    """

    def __init__(
    self,
    path_prefix: str,
    multimodal: bool = False,
    mmap: bool = True,
    object_storage_config: Optional[ObjectStorageConfig] = None,
    s3_config: Optional[S3Config] = None,
    ) -> None:
    super().__init__()
    self.path_prefix: str
    self.multimodal: bool
    self.mmap: bool
    self.object_storage_config: Optional[ObjectStorageConfig]

    self.bin_reader: _BinReader
    self.index: _IndexReader

    # Deprecated: s3_config is deprecated, use object_storage_config instead
    object_storage_config = object_storage_config or s3_config

    # Cache the index file if it is stored on object storage
    if is_object_storage_path(path_prefix) and object_storage_config is not None:
    idx_path = get_idx_path(path_prefix)
    cache_idx_path = get_index_cache_path(idx_path, object_storage_config)
    cache_index_file(idx_path, cache_idx_path)

    self.initialize(path_prefix, multimodal, mmap, object_storage_config)

    def initialize(
    self,
    path_prefix: str,
    multimodal: bool,
    mmap: bool,
    object_storage_config: Optional[ObjectStorageConfig],
    ) -> None:
    """Initialize the dataset

    This method is called by IndexedDataset.__init__ during object creation and by
    IndexedDataset.__setstate__ during un-pickling

    Args:
    path_prefix (str): The index (.idx) and data (.bin) prefix

    multimodal (bool): Whether the dataset is multimodal

    mmap (bool): Whether to mmap the .bin file

    object_storage_config (Optional[ObjectStorageConfig]): See IndexedDataset docstring
    for details.
    """
    idx_path = get_idx_path(path_prefix)
    bin_path = get_bin_path(path_prefix)
    if object_storage_config is None:
    assert os.path.exists(idx_path) and os.path.exists(
    bin_path
    ), "One or both of the .idx and .bin files cannot be found at the "
    f"path prefix {path_prefix}"
    self.path_prefix = path_prefix
    self.multimodal = multimodal
    self.mmap = mmap
    self.object_storage_config = object_storage_config
    if mmap:
    assert not object_storage_config
    self.bin_reader = _MMapBinReader(bin_path)
    elif object_storage_config:
    assert not mmap
    self.bin_reader = OBJECT_STORAGE_BIN_READERS[get_object_storage_access(path_prefix)](
    bin_path, object_storage_config
    )
    idx_path = get_index_cache_path(get_idx_path(path_prefix), object_storage_config)
    else:
    self.bin_reader = _FileBinReader(bin_path)
    self.index = _IndexReader(idx_path, self.multimodal)

    def __getstate__(self) -> Tuple[str, bool, bool, Optional[ObjectStorageConfig]]:
    """Get the state during pickling

    Returns:
    Tuple[str, bool, bool, Optional[ObjectStorageConfig]]: The state tuple
    """
    return self.path_prefix, self.multimodal, self.mmap, self.object_storage_config

    def __setstate__(self, state: Tuple[str, bool, bool, Optional[ObjectStorageConfig]]) -> None:
    """Set the state during un-pickling

    Args:
    state (Tuple[str, bool, bool, Optional[ObjectStorageConfig]]): The state tuple
    """
    path_prefix, multimodal, mmap, object_storage_config = state
    self.initialize(path_prefix, multimodal, mmap, object_storage_config)

    def __del__(self) -> None:
    """Clean up the object"""
    del self.bin_reader
    del self.index

    def __len__(self) -> int:
    """Return the length of the dataset i.e. the number of sequences in the index

    Returns:
    int: The length of the dataset
    """
    return len(self.index)

    def __getitem__(
    self, idx: Union[int, numpy.integer, slice]
    ) -> Union[
    numpy.ndarray,
    Tuple[numpy.ndarray, numpy.number],
    List[numpy.ndarray],
    Tuple[List[numpy.ndarray], numpy.ndarray],
    ]:
    """Return from the dataset

    Args:
    idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset

    Raises:
    ValueError: When the index slice is non-contiguous

    TypeError: When the index is of an unexpected type

    Returns:
    Union[
    numpy.ndarray,
    Tuple[numpy.ndarray, numpy.number],
    List[numpy.ndarray],
    Tuple[List[numpy.ndarray], numpy.ndarray],
    ]: The sequence tokens and modes at the index or index slice
    """
    if isinstance(idx, (int, numpy.integer)):
    sequence_pointer, sequence_length, sequence_mode = self.index[idx]
    sequence = self.bin_reader.read(
    dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer
    )
    return (sequence, sequence_mode) if sequence_mode is not None else sequence
    elif isinstance(idx, slice):
    start, stop, step = idx.indices(len(self))
    if step != 1:
    raise ValueError("Slices into indexed_dataset must be contiguous")
    sequence_lengths = self.index.sequence_lengths[idx]
    sequence_modes = (
    self.index.sequence_modes[idx] if self.multimodal else None # type: ignore[index]
    )
    sequence_offsets = list(accumulate(sequence_lengths))
    sequences = numpy.split(
    self.bin_reader.read(
    dtype=self.index.dtype,
    count=sum(sequence_lengths),
    offset=self.index.sequence_pointers[start],
    ),
    sequence_offsets[:-1],
    )
    return (sequences, sequence_modes) if sequence_modes is not None else sequences
    else:
    raise TypeError("Unexpected type received for idx: {}".format(type(idx)))

    def get(
    self, idx: int, offset: int = 0, length: Optional[int] = None
    ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]:
    """Retrieve a single item from the dataset with the option to only
    return a portion of the item.

    get(idx) is the same as [idx] but get() does not support slicing.

    Args:
    idx (Union[int, numpy.integer]): The index into the dataset

    offset (int): The integer token offset in the sequence

    length (int): The number of tokens to grab from the sequence

    Returns:
    Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]: The sequence tokens and mode
    at the index
    """
    sequence_pointer, sequence_length, sequence_mode = self.index[idx]
    if length is None:
    length = sequence_length - offset
    sequence_pointer += offset * DType.size(self.index.dtype)
    sequence = self.bin_reader.read(
    dtype=self.index.dtype, count=length, offset=sequence_pointer
    )
    return (sequence, sequence_mode) if sequence_mode is not None else sequence

    @property
    def sequence_lengths(self) -> numpy.ndarray:
    """Get the sequence lengths

    Returns:
    numpy.ndarray: The sequence lengths
    """
    return self.index.sequence_lengths

    @property
    def document_indices(self) -> numpy.ndarray:
    """Get the document indices

    Returns:
    numpy.ndarray: The document indices
    """
    return self.index.document_indices

    def get_document_indices(self) -> numpy.ndarray:
    """Get the document indices

    This method is slated for deprecation.

    Returns:
    numpy.ndarray: The document indices
    """
    return self.index.document_indices

    def set_document_indices(self, document_indices: numpy.ndarray) -> None:
    """Set the document indices

    This method is slated for deprecation.

    Args:
    document_indices (numpy.ndarray): The document indices
    """
    self.index.document_indices = document_indices

    @property
    def sequence_modes(self) -> numpy.ndarray:
    """Get the sequence modes

    Returns:
    numpy.ndarray: The sequence modes
    """
    assert self.index.sequence_modes
    return self.index.sequence_modes

    @staticmethod
    def exists(path_prefix: str) -> bool:
    """Return whether the IndexedDataset exists on disk at the prefix

    Args:
    path_prefix (str): The prefix to the index (.idx) and data (.bin) files

    Returns:
    bool: Whether the IndexedDataset exists on disk at the prefix
    """
    if is_object_storage_path(path_prefix):
    return dataset_exists(path_prefix, get_idx_path(path_prefix), get_bin_path(path_prefix))

    return os.path.exists(get_idx_path(path_prefix)) and os.path.exists(
    get_bin_path(path_prefix)
    )

  • 读取这个low-level dataset中一共有多少样本,然后依据split来计算按比例分割下的各实际范围

  • 然后调用build_generic_dataset为各范围构建mid-level dataset并返回,这也就是我们得到的 train_ds, valid_ds, test_ds

    1. 其首先是让rank 0构建mid-level dataset也就是实际的GPTDataset,然后让其他rank都等待

    2. 再让rank不为 0并且is_built_on_rank的rank构建GPTDataset,并对应返回该GPTDataset

    rank 0 先构建是为了把“构建 dataset 时可能产生的共享缓存写入”变成“单进程写入 + 多进程读取”,从而避免竞态、提升缓存命中、保证 barrier 同步与流程一致性。

    • GPTDataset的作用是:在底层 IndexedDataset(存放 .bin/.idx 的 token 序列)之上,构建可直接用于 GPT 自回归训练的 PyTorch Dataset。其主要作用有:

      • 把原始序列拼接/切片成固定长度样本:通过构建 document_index / sample_index / shuffle_index,把很多条变长序列按文档顺序拼接,然后切成长度为 sequence_length(可带 1 个 extra token)的训练样本。

      • 提供训练所需张量:getitem 返回 tokens/labels,并生成(或复用缓存的)attention_mask / loss_mask / position_ids,满足左到右(causal)语言模型训练。

      • 支持可复现的 shuffle 与多 epoch 采样:用 shuffle_index 控制样本随机顺序;当 num_samples 大于一个 epoch 的可用样本时,会计算需要重复多少个 epoch。

      • 支持索引缓存:会把构建出来的 document_index.npy / sample_index.npy / shuffle_index.npy 写到 path_to_cache,下次启动直接加载,避免重复计算。

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      132
      133
      134
      135
      136
      137
      138
      139
      140
      141
      142
      143
      144
      145
      146
      147
      148
      149
      150
      151
      152
      153
      154
      155
      156
      157
      158
      159
      160
      161
      162
      163
      164
      165
      166
      167
      168
      169
      170
      171
      172
      173
      174
      175
      176
      177
      178
      179
      180
      181
      182
      183
      184
      185
      186
      187
      188
      189
      190
      191
      192
      193
      194
      195
      196
      197
      198
      199
      200
      201
      202
      203
      204
      205
      206
      207
      208
      209
      210
      211
      212
      213
      214
      215
      216
      217
      218
      219
      220
      221
      222
      223
      224
      225
      226
      227
      228
      229
      230
      231
      232
      233
      234
      235
      236
      237
      238
      239
      240
      241
      242
      243
      244
      245
      246
      247
      248
      249
      250
      251
      252
      253
      254
      255
      256
      257
      258
      259
      260
      261
      262
      263
      264
      265
      266
      267
      268
      269
      270
      271
      272
      273
      274
      275
      276
      277
      278
      279
      280
      281
      282
      283
      284
      285
      286
      287
      288
      289
      290
      291
      292
      293
      294
      295
      296
      297
      298
      299
      300
      301
      302
      303
      304
      305
      306
      307
      308
      309
      310
      311
      312
      313
      314
      315
      316
      317
      318
      319
      320
      321
      322
      323
      324
      325
      326
      327
      328
      329
      330
      331
      332
      333
      334
      335
      336
      337
      338
      339
      340
      341
      342
      343
      344
      345
      346
      347
      348
      349
      350
      351
      352
      353
      354
      355
      356
      357
      358
      359
      360
      361
      362
      363
      364
      365
      366
      367
      368
      369
      370
      371
      372
      373
      374
      375
      376
      377
      378
      379
      380
      381
      382
      383
      384
      385
      386
      387
      388
      389
      390
      391
      392
      393
      394
      395
      396
      397
      398
      399
      400
      401
      402
      403
      404
      405
      406
      407
      408
      409
      410
      411
      412
      413
      414
      415
      416
      417
      418
      419
      420
      421
      422
      423
      424
      425
      426
      427
      428
      429
      430
      431
      432
      433
      434
      435
      436
      437
      438
      439
      440
      441
      442
      443
      444
      445
      446
      447
      448
      449
      450
      451
      452
      453
      454
      455
      456
      457
      458
      459
      460
      461
      462
      463
      464
      465
      466
      467
      468
      469
      470
      471
      472
      473
      474
      475
      476
      477
      478
      479
      480
      481
      482
      483
      484
      485
      486
      487
      class GPTDataset(MegatronDataset):
      """The base GPT dataset

      Args:
      indexed_dataset (IndexedDataset): The IndexedDataset around which to build the GPTDataset

      dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping

      indexed_indices (numpy.ndarray): The set of the documents indices to expose

      num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When
      None, build as many samples as correspond to one epoch.

      index_split (Split): The indexed_indices Split

      config (GPTDatasetConfig): The config
      """

      def __init__(
      self,
      indexed_dataset: IndexedDataset,
      dataset_path: Optional[str],
      indexed_indices: numpy.ndarray,
      num_samples: Optional[int],
      index_split: Split,
      config: GPTDatasetConfig,
      ) -> None:
      super().__init__(
      indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
      )
      self.masks_and_position_ids_are_cacheable = not any(
      [
      self.config.reset_position_ids,
      self.config.reset_attention_mask,
      self.config.eod_mask_loss,
      ]
      )
      self.masks_and_position_ids_are_cached = False
      self.cached_attention_mask = None
      self.cached_loss_mask = None
      self.cached_position_ids = None

      try:
      self._pad_token_id = self.config.tokenizer.pad
      except Exception:
      self._pad_token_id = _PAD_TOKEN_ID

      (self.document_index, self.sample_index, self.shuffle_index) = (
      self._build_document_sample_shuffle_indices()
      )

      @staticmethod
      def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int:
      """Abstract method implementation

      For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say,
      BERT, which should be split by document

      Args:
      low_level_dataset (IndexedDataset): The underlying IndexedDataset

      Returns:
      int: The number of unique elements in the underlying IndexedDataset
      """
      return low_level_dataset.sequence_lengths.shape[0]

      @staticmethod
      def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset:
      """Abstract method implementation

      Args:
      dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files

      config (GPTDatasetConfig): The config

      Returns:
      IndexedDataset: The underlying IndexedDataset
      """
      if is_object_storage_path(dataset_path):
      assert config.object_storage_cache_path is not None
      return IndexedDataset(
      dataset_path,
      multimodal=False,
      mmap=config.mmap_bin_files,
      object_storage_config=ObjectStorageConfig(
      path_to_idx_cache=config.object_storage_cache_path
      ),
      )
      return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files)

      def __len__(self) -> int:
      """Abstract method implementation

      Returns:
      int: The length of the dataset
      """
      return self.sample_index.shape[0] - 1

      def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
      """Abstract method implementation

      Args:
      idx (Optioal[int]): The index into the dataset

      Returns:
      Dict[str, torch.Tensor]: The sample information wrapped in a dictionary
      """
      if idx is None:
      # Batch padding sequence so the index does not matter
      text, _ = self._query_document_sample_shuffle_indices(0)
      else:
      text, _ = self._query_document_sample_shuffle_indices(idx)

      text = torch.from_numpy(text).long()
      if self.config.add_extra_token_to_sequence:
      tokens = text[:-1].contiguous()
      labels = text[1:].contiguous()
      else:
      tokens = text
      labels = torch.roll(text, shifts=-1, dims=0)
      labels[-1] = self._pad_token_id

      if (
      not self.masks_and_position_ids_are_cacheable
      or not self.masks_and_position_ids_are_cached
      ):
      attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
      tokens,
      self.config.tokenizer.eod,
      self.config.reset_position_ids,
      self.config.reset_attention_mask,
      self.config.eod_mask_loss,
      self.config.create_attention_mask,
      )
      if self.masks_and_position_ids_are_cacheable:
      self.cached_attention_mask = attention_mask
      self.cached_loss_mask = loss_mask
      self.cached_position_ids = position_ids
      self.masks_and_position_ids_are_cached = True
      else:
      attention_mask = self.cached_attention_mask
      loss_mask = self.cached_loss_mask
      position_ids = self.cached_position_ids

      # For padded sequences, mask the loss
      loss_mask[labels == self._pad_token_id] = 0.0

      # For padded sequences, ensure the embedding layer can map the token ID
      tokens[tokens == self._pad_token_id] = 0
      labels[labels == self._pad_token_id] = 0

      # Batch padding sequence so we mask the loss
      if idx is None:
      loss_mask = torch.zeros_like(loss_mask)

      if self.config.create_attention_mask:
      return {
      "tokens": tokens,
      "labels": labels,
      "attention_mask": attention_mask,
      "loss_mask": loss_mask,
      "position_ids": position_ids,
      }
      else:
      return {
      "tokens": tokens,
      "labels": labels,
      "loss_mask": loss_mask,
      "position_ids": position_ids,
      }

      def _query_document_sample_shuffle_indices(
      self, idx: int
      ) -> Tuple[numpy.ndarray, numpy.ndarray]:
      """Get the text (token ids) and document ids for a given index

      Args:
      idx (int): The index into the dataset

      Returns:
      Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
      """
      # Do the shuffle mapping
      idx = self.shuffle_index[idx]

      # Get the beginning and end documents and offsets
      doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
      doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]

      document_ids = []
      sample_parts = []

      # Sample spans a single document
      if doc_index_beg == doc_index_end:
      # Add the document id
      document_ids.append(self.document_index[doc_index_beg])

      # Add the entire sample
      sample_parts.append(
      self.dataset.get(
      self.document_index[doc_index_beg],
      offset=doc_index_beg_offset,
      length=doc_index_end_offset
      - doc_index_beg_offset
      + self.config.add_extra_token_to_sequence,
      )
      )

      # Sample spans multiple documents
      else:
      for i in range(doc_index_beg, doc_index_end + 1):
      # Add the document id
      document_ids.append(self.document_index[i])

      # Add the sample part
      offset = 0 if i > doc_index_beg else doc_index_beg_offset
      length = (
      None
      if i < doc_index_end
      else doc_index_end_offset + self.config.add_extra_token_to_sequence
      )
      sample_parts.append(
      self.dataset.get(self.document_index[i], offset=offset, length=length)
      )
      assert len(document_ids) == len(
      sample_parts
      ), f"len(document_ids) ({len(document_ids)}) != len(sample_parts) ({len(sample_parts)})"

      length = sum(map(len, sample_parts))

      # Pad the sample if necessary
      if length < (self.config.sequence_length + self.config.add_extra_token_to_sequence):
      sample_parts.append(
      [self._pad_token_id]
      * (self.config.sequence_length + self.config.add_extra_token_to_sequence - length)
      )

      return (
      numpy.concatenate(sample_parts, dtype=numpy.int64),
      numpy.array(document_ids, dtype=numpy.int64),
      )

      def _build_document_sample_shuffle_indices(
      self,
      ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
      """Build the document index, the sample index, and the shuffle index

      The document index:
      -- 1-D
      -- An ordered array of document ids

      The sample index:
      -- 2-D
      -- The document indices and offsets which mark the start of every sample

      The shuffle index:
      -- 1-D
      -- A random permutation of index range of the sample index

      Returns:
      Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample
      index, and the shuffle index
      """
      path_to_cache = self.config.path_to_cache
      if path_to_cache is None and not self.config.mock:
      path_to_cache = os.path.join(
      self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
      )

      if path_to_cache:
      base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}"
      get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}")
      path_to_description = get_path_to("description.txt")
      path_to_document_index = get_path_to("document_index.npy")
      path_to_sample_index = get_path_to("sample_index.npy")
      path_to_shuffle_index = get_path_to("shuffle_index.npy")
      cache_hit = all(
      map(
      os.path.isfile,
      [
      path_to_description,
      path_to_document_index,
      path_to_sample_index,
      path_to_shuffle_index,
      ],
      )
      )
      else:
      cache_hit = False

      if not path_to_cache or (
      not cache_hit
      and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0)
      ):
      log_single_rank(
      logger,
      logging.INFO,
      f"Build and save the {type(self).__name__} {self.index_split.name} indices",
      )
      t_beg = time.time()

      sequence_length = self.config.sequence_length
      num_tokens_per_epoch = self._get_num_tokens_per_epoch()
      num_epochs = self._get_num_epochs(num_tokens_per_epoch)

      if num_epochs == 1:
      separate_final_epoch = False
      else:
      # Get the number of samples for the last epoch
      num_samples_sans_final_epoch = (
      (num_epochs - 1) * num_tokens_per_epoch
      - self.config.add_extra_token_to_sequence
      ) // sequence_length
      num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
      num_samples_per_epoch = (
      num_tokens_per_epoch - self.config.add_extra_token_to_sequence
      ) // sequence_length

      # num_samples_from_final_epoch should be non-negative
      assert num_samples_from_final_epoch >= 0

      # num_samples_from_final_epoch should not exceed max value
      assert num_samples_from_final_epoch <= num_samples_per_epoch + 1

      # Separate the final epoch if it falls below the threshold
      threshold = 0.80
      separate_final_epoch = num_samples_from_final_epoch < int(
      threshold * num_samples_per_epoch
      )

      log_single_rank(
      logger,
      logging.DEBUG,
      f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
      )
      log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
      log_single_rank(
      logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
      )

      log_single_rank(
      logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
      )

      numpy_random_state = numpy.random.RandomState(self.config.random_seed)

      # Build the document index
      document_index = _build_document_index(
      self.indices, num_epochs, numpy_random_state, separate_final_epoch
      )

      # Build the sample index
      from megatron.core.datasets import helpers

      if self.index_split == Split.valid:
      drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence
      else:
      drop_last_partial_sequence = True

      assert document_index.dtype == numpy.int32
      assert self.dataset.sequence_lengths.dtype == numpy.int32
      if len(document_index) * 2 > len(self.dataset.sequence_lengths):
      # If "access density" of sequence_lengths is high, force load the mmap-ed array
      # into memory by making a copy.
      #
      # System performance benefits come from two aspects:
      # 1. We sequentially pre-load the whole file, most of which we expect to read
      # 2. The GIL is held when entering the c++ program, improving the speed of which
      # improves parallelism
      sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy()
      else:
      sequence_lengths_for_cpp = self.dataset.sequence_lengths
      sample_index = helpers.build_sample_idx(
      sequence_lengths_for_cpp,
      document_index,
      sequence_length,
      num_epochs,
      num_tokens_per_epoch,
      drop_last_partial_sequence,
      self.config.add_extra_token_to_sequence,
      )

      # Build the shuffle index
      if separate_final_epoch:
      shuffle_index = _build_shuffle_index(
      num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state
      )
      else:
      shuffle_index = _build_shuffle_index(
      sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state
      )

      if path_to_cache:
      os.makedirs(path_to_cache, exist_ok=True)
      # Write the description
      with open(path_to_description, "wt") as writer:
      writer.write(self.unique_description)
      numpy.save(path_to_document_index, document_index, allow_pickle=True)
      numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
      numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
      else:
      log_single_rank(
      logger,
      logging.WARNING,
      f"Unable to save {type(self).__name__} indexes because path_to_cache is None",
      )

      t_end = time.time()
      log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")

      log_single_rank(
      logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
      )
      log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")

      return document_index, sample_index, shuffle_index

      log_single_rank(
      logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
      )

      log_single_rank(
      logger,
      logging.INFO,
      f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
      )
      t_beg = time.time()
      document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode="r")
      t_end = time.time()
      log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")

      log_single_rank(
      logger,
      logging.INFO,
      f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
      )
      t_beg = time.time()
      sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode="r")
      t_end = time.time()
      log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")

      log_single_rank(
      logger,
      logging.INFO,
      f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
      )
      t_beg = time.time()
      shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode="r")
      t_end = time.time()
      log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")

      log_single_rank(
      logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
      )

      return document_index, sample_index, shuffle_index

      def _get_num_tokens_per_epoch(self) -> int:
      """Calculate the number of tokens in a single epoch

      Returns:
      int: The number of tokens in a single epoch
      """
      return int(numpy.sum(self.dataset.sequence_lengths[self.indices]))

      def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
      """Calculate the number of epochs

      Args:
      num_tokens_per_epoch (int): The number of tokens in a single epoch

      Returns:
      int: The number of epochs
      """
      num_epochs = 1
      num_tokens = num_tokens_per_epoch
      if self.num_samples is None:
      return num_epochs
      else:
      num_tokens_requested = (
      self.num_samples * self.config.sequence_length
      ) + self.config.add_extra_token_to_sequence
      while num_tokens < num_tokens_requested:
      num_epochs += 1
      num_tokens += num_tokens_per_epoch
      return num_epochs

MegatronPretrainingSampler

在得到可以获取单个GPT训练条目的mid-level dataset(train_ds, valid_ds, test_ds)后,MegatronPretrainingSampler会对其进行包裹,MegatronPretrainingSampler作为一个批采样器(batch sampler):它不返回单个样本索引,而是每次 yield 一组索引(一个 micro-batch),供 PyTorch DataLoader(…, batch_sampler=…) 直接使用。它的核心目标是:

  • Data Parallel 训练中,把“全局 batch”(= micro_batch_size * data_parallel_size)按 DP rank 切分成每个 rank 自己的 micro-batch 索引。

  • 支持从某个 consumed_samples 开始继续取样(用于 resume / 断点续训 / 跳过已训练样本)。

代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class MegatronPretrainingSampler:

def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)

def __len__(self):
return self.total_samples

def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx

def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

  • 在初始化阶段,其记录了当前所属的dp rank,并且通过self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size计算了在数据并行下一次需要读取的实际数量

  • 其获取迭代数据时,遍历范围为(self.consumed_samples, self.total_samples),它会在其中获取micro_batch_times_data_parallel_size个训练样本的idx,然后依据当前所属的dp rank计算得到本rank实际需要的miro_batch个idx

DataLoader

build_pretraining_data_loader中最终会结合mid-level dataset(GPTDataset)以及MegatronPretrainingSampler得到torch.utils.data.DataLoader。从而支持通过该DataLoader获取一个mocro_batch大小的训练样本。

1
2
3
4
5
6
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True if args.num_workers > 0 else False,
)

数据集使用流程

  1. 在用户提供的forward_step函数中,就会传入类别为torch.utils.data.DataLoaderdata_iterator,如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
"""Forward training step.

Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
"""
args = get_args()
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers('batch-generator').stop()

with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
if return_schedule_plan:
assert args.overlap_moe_expert_parallel_comm, \
"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return schedule_plan, partial(loss_func, loss_mask, model=model)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)

# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)

  • 然后在获取训练数据的get_batch函数中会查看是否是pp并行的第一个或最后一个,如此才会去获取训练数据,获取数据依赖的是get_batch_on_this_tp_rank和get_batch_on_this_cp_rank,如下所示。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_batch(data_iterator):
"""Generate a batch."""

# TODO: this is pretty hacky, find a better way
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None

# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return batch.values()

  • 首先是get_batch_on_this_tp_rank,由于前述在构造数据集的时候只会在tp rank=0的时候构造,所以这里会查看当前rank,如果是tp rank=0的worker就通过next(data_iterator)获取一批micro_batch数据,然后与其他tp rank≠0的worker进行broadcast,传播各数据。代码如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def get_batch_on_this_tp_rank(data_iterator):

args = get_args()

def _broadcast(item):
if item is not None:
torch.distributed.broadcast(
item,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group(),
)

if mpu.get_tensor_model_parallel_rank() == 0:

if data_iterator is not None:
data = next(data_iterator)
else:
data = None

batch = {
'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': (
None
if "attention_mask" not in data
else data["attention_mask"].cuda(non_blocking=True)
),
'position_ids': data["position_ids"].cuda(non_blocking=True),
}

if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])

elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])

elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens'])
_broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])

else:

tokens = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
labels = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
loss_mask = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
if args.create_attention_mask_in_dataloader:
attention_mask = torch.empty(
(args.micro_batch_size, 1, args.seq_length, args.seq_length),
dtype=torch.bool,
device=torch.cuda.current_device(),
)
else:
attention_mask = None
position_ids = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)

if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)

elif mpu.is_pipeline_first_stage():
labels = None
loss_mask = None

_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)

elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens)
_broadcast(position_ids)
else:
tokens = None
position_ids = None

_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids,
}

return batch

  • 然后是执行get_batch_on_this_cp_rank,在上下文并行中,为了GPU负载均衡,我们往往采用的是之字型计算划分,如下所示。所以这里进行CP维度切分的核心思想是将序列切分为2*CP份,然后每第i个cp rank拿走前面的第i份以及倒数第i份,从而平衡计算负载。代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""

# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != "attention_mask" else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.zeros(2, dtype=torch.int64, device=val.device)
index[0].fill_(cp_rank)
index[1].fill_(2 * cp_size - cp_rank - 1)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

return batch

  • 得到batch中的tokens, labels, loss_mask, attention_mask, position_ids数据后,就可以调用model进行前向传播计算了。最终的结果就是每个dp会使用不同的micro_batch数据,同一个dp中只有pp并行中的第一位和最后一位获取了数据集,这些pp中的各tp获取的都是同一份数据,如果有cp会进一步对这数据进行切分。

数据并行训练

上述已经分析了数据并行中数据集是如何划分给各个worker的以及在一个训练step中是如何获取数据集的,现在看在训练过程中是如何进行dp间梯度同步的。

在一次训练迭代中,其train_step会完成一次完整的step,如下所示,其会调用forward_backward_func完成一个step的训练:

1
2
3
4
5
6
7
8
9
10
11
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False,
adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,
)

forward_backward_func依据pp并行有多种,我们这里查看没有pp并行的forward_backward_no_pipelining,其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def forward_backward_no_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: Optional[int] = None, # unused
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
grad_finalize_pgs: Optional[GradFinalizeProcessGroups] = None,
):
"""Run forward and backward passes with no pipeline parallelism"""

if grad_finalize_pgs is None:
tp_group = parallel_state.get_tensor_model_parallel_group()
cp_group = parallel_state.get_context_parallel_group()
embd_group = parallel_state.get_embedding_group(check_initialized=False)
pp_group = parallel_state.get_pipeline_model_parallel_group()
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
grad_finalize_pgs = GradFinalizeProcessGroups()
grad_finalize_pgs.tp = tp_group
grad_finalize_pgs.cp = cp_group
grad_finalize_pgs.embd = embd_group
grad_finalize_pgs.pos_embd = pos_emb_group
grad_finalize_pgs.pp = pp_group
grad_finalize_pgs.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)

elif grad_finalize_pgs is not None:
assert hasattr(grad_finalize_pgs, 'tp')
assert hasattr(grad_finalize_pgs, 'cp')
assert hasattr(grad_finalize_pgs, 'embd'), (
"grad_finalize_pgs must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group. If you are "
"using the default process group, please use `parallel_state.get_embedding_group()` "
"to get the process group. If you don't need explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pos_embd'), (
"grad_finalize_pgs must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group. "
"If you are using the default process group, "
"please use `parallel_state.get_position_embedding_group()` "
"to get the process group. If you don't need explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pp')
assert hasattr(grad_finalize_pgs, 'dp_cp')

if isinstance(model, list):
assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"

config = get_model_config(model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext

model_type = get_model_type(model)

forward_data_store = []
input_tensor, output_tensor_grad = None, None
total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")

if config.overlap_moe_expert_parallel_comm and not forward_only:
forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
output_tensor_grad,
forward_data_store,
config,
collect_non_loss_data,
first_val_step,
forward_only,
no_sync_func,
total_num_tokens,
partial(check_first_val_step, first_val_step, forward_only),
)
else:
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
grad_finalize_pgs.cp.size(),
collect_non_loss_data,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
)
total_num_tokens += num_tokens
if not forward_only:
backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
grad_finalize_pgs.cp.size(),
collect_non_loss_data,
is_first_microbatch=check_first_val_step(
first_val_step, forward_only, num_microbatches == 1
),
current_microbatch=num_microbatches - 1,
)

total_num_tokens += num_tokens

if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config.finalize_model_grads_func(
[model],
total_num_tokens if config.calculate_per_token_loss else None,
grad_finalize_pgs=grad_finalize_pgs,
)

if config.timers is not None:
config.timers('forward-backward').stop()

if (
hasattr(config, 'enable_cuda_graph')
and config.enable_cuda_graph
and config.cuda_graph_scope != "full_iteration"
):
create_cudagraphs()

return forward_data_store

  1. 在前microbatch-1次训练中,会套上with no_sync_func(),并且只要不是forward_only模式就会在backward_step中,这里其实主要都是pp的逻辑,关键的是会使用torch.autograd.backward来传播梯度,在梯度传播中会对各个micro_batch产生的梯度进行累加。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
config,
pipeline_model_parallel_size=1,
):
"""Backward step through passed-in output tensor.

If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.

Returns gradient of loss with respect to input tensor (None if first
stage)."""

# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.

if config.timers is not None:
config.timers('backward-compute', log_level=2).start()

# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()

if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]

# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])

# In multi-modal models like VLM, some batches may not have images.
# When no image is present, the vision encoder (as a separate pipeline stage)
# will not participate in the computation.
# This results in a tensor that does not require gradients.
# In such cases, we intentionally skip the backward pass while preserving zero gradients.
if output_tensor[0].requires_grad:
if config.deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])

# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)

if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]

if config.timers is not None:
config.timers('backward-compute').stop()

return input_tensor_grad

  • 在最后第microbatch次进行训练时没有套上with no_sync_func()从而允许一些同步操作

  • 而数据并行的梯度间同步其实是在后续的config.finalize_model_grads_func中。

    1
    2
    3
    4
    5
    6
    7
    8
    if config.finalize_model_grads_func is not None and not forward_only:
    # Finalize model grads (perform full grad all-reduce / reduce-scatter for
    # data parallelism and layernorm all-reduce for sequence parallelism).
    config.finalize_model_grads_func(
    [model],
    total_num_tokens if config.calculate_per_token_loss else None,
    grad_finalize_pgs=grad_finalize_pgs,
    )
    1. config来自get_model_config(model)中,而config.finalize_model_grads_func是在train函数中添加的功能:

      1
      2
      3
      4
      def get_model_config(model):
      """Returns the config attribute, allowed to return None"""
      return get_attr_wrapped_model(model, "config", allow_none=False)

      1. 添加的finalize_model_grads代码如下
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      def finalize_model_grads(
      model: List[torch.nn.Module],
      num_tokens: Optional[torch.Tensor] = None,
      grad_finalize_pgs: Optional[GradFinalizeProcessGroups] = None,
      ):
      """
      All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
      embedding grads across first and last pipeline stages (if not tied),
      scale gradients by `num_tokens`.
      """

      config = get_model_config(model[0])
      if grad_finalize_pgs is not None:
      assert hasattr(grad_finalize_pgs, 'tp')
      assert hasattr(grad_finalize_pgs, 'pp')
      assert hasattr(grad_finalize_pgs, 'embd'), (
      "grad_finalize_pgs must have a embd. In previous version, it is used default "
      "`parallel_state.default_embedding_ranks` to create the process group."
      " If you are using the default process group, please use"
      " `parallel_state.get_embedding_group()` "
      "If you don't need embd_group, you need to explicitly set it to None."
      )
      assert hasattr(grad_finalize_pgs, 'pos_embd'), (
      "grad_finalize_pgs must have a pos_embd. In previous version, it is used default "
      "`parallel_state.default_position_embedding_ranks` to create the process group."
      " If you are using the default process group, please use "
      " `parallel_state.get_position_embedding_group()` "
      "If you don't need pos_embd_group, you need to explicitly set it to None."
      )
      assert hasattr(grad_finalize_pgs, 'dp_cp')
      tp_group = grad_finalize_pgs.tp
      pp_group = grad_finalize_pgs.pp
      embd_group = grad_finalize_pgs.embd
      pos_emb_group = grad_finalize_pgs.pos_embd
      dp_cp_group = grad_finalize_pgs.dp_cp
      else:
      tp_group = parallel_state.get_tensor_model_parallel_group()
      pp_group = parallel_state.get_pipeline_model_parallel_group()
      embd_group = parallel_state.get_embedding_group(check_initialized=False)
      pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
      dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)

      # All-reduce / reduce-scatter across DP replicas.
      if config.timers is not None:
      config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
      for model_chunk in model:
      model_chunk.finish_grad_sync()
      if config.timers is not None:
      config.timers('all-grads-sync').stop()

      # All-reduce t_embedder grads (for pp & vpp of DiT).
      if config.timers is not None:
      config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
      barrier=config.barrier_with_L1_time
      )
      _allreduce_conditional_embedding_grads(model, config, pp_group)
      if config.timers is not None:
      config.timers('conditional-embedder-grads-all-reduce').stop()

      # All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
      if config.timers is not None:
      config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start(
      barrier=config.barrier_with_L1_time
      )
      _allreduce_non_tensor_model_parallel_grads(model, config, tp_group)
      if config.timers is not None:
      config.timers('non-tensor-parallel-grads-all-reduce').stop()

      # All-reduce embedding grads (for pipeline parallelism).
      if config.timers is not None:
      config.timers('embedding-grads-all-reduce', log_level=1).start(
      barrier=config.barrier_with_L1_time
      )
      _allreduce_word_embedding_grads(model, config, embd_group, pp_group)
      _allreduce_position_embedding_grads(model, config, pos_emb_group, pp_group)

      if config.timers is not None:
      config.timers('embedding-grads-all-reduce').stop()

      if config.moe_router_enable_expert_bias:
      _update_router_expert_bias(model, config)

      # normalize gradients for per-token loss normalization.
      # if we are using by the number of tokens, then we use that as a divisor. this number
      # will be the total number of non-padded tokens in the global batch.
      if num_tokens is not None:

      # the number of tokens is only present on the last stage, so broadcast it
      # to the other ranks in the pipeline parallel group.
      assert not isinstance(pp_group, list)
      last_rank = get_pp_last_rank(pp_group)
      torch.distributed.broadcast(num_tokens, src=last_rank, group=pp_group)

      # all-reduce across DP ranks.
      torch.distributed.all_reduce(num_tokens, group=dp_cp_group)
      for model_chunk in model:
      if num_tokens > 0:
      scaling = 1.0 / num_tokens
      model_chunk.scale_gradients(scaling)

      • 其中最关键的是如下触发各个model_chunk的finish_grad_sync梯度同步的代码:
      1
      2
      3
      4
      5
      6
      7
      # All-reduce / reduce-scatter across DP replicas.
      if config.timers is not None:
      config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
      for model_chunk in model:
      model_chunk.finish_grad_sync()
      if config.timers is not None:
      config.timers('all-grads-sync').stop()
      • 而这函数的执行就与model的类型有关了,model的类型由get_model函数决定,如下所示,这里我们以最简单的DDP(DistributedDataParallel)进行举例分析
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      if wrap_with_ddp:
      if args.use_torch_fsdp2:
      assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
      DP = torch_FSDP
      elif args.use_megatron_fsdp:
      DP = megatron_FSDP
      else:
      DP = DDP

      ...

      with torch.cuda.stream(torch.cuda.Stream()):
      model = [
      DP(
      config=config,
      ddp_config=ddp_config,
      module=model_chunk,
      # Turn off bucketing for model_chunk 2 onwards, since communication for these
      # model chunks is overlapped with compute anyway.
      disable_bucketing=(model_chunk_idx > 0)
      or args.overlap_param_gather_with_optimizer_step,
      )
      for (model_chunk_idx, model_chunk) in enumerate(model)
      ]
      • DistributedDataParallel中的finish_grad_sync函数如下所示,其对每个bucket_group 调用了finish_grad_sync
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      def finish_grad_sync(self):
      """
      Finishes grad sync (all-reduce or reduce-scatter) communication operations
      for all model gradients.

      When overlap_grad_reduce is set to True, waits for asynchronous communication
      calls to complete. When overlap_grad_reduce is set to False, calls synchronous
      communication ops.
      """
      for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
      bucket_group.finish_grad_sync()
      • bucket_group 是_ParamAndGradBucketGroup类,其finish_grad_sync函数如下所示,这里看一般情况,也就是self.ddp_config.overlap_grad_reduce为False,即直接调用start_grad_sync
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      def finish_grad_sync(self):
      """
      Finishes grad sync (all-reduce or reduce-scatter) communication operations
      for all buckets in the bucket group.

      When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
      communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
      makes synchronous call.
      """
      self.param_gather_dispatched = False
      # If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
      if not self.ddp_config.overlap_grad_reduce:
      self.start_grad_sync()
      return
      # When using multiple DistOpt instances, we don't need to sync here as we launch
      # communications on a separate communication stream.
      if self.ddp_config.num_distributed_optimizer_instances > 1:
      torch.cuda.default_stream().wait_stream(self.communication_stream)
      return
      assert self.grad_reduce_handle is not None, (
      f"Communication call has not been issued for this bucket "
      f"({len(self.params_with_grad)}/{len(self.params)} params have grad available)"
      )
      self.grad_reduce_handle.wait()
      self.grad_reduce_handle = None
      • start_grad_sync函数如下所示,其核心是把一组 bucket(连续 grad buffer 的若干切片)触发数据并行梯度同步,并支持“是否与反向计算重叠(overlap)”、“是否用分布式优化器(reduce-scatter)”、“是否多 DistOpt 实例(两级通信)”、“是否做梯度检查/缩放”等。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      def start_grad_sync(self):
      """
      Initiates grad sync (all-reduce or reduce-scatter) communication operations
      for all buckets in the bucket group.

      When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
      communication call. When ddp_config.overlap_grad_reduce is set to False, makes
      synchronous call.
      """
      assert (
      self.grad_reduce_handle is None
      ), "Should not have multiple communication calls outstanding at once"

      if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads:
      self.check_grads(
      check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
      check_for_large=self.ddp_config.check_for_large_grads,
      )

      # gradient_scaling_factor already takes into account whether we are computing
      # an average or sum in the data-parallel collective.
      for bucket in self.buckets:
      if bucket.gradient_scaling_factor != 1.0:
      bucket.grad_data *= bucket.gradient_scaling_factor

      # Decide reduce_op.
      reduce_op = torch.distributed.ReduceOp.SUM
      if self.ddp_config.average_in_collective:
      reduce_op = torch.distributed.ReduceOp.AVG

      # We use the following stream synchronization for the gradient reduction
      # within and across DistOpt instances.

      # Compute Stream: -------------Gradient compute-------------------
      # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
      # NCCL Stream: -------RS------ -------AR------

      # Use async communications only when overlap_grad_reduce is True.
      async_op = (
      self.ddp_config.overlap_grad_reduce
      and self.ddp_config.num_distributed_optimizer_instances == 1
      )
      if (
      self.ddp_config.num_distributed_optimizer_instances > 1
      and self.ddp_config.overlap_grad_reduce
      ):
      # Assign a communication stream if we have multiple DistOpt instances and we
      # need to overlap communication.
      stream_context = torch.cuda.stream(self.communication_stream)

      # The RS/AR communication stream needs to wait for the default stream
      # to complete its gradient computation before launching the next
      # gradient reduction collective.
      self.communication_stream.wait_stream(torch.cuda.default_stream())
      else:
      stream_context = nullcontext()

      if self.ddp_config.use_distributed_optimizer:
      communication_group = self.intra_distributed_optimizer_instance_group
      else:
      communication_group = self.data_parallel_group

      # Coalesce communication kernels across buckets in the bucket group.
      with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
      for idx, bucket in enumerate(self.buckets):
      if self.ddp_config.use_distributed_optimizer:
      if self.cached_grad_buffer_shard_list[idx] is None:
      self.cached_grad_buffer_shard_list[idx] = shard_buffer(
      bucket.grad_data, self.intra_distributed_optimizer_instance_size
      )
      local_data_view = self.cached_grad_buffer_shard_list[idx][
      self.intra_distributed_optimizer_instance_rank
      ]
      dist_reduce_scatter_func(
      local_data_view,
      bucket.grad_data,
      op=reduce_op,
      group=communication_group,
      async_op=async_op,
      )
      else:
      torch.distributed.all_reduce(
      bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
      )

      # With multiple DistOpt instances, we need to all-reduce across instances.
      if (
      self.ddp_config.use_distributed_optimizer
      and self.ddp_config.num_distributed_optimizer_instances > 1
      ):
      assert self.inter_distributed_optimizer_instance_group is not None
      # Create a new coalescing manager for the inter-instance all-reduce.
      with (
      stream_context,
      _coalescing_manager(
      self.inter_distributed_optimizer_instance_group, async_ops=async_op
      ) as cm,
      ):
      for idx, bucket in enumerate(self.buckets):
      if self.cached_grad_buffer_shard_list[idx] is None:
      self.cached_grad_buffer_shard_list[idx] = shard_buffer(
      bucket.grad_data, self.intra_distributed_optimizer_instance_size
      )
      local_data_view = self.cached_grad_buffer_shard_list[idx][
      self.intra_distributed_optimizer_instance_rank
      ]

      torch.distributed.all_reduce(
      local_data_view,
      op=reduce_op,
      group=self.inter_distributed_optimizer_instance_group,
      async_op=async_op,
      )

      if async_op:
      self.grad_reduce_handle = cm
      else:
      # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
      # `cm` is not None, which is different from when `_coalescing_manager` is not used in
      # which case the torch.distributed._reduce_scatter_base() will return None. In order to
      # maintain consistency with prior code, we need to manually set communication handle to
      # None.
      self.grad_reduce_handle = None

看DDP数据并行的torch profile结果也能证实这一系列链路的正确性:

后续还可以再看看FSDP下的数据并行实现。


【Megatron-LM源码分析(四)】-DDP数据并行
http://example.com/2025/12/28/megatron-lm-ddp/
作者
滑滑蛋
发布于
2025年12月28日
许可协议