【pytorch-fsdp 源代码阅读(一)】-全流程概览

专有名词解释

  • warp:对模型进行包裹,使其具备fsdp的相关的分布式能力

  • shard: 对参数进行切分,得到每个rank sharded的参数

  • unshard: 将切分的参数allgather,得到完整的参数

  • reshard:将完整的参数释放,只保留每个rank的sharded的参数

  • sharded:切分后的参数

  • unsharded:完整的参数

fsdp概览

如下图所示,首先对于Zero算法来说:

  • Zero-1切分了优化器状态

  • Zero-2切分了优化器状态和梯度

  • Zero-3切分了优化器状态和梯度和参数

对于fsdp来说,它实际上就是Zero-3。传统的数据并行会在每个GPU上维护一份模型参数,梯度,优化器状态的副本,但是FSDP将这些状态分片到所有的数据并行worker中,并且可以选择将分片的模型参数卸载到CPU上,从而使得若现在有 $n$个GPU,某一层的参数量为 $m$,那么每个GPU会维护这一层 $m/n$个参数。

通常,模型层以嵌套方式用 FSDP 包装,因此在前向或后向计算期间,只有单个 FSDP 实例中的层需要将完整参数收集到单个设备。聚合到的完整参数会在计算后立即释放,释放的内存可以用于下一层的计算。通过这种方式,可以节省峰值 GPU 内存,因此可以扩展训练以使用更大的模型大小或更大的批量大小。为了进一步最大化内存效率,当实例在计算中不活动时,FSDP 可以将参数、梯度和优化器状态卸载到 CPU。

使用示例

如下是一个使用示例,简单来说有这几个关键步骤:

  1. 定义自动 wrap 策略:只 wrap nn.Linear层

  2. 将模型用FSDP进行包裹,从而转变为fsdp_model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# torchrun --nproc_per_node=2 --master_port=47123 test.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from functools import partial

def my_wrap_criteria(module):
return isinstance(module, nn.Linear)

# 定义模型
class Net(nn.Module):
def __init__(self, H):
super(Net, self).__init__()
self.fc0 = nn.Linear(H, H, bias=False)
self.fc1 = nn.Linear(H, H, bias=False)
self.fc2 = nn.Linear(H, H, bias=False)
self.fc3 = nn.Linear(H, H, bias=False)
self.fc4 = nn.Linear(H, H, bias=False)

def forward(self, x):
x = self.fc0(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
return x

# 启动函数
def main():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])

dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

H = 512
model = Net(H).cuda()

# 定义自动 wrap 策略:只 wrap nn.Linear
policy = partial(lambda_auto_wrap_policy, lambda_fn=my_wrap_criteria)

fsdp_model = FSDP(
model,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=policy
)
print("fsdp_model:", fsdp_model)

optimizer = optim.Adam(fsdp_model.parameters(), lr=1e-3)

# 模拟数据
x = torch.randn(32, H).cuda()
target = torch.randn(32, H).cuda()
criterion = nn.MSELoss()

fsdp_model.train()
for epoch in range(3):
optimizer.zero_grad()
output = fsdp_model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"[Rank {rank}] Epoch {epoch} Loss: {loss.item()}")

dist.destroy_process_group()

if __name__ == "__main__":
main()

fsdp初始化

1
2
class FullyShardedDataParallel(nn.Module, _FSDPState):

注意fsdp的初始化过程是“惰性”的(lazy),只有在forward调用的时候才会进行初始化,从而对模型进行shard。

Warp

具体的_auto_wrap的调用是在使用FullyShardedDataParallel包裹module后进行初始化的时候实现的,如下:

1
2
3
4
5
6
7
8
_auto_wrap(
module,
auto_wrap_policy,
self._ignored_modules,
self._ignored_params,
root_kwargs,
FullyShardedDataParallel,
)

判断模块是否需要划分的函数

pytorch中提供了多种对模型进行自动切分和包装的方法,下面介绍几个常用的:

CustomPolicy

这支持自定义包装策略,关键是允许通过lambda_fn来进行自定义。lambda_fn可以返回bool值,这代表和root执行同样的分片参数,也可以返回args,这代表自定义的分片参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

class CustomPolicy(_Policy):
"""
功能:
这是一个高度灵活的 FSDP(Fully Sharded Data Parallel)自动包装策略。
它允许用户通过提供一个自定义的 lambda 函数来精确控制对哪个模块应用 FSDP 包装,
以及使用什么样的参数。

工作机制:
策略的核心是用户传入的 `lambda_fn`。FSDP 会遍历模型中的每一个模块,并用该模块
作为参数调用 `lambda_fn`。根据 `lambda_fn` 的返回值,决定如何操作:
- 返回 `False`:不包装当前模块。
- 返回 `True`:使用 FSDP 的默认参数包装当前模块。
- 返回一个非空字典:包装当前模块,并使用该字典中的键值对来覆盖或补充 FSDP 的默认参数。
这允许为特定模块设置不同的分片策略(ShardingStrategy)或其他配置。

使用场景:
当你需要比 `ModuleWrapPolicy`(基于类名包装)更精细的控制时,此策略非常有用。
例如,你可能想为模型的大部分 Transformer 层使用默认包装,但为最后的输出层(如 lm_head)
指定一个不同的分片策略,或者完全不包装某个特定的层。
"""

def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]):
"""
构造函数。

参数:
- lambda_fn (Callable): 一个函数,它接受一个 `nn.Module` 实例作为输入,
并返回一个布尔值或一个字典,用于决定包装行为。
"""
self._lambda_fn = lambda_fn

def _run_policy(
self,
root_module: nn.Module,
ignored_modules: set[nn.Module],
root_kwargs: dict[str, Any],
) -> dict[nn.Module, dict[str, Any]]:
"""
(内部方法)执行策略的核心逻辑,遍历所有模块并应用 lambda 函数来决定包装方案。

参数:
- root_module (nn.Module): 整个模型。
- ignored_modules (set[nn.Module]): 需要忽略的模块集合。
- root_kwargs (dict[str, Any]): 应用于根 FSDP 模块的参数,作为包装子模块时的默认参数。

返回:
一个字典,键是需要被包装的目标模块实例,值是应用于该模块的 FSDP 参数。
"""
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
# 遍历整个模型的所有子模块
for module in root_module.modules():
# 如果模块在忽略列表中,则跳过
if module in ignored_modules:
continue

# 对当前模块调用用户提供的 lambda 函数
res = self._lambda_fn(module)

# 验证 lambda 函数的返回值是否合法(必须是布尔型或字典)
if not isinstance(res, (dict, bool)):
raise ValueError(
f"传递给 CustomPolicy 的 lambda_fn 应返回 "
f"False/True 或一个 kwarg 字典,但它返回了 {res}"
)

# 如果返回 False 或一个空字典,表示不包装该模块,直接跳过
if not res:
continue

# 如果需要包装,首先浅拷贝根 FSDP 的参数作为默认值
# 这样做是为了防止不同 FSDP 实例间共享和意外修改同一份配置
kwargs = copy.copy(root_kwargs)

# 如果 lambda 函数返回的是一个字典,用它的内容更新(覆盖)默认参数
if isinstance(res, dict):
kwargs.update(res)

# 将最终确定要包装的模块及其配置参数存入结果字典
target_module_to_kwargs[module] = kwargs

return target_module_to_kwargs

使用示例:

1
2
3
4
5
6
7
8
9
model = init_transformer_model(...)
def lambda_fn(module: nn.Module):
if module is model.lm_head:
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
elif isinstance(module, TransformerBlock):
return True
return False
policy = CustomPolicy(lambda_fn)
fsdp_model = FSDP(model, auto_wrap_policy=policy)
Module结构学习

对于module,其记录子结构的变量为_modules: dict[str, Optional["Module"]],即对于函数:

1
self.block1 = SomeModule()

底层实际做了如下事情(在 setattr 中):

1
self._modules["block1"] = SomeModule()

也就是说所有子模块都保存在 self._modules 中(有顺序的字典)。

这里调用了root_module.modules()来获取root_module的子modules(),该函数实际上就是用深度优先遍历的方式去遍历self._modules ,具体的方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def modules(self) -> Iterator["Module"]:
r"""返回一个遍历网络中所有模块的迭代器。

功能:
- 这是一个便捷方法,用于获取模型及其所有子模块的实例,而不需要它们的名称。
- 它在内部调用 `self.named_modules()`,但忽略了每个元组中的名称部分。

产生:
Module: 网络中的一个模块。

注意:
与 `named_modules` 一样,重复的模块实例默认只返回一次。在下面的例子中,
`l` 只会被返回一次。

示例::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)

0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)

"""
# 核心实现:
# 1. 调用 `self.named_modules()`,这个方法会返回一个 (名称, 模块) 元组的迭代器。
# 2. 在 for 循环中,使用 `_` 来接收并“丢弃”元组中的第一个元素(即模块的名称)。
# 3. `module` 变量接收元组中的第二个元素(即模块对象本身)。
# 4. `yield module` 将模块对象作为生成器的下一个值返回。
# 这种实现方式非常优雅,因为它将所有复杂的遍历逻辑(如递归、处理重复)
# 全部委托给了 `named_modules` 方法,自身保持了极度的简洁。
for _, module in self.named_modules():
yield module

def named_modules(
self,
memo: Optional[set["Module"]] = None, # 用于记录已访问模块的集合,防止重复处理
prefix: str = "", # 当前模块的名称前缀
remove_duplicate: bool = True, # 是否移除重复的模块实例
):
r"""返回一个迭代器,该迭代器遍历网络中的所有模块,同时产生模块的名称和模块本身。

这是一个深度优先(pre-order,先序)的遍历。

参数:
memo: 用于存储已添加到结果中的模块集合的备忘录。主要用于内部递归调用。
prefix: 将被添加到模块名称前面的前缀。主要用于内部递归调用。
remove_duplicate: 是否在结果中移除重复的模块实例。

产生:
(str, Module): (名称, 模块) 的元组。

注意:
默认情况下,重复的模块只返回一次。在下面的例子中,`l` 只会被返回一次。

示例::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l) # net 中有两个对同一 l 实例的引用
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)

0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
# 注意:尽管有两个 l,但 ('1', l) 不会再次出现,因为 l 已经被访问过。

"""
# 1. 初始化备忘录(memo)
# 如果是顶层调用(非递归调用),memo 为 None,此时创建一个新的集合。
# 在递归调用中,memo 会被传递下去。
if memo is None:
memo = set()

# 2. 处理当前模块(self)
# 检查当前模块实例是否已经被访问过。
if self not in memo:
# 如果 remove_duplicate 为 True,则将当前模块添加到备忘录中,
# 以确保后续遇到同一个实例时不再处理。
if remove_duplicate:
memo.add(self)

# 3. 产生当前模块的名称和实例
# 这是先序遍历的体现:先访问根节点(当前模块)。
yield prefix, self

# 4. 递归遍历所有子模块
# self._modules 是一个有序字典,存储了所有直接子模块(例如 self.layer1, self.conv2)。
for name, module in self._modules.items():
# 如果某个子模块是 None(例如,通过 delattr 删除后),则跳过。
if module is None:
continue

# 5. 构建子模块的完整名称
# 在当前前缀的基础上,添加子模块的名称。
# 例如,如果当前 prefix 是 'encoder',子模块 name 是 'layer1',
# 那么 submodule_prefix 就是 'encoder.layer1'。
submodule_prefix = prefix + ("." if prefix else "") + name

# 6. 递归调用
# 使用 `yield from` 将递归调用产生的生成器内容直接“转发”出去。
# 将 memo 和新构建的 submodule_prefix 传递给下一次递归。
yield from module.named_modules(
memo, submodule_prefix, remove_duplicate
)

transformer_auto_wrap_policy

这是transformer模型中比较常用的策略,主要就是自定义要划分的模块的类型,然后自动划分。具体使用的时候需要用functools.partial进行包装。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def transformer_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
transformer_layer_cls: Set[Type[nn.Module]],
) -> bool:
"""
功能:
这是一个专门为 Transformer 模型设计的便捷包装策略。它本质上是 `_module_wrap_policy` 的一个别名或封装。

目的:
提供一个语义上更清晰的函数名 (`transformer_auto_wrap_policy`),让用户在处理 Transformer 模型时,
能更直观地理解其作用。它特别适用于包装 Transformer 的编码器/解码器层(例如 `TransformerEncoderLayer`)。

要点:
- 它直接调用 `_module_wrap_policy`,并将 `transformer_layer_cls` 作为 `module_classes` 传递过去。
- 正确地包装共享参数(如词嵌入层)非常重要,因为它们必须位于同一个 FSDP 实例中。
这个策略通过将所有指定的层(通常是 Transformer block)包装起来,有助于确保模型中其他部分(如共享的嵌入层)
最终被包含在更高层级的 FSDP 实例中,从而被正确处理。
"""
# 直接调用通用的模块包装策略函数,实现完全相同的功能。
return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)

def _module_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
module_classes: Set[Type[nn.Module]],
) -> bool:
"""
功能:
这是一个核心的辅助函数,用于实现基于模块类的自动包装策略。
FSDP 的自动包装过程是一个从上到下(top-down)的遍历,但包装决策是从下到上(bottom-up)做出的。
此函数在这个过程中被调用,以决定是否应该包装当前的模块。

工作机制:
该函数的行为取决于 `recurse` 参数:
1. 当 `recurse=True` 时:表示 FSDP 正在递归地深入模块树(DFS 过程)。
在这种情况下,函数总是返回 `True`,告诉 FSDP 继续向下遍历,直到到达叶子模块或一个已经被包装的子模块。
2. 当 `recurse=False` 时:表示 FSDP 已经完成对当前模块所有子模块的遍历,现在需要对当前模块本身做出决策。
这时,函数会检查 `module` 是否是 `module_classes` 中指定的任何一个类的实例。
- 如果是,则返回 `True`,表示“请包装我这个模块”。
- 如果不是,则返回 `False`,表示“不要包装我,继续向上返回”。

参数:
module (nn.Module): 当前正在被考虑的模块。
recurse (bool): 控制函数行为的标志。`True` 表示继续递归,`False` 表示需要做出包装决策。
nonwrapped_numel (int): 尚未被包装的参数数量(在此函数中未使用,但在其他更复杂的策略中可能有用)。
module_classes (Set[Type[nn.Module]]): 一个包含模块类的集合。任何属于这些类的模块都将被包装。

返回:
一个布尔值。如果 `recurse=True`,总是返回 `True`。如果 `recurse=False`,返回是否应该包装 `module`。
"""
# 如果标志为 True,意味着我们仍在递归地深入模块树,所以总是返回 True 以继续递归。
if recurse:
return True

# 如果标志为 False,意味着已经到达决策点。
# 检查当前模块的类型是否在用户指定的需要包装的类型列表中。
# isinstance 的第二个参数需要是元组,所以我们将集合转换为元组。
return isinstance(module, tuple(module_classes))

使用示例:

1
2
3
4
5
6
7
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
),
use_orig_params=True,
)

具体进行自动划分

注意观察下面的函数中的fsdp_fn为FullyShardedDataParallel,即这个fsdp_fn的作用是把module包装成FullyShardedDataParallel类型。

这里有两种划分的调用方式:

  1. 如果 policy_Policy的实例(推荐方式),则使用策略对象来决定哪些模块需要被包装。

  2. 如果 policy 是一个可调用对象(旧版方式),则使用递归的方式进行包装。

暂时先只看第一种_Policy的实例的方法,其执行顺序如下:

  1. 执行_run_policy得到root_module下所有符合包装规则的module以及args

  2. 如果配置了混合精度就特殊处理一下

  3. 验证要包装的模块中的冻结参数(即 requires_grad=False 的参数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def _auto_wrap(
root_module: nn.Module,
policy: Union[Callable, _Policy], # 包装策略,可以是_Policy对象或一个可调用函数
ignored_modules: set[nn.Module], # 应该忽略不进行包装的模块集合
ignored_params: set[nn.Parameter], # 应该忽略不进行包装的参数集合
root_kwargs: dict[str, Any], # FSDP的根配置参数
fsdp_fn: Callable, # FSDP的包装函数,例如 `FullyShardedDataParallel` 类或 `fully_shard` 函数
):
"""
根据 `policy`,以后序遍历的方式自动包装 `root_module` 模块树中的模块。

此函数是 FSDP 自动包装功能的核心入口。
它根据传入的 `policy` 类型,选择不同的包装逻辑:
1. 如果 `policy` 是 `_Policy` 的实例(推荐方式),则使用策略对象来决定哪些模块需要被包装。
2. 如果 `policy` 是一个可调用对象(旧版方式),则使用递归的方式进行包装。

前提条件: `root_kwargs` 应该包含除 `module` 之外的所有FSDP构造函数参数。
"""
# 检查模块是否已经被FSDP包装过。自动包装不支持对已包装的模块再次进行包装(嵌套包装)。
_check_nested_wrapping(root_module)

# --- 分支1:基于 _Policy 对象的策略化自动包装 (推荐方式) ---
if isinstance(policy, _Policy):
# 运行策略,获取一个从目标模块到其对应FSDP参数的映射字典
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, root_kwargs
)
# 如果配置了混合精度(mixed_precision)
if root_kwargs.get("mixed_precision") is not None:
# 运行混合精度覆盖策略,这可能会修改 target_module_to_kwargs,
# 例如,将某些模块的混合精度设置为特定的类型或禁用它。
target_module_to_kwargs = _run_mixed_precision_override_policy(
root_module,
root_kwargs["mixed_precision"]._module_classes_to_ignore,
ignored_modules,
root_kwargs,
target_module_to_kwargs,
)
# 对指定的不使用混合精度的模块类别,通过注册前向钩子来覆盖其行为
overridden_module_classes = _override_module_mixed_precision(
root_module, root_kwargs["mixed_precision"]._module_classes_to_ignore
)
# 如果有模块的混合精度设置被覆盖,发出警告
_warn_on_overridden_mixed_precision(overridden_module_classes)

# 验证要包装的模块中的冻结参数(即 requires_grad=False 的参数)
# 确保所有参数的 `requires_grad` 状态在所有进程中是一致的
_validate_frozen_params(
root_module,
set(target_module_to_kwargs.keys()),
ignored_params,
root_kwargs.get("use_orig_params", False),
)
# 根据 target_module_to_kwargs 构建一个包装函数
wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
# 以后序遍历的方式,将包装函数应用到模块树上
_post_order_apply(root_module, wrap_fn)
return # 完成包装,直接返回

# --- 分支2:基于可调用函数的递归自动包装 (旧版方式) ---
# 准备递归包装所需的参数
recursive_wrap_kwargs = {
"module": root_module,
"auto_wrap_policy": policy,
"wrapper_cls": fsdp_fn,
"ignored_modules": ignored_modules,
"ignored_params": ignored_params,
"only_wrap_children": True, # 表示只对子模块进行递归包装
}
# 如果配置了混合精度
if root_kwargs.get("mixed_precision") is not None:
# 对指定的不使用混合精度的模块类别,通过注册前向钩子来覆盖其行为
overridden_module_classes = _override_module_mixed_precision(
root_module, root_kwargs["mixed_precision"]._module_classes_to_ignore
)
# 创建一个组合策略:它会同时应用用户提供的原始策略,
# 以及一个单独包装被忽略混合精度模块的策略。
policy = functools.partial(
_or_policy, # _or_policy 会依次尝试列表中的每个策略
policies=[
policy, # 用户原始策略
partial( # 一个新策略,用于单独包装需要忽略混合精度的模块
_wrap_module_cls_individually,
module_classes=root_kwargs["mixed_precision"]._module_classes_to_ignore,
),
],
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
# 如果有模块的混合精度设置被覆盖,发出警告
_warn_on_overridden_mixed_precision(overridden_module_classes)

# 执行递归包装。将根配置和递归配置合并后传递给 _recursive_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]

_construct_wrap_fn&_post_order_apply

首先构建一个warp函数,该函数对于target model且不是root的会调用fsdp_fn执行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def _construct_wrap_fn(
root_module: nn.Module, # 整个模型的根模块
target_module_to_kwargs: dict[nn.Module, dict[str, Any]], # 一个字典,映射了需要被包装的模块到它们的FSDP参数
fsdp_fn: Callable, # 实际执行包装的函数,例如 `FullyShardedDataParallel`
) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
此函数是一个高阶函数,它的作用是构建并返回一个"包装函数"(`fn`)。
这个返回的函数将被传递给 `_post_order_apply`,用于在后序遍历中实际应用FSDP包装。

功能:
- 它利用闭包捕获 `root_module`、`target_module_to_kwargs` 和 `fsdp_fn`。
- 返回的 `fn` 封装了决定是否包装一个特定模块以及如何包装它的全部逻辑。

参数:
- root_module (nn.Module): 模型的根模块,用于在包装时进行排除检查。
- target_module_to_kwargs (dict): 由包装策略生成的字典,指明了哪些模块需要被包装以及它们各自的FSDP配置。
- fsdp_fn (Callable): FSDP的包装器,如 `FullyShardedDataParallel` 类。

返回:
- 一个可调用对象 `fn`,该函数接受一个模块作为输入,如果该模块需要被包装,则返回包装后的模块,否则返回 `None`。
"""

def fn(module: nn.Module) -> Optional[nn.Module]:
"""
这个内部函数是实际执行替换逻辑的单元。
它会被 `_post_order_apply` 在遍历模型树时对每个模块调用。
"""
# 检查当前遍历到的 `module` 是否在我们的目标包装列表 `target_module_to_kwargs` 中
# 同时,显式地避免包装根模块 `root_module`,因为根模块的包装通常由调用FSDP的用户在最外层手动完成。
if module in target_module_to_kwargs and module is not root_module:
# 如果模块是目标模块,就从字典中获取其对应的FSDP参数
kwargs = target_module_to_kwargs[module]
# 使用传入的 `fsdp_fn` (例如 `FullyShardedDataParallel`) 和对应的参数来包装当前模块
# `_post_order_apply` 会用这里返回的新模块替换掉原始模块
return fsdp_fn(module, **kwargs)

# 如果模块不是目标包装模块,或者它就是根模块,则返回 None。
# `_post_order_apply` 看到返回 `None` 时,不会对原始模块做任何改动。
return None

return fn

然后对于_post_order_apply函数,他就是会执行上面构造的fn,然后以后序遍历的方式去遍历所有的子模块(注意不会替换root 模块),替换的方法就是setattr。

为什么要避免替换root呢,因为root这时候已经是FullyShardedDataParallel类型了,不需要重复进行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 注意:我们有意保持此函数简单,并将复杂性隔离到 `fn` 中,
# 以便能够通用地使用此函数。我们将来可能会将其移动到非 FSDP 特定的文件夹和/或使其公开。
def _post_order_apply(
root_module: nn.Module,
fn: Callable[[nn.Module], Optional[nn.Module]],
):
"""
此函数遵循后序遍历(post-order traversal)将 `fn` 应用于 `root_module` 模块树中的每个模块。
如果 `fn` 返回一个 :class:`nn.Module`,那么它将在树中用新返回的模块替换原始模块。
否则,`fn` 应返回 `None`,在这种情况下,模块不会被更改。

后序遍历意味着,对于任何给定的模块,函数 `fn` 会先被应用于其所有子模块,然后再应用于该模块本身。

参数:
- root_module (nn.Module): 整个模型层级结构的根模块。
- fn (Callable[[nn.Module], Optional[nn.Module]]): 一个可调用对象,它接收一个模块作为输入。
- 如果需要替换该模块,则返回一个新的 nn.Module 实例。
- 如果不需要改变该模块,则返回 None。
"""
# 跟踪已访问的模块,以避免多次访问共享的模块实例。
visited_modules: set[nn.Module] = {root_module}

# 定义一个内部辅助函数来进行递归的后序遍历和应用。
def _post_order_apply_inner(
module: nn.Module, # 当前正在处理的模块
module_name: str, # 当前模块在其父模块中的属性名
parent_module: Optional[nn.Module], # 当前模块的父模块
):
# 1. 遍历当前模块的所有直接子模块。
for child_module_name, child_module in module.named_children():
# 如果子模块还没有被访问过(防止重复处理共享模块)
if child_module not in visited_modules:
visited_modules.add(child_module) # 标记为已访问
# 2. 对子模块进行递归调用。这是实现后序遍历的关键:先深入子树。
_post_order_apply_inner(child_module, child_module_name, module)

# 3. 在所有子模块都处理完毕后,对当前模块应用 `fn` 函数。
# 这就是“后序”的含义:先处理子节点,再处理父节点。
optional_module = fn(module)

# 4. 如果 `fn` 返回了一个新的模块实例(而不是 None),则替换原始模块。
if optional_module is not None:
# 断言确保非根模块必须有一个父模块,否则替换操作无法进行。
assert isinstance(parent_module, nn.Module), (
f"非根模块应该设置其父模块,但对于 {module} 得到了 {parent_module}"
)
# 断言确保非根模块必须有一个名称,否则无法通过名称在父模块中找到并替换它。
assert module_name, (
f"非根模块应该设置其模块名称,但对于 {module} 得到了一个空模块名"
)
# 断言确保 `fn` 的返回值要么是 None,要么是 nn.Module 的实例。
assert isinstance(optional_module, nn.Module), (
f"fn 应返回 None 或 nn.Module,但得到了 {optional_module}"
)
# 使用 setattr 动态地将父模块中的 `module_name` 属性设置为新的 `optional_module`,
# 从而完成模块的替换。
setattr(parent_module, module_name, optional_module)

# 从根模块开始启动整个后序应用过程。
# 根模块没有父模块(None)和名称("")。
_post_order_apply_inner(root_module, "", None)

初始化param_handle

在Warp之后,初始化param_handle的顶层调用:

1
2
3
4
5
6
7
_init_param_handle_from_module(
self,
module,
device_id,
param_init_fn,
sync_module_states,
)

_init_param_handle_from_module

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@no_type_check
def _init_param_handle_from_module(
state: _FSDPState, # FSDP 状态对象,用于跟踪和管理 FSDP 实例的各种状态
fully_sharded_module: nn.Module, # 需要被 FSDP 完全分片的模块
device_id: Optional[Union[int, torch.device]], # 目标设备 ID
param_init_fn: Optional[Callable[[nn.Module], None]], # 可选的参数初始化函数,用于在物化(materialize)模块时调用
sync_module_states: bool, # 是否在初始化时同步模块的状态(参数和缓冲区)
) -> _FSDPState:
"""从一个模块 `fully_sharded_module` 初始化一个 `FlatParamHandle`。

`FlatParamHandle` 是 FSDP 的核心组件,它将模块的多个原始参数展平(flatten)
并合并成一个单一的、连续的 `FlatParameter`。这个函数负责完成这一过程。
"""""
# 1. 检查和准备设备
# 确保模块的所有参数都在同一个设备上,或者在 CPU 上,或者在 'meta' 设备上
_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
# 根据传入的 device_id 获取实际的 torch.device 对象
device_from_device_id = _get_device_from_device_id(
device_id, state.rank, state._device_handle
)

# 2. 模块物化(Materialization)
# 检查模块是否需要在设备上进行物化。如果模块的参数在 'meta' 设备上,
# 或者使用了 torchdistX 的延迟初始化,就需要进行物化,即为参数分配实际的内存。
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
fully_sharded_module, state._ignored_params, state._ignored_modules
)
# 如果需要物化并且用户提供了自定义的参数初始化函数 `param_init_fn`
if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
# 使用用户提供的函数来初始化并物化模块
_materialize_with_param_init_fn(
fully_sharded_module, param_init_fn, state._ignored_modules
)
# 如果是 'meta' 模块,但没有提供初始化函数
elif is_meta_module:
# 使用默认方式在目标设备上物化模块
_materialize_meta_module(
fully_sharded_module,
device_id,
state._ignored_modules,
state._device_handle,
)
# 如果使用了 torchdistX 的延迟初始化
elif is_torchdistX_deferred_init:
# 使用 torchdistX 的 API 来物化模块
deferred_init.materialize_module(
fully_sharded_module,
# 检查函数确保我们不会重复物化已经被 FSDP 管理或被忽略的子模块
check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
and submodule not in state._ignored_modules,
)

# 3. 将模块移动到目标设备
# 收集所有被忽略的模块中的缓冲区,这些缓冲区将不会被移动
ignored_buffers = {
buffer
for ignored_module in state._ignored_modules
for buffer in ignored_module.buffers()
}
# 将整个模块(包括其参数和缓冲区)移动到目标设备
_move_module_to_device(
fully_sharded_module,
state._ignored_params, # 忽略指定参数
ignored_buffers, # 忽略指定缓冲区
device_from_device_id,
)
# 确定计算设备(通常是 GPU),并更新 FSDP 状态
state.compute_device = _get_compute_device(
fully_sharded_module,
state._ignored_params,
device_from_device_id,
state.rank,
state._device_handle,
)

# 4. 参数同步
# 获取 FSDP 需要管理的所有原始参数
managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
# 验证这些参数是否符合 FSDP 的要求
_verify_managed_params(fully_sharded_module, managed_params)
# 如果设置了 `sync_module_states`,则在所有 rank 之间同步参数和缓冲区
# 这确保了在训练开始前,所有进程上的模型状态是完全一致的
if sync_module_states:
_sync_module_params_and_buffers(
fully_sharded_module, managed_params, state.process_group
)
# 对于混合分片策略,还需要在跨节点的进程组中进行同步
if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
_sync_module_params_and_buffers(
fully_sharded_module, managed_params, state._inter_node_pg
)

# 5. 创建 FlatParamHandle
# 这是最后一步,使用准备好的参数列表来实际创建和初始化 FlatParamHandle
_init_param_handle_from_params(state, managed_params, fully_sharded_module)

# 返回更新后的 FSDP 状态
return state

_get_orig_params

得到fully_sharded_module中所有的参数,也就是tensor矩阵。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def _get_orig_params(
module: nn.Module, # 要从中提取参数的模块
ignored_params: set[nn.Parameter], # 一个包含应被忽略的参数的集合
) -> Iterator[nn.Parameter]: # 返回一个参数的迭代器(生成器)
"""
返回一个遍历 `module` 中原始参数的迭代器。

这个迭代器不会返回以下几种参数:
1. 在 `ignored_params` 集合中的参数。
2. 任何 `FlatParameter` 实例(这可能因为嵌套使用 FSDP 而出现)。
3. 任何已经被展平的原始参数(这只在 `use_orig_params=True` 模式下有意义)。
"""
# 获取模块所有参数的生成器
param_gen = module.parameters()
try:
# 使用一个无限循环来手动迭代生成器
# 这样做是为了清晰地处理 StopIteration 异常
while True:
# 从生成器中获取下一个参数
param = next(param_gen)

# 这是核心的过滤逻辑:
# 1. `param not in ignored_params`:确保该参数不是用户明确指定要忽略的参数。
# 2. `not _is_fsdp_flattened(param)`:检查该参数是否已经被 FSDP 处理过。
# `_is_fsdp_flattened` 会检查参数是否是 `FlatParameter` 的实例,
# 或者是否已经被合并到另一个 `FlatParameter` 中。这对于处理嵌套 FSDP
# 至关重要,可以防止重复包装和管理同一个参数。
if param not in ignored_params and not _is_fsdp_flattened(param):
# 如果参数通过了所有检查,就将其 yield 出去
yield param
except StopIteration:
# 当 `module.parameters()` 迭代完成并抛出 StopIteration 异常时,
# 捕获它并正常退出循环。`pass` 表示什么也不做。
pass

_init_param_handle_from_params

注意初始化了FlatParamHandle后立刻进行了shard

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@no_type_check
def _init_param_handle_from_params(
state: _FSDPState, # FSDP 状态对象,用于跟踪和管理 FSDP 实例的各种状态
params: list[nn.Parameter], # 从模块中收集到的、需要被 FSDP 管理的原始参数列表
fully_sharded_module: nn.Module, # 这些参数所属的、需要被 FSDP 完全分片的模块
):
# 如果没有需要管理的参数,则直接返回,不做任何操作
if len(params) == 0:
return

# 1. 实例化 FlatParamHandle
# FlatParamHandle 是 FSDP 的核心,它负责将 `params` 列表中的多个参数
# “展平”(flatten)并合并成一个单一的、连续的张量(FlatParameter)。
# 这里传入了所有必要的配置,如分片策略、混合精度设置、进程组等。
handle = FlatParamHandle(
params, # 原始参数列表
fully_sharded_module, # 所属模块
state.compute_device, # 计算设备 (例如, 'cuda:0')
SHARDING_STRATEGY_MAP[state.sharding_strategy], # 分片策略
state.cpu_offload.offload_params, # 是否启用 CPU offload
state.mixed_precision.param_dtype, # 参数的数据类型 (例如, torch.float16)
state.mixed_precision.reduce_dtype, # all-reduce 操作的数据类型
state.mixed_precision.keep_low_precision_grads, # 是否保留低精度梯度
state.process_group, # 分布式通信的进程组
state._use_orig_params, # 是否使用原始参数的视图(一种优化)
fsdp_extension=state._fsdp_extension, # FSDP 扩展
)

# 2. 对 FlatParameter 进行分片
# 调用 .shard() 方法,根据指定的分片策略,将完整的 FlatParameter 分割成
# 多个分片,每个 rank 只保留自己负责的那一部分。这是实现显存优化的关键。
handle.shard()

# 3. 更新 FSDP 状态
# 确保当前 FSDP 实例还没有关联任何 handle
assert not state._handle
# 将新创建的 FlatParameter 添加到 FSDP 实例的参数列表中,以便优化器可以找到它
state.params.append(handle.flat_param)
# 将新创建的 handle 保存到 FSDP 状态中
state._handle = handle
# 建立从模块到其对应 handle 的映射关系
state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle

# 4. 处理 CPU Offload
# 如果启用了 CPU offload,并且分片后的 FlatParameter 当前不在 CPU 上
cpu_device = torch.device("cpu")
if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
# 将该分片移动到 CPU,以释放 GPU 显存
handle.flat_param_to(cpu_device)

FlatParamHandle

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class FlatParamHandle:
"""
一个管理扁平化参数(:class:`FlatParameter`)的句柄。

这包括分片和视图管理。

参数:
params (Sequence[nn.Parameter]): 要被展平到扁平化参数中的参数序列。
fully_sharded_module (nn.Module): 被 FSDP 包装的模块。
device (torch.device): 计算和通信设备,通常是 GPU。
sharding_strategy (ShardingStrategy): 应用于此句柄的 `FlatParameter` 的分片策略。
offload_params (bool): 是否将此句柄的 `FlatParameter` 卸载到 CPU。
mp_param_dtype (Optional[torch.dtype]): 用于参数的混合精度类型。
mp_reduce_dtype (Optional[torch.dtype]): 用于梯度归约的混合精度类型。
keep_low_precision_grads (bool): 是否保持低精度的梯度。
use_orig_params (bool): 如果为 True,FSDP 会保留原始参数变量,并从 `named_parameters()` 返回它们。
这允许在同一个 `FlatParameter` 内对不同原始参数使用不同的优化器超参数。
如果为 False,FSDP 会在每次迭代中重新构建参数。
"""

##################
# INITIALIZATION #
##################
def __init__(
self,
params: Sequence[Union[nn.Parameter, Tensor]],
fully_sharded_module: nn.Module,
device: torch.device,
sharding_strategy: HandleShardingStrategy,
offload_params: bool,
mp_param_dtype: Optional[torch.dtype],
mp_reduce_dtype: Optional[torch.dtype],
keep_low_precision_grads: bool,
process_group: dist.ProcessGroup,
use_orig_params: bool,
*,
fsdp_extension: Optional[FSDPExtensions] = None,
):
super().__init__()
# 确保传入的参数列表不为空
params = list(params)
if len(params) == 0:
raise ValueError(
f"不能用一个空的参数列表来构造 {self.__class__.__name__}"
)

# 初始化一些内部函数
self._init_setattr_fns()

# 从环境变量中读取一些用于调试和性能分析的高级配置
self._skip_writeback_check = (
os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
)
self._use_full_prec_in_eval = (
os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
)
# 这些 "fake" 选项用于性能分析,它们会跳过实际的通信操作
self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"

# ... (处理上述环境变量的警告信息) ...

# 是否对齐内存地址,目前仅在 use_orig_params=True 时启用
align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses)

# --- 初始化核心属性 ---
self.device = device # 计算设备 (e.g., 'cuda:0')
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
self.process_group = process_group # 分布式通信组
self.rank = process_group.rank() # 当前进程的排名
self.world_size = process_group.size() # 总进程数

# --- 存储 FSDP 的主要配置 ---
self._sharding_strategy = sharding_strategy # 分片策略 (e.g., SHARD_GRAD_OP)
self._offload_params = offload_params # 是否卸载到 CPU
self._use_orig_params = use_orig_params # 是否使用原始参数
self._keep_low_precision_grads = keep_low_precision_grads # 是否保留低精度梯度

# --- 初始化状态变量 ---
self._training_state = HandleTrainingState.IDLE # 初始状态为空闲
self._debug_level = dist.get_debug_level()
self._fully_sharded_module = fully_sharded_module # 关联的模块

# ... (初始化一些用于 prefetch 和执行顺序跟踪的内部状态变量) ...
self._handle_index: Optional[int] = None
self._needs_pre_forward_unshard = False
self._needs_pre_backward_unshard = False
self._prefetched = False

# --- 初始化数据类型 (dtype) ---
# 原始参数的数据类型
self._orig_param_dtype = params[0].dtype
# 初始化用于前向/后向传播和梯度计算的混合精度数据类型
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
assert self._fwd_bwd_param_dtype is not None # mypy

# 计算对齐所需的元素数量
self._aligned_numel = (
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
if align_addresses
else 0
)
self._fsdp_extension = fsdp_extension

# --- 最关键的步骤:创建扁平化参数和元数据 ---
# 这个方法会执行以下操作:
# 1. 计算所有参数的总元素数量。
# 2. 创建一个大的、一维的 `FlatParameter` 来容纳所有参数。
# 3. 将原始参数的数据复制到这个 `FlatParameter` 中。
# 4. 记录每个原始参数在 `FlatParameter` 中的位置、形状等元数据。
self._init_flat_param_and_metadata(
params,
fully_sharded_module,
self._aligned_numel,
use_orig_params, # type: ignore[arg-type]
)

# --- 最后一步:设置参数视图 ---
# 让原始模块的参数成为 `FlatParameter` 的“视图”(view)。
# 这意味着对原始参数的任何修改都会反映在 `FlatParameter` 上,反之亦然。
self._use_unsharded_views(as_params=False)

_init_flat_param_and_metadata

这个方法是 FSDP 魔法的起点。它像一个高效的管家,将一堆零散的参数( params )整齐地排列、打包,并贴上详细的标签(元数据),最终形成一个易于管理的单一实体( FlatParameter )。这个过程不仅处理了复杂的共享参数和内存对齐问题,还为后续的分布式操作(如 reduce-scatter )做好了准备。一旦这个方法执行完毕, FlatParamHandle 就拥有了一个完整的、随时可以被分片和恢复的扁平化参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def _init_flat_param_and_metadata(
self,
params: list[Union[Tensor, nn.Parameter]],
module: nn.Module,
aligned_numel: int,
use_orig_params: bool,
) -> None:
"""
初始化 ``FlatParameter`` 及其元数据。

注意:此方法只应在构造时调用一次,之后 ``FlatParameter`` 的元数据被假定为静态的。
"""
# --- 1. 输入验证 ---
if len(params) == 0:
raise ValueError("期望非空的 `params`")
if aligned_numel < 0:
raise ValueError(
f"期望非负的 `aligned_numel` 但得到了 {aligned_numel}"
)
# 验证所有待展平的张量具有相同的 dtype、requires_grad 和 device
(
dtype,
flat_param_requires_grad,
device,
) = self._validate_tensors_to_flatten(params)
params_set = set(params) # 转换为集合以提高查找效率

# --- 2. 初始化用于存储元数据的列表 ---
param_infos: list[ParamInfo] = [] # 参数信息 (名称, 所属模块, 模块名)
numels: list[int] = [] # 每个参数的元素数量
shapes: list[torch.Size] = [] # 每个参数的原始形状
strides: list[tuple[int, ...]] = [] # 每个参数的原始步长
fqns: list[str] = [] # 每个参数的完全限定名 (e.g., 'layer1.0.conv.weight')
shared_param_infos: list[SharedParamInfo] = [] # 共享参数的信息
# 用于跟踪已处理参数,以识别共享参数
shared_param_memo: dict[
Union[Tensor, nn.Parameter], tuple[nn.Module, str, str]
] = {}
params_to_flatten: list[Union[Tensor, nn.Parameter]] = [] # 最终要展平的张量列表(包括填充)
shared_params: list[Union[Tensor, nn.Parameter]] = [] # 识别出的共享参数列表
is_padding_mask: list[bool] = [] # 标记 `params_to_flatten` 中哪些是填充
total_numel = total_numel_without_padding = 0 # 计数器

# --- 3. 遍历模块,收集参数和元数据 ---
# 遍历模块的所有子模块和参数,以确保参数的顺序是确定的
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
submodule, recurse=False
):
if param not in params_set:
continue # 只处理在输入 `params` 列表中的参数

# 如果参数已经在 memo 中,说明它是共享参数
if param in shared_param_memo:
# ... 记录共享参数信息 ...
shared_params.append(param)
# ...
else: # 这是一个新的、未见过的参数
# --- 3a. 处理内存对齐填充 ---
if aligned_numel > 0:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
# 如果需要,插入一个填充张量
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
params_to_flatten.append(padding_tensor)
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad

# --- 3b. 记录主参数的元数据 ---
shared_param_memo[param] = (submodule, submodule_name, param_name)
params_to_flatten.append(param)
is_padding_mask.append(False)
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
numels.append(param.numel())
shapes.append(param.shape)
strides.append(param.stride())
# ... 记录其他元数据 ...
total_numel += param.numel()
total_numel_without_padding += param.numel()

# --- 4. 处理 reduce-scatter 的填充 ---
# 为了让 reduce-scatter 操作更高效,需要确保总元素数能被 world_size 整除
if aligned_numel > 0:
numel_to_pad = self.world_size - (total_numel % self.world_size)
if numel_to_pad > 0 and numel_to_pad < self.world_size:
# ... 如果需要,再次插入填充张量 ...
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
params_to_flatten.append(padding_tensor)
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad

# --- 5. 执行展平操作 ---
# 调用 `flatten_tensors_into_flat_param` 将 `params_to_flatten` 列表中的所有张量
# 合并成一个大的、一维的 `FlatParameter`
self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
params_to_flatten,
aligned_numel=0, # 此时已手动处理完对齐,故传 0
requires_grad=flat_param_requires_grad,
)

# --- 6. 将元数据附加到 FlatParameter 上 ---
# 调用 `FlatParameter` 的静态方法,将之前收集的所有元数据(形状、步长、名称等)
# 作为属性附加到新创建的 `self.flat_param` 对象上。
FlatParameter._init_metadata(
self.flat_param,
param_infos,
numels,
shapes,
strides,
# ... 传递所有其他元数据列表 ...
)
flatten_tensors_into_flat_param

最后的结果就是进行了张量展开,获得了一个扁平的张量,形状为:[参数数量,参数长度]。即每个参数param都变成了一维,最后各个param都拼接在了一起。所以这些参数最后在物理地址上都是连续的,方便操作。

1
2
3
4
5
6
7
8
def flatten_tensors_into_flat_param(
self,
tensors: list[Tensor],
aligned_numel: int,
requires_grad: bool,
) -> FlatParameter:
flat_param_data = self.flatten_tensors(tensors, aligned_numel)
return FlatParameter(flat_param_data, requires_grad=requires_grad)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def flatten_tensors(
self,
tensors: list[Tensor], # 输入:一个待展平的张量列表
aligned_numel: int, # 输入:用于内存对齐的元素数量。如果为0,则不进行对齐填充
) -> Tensor:
"""
将 `tensors` 展平为单个扁平张量。

如果 `aligned_numel` 大于 0,展平过程会包含可选的填充,
其中 `aligned_numel` 给出了实现地址对齐所需的元素数量。

注意:填充对齐算法必须与 `_init_flat_param_metadata` 方法保持同步。
我们分离这两个方法是因为初始化只发生一次,而此方法可能在训练过程中
被多次调用(例如,用于保存检查点)。
"""
# --- 1. 输入校验 ---
if len(tensors) == 0:
raise ValueError("期望 `tensors` 列表不为空")
if aligned_numel < 0:
raise ValueError(
f"期望 `aligned_numel` 为非负数,但得到了 {aligned_numel}"
)
# 校验所有输入张量的数据类型(dtype)和设备(device)是否一致
dtype, _, device = self._validate_tensors_to_flatten(tensors)

flat_tensors: list[Tensor] = [] # 用于存储最终要拼接的张量(包括填充)

# --- 2. 处理对齐填充 (如果需要) ---
if aligned_numel > 0:
total_numel = 0 # 记录当前已处理的元素总数
for tensor in tensors:
# 计算在添加当前张量之前需要多少填充,以使其起始位置按 `aligned_numel` 对齐
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
# 创建一个指定大小的填充张量
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad

# 添加实际的张量。首先将其展平为一维
# 如果张量是内存连续的,直接 flatten;否则使用 as_strided 避免额外拷贝
flat_tensors.append(
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
)
total_numel += tensor.numel()

# --- 3. 处理分片填充 (为了能被 world_size 整除) ---
numel_to_pad = self.world_size - (total_numel % self.world_size)
if numel_to_pad > 0 and numel_to_pad < self.world_size:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
else:
# --- 4. 无对齐填充的简单情况 ---
# 直接将每个张量展平并放入列表
flat_tensors = [
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
for tensor in tensors
]

# --- 5. 最终拼接 ---
# 将列表中的所有张量(包括原始张量和所有填充张量)拼接成一个最终的扁平化张量
return torch.cat(flat_tensors, dim=0)

fsdp模型forward

最外部的代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
handle = self._handle
with torch.autograd.profiler.record_function(
"FullyShardedDataParallel.forward"
):
args, kwargs = _root_pre_forward(self, self, args, kwargs)
unused = None
args, kwargs = _pre_forward(
self,
handle,
_pre_forward_unshard,
self._fsdp_wrapped_module,
args,
kwargs,
)
if handle:
_p_assert(
handle.flat_param.device == self.compute_device,
"Expected `FlatParameter` to be on the compute device "
f"{self.compute_device} but got {handle.flat_param.device}",
)
output = self._fsdp_wrapped_module(*args, **kwargs)
return _post_forward(
self, handle, _post_forward_reshard, self, unused, output
)

_pre_forward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@no_type_check
def _pre_forward(
state: _FSDPState,
handle: Optional[FlatParamHandle],
unshard_fn: Callable,
module: nn.Module,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""
执行前向传播前的逻辑。这包括:
1. 对当前分片的参数进行反分片(unshard),使其恢复为完整参数。
2. 为这些参数注册后向传播钩子(post-backward hooks)。
3. 将前向传播的输入(args, kwargs)转换为指定的计算精度。
"""
# 使用 PyTorch profiler 记录函数执行,便于性能分析
with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
# 这是一个针对梯度检查点(gradient checkpointing)的特殊处理。
# 在梯度检查点的重计算阶段,模块会再次执行前向传播,但此时参数已经 unshard 过了,
# 无需重复执行 unshard 和注册 hook 等操作,直接返回即可。
if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
return args, kwargs

# 1. 更新 FSDP 状态,标记当前正处于前向或后向传播阶段。
state.training_state = TrainingState.FORWARD_BACKWARD
# 记录当前模块的执行顺序,这对于后续的预取(prefetching)和梯度同步至关重要。
state._exec_order_data.record_pre_forward(handle, module.training)
if handle:
# 更新当前参数句柄(handle)的状态为“正在前向传播”。
handle._training_state = HandleTrainingState.FORWARD

# 2. 执行核心操作:反分片(Unsharding)。
# 这是最关键的一步。如果 unshard_fn 存在,就调用它。
# 这个函数内部会触发 all-gather 操作,从所有 GPU 上收集参数分片,
# 在当前设备上重建完整的、未分片的参数,以供模块的 forward 方法使用。
if unshard_fn is not None:
unshard_fn(state, handle)

# 3. 注册后向传播钩子(Post-Backward Hook)。
# 这个钩子会在反向传播计算完当前参数的梯度之后被触发。
# 它的主要作用是:
# a. 将参数重新分片(reshard),释放完整参数占用的内存。
# b. 对计算出的完整梯度进行 reduce-scatter 操作,完成梯度同步。
# 因为计算图(grad_fn)每次都可能变化,所以这个钩子需要在每次前向传播时都重新注册。
_register_post_backward_hook(state, handle)

# 针对 CPU Offload 的特殊处理:如果优化器在反向传播中将 CPU 上的梯度清空了,
# 这里需要重新分配一块内存空间给它,为下一次梯度累积做准备。
if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
handle.flat_param._cpu_grad = torch.zeros_like(
handle.flat_param._local_shard, device=torch.device("cpu")
).pin_memory()

# 检查是否需要将模型输入转换为低精度(如 fp16)。
should_cast_forward_inputs = (
state._handle and not state._handle._force_full_precision
)

# 4. 如果启用了混合精度(Mixed Precision)并且设置了 cast_forward_inputs,
# 就将 `args` 和 `kwargs` 中的张量递归地转换为指定的参数数据类型(param_dtype)。
if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)

# 注册一个只做 reshard 的后向钩子,用于不需要计算梯度的场景。
_register_post_backward_reshard_only_hook(state, handle, args, kwargs)

# 返回处理过(可能已转换精度)的输入,传递给原始模块的 forward 方法。
return args, kwargs

_pre_forward_unshard

对于_pre_forward,可以看到其传入的是_pre_forward_unshard函数,该函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def _pre_forward_unshard(
state: _FSDPState,
handle: Optional[FlatParamHandle],
) -> None:
"""Unshards parameters in the pre-forward."""
if not handle:
return
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not handle._prefetched:
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
handle._needs_pre_forward_unshard = False
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
current_stream = state._device_handle.current_stream()
if state._unshard_event is not None:
current_stream.wait_event(state._unshard_event)
state._unshard_event = None
else:
current_stream.wait_stream(state._unshard_stream)
with torch.profiler.record_function(
"FullyShardedDataParallel._pre_forward_prefetch"
):
_prefetch_handle(state, handle, _PrefetchMode.FORWARD)

@no_type_check
def _unshard(
state: _FSDPState,
handle: FlatParamHandle,
unshard_stream: torch.Stream,
pre_unshard_stream: torch.Stream,
) -> None:
"""
Unshards the handles in ``handles``. If the handles are in
:meth:`summon_full_params` and are using mixed precision, then they are
forced to full precision.

Postcondition: handle's ``FlatParameter`` 's data is the padded
unsharded flat parameter on the compute device.
"""
if not handle:
return
with state._device_handle.stream(pre_unshard_stream):
ran_pre_unshard = handle.pre_unshard()
if ran_pre_unshard:
unshard_stream.wait_stream(pre_unshard_stream)
if state.limit_all_gathers:
event = state._free_event_queue.dequeue_if_needed()
if event:
with torch.profiler.record_function(
"FullyShardedDataParallel.rate_limiter"
):
event.synchronize()
with state._device_handle.stream(unshard_stream):
handle.unshard()
handle.post_unshard()

其中最关键的是unshard函数,该函数用来将参数重新收集回来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// ... existing code ...
def unshard(self):
"""
Run the unshard logic.

This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.

If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
# 1. 检查是否真的需要执行 unshard 操作。
if not self.needs_unshard():
# 如果不需要(例如,sharding 策略是 NO_SHARD,或者参数已经被 unshard),
# 也要确保后续计算使用的是 unsharded 参数的视图。
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return

# 2. 如果需要 unshard,则执行以下核心步骤:
# 2a. 分配内存:为即将聚合的完整参数张量分配空间。
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
# 2b. 执行 All-Gather:调用我们之前分析过的 `_all_gather_flat_param` 方法,
# 从所有进程收集参数分片,并填充到刚刚分配的内存中。
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
# 2c. 切换状态:将模块内部的参数指针切换为指向这个刚刚聚合好的完整参数,
# 以便后续的前向或后向计算可以使用它。
self._use_unsharded_flat_param(padded_unsharded_flat_param)

_alloc_padded_unsharded_flat_param

该函数负责获取收集全参数的变量,并给他分配足够的内存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// ... existing code ...
def _alloc_padded_unsharded_flat_param(self):
"""
Allocate the *padded* unsharded flat parameter.

The unpadded unsharded
flat parameter is always a view into the padded one. This padded
parameter is saved to a different attribute on the ``FlatParameter``
depending on if we force full precision.
"""
# 1. 检查:确保当前正在使用分片策略。
self._check_sharded_strategy()
flat_param = self.flat_param

# 2. 获取目标张量:获取将要用于存储完整参数的那个张量对象。
# 此时它可能还只是一个没有分配实际存储空间的“空壳”。
unsharded_flat_param = self._get_padded_unsharded_flat_param()

# 3. 检查存储:确保这个张量之前的存储已经被释放,防止内存泄漏。
self._check_storage_freed(unsharded_flat_param)

# 4. 分配存储:这是核心操作。调用 `_alloc_storage` 为这个张量分配实际的内存空间。
# 分配的大小是 `_padded_unsharded_size`,即所有分片聚合后的总大小,可能还包含一些为了对齐而增加的 padding。
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]

# 5. 返回张量:返回这个已经分配好内存、准备好被填充的张量。
return unsharded_flat_param
_get_padded_unsharded_flat_param

注意这里判断了是否是强制使用全精度,如果是,就会返回flat_param._full_prec_full_param_padded,并且释放掉flat_param._full_param_padded,如果不是,就直接返回flat_param._full_param_padded

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// ... existing code ...
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
"""
Return a reference to the padded unsharded flat parameter depending on the calling context.

This should only be called if using a sharded strategy.
"""
self._check_sharded_strategy()
flat_param = self.flat_param
# 关键的逻辑判断:是否需要强制使用全精度?
if self._force_full_precision and self._uses_param_mixed_precision:
# --- 情况1: 需要强制全精度 ---
# 当 FSDP 进入一个需要全精度参数的上下文(例如 `summon_full_params`),
# 并且当前参数本身是使用混合精度(如 bfloat16)存储的。

# 1. 选择一个专门用于存储全精度(如 float32)参数的张量作为目标。
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
_p_assert(
// ... existing code ...
)
# 2. 释放可能存在的、旧的、低精度的完整参数的存储。
# 因为全精度版本接下来可能会被修改,这会导致低精度版本失效。
# 释放它是为了确保下次需要时,会重新执行 all-gather 获取最新的数据。
if flat_param._full_param_padded.untyped_storage().size() > 0:
_free_storage(flat_param._full_param_padded)
else:
# --- 情况2: 标准情况 ---
# 在常规的前向/后向传播中,直接使用默认的、与参数计算类型一致的张量即可。
# 这个张量的数据类型通常是低精度(如 bfloat16)。
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
return unsharded_flat_param
_alloc_storage

该函数调用了底层存储对象的 resize 方法,将其大小调整为所需的元素数量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
"""
Allocate storage for ``tensor`` with the given size.

Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
)
tensor._typed_storage()._resize_(size.numel())

_all_gather_flat_param

这个函数是通过all_gather操作来进行实际的参数收集:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// ... existing code ...
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.

Then switch to use the all-gathered tensor.
"""
# 1. 断言检查:确保分布式环境已初始化,并且目标张量的大小足以容纳所有分片。
_p_assert(
// ... existing code ...
)

# 2. 获取用于通信的进程组(Process Group)。
pg = (
self._fake_process_group
if self._use_fake_all_gather
else self.process_group
)

# 3. 根据张量是在 CPU 还是 GPU 上,执行不同的 all-gather 操作。
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
# 对于 CPU,将数据收集到一个 tensor 列表中。
tensor_list = list(
torch.chunk(
padded_unsharded_flat_param,
dist.get_world_size(pg), # type: ignore[arg-type]
)
)
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
# 对于 GPU,使用更高效的 all_gather_into_tensor 直接填充目标张量。
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)

# 4. 处理参数卸载(Offloading)的特殊情况。
if self._offload_params:
# 如果参数被卸载到 CPU,需要确保 CUDA stream 正确同步,防止数据竞争。
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
# 5. 返回填充了完整参数的张量。
return padded_unsharded_flat_param
all_gather_into_tensor学习

注意到这里对于GPU使用到了dist.all_gather_into_tensor操作。这个操作的示意图如下,即将各个GPU上的分片按序收集给各个GPU上,使得每个GPU都有一个整体:

这里有一个示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.distributed as dist
import os

def run():
dist.init_process_group(backend="nccl") # or "gloo" for CPU
rank = dist.get_rank()
world_size = dist.get_world_size()

# 每个进程构造自己的 input tensor
input_tensor = torch.ones(2, device='cuda') * (rank + 1)

# 所有数据拼接的输出 tensor
output_tensor = torch.empty(2 * world_size, device='cuda')

# All-gather into output tensor
dist.all_gather_into_tensor(output_tensor, input_tensor)

print(f"[rank {rank}] output_tensor: {output_tensor.cpu().tolist()}")

if __name__ == "__main__":
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # torchrun 自动设置
run()

# 2机,每机2卡的运行指令
# 机器1: torchrun --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr=fdbd:dc03:16:266::86 --master_port=12345 test.py
# 机器2: torchrun --nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=fdbd:dc03:16:266::86 --master_port=12345 test.py

# 运行结果
# [rank 0] output_tensor: [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
# [rank 1] output_tensor: [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
# [rank 2] output_tensor: [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
# [rank 3] output_tensor: [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]

all_gather学习

此外注意到这里这里对CPU使用了all_gather,它是 PyTorch 最经典的分布式通信原语之一。

它负责把每个进程上的 tensor 收集起来,按 rank 顺序填入 tensor_list 中。

  • tensor_list[i] 就是第 i 个进程的 tensor。

一个简单的示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import torch
import torch.distributed as dist
import os

def run():
dist.init_process_group(backend="gloo") # "nccl" for GPU
rank = dist.get_rank()
world_size = dist.get_world_size()

# 当前进程持有的 tensor,标记自己的 rank
input_tensor = torch.full((2,), rank, dtype=torch.int)

# tensor_list 是一个 list,会存放所有进程的 tensor
tensor_list = [torch.empty_like(input_tensor) for _ in range(world_size)]

# 执行 all_gather:每个进程收集所有人的 tensor
dist.all_gather(tensor_list, input_tensor)

print(f"[rank {rank}] tensor_list = {[t.tolist() for t in tensor_list]}")

if __name__ == "__main__":
run()

# 单机2卡的运行指令
# torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr=fdbd:dc03:16:266::86 --master_port=12346 test.py

# 运行结果
# [rank 0] tensor_list = [[0, 0], [1, 1]]
# [rank 1] tensor_list = [[0, 0], [1, 1]]
record_stream学习

这里对于off_load模型,即将参数卸载到CPU上的操作,会使用 record_stream(stream) 注册 stream 依赖,从而告诉系统这个 tensor 来自 CPU,是通过 stream X 拷贝到 GPU 的,请不要在这个 stream 执行完之前把它删掉。

1
2
3
4
5
6
7
8
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)

_register_post_backward_hook

主要用于在梯度计算完后对参数进行重新分片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def _register_post_backward_hook(
state: _FSDPState,
handle: Optional[FlatParamHandle],
) -> None:
"""
在 FlatParameter 的 AccumulateGrad 对象上注册一个后向钩子(post-backward hook),
用于在梯度计算完成后执行梯度的 reduce-scatter 操作以及参数的重新分片(reshard)。

AccumulateGrad 对象是完成 FlatParameter 梯度计算的最后一个函数,
因此钩子能确保在参数的整个梯度计算完成后才运行。

我们只在 FlatParameter 参与的 *第一个* 前向传播中注册一次钩子。
这依赖于 AccumulateGrad 对象在多次前向传播中被保留的特性。
"""
# 如果不需要计算梯度,则无需注册后向钩子。
if not torch.is_grad_enabled():
return
if not handle:
return
flat_param = handle.flat_param

# 根据是否在 TorchDynamo 编译模式下,选择不同的钩子注册方式。
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
# 检查钩子是否已注册,或者参数是否需要梯度。
already_registered = hasattr(flat_param, "_post_backward_hook_handle")
if already_registered or not flat_param.requires_grad:
return
# 使用 functools.partial 包装钩子函数,传入 FSDP 状态和句柄。
hook = functools.partial(_post_backward_hook, state, handle)
# 使用为编译模式设计的专用 API 注册钩子。
hook_handle = flat_param.register_post_accumulate_grad_hook(hook)
# 保存钩子句柄,用于状态检查和可能的卸载。
flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined]
else: # Eager mode (常规执行模式)
# 检查钩子是否已注册。
already_registered = hasattr(flat_param, "_post_backward_hook_state")
if already_registered or not flat_param.requires_grad:
return

# --- 获取 AccumulateGrad 对象 --- #
# 创建一个临时的、与 flat_param 相同大小的张量,这会创建一个简单的计算图,
# 从而使我们能够访问其 grad_fn。
temp_flat_param = flat_param.expand_as(flat_param)
_p_assert(
temp_flat_param.grad_fn is not None,
"需要 grad_fn 来访问 AccumulateGrad 对象并注册后向钩子",
)
# AccumulateGrad 对象是与参数直接关联的梯度累积函数。
acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr]
assert acc_grad is not None

# --- 注册钩子 --- #
# 在 AccumulateGrad 对象上注册钩子,确保在梯度累积完成后执行。
hook_handle = acc_grad.register_hook(
functools.partial(_post_backward_hook, state, handle)
)
# 保存 AccumulateGrad 对象和钩子句柄,用于状态检查和后续管理。
flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]

acc_grad.register_hook学习

为了在 反向传播过程中准确地知道哪些参数梯度已经计算完成,它会在 每个参数的 AccumulateGrad 节点上注册钩子(hook)

首先需要获取 grad_fn,构建 autograd 路径

  • flat_param 是一个 leaf tensor(即 requires_grad=True 且没有 grad_fn);

  • 使用 expand_as() 创建一个临时 view tensor,**这个 view 有 grad_fn**;

  • 这个 grad_fn 会链接到 AccumulateGrad 节点,而这个节点才允许注册 hook。

需要找到 AccumulateGrad

  • grad_fn.next_functions 是 PyTorch autograd 中的下游节点;

  • .next_functions[0][0] 正好是与参数 flat_param 直接绑定的 AccumulateGrad 节点。

1
2
3
4
5
  grad_fn (ExpandBackward0)

next_functions[0][0]

AccumulateGrad ← hook 就挂在这里!

然后我们需要注册钩子:

  • 使用 functools.partial() 固定住状态(state, handle);

  • 注册的 _post_backward_hook 会在梯度写入 .grad 前后触发;

  • 这是 FSDP 判断“这个参数的梯度已经完成了,可以 reshard / reduce / offload”的触发点。

然后记录下这个钩子对应的对象和句柄,方便:

  • 判断是否已注册(防止重复);

  • 后续移除 hook;

  • debug 或控制生命周期。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = (x * 2).sum()

print("y=",y)

temp_x = x.expand_as(x)
acc_grad = temp_x.grad_fn.next_functions[0][0]

def my_hook(grad_input, grad_output):
print(f"HOOK: grad_input={grad_input}, grad_output={grad_output}")
g = grad_output[0]
g = g * 0.5
print("Modified grad:", g)
# 不要 return!Node hook 不能返回任何值

acc_grad.register_hook(my_hook)

y.backward()
print("x.grad=",x.grad)

# 运行结果
# y= tensor(12., grad_fn=<SumBackward0>)
# HOOK: grad_input=(), grad_output=(tensor([2., 2., 2.]),)
# Modified grad: tensor([1., 1., 1.])
# x.grad= tensor([2., 2., 2.])

_post_backward_hook

这是 FSDP 的核心反向传播钩子,负责在本地梯度计算完成后,进行跨 GPU 的梯度同步(reduce-scatter)和参数重新分片(reshard)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@no_type_check
@torch.no_grad()
def _post_backward_hook(
state: _FSDPState,
handle: FlatParamHandle,
flat_param, # Note: this is a positional argument passed by the hook
*unused: Any,
):
"""
对 `handle` 的 `FlatParameter` 的梯度执行 Reduce-scatter 操作。

这是 FSDP 的核心反向传播钩子,负责在本地梯度计算完成后,
进行跨 GPU 的梯度同步(reduce-scatter)和参数重新分片(reshard)。

前置条件:
- `FlatParameter` 的 `.grad` 属性包含了本地批次(local batch)的完整(unsharded)梯度。

后置条件:
- 如果使用 `NO_SHARD` 策略,`.grad` 属性将是经过 all-reduce 后的完整梯度。
- 否则,`_saved_grad_shard` 属性将是经过 reduce-scatter 后的分片梯度(会与已有的梯度累加)。
"""
_log_post_backward_hook(state, handle, logger)
flat_param = handle.flat_param
# 标记该参数的后向钩子已被调用
flat_param._post_backward_called = True
with torch.autograd.profiler.record_function(
"FullyShardedDataParallel._post_backward_hook"
):
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
# 当对共享相同 `FlatParameter` 的子模块多次使用可重入的激活检查点(AC)时,
# 后向钩子可能会在一次反向传播中运行多次。在这种情况下,我们允许句柄的状态
# 已经是 `BACKWARD_POST`。
_p_assert(
handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
)
handle._training_state = HandleTrainingState.BACKWARD_POST

# 如果没有梯度,直接返回
if flat_param.grad is None:
return
# FSDP 不支持对梯度本身再求梯度
if flat_param.grad.requires_grad:
raise RuntimeError("FSDP does not support gradients of gradients")

# 关键步骤1:在进行梯度通信之前,先尝试重新分片参数,以尽早释放内存
_post_backward_reshard(state, handle)

# 如果不进行梯度同步(例如在 `no_sync()` 上下文中),则直接返回
if not state._sync_gradients:
if handle._use_orig_params:
# 如果使用了原始(未合并的)参数,需要将梯度视图指向正确的 unsharded grad
handle._use_unsharded_grad_views()
return

# 关键步骤2:等待当前计算流中的所有操作(如梯度计算)完成,
# 然后再开始 reduce-scatter 梯度。这确保了我们拥有完整的本地梯度。
# TorchDynamo 编译模式下跳过此步。
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
state._post_backward_stream.wait_stream(
state._device_handle.current_stream()
)

# 在专用的后向流中执行梯度通信
with state._device_handle.stream(state._post_backward_stream):
autograd_computed_grad = flat_param.grad.data
# 如果开启了低精度训练,且梯度类型与通信类型不符,则进行类型转换以降低通信开销
if (
not _low_precision_hook_enabled(state)
and flat_param.grad.dtype != handle._reduce_dtype
# 如果强制全精度(例如在 eval 模式下),则不降低梯度精度
and not handle._force_full_precision
):
flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)

# 根据分片策略执行梯度规约
if handle.uses_sharded_strategy:
_reduce_grad(state, handle) # Reduce-scatter
else:
_reduce_grad_no_shard(state, handle) # All-reduce

# 由于未分片的梯度是在计算流中产生的,但在后向流中消耗,
# 我们需要通知缓存分配器,以避免内存被过早回收。
_no_dispatch_record_stream(
autograd_computed_grad, state._post_backward_stream
)
_post_backward_reshard

在梯度计算好后我们可以提前将之前unshard的参数进行reshard,从而释放内存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def _post_backward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
"""
在反向传播后执行参数的重新分片(reshard)和预取(prefetch)操作。

这个函数是后向钩子(post-backward hook)的核心逻辑之一,负责在梯度计算和
聚合之后,管理参数的内存状态,并为下一次迭代(的第一个前向传播)做准备。
"""
# 1. 决定在反向传播后是否应该释放当前 handle 的未分片(unsharded)参数内存
free_unsharded_flat_param = _should_free_in_backward(state, handle)

# 2. 执行重新分片操作。如果 `free_unsharded_flat_param` 为 True,则会释放内存
_reshard(state, handle, free_unsharded_flat_param)

# TODO: 当前的后向预取(Post-backward prefetching)不支持一个模块包含多个 handle 的情况,
# 因为后向钩子是按 handle 触发的,而不是按一组 handle 触发的。
with torch.profiler.record_function(
"FullyShardedDataParallel._post_backward_prefetch"
):
# 3. 为下一次迭代预取参数。这里的模式是 BACKWARD,意味着这个预取是在
# 反向传播阶段触发的,目的是为下一次迭代的第一个前向传播做准备,
# 从而实现计算和通信的重叠。
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)

此外为了计算和通行的重叠,会为了下一次迭代提前开启unshard。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@no_type_check
def _prefetch_handle(
state: _FSDPState,
current_handle: Optional[FlatParamHandle],
prefetch_mode: _PrefetchMode,
) -> None:
"""
根据需要(异步地)预取下一个 handle 的参数。

这个函数是 FSDP 实现计算和通信重叠的关键。它会在当前 handle 计算的同时,
提前将下一个 handle 所需的参数从分片状态(sharded)通过 all-gather 恢复为
完整状态(unsharded)。
"""
if not current_handle:
return

# 1. 根据当前 handle 和预取模式(前向或后向),确定下一个需要预取的 handle
handle = _get_handle_to_prefetch(state, current_handle)
if not handle:
return

# 2. 临时模拟训练状态,以确保 `_unshard` 能够正确工作。
# 例如,在 `_unshard` 内部调用的 `_use_unsharded_views()` 需要根据正确的训练状态
# 来设置参数视图。
prev_training_state = handle._training_state
if prefetch_mode == _PrefetchMode.BACKWARD:
# 在后向钩子中预取,是为下一次前向传播做准备
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
# 在前向钩子中预取,是为下一次前向传播做准备
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")

# 3. 异步地执行 unshard (all-gather) 操作,但不同步等待操作完成。
# 这使得 all-gather 通信可以与当前流中的计算(例如,前向/后向计算)重叠。
# 同步操作(`wait()`)会被推迟到真正需要使用该参数之前执行。
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)

# 4. 恢复 handle 原始的训练状态
handle._training_state = prev_training_state
# 5. 标记该 handle 的参数已经被预取
handle._prefetched = True

_register_post_backward_reshard_only_hook

对于那些不需要梯度计算的参数,注册一个梯度计算结束后进行重分片的勾子函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def _register_post_backward_reshard_only_hook(
state: _FSDPState,
handle: Optional[FlatParamHandle],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> None:
"""
为那些不需要计算梯度(requires_grad=False)的扁平化参数(FlatParameter)注册一个
仅用于重新分片(reshard)的后向钩子。

我们通过在模块的输入激活(input activations)上注册一个多重梯度钩子(multi-post-grad hook)
来做到这一点。这么做的原因是,对于 requires_grad=False 的参数,我们无法像之前一样
在其自身的 AccumulateGrad 对象上注册钩子(因为它不存在)。

通过在输入张量上挂钩,我们可以确保在所有可能依赖于该参数的梯度都计算完毕后,
才执行重新分片操作,从而安全地释放内存。
"""
# 如果当前上下文不计算梯度,则无需执行任何后向逻辑。
if not torch.is_grad_enabled():
return

# `inp_tensors` 会被懒加载,以避免在所有参数都计算梯度的常规情况下产生不必要的CPU开销。
inp_tensors: Optional[list[torch.Tensor]] = None
if not handle:
return
flat_param = handle.flat_param

# 检查钩子是否已经注册过。
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_registered = hasattr(flat_param, "_post_backward_hook_handle")
else:
already_registered = hasattr(flat_param, "_post_backward_hook_state")

# 如果钩子已注册,或者参数需要梯度(此函数只处理不需要梯度的参数),则直接返回。
if already_registered or flat_param.requires_grad:
return

# --- 查找需要梯度的输入张量 --- #
# 这是此函数的关键逻辑:找到所有需要计算梯度的输入张量,并将钩子挂在它们上面。
if inp_tensors is None:
# 将所有输入参数扁平化为一个列表。
args_flat = pytree.arg_tree_leaves(*args, **kwargs)
# 筛选出其中是张量(Tensor)且需要梯度(requires_grad=True)的对象。
inp_tensors = [
obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
]
assert inp_tensors is not None # mypy

# --- 注册多重梯度钩子 --- #
# `register_multi_grad_hook` 会注册一个钩子,该钩子只有在 `inp_tensors` 列表
# 中所有张量的梯度都计算完毕后才会触发。
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
)

# 保存钩子句柄,以防止重复注册。
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment]
else:
flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]

_post_backward_reshard_only_hook

这里是注册的勾子函数,该函数会在梯度计算完成后计算,其作用是对于不需要梯度计算的参数,也在反向传播完成后将参数进行分片。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _post_backward_reshard_only_hook(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
"""
仅用于重新分片的后向钩子(post-backward hook)。

这个钩子专门为那些不需要梯度(`requires_grad=False`)的参数服务。
它的主要作用是在反向传播完成后,安全地将完整的(unsharded)参数重新分片(reshard),
从而释放内存。
"""
with torch.profiler.record_function(
"FullyShardedDataParallel._post_backward_hook_reshard_only"
):
# 如果前向传播的输出不需要梯度,`_pre_backward_hook` 可能不会被执行。
# 因此,这里需要显式地更新状态,以确保后续的后向预取(post-backward prefetching)逻辑能够正确运行。
state.training_state = TrainingState.FORWARD_BACKWARD
handle._training_state = HandleTrainingState.BACKWARD_POST
# 调用核心的重新分片逻辑
_post_backward_reshard(state, handle)


def _post_backward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
"""
执行后向传播后的重新分片和预取操作。

这个函数是后向钩子的核心部分,负责在梯度计算和聚合之后管理参数内存和为下一次迭代做准备。
"""
# 决定在反向传播后是否应该释放未分片的扁平参数(flat_param)
free_unsharded_flat_param = _should_free_in_backward(state, handle)
# 执行重新分片操作,根据上面的标志决定是否释放内存
_reshard(state, handle, free_unsharded_flat_param)

# TODO: 当前的后向预取不支持一个模块有多个 handle 的情况,
# 因为后向钩子是按 handle 触发的,而不是按 handle 组触发的。
with torch.profiler.record_function(
"FullyShardedDataParallel._post_backward_prefetch"
):
# 为下一次迭代的(前向)传播预取下一个 handle 的参数
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)


@no_type_check
def _should_free_in_backward(
state: _FSDPState,
handle: FlatParamHandle,
) -> bool:
"""
决定 FSDP 是否应该在后向钩子中释放未分片的扁平参数。

返回:
bool: 如果应该释放则返回 True,否则返回 False。
"""
# 如果未使用分片策略,则不释放
if not handle.uses_sharded_strategy:
return False
# 如果不进行梯度同步(例如,在使用 `no_sync()` 上下文时),
# 并且参数的分片策略是在前向传播后不重新分片(reshard),
# 那么我们选择不释放参数。这是一种启发式策略,
# 目的是用较高的内存占用换取更高的吞吐量(因为避免了额外的 all-gather 操作)。
# 否则,如果需要同步梯度,或者策略本身就需要重新分片,则释放参数以节省内存。
return (
state._sync_gradients
or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)

_post_forward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@no_type_check
def _post_forward(
state: _FSDPState,
handle: Optional[FlatParamHandle],
reshard_fn: Callable,
module: nn.Module,
input: Any,
output: Any,
) -> Any:
"""
运行前向传播后的逻辑。这包括一个机会来重新分片(reshard)当前未分片的参数
(例如在当前前向传播中使用的参数),并在前向传播的输出上注册 pre-backward 钩子。

功能:
- 这是 FSDP 前向传播钩子的核心实现,在每个 FSDP 包装的模块的 `forward` 方法之后执行。
- 主要负责在前向计算完成后,将不再需要的完整参数重新分片,以释放 GPU 内存。
- 同时,在输出张量上注册 pre-backward 钩子,以便在反向传播开始时,能够及时地将分片参数恢复为完整参数,用于梯度计算。

Args:
state (_FSDPState): FSDP 的全局状态。
handle (Optional[FlatParamHandle]): 当前前向传播中使用的参数句柄。
reshard_fn (Callable): 一个可调用对象,用于重新分片当前未分片的参数。如果为 `None`,则不执行任何重新分片操作。
module (nn.Module): 刚刚执行完 `forward` 的模块。
input (Any): 模块的输入(未使用,仅为满足钩子签名要求)。
output (Any): 前向传播的输出。Pre-backward 钩子会注册在该输出中需要梯度的张量上。

后置条件:
- 每个 `FlatParameter` 的 `data` 属性将指向分片后的扁平化参数,从而释放内存。
- 输出张量上已注册 pre-backward 钩子。

主要逻辑:
1. **处理激活检查点(Activation Checkpointing)**:如果与 `fully_shard` 和 `checkpoint` 一起使用,在重新计算的前向传播中会跳过此后向钩子逻辑,因为参数状态由激活检查点管理。
2. **记录执行顺序**:记录当前 handle 的前向传播完成事件,用于后续的乱序执行优化。
3. **重新分片 (Resharding)**:如果提供了 `reshard_fn`,则调用它来执行参数的重新分片,将完整的参数转换回分片状态,释放内存。
4. **注册 Pre-Backward 钩子**:调用 `_register_pre_backward_hooks`,遍历 `output` 中的张量,为那些需要梯度的张量注册一个钩子。这个钩子将在反向传播到达该张量时触发,执行参数的 unshard 操作(all-gather)。
5. **更新状态**:将 FSDP 实例和 handle 的训练状态更新为 `IDLE`,表示前向传播阶段已完成。
"""
with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
# 对于 `fully_shard` + `checkpoint`,在重新计算的前向传播中跳过 post-forward 逻辑
if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
return output

state._exec_order_data.record_post_forward(handle)
if reshard_fn is not None:
reshard_fn(state, handle)
# 注册 pre-backward 钩子,以便为梯度计算(如果需要)unshard 扁平化参数
output = _register_pre_backward_hooks(state, module, output, handle)
state.training_state = TrainingState.IDLE
if handle:
handle._training_state = HandleTrainingState.IDLE
return output

_post_forward_reshard

_post_forward中使用的reshard_fn就是_post_forward_reshard。

这是在前向传播后触发重新分片的入口函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@no_type_check
def _post_forward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
) -> None:
"""在前向传播后重新分片参数。"""
# 功能:
# - 作为前向传播钩子的一部分,决定是否以及如何重新分片(reshard)刚刚在前向计算中使用过的参数。
# - 重新分片的目的是及时释放未分片(unsharded)参数占用的 GPU 内存。
#
# 主要逻辑:
# 1. 检查 handle 是否存在,如果不存在则直接返回。
# 2. 决定是否要释放未分片的扁平化参数(`free_unsharded_flat_param`)。
# - 通常情况下,参数在使用后会被立即释放以节省内存。
# - 一个重要的例外是根(root)FSDP 模块。在 `FULL_SHARD` 策略下,根模块的参数在
# 前向传播后不会被立即释放,因为它们很可能马上就要用于反向传播的计算。
# 这是一种性能优化,避免了在前向后和反向前进行不必要的 `reshard` 和 `unshard` 操作。
# - `RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES` 包含了需要这种行为的分片策略。
# 3. 调用 `_reshard` 函数,传入计算出的 `free_unsharded_flat_param` 标志,执行实际的重新分片操作。
if not handle:
return
# 对于 `FULL_SHARD`,不要在 post-forward 中释放根模块的参数,
# 意图是它们能立即用于反向计算(尽管这可能不总是真的)
free_unsharded_flat_param = (
not state._is_root
and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
_reshard(state, handle, free_unsharded_flat_param)

这个函数调用参数句柄(handle)来执行实际的重新分片逻辑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@no_type_check
def _reshard(
state: _FSDPState,
handle: FlatParamHandle,
free_unsharded_flat_param: bool,
):
"""
重新分片句柄。`free_unsharded_flat_param` 指示是否释放
句柄的带填充的未分片扁平参数。
"""
# 功能:
# - 协调参数句柄(handle)的重新分片过程。
#
# 主要逻辑:
# 1. 调用 `handle.reshard()` 方法,将 `free_unsharded_flat_param` 标志传递下去,
# 由 handle 对象自己管理其内部状态和内存。
# 2. 如果 `limit_all_gathers` 选项被启用并且参数被释放,它会使用一个 CUDA 事件队列(`_free_event_queue`)
# 来确保在释放内存前,所有在当前流上的操作都已经完成,这是一种更精细的同步机制。
# 3. 调用 `handle.post_reshard()` 来执行任何 reshard 后的清理工作。
# 4. 将 `handle._prefetched` 标志设置为 `False`,表示参数现在是分片状态,
# 下次访问时需要通过 all-gather(即 unshard)来获取完整数据。
handle.reshard(free_unsharded_flat_param)
if state.limit_all_gathers and free_unsharded_flat_param:
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
# 在 torch compile 模式下,我们目前不为释放操作运行事件队列
# 但也许我们需要?TODO(voz): 研究一下
free_event = state._device_handle.Event()
free_event.record()
state._free_event_queue.enqueue(free_event)
handle.post_reshard()
# 无论扁平参数是否被释放,我们总是在下次访问时“unshard”参数
# 以获取其正确的形状。
handle._prefetched = False

执行重分片逻辑,需要先转为使用本地参数,再安全释放收集的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def reshard(self, free_unsharded_flat_param: bool):
"""
运行重新分片逻辑。

这包括如果 `free_unsharded_flat_param` 为真,则释放未分片的扁平参数,
并切换到使用分片的扁平参数。注意,这也隐式地将分片的扁平参数
卸载到 CPU(如果启用了 CPU offload),通过将其指向位于 CPU 上的 `_local_shard` 属性。
"""
# 功能:
# - 在参数句柄(handle)级别上执行重新分片的核心操作。
#
# 主要逻辑:
# 1. **切换指针**:首先调用 `_use_sharded_flat_param()`。这是一个关键步骤,它将 `FlatParameter`
# 的内部 `data` 指针重新指向分片后的张量(`_sharded_flat_param`)。
# 这样做可以防止在释放内存后发生“悬空指针”或“use-after-free”的 bug。
# 2. **释放内存**:如果 `free_unsharded_flat_param` 为 `True`,则调用 `_free_unsharded_flat_param()`
# 来释放之前未分片的、完整的参数所占用的内存。
# 在释放之前切换到分片的 `FlatParameter`,以防止外部性能分析工具出现“use-after-free”类型的 bug,
# 其中对于 `use_orig_params=True`,当在 `_use_sharded_views()` 中设置 `param.data = ...` 时,
# `param` 不会指向有效的内存。
self._use_sharded_flat_param()
if free_unsharded_flat_param:
self._free_unsharded_flat_param()

主要作用是将 self.flat_param (一个 nn.Parameter) 的 .data 属性从指向完整的、未分片的张量,切换为指向本地的分片张量 (self.flat_param._local_shard)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def _use_sharded_flat_param(self) -> None:
"""切换到使用分片的扁平参数。"""
# 功能:
# - 这是 reshard(重新分片)过程中的关键步骤。
# - 主要作用是将 self.flat_param (一个 nn.Parameter) 的 .data 属性从指向完整的、
# 未分片的张量,切换为指向本地的分片张量 (self.flat_param._local_shard)。
# - 这个切换是实现内存优化的核心:一旦 .data 指向了分片,之前完整张量所占用的
# 内存就可以被安全地释放。
# - 如果 `use_orig_params` 为 True,此方法还负责更新原始模型参数,使其成为
# 分片张量的“视图”(view),并处理其梯度的视图。

flat_param = self.flat_param
if self._use_orig_params:
# --- 特殊情况处理:决定是否跳过更新原始参数视图 --- #
# 在某些策略下(如 NO_SHARD),我们不在前向传播后立即重新分片。这是一种优化,
# 避免在前向和后向之间进行不必要的 unshard/reshard。
# `skip_use_sharded_views` 用于标识这种情况。
in_forward = self._training_state == HandleTrainingState.FORWARD
skip_use_sharded_views = (
torch.is_grad_enabled()
and in_forward
and self._sharding_strategy
in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
# 如果需要跳过,提前保存未分片参数的引用
if skip_use_sharded_views:
unsharded_flat_param = flat_param.data

if self._offload_params:
# --- CPU Offload 断言 --- #
# 如果启用了参数的 CPU 卸载,那么此时的本地分片理应在 CPU 上。
device = flat_param._local_shard.device # type: ignore[attr-defined]
_p_assert(
device == torch.device("cpu"),
f"期望本地分片在 CPU 上,但实际在 {device}",
)

# --- 核心操作:切换 .data 指针 --- #
# 这是此方法最核心的一行。它将 FlatParameter 的数据指针指向本地分片。
# 如果启用了 CPU Offload,_local_shard 就在 CPU 上,这个操作也完成了数据到 CPU 的“卸载”。
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]

if self._use_orig_params:
# --- 更新原始参数及其梯度视图 --- #
if skip_use_sharded_views: # type: ignore[possibly-undefined]
# 如果跳过了视图更新,只需保存未分片的参数引用即可。
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
else:
# 否则,调用 _use_sharded_views(),将原始参数的 .data 更新为分片张量的视图。
self._use_sharded_views()

# 在前向传播后的 reshard 中,我们可能尝试使用分片的梯度视图
# (或者,如果在 no_sync() 中累积了梯度,则使用未分片的梯度视图),
# 但在后向传播后的 reshard 中,我们将此调用推迟到 reduce-scatter 之后。
if (
in_forward # type: ignore[possibly-undefined]
# 如果跳过了使用分片视图,则跳过使用梯度视图,
# 因为向用户暴露未分片的参数和分片的梯度可能会引起困惑
and not self._skipped_use_sharded_views
):
# 检查在 no_sync() 上下文中是否累积了完整的梯度
accumulated_grad_in_no_sync = (
flat_param.grad is not None
and self.uses_sharded_strategy
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
)
if accumulated_grad_in_no_sync:
# 如果有完整的梯度,则原始参数的梯度视图也应指向这个完整的梯度。
self._use_unsharded_grad_views()
else:
# 否则,梯度视图应指向分片后的梯度。
self._use_sharded_grad_views()

释放放带填充的未分片扁平参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def _free_unsharded_flat_param(self):
"""
释放带填充的未分片扁平参数。我们允许在存储未分配时也调用此函数。

要释放的张量取决于调用上下文,因为 unshard 可能强制使用了全精度,
在这种情况下,会使用一个不同的张量。
"""
# 功能:
# - 定位到未分片的、完整的扁平化参数,并准备释放其内存。
#
# 主要逻辑:
# 1. 获取正确的未分片参数张量 `unsharded_flat_param`。
# 2. 检查该张量是否在计算设备上(例如 GPU)。
# 3. **同步流**:调用 `_no_dispatch_record_stream()`,确保在释放张量内存之前,
# 当前 CUDA 流中所有使用该张量的操作都已完成。这是一个重要的同步步骤,
# 防止在 GPU 操作完成前就释放了其正在使用的内存。
# 4. 调用底层的 `_free_storage()` 工具函数来执行实际的内存释放。
self._check_sharded_strategy()
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_on_compute_device(unsharded_flat_param)
# 在当前流中的所有操作完成之前,不要释放内存
_no_dispatch_record_stream(
unsharded_flat_param, self._device_handle.current_stream()
)
_free_storage(unsharded_flat_param)

首先获取到之前带填充的、未分片的扁平参数的引用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
"""
根据调用上下文,返回对带填充的、未分片的扁平参数的引用。

功能:
- 此方法是获取用于 all-gather 操作的目标张量的核心逻辑。
- 它处理了混合精度训练中的一个重要情况:当需要强制使用全精度参数时,它会返回一个不同的、高精度的张量,并释放可能存在的旧的、低精度的张量,以确保数据一致性。

主要逻辑:
1. **检查分片策略**:确保此方法仅在使用了分片策略(如 `FULL_SHARD` 或 `SHARD_GRAD_OP`)时被调用。
2. **处理强制全精度和混合精度**:
- 如果 `_force_full_precision`(例如,在 `summon_full_params` 中)和 `_uses_param_mixed_precision` 都为 `True`,则意味着我们需要一个全精度的参数副本进行操作。
- 在这种情况下,返回 `_full_prec_full_param_padded`,这是一个专门用于存储全精度参数的张量。
- **关键操作**:如果低精度的 `_full_param_padded` 张量仍然占用内存(意味着它可能来自上一次前向传播且未被释放),则必须将其释放。这是因为对全精度参数的修改会使这个低精度副本失效。释放后,下一次计算将强制执行新的 all-gather 来获取最新的数据,而不是使用过时的低精度缓存。
3. **标准情况**:
- 在其他所有情况下(例如,不强制全精度或不使用混合精度),直接返回标准的 `_full_param_padded` 张量,该张量将作为 all-gather 的目标。
"""
# 确认当前正在使用分片策略,因为此方法与获取未分片参数相关
self._check_sharded_strategy()
flat_param = self.flat_param
# 检查是否需要强制使用全精度参数,并且参数混合精度已启用
if self._force_full_precision and self._uses_param_mixed_precision:
# 当启用参数混合精度时,我们使用一个不同的张量作为 all-gather 的目标,
# 以保持 `_full_param_padded` 始终是低精度这一不变性。
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
# 断言确保我们获取的确实是全精度张量,其类型不应与前向/后向传播中使用的低精度类型相同
_p_assert(
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
f"期望全精度但得到了 {self._fwd_bwd_param_dtype}",
)
# 对于在 forward 后不重新分片的策略,`_full_param_padded` 可能仍被分配了内存。
# 由于我们在这里强制使用全精度,全精度副本可能会被修改,从而使现有的低精度副本失效。
# 因此,我们在这里释放它,以确保下一次前向/后向计算会进行新的 all-gather,以持久化修改。
if flat_param._full_param_padded.untyped_storage().size() > 0:
_free_storage(flat_param._full_param_padded)
else:
# 在标准情况下,直接使用 `_full_param_padded` 作为未分片的参数
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
return unsharded_flat_param

这是一个通用的底层工具函数,通过将张量的存储大小调整为 0 来释放其内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _free_storage(tensor: torch.Tensor):
"""
释放 `tensor` 的底层存储。

返回:
bool: 如果方法释放了存储,则返回 `True`;如果存储已被释放,则返回 `False`。
"""
# 功能:
# - 这是实际执行内存释放的最低级函数。
#
# 主要逻辑:
# 1. 在 `torch.no_grad()` 上下文中操作,避免不必要的梯度跟踪。
# 2. 检查存储是否已经被释放(大小是否为 0)。
# 3. **安全检查**:断言(`_p_assert`)张量的 `storage_offset()` 为 0。这是一个重要的安全措施,
# 确保我们正在释放的张量是其底层存储的唯一所有者。如果一个存储被多个张量视图(view)共享,
# 释放它是不安全的。
# 4. **释放操作**:调用 `tensor._typed_storage()._resize_(0)`。这个内部方法会将张量的底层存储
# 大小调整为 0,从而有效地将内存返回给 PyTorch 的缓存分配器,使其可以被重用。
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"当张量不是其存储的唯一占用者时,释放它的存储是不安全的\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)

_register_pre_backward_hooks

主要是注册反向传播前置钩子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@no_type_check
def _register_pre_backward_hooks(
state: _FSDPState,
module: nn.Module,
outputs: Any,
handle: FlatParamHandle,
) -> None:
"""
在 `outputs`(前向传播的输出)中需要梯度的张量上注册反向传播前置钩子(pre-backward hooks)。
这些输出是使用 `handle` 的 `FlatParameter` 计算得出的。

功能:
- 这是 FSDP 实现自动、即时(just-in-time)参数 un-sharding 的核心机制。
- 通过在模块的输出张量上注册钩子,FSDP 可以在反向传播到达该模块之前,精确地触发相应参数的 all-gather 操作。
- 这样可以确保在计算梯度时,完整的、未分片的参数是可用的,同时在其他时间保持分片状态以节省内存。

主要逻辑:
1. **检查梯度计算**:如果当前没有启用梯度计算(例如,在 `torch.no_grad()` 上下文中),则无需注册任何钩子,直接返回。
2. **重置状态**:
- 对于根模块,重置 `_post_backward_callback_queued` 标志,为新的反向传播做准备。
- 对于当前 `handle`,重置 `_needs_pre_backward_unshard` 和 `_ran_pre_backward_hook` 标志,以确保钩子逻辑的正确执行。
3. **定义钩子注册函数 `_register_hook`**:
- 此内部函数负责在单个张量上注册钩子。
- **条件**:仅当张量 `requires_grad` 时才注册,因为只有这些张量会参与反向传播。
- **注册**:使用 `t.register_hook()` 将 `_pre_backward_hook`(通过 `functools.partial` 包装)附加到张量上。
- **标记需求**:注册钩子后,将 `handle._needs_pre_backward_unshard` 设为 `True`,表明该 `handle` 对应的参数在反向传播中需要被 un-shard。
4. **递归应用钩子**:
- 使用 `_apply_to_tensors` 工具函数,将 `_register_hook` 应用于 `outputs` 中的所有张量。这可以处理复杂的输出结构(如元组、列表、字典等)。
"""
# 如果没有启用梯度计算(例如在 `torch.no_grad()` 中),则不需要反向传播逻辑
if not torch.is_grad_enabled():
return outputs
# 如果是根 FSDP 实例,重置 post-backward 回调已排队的标志
if state._is_root:
state._post_backward_callback_queued = False # 此标志仅在根节点上定义

if handle:
# 初始化标志,表示此 handle 尚不需要在反向传播前进行 un-shard
handle._needs_pre_backward_unshard = False
# 由于此 handle 的 FlatParameter 参与了前向传播,我们保守地假设
# 它将在反向传播中使用。重置此标志,用于跟踪 pre-backward 钩子是否已运行。
handle._ran_pre_backward_hook = False

def _register_hook(t: torch.Tensor) -> torch.Tensor:
# 只在需要计算梯度的张量上注册钩子
if t.requires_grad:
# 注册一个不可序列化的钩子。`_pre_backward_hook` 将在反向传播到此张量时被调用。
t.register_hook(
torch.utils.hooks.unserializable_hook(
functools.partial(_pre_backward_hook, state, module, handle)
)
)
# 如果注册了钩子,说明这个 handle 对应的参数将需要 un-shard
if handle:
handle._needs_pre_backward_unshard = True
return t

# 递归地将 _register_hook 函数应用于 `outputs` 中的所有张量
return _apply_to_tensors(_register_hook, outputs)

具体注册的勾子函数为_pre_backward_hook,该函数主要是执行_unshard来通过 all-gather 操作获取完整的参数,并且为了重叠计算和通信,它会立即触发下一个(在反向传播顺序中)模块参数的 prefetching(预取)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@no_type_check
def _pre_backward_hook(
state: _FSDPState,
module: nn.Module,
handle: FlatParamHandle,
grad,
*unused: Any,
) -> Any:
"""
为梯度计算准备 `handle` 的 `FlatParameter`。

功能:
- 这是 FSDP 的核心反向传播钩子,由 `_register_pre_backward_hooks` 注册。
- 当反向传播的梯度流到达一个模块的输出张量时,这个钩子被触发。
- 它的主要职责是:
1. **Un-shard 参数**:执行 all-gather 操作,将当前模块所需的 `FlatParameter` 从分片状态恢复为完整的、未分片的张量,以便进行梯度计算。
2. **预取下一个参数**:为了重叠计算和通信,它会立即触发下一个(在反向传播顺序中)模块参数的 prefetching(预取)。
3. **状态管理**:管理 FSDP 的内部状态,例如标记钩子已运行,以及为根模块注册最终的 post-backward 回调。

主要逻辑:
1. **钩子执行保护**:检查 `_ran_pre_backward_hook` 标志,确保对于同一次前向计算涉及的同一组参数,此钩子只执行一次。
2. **根模块初始化**:如果是根 FSDP 模块,并且是反向传播的第一次调用,它会注册一个 `_post_backward_final_callback`。这个回调将在整个反向传播结束后执行,用于最终的清理工作(如梯度 reshard)。
3. **状态转换**:将 FSDP 状态机切换到 `FORWARD_BACKWARD` 和 `BACKWARD_PRE`,用于调试和断言。
4. **参数 Un-shard**:
- 检查 `_needs_pre_backward_unshard` 标志。
- 如果需要 un-shard 且参数尚未被预取(`_prefetched` 为 False),则调用 `_unshard` 执行 all-gather。
- 使用 `wait_stream` 确保计算流等待 un-shard 操作完成。
5. **反向预取(Backward Prefetch)**:
- 调用 `_prefetch_handle` 并传入 `_PrefetchMode.BACKWARD`,以启动下一个句柄的参数 un-sharding。这是 FSDP 的关键性能优化。
6. **梯度准备**:调用 `handle.prepare_gradient_for_backward()`,为即将到来的梯度计算做准备。
7. **标记完成**:设置 `_ran_pre_backward_hook = True`。
"""
# 对于同一次模块前向计算中涉及的同一组句柄,只运行一次 pre-backward 钩子
if (
handle
and hasattr(handle, "_ran_pre_backward_hook")
and handle._ran_pre_backward_hook
):
return grad

with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
# 为根 FSDP 实例排队一次 post-backward 回调,将其附加到最外层的反向图任务上,
# 以便在所有反向调用完成后调用它。
if state._is_root and not state._post_backward_callback_queued:
_register_post_backward_final_callback(state, module)
_reset_flat_param_grad_info_if_needed(state._all_handles)
elif handle:
# 断言 FSDP 模块处于正确的训练状态
allowed_states = [TrainingState.IDLE]
if _is_composable(state):
allowed_states.append(TrainingState.FORWARD_BACKWARD)
_assert_in_training_states(state, allowed_states)
# 更新训练状态为正在进行反向传播
state.training_state = TrainingState.FORWARD_BACKWARD
# 排队 post-backward 回调是 pre-backward 钩子中唯一不是按句柄处理的逻辑,
# 因此如果没有句柄,我们可以在这里提前返回。
if not handle:
return grad
# 更新句柄的训练状态为反向传播前
handle._training_state = HandleTrainingState.BACKWARD_PRE

if handle._needs_pre_backward_unshard:
# 如果句柄已经被预取,则无需再次调用 `_unshard()`
if not handle._prefetched:
_unshard(
state,
handle,
state._unshard_stream, # 用于 unshard 的 CUDA 流
state._pre_unshard_stream, # 用于 pre-unshard 的 CUDA 流
)
# 在 tracing 期间不要等待,以避免图中断
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
# 确保计算流等待 unshard 操作完成
state._device_handle.current_stream().wait_stream(state._unshard_stream)

# 将此标志设置为 `False`,以确保目标错误的预取不会实际 unshard 这些句柄
handle._needs_pre_backward_unshard = False
with torch.profiler.record_function(
"FullyShardedDataParallel._pre_backward_prefetch"
):
# 预取下一个在反向传播中需要的句柄的参数
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
# 为反向传播准备梯度
handle.prepare_gradient_for_backward()
# 标记此句柄的 pre-backward 钩子已运行
handle._ran_pre_backward_hook = True
return grad

总结

勾子函数及运行流程

参考资料


【pytorch-fsdp 源代码阅读(一)】-全流程概览
http://example.com/2025/07/02/pytorch-fsdp-1/
作者
滑滑蛋
发布于
2025年7月2日
许可协议