【Verl源码分析(三)】Verl中训练引擎与推理引擎共置处理(以FSDP、vLLM为例)

Verl在进行强化学习训练时,既需要使用推理引擎执行推理采样,也需要训练引擎进行模型更新,所以需要使用两类引擎,故这里以FSDP训练引擎和vLLM推理引擎为例对Verl的相关处理进行介绍。

注意查看的是0.4.1.x版本的Verl代码:https://github.com/verl-project/verl/tree/v0.4.1.x

初始化

如前所述(https://slipegg.github.io/2026/01/30/Verl-Resource-Management/),Verl在一个资源池上进行初始化时会将各角色的类合并到`WorkDict`中(actor、rollout以及ref是`ActorRolloutRefWorker`类,critic是`CriticWorker`类),然后再在资源池上生成多个`WorkDict`实例,各个角色会在`WorkDict`的基础上抽取出一个角色类,并依次来调用各个实例执行的内容。

实例初始化

创建合并的类WorkDict并将其包裹为ray的类然后进行实例化的相关代码如下:

1
2
3
4
5
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)

self.ray_worker_group_cls(...)中就会进行实例化,其中ray_worker_group_cls往往是RayWorkerGroup,其初始化后会在资源池的各个GPU节点上初始化对应的WorkDict实例。

WorkDict实例化时执行的内容如下所示,如果环境变量没有配置DISABLE_WORKER_INIT就会直接将各个角色的类也进行初始化,一般而言都会直接初始化。

1
2
3
4
5
6
7
8
9
10
11
12
class WorkerDict(worker_cls):
def __init__(self):
super().__init__()
self.worker_dict = {}
for key, user_defined_cls in cls_dict.items():
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
# directly instantiate the class without remote
# in worker class, e.g. <verl.single_controller.base.worker.Worker>
# when DISABLE_WORKER_INIT == 1 it will return immediately
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}))

我们这里关注的训练与推理的角色是ActorRolloutRefWorker类,该类在执行__init__初始化时会初始化一些配置,主要的流程如下:

  1. 使用torch.distributed.init_process_group生成NCCL通信

  2. fsdp_size配置生成device_mesh,这决定了FSDP整体的布局,可能是1D也可能是2D的布局:

    • 1D:(world_size,) → FSDP FULL_SHARD(类似 ZeRO-3)

    • 2D:(ddp, fsdp) → FSDP HYBRID_SHARD(外层 DDP、内层 FSDP)

  3. 当前ActorRolloutRefWorker的是什么角色是依据role名决定的,如下所示:

1
2
3
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
  • 如果是actor,那么就支持配置是否需要offload_param以及是否需要offload_optimizer,如果是ref,那么就只支持配置是否需要offload_param

  • 然后还需要根据device_mesh以及config来计算是否需要micro_batch_size等配置

模型初始化

在初始化了WorkDict实例后,会提取出actor_rollout_wg,然后会调用actor_rollout_wg.init_model(...)函数来进行模型的初始化。init_model执行内容如下:

  1. 执行_build_model_optimizer函数,如果当前角色是actor那么就会构建出model、optimizer和lr_scheduler,如果不是actor就只会构建出model。

    1. 构建model是使用transformers库的from_pretrained这一个API,从下载的文件中加载出来,注意目前是直接加载到CPU内存中,然后model加载后还会使用FSDP进行包装:

      1. 如果不是actor,那么就会强制启用cpu_offload来将FSDP模型先放入到CPU中以节省内存;如果是actor,那么就会直接加载到当前worker所属的GPU上,后面还会依据_is_offload_param配置决定是否需要将模型卸载到CPU上

      2. 这里依据self.config.actor.strategy配置支持fsdp与fsdp2两类分割方法

    2. 构建optimizer默认使用的是AdamW优化器,初始化时其占用的空间是与model所在的设备相同,并且后面会依据_is_offload_optimizer来将优化器卸载到CPU上

  2. 如果当前角色是rollout就会执行_build_rollout,rollout引擎支持hfvllm以及sglang,单看vllm引擎的执行:

    1. 其初始化的代码如下所示,其主要初始化了vllm引擎,以及使用刚刚通过_build_model_optimizer获取到的self.actor_module_fsdp来初始化rollout_sharding_manager

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      elif rollout_name == "vllm":
      from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
      from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager

      log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
      local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
      lora_kwargs = {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} if self._is_lora else {}
      # lora_kwargs = {}
      if vllm_mode == "customized":
      rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, **lora_kwargs)
      elif vllm_mode == "spmd":
      from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout

      vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
      rollout = vllm_rollout_cls(model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, trust_remote_code=trust_remote_code, **lora_kwargs)
      else:
      raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")

      log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
      full_params = torch.distributed.get_world_size() == 1
      rollout_sharding_manager = FSDPVLLMShardingManager(
      module=self.actor_module_fsdp,
      inference_engine=rollout.inference_engine,
      model_config=self.actor_model_config,
      full_params=full_params,
      device_mesh=rollout_device_mesh,
      offload_param=self._is_offload_param,
      load_format=self.config.rollout.load_format,
      layered_summon=self.config.rollout.get("layered_summon", False),
      )
      log_gpu_memory_usage("After building sharding manager", logger=logger)
      • 以同步执行的vLLMRollout类为例,其初始化代码如下,核心而言,其self.inference_engine在初始化时会从路径加载模型到GPU上,因为vLLM worker 需要完整模型结构(config/tokenizer/层定义)才能初始化调度器和通信拓扑,而随之会调用sleep(level=1)函数来卸载vLLM初始化时所占用的GPU显存,对于level=1级别的sleep,其会清空KVCache,然后将模型权重卸载到CPU上
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      class vLLMRollout(BaseRollout):
      def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):
      """A vLLM rollout. It requires the module is supported by the vllm.

      Args:
      module: module here follows huggingface APIs
      config: DictConfig
      tokenizer: the task/model tokenizer
      model_hf_config: the huggingface config to initiallize the generating model in vllm
      **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
      """
      super().__init__()
      self.config = config
      assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"

      tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
      assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size"
      max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192)

      if kwargs.get("train_tp") is not None:
      # deployed with megatron
      import os

      os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
      os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
      if vllm_version in (
      "0.5.4",
      "0.6.3",
      ):
      train_tp = kwargs.get("train_tp")
      num_tp_per_train_tp = train_tp // tensor_parallel_size
      vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp)
      else:
      vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size)

      rope_scaling_config = getattr(model_hf_config, "rope_scaling", None)
      if not rope_scaling_config:
      max_position_embeddings = None
      if hasattr(model_hf_config, "max_position_embeddings"):
      max_position_embeddings = model_hf_config.max_position_embeddings
      elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"):
      max_position_embeddings = model_hf_config.llm_config.max_position_embeddings
      elif hasattr(model_hf_config, "text_config") and hasattr(model_hf_config.text_config, "max_position_embeddings"):
      max_position_embeddings = model_hf_config.text_config.max_position_embeddings
      if max_position_embeddings is None:
      raise ValueError("max_position_embeddings not found in model_hf_config")

      assert max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length"

      max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)

      if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:
      raise ValueError(
      "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \
      please increase max_num_batched_tokens or disable chunked prefill"
      )

      trust_remote_code = kwargs.get("trust_remote_code", False)
      load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format

      lora_kwargs = kwargs.pop("lora_kwargs", {})
      self.lora_kwargs = lora_kwargs
      # copy it to avoid secretly modifying the engine config
      engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm))
      # For each vLLM engine parameter,
      # - `None` means not setting it, so we pop it, and leave it to vLLM default value
      # (which can vary across different vLLM versions);
      # - Otherwise it's the desired value we want to explicitly set.
      engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}
      if config.get("limit_images", None): # support for multi-image data
      engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")}

      self.inference_engine = LLM(
      model=model_path,
      enable_sleep_mode=True,
      tensor_parallel_size=tensor_parallel_size,
      distributed_executor_backend="external_launcher",
      dtype=config.dtype,
      enforce_eager=config.enforce_eager,
      gpu_memory_utilization=config.gpu_memory_utilization,
      disable_custom_all_reduce=True,
      skip_tokenizer_init=False,
      max_model_len=max_model_len,
      load_format=load_format,
      disable_log_stats=config.disable_log_stats,
      max_num_batched_tokens=max_num_batched_tokens,
      enable_chunked_prefill=config.enable_chunked_prefill,
      enable_prefix_caching=True,
      trust_remote_code=trust_remote_code,
      seed=config.get("seed", 0),
      **lora_kwargs,
      **engine_kwargs,
      )

      # Offload vllm model to reduce peak memory usage
      self.inference_engine.sleep(level=1)

      kwargs = dict(
      n=1,
      logprobs=0, # can be set to 0 and let actor to recompute
      max_tokens=config.response_length,
      )

      # # we may detokenize the result all together later
      if vllm_version != "0.3.1":
      kwargs["detokenize"] = False

      # supporting adding any sampling params from the config file
      for k in config.keys():
      if hasattr(SamplingParams(), str(k)):
      kwargs[k] = config.get(k)

      print(f"kwargs: {kwargs}")
      self.sampling_params = SamplingParams(**kwargs)

      self.pad_token_id = tokenizer.pad_token_id

      • 对于rollout_sharding_manager,其核心功能是将已经被分片的训练好的actor模型收集起来,然后加载并更新到vllm中
  3. 之后如果是actor或者是rollout就会创建FSDPCheckpointManager来管理检查点

模型转换rollout_sharding_manager

在整个强化学习过程中,随着训练进行,训练模型参数会被更新,故而在后续的rollout过程中需要使用最新的模型参数,这就涉及到了模型参数从训练引擎向推理引擎更新的过程。而这主要就是由rollout_sharding_manager完成。这里我们查看fsdp与vllm之间进行转化的rollout_sharding_manager,其相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class FSDPVLLMShardingManager(BaseShardingManager):
@check_device_is_available()
def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False, device_mesh: DeviceMesh = None, offload_param: bool = False, load_format: str = "dummy_hf", layered_summon: bool = True):
self.module = module
# For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model
self.inference_engine = inference_engine
# self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None

if "vllm_v_0_6_3" in str(type(self.inference_engine)) or "vllm_v_0_5_4" in str(type(self.inference_engine)):
# vLLM <= v0.6.3
self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None
else:
# vLLM > v0.6.3
self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None

self.model_config = model_config
self.device_mesh = device_mesh
self.offload_param = offload_param
self.load_format = load_format
self.layered_summon = layered_summon

# Full params
self.full_params = full_params
if full_params and fsdp_version(self.module) == 1:
FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())
elif fsdp_version(self.module) == 1:
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)

self.tp_size = self.device_mesh["infer_tp"].size()
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()

# Note that torch_random_states may be different on each dp rank
self.torch_random_states = get_torch_device().get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None

self.base_sync_done: bool = "dummy" not in load_format
if is_version_ge(pkg="vllm", minver="0.7.3"):
VLLMHijack.hijack()

@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def __enter__(self):
def __collect_lora_params() -> OrderedDict:
"""
collect lora params or full params if base model is not ready in vllm
work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)
"""
from peft.utils.save_and_load import get_peft_model_state_dict

lora_params = OrderedDict()
peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module)
if fsdp_version(self.module) > 0:
if self.layered_summon:
if not self.base_sync_done:
raise ValueError("To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let rollout.load_format=safetensors")
lora_params = layered_summon_lora_params(self.module)
else:
with FSDP.summon_full_params(self.module, writeback=False):
if self.base_sync_done:
lora_params = get_peft_model_state_dict(peft_model)
lora_params = {name: param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() for name, param in lora_params.items()}
else:
model = peft_model.base_model.model
orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name()
model = model.to("cpu")
for name, param in model.state_dict().items():
if any(x in name for x in ["_flat_param", "lora_"]):
continue
name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "")
lora_params[name] = param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu()
model = model.to(orig_dev)
get_torch_device().empty_cache()
else:
if self.base_sync_done:
lora_params = get_peft_model_state_dict(peft_model)
else:
model = peft_model.base_model.model
orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name()
model = model.to("cpu")
for name, param in model.state_dict().items():
if any(x in name for x in ["_flat_param", "lora_"]):
continue
name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "")
lora_params[name] = param.detach().cpu()
model = model.to(orig_dev)
return lora_params

# NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
self.timing = {}
with simple_timer("reshard", self.timing):
get_torch_device().empty_cache()

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)

peft_config = None
peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module)
if hasattr(peft_model, "peft_config"):
peft_config = peft_model.peft_config.get("default", None)
params = __collect_lora_params()
else:
params = self.module.state_dict()
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)

# Copy, not share memory
load_format = "hf" if self.full_params else "dtensor"

if vllm_version in (
"0.5.4",
"0.6.3",
):
self.inference_engine.sync_model_weights(params, load_format=load_format)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
del params
else:
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
else:
self.inference_engine.wake_up()

# update model params
self.update_params(params, peft_config=peft_config)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
del params
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
get_torch_device().empty_cache()

if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["kv_cache"])

log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)

@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
# TODO(ZSL): check this
if vllm_version in (
"0.5.4",
"0.6.3",
):
self.inference_engine.offload_model_weights()
else:
self.inference_engine.sleep(level=1)

self.module.train()

# add empty cache after each compute
get_torch_device().empty_cache()

# restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)

@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data

# TODO: Current impl doesn't consider FSDP with torch micro-dp
if vllm_version in (
"0.5.4",
"0.6.3",
):
group = vllm_ps.get_tensor_model_parallel_group()
else:
group = vllm_ps.get_tensor_model_parallel_group().device_group

all_gather_data_proto(data=data, process_group=group)
return data

@GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data

return data.chunk(chunks=self.tp_size)[self.tp_rank]

def update_params(self, updated_params, peft_config=None):
model = self.model_runner.model
if peft_config:
if self.base_sync_done:
lora_int_id = int(time.time_ns() % 0x7FFFFFFF)
lora_reqest = TensorLoRARequest(
lora_name=f"{lora_int_id}",
lora_int_id=lora_int_id,
lora_path="simon_lora_path",
peft_config=asdict(peft_config),
lora_tensors=updated_params,
)
self.inference_engine.llm_engine.add_lora(lora_reqest)
logger.info(f"vLLM load weights, loaded_params: {len(updated_params)}")
return
else:

def replace_lora_wrapper(k):
stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
if any([k.endswith(f"{s}.weight") for s in stacked_params]):
return k.replace(".weight", ".base_layer.weight")
if any([k.endswith(f"{s}.bias") for s in stacked_params]):
return k.replace(".bias", ".base_layer.bias")
return k

updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()}

patch_vllm_moe_model_weight_loader(model)
device = get_device_id() # used when fsdp2 set cpu_offload_policy
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))

self.base_sync_done = True
logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}")

使用方式

rollout_sharding_manager在rollout执行generate_sequences时会被调用,调用代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="red")
def generate_sequences(self, prompts: DataProto):
# ...
with self.rollout_sharding_manager:
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)

prompts = self.rollout_sharding_manager.preprocess_data(prompts)
with simple_timer("generate_sequences", timing_generate):
output = self.rollout.generate_sequences(prompts=prompts)

log_gpu_memory_usage("After rollout generation", logger=logger)

output = self.rollout_sharding_manager.postprocess_data(output)

# ...

# clear kv cache
get_torch_device().empty_cache()
return output

其整体流程为:

  1. 随着with self.rollout_sharding_manager的调用,执行__enter__

  2. 然后再执行preprocess_data

  3. 然后才调用推理引擎执行generate_sequences

  4. 然后再执行postprocess_data

  5. 最后随着退出with self.rollout_sharding_manager的范围,会默认执行__exit__

初始化

FSDPVLLMShardingManager初始化时的调用代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
full_params = torch.distributed.get_world_size() == 1
rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params=full_params,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
load_format=self.config.rollout.load_format,
layered_summon=self.config.rollout.get("layered_summon", False),
)

在初始化时需要注意的配置有:

  • 因为之前的vllm不支持spmd,所以依据vllm版本进行了一些定制

  • 依据full_param参数(即当前torch环境中卡的数量是否是1),调用设置FSDP.set_state_dict_type

    • 如果只有一张卡并且fsdp version是1,那么就设置在获取state_dict的时候直接就获取全参数,相当于是为单卡推理情景专门做了些定制

    • 不然一般情况都是获取分片后的DTensor

  • 之后还获取到了推理device_mesh中的tp、dp信息等

__enter____exit__

这里直接看一般模型以及新版本vllm的处理方式,暂时不管lora以及旧版本vllm的处理方式。

__enter__的处理流程如下:

  1. 首先是清除torch中cache缓存,留出显存空间

  2. 如果是offload_param模式,需要将FSDP模型从CPU中加载到GPU中

  3. 调用self.module.state_dict()来获取到模型参数params,注意这里会依据之前FSDP.set_state_dict_type的设置,如果是rollout是单卡就直接获取到全参数,如果不是就获取到的是分片后的DTensor

  4. 调用vllm的wake_up(),激活vllm

  5. 将获取到的模型参数params更新给vllm。

    1. 其中关键加载的代码为:loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))

    2. 这里如果是参数是DTensor类型,就会调用full_tensor()借助DTensor来获取全部参数,如果不是DTensor就说明已经是全参数了,直接加载进vllm即可,而vllm中TP的切分在vllm内部中自行完成

    3. 加载进vllm中可能会涉及到一些参数格式的转换,观察到之前verl定制的vllm的加载,其中会涉及到类似将w_qw_kw_v都合并为一个大Tensor,这样后续推理时能一次性计算出来

  6. 如果是offload_param模式,需要将FSDP模型从GPU中再加载到CPU中,并且再清除torch中cache缓存

__exit__的处理流程如下:

  1. 调用vllm的level为1的sleep,其会将vllm的模型参数卸载到CPU中并清除掉kv cache

  2. 将fsdp模型转变为train状态

  3. 再清除torch的cache

  4. 存储随机状态

preprocess_datapostprocess_data

在rollout进行generate_sequences前对于多worker的形式需要进行数据处理

preprocess_data的处理流程如下:

  1. 如果tp_size是1,就直接返回data,不需要处理

  2. 如果tp_size不是1,就需要处理数据,因为当前是将数据均匀分发给各个worker,但是由于TP的存在,同组的TP worker应该使用的是同样的数据,所以这里通过all gather来获取到全部的Tensor和非Tensor数据

postprocess_data的处理流程如下:

  1. 如果tp_size不是1,就需要将数据进行分片,然后每个tp worker取自己的一部分

【Verl源码分析(三)】Verl中训练引擎与推理引擎共置处理(以FSDP、vLLM为例)
http://example.com/2026/02/20/Verl-train-inference/
作者
滑滑蛋
发布于
2026年2月20日
许可协议