【Megatron-LM源码分析(四)】-DDP数据并行
数据并行存在多种,最简单的就是DDP,每个DP都拥有完整的模型,然后在每个batch结束后在DP间同步梯度,最后统一进行优化器更新。再复杂一些的数据宾县会上ZERO技术,将模型、梯度、优化器状态等进行切分。
这里主要关注最简单的DDP,主要关注DP并行组如何划分,如果做个各个DP读取不同的数据,以及训练过程中如何做到梯度同步。
DP并行组
查看megatron/core/parallel_state.py中的initialize_model_parallel函数,可以看到其model_size计算的公式为model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size,然后其data_parallel_size并不是直接定义出来的,而是通过data_parallel_size: int = world_size // model_size计算出来。
其构建DP并行组的代码如下,这里是通过decoder_rank_generator将所有除所属dp不同外都相同的rank组成一个dp group通信组:
1 | |
数据切分
在DP并行下,每个dp都应该获得不同的数据。
数据集读取器构造流程
数据集读取器的整体构造流程为:
用户提供数据集原始文件,以及token化所需要的merge和vocab文件,然后提供split划分train、valid、test的比例,并设置
micro-batch-size和global-batch-size用户通过定义
BlendedMegatronDatasetBuilder得到将文件转化为数据集的train_ds, valid_ds, test_ds然后将这些数据集转化为支持迭代获取一批一批数据的rain_dataloader, valid_dataloaders, test_dataloader
Megatron-LM构建数据集读取器的具体构造顺序如下:
- 用户需要自定义一个
train_valid_test_datasets_provider,在该函数中构建BlendedMegatronDatasetBuilder获得train_ds, valid_ds, test_ds并返回,这类ds对原始数据集文件进行了包裹,支持从中读取数据。
1 | |
将自定义的
train_valid_test_datasets_provider作为参数传递给pretrain核心训练函数,然后在pretrain中借助build_train_valid_test_data_iterators函数构造出数据集迭代器:train_data_iterator,valid_data_iterator, test_data_iterator。注意如果是pp并行中启用了vp,那么需要多个iterators。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18if __name__ == "__main__":
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
# Optionally enable inprocess restart on pretrain
pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
store=store,
)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17def pretrain(...):
...
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator = (
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
)
...build_train_valid_test_data_iterators中首先会通过build_train_valid_test_data_loaders获取 train_dataloader, valid_dataloaders, test_dataloader1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider):
"""Build pretraining data iterators."""
args = get_args()
# Build loaders.
train_dataloader, valid_dataloaders, test_dataloader = build_train_valid_test_data_loaders(
build_train_valid_test_datasets_provider
)
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic', 'external']
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")
if train_dataloader is not None:
train_data_iterator = _get_iterator(dl_type, train_dataloader)
else:
train_data_iterator = None
# when using full validation, we need to override eval iters with the correct
# number of iterations on tp rank 0 so that it can be distributed to the other
# ranks later
if args.full_validation:
if args.multiple_validation_sets:
if valid_dataloaders[0] is None:
args.eval_iters = [None]*len(valid_dataloaders)
else:
args.eval_iters = [len(dl) for dl in valid_dataloaders]
else:
args.eval_iters = len(valid_dataloaders[0])
if args.multiple_validation_sets:
if valid_dataloaders[0] is None:
valid_data_iterators = [None] * len(valid_dataloaders)
else:
valid_dl_type = "cyclic" if args.full_validation else dl_type
print(
f"[VALID DATA LOADER LENGTHS] "
", ".join(f"{idx}: {len(dl)}" for idx, dl in enumerate(valid_dataloaders))
)
valid_data_iterators = [
_get_iterator(valid_dl_type, dl) for dl in valid_dataloaders
]
elif valid_dataloaders[0] is not None:
valid_data_iterators = _get_iterator(dl_type, valid_dataloaders[0])
else:
valid_data_iterators = None
if test_dataloader is not None:
test_data_iterator = _get_iterator(dl_type, test_dataloader)
else:
test_data_iterator = None
return train_data_iterator, valid_data_iterators, test_data_iteratorbuild_train_valid_test_data_loaders函数如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider):
"""Build pretraining data loaders."""
args = get_args()
(train_dataloader, valid_dataloaders, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
assert (
args.train_samples is None
), 'Only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
if args.train_samples is None:
args.consumed_valid_samples = (
(args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size
)
# Rely on distributed-aware core datasets, temporary
is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False)
# Construct the data pipeline
if is_distributed or mpu.get_tensor_model_parallel_rank() == 0:
# Build datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
build_train_valid_test_datasets_provider
)
valid_ds = [valid_ds] if not isinstance(valid_ds, list) else valid_ds
# Build dataloders.
train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples)
valid_dataloaders = []
for valid_d in valid_ds:
if args.skip_train or args.full_validation:
valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0))
else:
if args.multiple_validation_sets:
# TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not
# correct and needs to be calculated/set per validation set
raise NotImplementedError("--multiple-validation-sets currently requires --full-validation")
valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples))
if not args.multiple_validation_sets:
assert len(valid_dataloaders) == 1
test_dataloader = build_pretraining_data_loader(test_ds, 0)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
do_valid = valid_dataloaders is not None and (args.full_validation or args.eval_iters > 0)
do_test = test_dataloader is not None and (args.full_validation or args.eval_iters > 0)
flags = torch.tensor(
[int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device='cuda'
)
else:
flags = torch.tensor([0, 0, 0], dtype=torch.long, device='cuda')
torch.distributed.broadcast(flags, 0)
args.do_train = getattr(args, "do_train", False) or flags[0].item()
args.do_valid = getattr(args, "do_valid", False) or flags[1].item()
args.do_test = getattr(args, "do_test", False) or flags[2].item()
return train_dataloader, valid_dataloaders, test_dataloader其首先会补全出当前已消耗的train、valid样本数量,以避免断点重训后还使用同样的数据。
然后获取
build_train_valid_test_datasets_provider的is_distributed参数,如果为true才认为需要构建数据集。然后其调用
build_train_valid_test_datasets,通过简单计算train、valid、test需要的样本数量借助用户自定义的train_valid_test_datasets_provider获得train_ds, valid_ds, test_ds
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27def build_train_valid_test_datasets(build_train_valid_test_datasets_provider):
"""Build pretraining datasets."""
train_valid_test_num_samples = get_train_valid_test_num_samples()
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_valid_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_valid_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_valid_test_num_samples[2]))
return build_train_valid_test_datasets_provider(train_valid_test_num_samples)
def get_train_valid_test_num_samples():
"""Train/valid/test num samples."""
args = get_args()
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
if args.full_validation:
eval_samples = None
else:
eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
eval_samples = eval_iters * args.global_batch_size
test_iters = args.eval_iters
return (train_samples, eval_samples, test_iters * args.global_batch_size)然后再将train_ds, valid_ds, test_ds与当前已消耗的train、valid样本数量结合利用
build_pretraining_data_loader构造出train_dataloader、valid_dataloaders、test_dataloader ,如果对应的dataloader不为空就设置对应args.do_train、args.do_valid、args.do_test。build_pretraining_data_loader代码如下所示1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53def build_pretraining_data_loader(dataset, consumed_samples):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
if hasattr(dataset,'split'):
split = dataset.split
elif hasattr(dataset,'index_split'):
split = dataset.index_split
else:
split = None
if split == Split.valid and args.full_validation:
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'single':
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
elif args.dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return dataset
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True if args.num_workers > 0 else False,
)这里有Smapler,其中MegatronPretrainingSampler是按序读取、可恢复的Sampler,MegatronPretrainingRandomSampler是随机采样(基于当前epoch做随机数)、可无限循环的Smapler。
最后基于Samplerydataset返回了标准的
torch.utils.data.DataLoader注意这里对于valid数据集且需要full_validation,即数据全跑一遍的情况,构建了consumed_samples=0的MegatronPretrainingSampler。此外也支持通过
args.dataloader_type == "external"自定义Dataloader
如果dataloader非空就通过
_get_iterator依据dataloader_type对齐进行包装,包装成RerunDataIterator以支持rerun容错重跑注意对于valid_dataloaders,如果参数配置了
full_validation需要更新eval_iters为全部的iters
train_data_iterator与valid_data_iterator会被传入到
train函数中进行训练,test_data_iterator会在训练完后如果配置了args.do_test就最最终的测试。
数据集构造关键类介绍
BlendedMegatronDatasetBuilder
BlendedMegatronDatasetBuilder主要是支持数据集混合功能,例如将常识数据集与代码数据集混合,其代码如下:
1 | |
不过这里我们暂时不考虑数据集混合的情况,而是先看单一数据集下如何处理的。
单一数据集下会走进_build_blended_dataset_splits的如下代码:
1 | |
在_build_megatron_dataset_splits中的处理流程如下:
- 如果当前是不需要创建数据集的rank(
is_dataset_built_on_rank)就进行同步等待,即只有最前和最后的pp并行的rank以及tp的第一位才需要构建,从而避免资源浪费。
1 | |
然后通过
self.cls.build_low_level_dataset(dataset_path, self.config)构建 low-level dataset,这里我们查看的是GPTDataset,如下所示1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@staticmethod
def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset:
"""Abstract method implementation
Args:
dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files
config (GPTDatasetConfig): The config
Returns:
IndexedDataset: The underlying IndexedDataset
"""
if is_object_storage_path(dataset_path):
assert config.object_storage_cache_path is not None
return IndexedDataset(
dataset_path,
multimodal=False,
mmap=config.mmap_bin_files,
object_storage_config=ObjectStorageConfig(
path_to_idx_cache=config.object_storage_cache_path
),
)
return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files)从代码可以看到其本质是构建了一个
IndexedDataset,代码如下所示它使用的数据是一对文件:
path_prefix.idx:索引文件,记录每条样本(sequence)的长度、在
.bin里的字节偏移(pointer)、以及(可选)multimodal 的 mode;还记录文档边界(document_indices)。path_prefix.bin:真实 token 数据(连续存储的定长 dtype 数组)。
其核心能力是:
提供O(1) 级别的随机访问能力:通过
.idx找到第 i 条样本在.bin中的 offset 和 length,然后从.bin读取对应 token 序列。支持切片读取:
dataset[start:stop]允许一次读连续多条样本(step 必须为 1),避免逐条调用导致频繁 IO。部分读取:
get(idx, offset, length)可以只取某条样本的一段 token(用于截断、窗口等场景)。高效读取策略可选:通过 mmap=True/False 选择用内存映射(_MMapBinReader)或文件读(_FileBinReader);如果数据在对象存储(S3/MSC),则用对应 reader 分块拉取,并把
.idx缓存到本地。**exists(path_prefix)**:检查
.idx/.bin是否存在(本地或对象存储)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267class IndexedDataset(torch.utils.data.Dataset):
"""The low-level interface dataset class
Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
multimodal (bool): Whether the dataset is multimodal. Defaults to False.
mmap (bool): Whether to mmap the .bin files. Defaults to True.
object_storage_config (Optional[ObjectStorageConfig]): Supplied only for data stored on S3
or MSC. IndexedDataset downloads the index (.idx) file to
`object_storage_config.path_to_idx_cache` and streams data from the data (.bin) file
in `object_storage_config.bin_chunk_nbytes` blocks. Note that `mmap` must be disabled
for S3 data loading. Defaults to None.
"""
def __init__(
self,
path_prefix: str,
multimodal: bool = False,
mmap: bool = True,
object_storage_config: Optional[ObjectStorageConfig] = None,
s3_config: Optional[S3Config] = None,
) -> None:
super().__init__()
self.path_prefix: str
self.multimodal: bool
self.mmap: bool
self.object_storage_config: Optional[ObjectStorageConfig]
self.bin_reader: _BinReader
self.index: _IndexReader
# Deprecated: s3_config is deprecated, use object_storage_config instead
object_storage_config = object_storage_config or s3_config
# Cache the index file if it is stored on object storage
if is_object_storage_path(path_prefix) and object_storage_config is not None:
idx_path = get_idx_path(path_prefix)
cache_idx_path = get_index_cache_path(idx_path, object_storage_config)
cache_index_file(idx_path, cache_idx_path)
self.initialize(path_prefix, multimodal, mmap, object_storage_config)
def initialize(
self,
path_prefix: str,
multimodal: bool,
mmap: bool,
object_storage_config: Optional[ObjectStorageConfig],
) -> None:
"""Initialize the dataset
This method is called by IndexedDataset.__init__ during object creation and by
IndexedDataset.__setstate__ during un-pickling
Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
multimodal (bool): Whether the dataset is multimodal
mmap (bool): Whether to mmap the .bin file
object_storage_config (Optional[ObjectStorageConfig]): See IndexedDataset docstring
for details.
"""
idx_path = get_idx_path(path_prefix)
bin_path = get_bin_path(path_prefix)
if object_storage_config is None:
assert os.path.exists(idx_path) and os.path.exists(
bin_path
), "One or both of the .idx and .bin files cannot be found at the "
f"path prefix {path_prefix}"
self.path_prefix = path_prefix
self.multimodal = multimodal
self.mmap = mmap
self.object_storage_config = object_storage_config
if mmap:
assert not object_storage_config
self.bin_reader = _MMapBinReader(bin_path)
elif object_storage_config:
assert not mmap
self.bin_reader = OBJECT_STORAGE_BIN_READERS[get_object_storage_access(path_prefix)](
bin_path, object_storage_config
)
idx_path = get_index_cache_path(get_idx_path(path_prefix), object_storage_config)
else:
self.bin_reader = _FileBinReader(bin_path)
self.index = _IndexReader(idx_path, self.multimodal)
def __getstate__(self) -> Tuple[str, bool, bool, Optional[ObjectStorageConfig]]:
"""Get the state during pickling
Returns:
Tuple[str, bool, bool, Optional[ObjectStorageConfig]]: The state tuple
"""
return self.path_prefix, self.multimodal, self.mmap, self.object_storage_config
def __setstate__(self, state: Tuple[str, bool, bool, Optional[ObjectStorageConfig]]) -> None:
"""Set the state during un-pickling
Args:
state (Tuple[str, bool, bool, Optional[ObjectStorageConfig]]): The state tuple
"""
path_prefix, multimodal, mmap, object_storage_config = state
self.initialize(path_prefix, multimodal, mmap, object_storage_config)
def __del__(self) -> None:
"""Clean up the object"""
del self.bin_reader
del self.index
def __len__(self) -> int:
"""Return the length of the dataset i.e. the number of sequences in the index
Returns:
int: The length of the dataset
"""
return len(self.index)
def __getitem__(
self, idx: Union[int, numpy.integer, slice]
) -> Union[
numpy.ndarray,
Tuple[numpy.ndarray, numpy.number],
List[numpy.ndarray],
Tuple[List[numpy.ndarray], numpy.ndarray],
]:
"""Return from the dataset
Args:
idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset
Raises:
ValueError: When the index slice is non-contiguous
TypeError: When the index is of an unexpected type
Returns:
Union[
numpy.ndarray,
Tuple[numpy.ndarray, numpy.number],
List[numpy.ndarray],
Tuple[List[numpy.ndarray], numpy.ndarray],
]: The sequence tokens and modes at the index or index slice
"""
if isinstance(idx, (int, numpy.integer)):
sequence_pointer, sequence_length, sequence_mode = self.index[idx]
sequence = self.bin_reader.read(
dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer
)
return (sequence, sequence_mode) if sequence_mode is not None else sequence
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sequence_lengths = self.index.sequence_lengths[idx]
sequence_modes = (
self.index.sequence_modes[idx] if self.multimodal else None # type: ignore[index]
)
sequence_offsets = list(accumulate(sequence_lengths))
sequences = numpy.split(
self.bin_reader.read(
dtype=self.index.dtype,
count=sum(sequence_lengths),
offset=self.index.sequence_pointers[start],
),
sequence_offsets[:-1],
)
return (sequences, sequence_modes) if sequence_modes is not None else sequences
else:
raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
def get(
self, idx: int, offset: int = 0, length: Optional[int] = None
) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]:
"""Retrieve a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
Args:
idx (Union[int, numpy.integer]): The index into the dataset
offset (int): The integer token offset in the sequence
length (int): The number of tokens to grab from the sequence
Returns:
Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]: The sequence tokens and mode
at the index
"""
sequence_pointer, sequence_length, sequence_mode = self.index[idx]
if length is None:
length = sequence_length - offset
sequence_pointer += offset * DType.size(self.index.dtype)
sequence = self.bin_reader.read(
dtype=self.index.dtype, count=length, offset=sequence_pointer
)
return (sequence, sequence_mode) if sequence_mode is not None else sequence
@property
def sequence_lengths(self) -> numpy.ndarray:
"""Get the sequence lengths
Returns:
numpy.ndarray: The sequence lengths
"""
return self.index.sequence_lengths
@property
def document_indices(self) -> numpy.ndarray:
"""Get the document indices
Returns:
numpy.ndarray: The document indices
"""
return self.index.document_indices
def get_document_indices(self) -> numpy.ndarray:
"""Get the document indices
This method is slated for deprecation.
Returns:
numpy.ndarray: The document indices
"""
return self.index.document_indices
def set_document_indices(self, document_indices: numpy.ndarray) -> None:
"""Set the document indices
This method is slated for deprecation.
Args:
document_indices (numpy.ndarray): The document indices
"""
self.index.document_indices = document_indices
@property
def sequence_modes(self) -> numpy.ndarray:
"""Get the sequence modes
Returns:
numpy.ndarray: The sequence modes
"""
assert self.index.sequence_modes
return self.index.sequence_modes
@staticmethod
def exists(path_prefix: str) -> bool:
"""Return whether the IndexedDataset exists on disk at the prefix
Args:
path_prefix (str): The prefix to the index (.idx) and data (.bin) files
Returns:
bool: Whether the IndexedDataset exists on disk at the prefix
"""
if is_object_storage_path(path_prefix):
return dataset_exists(path_prefix, get_idx_path(path_prefix), get_bin_path(path_prefix))
return os.path.exists(get_idx_path(path_prefix)) and os.path.exists(
get_bin_path(path_prefix)
)读取这个low-level dataset中一共有多少样本,然后依据split来计算按比例分割下的各实际范围
然后调用
build_generic_dataset为各范围构建mid-level dataset并返回,这也就是我们得到的 train_ds, valid_ds, test_ds其首先是让rank 0构建mid-level dataset也就是实际的
GPTDataset,然后让其他rank都等待再让rank不为 0并且
is_built_on_rank的rank构建GPTDataset,并对应返回该GPTDataset。
rank 0 先构建是为了把“构建 dataset 时可能产生的共享缓存写入”变成“单进程写入 + 多进程读取”,从而避免竞态、提升缓存命中、保证 barrier 同步与流程一致性。
GPTDataset的作用是:在底层 IndexedDataset(存放.bin/.idx的 token 序列)之上,构建可直接用于 GPT 自回归训练的 PyTorchDataset。其主要作用有:把原始序列拼接/切片成固定长度样本:通过构建 document_index / sample_index / shuffle_index,把很多条变长序列按文档顺序拼接,然后切成长度为 sequence_length(可带 1 个 extra token)的训练样本。
提供训练所需张量:getitem 返回 tokens/labels,并生成(或复用缓存的)
attention_mask / loss_mask / position_ids,满足左到右(causal)语言模型训练。支持可复现的 shuffle 与多 epoch 采样:用 shuffle_index 控制样本随机顺序;当 num_samples 大于一个 epoch 的可用样本时,会计算需要重复多少个 epoch。
支持索引缓存:会把构建出来的 document_index.npy / sample_index.npy / shuffle_index.npy 写到 path_to_cache,下次启动直接加载,避免重复计算。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487class GPTDataset(MegatronDataset):
"""The base GPT dataset
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the GPTDataset
dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When
None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (GPTDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: Optional[str],
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: GPTDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.masks_and_position_ids_are_cacheable = not any(
[
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
]
)
self.masks_and_position_ids_are_cached = False
self.cached_attention_mask = None
self.cached_loss_mask = None
self.cached_position_ids = None
try:
self._pad_token_id = self.config.tokenizer.pad
except Exception:
self._pad_token_id = _PAD_TOKEN_ID
(self.document_index, self.sample_index, self.shuffle_index) = (
self._build_document_sample_shuffle_indices()
)
@staticmethod
def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int:
"""Abstract method implementation
For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say,
BERT, which should be split by document
Args:
low_level_dataset (IndexedDataset): The underlying IndexedDataset
Returns:
int: The number of unique elements in the underlying IndexedDataset
"""
return low_level_dataset.sequence_lengths.shape[0]
@staticmethod
def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset:
"""Abstract method implementation
Args:
dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files
config (GPTDatasetConfig): The config
Returns:
IndexedDataset: The underlying IndexedDataset
"""
if is_object_storage_path(dataset_path):
assert config.object_storage_cache_path is not None
return IndexedDataset(
dataset_path,
multimodal=False,
mmap=config.mmap_bin_files,
object_storage_config=ObjectStorageConfig(
path_to_idx_cache=config.object_storage_cache_path
),
)
return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files)
def __len__(self) -> int:
"""Abstract method implementation
Returns:
int: The length of the dataset
"""
return self.sample_index.shape[0] - 1
def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
"""Abstract method implementation
Args:
idx (Optioal[int]): The index into the dataset
Returns:
Dict[str, torch.Tensor]: The sample information wrapped in a dictionary
"""
if idx is None:
# Batch padding sequence so the index does not matter
text, _ = self._query_document_sample_shuffle_indices(0)
else:
text, _ = self._query_document_sample_shuffle_indices(idx)
text = torch.from_numpy(text).long()
if self.config.add_extra_token_to_sequence:
tokens = text[:-1].contiguous()
labels = text[1:].contiguous()
else:
tokens = text
labels = torch.roll(text, shifts=-1, dims=0)
labels[-1] = self._pad_token_id
if (
not self.masks_and_position_ids_are_cacheable
or not self.masks_and_position_ids_are_cached
):
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
self.config.tokenizer.eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
self.config.create_attention_mask,
)
if self.masks_and_position_ids_are_cacheable:
self.cached_attention_mask = attention_mask
self.cached_loss_mask = loss_mask
self.cached_position_ids = position_ids
self.masks_and_position_ids_are_cached = True
else:
attention_mask = self.cached_attention_mask
loss_mask = self.cached_loss_mask
position_ids = self.cached_position_ids
# For padded sequences, mask the loss
loss_mask[labels == self._pad_token_id] = 0.0
# For padded sequences, ensure the embedding layer can map the token ID
tokens[tokens == self._pad_token_id] = 0
labels[labels == self._pad_token_id] = 0
# Batch padding sequence so we mask the loss
if idx is None:
loss_mask = torch.zeros_like(loss_mask)
if self.config.create_attention_mask:
return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
else:
return {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
def _query_document_sample_shuffle_indices(
self, idx: int
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""Get the text (token ids) and document ids for a given index
Args:
idx (int): The index into the dataset
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
"""
# Do the shuffle mapping
idx = self.shuffle_index[idx]
# Get the beginning and end documents and offsets
doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]
document_ids = []
sample_parts = []
# Sample spans a single document
if doc_index_beg == doc_index_end:
# Add the document id
document_ids.append(self.document_index[doc_index_beg])
# Add the entire sample
sample_parts.append(
self.dataset.get(
self.document_index[doc_index_beg],
offset=doc_index_beg_offset,
length=doc_index_end_offset
- doc_index_beg_offset
+ self.config.add_extra_token_to_sequence,
)
)
# Sample spans multiple documents
else:
for i in range(doc_index_beg, doc_index_end + 1):
# Add the document id
document_ids.append(self.document_index[i])
# Add the sample part
offset = 0 if i > doc_index_beg else doc_index_beg_offset
length = (
None
if i < doc_index_end
else doc_index_end_offset + self.config.add_extra_token_to_sequence
)
sample_parts.append(
self.dataset.get(self.document_index[i], offset=offset, length=length)
)
assert len(document_ids) == len(
sample_parts
), f"len(document_ids) ({len(document_ids)}) != len(sample_parts) ({len(sample_parts)})"
length = sum(map(len, sample_parts))
# Pad the sample if necessary
if length < (self.config.sequence_length + self.config.add_extra_token_to_sequence):
sample_parts.append(
[self._pad_token_id]
* (self.config.sequence_length + self.config.add_extra_token_to_sequence - length)
)
return (
numpy.concatenate(sample_parts, dtype=numpy.int64),
numpy.array(document_ids, dtype=numpy.int64),
)
def _build_document_sample_shuffle_indices(
self,
) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
"""Build the document index, the sample index, and the shuffle index
The document index:
-- 1-D
-- An ordered array of document ids
The sample index:
-- 2-D
-- The document indices and offsets which mark the start of every sample
The shuffle index:
-- 1-D
-- A random permutation of index range of the sample index
Returns:
Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample
index, and the shuffle index
"""
path_to_cache = self.config.path_to_cache
if path_to_cache is None and not self.config.mock:
path_to_cache = os.path.join(
self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
)
if path_to_cache:
base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}"
get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}")
path_to_description = get_path_to("description.txt")
path_to_document_index = get_path_to("document_index.npy")
path_to_sample_index = get_path_to("sample_index.npy")
path_to_shuffle_index = get_path_to("shuffle_index.npy")
cache_hit = all(
map(
os.path.isfile,
[
path_to_description,
path_to_document_index,
path_to_sample_index,
path_to_shuffle_index,
],
)
)
else:
cache_hit = False
if not path_to_cache or (
not cache_hit
and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0)
):
log_single_rank(
logger,
logging.INFO,
f"Build and save the {type(self).__name__} {self.index_split.name} indices",
)
t_beg = time.time()
sequence_length = self.config.sequence_length
num_tokens_per_epoch = self._get_num_tokens_per_epoch()
num_epochs = self._get_num_epochs(num_tokens_per_epoch)
if num_epochs == 1:
separate_final_epoch = False
else:
# Get the number of samples for the last epoch
num_samples_sans_final_epoch = (
(num_epochs - 1) * num_tokens_per_epoch
- self.config.add_extra_token_to_sequence
) // sequence_length
num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
num_samples_per_epoch = (
num_tokens_per_epoch - self.config.add_extra_token_to_sequence
) // sequence_length
# num_samples_from_final_epoch should be non-negative
assert num_samples_from_final_epoch >= 0
# num_samples_from_final_epoch should not exceed max value
assert num_samples_from_final_epoch <= num_samples_per_epoch + 1
# Separate the final epoch if it falls below the threshold
threshold = 0.80
separate_final_epoch = num_samples_from_final_epoch < int(
threshold * num_samples_per_epoch
)
log_single_rank(
logger,
logging.DEBUG,
f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
)
log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
log_single_rank(
logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
)
log_single_rank(
logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
)
numpy_random_state = numpy.random.RandomState(self.config.random_seed)
# Build the document index
document_index = _build_document_index(
self.indices, num_epochs, numpy_random_state, separate_final_epoch
)
# Build the sample index
from megatron.core.datasets import helpers
if self.index_split == Split.valid:
drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence
else:
drop_last_partial_sequence = True
assert document_index.dtype == numpy.int32
assert self.dataset.sequence_lengths.dtype == numpy.int32
if len(document_index) * 2 > len(self.dataset.sequence_lengths):
# If "access density" of sequence_lengths is high, force load the mmap-ed array
# into memory by making a copy.
#
# System performance benefits come from two aspects:
# 1. We sequentially pre-load the whole file, most of which we expect to read
# 2. The GIL is held when entering the c++ program, improving the speed of which
# improves parallelism
sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy()
else:
sequence_lengths_for_cpp = self.dataset.sequence_lengths
sample_index = helpers.build_sample_idx(
sequence_lengths_for_cpp,
document_index,
sequence_length,
num_epochs,
num_tokens_per_epoch,
drop_last_partial_sequence,
self.config.add_extra_token_to_sequence,
)
# Build the shuffle index
if separate_final_epoch:
shuffle_index = _build_shuffle_index(
num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state
)
else:
shuffle_index = _build_shuffle_index(
sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state
)
if path_to_cache:
os.makedirs(path_to_cache, exist_ok=True)
# Write the description
with open(path_to_description, "wt") as writer:
writer.write(self.unique_description)
numpy.save(path_to_document_index, document_index, allow_pickle=True)
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
else:
log_single_rank(
logger,
logging.WARNING,
f"Unable to save {type(self).__name__} indexes because path_to_cache is None",
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
return document_index, sample_index, shuffle_index
log_single_rank(
logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
)
log_single_rank(
logger,
logging.INFO,
f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
)
t_beg = time.time()
document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode="r")
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode="r")
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
)
t_beg = time.time()
shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode="r")
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
return document_index, sample_index, shuffle_index
def _get_num_tokens_per_epoch(self) -> int:
"""Calculate the number of tokens in a single epoch
Returns:
int: The number of tokens in a single epoch
"""
return int(numpy.sum(self.dataset.sequence_lengths[self.indices]))
def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
"""Calculate the number of epochs
Args:
num_tokens_per_epoch (int): The number of tokens in a single epoch
Returns:
int: The number of epochs
"""
num_epochs = 1
num_tokens = num_tokens_per_epoch
if self.num_samples is None:
return num_epochs
else:
num_tokens_requested = (
self.num_samples * self.config.sequence_length
) + self.config.add_extra_token_to_sequence
while num_tokens < num_tokens_requested:
num_epochs += 1
num_tokens += num_tokens_per_epoch
return num_epochs
MegatronPretrainingSampler
在得到可以获取单个GPT训练条目的mid-level dataset(train_ds, valid_ds, test_ds)后,MegatronPretrainingSampler会对其进行包裹,MegatronPretrainingSampler作为一个批采样器(batch sampler):它不返回单个样本索引,而是每次 yield 一组索引(一个 micro-batch),供 PyTorch DataLoader(…, batch_sampler=…) 直接使用。它的核心目标是:
在 Data Parallel 训练中,把“全局 batch”(= micro_batch_size * data_parallel_size)按 DP rank 切分成每个 rank 自己的 micro-batch 索引。
支持从某个 consumed_samples 开始继续取样(用于 resume / 断点续训 / 跳过已训练样本)。
代码如下所示:
1 | |
在初始化阶段,其记录了当前所属的dp rank,并且通过
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size计算了在数据并行下一次需要读取的实际数量其获取迭代数据时,遍历范围为(self.consumed_samples, self.total_samples),它会在其中获取
micro_batch_times_data_parallel_size个训练样本的idx,然后依据当前所属的dp rank计算得到本rank实际需要的miro_batch个idx
DataLoader
在build_pretraining_data_loader中最终会结合mid-level dataset(GPTDataset)以及MegatronPretrainingSampler得到torch.utils.data.DataLoader。从而支持通过该DataLoader获取一个mocro_batch大小的训练样本。
1 | |
数据集使用流程
- 在用户提供的forward_step函数中,就会传入类别为
torch.utils.data.DataLoader的data_iterator,如下所示:
1 | |
- 然后在获取训练数据的get_batch函数中会查看是否是pp并行的第一个或最后一个,如此才会去获取训练数据,获取数据依赖的是get_batch_on_this_tp_rank和get_batch_on_this_cp_rank,如下所示。
1 | |
- 首先是
get_batch_on_this_tp_rank,由于前述在构造数据集的时候只会在tp rank=0的时候构造,所以这里会查看当前rank,如果是tp rank=0的worker就通过next(data_iterator)获取一批micro_batch数据,然后与其他tp rank≠0的worker进行broadcast,传播各数据。代码如下所示:
1 | |
- 然后是执行
get_batch_on_this_cp_rank,在上下文并行中,为了GPU负载均衡,我们往往采用的是之字型计算划分,如下所示。所以这里进行CP维度切分的核心思想是将序列切分为2*CP份,然后每第i个cp rank拿走前面的第i份以及倒数第i份,从而平衡计算负载。代码如下

1 | |
- 得到batch中的tokens, labels, loss_mask, attention_mask, position_ids数据后,就可以调用model进行前向传播计算了。最终的结果就是每个dp会使用不同的micro_batch数据,同一个dp中只有pp并行中的第一位和最后一位获取了数据集,这些pp中的各tp获取的都是同一份数据,如果有cp会进一步对这数据进行切分。
数据并行训练
上述已经分析了数据并行中数据集是如何划分给各个worker的以及在一个训练step中是如何获取数据集的,现在看在训练过程中是如何进行dp间梯度同步的。
在一次训练迭代中,其train_step会完成一次完整的step,如下所示,其会调用forward_backward_func完成一个step的训练:
1 | |
forward_backward_func依据pp并行有多种,我们这里查看没有pp并行的forward_backward_no_pipelining,其代码如下:
1 | |
- 在前microbatch-1次训练中,会套上
with no_sync_func(),并且只要不是forward_only模式就会在backward_step中,这里其实主要都是pp的逻辑,关键的是会使用torch.autograd.backward来传播梯度,在梯度传播中会对各个micro_batch产生的梯度进行累加。
1 | |
在最后第microbatch次进行训练时没有套上
with no_sync_func()从而允许一些同步操作而数据并行的梯度间同步其实是在后续的config.finalize_model_grads_func中。
1
2
3
4
5
6
7
8if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config.finalize_model_grads_func(
[model],
total_num_tokens if config.calculate_per_token_loss else None,
grad_finalize_pgs=grad_finalize_pgs,
)config来自get_model_config(model)中,而
config.finalize_model_grads_func是在train函数中添加的功能:1
2
3
4def get_model_config(model):
"""Returns the config attribute, allowed to return None"""
return get_attr_wrapped_model(model, "config", allow_none=False)- 添加的
finalize_model_grads代码如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100def finalize_model_grads(
model: List[torch.nn.Module],
num_tokens: Optional[torch.Tensor] = None,
grad_finalize_pgs: Optional[GradFinalizeProcessGroups] = None,
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
if grad_finalize_pgs is not None:
assert hasattr(grad_finalize_pgs, 'tp')
assert hasattr(grad_finalize_pgs, 'pp')
assert hasattr(grad_finalize_pgs, 'embd'), (
"grad_finalize_pgs must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group."
" If you are using the default process group, please use"
" `parallel_state.get_embedding_group()` "
"If you don't need embd_group, you need to explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'pos_embd'), (
"grad_finalize_pgs must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group."
" If you are using the default process group, please use "
" `parallel_state.get_position_embedding_group()` "
"If you don't need pos_embd_group, you need to explicitly set it to None."
)
assert hasattr(grad_finalize_pgs, 'dp_cp')
tp_group = grad_finalize_pgs.tp
pp_group = grad_finalize_pgs.pp
embd_group = grad_finalize_pgs.embd
pos_emb_group = grad_finalize_pgs.pos_embd
dp_cp_group = grad_finalize_pgs.dp_cp
else:
tp_group = parallel_state.get_tensor_model_parallel_group()
pp_group = parallel_state.get_pipeline_model_parallel_group()
embd_group = parallel_state.get_embedding_group(check_initialized=False)
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_conditional_embedding_grads(model, config, pp_group)
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce').stop()
# All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_non_tensor_model_parallel_grads(model, config, tp_group)
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_word_embedding_grads(model, config, embd_group, pp_group)
_allreduce_position_embedding_grads(model, config, pos_emb_group, pp_group)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
assert not isinstance(pp_group, list)
last_rank = get_pp_last_rank(pp_group)
torch.distributed.broadcast(num_tokens, src=last_rank, group=pp_group)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=dp_cp_group)
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)- 其中最关键的是如下触发各个model_chunk的
finish_grad_sync梯度同步的代码:
1
2
3
4
5
6
7# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()- 而这函数的执行就与model的类型有关了,model的类型由
get_model函数决定,如下所示,这里我们以最简单的DDP(DistributedDataParallel)进行举例分析
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24if wrap_with_ddp:
if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
elif args.use_megatron_fsdp:
DP = megatron_FSDP
else:
DP = DDP
...
with torch.cuda.stream(torch.cuda.Stream()):
model = [
DP(
config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0)
or args.overlap_param_gather_with_optimizer_step,
)
for (model_chunk_idx, model_chunk) in enumerate(model)
]DistributedDataParallel中的finish_grad_sync函数如下所示,其对每个bucket_group 调用了finish_grad_sync。
1
2
3
4
5
6
7
8
9
10
11def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()- bucket_group 是
_ParamAndGradBucketGroup类,其finish_grad_sync函数如下所示,这里看一般情况,也就是self.ddp_config.overlap_grad_reduce为False,即直接调用start_grad_sync
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
assert self.grad_reduce_handle is not None, (
f"Communication call has not been issued for this bucket "
f"({len(self.params_with_grad)}/{len(self.params)} params have grad available)"
)
self.grad_reduce_handle.wait()
self.grad_reduce_handle = Nonestart_grad_sync函数如下所示,其核心是把一组 bucket(连续 grad buffer 的若干切片)触发数据并行梯度同步,并支持“是否与反向计算重叠(overlap)”、“是否用分布式优化器(reduce-scatter)”、“是否多 DistOpt 实例(两级通信)”、“是否做梯度检查/缩放”等。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.grad_reduce_handle is None
), "Should not have multiple communication calls outstanding at once"
if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads:
self.check_grads(
check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
check_for_large=self.ddp_config.check_for_large_grads,
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for bucket in self.buckets:
if bucket.gradient_scaling_factor != 1.0:
bucket.grad_data *= bucket.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# We use the following stream synchronization for the gradient reduction
# within and across DistOpt instances.
# Compute Stream: -------------Gradient compute-------------------
# Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True.
async_op = (
self.ddp_config.overlap_grad_reduce
and self.ddp_config.num_distributed_optimizer_instances == 1
)
if (
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
# Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
if self.ddp_config.use_distributed_optimizer:
communication_group = self.intra_distributed_optimizer_instance_group
else:
communication_group = self.data_parallel_group
# Coalesce communication kernels across buckets in the bucket group.
with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
for idx, bucket in enumerate(self.buckets):
if self.ddp_config.use_distributed_optimizer:
if self.cached_grad_buffer_shard_list[idx] is None:
self.cached_grad_buffer_shard_list[idx] = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)
local_data_view = self.cached_grad_buffer_shard_list[idx][
self.intra_distributed_optimizer_instance_rank
]
dist_reduce_scatter_func(
local_data_view,
bucket.grad_data,
op=reduce_op,
group=communication_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
)
# With multiple DistOpt instances, we need to all-reduce across instances.
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):
assert self.inter_distributed_optimizer_instance_group is not None
# Create a new coalescing manager for the inter-instance all-reduce.
with (
stream_context,
_coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm,
):
for idx, bucket in enumerate(self.buckets):
if self.cached_grad_buffer_shard_list[idx] is None:
self.cached_grad_buffer_shard_list[idx] = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)
local_data_view = self.cached_grad_buffer_shard_list[idx][
self.intra_distributed_optimizer_instance_rank
]
torch.distributed.all_reduce(
local_data_view,
op=reduce_op,
group=self.inter_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.grad_reduce_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.grad_reduce_handle = None- 添加的
看DDP数据并行的torch profile结果也能证实这一系列链路的正确性:

后续还可以再看看FSDP下的数据并行实现。