【Verl源码分析(二)】Verl中的数据流动

注意查看的是0.4.1.x版本的Verl代码:https://github.com/verl-project/verl/tree/v0.4.1.x

下面给出verl/trainer/ppo/ray_trainer.py中RayPPOTrainer.fit简化后的核心流程代码,也就是Verl中训练的核心流程,下面以此为核心从数据流动的视角进行介绍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def fit(self):
# ------------------------------------------------------------
# 0. 初始化阶段
# ------------------------------------------------------------

# 在任何训练 / rollout 之前加载 checkpoint
# 包括 actor / critic / optimizer / scheduler 等状态
self._load_checkpoint()

# ------------------------------------------------------------
# 1. 训练前验证(baseline & sanity check)
# ------------------------------------------------------------
# 如果提供了 validation reward_fn,并且配置允许,
# 则在训练开始前跑一次完整 validation,
# 用于:
# - 建立初始性能 baseline
# - 验证 reward_fn / rollout / decoding 是否正常
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
self._validate()
# val_only 用于 debug / evaluation-only 场景
if self.config.trainer.get("val_only", False):
return

# PPO 的 step 从 1 开始(与 logging / scheduler 对齐)
self.global_steps = 1

# ------------------------------------------------------------
# 2. 训练主循环(epoch × batch)
# ------------------------------------------------------------
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
# 将 dataloader 输出转换为 DataProto,
# 这是 verl 内部统一的数据承载结构
batch: DataProto = DataProto.from_single_dict(batch_dict)

# ----------------------------------------------------
# 2.1 构造 rollout 输入(prompt-only batch)
# ----------------------------------------------------
# 从 batch 中移除 generation 不需要的字段,
# 得到一个只包含 prompt 的 gen_batch
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)

# 是否是最后一个训练 step(用于提前退出 & validation)
is_last_step = self.global_steps >= self.total_training_steps

# ----------------------------------------------------
# 2.2 Actor rollout(采样 response)
# ----------------------------------------------------
# 使用当前 actor policy 对 prompt 进行采样生成,
# 得到 response / logprob / timing 等信息
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

# ----------------------------------------------------
# 2.3 REMAX baseline(可选)
# ----------------------------------------------------
# 如果使用 REMAX 优势估计:
# - 对同一 prompt 再生成一次 greedy response
# - 用 greedy reward 作为 baseline
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False # greedy decoding
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

# 将 greedy response 拼回 batch 以便计算 reward
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)

# 存储 baseline reward,后续 advantage = R_sample - R_baseline
batch.batch["reward_baselines"] = reward_baseline_tensor

# ----------------------------------------------------
# 2.4 对齐 rollout 维度 & 合并采样结果
# ----------------------------------------------------
# rollout.n > 1 时,同一个 prompt 会生成多个 response,
# 这里 repeat 原始 batch,使其与 rollout 输出对齐
batch = batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n,
interleave=True
)
batch = batch.union(gen_batch_output)

# 构造 response mask,用于区分 prompt / response token
batch.batch["response_mask"] = compute_response_mask(batch)

# ----------------------------------------------------
# 2.5 Reward 计算
# ----------------------------------------------------
if self.use_rm:
# 使用 reward model 计算 token-level / sequence-level reward
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# 结合 reward_fn 进行后处理(加权 / 规则奖励等)
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

# ----------------------------------------------------
# 2.6 Log-probability 计算
# ----------------------------------------------------
# 计算当前 actor policy 下的 log_prob(old_log_prob)
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)

# 如果使用 reference policy(KL 约束)
if self.use_reference_policy:
# ref policy 可以是独立模型,也可以复用 actor
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)

batch = batch.union(ref_log_prob)

# ----------------------------------------------------
# 2.7 Critic value 预测(可选)
# ----------------------------------------------------
if self.use_critic:
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

# token-level reward,用于 advantage / loss 计算
batch.batch["token_level_scores"] = reward_tensor

# ----------------------------------------------------
# 2.8 KL penalty(reward-level)
# ----------------------------------------------------
# 在 reward 层面引入 KL 惩罚(而不是 loss 层)
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty
)

# ----------------------------------------------------
# 2.9 Advantage 计算
# ----------------------------------------------------
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable,
config=self.config.algorithm,
)

# ----------------------------------------------------
# 2.10 Critic 更新
# ----------------------------------------------------
if self.use_critic:
critic_output = self.critic_wg.update_critic(batch)

# ----------------------------------------------------
# 2.11 Actor 更新
# ----------------------------------------------------
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)

# ----------------------------------------------------
# 2.12 Validation & step 管理
# ----------------------------------------------------
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
val_metrics: dict = self._validate()

self.global_steps += 1
if is_last_step:
return

数据流动

原始数据集处理

整体概览如下:

这里以经典的gsm8k数据集(https://huggingface.co/datasets/openai/gsm8k)为例,介绍Verl对其一开始做了什么处理。

  • gsm8k数据集分为训练集和测试集,每个数据集中包含2个字段,分别是question和answer,其中answer会以####{最终答案}的形式来记录最终答案。

  • Verl提供了examples/data_preprocess/gsm8k.py来处理该原始数据集,将其进行格式转换,最终存储为test.parquettrain.parquet

    • 其核心处理是给每个问题后面添加了Let\'s think step by step and output the final answer after "####".来引导模型用####{最终答案}的方式来输出。

    • 然后将其格式转化为了如下的样式:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    data = {
    "data_source": data_source,
    "prompt": [
    {
    "role": "user",
    "content": question,
    }
    ],
    "ability": "math",
    "reward_model": {"style": "rule", "ground_truth": solution},
    "extra_info": {
    "split": split,
    "index": idx,
    "answer": answer_raw,
    "question": question_raw,
    },
    }
    • 这里取其中一个问题为例进行介绍:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    {
    "data_source": "openai/gsm8k",
    "prompt": [
    {
    "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\"s think step by step and output the final answer after "####".",
    "role": "user"
    }
    ],
    "ability": "math",
    "reward_model": {
    "ground_truth": "72",
    "style": "rule"
    },
    "extra_info": {
    "answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72",
    "index": 0,
    "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
    "split": "train"
    }
    }
    • 注意转换为parquet格式的一个好处在于parquet是列存储的,其可以很方便地直接读取到某一属性的所有值。
  • 进一步在训练前会进行数据的预处理将其转化为train_dataloaderval_dataloader

    1. 首先分别读取test.parquettrain.parquet,并分别包装为RLHFDataset(Dataset)类型的train_datasetval_dataset,其主要作用是记录了一些关键配置,读取并保存对应的parquet文件以及tokenizer和多模态处理的processor

    2. 然后再创建一个train_sampler负责采样训练数据

    3. 然后结合train_datasettrain_sampler构建出一个torchdata.stateful_dataloader类型的train_dataloadertorchdata.stateful_dataloader类型的主要特点是可以从checkpoint中恢复上次数据读取的位置,从而保证数据采样的轨迹还是和之前相同。

    4. 进一步的还会创建出val_dataloader负责加载验证集数据

训练迭代中的数据处理

总体概览如下:

  • 在训练迭代过程中会读取dataloader来获取数据进行训练与验证

    • 在训练过程中读取train_dataloader的相关代码如下:
    1
    2
    3
    4
    for epoch in range(self.config.trainer.total_epochs):
    for batch_dict in self.train_dataloader:

    batch: DataProto = DataProto.from_single_dict(batch_dict)
    • 其获取到的batch_dict数据如下,可以看到其既包含原始数据,也包含了tokenizer处理后的数据,以及相关的attention_mask和position_ids等

    • 其后会进一步转换为的DataProto类的batchDataProto是Verl核心的数据处理类,其负责将tensor类型的变量转化为TensorDict,从而方便做到batch级的操作,如读取第几个batch的数据,或者将其切分为好几个小batch,还负责将非tensor内容放置到non_tensor_batch中。

      • 如下所示,tensor类型的变量包括了训练的核心数据input_idattention_maskposition_ids,非tensor类型的变量是一些额外的信息

      • 最终得到的batch如下所示:

    • 然后会从batchpop弹出获取到input_idattention_maskposition_ids这3类数据并进一步组成DataProto类的gen_batch

    • 然后gen_batch会被送给actor_rollout_wg然后在各个worker上执行generate_sequences来获取采样数据gen_batch_output

      • 注意generate_sequences函数的dispatch方式是DP_COMPUTE_PROTO

        • 其分发方法为dispatch_dp_compute_data_proto,相关代码如下,可以看到其将数据划分为了worker_group.world_size份,每一份后续才会作为worker的输入,worker在读取时如果这个数据在远程就会通过ray进行拉取。
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        37
        38
        39
        40
        41
        42
        43
        44
        45
        46
        47
        48
        49
        50
        def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
        from verl.single_controller.base.worker_group import WorkerGroup

        assert isinstance(worker_group, WorkerGroup)
        # Note: enable auto padding for dp compute DatapProto
        splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
        worker_group.world_size,
        *args,
        **kwargs,
        )
        return splitted_args, splitted_kwargs

        def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
        from verl.protocol import DataProto, DataProtoFuture

        splitted_args = []
        splitted_kwargs = {}

        data_proto_len = None
        padding_size = None
        for arg in args:
        assert isinstance(arg, (DataProto, DataProtoFuture))
        if isinstance(arg, DataProto) and arg.is_padding_enabled():
        # for padding, we only support DataProto with same length
        if data_proto_len is None:
        data_proto_len = len(arg)
        padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0
        splitted_kwargs[_padding_size_key] = padding_size
        else:
        assert data_proto_len == len(arg), f"expecting all arg share same length of {data_proto_len}, but got {len(arg)}"
        data_proto_len = len(arg)
        arg.padding(padding_size=padding_size)

        splitted_args.append(arg.chunk(chunks=chunks))

        for key, val in kwargs.items():
        assert isinstance(val, (DataProto, DataProtoFuture))
        if isinstance(val, DataProto) and val.is_padding_enabled():
        # for padding, we only support DataProto with same length
        if data_proto_len is None:
        data_proto_len = len(val)
        padding_size = chunks - (data_proto_len % chunks)
        splitted_kwargs[_padding_size_key] = padding_size
        else:
        assert data_proto_len == len(val), f"expecting all arg share same length of {data_proto_len}, but got {len(val)}"
        data_proto_len = len(val)
        splitted_kwargs[key] = val.chunk(chunks=chunks)

        return splitted_args, splitted_kwargs

        • 其收集方法是collect_dp_compute_data_proto,相关代码如下,可以看到其若采取的是sync模式,则输出会是DataProto类型,那么就直接将其拼接起来,如果是异步模式,即得到的是ray.ObjectRef类型,那么就只会将这些ref拼接起来。
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        def collect_dp_compute_data_proto(worker_group, output):
        import ray

        from verl.protocol import DataProto

        for o in output:
        assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"

        output = collect_dp_compute(worker_group, output)
        return _concat_data_proto_or_future(output)

        def collect_dp_compute(worker_group, output):
        from verl.single_controller.base.worker_group import WorkerGroup

        assert isinstance(worker_group, WorkerGroup)
        assert len(output) == worker_group.world_size
        return output

        def _concat_data_proto_or_future(output: List):
        import ray

        from verl.protocol import DataProto, DataProtoFuture

        # make sure all the elements in output has the same type
        for o in output:
        assert type(o) is type(output[0])

        o = output[0]

        if isinstance(o, DataProto):
        return DataProto.concat(output)
        elif isinstance(o, ray.ObjectRef):
        return DataProtoFuture.concat(output)
        else:
        raise NotImplementedError
      • 并且generate_sequences函数的执行模式是Execute.ALL,并且实例采用的是blocking模式,即会等其执行完后再把数据获取回来。

    • 因为在generate_sequences时一个输入可能会输出self.config.actor_rollout_ref.rollout.n个结果,所以为了对齐generate_sequences的结果,需要将batch中的数据对应重复self.config.actor_rollout_ref.rollout.n个,然后再合并gen_batch_output,此时batch中除了原本的input_idattention_maskposition_ids外还有了promptresponsesresponse_mask

      • prompt是在生成时额外在后续添加的提升内容,如
      1
      'system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.\nuser\nTom receives a $12 allowance per month. In the first week, he spends a third of it; in the second week, he spends a quarter of what he has left. How much money does he have left to finish the month? Let\'s think step by step and output the final answer after "####".\nassistant\n'
      • responsesh是实际模型生成的内容,如:
      1
      "To find out how much money Tom has left after the stipulated transactions, let's work through the steps:\n\n1. **Step 1: Tom receives $12 for his first allowance.**\n\n2. **Step 2: Tom spends a third of his allowance in the first week.**\n   - Amount spent = $\\frac{1}{3} \\times $12 = $4$\n\n3. **Step 3: Calculate the remaining amount of money after spending in the first week:**\n   - Remaining = $12 - $4 = $8\n\n4. **Step 4: Tom spends a quarter of what he has left in the second week.**\n   - Amount spent in the second week = $\\frac{1}{4} \\times $8 = $2\n\n5. **Step 5: Calculate the remaining amount of money after spending in the second week:**\n   - Remaining is $8 - $2 = $6\n\nSo, after he has both runs of spending, Tom has **$6** left to finish the month."
    • 如果有配置还会进行token级别的在各个DP rank间调整序列顺序以平衡计算量

    • 然后会调用reward模型或者reward func来计算reward,这使得batch中又额外添加了token_level_scores

    • 还会按需使用actor重新计算old_log_prob以避免actor与rollout的精度不统一,这使得batch中又额外添加了old_log_prob

    • 如果配置了use_reference_policy,还会计算ref的log_prob,使得batch中又额外添加了ref_log_prob

    • 如果配置了use_critic,还会调用critic计算values,使得batch中又额外添加了values元素

    • 然后会计算advantage,Verl支持GAE、GRPO等计算方式,计算后会给batch添加advantagesreturns

      • 在PPO中advantages就是优势,returnsadvantages+values,policy loss 用 advantages,value loss 用 returns

      • 在GRPO中advantagesreturns相同

    • 然后会更新critic,在更新critic过程中会使用到batchinput_idsresponsesattention_maskposition_idsvaluesreturns

      • 其中会使用到input_idsresponsesattention_maskposition_ids来在critic中重新前向传播得到最新预测的vpreds

      • 然后依据vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)计算clip后的value防止更新过猛

      • 进一步与returns比较得到loss

    • 然后会更新actor,在更新actor过程中会会使用到batchinput_idsresponsesattention_maskposition_idsold_log_probsadvantages

      • 其中会使用到input_idsresponsesattention_maskposition_ids来在actor中重新前向传播得到最新模型的log_prob,用于结合old_log_probsadvantages进行重要性采样

      • 然后使用经典的PPO clip来计算重要性采样下的advantage,从而更新actor


【Verl源码分析(二)】Verl中的数据流动
http://example.com/2026/02/03/Verl-Data-Process/
作者
滑滑蛋
发布于
2026年2月3日
许可协议