【Verl源码分析(三)】Verl中训练引擎与推理引擎共置处理(以FSDP、vLLM为例)
Verl在进行强化学习训练时,既需要使用推理引擎执行推理采样,也需要训练引擎进行模型更新,所以需要使用两类引擎,故这里以FSDP训练引擎和vLLM推理引擎为例对Verl的相关处理进行介绍。
注意查看的是0.4.1.x版本的Verl代码:https://github.com/verl-project/verl/tree/v0.4.1.x
初始化
实例初始化
创建合并的类WorkDict并将其包裹为ray的类然后进行实例化的相关代码如下:
1 | |
在self.ray_worker_group_cls(...)中就会进行实例化,其中ray_worker_group_cls往往是RayWorkerGroup,其初始化后会在资源池的各个GPU节点上初始化对应的WorkDict实例。
WorkDict实例化时执行的内容如下所示,如果环境变量没有配置DISABLE_WORKER_INIT就会直接将各个角色的类也进行初始化,一般而言都会直接初始化。
1 | |
我们这里关注的训练与推理的角色是ActorRolloutRefWorker类,该类在执行__init__初始化时会初始化一些配置,主要的流程如下:
使用
torch.distributed.init_process_group生成NCCL通信按
fsdp_size配置生成device_mesh,这决定了FSDP整体的布局,可能是1D也可能是2D的布局:1D:
(world_size,)→ FSDP FULL_SHARD(类似 ZeRO-3)2D:
(ddp, fsdp)→ FSDP HYBRID_SHARD(外层 DDP、内层 FSDP)
当前
ActorRolloutRefWorker的是什么角色是依据role名决定的,如下所示:
1 | |
如果是actor,那么就支持配置是否需要offload_param以及是否需要offload_optimizer,如果是ref,那么就只支持配置是否需要offload_param
然后还需要根据
device_mesh以及config来计算是否需要micro_batch_size等配置
模型初始化
在初始化了WorkDict实例后,会提取出actor_rollout_wg,然后会调用actor_rollout_wg.init_model(...)函数来进行模型的初始化。init_model执行内容如下:
执行
_build_model_optimizer函数,如果当前角色是actor那么就会构建出model、optimizer和lr_scheduler,如果不是actor就只会构建出model。构建model是使用transformers库的
from_pretrained这一个API,从下载的文件中加载出来,注意目前是直接加载到CPU内存中,然后model加载后还会使用FSDP进行包装:如果不是actor,那么就会强制启用
cpu_offload来将FSDP模型先放入到CPU中以节省内存;如果是actor,那么就会直接加载到当前worker所属的GPU上,后面还会依据_is_offload_param配置决定是否需要将模型卸载到CPU上这里依据
self.config.actor.strategy配置支持fsdp与fsdp2两类分割方法
构建optimizer默认使用的是AdamW优化器,初始化时其占用的空间是与model所在的设备相同,并且后面会依据
_is_offload_optimizer来将优化器卸载到CPU上
如果当前角色是rollout就会执行
_build_rollout,rollout引擎支持hf、vllm以及sglang,单看vllm引擎的执行:其初始化的代码如下所示,其主要初始化了vllm引擎,以及使用刚刚通过
_build_model_optimizer获取到的self.actor_module_fsdp来初始化rollout_sharding_manager1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31elif rollout_name == "vllm":
from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
lora_kwargs = {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} if self._is_lora else {}
# lora_kwargs = {}
if vllm_mode == "customized":
rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, **lora_kwargs)
elif vllm_mode == "spmd":
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, trust_remote_code=trust_remote_code, **lora_kwargs)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
full_params = torch.distributed.get_world_size() == 1
rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params=full_params,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
load_format=self.config.rollout.load_format,
layered_summon=self.config.rollout.get("layered_summon", False),
)
log_gpu_memory_usage("After building sharding manager", logger=logger)- 以同步执行的
vLLMRollout类为例,其初始化代码如下,核心而言,其self.inference_engine在初始化时会从路径加载模型到GPU上,因为vLLM worker 需要完整模型结构(config/tokenizer/层定义)才能初始化调度器和通信拓扑,而随之会调用sleep(level=1)函数来卸载vLLM初始化时所占用的GPU显存,对于level=1级别的sleep,其会清空KVCache,然后将模型权重卸载到CPU上
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in vllm
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super().__init__()
self.config = config
assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size"
max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192)
if kwargs.get("train_tp") is not None:
# deployed with megatron
import os
os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
if vllm_version in (
"0.5.4",
"0.6.3",
):
train_tp = kwargs.get("train_tp")
num_tp_per_train_tp = train_tp // tensor_parallel_size
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp)
else:
vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size)
rope_scaling_config = getattr(model_hf_config, "rope_scaling", None)
if not rope_scaling_config:
max_position_embeddings = None
if hasattr(model_hf_config, "max_position_embeddings"):
max_position_embeddings = model_hf_config.max_position_embeddings
elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"):
max_position_embeddings = model_hf_config.llm_config.max_position_embeddings
elif hasattr(model_hf_config, "text_config") and hasattr(model_hf_config.text_config, "max_position_embeddings"):
max_position_embeddings = model_hf_config.text_config.max_position_embeddings
if max_position_embeddings is None:
raise ValueError("max_position_embeddings not found in model_hf_config")
assert max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length"
max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)
if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:
raise ValueError(
"Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \
please increase max_num_batched_tokens or disable chunked prefill"
)
trust_remote_code = kwargs.get("trust_remote_code", False)
load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format
lora_kwargs = kwargs.pop("lora_kwargs", {})
self.lora_kwargs = lora_kwargs
# copy it to avoid secretly modifying the engine config
engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm))
# For each vLLM engine parameter,
# - `None` means not setting it, so we pop it, and leave it to vLLM default value
# (which can vary across different vLLM versions);
# - Otherwise it's the desired value we want to explicitly set.
engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}
if config.get("limit_images", None): # support for multi-image data
engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")}
self.inference_engine = LLM(
model=model_path,
enable_sleep_mode=True,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="external_launcher",
dtype=config.dtype,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
disable_custom_all_reduce=True,
skip_tokenizer_init=False,
max_model_len=max_model_len,
load_format=load_format,
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_prefix_caching=True,
trust_remote_code=trust_remote_code,
seed=config.get("seed", 0),
**lora_kwargs,
**engine_kwargs,
)
# Offload vllm model to reduce peak memory usage
self.inference_engine.sleep(level=1)
kwargs = dict(
n=1,
logprobs=0, # can be set to 0 and let actor to recompute
max_tokens=config.response_length,
)
# # we may detokenize the result all together later
if vllm_version != "0.3.1":
kwargs["detokenize"] = False
# supporting adding any sampling params from the config file
for k in config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = config.get(k)
print(f"kwargs: {kwargs}")
self.sampling_params = SamplingParams(**kwargs)
self.pad_token_id = tokenizer.pad_token_id- 对于
rollout_sharding_manager,其核心功能是将已经被分片的训练好的actor模型收集起来,然后加载并更新到vllm中
- 以同步执行的
之后如果是actor或者是rollout就会创建
FSDPCheckpointManager来管理检查点
模型转换rollout_sharding_manager
在整个强化学习过程中,随着训练进行,训练模型参数会被更新,故而在后续的rollout过程中需要使用最新的模型参数,这就涉及到了模型参数从训练引擎向推理引擎更新的过程。而这主要就是由rollout_sharding_manager完成。这里我们查看fsdp与vllm之间进行转化的rollout_sharding_manager,其相关代码如下:
1 | |
使用方式
rollout_sharding_manager在rollout执行generate_sequences时会被调用,调用代码如下
1 | |
其整体流程为:
随着
with self.rollout_sharding_manager的调用,执行__enter__然后再执行
preprocess_data然后才调用推理引擎执行
generate_sequences然后再执行postprocess_data
最后随着退出
with self.rollout_sharding_manager的范围,会默认执行__exit__
初始化
FSDPVLLMShardingManager初始化时的调用代码如下所示:
1 | |
在初始化时需要注意的配置有:
因为之前的vllm不支持spmd,所以依据vllm版本进行了一些定制
依据
full_param参数(即当前torch环境中卡的数量是否是1),调用设置FSDP.set_state_dict_type如果只有一张卡并且fsdp version是1,那么就设置在获取state_dict的时候直接就获取全参数,相当于是为单卡推理情景专门做了些定制
不然一般情况都是获取分片后的DTensor
之后还获取到了推理
device_mesh中的tp、dp信息等
__enter__与__exit__
这里直接看一般模型以及新版本vllm的处理方式,暂时不管lora以及旧版本vllm的处理方式。
__enter__的处理流程如下:
首先是清除torch中cache缓存,留出显存空间
如果是
offload_param模式,需要将FSDP模型从CPU中加载到GPU中调用
self.module.state_dict()来获取到模型参数params,注意这里会依据之前FSDP.set_state_dict_type的设置,如果是rollout是单卡就直接获取到全参数,如果不是就获取到的是分片后的DTensor。调用vllm的
wake_up(),激活vllm将获取到的模型参数
params更新给vllm。其中关键加载的代码为:
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))这里如果是参数是DTensor类型,就会调用
full_tensor()借助DTensor来获取全部参数,如果不是DTensor就说明已经是全参数了,直接加载进vllm即可,而vllm中TP的切分在vllm内部中自行完成加载进vllm中可能会涉及到一些参数格式的转换,观察到之前verl定制的vllm的加载,其中会涉及到类似将
w_q、w_k、w_v都合并为一个大Tensor,这样后续推理时能一次性计算出来
如果是
offload_param模式,需要将FSDP模型从GPU中再加载到CPU中,并且再清除torch中cache缓存
__exit__的处理流程如下:
调用vllm的level为1的sleep,其会将vllm的模型参数卸载到CPU中并清除掉kv cache
将fsdp模型转变为train状态
再清除torch的cache
存储随机状态
preprocess_data与postprocess_data
在rollout进行generate_sequences前对于多worker的形式需要进行数据处理
preprocess_data的处理流程如下:
如果tp_size是1,就直接返回data,不需要处理
如果tp_size不是1,就需要处理数据,因为当前是将数据均匀分发给各个worker,但是由于TP的存在,同组的TP worker应该使用的是同样的数据,所以这里通过all gather来获取到全部的Tensor和非Tensor数据
postprocess_data的处理流程如下:
- 如果tp_size不是1,就需要将数据进行分片,然后每个tp worker取自己的一部分