【pytorch-fsdp 源代码阅读(二)】-参数流转

初始化

  1. 获取module下需要展开的tensors

  2. 将tensors放入到一个数组中,然后使用cat拼接到一起。这部分展开的数据会放到FlatParamHandle.flat_param中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def flatten_tensors(
self,
tensors: list[Tensor],
aligned_numel: int,
) -> Tensor:
"""
Flatten ``tensors`` into a single flat tensor.

The flattening optionally includes
padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
gives the numel required to have address alignment.

NOTE: The padding alignment algorithm must be kept in sync with
:meth:`_init_flat_param_metadata`. We separate the two methods because
the initialization happens once, whereas this method may be called
multiple times throughout training (e.g. for checkpointing).
"""
if len(tensors) == 0:
raise ValueError("Expects non-empty `tensors`")
if aligned_numel < 0:
raise ValueError(
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
dtype, _, device = self._validate_tensors_to_flatten(tensors)
flat_tensors: list[Tensor] = []
if aligned_numel > 0:
total_numel = 0
for tensor in tensors:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
flat_tensors.append(
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
)
total_numel += tensor.numel()
numel_to_pad = self.world_size - (total_numel % self.world_size)
if numel_to_pad > 0 and numel_to_pad < self.world_size:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
else:
flat_tensors = [
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
for tensor in tensors
]
return torch.cat(flat_tensors, dim=0)

def flatten_tensors_into_flat_param(
self,
tensors: list[Tensor],
aligned_numel: int,
requires_grad: bool,
) -> FlatParameter:
flat_param_data = self.flatten_tensors(tensors, aligned_numel)
return FlatParameter(flat_param_data, requires_grad=requires_grad)



self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
params_to_flatten,
aligned_numel=0,
requires_grad=flat_param_requires_grad,
)
  • 根据各个tensors的参数量划分FlatParamHandle.flat_param,得到views,
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@no_type_check
def _get_unflat_views_unaligned(
self,
tensor: Optional[torch.Tensor] = None,
) -> Iterator[Tensor]:
"""
Return unflattened ``Tensor`` views into ``tensor``.

If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based
on ``flat_param`` 's metadata.

Examples for ``tensor`` include ``flat_param.grad`` or unsharded
tensor optimizer state.
"""
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
views = (
_ext_post_unflatten_transform(
subtensor.view(shape)
if contiguous
else subtensor.as_strided(shape, stride),
param_extension,
self._fsdp_extension,
)
for (subtensor, shape, stride, contiguous, param_extension) in zip(
torch.split(tensor, flat_param._numels, dim=0),
flat_param._shapes,
flat_param._strides,
flat_param._contiguities,
flat_param._param_extensions,
)
)
return views
  • 将这些views设置为module的attr,即进行替换
1
2
3
param_var: Tensor = view
self._setattr_tensor(module, param_name, param_var)

_lazy_init

会调用init_flat_param_attributes()

  1. 设置flat_param._local_shard = flat_param.data

  2. 设置flat_param._full_param_padded为padded_unsharded_numel大小的torch.empty

  3. 设置flat_param._padded_unsharded_size

  4. 释放flat_param._full_param_padded的底层存储

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
if self.uses_sharded_strategy:
# We maintain a padded unsharded tensor that serves as the
# all-gather destination and owns the original parameter storages.
unsharded_param_dtype = (
self._fwd_bwd_param_dtype
if self._uses_param_mixed_precision
else flat_param.dtype
) # use low precision if parameter mixed precision is enabled
padded_unsharded_numel = flat_param.numel() * self.world_size
flat_param._full_param_padded = torch.empty(
padded_unsharded_numel,
device=self.device,
dtype=unsharded_param_dtype,
)
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
_free_storage(flat_param._full_param_padded)

if self._uses_param_mixed_precision:
# For parameter mixed precision, we maintain a full precision
# padded unsharded tensor for when we force full precision.
flat_param._full_prec_full_param_padded = torch.empty(
padded_unsharded_numel,
device=self.device,
dtype=flat_param.dtype, # full precision
)
_free_storage(flat_param._full_prec_full_param_padded)

Shard

_post_forward_reshard

  1. 注意只有非root且RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES才会进行参数的reshard:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _post_forward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
) -> None:
"""Reshards parameters in the post-forward."""
if not handle:
return
# Do not free the root's parameters in the post-forward for `FULL_SHARD`
# with the intention that they are immediately used for backward
# computation (though this may not be true)
free_unsharded_flat_param = (
not state._is_root
and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
_reshard(state, handle, free_unsharded_flat_param)
  • 将FlatParamHanle.flat_param.data设置为FlatParamHanle.flat_param._local_shard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def reshard(self, free_unsharded_flat_param: bool):
"""
Run the reshard logic.

This includes freeing the unsharded flat
parameter if ``free_unsharded_flat_param`` and switching to using the
sharded flat parameter. Note that this also implicitly offloads
the sharded flat parameter (if CPU offload is enabled) by pointing
it to the ``_local_shard`` attribute which resides on CPU.
"""
# Switch to the sharded `FlatParameter` before freeing to prevent
# "use-after-free"-type bugs with external profiling tools, where for
# `use_orig_params=True`, the `param` does not point to valid memory
# when setting `param.data = ...` in `_use_sharded_views()`.
self._use_sharded_flat_param()
if free_unsharded_flat_param:
self._free_unsharded_flat_param()
1
flat_param.data = flat_param._local_shard  # type: ignore[attr-defined]
  • 得到FlatParamHanle.unsharded_flat_param,即FlatParamHanle.flat_param._full_param_padded:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
unsharded_flat_param = flat_param._full_param_padded  # type: ignore[attr-defined]

def _free_unsharded_flat_param(self):
"""
Free the padded unsharded flat parameter. We allow this
function to be called even when storage is not allocated

The tensor to free depends
on the calling context since the unshard may have forced full
precision, in which case a different tensor is used.
"""
self._check_sharded_strategy()
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_on_compute_device(unsharded_flat_param)
# Do not free the memory until all ops in the current stream finish
_no_dispatch_record_stream(
unsharded_flat_param, self._device_handle.current_stream()
)
_free_storage(unsharded_flat_param)
  • 释放这部分存储
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

def _free_storage(tensor: torch.Tensor):
"""
Frees the underlying storage of ``tensor``.

Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)

Unshard

_pre_forward_unshard

  1. 获取到FlatParamHanle.flat_param._full_param_padded,这是一个tensor
1
2
3
flat_param = self.flat_param
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
return unsharded_flat_param
  • 检查存储是不是真释放了

  • 给FlatParamHanle.flat_param._full_param_padded分配存储

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) 

def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
"""
Allocate storage for ``tensor`` with the given size.

Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
)
tensor._typed_storage()._resize_(size.numel())
  • 通过all gather来将各个GPU上的FlatParamHanle.flat_param.data收集给FlatParamHanle.flat_param._full_param_padded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.

Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)

pg = (
self._fake_process_group
if self._use_fake_all_gather
else self.process_group
)

# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(
torch.chunk(
padded_unsharded_flat_param,
dist.get_world_size(pg), # type: ignore[arg-type]
)
)
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)

if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param

  1. 使用收集到的FlatParamHanle.flat_param._full_param_padded,将self.flat_param.data更新为它,然后调用_use_unsharded_views,得到views然后赋值给各个param
1
2
3
4
5
6
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
# slicing [:] is not visible to autograd because of .data
self.flat_param.data = flat_param_part
self._use_unsharded_views(as_params=False)


【pytorch-fsdp 源代码阅读(二)】-参数流转
http://example.com/2025/07/12/pytorch-fsdp-2/
作者
滑滑蛋
发布于
2025年7月12日
许可协议